fixes driven by ruff & mypy
This commit is contained in:
parent
b703f1ee20
commit
50b30d31fb
@ -1,4 +1,7 @@
|
||||
from .simulation import Simulation, type_to_C
|
||||
from .simulation import (
|
||||
Simulation as Simulation,
|
||||
type_to_C as type_to_C,
|
||||
)
|
||||
|
||||
__author__ = 'Jan Petykiewicz'
|
||||
__version__ = '0.4'
|
||||
|
@ -2,12 +2,13 @@
|
||||
Class for constructing and holding the basic FDTD operations and fields
|
||||
"""
|
||||
|
||||
from typing import Callable, Type, Sequence
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable, Sequence
|
||||
import numpy
|
||||
from numpy.typing import NDArray
|
||||
from numpy import floating, complexfloating
|
||||
import jinja2
|
||||
import warnings
|
||||
import logging
|
||||
|
||||
import pyopencl
|
||||
import pyopencl.array
|
||||
@ -16,13 +17,18 @@ from pyopencl.elementwise import ElementwiseKernel
|
||||
from meanas.fdmath import vec
|
||||
|
||||
|
||||
__author__ = 'Jan Petykiewicz'
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Create jinja2 env on module load
|
||||
jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__.split('.')[0], 'kernels'))
|
||||
|
||||
|
||||
class FDTDError(Exception):
|
||||
""" Custom exception for opencl_fdtd """
|
||||
pass
|
||||
|
||||
|
||||
class Simulation:
|
||||
r"""
|
||||
Constructs and holds the basic FDTD operations and related fields
|
||||
@ -64,7 +70,7 @@ class Simulation:
|
||||
dt: float
|
||||
inv_dxes: list[pyopencl.array.Array]
|
||||
|
||||
arg_type: Type
|
||||
arg_type: type
|
||||
|
||||
context: pyopencl.Context
|
||||
queue: pyopencl.CommandQueue
|
||||
@ -85,7 +91,7 @@ class Simulation:
|
||||
initial_fields: dict[str, NDArray] | None = None,
|
||||
context: pyopencl.Context | None = None,
|
||||
queue: pyopencl.CommandQueue | None = None,
|
||||
float_type: Type = numpy.float32,
|
||||
float_type: type = numpy.float32,
|
||||
do_poynting: bool = True,
|
||||
do_fieldsrc: bool = False,
|
||||
) -> None:
|
||||
@ -134,7 +140,7 @@ class Simulation:
|
||||
if dxes is None:
|
||||
dxes = 1.0
|
||||
|
||||
if isinstance(dxes, (float, int)):
|
||||
if isinstance(dxes, float | int):
|
||||
uniform_dx = dxes
|
||||
min_dx = dxes
|
||||
else:
|
||||
@ -143,13 +149,14 @@ class Simulation:
|
||||
min_dx = min(min(dxn) for dxn in dxes[0] + dxes[1])
|
||||
|
||||
max_dt = min_dx * .99 / numpy.sqrt(3)
|
||||
logger.info(f'{min_dx=}, {max_dt=}, {dt=}')
|
||||
|
||||
if dt is None:
|
||||
self.dt = max_dt
|
||||
elif dt > max_dt:
|
||||
warnings.warn(f'Warning: unstable dt: {dt}')
|
||||
warnings.warn(f'Warning: unstable dt: {dt}', stacklevel=2)
|
||||
elif dt <= 0:
|
||||
raise Exception(f'Invalid dt: {dt}')
|
||||
raise FDTDError(f'Invalid dt: {dt}')
|
||||
else:
|
||||
self.dt = dt
|
||||
|
||||
@ -173,28 +180,31 @@ class Simulation:
|
||||
def ptr(arg: str) -> str:
|
||||
return ctype + ' *' + arg
|
||||
|
||||
base_fields = OrderedDict()
|
||||
base_fields[ptr('E')] = self.E
|
||||
base_fields[ptr('H')] = self.H
|
||||
base_fields[ctype + ' dt'] = self.dt
|
||||
base_fields = {
|
||||
ptr('E'): self.E,
|
||||
ptr('H'): self.H,
|
||||
ctype + ' dt': self.dt,
|
||||
}
|
||||
if uniform_dx is False:
|
||||
inv_dx_names = ['inv_d' + eh + r for eh in 'eh' for r in 'xyz']
|
||||
for name, field in zip(inv_dx_names, self.inv_dxes):
|
||||
for name, field in zip(inv_dx_names, self.inv_dxes, strict=True):
|
||||
base_fields[ptr(name)] = field
|
||||
|
||||
eps_field = OrderedDict()
|
||||
eps_field[ptr('eps')] = self.eps
|
||||
eps_field = {ptr('eps'): self.eps}
|
||||
|
||||
if bloch_boundaries:
|
||||
base_fields[ptr('F')] = self.F
|
||||
base_fields[ptr('G')] = self.G
|
||||
base_fields |= {
|
||||
ptr('F'): self.F,
|
||||
ptr('G'): self.G,
|
||||
}
|
||||
|
||||
bloch_fields = OrderedDict()
|
||||
bloch_fields[ptr('E')] = self.F
|
||||
bloch_fields[ptr('H')] = self.G
|
||||
bloch_fields[ctype + ' dt'] = self.dt
|
||||
bloch_fields[ptr('F')] = self.E
|
||||
bloch_fields[ptr('G')] = self.H
|
||||
bloch_fields = {
|
||||
ptr('E'): self.F,
|
||||
ptr('H'): self.G,
|
||||
ctype + ' dt': self.dt,
|
||||
ptr('F'): self.E,
|
||||
ptr('G'): self.H,
|
||||
}
|
||||
|
||||
common_source = jinja_env.get_template('common.cl').render(
|
||||
ftype=ctype,
|
||||
@ -216,8 +226,8 @@ class Simulation:
|
||||
if bloch_boundaries:
|
||||
bloch_args = jinja_args.copy()
|
||||
bloch_args['do_poynting'] = False
|
||||
bloch_args['bloch'] = [
|
||||
{'axis': b['axis'],
|
||||
bloch_args['bloch'] = [{
|
||||
'axis': b['axis'],
|
||||
'real': b['imag'],
|
||||
'imag': b['real'],
|
||||
}
|
||||
@ -227,7 +237,7 @@ class Simulation:
|
||||
self.sources['F'] = F_source
|
||||
self.sources['G'] = G_source
|
||||
|
||||
S_fields = OrderedDict()
|
||||
S_fields = {}
|
||||
if do_poynting:
|
||||
self.S = pyopencl.array.zeros_like(self.E)
|
||||
S_fields[ptr('S')] = self.S
|
||||
@ -237,7 +247,7 @@ class Simulation:
|
||||
S_fields[ptr('S0')] = self.S0
|
||||
S_fields[ptr('S1')] = self.S1
|
||||
|
||||
J_fields = OrderedDict()
|
||||
J_fields = {}
|
||||
if do_fieldsrc:
|
||||
J_source = jinja_env.get_template('update_j.cl').render(**jinja_args)
|
||||
self.sources['J'] = J_source
|
||||
@ -247,37 +257,36 @@ class Simulation:
|
||||
J_fields[ptr('Jr')] = self.Jr
|
||||
J_fields[ptr('Ji')] = self.Ji
|
||||
|
||||
'''
|
||||
PML
|
||||
'''
|
||||
#
|
||||
# PML
|
||||
#
|
||||
pml_e_fields, pml_h_fields = self._create_pmls(pmls)
|
||||
if bloch_boundaries:
|
||||
pml_f_fields, pml_g_fields = self._create_pmls(pmls)
|
||||
|
||||
'''
|
||||
Create operations
|
||||
'''
|
||||
#
|
||||
# Create operations
|
||||
#
|
||||
self.update_E = self._create_operation(E_source, (base_fields, eps_field, pml_e_fields))
|
||||
self.update_H = self._create_operation(H_source, (base_fields, pml_h_fields, S_fields))
|
||||
if bloch_boundaries:
|
||||
self.update_F = self._create_operation(F_source, (bloch_fields, eps_field, pml_f_fields))
|
||||
self.update_G = self._create_operation(G_source, (bloch_fields, pml_g_fields))
|
||||
if do_fieldsrc:
|
||||
args = OrderedDict()
|
||||
[args.update(d) for d in (base_fields, J_fields)]
|
||||
args = base_fields | J_fields
|
||||
var_args = [ctype + ' ' + v for v in 'cs'] + ['uint ' + r + m for r in 'xyz' for m in ('min', 'max')]
|
||||
update = ElementwiseKernel(self.context, operation=J_source,
|
||||
arguments=', '.join(list(args.keys()) + var_args))
|
||||
self.update_J = lambda e, *a: update(*args.values(), *a, wait_for=e)
|
||||
|
||||
def _create_pmls(self, pmls):
|
||||
def _create_pmls(self, pmls: Sequence[dict[str, float]]) -> tuple[dict[str, pyopencl.array.Array], ...]:
|
||||
ctype = type_to_C(self.arg_type)
|
||||
|
||||
def ptr(arg: str) -> str:
|
||||
return ctype + ' *' + arg
|
||||
|
||||
pml_e_fields = OrderedDict()
|
||||
pml_h_fields = OrderedDict()
|
||||
pml_e_fields = {}
|
||||
pml_h_fields = {}
|
||||
for pml in pmls:
|
||||
a = 'xyz'.find(pml['axis'])
|
||||
|
||||
@ -285,7 +294,9 @@ class Simulation:
|
||||
kappa_max = numpy.sqrt(pml['mu_eff'] * pml['epsilon_eff'])
|
||||
alpha_max = pml['cfs_alpha']
|
||||
|
||||
def par(x):
|
||||
print(sigma_max, kappa_max, alpha_max, pml['thickness'], self.dt)
|
||||
|
||||
def par(x, pml=pml, sigma_max=sigma_max, kappa_max=kappa_max, alpha_max=alpha_max): # noqa: ANN001, ANN202
|
||||
scaling = (x / pml['thickness']) ** pml['m']
|
||||
sigma = scaling * sigma_max
|
||||
kappa = 1 + scaling * (kappa_max - 1)
|
||||
@ -301,8 +312,13 @@ class Simulation:
|
||||
elif pml['polarity'] == 'n':
|
||||
xh -= 0.5
|
||||
|
||||
logger.debug(f'{pml=}')
|
||||
logger.debug(f'{xe=}')
|
||||
logger.debug(f'{xh=}')
|
||||
logger.debug(f'{par(xe)=}')
|
||||
logger.debug(f'{par(xh)=}')
|
||||
pml_p_names = [['p' + pml['axis'] + i + eh + pml['polarity'] for i in '012'] for eh in 'eh']
|
||||
for name_e, name_h, pe, ph in zip(pml_p_names[0], pml_p_names[1], par(xe), par(xh)):
|
||||
for name_e, name_h, pe, ph in zip(pml_p_names[0], pml_p_names[1], par(xe), par(xh), strict=True):
|
||||
pml_e_fields[ptr(name_e)] = pyopencl.array.to_device(self.queue, pe)
|
||||
pml_h_fields[ptr(name_h)] = pyopencl.array.to_device(self.queue, ph)
|
||||
|
||||
@ -313,13 +329,13 @@ class Simulation:
|
||||
psi_shape = list(self.shape)
|
||||
psi_shape[a] = pml['thickness']
|
||||
|
||||
for ne, nh in zip(*psi_names):
|
||||
for ne, nh in zip(*psi_names, strict=True):
|
||||
pml_e_fields[ptr(ne)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type)
|
||||
pml_h_fields[ptr(nh)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type)
|
||||
return pml_e_fields, pml_h_fields
|
||||
|
||||
def _create_operation(self, source, args_fields) -> Callable[..., pyopencl.Event]:
|
||||
args = OrderedDict()
|
||||
def _create_operation(self, source: str, args_fields: Sequence[dict[str, pyopencl.array.Array]]) -> Callable[..., pyopencl.Event]:
|
||||
args = {}
|
||||
for d in args_fields:
|
||||
args.update(d)
|
||||
update = ElementwiseKernel(
|
||||
@ -334,38 +350,30 @@ class Simulation:
|
||||
context: pyopencl.Context | None = None,
|
||||
queue: pyopencl.CommandQueue | None = None,
|
||||
) -> None:
|
||||
if context is None:
|
||||
self.context = pyopencl.create_some_context()
|
||||
else:
|
||||
self.context = context
|
||||
|
||||
if queue is None:
|
||||
self.queue = pyopencl.CommandQueue(self.context)
|
||||
else:
|
||||
self.queue = queue
|
||||
self.context = context or pyopencl.create_some_context()
|
||||
self.queue = queue or pyopencl.CommandQueue(self.context)
|
||||
|
||||
def _create_eps(self, epsilon: NDArray) -> pyopencl.array.Array:
|
||||
if len(epsilon) != 3:
|
||||
raise Exception('Epsilon must be a list with length of 3')
|
||||
if not all((e.shape == epsilon[0].shape for e in epsilon[1:])):
|
||||
raise Exception('All epsilon grids must have the same shape. Shapes are {}', [e.shape for e in epsilon])
|
||||
raise FDTDError('Epsilon must be a list with length of 3')
|
||||
if not all(e.shape == epsilon[0].shape for e in epsilon[1:]):
|
||||
raise FDTDError('All epsilon grids must have the same shape. Shapes are {}', [e.shape for e in epsilon])
|
||||
if not epsilon[0].shape == self.shape:
|
||||
raise Exception(f'Epsilon shape mismatch. Expected {self.shape}, got {epsilon[0].shape}')
|
||||
raise FDTDError(f'Epsilon shape mismatch. Expected {self.shape}, got {epsilon[0].shape}')
|
||||
self.eps = pyopencl.array.to_device(self.queue, vec(epsilon).astype(self.arg_type))
|
||||
|
||||
def _create_field(self, initial_value: NDArray | None = None) -> pyopencl.array.Array:
|
||||
if initial_value is None:
|
||||
return pyopencl.array.zeros_like(self.eps)
|
||||
else:
|
||||
if len(initial_value) != 3:
|
||||
Exception('Initial field value must be a list of length 3')
|
||||
if not all((f.shape == self.shape for f in initial_value)):
|
||||
Exception('Initial field list elements must have same shape as epsilon elements')
|
||||
raise FDTDError('Initial field value must be a list of length 3')
|
||||
if not all(f.shape == self.shape for f in initial_value):
|
||||
raise FDTDError('Initial field list elements must have same shape as epsilon elements')
|
||||
return pyopencl.array.to_device(self.queue, vec(initial_value).astype(self.arg_type))
|
||||
|
||||
|
||||
def type_to_C(
|
||||
float_type: Type,
|
||||
float_type: type,
|
||||
) -> str:
|
||||
"""
|
||||
Returns a string corresponding to the C equivalent of a numpy type.
|
||||
@ -385,7 +393,7 @@ def type_to_C(
|
||||
numpy.complex128: 'cdouble_t',
|
||||
}
|
||||
if float_type not in types:
|
||||
raise Exception(f'Unsupported type: {float_type}')
|
||||
raise FDTDError(f'Unsupported type: {float_type}')
|
||||
|
||||
return types[float_type]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user