refactor solver (untested)

This commit is contained in:
jan 2016-08-03 14:51:25 -07:00
commit ff3951ba35
5 changed files with 194 additions and 124 deletions

View file

@ -29,65 +29,10 @@ def type_to_C(float_type: numpy.float32 or numpy.float64) -> str:
return types[float_type]
def shape_source(shape) -> str:
"""
Defines sx, sy, sz C constants specifying the shape of the grid in each of the 3 dimensions.
:param shape: [sx, sy, sz] values.
:return: String containing C source.
"""
sxyz = """
// Field sizes
const int sx = {shape[0]};
const int sy = {shape[1]};
const int sz = {shape[2]};
""".format(shape=shape)
return sxyz
# Defines dix, diy, diz constants used for stepping in the x, y, z directions in a linear array
# (ie, given Ex[i] referring to position (x, y, z), Ex[i+diy] will refer to position (x, y+1, z))
dixyz_source = """
// Convert offset in field xyz to linear index offset
const int dix = 1;
const int diy = sx;
const int diz = sx * sy;
"""
# Given a linear index i and shape sx, sy, sz, defines x, y, and z
# as the 3D indices of the current element (i).
xyz_source = """
// Convert linear index to field index (xyz)
const int z = i / (sx * sy);
const int y = (i - z * sx * sy) / sx;
const int x = (i - y * sx - z * sx * sy);
"""
vec_source = """
if (i >= sx * sy * sz) {
PYOPENCL_ELWISE_CONTINUE;
}
//Pointers into the components of a vectorized vector-field
const int XX = 0;
const int YY = sx * sy * sz;
const int ZZ = sx * sy * sz * 2;
"""
E_ptrs = """
__global cdouble_t *Ex = E + XX;
__global cdouble_t *Ey = E + YY;
__global cdouble_t *Ez = E + ZZ;
"""
H_ptrs = """
__global cdouble_t *Hx = H + XX;
__global cdouble_t *Hy = H + YY;
__global cdouble_t *Hz = H + ZZ;
"""
preamble = '''
#define PYOPENCL_DEFINE_CDOUBLE
#include <pyopencl-complex.h>
'''
ctype = type_to_C(numpy.complex128)
@ -98,15 +43,17 @@ def ptrs(*args):
def create_a(context, shape, mu=False, pec=False, pmc=False):
header = shape_source(shape) + dixyz_source + xyz_source
vec_h = vec_source + E_ptrs + H_ptrs
common_source = jinja_env.get_template('common.cl').render(shape=shape,
ctype=ctype)
pec_arg = ['char *pec']
pmc_arg = ['char *pmc']
des = [ctype + ' *inv_de' + a for a in 'xyz']
dhs = [ctype + ' *inv_dh' + a for a in 'xyz']
p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec)
p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec,
ctype=ctype)
P2E_kernel = ElementwiseKernel(context,
name='P2E',
preamble=preamble,
@ -115,8 +62,7 @@ def create_a(context, shape, mu=False, pec=False, pmc=False):
e2h_source = jinja_env.get_template('e2h.cl').render(mu=mu,
pmc=pmc,
dixyz_source=header,
vec_source=vec_h)
common_cl=common_source)
E2H_kernel = ElementwiseKernel(context,
name='E2H',
preamble=preamble,
@ -124,8 +70,7 @@ def create_a(context, shape, mu=False, pec=False, pmc=False):
arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des))
h2e_source = jinja_env.get_template('h2e.cl').render(pec=pec,
dixyz_source=header,
vec_source=vec_h)
common_cl=common_source)
H2E_kernel = ElementwiseKernel(context,
name='H2E',
preamble=preamble,
@ -143,8 +88,8 @@ def create_a(context, shape, mu=False, pec=False, pmc=False):
def create_xr_step(context):
update_xr_source = '''
x[i] = cdouble_add(x[i], cdouble_mul(alpha, p[i]));
r[i] = cdouble_sub(r[i], cdouble_mul(alpha, v[i]));
x[i] = add(x[i], mul(alpha, p[i]));
r[i] = sub(r[i], mul(alpha, v[i]));
'''
xr_args = ', '.join(ptrs('x', 'p', 'r', 'v') + [ctype + ' alpha'])
@ -191,7 +136,7 @@ def create_rhoerr_step(context):
def create_p_step(context):
update_p_source = '''
p[i] = cdouble_add(r[i], cdouble_mul(beta, p[i]));
p[i] = add(r[i], mul(beta, p[i]));
'''
p_args = ptrs('p', 'r') + [ctype + ' beta']
@ -214,9 +159,9 @@ def create_dot(context):
name='dot',
preamble=preamble,
dtype_out=dot_dtype,
neutral='cdouble_new(0.0, 0.0)',
map_expr='cdouble_mul(p[i], v[i])',
reduce_expr='cdouble_add(a, b)',
neutral='zero',
map_expr='mul(p[i], v[i])',
reduce_expr='add(a, b)',
arguments=ptrs('p', 'v'))
def ri_update(p, v, e):
@ -230,14 +175,14 @@ def create_a_csr(context):
spmv_source = '''
int start = m_row_ptr[i];
int stop = m_row_ptr[i+1];
cdouble_t dot = cdouble_new(0.0, 0.0);
dtype dot = zero;
int col_ind, d_ind;
for (int j=start; j<stop; j++) {
col_ind = m_col_ind[j];
d_ind = j;
dot = cdouble_add(dot, cdouble_mul(v_in[col_ind], m_data[d_ind]));
dot = add(dot, mul(v_in[col_ind], m_data[d_ind]));
}
v_out[i] = dot;
'''