Add CSR sparse matrix-vector multiply operation
This commit is contained in:
parent
12baa97592
commit
c9cb48d175
@ -224,3 +224,35 @@ def create_dot(context):
|
|||||||
return g.get()
|
return g.get()
|
||||||
|
|
||||||
return ri_update
|
return ri_update
|
||||||
|
|
||||||
|
|
||||||
|
def create_a_csr(context):
|
||||||
|
spmv_source = '''
|
||||||
|
int start = m_row_ptr[i];
|
||||||
|
int stop = m_row_ptr[i+1];
|
||||||
|
cdouble_t dot = cdouble_new(0.0, 0.0);
|
||||||
|
|
||||||
|
int col_ind, d_ind;
|
||||||
|
for (int j=start; j<stop; j++) {
|
||||||
|
col_ind = m_col_ind[j];
|
||||||
|
d_ind = j;
|
||||||
|
|
||||||
|
dot = cdouble_add(dot, cdouble_mul(v_in[col_ind], m_data[d_ind]));
|
||||||
|
}
|
||||||
|
v_out[i] = dot;
|
||||||
|
'''
|
||||||
|
|
||||||
|
v_out_args = ctype + ' *v_out'
|
||||||
|
m_args = 'int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data'
|
||||||
|
v_in_args = ctype + ' *v_in'
|
||||||
|
|
||||||
|
spmv_kernel = ElementwiseKernel(context,
|
||||||
|
name='csr_spmv',
|
||||||
|
preamble=preamble,
|
||||||
|
operation=spmv_source,
|
||||||
|
arguments=', '.join((v_out_args, m_args, v_in_args)))
|
||||||
|
|
||||||
|
def spmv(v_out, m, v_in, e):
|
||||||
|
return spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)
|
||||||
|
|
||||||
|
return spmv
|
||||||
|
Loading…
Reference in New Issue
Block a user