Add CSR sparse matrix-vector multiply operation

This commit is contained in:
jan 2016-07-04 22:55:59 -07:00
parent 12baa97592
commit c9cb48d175

View File

@ -224,3 +224,35 @@ def create_dot(context):
return g.get() return g.get()
return ri_update return ri_update
def create_a_csr(context):
spmv_source = '''
int start = m_row_ptr[i];
int stop = m_row_ptr[i+1];
cdouble_t dot = cdouble_new(0.0, 0.0);
int col_ind, d_ind;
for (int j=start; j<stop; j++) {
col_ind = m_col_ind[j];
d_ind = j;
dot = cdouble_add(dot, cdouble_mul(v_in[col_ind], m_data[d_ind]));
}
v_out[i] = dot;
'''
v_out_args = ctype + ' *v_out'
m_args = 'int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data'
v_in_args = ctype + ' *v_in'
spmv_kernel = ElementwiseKernel(context,
name='csr_spmv',
preamble=preamble,
operation=spmv_source,
arguments=', '.join((v_out_args, m_args, v_in_args)))
def spmv(v_out, m, v_in, e):
return spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)
return spmv