diff --git a/opencl_fdtd/simulation.py b/opencl_fdtd/simulation.py index 78460ec..525d818 100644 --- a/opencl_fdtd/simulation.py +++ b/opencl_fdtd/simulation.py @@ -273,7 +273,24 @@ class Simulation(object): arguments=', '.join(S_args.keys())) self.update_S = lambda e: S_update(*S_args.values(), wait_for=e) + def _create_eps(self, epsilon: List[numpy.ndarray]): + if len(epsilon) != 3: + raise Exception('Epsilon must be a list with length of 3') + if not all((e.shape == epsilon[0].shape for e in epsilon[1:])): + raise Exception('All epsilon grids must have the same shape. Shapes are {}', [e.shape for e in epsilon]) + if not epsilon[0].shape == self.shape: + raise Exception('Epsilon shape mismatch. Expected {}, got {}'.format(self.shape, epsilon[0].shape)) + self.eps = pyopencl.array.to_device(self.queue, vec(epsilon).astype(self.arg_type)) + def _create_field(self, initial_value: List[numpy.ndarray] = None): + if initial_value is None: + return pyopencl.array.zeros_like(self.eps) + else: + if len(initial_value) != 3: + Exception('Initial field value must be a list of length 3') + if not all((f.shape == self.shape for f in initial_value)): + Exception('Initial field list elements must have same shape as epsilon elements') + return pyopencl.array.to_device(self.queue, vec(initial_value).astype(self.arg_type)) def type_to_C(float_type) -> str: """