From ff3951ba35e0e6189e5845251bed43e26acf73b3 Mon Sep 17 00:00:00 2001 From: jan Date: Wed, 3 Aug 2016 14:51:25 -0700 Subject: [PATCH] refactor solver (untested) --- opencl_fdfd/kernels/common.cl | 52 +++++++++++++++++++++ opencl_fdfd/kernels/e2h.cl | 77 ++++++++++++++++++++----------- opencl_fdfd/kernels/h2e.cl | 80 +++++++++++++++++++++----------- opencl_fdfd/kernels/p2e.cl | 24 +++++++++- opencl_fdfd/ops.py | 87 +++++++---------------------------- 5 files changed, 195 insertions(+), 125 deletions(-) create mode 100644 opencl_fdfd/kernels/common.cl diff --git a/opencl_fdfd/kernels/common.cl b/opencl_fdfd/kernels/common.cl new file mode 100644 index 0000000..d429089 --- /dev/null +++ b/opencl_fdfd/kernels/common.cl @@ -0,0 +1,52 @@ +/* Common code for E, H updates + * + * Template parameters: + * ctype string denoting type for storing complex field values + * shape list of 3 ints specifying shape of fields + */ + +//Defines to clean up operation names +#define ctype {{ctype}}_t +#define zero {{ctype}}_new(0.0, 0.0) +#define add {{ctype}}_add +#define sub {{ctype}}_sub +#define mul {{ctype}}_mul + +// Field sizes +const int sx = {shape[0]}; +const int sy = {shape[1]}; +const int sz = {shape[2]}; + +//Since we use i to index into Ex[], E[], ... rather than E[], do nothing if +// i is outside the bounds of Ex[]. +if (i >= sx * sy * sz) { + PYOPENCL_ELWISE_CONTINUE; +} + +// Given a linear index i and shape (sx, sy, sz), defines x, y, and z +// as the 3D indices of the current element (i). +// (ie, converts linear index [i] to field indices (x, y, z) +const int z = i / (sx * sy); +const int y = (i - z * sx * sy) / sx; +const int x = (i - y * sx - z * sx * sy); + +// Calculate linear index offsets corresponding to offsets in 3D +// (ie, if E[i] <-> E(x, y, z), then E[i + diy] <-> E(x, y + 1, z) +const int dix = 1; +const int diy = sx; +const int diz = sx * sy; + +//Pointer offsets into the components of a linearized vector-field +// (eg. Hx = H + XX, where H and Hx are pointers) +const int XX = 0; +const int YY = sx * sy * sz; +const int ZZ = sx * sy * sz * 2; + +//Define pointers to vector components of each field (eg. Hx = H + XX) +__global ctype *Ex = E + XX; +__global ctype *Ey = E + YY; +__global ctype *Ez = E + ZZ; + +__global ctype *Hx = H + XX; +__global ctype *Hy = H + YY; +__global ctype *Hz = H + ZZ; diff --git a/opencl_fdfd/kernels/e2h.cl b/opencl_fdfd/kernels/e2h.cl index 2227fdb..0332252 100644 --- a/opencl_fdfd/kernels/e2h.cl +++ b/opencl_fdfd/kernels/e2h.cl @@ -1,17 +1,39 @@ /* - * * H update equations * + * Template parameters: + * mu False if (mu == 1) everywhere + * pmc False if no PMC anywhere + * common_cl Rendered code from common.cl + * + * Arguments: + * ctype *E E-field + * ctype *H H-field + * ctype *inv_mu 1/mu (at H-field locations) + * char *pmc Boolean mask denoting presence of PMC (at H-field locations) + * ctype *inv_dex 1/dx_e (complex cell widths for x direction at E locations) + * ctype *inv_dey 1/dy_e (complex cell widths for y direction at E locations) + * ctype *inv_dez 1/dz_e (complex cell widths for z direction at E locations) + * */ -//Define sx, x, dix (and y, z versions of those) -{{dixyz_source}} +{{common_cl}} -//Define vectorized fields and pointers (eg. Hx = H + XX) -{{vec_source}} +__global ctype *inv_mu_x = inv_mu + XX; +__global ctype *inv_mu_y = inv_mu + YY; +__global ctype *inv_mu_z = inv_mu + ZZ; +__global ctype *pmc_x = pmc + XX; +__global ctype *pmc_y = pmc + YY; +__global ctype *pmc_z = pmc + ZZ; -// Wrap indices if necessary +/* + * Implement periodic boundary conditions + * + * ipx gives the index of the adjacent cell in the plus-x direction ([i]ndex [p]lus [x]). + * In the event that we start at x == (sx - 1), we actually want to wrap around and grab the cell + * where x == 0 instead, ie. ipx = i - (sx - 1) * dix . + */ int ipx, ipy, ipz; if ( x == sx - 1 ) { ipx = i - (sx - 1) * dix; @@ -32,53 +54,56 @@ if ( z == sz - 1 ) { } -//Update H components; set them to 0 if PMC is enabled there. -// Also divide by mu only if requested. +//Update H components; set them to 0 if PMC is enabled at that location. +//Mu division and PMC conditional are only included if {{mu}} and {{pmc}} are true {% if pmc -%} -if (pmc[XX + i] != 0) { - Hx[i] = cdouble_new(0.0, 0.0); +if (pmc_x[i] != 0) { + Hx[i] = zero; } else {%- endif -%} { - cdouble_t Dzy = cdouble_mul(cdouble_sub(Ez[ipy], Ez[i]), inv_dey[y]); - cdouble_t Dyz = cdouble_mul(cdouble_sub(Ey[ipz], Ey[i]), inv_dez[z]); + ctype Dzy = mul(sub(Ez[ipy], Ez[i]), inv_dey[y]); + ctype Dyz = mul(sub(Ey[ipz], Ey[i]), inv_dez[z]); + ctype x_curl = sub(Dzy, Dyz); {%- if mu -%} - Hx[i] = cdouble_mul(inv_mu[XX + i], cdouble_sub(Dzy, Dyz)); + Hx[i] = mul(inv_mu_x[i], x_curl); {%- else -%} - Hx[i] = cdouble_sub(Dzy, Dyz); + Hx[i] = x_curl; {%- endif %} } {% if pmc -%} -if (pmc[YY + i] != 0) { - Hy[i] = cdouble_new(0.0, 0.0); +if (pmc_y[i] != 0) { + Hy[i] = zero; } else {%- endif -%} { - cdouble_t Dxz = cdouble_mul(cdouble_sub(Ex[ipz], Ex[i]), inv_dez[z]); - cdouble_t Dzx = cdouble_mul(cdouble_sub(Ez[ipx], Ez[i]), inv_dex[x]); + ctype Dxz = mul(sub(Ex[ipz], Ex[i]), inv_dez[z]); + ctype Dzx = mul(sub(Ez[ipx], Ez[i]), inv_dex[x]); + ctype y_curl = sub(Dxz, Dzx); {%- if mu -%} - Hy[i] = cdouble_mul(inv_mu[YY + i], cdouble_sub(Dxz, Dzx)); + Hy[i] = mul(inv_mu_y[i], y_curl); {%- else -%} - Hy[i] = cdouble_sub(Dxz, Dzx); + Hy[i] = y_curl; {%- endif %} } {% if pmc -%} -if (pmc[ZZ + i] != 0) { - Hz[i] = cdouble_new(0.0, 0.0); +if (pmc_z[i] != 0) { + Hz[i] = zero; } else {%- endif -%} { - cdouble_t Dyx = cdouble_mul(cdouble_sub(Ey[ipx], Ey[i]), inv_dex[x]); - cdouble_t Dxy = cdouble_mul(cdouble_sub(Ex[ipy], Ex[i]), inv_dey[y]); + ctype Dyx = mul(sub(Ey[ipx], Ey[i]), inv_dex[x]); + ctype Dxy = mul(sub(Ex[ipy], Ex[i]), inv_dey[y]); + ctype z_curl = sub(Dyx, Dxy); {%- if mu -%} - Hz[i] = cdouble_mul(inv_mu[ZZ + i], cdouble_sub(Dyx, Dxy)); + Hz[i] = mul(inv_mu_z[i], z_curl); {%- else -%} - Hz[i] = cdouble_sub(Dyx, Dxy); + Hz[i] = z_curl; {%- endif %} } diff --git a/opencl_fdfd/kernels/h2e.cl b/opencl_fdfd/kernels/h2e.cl index 3dd968c..9c51a25 100644 --- a/opencl_fdfd/kernels/h2e.cl +++ b/opencl_fdfd/kernels/h2e.cl @@ -1,17 +1,45 @@ /* - * * E update equations * + * Template parameters: + * pec False if no PEC anywhere + * common_cl Rendered code from common.cl + * + * Arguments: + * ctype *E E-field + * ctype *H H-field + * ctype *oeps omega*epsilon (at E-field locations) + * ctype *Pl Entries of (diagonal) left preconditioner matrix + * char *pec Boolean mask denoting presence of PEC (at E-field locations) + * ctype *inv_dhx 1/dx_h (complex cell widths for x direction at H locations) + * ctype *inv_dhy 1/dy_h (complex cell widths for y direction at H locations) + * ctype *inv_dhz 1/dz_h (complex cell widths for z direction at H locations) + * */ -//Define sx, x, dix (and y, z versions of those) -{{dixyz_source}} - -//Define vectorized fields and pointers (eg. Hx = H + XX) -{{vec_source}} +{{common_cl}} -// Wrap indices if necessary +__global ctype *oeps_x = oeps + XX; +__global ctype *oeps_y = oeps + YY; +__global ctype *oeps_z = oeps + ZZ; + +__global ctype *pec_x = pec + XX; +__global ctype *pec_y = pec + YY; +__global ctype *pec_z = pec + ZZ; + +__global ctype *Pl_x = Pl + XX; +__global ctype *Pl_y = Pl + YY; +__global ctype *Pl_z = Pl + ZZ; + + +/* + * Implement periodic boundary conditions + * + * imx gives the index of the adjacent cell in the minus-x direction ([i]ndex [m]inus [x]). + * In the event that we start at x == 0, we actually want to wrap around and grab the cell + * where x == (sx - 1) instead, ie. imx = i + (sx - 1) * dix . + */ int imx, imy, imz; if ( x == 0 ) { imx = i + (sx - 1) * dix; @@ -34,38 +62,38 @@ if ( z == 0 ) { //Update E components; set them to 0 if PEC is enabled there. {% if pec -%} -if (pec[XX + i] == 0) +if (pec_x[i] == 0) {%- endif -%} { - cdouble_t tEx = cdouble_mul(Ex[i], oeps[XX + i]); - cdouble_t Dzy = cdouble_mul(cdouble_sub(Hz[i], Hz[imy]), inv_dhy[y]); - cdouble_t Dyz = cdouble_mul(cdouble_sub(Hy[i], Hy[imz]), inv_dhz[z]); - tEx = cdouble_add(tEx, cdouble_sub(Dzy, Dyz)); - Ex[i] = cdouble_mul(tEx, Pl[XX + i]); + ctype tEx = mul(Ex[i], oeps_x[i]); + ctype Dzy = mul(sub(Hz[i], Hz[imy]), inv_dhy[y]); + ctype Dyz = mul(sub(Hy[i], Hy[imz]), inv_dhz[z]); + tEx = add(tEx, sub(Dzy, Dyz)); + Ex[i] = mul(tEx, Pl_x[i]); } {% if pec -%} -if (pec[YY + i] == 0) +if (pec_y[i] == 0) {%- endif -%} { - cdouble_t tEy = cdouble_mul(Ey[i], oeps[YY + i]); - cdouble_t Dxz = cdouble_mul(cdouble_sub(Hx[i], Hx[imz]), inv_dhz[z]); - cdouble_t Dzx = cdouble_mul(cdouble_sub(Hz[i], Hz[imx]), inv_dhx[x]); - tEy = cdouble_add(tEy, cdouble_sub(Dxz, Dzx)); - Ey[i] = cdouble_mul(tEy, Pl[YY + i]); + ctype tEy = mul(Ey[i], oeps_y[i]); + ctype Dxz = mul(sub(Hx[i], Hx[imz]), inv_dhz[z]); + ctype Dzx = mul(sub(Hz[i], Hz[imx]), inv_dhx[x]); + tEy = add(tEy, sub(Dxz, Dzx)); + Ey[i] = mul(tEy, Pl_y[i]); } {% if pec -%} -if (pec[ZZ + i] == 0) +if (pec_z[i] == 0) {%- endif -%} { - cdouble_t tEz = cdouble_mul(Ez[i], oeps[ZZ + i]); - cdouble_t Dyx = cdouble_mul(cdouble_sub(Hy[i], Hy[imx]), inv_dhx[x]); - cdouble_t Dxy = cdouble_mul(cdouble_sub(Hx[i], Hx[imy]), inv_dhy[y]); - tEz = cdouble_add(tEz, cdouble_sub(Dyx, Dxy)); - Ez[i] = cdouble_mul(tEz, Pl[ZZ + i]); + ctype tEz = mul(Ez[i], oeps_z[i]); + ctype Dyx = mul(sub(Hy[i], Hy[imx]), inv_dhx[x]); + ctype Dxy = mul(sub(Hx[i], Hx[imy]), inv_dhy[y]); + tEz = add(tEz, sub(Dyx, Dxy)); + Ez[i] = mul(tEz, Pl_z[i]); } /* - * End H update equations + * End E update equations */ diff --git a/opencl_fdfd/kernels/p2e.cl b/opencl_fdfd/kernels/p2e.cl index 432eba9..ba2c87e 100644 --- a/opencl_fdfd/kernels/p2e.cl +++ b/opencl_fdfd/kernels/p2e.cl @@ -1,9 +1,29 @@ +/* + * Apply PEC and preconditioner. + * + * Template parameters: + * ctype name of complex type (eg. cdouble) + * pec false iff no PEC anyhwere + * + * Arguments: + * ctype *E (output) E-field + * ctype *Pr Entries of (diagonal) right preconditioner matrix + * ctype *p (input vector) + * + */ + + +//Defines to clean up operation names +#define ctype {{ctype}}_t +#define zero {{ctype}}_new(0.0, 0.0) +#define mul {{ctype}}_mul + {%- if pec -%} if (pec[i] != 0) { - E[i] = cdouble_new(0.0, 0.0); + E[i] = zero; } else {%- endif -%} { - E[i] = cdouble_mul(Pr[i], p[i]); + E[i] = mul(Pr[i], p[i]); } diff --git a/opencl_fdfd/ops.py b/opencl_fdfd/ops.py index 0a4ff87..108201b 100644 --- a/opencl_fdfd/ops.py +++ b/opencl_fdfd/ops.py @@ -29,65 +29,10 @@ def type_to_C(float_type: numpy.float32 or numpy.float64) -> str: return types[float_type] -def shape_source(shape) -> str: - """ - Defines sx, sy, sz C constants specifying the shape of the grid in each of the 3 dimensions. - - :param shape: [sx, sy, sz] values. - :return: String containing C source. - """ - sxyz = """ -// Field sizes -const int sx = {shape[0]}; -const int sy = {shape[1]}; -const int sz = {shape[2]}; -""".format(shape=shape) - return sxyz - -# Defines dix, diy, diz constants used for stepping in the x, y, z directions in a linear array -# (ie, given Ex[i] referring to position (x, y, z), Ex[i+diy] will refer to position (x, y+1, z)) -dixyz_source = """ -// Convert offset in field xyz to linear index offset -const int dix = 1; -const int diy = sx; -const int diz = sx * sy; -""" - -# Given a linear index i and shape sx, sy, sz, defines x, y, and z -# as the 3D indices of the current element (i). -xyz_source = """ -// Convert linear index to field index (xyz) -const int z = i / (sx * sy); -const int y = (i - z * sx * sy) / sx; -const int x = (i - y * sx - z * sx * sy); -""" - -vec_source = """ -if (i >= sx * sy * sz) { - PYOPENCL_ELWISE_CONTINUE; -} - -//Pointers into the components of a vectorized vector-field -const int XX = 0; -const int YY = sx * sy * sz; -const int ZZ = sx * sy * sz * 2; -""" - -E_ptrs = """ -__global cdouble_t *Ex = E + XX; -__global cdouble_t *Ey = E + YY; -__global cdouble_t *Ez = E + ZZ; -""" - -H_ptrs = """ -__global cdouble_t *Hx = H + XX; -__global cdouble_t *Hy = H + YY; -__global cdouble_t *Hz = H + ZZ; -""" - preamble = ''' #define PYOPENCL_DEFINE_CDOUBLE #include + ''' ctype = type_to_C(numpy.complex128) @@ -98,15 +43,17 @@ def ptrs(*args): def create_a(context, shape, mu=False, pec=False, pmc=False): - header = shape_source(shape) + dixyz_source + xyz_source - vec_h = vec_source + E_ptrs + H_ptrs + + common_source = jinja_env.get_template('common.cl').render(shape=shape, + ctype=ctype) pec_arg = ['char *pec'] pmc_arg = ['char *pmc'] des = [ctype + ' *inv_de' + a for a in 'xyz'] dhs = [ctype + ' *inv_dh' + a for a in 'xyz'] - p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec) + p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec, + ctype=ctype) P2E_kernel = ElementwiseKernel(context, name='P2E', preamble=preamble, @@ -115,8 +62,7 @@ def create_a(context, shape, mu=False, pec=False, pmc=False): e2h_source = jinja_env.get_template('e2h.cl').render(mu=mu, pmc=pmc, - dixyz_source=header, - vec_source=vec_h) + common_cl=common_source) E2H_kernel = ElementwiseKernel(context, name='E2H', preamble=preamble, @@ -124,8 +70,7 @@ def create_a(context, shape, mu=False, pec=False, pmc=False): arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des)) h2e_source = jinja_env.get_template('h2e.cl').render(pec=pec, - dixyz_source=header, - vec_source=vec_h) + common_cl=common_source) H2E_kernel = ElementwiseKernel(context, name='H2E', preamble=preamble, @@ -143,8 +88,8 @@ def create_a(context, shape, mu=False, pec=False, pmc=False): def create_xr_step(context): update_xr_source = ''' - x[i] = cdouble_add(x[i], cdouble_mul(alpha, p[i])); - r[i] = cdouble_sub(r[i], cdouble_mul(alpha, v[i])); + x[i] = add(x[i], mul(alpha, p[i])); + r[i] = sub(r[i], mul(alpha, v[i])); ''' xr_args = ', '.join(ptrs('x', 'p', 'r', 'v') + [ctype + ' alpha']) @@ -191,7 +136,7 @@ def create_rhoerr_step(context): def create_p_step(context): update_p_source = ''' - p[i] = cdouble_add(r[i], cdouble_mul(beta, p[i])); + p[i] = add(r[i], mul(beta, p[i])); ''' p_args = ptrs('p', 'r') + [ctype + ' beta'] @@ -214,9 +159,9 @@ def create_dot(context): name='dot', preamble=preamble, dtype_out=dot_dtype, - neutral='cdouble_new(0.0, 0.0)', - map_expr='cdouble_mul(p[i], v[i])', - reduce_expr='cdouble_add(a, b)', + neutral='zero', + map_expr='mul(p[i], v[i])', + reduce_expr='add(a, b)', arguments=ptrs('p', 'v')) def ri_update(p, v, e): @@ -230,14 +175,14 @@ def create_a_csr(context): spmv_source = ''' int start = m_row_ptr[i]; int stop = m_row_ptr[i+1]; - cdouble_t dot = cdouble_new(0.0, 0.0); + dtype dot = zero; int col_ind, d_ind; for (int j=start; j