diff --git a/opencl_fdfd/kernels/common.cl b/opencl_fdfd/kernels/common.cl index d429089..41e811f 100644 --- a/opencl_fdfd/kernels/common.cl +++ b/opencl_fdfd/kernels/common.cl @@ -1,21 +1,13 @@ /* Common code for E, H updates * * Template parameters: - * ctype string denoting type for storing complex field values * shape list of 3 ints specifying shape of fields */ -//Defines to clean up operation names -#define ctype {{ctype}}_t -#define zero {{ctype}}_new(0.0, 0.0) -#define add {{ctype}}_add -#define sub {{ctype}}_sub -#define mul {{ctype}}_mul - // Field sizes -const int sx = {shape[0]}; -const int sy = {shape[1]}; -const int sz = {shape[2]}; +const int sx = {{shape[0]}}; +const int sy = {{shape[1]}}; +const int sz = {{shape[2]}}; //Since we use i to index into Ex[], E[], ... rather than E[], do nothing if // i is outside the bounds of Ex[]. diff --git a/opencl_fdfd/kernels/e2h.cl b/opencl_fdfd/kernels/e2h.cl index 0332252..8e2a076 100644 --- a/opencl_fdfd/kernels/e2h.cl +++ b/opencl_fdfd/kernels/e2h.cl @@ -23,9 +23,9 @@ __global ctype *inv_mu_x = inv_mu + XX; __global ctype *inv_mu_y = inv_mu + YY; __global ctype *inv_mu_z = inv_mu + ZZ; -__global ctype *pmc_x = pmc + XX; -__global ctype *pmc_y = pmc + YY; -__global ctype *pmc_z = pmc + ZZ; +__global char *pmc_x = pmc + XX; +__global char *pmc_y = pmc + YY; +__global char *pmc_z = pmc + ZZ; /* * Implement periodic boundary conditions diff --git a/opencl_fdfd/kernels/h2e.cl b/opencl_fdfd/kernels/h2e.cl index 9c51a25..d71a78c 100644 --- a/opencl_fdfd/kernels/h2e.cl +++ b/opencl_fdfd/kernels/h2e.cl @@ -24,9 +24,9 @@ __global ctype *oeps_x = oeps + XX; __global ctype *oeps_y = oeps + YY; __global ctype *oeps_z = oeps + ZZ; -__global ctype *pec_x = pec + XX; -__global ctype *pec_y = pec + YY; -__global ctype *pec_z = pec + ZZ; +__global char *pec_x = pec + XX; +__global char *pec_y = pec + YY; +__global char *pec_z = pec + ZZ; __global ctype *Pl_x = Pl + XX; __global ctype *Pl_y = Pl + YY; diff --git a/opencl_fdfd/kernels/p2e.cl b/opencl_fdfd/kernels/p2e.cl index ba2c87e..20792cd 100644 --- a/opencl_fdfd/kernels/p2e.cl +++ b/opencl_fdfd/kernels/p2e.cl @@ -2,7 +2,6 @@ * Apply PEC and preconditioner. * * Template parameters: - * ctype name of complex type (eg. cdouble) * pec false iff no PEC anyhwere * * Arguments: @@ -13,12 +12,6 @@ */ -//Defines to clean up operation names -#define ctype {{ctype}}_t -#define zero {{ctype}}_new(0.0, 0.0) -#define mul {{ctype}}_mul - - {%- if pec -%} if (pec[i] != 0) { E[i] = zero; diff --git a/opencl_fdfd/ops.py b/opencl_fdfd/ops.py index 108201b..f0b5433 100644 --- a/opencl_fdfd/ops.py +++ b/opencl_fdfd/ops.py @@ -28,14 +28,20 @@ def type_to_C(float_type: numpy.float32 or numpy.float64) -> str: return types[float_type] +ctype = type_to_C(numpy.complex128) +ctype_bare = 'cdouble' preamble = ''' #define PYOPENCL_DEFINE_CDOUBLE #include -''' - -ctype = type_to_C(numpy.complex128) +//Defines to clean up operation and type names +#define ctype {ctype}_t +#define zero {ctype}_new(0.0, 0.0) +#define add {ctype}_add +#define sub {ctype}_sub +#define mul {ctype}_mul +'''.format(ctype=ctype_bare) def ptrs(*args): @@ -44,16 +50,14 @@ def ptrs(*args): def create_a(context, shape, mu=False, pec=False, pmc=False): - common_source = jinja_env.get_template('common.cl').render(shape=shape, - ctype=ctype) + common_source = jinja_env.get_template('common.cl').render(shape=shape) pec_arg = ['char *pec'] pmc_arg = ['char *pmc'] des = [ctype + ' *inv_de' + a for a in 'xyz'] dhs = [ctype + ' *inv_dh' + a for a in 'xyz'] - p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec, - ctype=ctype) + p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec) P2E_kernel = ElementwiseKernel(context, name='P2E', preamble=preamble,