diff --git a/opencl_fdtd/__init__.py b/opencl_fdtd/__init__.py index 0dcc446..5bb721e 100644 --- a/opencl_fdtd/__init__.py +++ b/opencl_fdtd/__init__.py @@ -1,4 +1,7 @@ -from .simulation import Simulation, type_to_C +from .simulation import ( + Simulation as Simulation, + type_to_C as type_to_C, + ) __author__ = 'Jan Petykiewicz' __version__ = '0.4' diff --git a/opencl_fdtd/simulation.py b/opencl_fdtd/simulation.py index fbd282a..99ba3ad 100644 --- a/opencl_fdtd/simulation.py +++ b/opencl_fdtd/simulation.py @@ -2,12 +2,13 @@ Class for constructing and holding the basic FDTD operations and fields """ -from typing import Callable, Type, Sequence -from collections import OrderedDict +from collections.abc import Callable, Sequence import numpy from numpy.typing import NDArray +from numpy import floating, complexfloating import jinja2 import warnings +import logging import pyopencl import pyopencl.array @@ -16,13 +17,18 @@ from pyopencl.elementwise import ElementwiseKernel from meanas.fdmath import vec -__author__ = 'Jan Petykiewicz' +logger = logging.getLogger(__name__) # Create jinja2 env on module load jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__.split('.')[0], 'kernels')) +class FDTDError(Exception): + """ Custom exception for opencl_fdtd """ + pass + + class Simulation: r""" Constructs and holds the basic FDTD operations and related fields @@ -64,7 +70,7 @@ class Simulation: dt: float inv_dxes: list[pyopencl.array.Array] - arg_type: Type + arg_type: type context: pyopencl.Context queue: pyopencl.CommandQueue @@ -85,7 +91,7 @@ class Simulation: initial_fields: dict[str, NDArray] | None = None, context: pyopencl.Context | None = None, queue: pyopencl.CommandQueue | None = None, - float_type: Type = numpy.float32, + float_type: type = numpy.float32, do_poynting: bool = True, do_fieldsrc: bool = False, ) -> None: @@ -134,7 +140,7 @@ class Simulation: if dxes is None: dxes = 1.0 - if isinstance(dxes, (float, int)): + if isinstance(dxes, float | int): uniform_dx = dxes min_dx = dxes else: @@ -143,13 +149,14 @@ class Simulation: min_dx = min(min(dxn) for dxn in dxes[0] + dxes[1]) max_dt = min_dx * .99 / numpy.sqrt(3) + logger.info(f'{min_dx=}, {max_dt=}, {dt=}') if dt is None: self.dt = max_dt elif dt > max_dt: - warnings.warn(f'Warning: unstable dt: {dt}') + warnings.warn(f'Warning: unstable dt: {dt}', stacklevel=2) elif dt <= 0: - raise Exception(f'Invalid dt: {dt}') + raise FDTDError(f'Invalid dt: {dt}') else: self.dt = dt @@ -173,28 +180,31 @@ class Simulation: def ptr(arg: str) -> str: return ctype + ' *' + arg - base_fields = OrderedDict() - base_fields[ptr('E')] = self.E - base_fields[ptr('H')] = self.H - base_fields[ctype + ' dt'] = self.dt + base_fields = { + ptr('E'): self.E, + ptr('H'): self.H, + ctype + ' dt': self.dt, + } if uniform_dx is False: inv_dx_names = ['inv_d' + eh + r for eh in 'eh' for r in 'xyz'] - for name, field in zip(inv_dx_names, self.inv_dxes): + for name, field in zip(inv_dx_names, self.inv_dxes, strict=True): base_fields[ptr(name)] = field - eps_field = OrderedDict() - eps_field[ptr('eps')] = self.eps + eps_field = {ptr('eps'): self.eps} if bloch_boundaries: - base_fields[ptr('F')] = self.F - base_fields[ptr('G')] = self.G + base_fields |= { + ptr('F'): self.F, + ptr('G'): self.G, + } - bloch_fields = OrderedDict() - bloch_fields[ptr('E')] = self.F - bloch_fields[ptr('H')] = self.G - bloch_fields[ctype + ' dt'] = self.dt - bloch_fields[ptr('F')] = self.E - bloch_fields[ptr('G')] = self.H + bloch_fields = { + ptr('E'): self.F, + ptr('H'): self.G, + ctype + ' dt': self.dt, + ptr('F'): self.E, + ptr('G'): self.H, + } common_source = jinja_env.get_template('common.cl').render( ftype=ctype, @@ -216,18 +226,18 @@ class Simulation: if bloch_boundaries: bloch_args = jinja_args.copy() bloch_args['do_poynting'] = False - bloch_args['bloch'] = [ - {'axis': b['axis'], - 'real': b['imag'], - 'imag': b['real'], - } + bloch_args['bloch'] = [{ + 'axis': b['axis'], + 'real': b['imag'], + 'imag': b['real'], + } for b in bloch_boundaries] F_source = jinja_env.get_template('update_e.cl').render(**bloch_args) G_source = jinja_env.get_template('update_h.cl').render(**bloch_args) self.sources['F'] = F_source self.sources['G'] = G_source - S_fields = OrderedDict() + S_fields = {} if do_poynting: self.S = pyopencl.array.zeros_like(self.E) S_fields[ptr('S')] = self.S @@ -237,7 +247,7 @@ class Simulation: S_fields[ptr('S0')] = self.S0 S_fields[ptr('S1')] = self.S1 - J_fields = OrderedDict() + J_fields = {} if do_fieldsrc: J_source = jinja_env.get_template('update_j.cl').render(**jinja_args) self.sources['J'] = J_source @@ -247,37 +257,36 @@ class Simulation: J_fields[ptr('Jr')] = self.Jr J_fields[ptr('Ji')] = self.Ji - ''' - PML - ''' + # + # PML + # pml_e_fields, pml_h_fields = self._create_pmls(pmls) if bloch_boundaries: pml_f_fields, pml_g_fields = self._create_pmls(pmls) - ''' - Create operations - ''' + # + # Create operations + # self.update_E = self._create_operation(E_source, (base_fields, eps_field, pml_e_fields)) self.update_H = self._create_operation(H_source, (base_fields, pml_h_fields, S_fields)) if bloch_boundaries: self.update_F = self._create_operation(F_source, (bloch_fields, eps_field, pml_f_fields)) self.update_G = self._create_operation(G_source, (bloch_fields, pml_g_fields)) if do_fieldsrc: - args = OrderedDict() - [args.update(d) for d in (base_fields, J_fields)] + args = base_fields | J_fields var_args = [ctype + ' ' + v for v in 'cs'] + ['uint ' + r + m for r in 'xyz' for m in ('min', 'max')] update = ElementwiseKernel(self.context, operation=J_source, arguments=', '.join(list(args.keys()) + var_args)) self.update_J = lambda e, *a: update(*args.values(), *a, wait_for=e) - def _create_pmls(self, pmls): + def _create_pmls(self, pmls: Sequence[dict[str, float]]) -> tuple[dict[str, pyopencl.array.Array], ...]: ctype = type_to_C(self.arg_type) def ptr(arg: str) -> str: return ctype + ' *' + arg - pml_e_fields = OrderedDict() - pml_h_fields = OrderedDict() + pml_e_fields = {} + pml_h_fields = {} for pml in pmls: a = 'xyz'.find(pml['axis']) @@ -285,7 +294,9 @@ class Simulation: kappa_max = numpy.sqrt(pml['mu_eff'] * pml['epsilon_eff']) alpha_max = pml['cfs_alpha'] - def par(x): + print(sigma_max, kappa_max, alpha_max, pml['thickness'], self.dt) + + def par(x, pml=pml, sigma_max=sigma_max, kappa_max=kappa_max, alpha_max=alpha_max): # noqa: ANN001, ANN202 scaling = (x / pml['thickness']) ** pml['m'] sigma = scaling * sigma_max kappa = 1 + scaling * (kappa_max - 1) @@ -301,8 +312,13 @@ class Simulation: elif pml['polarity'] == 'n': xh -= 0.5 + logger.debug(f'{pml=}') + logger.debug(f'{xe=}') + logger.debug(f'{xh=}') + logger.debug(f'{par(xe)=}') + logger.debug(f'{par(xh)=}') pml_p_names = [['p' + pml['axis'] + i + eh + pml['polarity'] for i in '012'] for eh in 'eh'] - for name_e, name_h, pe, ph in zip(pml_p_names[0], pml_p_names[1], par(xe), par(xh)): + for name_e, name_h, pe, ph in zip(pml_p_names[0], pml_p_names[1], par(xe), par(xh), strict=True): pml_e_fields[ptr(name_e)] = pyopencl.array.to_device(self.queue, pe) pml_h_fields[ptr(name_h)] = pyopencl.array.to_device(self.queue, ph) @@ -313,13 +329,13 @@ class Simulation: psi_shape = list(self.shape) psi_shape[a] = pml['thickness'] - for ne, nh in zip(*psi_names): + for ne, nh in zip(*psi_names, strict=True): pml_e_fields[ptr(ne)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type) pml_h_fields[ptr(nh)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type) return pml_e_fields, pml_h_fields - def _create_operation(self, source, args_fields) -> Callable[..., pyopencl.Event]: - args = OrderedDict() + def _create_operation(self, source: str, args_fields: Sequence[dict[str, pyopencl.array.Array]]) -> Callable[..., pyopencl.Event]: + args = {} for d in args_fields: args.update(d) update = ElementwiseKernel( @@ -334,38 +350,30 @@ class Simulation: context: pyopencl.Context | None = None, queue: pyopencl.CommandQueue | None = None, ) -> None: - if context is None: - self.context = pyopencl.create_some_context() - else: - self.context = context - - if queue is None: - self.queue = pyopencl.CommandQueue(self.context) - else: - self.queue = queue + self.context = context or pyopencl.create_some_context() + self.queue = queue or pyopencl.CommandQueue(self.context) def _create_eps(self, epsilon: NDArray) -> pyopencl.array.Array: if len(epsilon) != 3: - raise Exception('Epsilon must be a list with length of 3') - if not all((e.shape == epsilon[0].shape for e in epsilon[1:])): - raise Exception('All epsilon grids must have the same shape. Shapes are {}', [e.shape for e in epsilon]) + raise FDTDError('Epsilon must be a list with length of 3') + if not all(e.shape == epsilon[0].shape for e in epsilon[1:]): + raise FDTDError('All epsilon grids must have the same shape. Shapes are {}', [e.shape for e in epsilon]) if not epsilon[0].shape == self.shape: - raise Exception(f'Epsilon shape mismatch. Expected {self.shape}, got {epsilon[0].shape}') + raise FDTDError(f'Epsilon shape mismatch. Expected {self.shape}, got {epsilon[0].shape}') self.eps = pyopencl.array.to_device(self.queue, vec(epsilon).astype(self.arg_type)) def _create_field(self, initial_value: NDArray | None = None) -> pyopencl.array.Array: if initial_value is None: return pyopencl.array.zeros_like(self.eps) - else: - if len(initial_value) != 3: - Exception('Initial field value must be a list of length 3') - if not all((f.shape == self.shape for f in initial_value)): - Exception('Initial field list elements must have same shape as epsilon elements') - return pyopencl.array.to_device(self.queue, vec(initial_value).astype(self.arg_type)) + if len(initial_value) != 3: + raise FDTDError('Initial field value must be a list of length 3') + if not all(f.shape == self.shape for f in initial_value): + raise FDTDError('Initial field list elements must have same shape as epsilon elements') + return pyopencl.array.to_device(self.queue, vec(initial_value).astype(self.arg_type)) def type_to_C( - float_type: Type, + float_type: type, ) -> str: """ Returns a string corresponding to the C equivalent of a numpy type. @@ -385,7 +393,7 @@ def type_to_C( numpy.complex128: 'cdouble_t', } if float_type not in types: - raise Exception(f'Unsupported type: {float_type}') + raise FDTDError(f'Unsupported type: {float_type}') return types[float_type]