|
|
|
@ -224,3 +224,35 @@ def create_dot(context):
|
|
|
|
|
return g.get()
|
|
|
|
|
|
|
|
|
|
return ri_update
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_a_csr(context):
|
|
|
|
|
spmv_source = '''
|
|
|
|
|
int start = m_row_ptr[i];
|
|
|
|
|
int stop = m_row_ptr[i+1];
|
|
|
|
|
cdouble_t dot = cdouble_new(0.0, 0.0);
|
|
|
|
|
|
|
|
|
|
int col_ind, d_ind;
|
|
|
|
|
for (int j=start; j<stop; j++) {
|
|
|
|
|
col_ind = m_col_ind[j];
|
|
|
|
|
d_ind = j;
|
|
|
|
|
|
|
|
|
|
dot = cdouble_add(dot, cdouble_mul(v_in[col_ind], m_data[d_ind]));
|
|
|
|
|
}
|
|
|
|
|
v_out[i] = dot;
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
v_out_args = ctype + ' *v_out'
|
|
|
|
|
m_args = 'int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data'
|
|
|
|
|
v_in_args = ctype + ' *v_in'
|
|
|
|
|
|
|
|
|
|
spmv_kernel = ElementwiseKernel(context,
|
|
|
|
|
name='csr_spmv',
|
|
|
|
|
preamble=preamble,
|
|
|
|
|
operation=spmv_source,
|
|
|
|
|
arguments=', '.join((v_out_args, m_args, v_in_args)))
|
|
|
|
|
|
|
|
|
|
def spmv(v_out, m, v_in, e):
|
|
|
|
|
return spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)
|
|
|
|
|
|
|
|
|
|
return spmv
|
|
|
|
|