forked from jan/opencl_fdfd
fix bugs after refactor
This commit is contained in:
parent
ff3951ba35
commit
8e3706948e
@ -1,21 +1,13 @@
|
|||||||
/* Common code for E, H updates
|
/* Common code for E, H updates
|
||||||
*
|
*
|
||||||
* Template parameters:
|
* Template parameters:
|
||||||
* ctype string denoting type for storing complex field values
|
|
||||||
* shape list of 3 ints specifying shape of fields
|
* shape list of 3 ints specifying shape of fields
|
||||||
*/
|
*/
|
||||||
|
|
||||||
//Defines to clean up operation names
|
|
||||||
#define ctype {{ctype}}_t
|
|
||||||
#define zero {{ctype}}_new(0.0, 0.0)
|
|
||||||
#define add {{ctype}}_add
|
|
||||||
#define sub {{ctype}}_sub
|
|
||||||
#define mul {{ctype}}_mul
|
|
||||||
|
|
||||||
// Field sizes
|
// Field sizes
|
||||||
const int sx = {shape[0]};
|
const int sx = {{shape[0]}};
|
||||||
const int sy = {shape[1]};
|
const int sy = {{shape[1]}};
|
||||||
const int sz = {shape[2]};
|
const int sz = {{shape[2]}};
|
||||||
|
|
||||||
//Since we use i to index into Ex[], E[], ... rather than E[], do nothing if
|
//Since we use i to index into Ex[], E[], ... rather than E[], do nothing if
|
||||||
// i is outside the bounds of Ex[].
|
// i is outside the bounds of Ex[].
|
||||||
|
@ -23,9 +23,9 @@ __global ctype *inv_mu_x = inv_mu + XX;
|
|||||||
__global ctype *inv_mu_y = inv_mu + YY;
|
__global ctype *inv_mu_y = inv_mu + YY;
|
||||||
__global ctype *inv_mu_z = inv_mu + ZZ;
|
__global ctype *inv_mu_z = inv_mu + ZZ;
|
||||||
|
|
||||||
__global ctype *pmc_x = pmc + XX;
|
__global char *pmc_x = pmc + XX;
|
||||||
__global ctype *pmc_y = pmc + YY;
|
__global char *pmc_y = pmc + YY;
|
||||||
__global ctype *pmc_z = pmc + ZZ;
|
__global char *pmc_z = pmc + ZZ;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Implement periodic boundary conditions
|
* Implement periodic boundary conditions
|
||||||
|
@ -24,9 +24,9 @@ __global ctype *oeps_x = oeps + XX;
|
|||||||
__global ctype *oeps_y = oeps + YY;
|
__global ctype *oeps_y = oeps + YY;
|
||||||
__global ctype *oeps_z = oeps + ZZ;
|
__global ctype *oeps_z = oeps + ZZ;
|
||||||
|
|
||||||
__global ctype *pec_x = pec + XX;
|
__global char *pec_x = pec + XX;
|
||||||
__global ctype *pec_y = pec + YY;
|
__global char *pec_y = pec + YY;
|
||||||
__global ctype *pec_z = pec + ZZ;
|
__global char *pec_z = pec + ZZ;
|
||||||
|
|
||||||
__global ctype *Pl_x = Pl + XX;
|
__global ctype *Pl_x = Pl + XX;
|
||||||
__global ctype *Pl_y = Pl + YY;
|
__global ctype *Pl_y = Pl + YY;
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
* Apply PEC and preconditioner.
|
* Apply PEC and preconditioner.
|
||||||
*
|
*
|
||||||
* Template parameters:
|
* Template parameters:
|
||||||
* ctype name of complex type (eg. cdouble)
|
|
||||||
* pec false iff no PEC anyhwere
|
* pec false iff no PEC anyhwere
|
||||||
*
|
*
|
||||||
* Arguments:
|
* Arguments:
|
||||||
@ -13,12 +12,6 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
//Defines to clean up operation names
|
|
||||||
#define ctype {{ctype}}_t
|
|
||||||
#define zero {{ctype}}_new(0.0, 0.0)
|
|
||||||
#define mul {{ctype}}_mul
|
|
||||||
|
|
||||||
|
|
||||||
{%- if pec -%}
|
{%- if pec -%}
|
||||||
if (pec[i] != 0) {
|
if (pec[i] != 0) {
|
||||||
E[i] = zero;
|
E[i] = zero;
|
||||||
|
@ -28,14 +28,20 @@ def type_to_C(float_type: numpy.float32 or numpy.float64) -> str:
|
|||||||
|
|
||||||
return types[float_type]
|
return types[float_type]
|
||||||
|
|
||||||
|
ctype = type_to_C(numpy.complex128)
|
||||||
|
ctype_bare = 'cdouble'
|
||||||
|
|
||||||
preamble = '''
|
preamble = '''
|
||||||
#define PYOPENCL_DEFINE_CDOUBLE
|
#define PYOPENCL_DEFINE_CDOUBLE
|
||||||
#include <pyopencl-complex.h>
|
#include <pyopencl-complex.h>
|
||||||
|
|
||||||
'''
|
//Defines to clean up operation and type names
|
||||||
|
#define ctype {ctype}_t
|
||||||
ctype = type_to_C(numpy.complex128)
|
#define zero {ctype}_new(0.0, 0.0)
|
||||||
|
#define add {ctype}_add
|
||||||
|
#define sub {ctype}_sub
|
||||||
|
#define mul {ctype}_mul
|
||||||
|
'''.format(ctype=ctype_bare)
|
||||||
|
|
||||||
|
|
||||||
def ptrs(*args):
|
def ptrs(*args):
|
||||||
@ -44,16 +50,14 @@ def ptrs(*args):
|
|||||||
|
|
||||||
def create_a(context, shape, mu=False, pec=False, pmc=False):
|
def create_a(context, shape, mu=False, pec=False, pmc=False):
|
||||||
|
|
||||||
common_source = jinja_env.get_template('common.cl').render(shape=shape,
|
common_source = jinja_env.get_template('common.cl').render(shape=shape)
|
||||||
ctype=ctype)
|
|
||||||
|
|
||||||
pec_arg = ['char *pec']
|
pec_arg = ['char *pec']
|
||||||
pmc_arg = ['char *pmc']
|
pmc_arg = ['char *pmc']
|
||||||
des = [ctype + ' *inv_de' + a for a in 'xyz']
|
des = [ctype + ' *inv_de' + a for a in 'xyz']
|
||||||
dhs = [ctype + ' *inv_dh' + a for a in 'xyz']
|
dhs = [ctype + ' *inv_dh' + a for a in 'xyz']
|
||||||
|
|
||||||
p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec,
|
p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec)
|
||||||
ctype=ctype)
|
|
||||||
P2E_kernel = ElementwiseKernel(context,
|
P2E_kernel = ElementwiseKernel(context,
|
||||||
name='P2E',
|
name='P2E',
|
||||||
preamble=preamble,
|
preamble=preamble,
|
||||||
|
Loading…
Reference in New Issue
Block a user