diff --git a/opencl_fdfd/csr.py b/opencl_fdfd/csr.py index ed2a338..d3314e4 100644 --- a/opencl_fdfd/csr.py +++ b/opencl_fdfd/csr.py @@ -1,4 +1,5 @@ import numpy +from numpy.linalg import norm import pyopencl import pyopencl.array @@ -52,17 +53,15 @@ def create_ops(context): v_out[i] = dot; ''' - v_out_args = ctype + ' *v_out, int v_len_half' - m_args = 'int m_nnz, int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data' + v_out_args = ctype + ' *v_out' + m_args = 'int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data' v_in_args = ctype + ' *v_in' spmv_kernel = ElementwiseKernel(context, operation=spmv_source, preamble=preamble, arguments=', '.join((v_out_args, m_args, v_in_args))) def spmv(v_out, m, v_in, e): - return spmv_kernel(v_out, (v_out.size - 1)//2, - m.data.size, m.row_ptr, m.col_ind, m.data, - v_in, wait_for=e) + return spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e) # ------------------------------------- @@ -93,7 +92,7 @@ def create_ops(context): dtype_out=ri_dtype, neutral='(double3)(0.0, 0.0, 0.0)', map_expr=update_ri_source, - reduce_expr='a+b', + reduce_expr='a + b', arguments=ctype + ' *r') def ri_update(r, e): @@ -149,7 +148,7 @@ def cg(a, b, max_iters=10000, err_thresh=1e-6, context=None, queue=None, verbose ops = create_ops(context) x = pyopencl.array.zeros(queue, dtype=numpy.complex128, shape=b.shape) - v = pyopencl.array.empty_like(x) + v = pyopencl.array.zeros_like(x) p = pyopencl.array.zeros_like(x) r = pyopencl.array.to_device(queue, b) alpha = 1.0 + 0j @@ -158,20 +157,19 @@ def cg(a, b, max_iters=10000, err_thresh=1e-6, context=None, queue=None, verbose m = CSRMatrix(queue, a) - e = ops['spmv'](v, m, x, []) - e = ops['xr_update'](x, p, r, v, 0.0, [e]) - _, err2 = ops['ri_update'](r, [e]) + _, err2 = ops['ri_update'](r, []) b_norm = numpy.sqrt(err2) print('b_norm check: ', b_norm) start_time2 = time.perf_counter() + success = False for k in range(max_iters): if verbose: print('[{:06d}] rho {:.4} alpha {:4.4}'.format(k, rho, alpha), end=' ') rho_prev = rho - e = ops['xr_update'](x, p, r, v, alpha, [e]) + e = ops['xr_update'](x, p, r, v, alpha, []) rho, err2 = ops['ri_update'](r, [e]) errs += [numpy.sqrt(err2) / b_norm] @@ -179,17 +177,35 @@ def cg(a, b, max_iters=10000, err_thresh=1e-6, context=None, queue=None, verbose print('err', errs[-1]) if errs[-1] < err_thresh: - time_elapsed = time.perf_counter() - start_time - print('Success, {} iterations in {} sec: {} iterations/sec'.format(k, - time_elapsed, k/time_elapsed)) - print('overhead', start_time2-start_time) - return x.get(), errs, True + success = True + break e = ops['p_update'](p, r, rho/rho_prev, []) - e.wait() ops['spmv'](v, m, p, [e]).wait() + +# v2 = a @ p.get() +# print('norm', norm(v2 - v.get())) + alpha = rho / pyopencl.array.dot(p, v).get() if k % 1000 == 0: print(k) - return x.get(), errs, False + ''' + Done solving + ''' + time_elapsed = time.perf_counter() - start_time + + x = x.get() + + if success: + print('Success', end='') + else: + print('Failure', end=', ') + print(', {} iterations in {} sec: {} iterations/sec \ + '.format(k, time_elapsed, k / time_elapsed)) + print('final error', errs[-1]) + print('overhead {} sec'.format(start_time2 - start_time)) + + print('Post-everything residual:', norm(a @ x - b) / norm(b)) + return x +