fix bugs after refactor

release
jan 8 years ago
parent ff3951ba35
commit 8e3706948e

@ -1,21 +1,13 @@
/* Common code for E, H updates /* Common code for E, H updates
* *
* Template parameters: * Template parameters:
* ctype string denoting type for storing complex field values
* shape list of 3 ints specifying shape of fields * shape list of 3 ints specifying shape of fields
*/ */
//Defines to clean up operation names
#define ctype {{ctype}}_t
#define zero {{ctype}}_new(0.0, 0.0)
#define add {{ctype}}_add
#define sub {{ctype}}_sub
#define mul {{ctype}}_mul
// Field sizes // Field sizes
const int sx = {shape[0]}; const int sx = {{shape[0]}};
const int sy = {shape[1]}; const int sy = {{shape[1]}};
const int sz = {shape[2]}; const int sz = {{shape[2]}};
//Since we use i to index into Ex[], E[], ... rather than E[], do nothing if //Since we use i to index into Ex[], E[], ... rather than E[], do nothing if
// i is outside the bounds of Ex[]. // i is outside the bounds of Ex[].

@ -23,9 +23,9 @@ __global ctype *inv_mu_x = inv_mu + XX;
__global ctype *inv_mu_y = inv_mu + YY; __global ctype *inv_mu_y = inv_mu + YY;
__global ctype *inv_mu_z = inv_mu + ZZ; __global ctype *inv_mu_z = inv_mu + ZZ;
__global ctype *pmc_x = pmc + XX; __global char *pmc_x = pmc + XX;
__global ctype *pmc_y = pmc + YY; __global char *pmc_y = pmc + YY;
__global ctype *pmc_z = pmc + ZZ; __global char *pmc_z = pmc + ZZ;
/* /*
* Implement periodic boundary conditions * Implement periodic boundary conditions

@ -24,9 +24,9 @@ __global ctype *oeps_x = oeps + XX;
__global ctype *oeps_y = oeps + YY; __global ctype *oeps_y = oeps + YY;
__global ctype *oeps_z = oeps + ZZ; __global ctype *oeps_z = oeps + ZZ;
__global ctype *pec_x = pec + XX; __global char *pec_x = pec + XX;
__global ctype *pec_y = pec + YY; __global char *pec_y = pec + YY;
__global ctype *pec_z = pec + ZZ; __global char *pec_z = pec + ZZ;
__global ctype *Pl_x = Pl + XX; __global ctype *Pl_x = Pl + XX;
__global ctype *Pl_y = Pl + YY; __global ctype *Pl_y = Pl + YY;

@ -2,7 +2,6 @@
* Apply PEC and preconditioner. * Apply PEC and preconditioner.
* *
* Template parameters: * Template parameters:
* ctype name of complex type (eg. cdouble)
* pec false iff no PEC anyhwere * pec false iff no PEC anyhwere
* *
* Arguments: * Arguments:
@ -13,12 +12,6 @@
*/ */
//Defines to clean up operation names
#define ctype {{ctype}}_t
#define zero {{ctype}}_new(0.0, 0.0)
#define mul {{ctype}}_mul
{%- if pec -%} {%- if pec -%}
if (pec[i] != 0) { if (pec[i] != 0) {
E[i] = zero; E[i] = zero;

@ -28,14 +28,20 @@ def type_to_C(float_type: numpy.float32 or numpy.float64) -> str:
return types[float_type] return types[float_type]
ctype = type_to_C(numpy.complex128)
ctype_bare = 'cdouble'
preamble = ''' preamble = '''
#define PYOPENCL_DEFINE_CDOUBLE #define PYOPENCL_DEFINE_CDOUBLE
#include <pyopencl-complex.h> #include <pyopencl-complex.h>
''' //Defines to clean up operation and type names
#define ctype {ctype}_t
ctype = type_to_C(numpy.complex128) #define zero {ctype}_new(0.0, 0.0)
#define add {ctype}_add
#define sub {ctype}_sub
#define mul {ctype}_mul
'''.format(ctype=ctype_bare)
def ptrs(*args): def ptrs(*args):
@ -44,16 +50,14 @@ def ptrs(*args):
def create_a(context, shape, mu=False, pec=False, pmc=False): def create_a(context, shape, mu=False, pec=False, pmc=False):
common_source = jinja_env.get_template('common.cl').render(shape=shape, common_source = jinja_env.get_template('common.cl').render(shape=shape)
ctype=ctype)
pec_arg = ['char *pec'] pec_arg = ['char *pec']
pmc_arg = ['char *pmc'] pmc_arg = ['char *pmc']
des = [ctype + ' *inv_de' + a for a in 'xyz'] des = [ctype + ' *inv_de' + a for a in 'xyz']
dhs = [ctype + ' *inv_dh' + a for a in 'xyz'] dhs = [ctype + ' *inv_dh' + a for a in 'xyz']
p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec, p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec)
ctype=ctype)
P2E_kernel = ElementwiseKernel(context, P2E_kernel = ElementwiseKernel(context,
name='P2E', name='P2E',
preamble=preamble, preamble=preamble,

Loading…
Cancel
Save