diff --git a/opencl_fdfd/__init__.py b/opencl_fdfd/__init__.py new file mode 100644 index 0000000..b2f1f79 --- /dev/null +++ b/opencl_fdfd/__init__.py @@ -0,0 +1 @@ +from .main import cg_solver \ No newline at end of file diff --git a/opencl_fdfd/kernels/e2h.cl b/opencl_fdfd/kernels/e2h.cl new file mode 100644 index 0000000..6f4f9fa --- /dev/null +++ b/opencl_fdfd/kernels/e2h.cl @@ -0,0 +1,87 @@ +/* + * + * H update equations + * + */ + +//Define sx, x, dix (and y, z versions of those) +{{dixyz_source}} + +//Define vectorized fields and pointers (eg. Hx = H + XX) +{{vec_source}} + + +// Wrap indices if necessary +int ipx, ipy, ipz; +if ( x == sx - 1 ) { + ipx = i - (sx - 1) * dix; +} else { + ipx = i + dix; +} + +if ( y == sy - 1 ) { + ipy = i - (sy - 1) * diy; +} else { + ipy = i + diy; +} + +if ( z == sz - 1 ) { + ipz = i - (sz - 1) * diz; +} else { + ipz = i + diz; +} + + +//Update H components; set them to 0 if PMC is enabled there. +// Also divide by mu only if requested. +{% if pmc -%} +if (pmc[XX + i]) { + Hx[i] = cdouble_new(0.0, 0.0); +} else +{%- endif -%} +{ + cdouble_t Dzy = cdouble_mul(cdouble_sub(Ez[ipy], Ez[i]), inv_dey[y]); + cdouble_t Dyz = cdouble_mul(cdouble_sub(Ey[ipz], Ey[i]), inv_dez[z]); + + {%- if mu -%} + Hx[i] = cdouble_mul(inv_mu[XX + i], cdouble_sub(Dzy, Dyz)); + {%- else -%} + Hx[i] = cdouble_sub(Dzy, Dyz); + {%- endif %} +} + +{% if pmc -%} +if (pmc[YY + i]) { + Hy[i] = cdouble_new(0.0, 0.0); +} else +{%- endif -%} +{ + cdouble_t Dxz = cdouble_mul(cdouble_sub(Ex[ipz], Ex[i]), inv_dez[z]); + cdouble_t Dzx = cdouble_mul(cdouble_sub(Ez[ipx], Ez[i]), inv_dex[x]); + + {%- if mu -%} + Hy[i] = cdouble_mul(inv_mu[YY + i], cdouble_sub(Dxz, Dzx)); + {%- else -%} + Hy[i] = cdouble_sub(Dxz, Dzx); + {%- endif %} +} + +{% if pmc -%} +if (pmc[XX + i]) { + Hx[i] = cdouble_new(0.0, 0.0); +} else +{%- endif -%} +{ + cdouble_t Dyx = cdouble_mul(cdouble_sub(Ey[ipx], Ey[i]), inv_dex[x]); + cdouble_t Dxy = cdouble_mul(cdouble_sub(Ex[ipy], Ex[i]), inv_dey[y]); + + {%- if mu -%} + Hz[i] = cdouble_mul(inv_mu[ZZ + i], cdouble_sub(Dyx, Dxy)); + {%- else -%} + Hz[i] = cdouble_sub(Dyx, Dxy); + {%- endif %} +} + +/* + * End H update equations + */ \ No newline at end of file diff --git a/opencl_fdfd/kernels/h2e.cl b/opencl_fdfd/kernels/h2e.cl new file mode 100644 index 0000000..d4fda62 --- /dev/null +++ b/opencl_fdfd/kernels/h2e.cl @@ -0,0 +1,77 @@ +/* + * + * E update equations + * + */ + +//Define sx, x, dix (and y, z versions of those) +{{dixyz_source}} + +//Define vectorized fields and pointers (eg. Hx = H + XX) +{{vec_source}} + + +// Wrap indices if necessary +int imx, imy, imz; +if ( x == 0 ) { + imx = i + (sx - 1) * dix; +} else { + imx = i - dix; +} + +if ( y == 0 ) { + imy = i + (sy - 1) * diy; +} else { + imy = i - diy; +} + +if ( z == 0 ) { + imz = i + (sz - 1) * diz; +} else { + imz = i - diz; +} + + +//Update E components; set them to 0 if PEC is enabled there. +{% if pec -%} +if (pec[XX + i]) { + Ex[i] = cdouble_new(0.0, 0.0); +} else +{%- endif -%} +{ + cdouble_t tEx = cdouble_mul(Ex[i], oeps[XX + i]); + cdouble_t Dzy = cdouble_mul(cdouble_sub(Hz[i], Hz[imy]), inv_dhy[y]); + cdouble_t Dyz = cdouble_mul(cdouble_sub(Hy[i], Hy[imz]), inv_dhz[z]); + tEx = cdouble_add(tEx, cdouble_sub(Dzy, Dyz)); + Ex[i] = cdouble_mul(tEx, Pl[XX + i]); +} + +{% if pec -%} +if (pec[YY + i]) { + Ey[i] = cdouble_new(0.0, 0.0); +} else +{%- endif -%} +{ + cdouble_t tEy = cdouble_mul(Ey[i], oeps[YY + i]); + cdouble_t Dxz = cdouble_mul(cdouble_sub(Hx[i], Hx[imz]), inv_dhz[z]); + cdouble_t Dzx = cdouble_mul(cdouble_sub(Hz[i], Hz[imx]), inv_dhx[x]); + tEy = cdouble_add(tEy, cdouble_sub(Dxz, Dzx)); + Ey[i] = cdouble_mul(tEy, Pl[YY + i]); +} + +{% if pec -%} +if (pec[ZZ + i]) { + Ez[i] = cdouble_new(0.0, 0.0); +} else +{%- endif -%} +{ + cdouble_t tEz = cdouble_mul(Ez[i], oeps[ZZ + i]); + cdouble_t Dyx = cdouble_mul(cdouble_sub(Hy[i], Hy[imx]), inv_dhx[x]); + cdouble_t Dxy = cdouble_mul(cdouble_sub(Hx[i], Hx[imy]), inv_dhy[y]); + tEz = cdouble_add(tEz, cdouble_sub(Dyx, Dxy)); + Ez[i] = cdouble_mul(tEz, Pl[ZZ + i]); +} + +/* + * End H update equations + */ \ No newline at end of file diff --git a/opencl_fdfd/main.py b/opencl_fdfd/main.py index 1e131fa..69b6070 100644 --- a/opencl_fdfd/main.py +++ b/opencl_fdfd/main.py @@ -1,364 +1,13 @@ import numpy from numpy.linalg import norm - -import jinja2 import pyopencl import pyopencl.array -from pyopencl.elementwise import ElementwiseKernel -from pyopencl.reduction import ReductionKernel import time import fdfd_tools.operators - -def type_to_C(float_type: numpy.float32 or numpy.float64) -> str: - """ - Returns a string corresponding to the C equivalent of a numpy type. - - :param float_type: numpy type: float32, float64, complex64, complex128 - :return: string containing the corresponding C type (eg. 'double') - """ - types = { - numpy.float32: 'float', - numpy.float64: 'double', - numpy.complex64: 'cfloat_t', - numpy.complex128: 'cdouble_t', - } - if float_type not in types: - raise Exception('Unsupported type') - - return types[float_type] - - -def shape_source(shape) -> str: - """ - Defines sx, sy, sz C constants specifying the shape of the grid in each of the 3 dimensions. - - :param shape: [sx, sy, sz] values. - :return: String containing C source. - """ - sxyz = """ -// Field sizes -const int sx = {shape[0]}; -const int sy = {shape[1]}; -const int sz = {shape[2]}; -""".format(shape=shape) - return sxyz - -# Defines dix, diy, diz constants used for stepping in the x, y, z directions in a linear array -# (ie, given Ex[i] referring to position (x, y, z), Ex[i+diy] will refer to position (x, y+1, z)) -dixyz_source = """ -// Convert offset in field xyz to linear index offset -const int dix = 1; -const int diy = sx; -const int diz = sx * sy; -""" - -# Given a linear index i and shape sx, sy, sz, defines x, y, and z -# as the 3D indices of the current element (i). -xyz_source = """ -// Convert linear index to field index (xyz) -const int z = i / (sx * sy); -const int y = (i - z * sx * sy) / sx; -const int x = (i - y * sx - z * sx * sy); -""" - -vec_source = """ -if (i >= sx * sy * sz) { - PYOPENCL_ELWISE_CONTINUE; -} - -//Pointers into the components of a vectorized vector-field -const int XX = 0; -const int YY = sx * sy * sz; -const int ZZ = sx * sy * sz * 2; -""" - -E_ptrs = """ -__global cdouble_t *Ex = E + XX; -__global cdouble_t *Ey = E + YY; -__global cdouble_t *Ez = E + ZZ; -""" - -H_ptrs = """ -__global cdouble_t *Hx = H + XX; -__global cdouble_t *Hy = H + YY; -__global cdouble_t *Hz = H + ZZ; -""" - -# Source code for updating the E field; maxes use of dixyz_source. -maxwell_E_source = """ -// E update equations -int imx, imy, imz; -if ( x == 0 ) { - imx = i + (sx - 1) * dix; -} else { - imx = i - dix; -} - -if ( y == 0 ) { - imy = i + (sy - 1) * diy; -} else { - imy = i - diy; -} - -if ( z == 0 ) { - imz = i + (sz - 1) * diz; -} else { - imz = i - diz; -} - -// E update equations - -{% if pec -%} -if (pec[XX + i]) { - Ex[i] = cdouble_new(0.0, 0.0); -} else -{%- endif -%} -{ - cdouble_t tEx = cdouble_mul(Ex[i], oeps[XX + i]); - cdouble_t Dzy = cdouble_mul(cdouble_sub(Hz[i], Hz[imy]), inv_dhy[y]); - cdouble_t Dyz = cdouble_mul(cdouble_sub(Hy[i], Hy[imz]), inv_dhz[z]); - tEx = cdouble_add(tEx, cdouble_sub(Dzy, Dyz)); - Ex[i] = cdouble_mul(tEx, Pl[XX + i]); -} - -{% if pec -%} -if (pec[YY + i]) { - Ey[i] = cdouble_new(0.0, 0.0); -} else -{%- endif -%} -{ - cdouble_t tEy = cdouble_mul(Ey[i], oeps[YY + i]); - cdouble_t Dxz = cdouble_mul(cdouble_sub(Hx[i], Hx[imz]), inv_dhz[z]); - cdouble_t Dzx = cdouble_mul(cdouble_sub(Hz[i], Hz[imx]), inv_dhx[x]); - tEy = cdouble_add(tEy, cdouble_sub(Dxz, Dzx)); - Ey[i] = cdouble_mul(tEy, Pl[YY + i]); -} - -{% if pec -%} -if (pec[ZZ + i]) { - Ez[i] = cdouble_new(0.0, 0.0); -} else -{%- endif -%} -{ - cdouble_t tEz = cdouble_mul(Ez[i], oeps[ZZ + i]); - cdouble_t Dyx = cdouble_mul(cdouble_sub(Hy[i], Hy[imx]), inv_dhx[x]); - cdouble_t Dxy = cdouble_mul(cdouble_sub(Hx[i], Hx[imy]), inv_dhy[y]); - tEz = cdouble_add(tEz, cdouble_sub(Dyx, Dxy)); - Ez[i] = cdouble_mul(tEz, Pl[ZZ + i]); -} -""" - -# Source code for updating the H field; maxes use of dixyz_source and assumes mu=0 -maxwell_H_source = """ -// H update equations -int ipx, ipy, ipz; -if ( x == sx - 1 ) { - ipx = i - (sx - 1) * dix; -} else { - ipx = i + dix; -} - -if ( y == sy - 1 ) { - ipy = i - (sy - 1) * diy; -} else { - ipy = i + diy; -} - -if ( z == sz - 1 ) { - ipz = i - (sz - 1) * diz; -} else { - ipz = i + diz; -} - -{% if pmc -%} -if (pmc[XX + i]) { - Hx[i] = cdouble_new(0.0, 0.0); -} else -{%- endif -%} -{ - cdouble_t Dzy = cdouble_mul(cdouble_sub(Ez[ipy], Ez[i]), inv_dey[y]); - cdouble_t Dyz = cdouble_mul(cdouble_sub(Ey[ipz], Ey[i]), inv_dez[z]); - - {%- if mu -%} - Hx[i] = cdouble_mul(inv_mu[XX + i], cdouble_sub(Dzy, Dyz)); - {%- else -%} - Hx[i] = cdouble_sub(Dzy, Dyz); - {%- endif %} -} - -{% if pmc -%} -if (pmc[YY + i]) { - Hy[i] = cdouble_new(0.0, 0.0); -} else -{%- endif -%} -{ - cdouble_t Dxz = cdouble_mul(cdouble_sub(Ex[ipz], Ex[i]), inv_dez[z]); - cdouble_t Dzx = cdouble_mul(cdouble_sub(Ez[ipx], Ez[i]), inv_dex[x]); - - {%- if mu -%} - Hy[i] = cdouble_mul(inv_mu[YY + i], cdouble_sub(Dxz, Dzx)); - {%- else -%} - Hy[i] = cdouble_sub(Dxz, Dzx); - {%- endif %} -} - -{% if pmc -%} -if (pmc[XX + i]) { - Hx[i] = cdouble_new(0.0, 0.0); -} else -{%- endif -%} -{ - cdouble_t Dyx = cdouble_mul(cdouble_sub(Ey[ipx], Ey[i]), inv_dex[x]); - cdouble_t Dxy = cdouble_mul(cdouble_sub(Ex[ipy], Ex[i]), inv_dey[y]); - - {%- if mu -%} - Hz[i] = cdouble_mul(inv_mu[ZZ + i], cdouble_sub(Dyx, Dxy)); - {%- else -%} - Hz[i] = cdouble_sub(Dyx, Dxy); - {%- endif %} -} -""" - -p2e_source = ''' -Ex[i] = cdouble_mul(Pr[XX + i], p[XX + i]); -Ey[i] = cdouble_mul(Pr[YY + i], p[YY + i]); -Ez[i] = cdouble_mul(Pr[ZZ + i], p[ZZ + i]); -''' - -preamble = ''' -#define PYOPENCL_DEFINE_CDOUBLE -#include -''' - -ctype = type_to_C(numpy.complex128) - - -def ptrs(*args): - return [ctype + ' *' + s for s in args] - - -def create_a(context, shape, mu=False, pec=False, pmc=False): - dhs = [ctype + ' *inv_dh' + a for a in 'xyz'] - des = [ctype + ' *inv_de' + a for a in 'xyz'] - - header = shape_source(shape) + dixyz_source + xyz_source + vec_source + E_ptrs - P2E_kernel = ElementwiseKernel(context, - name='P2E', - preamble=preamble, - operation=header + p2e_source, - arguments=', '.join(ptrs('E', 'p', 'Pr'))) - - pmc_arg = ['int *pmc'] - e2h_source = header + H_ptrs + jinja2.Template(maxwell_H_source).render(mu=mu, pmc=pmc) - E2H_kernel = ElementwiseKernel(context, - name='E2H', - preamble=preamble, - operation=e2h_source, - arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des)) - - pec_arg = ['int *pec'] - h2e_source = header + H_ptrs + jinja2.Template(maxwell_E_source).render(pec=pec) - H2E_kernel = ElementwiseKernel(context, - name='H2E', - preamble=preamble, - operation=h2e_source, - arguments=', '.join(ptrs('E', 'H', 'oeps', 'Pl') + pec_arg + dhs)) - - def spmv(E, H, p, idxes, oeps, inv_mu, pec, pmc, Pl, Pr, e): - e2 = P2E_kernel(E, p, Pr, wait_for=e) - e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2]) - e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2]) - return [e2] - - return spmv - - -def create_xr_step(context): - update_xr_source = ''' - x[i] = cdouble_add(x[i], cdouble_mul(alpha, p[i])); - r[i] = cdouble_sub(r[i], cdouble_mul(alpha, v[i])); - ''' - - xr_args = ', '.join(ptrs('x', 'p', 'r', 'v') + [ctype + ' alpha']) - - xr_kernel = ElementwiseKernel(context, - name='XR', - preamble=preamble, - operation=update_xr_source, - arguments=xr_args) - - def xr_update(x, p, r, v, alpha, e): - return [xr_kernel(x, p, r, v, alpha, wait_for=e)] - - return xr_update - - -def create_rhoerr_step(context): - update_ri_source = ''' - (double3)(r[i].real * r[i].real, \ - r[i].real * r[i].imag, \ - r[i].imag * r[i].imag) - ''' - - ri_dtype = pyopencl.array.vec.double3 - - ri_kernel = ReductionKernel(context, - name='RHOERR', - preamble=preamble, - dtype_out=ri_dtype, - neutral='(double3)(0.0, 0.0, 0.0)', - map_expr=update_ri_source, - reduce_expr='a+b', - arguments=ctype + ' *r') - - def ri_update(r, e): - g = ri_kernel(r, wait_for=e).astype(ri_dtype).get() - rr, ri, ii = [g[q] for q in 'xyz'] - rho = rr + 2j * ri - ii - err = rr + ii - return rho, err - - return ri_update - - -def create_p_step(context): - update_p_source = ''' - p[i] = cdouble_add(r[i], cdouble_mul(beta, p[i])); - ''' - p_args = ptrs('p', 'r') + [ctype + ' beta'] - - p_kernel = ElementwiseKernel(context, - name='P', - preamble=preamble, - operation=update_p_source, - arguments=', '.join(p_args)) - - def p_update(p, r, beta, e): - return [p_kernel(p, r, beta, wait_for=e)] - - return p_update - - -def create_dot(context): - dot_dtype = numpy.complex128 - - dot_kernel = ReductionKernel(context, - name='dot', - preamble=preamble, - dtype_out=dot_dtype, - neutral='cdouble_new(0.0, 0.0)', - map_expr='cdouble_mul(p[i], v[i])', - reduce_expr='cdouble_add(a, b)', - arguments=ptrs('p', 'v')) - - def ri_update(p, v, e): - g = dot_kernel(p, v, wait_for=e) - return g.get() - - return ri_update +from . import ops def cg_solver(omega, dxes, J, epsilon, mu=None, pec=None, pmc=None, adjoint=False, @@ -444,25 +93,25 @@ def cg_solver(omega, dxes, J, epsilon, mu=None, pec=None, pmc=None, adjoint=Fals invm = load_field(1 / mu) if pec is None: - gpec = load_field(numpy.array([]), dtype=int) + gpec = load_field(numpy.array([]), dtype=numpy.int8) else: - gpec = load_field(pec, dtype=int) + gpec = load_field(pec, dtype=numpy.int8) if pmc is None: - gpmc = load_field(numpy.array([]), dtype=int) + gpmc = load_field(numpy.array([]), dtype=numpy.int8) else: - gpmc = load_field(pmc, dtype=int) + gpmc = load_field(pmc, dtype=numpy.int8) ''' Generate OpenCL kernels ''' has_mu, has_pec, has_pmc = [q is not None for q in (mu, pec, pmc)] - a_step_full = create_a(context, shape, has_mu, has_pec, has_pmc) - xr_step = create_xr_step(context) - rhoerr_step = create_rhoerr_step(context) - p_step = create_p_step(context) - dot = create_dot(context) + a_step_full = ops.create_a(context, shape, has_mu, has_pec, has_pmc) + xr_step = ops.create_xr_step(context) + rhoerr_step = ops.create_rhoerr_step(context) + p_step = ops.create_p_step(context) + dot = ops.create_dot(context) def a_step(E, H, p, events): return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events) diff --git a/opencl_fdfd/ops.py b/opencl_fdfd/ops.py new file mode 100644 index 0000000..bb686d4 --- /dev/null +++ b/opencl_fdfd/ops.py @@ -0,0 +1,225 @@ +import numpy +import jinja2 + +import pyopencl +import pyopencl.array +from pyopencl.elementwise import ElementwiseKernel +from pyopencl.reduction import ReductionKernel + +# Create jinja2 env on module load +jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__, 'kernels')) + + +def type_to_C(float_type: numpy.float32 or numpy.float64) -> str: + """ + Returns a string corresponding to the C equivalent of a numpy type. + + :param float_type: numpy type: float32, float64, complex64, complex128 + :return: string containing the corresponding C type (eg. 'double') + """ + types = { + numpy.float32: 'float', + numpy.float64: 'double', + numpy.complex64: 'cfloat_t', + numpy.complex128: 'cdouble_t', + } + if float_type not in types: + raise Exception('Unsupported type') + + return types[float_type] + + +def shape_source(shape) -> str: + """ + Defines sx, sy, sz C constants specifying the shape of the grid in each of the 3 dimensions. + + :param shape: [sx, sy, sz] values. + :return: String containing C source. + """ + sxyz = """ +// Field sizes +const int sx = {shape[0]}; +const int sy = {shape[1]}; +const int sz = {shape[2]}; +""".format(shape=shape) + return sxyz + +# Defines dix, diy, diz constants used for stepping in the x, y, z directions in a linear array +# (ie, given Ex[i] referring to position (x, y, z), Ex[i+diy] will refer to position (x, y+1, z)) +dixyz_source = """ +// Convert offset in field xyz to linear index offset +const int dix = 1; +const int diy = sx; +const int diz = sx * sy; +""" + +# Given a linear index i and shape sx, sy, sz, defines x, y, and z +# as the 3D indices of the current element (i). +xyz_source = """ +// Convert linear index to field index (xyz) +const int z = i / (sx * sy); +const int y = (i - z * sx * sy) / sx; +const int x = (i - y * sx - z * sx * sy); +""" + +vec_source = """ +if (i >= sx * sy * sz) { + PYOPENCL_ELWISE_CONTINUE; +} + +//Pointers into the components of a vectorized vector-field +const int XX = 0; +const int YY = sx * sy * sz; +const int ZZ = sx * sy * sz * 2; +""" + +E_ptrs = """ +__global cdouble_t *Ex = E + XX; +__global cdouble_t *Ey = E + YY; +__global cdouble_t *Ez = E + ZZ; +""" + +H_ptrs = """ +__global cdouble_t *Hx = H + XX; +__global cdouble_t *Hy = H + YY; +__global cdouble_t *Hz = H + ZZ; +""" + +preamble = ''' +#define PYOPENCL_DEFINE_CDOUBLE +#include +''' + +ctype = type_to_C(numpy.complex128) + + +def ptrs(*args): + return [ctype + ' *' + s for s in args] + + +def create_a(context, shape, mu=False, pec=False, pmc=False): + header = shape_source(shape) + dixyz_source + xyz_source + vec_h = vec_source + E_ptrs + H_ptrs + + p2e_source = 'E[i] = cdouble_mul(Pr[i], p[i]);' + P2E_kernel = ElementwiseKernel(context, + name='P2E', + preamble=preamble, + operation=p2e_source, + arguments=', '.join(ptrs('E', 'p', 'Pr'))) + + pmc_arg = ['char *pmc'] + des = [ctype + ' *inv_de' + a for a in 'xyz'] + e2h_source = jinja_env.get_template('e2h.cl').render(mu=mu, + pmc=pmc, + dixyz_source=header, + vec_source=vec_h) + E2H_kernel = ElementwiseKernel(context, + name='E2H', + preamble=preamble, + operation=e2h_source, + arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des)) + + pec_arg = ['char *pec'] + dhs = [ctype + ' *inv_dh' + a for a in 'xyz'] + h2e_source = jinja_env.get_template('h2e.cl').render(pmc=pec, + dixyz_source=header, + vec_source=vec_h) + H2E_kernel = ElementwiseKernel(context, + name='H2E', + preamble=preamble, + operation=h2e_source, + arguments=', '.join(ptrs('E', 'H', 'oeps', 'Pl') + pec_arg + dhs)) + + def spmv(E, H, p, idxes, oeps, inv_mu, pec, pmc, Pl, Pr, e): + e2 = P2E_kernel(E, p, Pr, wait_for=e) + e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2]) + e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2]) + return [e2] + + return spmv + + +def create_xr_step(context): + update_xr_source = ''' + x[i] = cdouble_add(x[i], cdouble_mul(alpha, p[i])); + r[i] = cdouble_sub(r[i], cdouble_mul(alpha, v[i])); + ''' + + xr_args = ', '.join(ptrs('x', 'p', 'r', 'v') + [ctype + ' alpha']) + + xr_kernel = ElementwiseKernel(context, + name='XR', + preamble=preamble, + operation=update_xr_source, + arguments=xr_args) + + def xr_update(x, p, r, v, alpha, e): + return [xr_kernel(x, p, r, v, alpha, wait_for=e)] + + return xr_update + + +def create_rhoerr_step(context): + update_ri_source = ''' + (double3)(r[i].real * r[i].real, \ + r[i].real * r[i].imag, \ + r[i].imag * r[i].imag) + ''' + + ri_dtype = pyopencl.array.vec.double3 + + ri_kernel = ReductionKernel(context, + name='RHOERR', + preamble=preamble, + dtype_out=ri_dtype, + neutral='(double3)(0.0, 0.0, 0.0)', + map_expr=update_ri_source, + reduce_expr='a+b', + arguments=ctype + ' *r') + + def ri_update(r, e): + g = ri_kernel(r, wait_for=e).astype(ri_dtype).get() + rr, ri, ii = [g[q] for q in 'xyz'] + rho = rr + 2j * ri - ii + err = rr + ii + return rho, err + + return ri_update + + +def create_p_step(context): + update_p_source = ''' + p[i] = cdouble_add(r[i], cdouble_mul(beta, p[i])); + ''' + p_args = ptrs('p', 'r') + [ctype + ' beta'] + + p_kernel = ElementwiseKernel(context, + name='P', + preamble=preamble, + operation=update_p_source, + arguments=', '.join(p_args)) + + def p_update(p, r, beta, e): + return [p_kernel(p, r, beta, wait_for=e)] + + return p_update + + +def create_dot(context): + dot_dtype = numpy.complex128 + + dot_kernel = ReductionKernel(context, + name='dot', + preamble=preamble, + dtype_out=dot_dtype, + neutral='cdouble_new(0.0, 0.0)', + map_expr='cdouble_mul(p[i], v[i])', + reduce_expr='cdouble_add(a, b)', + arguments=ptrs('p', 'v')) + + def ri_update(p, v, e): + g = dot_kernel(p, v, wait_for=e) + return g.get() + + return ri_update