split up into multiple files

release
jan 8 years ago
parent 20949f56ff
commit a379d8b794

@ -0,0 +1 @@
from .main import cg_solver

@ -0,0 +1,87 @@
/*
*
* H update equations
*
*/
//Define sx, x, dix (and y, z versions of those)
{{dixyz_source}}
//Define vectorized fields and pointers (eg. Hx = H + XX)
{{vec_source}}
// Wrap indices if necessary
int ipx, ipy, ipz;
if ( x == sx - 1 ) {
ipx = i - (sx - 1) * dix;
} else {
ipx = i + dix;
}
if ( y == sy - 1 ) {
ipy = i - (sy - 1) * diy;
} else {
ipy = i + diy;
}
if ( z == sz - 1 ) {
ipz = i - (sz - 1) * diz;
} else {
ipz = i + diz;
}
//Update H components; set them to 0 if PMC is enabled there.
// Also divide by mu only if requested.
{% if pmc -%}
if (pmc[XX + i]) {
Hx[i] = cdouble_new(0.0, 0.0);
} else
{%- endif -%}
{
cdouble_t Dzy = cdouble_mul(cdouble_sub(Ez[ipy], Ez[i]), inv_dey[y]);
cdouble_t Dyz = cdouble_mul(cdouble_sub(Ey[ipz], Ey[i]), inv_dez[z]);
{%- if mu -%}
Hx[i] = cdouble_mul(inv_mu[XX + i], cdouble_sub(Dzy, Dyz));
{%- else -%}
Hx[i] = cdouble_sub(Dzy, Dyz);
{%- endif %}
}
{% if pmc -%}
if (pmc[YY + i]) {
Hy[i] = cdouble_new(0.0, 0.0);
} else
{%- endif -%}
{
cdouble_t Dxz = cdouble_mul(cdouble_sub(Ex[ipz], Ex[i]), inv_dez[z]);
cdouble_t Dzx = cdouble_mul(cdouble_sub(Ez[ipx], Ez[i]), inv_dex[x]);
{%- if mu -%}
Hy[i] = cdouble_mul(inv_mu[YY + i], cdouble_sub(Dxz, Dzx));
{%- else -%}
Hy[i] = cdouble_sub(Dxz, Dzx);
{%- endif %}
}
{% if pmc -%}
if (pmc[XX + i]) {
Hx[i] = cdouble_new(0.0, 0.0);
} else
{%- endif -%}
{
cdouble_t Dyx = cdouble_mul(cdouble_sub(Ey[ipx], Ey[i]), inv_dex[x]);
cdouble_t Dxy = cdouble_mul(cdouble_sub(Ex[ipy], Ex[i]), inv_dey[y]);
{%- if mu -%}
Hz[i] = cdouble_mul(inv_mu[ZZ + i], cdouble_sub(Dyx, Dxy));
{%- else -%}
Hz[i] = cdouble_sub(Dyx, Dxy);
{%- endif %}
}
/*
* End H update equations
*/

@ -0,0 +1,77 @@
/*
*
* E update equations
*
*/
//Define sx, x, dix (and y, z versions of those)
{{dixyz_source}}
//Define vectorized fields and pointers (eg. Hx = H + XX)
{{vec_source}}
// Wrap indices if necessary
int imx, imy, imz;
if ( x == 0 ) {
imx = i + (sx - 1) * dix;
} else {
imx = i - dix;
}
if ( y == 0 ) {
imy = i + (sy - 1) * diy;
} else {
imy = i - diy;
}
if ( z == 0 ) {
imz = i + (sz - 1) * diz;
} else {
imz = i - diz;
}
//Update E components; set them to 0 if PEC is enabled there.
{% if pec -%}
if (pec[XX + i]) {
Ex[i] = cdouble_new(0.0, 0.0);
} else
{%- endif -%}
{
cdouble_t tEx = cdouble_mul(Ex[i], oeps[XX + i]);
cdouble_t Dzy = cdouble_mul(cdouble_sub(Hz[i], Hz[imy]), inv_dhy[y]);
cdouble_t Dyz = cdouble_mul(cdouble_sub(Hy[i], Hy[imz]), inv_dhz[z]);
tEx = cdouble_add(tEx, cdouble_sub(Dzy, Dyz));
Ex[i] = cdouble_mul(tEx, Pl[XX + i]);
}
{% if pec -%}
if (pec[YY + i]) {
Ey[i] = cdouble_new(0.0, 0.0);
} else
{%- endif -%}
{
cdouble_t tEy = cdouble_mul(Ey[i], oeps[YY + i]);
cdouble_t Dxz = cdouble_mul(cdouble_sub(Hx[i], Hx[imz]), inv_dhz[z]);
cdouble_t Dzx = cdouble_mul(cdouble_sub(Hz[i], Hz[imx]), inv_dhx[x]);
tEy = cdouble_add(tEy, cdouble_sub(Dxz, Dzx));
Ey[i] = cdouble_mul(tEy, Pl[YY + i]);
}
{% if pec -%}
if (pec[ZZ + i]) {
Ez[i] = cdouble_new(0.0, 0.0);
} else
{%- endif -%}
{
cdouble_t tEz = cdouble_mul(Ez[i], oeps[ZZ + i]);
cdouble_t Dyx = cdouble_mul(cdouble_sub(Hy[i], Hy[imx]), inv_dhx[x]);
cdouble_t Dxy = cdouble_mul(cdouble_sub(Hx[i], Hx[imy]), inv_dhy[y]);
tEz = cdouble_add(tEz, cdouble_sub(Dyx, Dxy));
Ez[i] = cdouble_mul(tEz, Pl[ZZ + i]);
}
/*
* End H update equations
*/

