Simplify csr.cg and use operations from .ops
This commit is contained in:
parent
c9cb48d175
commit
e94a07db28
@ -1,128 +1,12 @@
|
|||||||
import numpy
|
import numpy
|
||||||
from numpy.linalg import norm
|
from numpy.linalg import norm
|
||||||
|
|
||||||
import pyopencl
|
import pyopencl
|
||||||
import pyopencl.array
|
import pyopencl.array
|
||||||
from pyopencl.elementwise import ElementwiseKernel
|
|
||||||
from pyopencl.reduction import ReductionKernel
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
def type_to_C(float_type: numpy.float32 or numpy.float64) -> str:
|
from . import ops
|
||||||
"""
|
|
||||||
Returns a string corresponding to the C equivalent of a numpy type.
|
|
||||||
|
|
||||||
:param float_type: numpy type: float32, float64, complex64, complex128
|
|
||||||
:return: string containing the corresponding C type (eg. 'double')
|
|
||||||
"""
|
|
||||||
types = {
|
|
||||||
numpy.float32: 'float',
|
|
||||||
numpy.float64: 'double',
|
|
||||||
numpy.complex64: 'cfloat_t',
|
|
||||||
numpy.complex128: 'cdouble_t',
|
|
||||||
}
|
|
||||||
if float_type not in types:
|
|
||||||
raise Exception('Unsupported type')
|
|
||||||
|
|
||||||
return types[float_type]
|
|
||||||
|
|
||||||
|
|
||||||
def create_ops(context):
|
|
||||||
preamble = '''
|
|
||||||
#define PYOPENCL_DEFINE_CDOUBLE
|
|
||||||
#include <pyopencl-complex.h>
|
|
||||||
'''
|
|
||||||
|
|
||||||
ctype = type_to_C(numpy.complex128)
|
|
||||||
|
|
||||||
# -------------------------------------
|
|
||||||
|
|
||||||
spmv_source = '''
|
|
||||||
int start = m_row_ptr[i];
|
|
||||||
int stop = m_row_ptr[i+1];
|
|
||||||
cdouble_t dot = cdouble_new(0.0, 0.0);
|
|
||||||
|
|
||||||
int col_ind, d_ind;
|
|
||||||
for (int j=start; j<stop; j++) {
|
|
||||||
col_ind = m_col_ind[j];
|
|
||||||
d_ind = j;
|
|
||||||
|
|
||||||
dot = cdouble_add(dot, cdouble_mul(v_in[col_ind], m_data[d_ind]));
|
|
||||||
}
|
|
||||||
v_out[i] = dot;
|
|
||||||
'''
|
|
||||||
|
|
||||||
v_out_args = ctype + ' *v_out'
|
|
||||||
m_args = 'int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data'
|
|
||||||
v_in_args = ctype + ' *v_in'
|
|
||||||
|
|
||||||
spmv_kernel = ElementwiseKernel(context, operation=spmv_source, preamble=preamble,
|
|
||||||
arguments=', '.join((v_out_args, m_args, v_in_args)))
|
|
||||||
|
|
||||||
def spmv(v_out, m, v_in, e):
|
|
||||||
return spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)
|
|
||||||
|
|
||||||
# -------------------------------------
|
|
||||||
|
|
||||||
update_xr_source = '''
|
|
||||||
x[i] = cdouble_add(x[i], cdouble_mul(alpha, p[i]));
|
|
||||||
r[i] = cdouble_sub(r[i], cdouble_mul(alpha, v[i]));
|
|
||||||
'''
|
|
||||||
|
|
||||||
xr_args = ', '.join([ctype + ' ' + f for f in ('*x', '*p', '*r', '*v', 'alpha')])
|
|
||||||
|
|
||||||
xr_kernel = ElementwiseKernel(context, operation=update_xr_source, preamble=preamble,
|
|
||||||
arguments=xr_args)
|
|
||||||
|
|
||||||
def xr_update(x, p, r, v, alpha, e):
|
|
||||||
return xr_kernel(x, p, r, v, alpha, wait_for=e)
|
|
||||||
|
|
||||||
# -------------------------------------
|
|
||||||
|
|
||||||
update_ri_source = '''
|
|
||||||
(double3)(r[i].real * r[i].real, \
|
|
||||||
r[i].real * r[i].imag, \
|
|
||||||
r[i].imag * r[i].imag)
|
|
||||||
'''
|
|
||||||
|
|
||||||
ri_dtype = pyopencl.array.vec.double3
|
|
||||||
|
|
||||||
ri_kernel = ReductionKernel(context, preamble=preamble,
|
|
||||||
dtype_out=ri_dtype,
|
|
||||||
neutral='(double3)(0.0, 0.0, 0.0)',
|
|
||||||
map_expr=update_ri_source,
|
|
||||||
reduce_expr='a + b',
|
|
||||||
arguments=ctype + ' *r')
|
|
||||||
|
|
||||||
def ri_update(r, e):
|
|
||||||
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
|
|
||||||
rr, ri, ii = [g[q] for q in 'xyz']
|
|
||||||
rho = rr + 2j * ri - ii
|
|
||||||
err = rr + ii
|
|
||||||
return rho, err
|
|
||||||
|
|
||||||
# -------------------------------------
|
|
||||||
|
|
||||||
update_p_source = '''
|
|
||||||
p[i] = cdouble_add(r[i], cdouble_mul(beta, p[i]));
|
|
||||||
'''
|
|
||||||
p_args = ctype + ' *p, ' + ctype + ' *r, ' + ctype + ' beta'
|
|
||||||
|
|
||||||
p_kernel = ElementwiseKernel(context, preamble=preamble, operation=update_p_source,
|
|
||||||
arguments=p_args)
|
|
||||||
|
|
||||||
def p_update(p, r, beta, e):
|
|
||||||
return p_kernel(p, r, beta, wait_for=e)
|
|
||||||
|
|
||||||
ops = {
|
|
||||||
'spmv': spmv,
|
|
||||||
'p_update': p_update,
|
|
||||||
'ri_update': ri_update,
|
|
||||||
'xr_update': xr_update,
|
|
||||||
}
|
|
||||||
|
|
||||||
return ops
|
|
||||||
|
|
||||||
|
|
||||||
class CSRMatrix(object):
|
class CSRMatrix(object):
|
||||||
@ -145,32 +29,51 @@ def cg(a, b, max_iters=10000, err_thresh=1e-6, context=None, queue=None, verbose
|
|||||||
if queue is None:
|
if queue is None:
|
||||||
queue = pyopencl.CommandQueue(context)
|
queue = pyopencl.CommandQueue(context)
|
||||||
|
|
||||||
ops = create_ops(context)
|
def load_field(v, dtype=numpy.complex128):
|
||||||
|
return pyopencl.array.to_device(queue, v.astype(dtype))
|
||||||
|
|
||||||
|
r = load_field(b)
|
||||||
|
x = pyopencl.array.zeros_like(r)
|
||||||
|
v = pyopencl.array.zeros_like(r)
|
||||||
|
p = pyopencl.array.zeros_like(r)
|
||||||
|
|
||||||
x = pyopencl.array.zeros(queue, dtype=numpy.complex128, shape=b.shape)
|
|
||||||
v = pyopencl.array.zeros_like(x)
|
|
||||||
p = pyopencl.array.zeros_like(x)
|
|
||||||
r = pyopencl.array.to_device(queue, b)
|
|
||||||
alpha = 1.0 + 0j
|
alpha = 1.0 + 0j
|
||||||
rho = 1.0 + 0j
|
rho = 1.0 + 0j
|
||||||
errs = []
|
errs = []
|
||||||
|
|
||||||
m = CSRMatrix(queue, a)
|
m = CSRMatrix(queue, a)
|
||||||
|
|
||||||
_, err2 = ops['ri_update'](r, [])
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
Generate OpenCL kernels
|
||||||
|
'''
|
||||||
|
a_step = ops.create_a_csr(context)
|
||||||
|
xr_step = ops.create_xr_step(context)
|
||||||
|
rhoerr_step = ops.create_rhoerr_step(context)
|
||||||
|
p_step = ops.create_p_step(context)
|
||||||
|
dot = ops.create_dot(context)
|
||||||
|
|
||||||
|
def a_step(E, H, p, events):
|
||||||
|
return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events)
|
||||||
|
|
||||||
|
'''
|
||||||
|
Start the solve
|
||||||
|
'''
|
||||||
|
start_time2 = time.perf_counter()
|
||||||
|
|
||||||
|
_, err2 = rhoerr_step(r, [])
|
||||||
b_norm = numpy.sqrt(err2)
|
b_norm = numpy.sqrt(err2)
|
||||||
print('b_norm check: ', b_norm)
|
print('b_norm check: ', b_norm)
|
||||||
|
|
||||||
start_time2 = time.perf_counter()
|
|
||||||
|
|
||||||
success = False
|
success = False
|
||||||
for k in range(max_iters):
|
for k in range(max_iters):
|
||||||
if verbose:
|
if verbose:
|
||||||
print('[{:06d}] rho {:.4} alpha {:4.4}'.format(k, rho, alpha), end=' ')
|
print('[{:06d}] rho {:.4} alpha {:4.4}'.format(k, rho, alpha), end=' ')
|
||||||
|
|
||||||
rho_prev = rho
|
rho_prev = rho
|
||||||
e = ops['xr_update'](x, p, r, v, alpha, [])
|
e = xr_step(x, p, r, v, alpha, [])
|
||||||
rho, err2 = ops['ri_update'](r, [e])
|
rho, err2 = rhoerr_step(r, e)
|
||||||
|
|
||||||
errs += [numpy.sqrt(err2) / b_norm]
|
errs += [numpy.sqrt(err2) / b_norm]
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
@ -179,13 +82,10 @@ def cg(a, b, max_iters=10000, err_thresh=1e-6, context=None, queue=None, verbose
|
|||||||
if errs[-1] < err_thresh:
|
if errs[-1] < err_thresh:
|
||||||
success = True
|
success = True
|
||||||
break
|
break
|
||||||
e = ops['p_update'](p, r, rho/rho_prev, [])
|
|
||||||
ops['spmv'](v, m, p, [e]).wait()
|
|
||||||
|
|
||||||
# v2 = a @ p.get()
|
e = p_step(p, r, rho/rho_prev, [])
|
||||||
# print('norm', norm(v2 - v.get()))
|
e = a_step(v, m, p, e)
|
||||||
|
alpha = rho / dot(p, v, e)
|
||||||
alpha = rho / pyopencl.array.dot(p, v).get()
|
|
||||||
|
|
||||||
if k % 1000 == 0:
|
if k % 1000 == 0:
|
||||||
print(k)
|
print(k)
|
||||||
|
Loading…
Reference in New Issue
Block a user