fixes driven by ruff & mypy

This commit is contained in:
Jan Petykiewicz 2024-07-30 23:18:54 -07:00
parent b703f1ee20
commit 50b30d31fb
2 changed files with 79 additions and 68 deletions

View File

@ -1,4 +1,7 @@
from .simulation import Simulation, type_to_C
from .simulation import (
Simulation as Simulation,
type_to_C as type_to_C,
)
__author__ = 'Jan Petykiewicz'
__version__ = '0.4'

View File

@ -2,12 +2,13 @@
Class for constructing and holding the basic FDTD operations and fields
"""
from typing import Callable, Type, Sequence
from collections import OrderedDict
from collections.abc import Callable, Sequence
import numpy
from numpy.typing import NDArray
from numpy import floating, complexfloating
import jinja2
import warnings
import logging
import pyopencl
import pyopencl.array
@ -16,13 +17,18 @@ from pyopencl.elementwise import ElementwiseKernel
from meanas.fdmath import vec
__author__ = 'Jan Petykiewicz'
logger = logging.getLogger(__name__)
# Create jinja2 env on module load
jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__.split('.')[0], 'kernels'))
class FDTDError(Exception):
""" Custom exception for opencl_fdtd """
pass
class Simulation:
r"""
Constructs and holds the basic FDTD operations and related fields
@ -64,7 +70,7 @@ class Simulation:
dt: float
inv_dxes: list[pyopencl.array.Array]
arg_type: Type
arg_type: type
context: pyopencl.Context
queue: pyopencl.CommandQueue
@ -85,7 +91,7 @@ class Simulation:
initial_fields: dict[str, NDArray] | None = None,
context: pyopencl.Context | None = None,
queue: pyopencl.CommandQueue | None = None,
float_type: Type = numpy.float32,
float_type: type = numpy.float32,
do_poynting: bool = True,
do_fieldsrc: bool = False,
) -> None:
@ -134,7 +140,7 @@ class Simulation:
if dxes is None:
dxes = 1.0
if isinstance(dxes, (float, int)):
if isinstance(dxes, float | int):
uniform_dx = dxes
min_dx = dxes
else:
@ -143,13 +149,14 @@ class Simulation:
min_dx = min(min(dxn) for dxn in dxes[0] + dxes[1])
max_dt = min_dx * .99 / numpy.sqrt(3)
logger.info(f'{min_dx=}, {max_dt=}, {dt=}')
if dt is None:
self.dt = max_dt
elif dt > max_dt:
warnings.warn(f'Warning: unstable dt: {dt}')
warnings.warn(f'Warning: unstable dt: {dt}', stacklevel=2)
elif dt <= 0:
raise Exception(f'Invalid dt: {dt}')
raise FDTDError(f'Invalid dt: {dt}')
else:
self.dt = dt
@ -173,28 +180,31 @@ class Simulation:
def ptr(arg: str) -> str:
return ctype + ' *' + arg
base_fields = OrderedDict()
base_fields[ptr('E')] = self.E
base_fields[ptr('H')] = self.H
base_fields[ctype + ' dt'] = self.dt
base_fields = {
ptr('E'): self.E,
ptr('H'): self.H,
ctype + ' dt': self.dt,
}
if uniform_dx is False:
inv_dx_names = ['inv_d' + eh + r for eh in 'eh' for r in 'xyz']
for name, field in zip(inv_dx_names, self.inv_dxes):
for name, field in zip(inv_dx_names, self.inv_dxes, strict=True):
base_fields[ptr(name)] = field
eps_field = OrderedDict()
eps_field[ptr('eps')] = self.eps
eps_field = {ptr('eps'): self.eps}
if bloch_boundaries:
base_fields[ptr('F')] = self.F
base_fields[ptr('G')] = self.G
base_fields |= {
ptr('F'): self.F,
ptr('G'): self.G,
}
bloch_fields = OrderedDict()
bloch_fields[ptr('E')] = self.F
bloch_fields[ptr('H')] = self.G
bloch_fields[ctype + ' dt'] = self.dt
bloch_fields[ptr('F')] = self.E
bloch_fields[ptr('G')] = self.H
bloch_fields = {
ptr('E'): self.F,
ptr('H'): self.G,
ctype + ' dt': self.dt,
ptr('F'): self.E,
ptr('G'): self.H,
}
common_source = jinja_env.get_template('common.cl').render(
ftype=ctype,
@ -216,18 +226,18 @@ class Simulation:
if bloch_boundaries:
bloch_args = jinja_args.copy()
bloch_args['do_poynting'] = False
bloch_args['bloch'] = [
{'axis': b['axis'],
'real': b['imag'],
'imag': b['real'],
}
bloch_args['bloch'] = [{
'axis': b['axis'],
'real': b['imag'],
'imag': b['real'],
}
for b in bloch_boundaries]
F_source = jinja_env.get_template('update_e.cl').render(**bloch_args)
G_source = jinja_env.get_template('update_h.cl').render(**bloch_args)
self.sources['F'] = F_source
self.sources['G'] = G_source
S_fields = OrderedDict()
S_fields = {}
if do_poynting:
self.S = pyopencl.array.zeros_like(self.E)
S_fields[ptr('S')] = self.S
@ -237,7 +247,7 @@ class Simulation:
S_fields[ptr('S0')] = self.S0
S_fields[ptr('S1')] = self.S1
J_fields = OrderedDict()
J_fields = {}
if do_fieldsrc:
J_source = jinja_env.get_template('update_j.cl').render(**jinja_args)
self.sources['J'] = J_source
@ -247,37 +257,36 @@ class Simulation:
J_fields[ptr('Jr')] = self.Jr
J_fields[ptr('Ji')] = self.Ji
'''
PML
'''
#
# PML
#
pml_e_fields, pml_h_fields = self._create_pmls(pmls)
if bloch_boundaries:
pml_f_fields, pml_g_fields = self._create_pmls(pmls)
'''
Create operations
'''
#
# Create operations
#
self.update_E = self._create_operation(E_source, (base_fields, eps_field, pml_e_fields))
self.update_H = self._create_operation(H_source, (base_fields, pml_h_fields, S_fields))
if bloch_boundaries:
self.update_F = self._create_operation(F_source, (bloch_fields, eps_field, pml_f_fields))
self.update_G = self._create_operation(G_source, (bloch_fields, pml_g_fields))
if do_fieldsrc:
args = OrderedDict()
[args.update(d) for d in (base_fields, J_fields)]
args = base_fields | J_fields
var_args = [ctype + ' ' + v for v in 'cs'] + ['uint ' + r + m for r in 'xyz' for m in ('min', 'max')]
update = ElementwiseKernel(self.context, operation=J_source,
arguments=', '.join(list(args.keys()) + var_args))
self.update_J = lambda e, *a: update(*args.values(), *a, wait_for=e)
def _create_pmls(self, pmls):
def _create_pmls(self, pmls: Sequence[dict[str, float]]) -> tuple[dict[str, pyopencl.array.Array], ...]:
ctype = type_to_C(self.arg_type)
def ptr(arg: str) -> str:
return ctype + ' *' + arg
pml_e_fields = OrderedDict()
pml_h_fields = OrderedDict()
pml_e_fields = {}
pml_h_fields = {}
for pml in pmls:
a = 'xyz'.find(pml['axis'])
@ -285,7 +294,9 @@ class Simulation:
kappa_max = numpy.sqrt(pml['mu_eff'] * pml['epsilon_eff'])
alpha_max = pml['cfs_alpha']
def par(x):
print(sigma_max, kappa_max, alpha_max, pml['thickness'], self.dt)
def par(x, pml=pml, sigma_max=sigma_max, kappa_max=kappa_max, alpha_max=alpha_max): # noqa: ANN001, ANN202
scaling = (x / pml['thickness']) ** pml['m']
sigma = scaling * sigma_max
kappa = 1 + scaling * (kappa_max - 1)
@ -301,8 +312,13 @@ class Simulation:
elif pml['polarity'] == 'n':
xh -= 0.5
logger.debug(f'{pml=}')
logger.debug(f'{xe=}')
logger.debug(f'{xh=}')
logger.debug(f'{par(xe)=}')
logger.debug(f'{par(xh)=}')
pml_p_names = [['p' + pml['axis'] + i + eh + pml['polarity'] for i in '012'] for eh in 'eh']
for name_e, name_h, pe, ph in zip(pml_p_names[0], pml_p_names[1], par(xe), par(xh)):
for name_e, name_h, pe, ph in zip(pml_p_names[0], pml_p_names[1], par(xe), par(xh), strict=True):
pml_e_fields[ptr(name_e)] = pyopencl.array.to_device(self.queue, pe)
pml_h_fields[ptr(name_h)] = pyopencl.array.to_device(self.queue, ph)
@ -313,13 +329,13 @@ class Simulation:
psi_shape = list(self.shape)
psi_shape[a] = pml['thickness']
for ne, nh in zip(*psi_names):
for ne, nh in zip(*psi_names, strict=True):
pml_e_fields[ptr(ne)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type)
pml_h_fields[ptr(nh)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type)
return pml_e_fields, pml_h_fields
def _create_operation(self, source, args_fields) -> Callable[..., pyopencl.Event]:
args = OrderedDict()
def _create_operation(self, source: str, args_fields: Sequence[dict[str, pyopencl.array.Array]]) -> Callable[..., pyopencl.Event]:
args = {}
for d in args_fields:
args.update(d)
update = ElementwiseKernel(
@ -334,38 +350,30 @@ class Simulation:
context: pyopencl.Context | None = None,
queue: pyopencl.CommandQueue | None = None,
) -> None:
if context is None:
self.context = pyopencl.create_some_context()
else:
self.context = context
if queue is None:
self.queue = pyopencl.CommandQueue(self.context)
else:
self.queue = queue
self.context = context or pyopencl.create_some_context()
self.queue = queue or pyopencl.CommandQueue(self.context)
def _create_eps(self, epsilon: NDArray) -> pyopencl.array.Array:
if len(epsilon) != 3:
raise Exception('Epsilon must be a list with length of 3')
if not all((e.shape == epsilon[0].shape for e in epsilon[1:])):
raise Exception('All epsilon grids must have the same shape. Shapes are {}', [e.shape for e in epsilon])
raise FDTDError('Epsilon must be a list with length of 3')
if not all(e.shape == epsilon[0].shape for e in epsilon[1:]):
raise FDTDError('All epsilon grids must have the same shape. Shapes are {}', [e.shape for e in epsilon])
if not epsilon[0].shape == self.shape:
raise Exception(f'Epsilon shape mismatch. Expected {self.shape}, got {epsilon[0].shape}')
raise FDTDError(f'Epsilon shape mismatch. Expected {self.shape}, got {epsilon[0].shape}')
self.eps = pyopencl.array.to_device(self.queue, vec(epsilon).astype(self.arg_type))
def _create_field(self, initial_value: NDArray | None = None) -> pyopencl.array.Array:
if initial_value is None:
return pyopencl.array.zeros_like(self.eps)
else:
if len(initial_value) != 3:
Exception('Initial field value must be a list of length 3')
if not all((f.shape == self.shape for f in initial_value)):
Exception('Initial field list elements must have same shape as epsilon elements')
return pyopencl.array.to_device(self.queue, vec(initial_value).astype(self.arg_type))
if len(initial_value) != 3:
raise FDTDError('Initial field value must be a list of length 3')
if not all(f.shape == self.shape for f in initial_value):
raise FDTDError('Initial field list elements must have same shape as epsilon elements')
return pyopencl.array.to_device(self.queue, vec(initial_value).astype(self.arg_type))
def type_to_C(
float_type: Type,
float_type: type,
) -> str:
"""
Returns a string corresponding to the C equivalent of a numpy type.
@ -385,7 +393,7 @@ def type_to_C(
numpy.complex128: 'cdouble_t',
}
if float_type not in types:
raise Exception(f'Unsupported type: {float_type}')
raise FDTDError(f'Unsupported type: {float_type}')
return types[float_type]