Add CSR sparse matrix-vector multiply operation
This commit is contained in:
parent
12baa97592
commit
c9cb48d175
@ -224,3 +224,35 @@ def create_dot(context):
|
||||
return g.get()
|
||||
|
||||
return ri_update
|
||||
|
||||
|
||||
def create_a_csr(context):
|
||||
spmv_source = '''
|
||||
int start = m_row_ptr[i];
|
||||
int stop = m_row_ptr[i+1];
|
||||
cdouble_t dot = cdouble_new(0.0, 0.0);
|
||||
|
||||
int col_ind, d_ind;
|
||||
for (int j=start; j<stop; j++) {
|
||||
col_ind = m_col_ind[j];
|
||||
d_ind = j;
|
||||
|
||||
dot = cdouble_add(dot, cdouble_mul(v_in[col_ind], m_data[d_ind]));
|
||||
}
|
||||
v_out[i] = dot;
|
||||
'''
|
||||
|
||||
v_out_args = ctype + ' *v_out'
|
||||
m_args = 'int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data'
|
||||
v_in_args = ctype + ' *v_in'
|
||||
|
||||
spmv_kernel = ElementwiseKernel(context,
|
||||
name='csr_spmv',
|
||||
preamble=preamble,
|
||||
operation=spmv_source,
|
||||
arguments=', '.join((v_out_args, m_args, v_in_args)))
|
||||
|
||||
def spmv(v_out, m, v_in, e):
|
||||
return spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)
|
||||
|
||||
return spmv
|
||||
|
Loading…
Reference in New Issue
Block a user