@ -1,364 +1,13 @@
import numpy import numpy
from numpy.linalg import norm from numpy.linalg import norm
import jinja2
import pyopencl import pyopencl
import pyopencl.array import pyopencl.array
from pyopencl.elementwise import ElementwiseKernel
from pyopencl.reduction import ReductionKernel
import time import time
import fdfd_tools.operators import fdfd_tools.operators
from . import ops
def type_to_C(float_type: numpy.float32 or numpy.float64) -> str:
"""
Returns a string corresponding to the C equivalent of a numpy type.
:param float_type: numpy type: float32, float64, complex64, complex128
:return: string containing the corresponding C type (eg. 'double')
"""
types = {
numpy.float32: 'float',
numpy.float64: 'double',
numpy.complex64: 'cfloat_t',
numpy.complex128: 'cdouble_t',
}
if float_type not in types:
raise Exception('Unsupported type')
return types[float_type]
def shape_source(shape) -> str:
"""
Defines sx, sy, sz C constants specifying the shape of the grid in each of the 3 dimensions.
:param shape: [sx, sy, sz] values.
:return: String containing C source.
"""
sxyz = """
// Field sizes
const int sx = {shape[0]};
const int sy = {shape[1]};
const int sz = {shape[2]};
""".format(shape=shape)
return sxyz
# Defines dix, diy, diz constants used for stepping in the x, y, z directions in a linear array
# (ie, given Ex[i] referring to position (x, y, z), Ex[i+diy] will refer to position (x, y+1, z))
dixyz_source = """
// Convert offset in field xyz to linear index offset
const int dix = 1;
const int diy = sx;
const int diz = sx * sy;
"""
# Given a linear index i and shape sx, sy, sz, defines x, y, and z
# as the 3D indices of the current element (i).
xyz_source = """
// Convert linear index to field index (xyz)
const int z = i / (sx * sy);
const int y = (i - z * sx * sy) / sx;
const int x = (i - y * sx - z * sx * sy);
"""
vec_source = """
if (i >= sx * sy * sz) {
PYOPENCL_ELWISE_CONTINUE;
}
//Pointers into the components of a vectorized vector-field
const int XX = 0;
const int YY = sx * sy * sz;
const int ZZ = sx * sy * sz * 2;
"""
E_ptrs = """
__global cdouble_t *Ex = E + XX;
__global cdouble_t *Ey = E + YY;
__global cdouble_t *Ez = E + ZZ;
"""
H_ptrs = """
__global cdouble_t *Hx = H + XX;
__global cdouble_t *Hy = H + YY;
__global cdouble_t *Hz = H + ZZ;
"""
# Source code for updating the E field; maxes use of dixyz_source.
maxwell_E_source = """
// E update equations
int imx, imy, imz;
if ( x == 0 ) {
imx = i + (sx - 1) * dix;
} else {
imx = i - dix;
}
if ( y == 0 ) {
imy = i + (sy - 1) * diy;
} else {
imy = i - diy;
}
if ( z == 0 ) {
imz = i + (sz - 1) * diz;
} else {
imz = i - diz;
}
// E update equations
{% if pec -%}
if (pec[XX + i]) {
Ex[i] = cdouble_new(0.0, 0.0);
} else
{%- endif -%}
{
cdouble_t tEx = cdouble_mul(Ex[i], oeps[XX + i]);
cdouble_t Dzy = cdouble_mul(cdouble_sub(Hz[i], Hz[imy]), inv_dhy[y]);
cdouble_t Dyz = cdouble_mul(cdouble_sub(Hy[i], Hy[imz]), inv_dhz[z]);
tEx = cdouble_add(tEx, cdouble_sub(Dzy, Dyz));
Ex[i] = cdouble_mul(tEx, Pl[XX + i]);
}
{% if pec -%}
if (pec[YY + i]) {
Ey[i] = cdouble_new(0.0, 0.0);
} else
{%- endif -%}
{
cdouble_t tEy = cdouble_mul(Ey[i], oeps[YY + i]);
cdouble_t Dxz = cdouble_mul(cdouble_sub(Hx[i], Hx[imz]), inv_dhz[z]);
cdouble_t Dzx = cdouble_mul(cdouble_sub(Hz[i], Hz[imx]), inv_dhx[x]);
tEy = cdouble_add(tEy, cdouble_sub(Dxz, Dzx));
Ey[i] = cdouble_mul(tEy, Pl[YY + i]);
}
{% if pec -%}
if (pec[ZZ + i]) {
Ez[i] = cdouble_new(0.0, 0.0);
} else
{%- endif -%}
{
cdouble_t tEz = cdouble_mul(Ez[i], oeps[ZZ + i]);
cdouble_t Dyx = cdouble_mul(cdouble_sub(Hy[i], Hy[imx]), inv_dhx[x]);
cdouble_t Dxy = cdouble_mul(cdouble_sub(Hx[i], Hx[imy]), inv_dhy[y]);
tEz = cdouble_add(tEz, cdouble_sub(Dyx, Dxy));
Ez[i] = cdouble_mul(tEz, Pl[ZZ + i]);
}
"""
# Source code for updating the H field; maxes use of dixyz_source and assumes mu=0
maxwell_H_source = """
// H update equations
int ipx, ipy, ipz;
if ( x == sx - 1 ) {
ipx = i - (sx - 1) * dix;
} else {
ipx = i + dix;
}
if ( y == sy - 1 ) {
ipy = i - (sy - 1) * diy;
} else {
ipy = i + diy;
}
if ( z == sz - 1 ) {
ipz = i - (sz - 1) * diz;
} else {
ipz = i + diz;
}
{% if pmc -%}
if (pmc[XX + i]) {
Hx[i] = cdouble_new(0.0, 0.0);
} else
{%- endif -%}
{
cdouble_t Dzy = cdouble_mul(cdouble_sub(Ez[ipy], Ez[i]), inv_dey[y]);
cdouble_t Dyz = cdouble_mul(cdouble_sub(Ey[ipz], Ey[i]), inv_dez[z]);
{%- if mu -%}
Hx[i] = cdouble_mul(inv_mu[XX + i], cdouble_sub(Dzy, Dyz));
{%- else -%}
Hx[i] = cdouble_sub(Dzy, Dyz);
{%- endif %}
}
{% if pmc -%}
if (pmc[YY + i]) {
Hy[i] = cdouble_new(0.0, 0.0);
} else
{%- endif -%}
{
cdouble_t Dxz = cdouble_mul(cdouble_sub(Ex[ipz], Ex[i]), inv_dez[z]);
cdouble_t Dzx = cdouble_mul(cdouble_sub(Ez[ipx], Ez[i]), inv_dex[x]);
{%- if mu -%}
Hy[i] = cdouble_mul(inv_mu[YY + i], cdouble_sub(Dxz, Dzx));
{%- else -%}
Hy[i] = cdouble_sub(Dxz, Dzx);
{%- endif %}
}
{% if pmc -%}
if (pmc[XX + i]) {
Hx[i] = cdouble_new(0.0, 0.0);
} else
{%- endif -%}
{
cdouble_t Dyx = cdouble_mul(cdouble_sub(Ey[ipx], Ey[i]), inv_dex[x]);
cdouble_t Dxy = cdouble_mul(cdouble_sub(Ex[ipy], Ex[i]), inv_dey[y]);
{%- if mu -%}
Hz[i] = cdouble_mul(inv_mu[ZZ + i], cdouble_sub(Dyx, Dxy));
{%- else -%}
Hz[i] = cdouble_sub(Dyx, Dxy);
{%- endif %}
}
"""
p2e_source = '''
Ex[i] = cdouble_mul(Pr[XX + i], p[XX + i]);
Ey[i] = cdouble_mul(Pr[YY + i], p[YY + i]);
Ez[i] = cdouble_mul(Pr[ZZ + i], p[ZZ + i]);
'''
preamble = '''
#define PYOPENCL_DEFINE_CDOUBLE
#include <pyopencl-complex.h>
'''
ctype = type_to_C(numpy.complex128)
def ptrs(*args):
return [ctype + ' *' + s for s in args]
def create_a(context, shape, mu=False, pec=False, pmc=False):
dhs = [ctype + ' *inv_dh' + a for a in 'xyz']
des = [ctype + ' *inv_de' + a for a in 'xyz']
header = shape_source(shape) + dixyz_source + xyz_source + vec_source + E_ptrs
P2E_kernel = ElementwiseKernel(context,
name='P2E',
preamble=preamble,
operation=header + p2e_source,
arguments=', '.join(ptrs('E', 'p', 'Pr')))
pmc_arg = ['int *pmc']
e2h_source = header + H_ptrs + jinja2.Template(maxwell_H_source).render(mu=mu, pmc=pmc)
E2H_kernel = ElementwiseKernel(context,
name='E2H',
preamble=preamble,
operation=e2h_source,
arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des))
pec_arg = ['int *pec']
h2e_source = header + H_ptrs + jinja2.Template(maxwell_E_source).render(pec=pec)
H2E_kernel = ElementwiseKernel(context,
name='H2E',
preamble=preamble,
operation=h2e_source,
arguments=', '.join(ptrs('E', 'H', 'oeps', 'Pl') + pec_arg + dhs))
def spmv(E, H, p, idxes, oeps, inv_mu, pec, pmc, Pl, Pr, e):
e2 = P2E_kernel(E, p, Pr, wait_for=e)
e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2])
e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2])
return [e2]
return spmv
def create_xr_step(context):
update_xr_source = '''
x[i] = cdouble_add(x[i], cdouble_mul(alpha, p[i]));
r[i] = cdouble_sub(r[i], cdouble_mul(alpha, v[i]));
'''
xr_args = ', '.join(ptrs('x', 'p', 'r', 'v') + [ctype + ' alpha'])
xr_kernel = ElementwiseKernel(context,
name='XR',
preamble=preamble,
operation=update_xr_source,
arguments=xr_args)
def xr_update(x, p, r, v, alpha, e):
return [xr_kernel(x, p, r, v, alpha, wait_for=e)]
return xr_update
def create_rhoerr_step(context):
update_ri_source = '''
(double3)(r[i].real * r[i].real, \
r[i].real * r[i].imag, \
r[i].imag * r[i].imag)
'''
ri_dtype = pyopencl.array.vec.double3
ri_kernel = ReductionKernel(context,
name='RHOERR',
preamble=preamble,
dtype_out=ri_dtype,
neutral='(double3)(0.0, 0.0, 0.0)',
map_expr=update_ri_source,
reduce_expr='a+b',
arguments=ctype + ' *r')
def ri_update(r, e):
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
rr, ri, ii = [g[q] for q in 'xyz']
rho = rr + 2j * ri - ii
err = rr + ii
return rho, err
return ri_update
def create_p_step(context):
update_p_source = '''
p[i] = cdouble_add(r[i], cdouble_mul(beta, p[i]));
'''
p_args = ptrs('p', 'r') + [ctype + ' beta']
p_kernel = ElementwiseKernel(context,
name='P',
preamble=preamble,
operation=update_p_source,
arguments=', '.join(p_args))
def p_update(p, r, beta, e):
return [p_kernel(p, r, beta, wait_for=e)]
return p_update
def create_dot(context):
dot_dtype = numpy.complex128
dot_kernel = ReductionKernel(context,
name='dot',
preamble=preamble,
dtype_out=dot_dtype,
neutral='cdouble_new(0.0, 0.0)',
map_expr='cdouble_mul(p[i], v[i])',
reduce_expr='cdouble_add(a, b)',
arguments=ptrs('p', 'v'))
def ri_update(p, v, e):
g = dot_kernel(p, v, wait_for=e)
return g.get()
return ri_update
def cg_solver(omega, dxes, J, epsilon, mu=None, pec=None, pmc=None, adjoint=False, def cg_solver(omega, dxes, J, epsilon, mu=None, pec=None, pmc=None, adjoint=False,
@ -444,25 +93,25 @@ def cg_solver(omega, dxes, J, epsilon, mu=None, pec=None, pmc=None, adjoint=Fals
invm = load_field(1 / mu) invm = load_field(1 / mu)
if pec is None: if pec is None:
gpec = load_field(numpy.array([]), dtype=int) gpec = load_field(numpy.array([]), dtype=numpy.int8)
else: else:
gpec = load_field(pec, dtype=int) gpec = load_field(pec, dtype=numpy.int8)
if pmc is None: if pmc is None:
gpmc = load_field(numpy.array([]), dtype=int) gpmc = load_field(numpy.array([]), dtype=numpy.int8)
else: else:
gpmc = load_field(pmc, dtype=int) gpmc = load_field(pmc, dtype=numpy.int8)
''' '''
Generate OpenCL kernels Generate OpenCL kernels
''' '''
has_mu, has_pec, has_pmc = [q is not None for q in (mu, pec, pmc)] has_mu, has_pec, has_pmc = [q is not None for q in (mu, pec, pmc)]
a_step_full = create_a(context, shape, has_mu, has_pec, has_pmc) a_step_full = ops.create_a(context, shape, has_mu, has_pec, has_pmc)
xr_step = create_xr_step(context) xr_step = ops.create_xr_step(context)
rhoerr_step = create_rhoerr_step(context) rhoerr_step = ops.create_rhoerr_step(context)
p_step = create_p_step(context) p_step = ops.create_p_step(context)
dot = create_dot(context) dot = ops.create_dot(context)
def a_step(E, H, p, events): def a_step(E, H, p, events):
return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events) return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events)

@ -0,0 +1,225 @@
import numpy
import jinja2
import pyopencl
import pyopencl.array
from pyopencl.elementwise import ElementwiseKernel
from pyopencl.reduction import ReductionKernel
# Create jinja2 env on module load
jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__, 'kernels'))
def type_to_C(float_type: numpy.float32 or numpy.float64) -> str:
"""
Returns a string corresponding to the C equivalent of a numpy type.
:param float_type: numpy type: float32, float64, complex64, complex128
:return: string containing the corresponding C type (eg. 'double')
"""
types = {
numpy.float32: 'float',
numpy.float64: 'double',
numpy.complex64: 'cfloat_t',
numpy.complex128: 'cdouble_t',
}
if float_type not in types:
raise Exception('Unsupported type')
return types[float_type]
def shape_source(shape) -> str:
"""
Defines sx, sy, sz C constants specifying the shape of the grid in each of the 3 dimensions.
:param shape: [sx, sy, sz] values.
:return: String containing C source.
"""
sxyz = """
// Field sizes
const int sx = {shape[0]};
const int sy = {shape[1]};
const int sz = {shape[2]};
""".format(shape=shape)
return sxyz
# Defines dix, diy, diz constants used for stepping in the x, y, z directions in a linear array
# (ie, given Ex[i] referring to position (x, y, z), Ex[i+diy] will refer to position (x, y+1, z))
dixyz_source = """
// Convert offset in field xyz to linear index offset
const int dix = 1;
const int diy = sx;
const int diz = sx * sy;
"""
# Given a linear index i and shape sx, sy, sz, defines x, y, and z
# as the 3D indices of the current element (i).
xyz_source = """
// Convert linear index to field index (xyz)
const int z = i / (sx * sy);
const int y = (i - z * sx * sy) / sx;
const int x = (i - y * sx - z * sx * sy);
"""
vec_source = """
if (i >= sx * sy * sz) {
PYOPENCL_ELWISE_CONTINUE;
}
//Pointers into the components of a vectorized vector-field
const int XX = 0;
const int YY = sx * sy * sz;
const int ZZ = sx * sy * sz * 2;
"""
E_ptrs = """
__global cdouble_t *Ex = E + XX;
__global cdouble_t *Ey = E + YY;
__global cdouble_t *Ez = E + ZZ;
"""
H_ptrs = """
__global cdouble_t *Hx = H + XX;
__global cdouble_t *Hy = H + YY;
__global cdouble_t *Hz = H + ZZ;
"""
preamble = '''
#define PYOPENCL_DEFINE_CDOUBLE
#include <pyopencl-complex.h>
'''
ctype = type_to_C(numpy.complex128)
def ptrs(*args):
return [ctype + ' *' + s for s in args]
def create_a(context, shape, mu=False, pec=False, pmc=False):
header = shape_source(shape) + dixyz_source + xyz_source
vec_h = vec_source + E_ptrs + H_ptrs
p2e_source = 'E[i] = cdouble_mul(Pr[i], p[i]);'
P2E_kernel = ElementwiseKernel(context,
name='P2E',
preamble=preamble,
operation=p2e_source,
arguments=', '.join(ptrs('E', 'p', 'Pr')))
pmc_arg = ['char *pmc']
des = [ctype + ' *inv_de' + a for a in 'xyz']
e2h_source = jinja_env.get_template('e2h.cl').render(mu=mu,
pmc=pmc,
dixyz_source=header,
vec_source=vec_h)
E2H_kernel = ElementwiseKernel(context,
name='E2H',
preamble=preamble,
operation=e2h_source,
arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des))
pec_arg = ['char *pec']
dhs = [ctype + ' *inv_dh' + a for a in 'xyz']
h2e_source = jinja_env.get_template('h2e.cl').render(pmc=pec,
dixyz_source=header,
vec_source=vec_h)
H2E_kernel = ElementwiseKernel(context,
name='H2E',
preamble=preamble,
operation=h2e_source,
arguments=', '.join(ptrs('E', 'H', 'oeps', 'Pl') + pec_arg + dhs))
def spmv(E, H, p, idxes, oeps, inv_mu, pec, pmc, Pl, Pr, e):
e2 = P2E_kernel(E, p, Pr, wait_for=e)
e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2])
e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2])
return [e2]
return spmv
def create_xr_step(context):
update_xr_source = '''
x[i] = cdouble_add(x[i], cdouble_mul(alpha, p[i]));
r[i] = cdouble_sub(r[i], cdouble_mul(alpha, v[i]));
'''
xr_args = ', '.join(ptrs('x', 'p', 'r', 'v') + [ctype + ' alpha'])
xr_kernel = ElementwiseKernel(context,
name='XR',
preamble=preamble,
operation=update_xr_source,
arguments=xr_args)
def xr_update(x, p, r, v, alpha, e):
return [xr_kernel(x, p, r, v, alpha, wait_for=e)]
return xr_update
def create_rhoerr_step(context):
update_ri_source = '''
(double3)(r[i].real * r[i].real, \
r[i].real * r[i].imag, \
r[i].imag * r[i].imag)
'''
ri_dtype = pyopencl.array.vec.double3
ri_kernel = ReductionKernel(context,
name='RHOERR',
preamble=preamble,
dtype_out=ri_dtype,
neutral='(double3)(0.0, 0.0, 0.0)',
map_expr=update_ri_source,
reduce_expr='a+b',
arguments=ctype + ' *r')
def ri_update(r, e):
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
rr, ri, ii = [g[q] for q in 'xyz']
rho = rr + 2j * ri - ii
err = rr + ii
return rho, err
return ri_update
def create_p_step(context):
update_p_source = '''
p[i] = cdouble_add(r[i], cdouble_mul(beta, p[i]));
'''
p_args = ptrs('p', 'r') + [ctype + ' beta']
p_kernel = ElementwiseKernel(context,
name='P',
preamble=preamble,
operation=update_p_source,
arguments=', '.join(p_args))
def p_update(p, r, beta, e):
return [p_kernel(p, r, beta, wait_for=e)]
return p_update
def create_dot(context):
dot_dtype = numpy.complex128
dot_kernel = ReductionKernel(context,
name='dot',
preamble=preamble,
dtype_out=dot_dtype,
neutral='cdouble_new(0.0, 0.0)',
map_expr='cdouble_mul(p[i], v[i])',
reduce_expr='cdouble_add(a, b)',
arguments=ptrs('p', 'v'))
def ri_update(p, v, e):
g = dot_kernel(p, v, wait_for=e)
return g.get()
return ri_update
Loading…
Cancel
Save