diff --git a/opencl_fdtd/simulation.py b/opencl_fdtd/simulation.py index 525d818..7b94c64 100644 --- a/opencl_fdtd/simulation.py +++ b/opencl_fdtd/simulation.py @@ -76,8 +76,7 @@ class Simulation(object): epsilon: List[numpy.ndarray], pmls: List[Dict[str, int or float]], dt: float = .99/numpy.sqrt(3), - initial_E: List[numpy.ndarray] = None, - initial_H: List[numpy.ndarray] = None, + initial_fields: Dict[str, List[numpy.ndarray]] = None, context: pyopencl.Context = None, queue: pyopencl.CommandQueue = None, float_type: numpy.float32 or numpy.float64 = numpy.float32, @@ -113,21 +112,14 @@ class Simulation(object): * GPU memory requirements are approximately doubled, since S and the intermediate products must be stored. """ + if initial_fields is None: + initial_fields = {} - if len(epsilon) != 3: - Exception('Epsilon must be a list with length of 3') - if not all((e.shape == epsilon[0].shape for e in epsilon[1:])): - Exception('All epsilon grids must have the same shape. Shapes are {}', [e.shape for e in epsilon]) - - if context is None: - self.context = pyopencl.create_some_context() - else: - self.context = context - - if queue is None: - self.queue = pyopencl.CommandQueue(self.context) - else: - self.queue = queue + self.shape = epsilon[0].shape + self.arg_type = float_type + self.sources = {} + self._create_context(context, queue) + self._create_eps(epsilon) if dt > .99/numpy.sqrt(3): warnings.warn('Warning: unstable dt: {}'.format(dt)) @@ -136,27 +128,8 @@ class Simulation(object): else: self.dt = dt - self.arg_type = float_type - self.sources = {} - self.eps = pyopencl.array.to_device(self.queue, vec(epsilon).astype(float_type)) - - if initial_E is None: - self.E = pyopencl.array.zeros_like(self.eps) - else: - if len(initial_E) != 3: - Exception('Initial_E must be a list of length 3') - if not all((E.shape == epsilon[0].shape for E in initial_E)): - Exception('Initial_E list elements must have same shape as epsilon elements') - self.E = pyopencl.array.to_device(self.queue, vec(E).astype(float_type)) - - if initial_H is None: - self.H = pyopencl.array.zeros_like(self.eps) - else: - if len(initial_H) != 3: - Exception('Initial_H must be a list of length 3') - if not all((H.shape == epsilon[0].shape for H in initial_H)): - Exception('Initial_H list elements must have same shape as epsilon elements') - self.H = pyopencl.array.to_device(self.queue, vec(H).astype(float_type)) + self.E = self._create_field(initial_fields.get('E', None)) + self.H = self._create_field(initial_fields.get('H', None)) for pml in pmls: pml.setdefault('thickness', 8) @@ -181,7 +154,7 @@ class Simulation(object): common_source = jinja_env.get_template('common.cl').render( ftype=ctype, - shape=epsilon[0].shape, + shape=self.shape, ) jinja_args = { 'common_header': common_source, @@ -194,21 +167,37 @@ class Simulation(object): self.sources['E'] = E_source self.sources['H'] = H_source + + + S_fields = OrderedDict() if do_poynting: S_source = jinja_env.get_template('update_s.cl').render(**jinja_args) self.sources['S'] = S_source - self.oS = pyopencl.array.zeros(self.queue, self.E.shape + (2,), dtype=float_type) + self.oS = pyopencl.array.zeros(self.queue, self.E.shape + (2,), dtype=self.arg_type) self.S = pyopencl.array.zeros_like(self.E) - S_fields = OrderedDict() S_fields[ptr('oS')] = self.oS S_fields[ptr('S')] = self.S - else: - S_fields = OrderedDict() ''' PML ''' + pml_e_fields, pml_h_fields = self._create_pmls(pmls) + + ''' + Create operations + ''' + self.update_E = self._create_operation(E_source, (base_fields, eps_field, pml_e_fields)) + self.update_H = self._create_operation(H_source, (base_fields, pml_h_fields, S_fields)) + if do_poynting: + self.update_S = self._create_operation(S_source, (base_fields, S_fields)) + + + def _create_pmls(self, pmls): + ctype = type_to_C(self.arg_type) + def ptr(arg: str) -> str: + return ctype + ' *' + arg + pml_e_fields = OrderedDict() pml_h_fields = OrderedDict() for pml in pmls: @@ -225,7 +214,7 @@ class Simulation(object): p1 = sigma / (sigma + alpha) * (p0 - 1) return p0, p1 - xe, xh = (numpy.arange(1, pml['thickness'] + 1, dtype=float_type)[::-1] for _ in range(2)) + xe, xh = (numpy.arange(1, pml['thickness'] + 1, dtype=self.arg_type)[::-1] for _ in range(2)) if pml['polarity'] == 'p': xe -= 0.5 elif pml['polarity'] == 'n': @@ -240,39 +229,34 @@ class Simulation(object): psi_base = 'Psi_' + pml['axis'] + pml['polarity'] + '_' psi_names = [[psi_base + eh + c for c in uv] for eh in 'EH'] - psi_shape = list(epsilon[0].shape) + psi_shape = list(self.shape) psi_shape[a] = pml['thickness'] for ne, nh in zip(*psi_names): pml_e_fields[ptr(ne)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type) pml_h_fields[ptr(nh)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type) + return pml_e_fields, pml_h_fields - self.pml_e_fields = pml_e_fields - self.pml_h_fields = pml_h_fields + def _create_operation(self, source, args_fields): + args = OrderedDict() + [args.update(d) for d in args_fields] + update = ElementwiseKernel(self.context, operation=source, + arguments=', '.join(args.keys())) + return lambda e: update(*args.values(), wait_for=e) - ''' - Create operations - ''' - E_args = OrderedDict() - [E_args.update(d) for d in (base_fields, eps_field, pml_e_fields)] - E_update = ElementwiseKernel(self.context, operation=E_source, - arguments=', '.join(E_args.keys())) + def _create_context(self, context: pyopencl.Context = None, + queue: pyopencl.CommandQueue = None): + if context is None: + self.context = pyopencl.create_some_context() + else: + self.context = context - H_args = OrderedDict() - [H_args.update(d) for d in (base_fields, pml_h_fields, S_fields)] - H_update = ElementwiseKernel(self.context, operation=H_source, - arguments=', '.join(H_args.keys())) - self.update_E = lambda e: E_update(*E_args.values(), wait_for=e) - self.update_H = lambda e: H_update(*H_args.values(), wait_for=e) + if queue is None: + self.queue = pyopencl.CommandQueue(self.context) + else: + self.queue = queue - if do_poynting: - S_args = OrderedDict() - [S_args.update(d) for d in (base_fields, S_fields)] - S_update = ElementwiseKernel(self.context, operation=S_source, - arguments=', '.join(S_args.keys())) - - self.update_S = lambda e: S_update(*S_args.values(), wait_for=e) def _create_eps(self, epsilon: List[numpy.ndarray]): if len(epsilon) != 3: raise Exception('Epsilon must be a list with length of 3')