cleanup and improved reporting for csr
This commit is contained in:
parent
03f7f6d3c4
commit
bb7b90e938
@ -1,4 +1,5 @@
|
|||||||
import numpy
|
import numpy
|
||||||
|
from numpy.linalg import norm
|
||||||
|
|
||||||
import pyopencl
|
import pyopencl
|
||||||
import pyopencl.array
|
import pyopencl.array
|
||||||
@ -52,17 +53,15 @@ def create_ops(context):
|
|||||||
v_out[i] = dot;
|
v_out[i] = dot;
|
||||||
'''
|
'''
|
||||||
|
|
||||||
v_out_args = ctype + ' *v_out, int v_len_half'
|
v_out_args = ctype + ' *v_out'
|
||||||
m_args = 'int m_nnz, int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data'
|
m_args = 'int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data'
|
||||||
v_in_args = ctype + ' *v_in'
|
v_in_args = ctype + ' *v_in'
|
||||||
|
|
||||||
spmv_kernel = ElementwiseKernel(context, operation=spmv_source, preamble=preamble,
|
spmv_kernel = ElementwiseKernel(context, operation=spmv_source, preamble=preamble,
|
||||||
arguments=', '.join((v_out_args, m_args, v_in_args)))
|
arguments=', '.join((v_out_args, m_args, v_in_args)))
|
||||||
|
|
||||||
def spmv(v_out, m, v_in, e):
|
def spmv(v_out, m, v_in, e):
|
||||||
return spmv_kernel(v_out, (v_out.size - 1)//2,
|
return spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)
|
||||||
m.data.size, m.row_ptr, m.col_ind, m.data,
|
|
||||||
v_in, wait_for=e)
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
|
|
||||||
@ -93,7 +92,7 @@ def create_ops(context):
|
|||||||
dtype_out=ri_dtype,
|
dtype_out=ri_dtype,
|
||||||
neutral='(double3)(0.0, 0.0, 0.0)',
|
neutral='(double3)(0.0, 0.0, 0.0)',
|
||||||
map_expr=update_ri_source,
|
map_expr=update_ri_source,
|
||||||
reduce_expr='a+b',
|
reduce_expr='a + b',
|
||||||
arguments=ctype + ' *r')
|
arguments=ctype + ' *r')
|
||||||
|
|
||||||
def ri_update(r, e):
|
def ri_update(r, e):
|
||||||
@ -149,7 +148,7 @@ def cg(a, b, max_iters=10000, err_thresh=1e-6, context=None, queue=None, verbose
|
|||||||
ops = create_ops(context)
|
ops = create_ops(context)
|
||||||
|
|
||||||
x = pyopencl.array.zeros(queue, dtype=numpy.complex128, shape=b.shape)
|
x = pyopencl.array.zeros(queue, dtype=numpy.complex128, shape=b.shape)
|
||||||
v = pyopencl.array.empty_like(x)
|
v = pyopencl.array.zeros_like(x)
|
||||||
p = pyopencl.array.zeros_like(x)
|
p = pyopencl.array.zeros_like(x)
|
||||||
r = pyopencl.array.to_device(queue, b)
|
r = pyopencl.array.to_device(queue, b)
|
||||||
alpha = 1.0 + 0j
|
alpha = 1.0 + 0j
|
||||||
@ -158,20 +157,19 @@ def cg(a, b, max_iters=10000, err_thresh=1e-6, context=None, queue=None, verbose
|
|||||||
|
|
||||||
m = CSRMatrix(queue, a)
|
m = CSRMatrix(queue, a)
|
||||||
|
|
||||||
e = ops['spmv'](v, m, x, [])
|
_, err2 = ops['ri_update'](r, [])
|
||||||
e = ops['xr_update'](x, p, r, v, 0.0, [e])
|
|
||||||
_, err2 = ops['ri_update'](r, [e])
|
|
||||||
|
|
||||||
b_norm = numpy.sqrt(err2)
|
b_norm = numpy.sqrt(err2)
|
||||||
print('b_norm check: ', b_norm)
|
print('b_norm check: ', b_norm)
|
||||||
|
|
||||||
start_time2 = time.perf_counter()
|
start_time2 = time.perf_counter()
|
||||||
|
|
||||||
|
success = False
|
||||||
for k in range(max_iters):
|
for k in range(max_iters):
|
||||||
if verbose:
|
if verbose:
|
||||||
print('[{:06d}] rho {:.4} alpha {:4.4}'.format(k, rho, alpha), end=' ')
|
print('[{:06d}] rho {:.4} alpha {:4.4}'.format(k, rho, alpha), end=' ')
|
||||||
rho_prev = rho
|
rho_prev = rho
|
||||||
e = ops['xr_update'](x, p, r, v, alpha, [e])
|
e = ops['xr_update'](x, p, r, v, alpha, [])
|
||||||
rho, err2 = ops['ri_update'](r, [e])
|
rho, err2 = ops['ri_update'](r, [e])
|
||||||
errs += [numpy.sqrt(err2) / b_norm]
|
errs += [numpy.sqrt(err2) / b_norm]
|
||||||
|
|
||||||
@ -179,17 +177,35 @@ def cg(a, b, max_iters=10000, err_thresh=1e-6, context=None, queue=None, verbose
|
|||||||
print('err', errs[-1])
|
print('err', errs[-1])
|
||||||
|
|
||||||
if errs[-1] < err_thresh:
|
if errs[-1] < err_thresh:
|
||||||
time_elapsed = time.perf_counter() - start_time
|
success = True
|
||||||
print('Success, {} iterations in {} sec: {} iterations/sec'.format(k,
|
break
|
||||||
time_elapsed, k/time_elapsed))
|
|
||||||
print('overhead', start_time2-start_time)
|
|
||||||
return x.get(), errs, True
|
|
||||||
e = ops['p_update'](p, r, rho/rho_prev, [])
|
e = ops['p_update'](p, r, rho/rho_prev, [])
|
||||||
e.wait()
|
|
||||||
ops['spmv'](v, m, p, [e]).wait()
|
ops['spmv'](v, m, p, [e]).wait()
|
||||||
|
|
||||||
|
# v2 = a @ p.get()
|
||||||
|
# print('norm', norm(v2 - v.get()))
|
||||||
|
|
||||||
alpha = rho / pyopencl.array.dot(p, v).get()
|
alpha = rho / pyopencl.array.dot(p, v).get()
|
||||||
|
|
||||||
if k % 1000 == 0:
|
if k % 1000 == 0:
|
||||||
print(k)
|
print(k)
|
||||||
|
|
||||||
return x.get(), errs, False
|
'''
|
||||||
|
Done solving
|
||||||
|
'''
|
||||||
|
time_elapsed = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
x = x.get()
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print('Success', end='')
|
||||||
|
else:
|
||||||
|
print('Failure', end=', ')
|
||||||
|
print(', {} iterations in {} sec: {} iterations/sec \
|
||||||
|
'.format(k, time_elapsed, k / time_elapsed))
|
||||||
|
print('final error', errs[-1])
|
||||||
|
print('overhead {} sec'.format(start_time2 - start_time))
|
||||||
|
|
||||||
|
print('Post-everything residual:', norm(a @ x - b) / norm(b))
|
||||||
|
return x
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user