fixes driven by ruff & mypy
This commit is contained in:
parent
b703f1ee20
commit
50b30d31fb
@ -1,4 +1,7 @@
|
|||||||
from .simulation import Simulation, type_to_C
|
from .simulation import (
|
||||||
|
Simulation as Simulation,
|
||||||
|
type_to_C as type_to_C,
|
||||||
|
)
|
||||||
|
|
||||||
__author__ = 'Jan Petykiewicz'
|
__author__ = 'Jan Petykiewicz'
|
||||||
__version__ = '0.4'
|
__version__ = '0.4'
|
||||||
|
@ -2,12 +2,13 @@
|
|||||||
Class for constructing and holding the basic FDTD operations and fields
|
Class for constructing and holding the basic FDTD operations and fields
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Callable, Type, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from collections import OrderedDict
|
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
from numpy import floating, complexfloating
|
||||||
import jinja2
|
import jinja2
|
||||||
import warnings
|
import warnings
|
||||||
|
import logging
|
||||||
|
|
||||||
import pyopencl
|
import pyopencl
|
||||||
import pyopencl.array
|
import pyopencl.array
|
||||||
@ -16,13 +17,18 @@ from pyopencl.elementwise import ElementwiseKernel
|
|||||||
from meanas.fdmath import vec
|
from meanas.fdmath import vec
|
||||||
|
|
||||||
|
|
||||||
__author__ = 'Jan Petykiewicz'
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Create jinja2 env on module load
|
# Create jinja2 env on module load
|
||||||
jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__.split('.')[0], 'kernels'))
|
jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__.split('.')[0], 'kernels'))
|
||||||
|
|
||||||
|
|
||||||
|
class FDTDError(Exception):
|
||||||
|
""" Custom exception for opencl_fdtd """
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Simulation:
|
class Simulation:
|
||||||
r"""
|
r"""
|
||||||
Constructs and holds the basic FDTD operations and related fields
|
Constructs and holds the basic FDTD operations and related fields
|
||||||
@ -64,7 +70,7 @@ class Simulation:
|
|||||||
dt: float
|
dt: float
|
||||||
inv_dxes: list[pyopencl.array.Array]
|
inv_dxes: list[pyopencl.array.Array]
|
||||||
|
|
||||||
arg_type: Type
|
arg_type: type
|
||||||
|
|
||||||
context: pyopencl.Context
|
context: pyopencl.Context
|
||||||
queue: pyopencl.CommandQueue
|
queue: pyopencl.CommandQueue
|
||||||
@ -85,7 +91,7 @@ class Simulation:
|
|||||||
initial_fields: dict[str, NDArray] | None = None,
|
initial_fields: dict[str, NDArray] | None = None,
|
||||||
context: pyopencl.Context | None = None,
|
context: pyopencl.Context | None = None,
|
||||||
queue: pyopencl.CommandQueue | None = None,
|
queue: pyopencl.CommandQueue | None = None,
|
||||||
float_type: Type = numpy.float32,
|
float_type: type = numpy.float32,
|
||||||
do_poynting: bool = True,
|
do_poynting: bool = True,
|
||||||
do_fieldsrc: bool = False,
|
do_fieldsrc: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -134,7 +140,7 @@ class Simulation:
|
|||||||
if dxes is None:
|
if dxes is None:
|
||||||
dxes = 1.0
|
dxes = 1.0
|
||||||
|
|
||||||
if isinstance(dxes, (float, int)):
|
if isinstance(dxes, float | int):
|
||||||
uniform_dx = dxes
|
uniform_dx = dxes
|
||||||
min_dx = dxes
|
min_dx = dxes
|
||||||
else:
|
else:
|
||||||
@ -143,13 +149,14 @@ class Simulation:
|
|||||||
min_dx = min(min(dxn) for dxn in dxes[0] + dxes[1])
|
min_dx = min(min(dxn) for dxn in dxes[0] + dxes[1])
|
||||||
|
|
||||||
max_dt = min_dx * .99 / numpy.sqrt(3)
|
max_dt = min_dx * .99 / numpy.sqrt(3)
|
||||||
|
logger.info(f'{min_dx=}, {max_dt=}, {dt=}')
|
||||||
|
|
||||||
if dt is None:
|
if dt is None:
|
||||||
self.dt = max_dt
|
self.dt = max_dt
|
||||||
elif dt > max_dt:
|
elif dt > max_dt:
|
||||||
warnings.warn(f'Warning: unstable dt: {dt}')
|
warnings.warn(f'Warning: unstable dt: {dt}', stacklevel=2)
|
||||||
elif dt <= 0:
|
elif dt <= 0:
|
||||||
raise Exception(f'Invalid dt: {dt}')
|
raise FDTDError(f'Invalid dt: {dt}')
|
||||||
else:
|
else:
|
||||||
self.dt = dt
|
self.dt = dt
|
||||||
|
|
||||||
@ -173,28 +180,31 @@ class Simulation:
|
|||||||
def ptr(arg: str) -> str:
|
def ptr(arg: str) -> str:
|
||||||
return ctype + ' *' + arg
|
return ctype + ' *' + arg
|
||||||
|
|
||||||
base_fields = OrderedDict()
|
base_fields = {
|
||||||
base_fields[ptr('E')] = self.E
|
ptr('E'): self.E,
|
||||||
base_fields[ptr('H')] = self.H
|
ptr('H'): self.H,
|
||||||
base_fields[ctype + ' dt'] = self.dt
|
ctype + ' dt': self.dt,
|
||||||
|
}
|
||||||
if uniform_dx is False:
|
if uniform_dx is False:
|
||||||
inv_dx_names = ['inv_d' + eh + r for eh in 'eh' for r in 'xyz']
|
inv_dx_names = ['inv_d' + eh + r for eh in 'eh' for r in 'xyz']
|
||||||
for name, field in zip(inv_dx_names, self.inv_dxes):
|
for name, field in zip(inv_dx_names, self.inv_dxes, strict=True):
|
||||||
base_fields[ptr(name)] = field
|
base_fields[ptr(name)] = field
|
||||||
|
|
||||||
eps_field = OrderedDict()
|
eps_field = {ptr('eps'): self.eps}
|
||||||
eps_field[ptr('eps')] = self.eps
|
|
||||||
|
|
||||||
if bloch_boundaries:
|
if bloch_boundaries:
|
||||||
base_fields[ptr('F')] = self.F
|
base_fields |= {
|
||||||
base_fields[ptr('G')] = self.G
|
ptr('F'): self.F,
|
||||||
|
ptr('G'): self.G,
|
||||||
|
}
|
||||||
|
|
||||||
bloch_fields = OrderedDict()
|
bloch_fields = {
|
||||||
bloch_fields[ptr('E')] = self.F
|
ptr('E'): self.F,
|
||||||
bloch_fields[ptr('H')] = self.G
|
ptr('H'): self.G,
|
||||||
bloch_fields[ctype + ' dt'] = self.dt
|
ctype + ' dt': self.dt,
|
||||||
bloch_fields[ptr('F')] = self.E
|
ptr('F'): self.E,
|
||||||
bloch_fields[ptr('G')] = self.H
|
ptr('G'): self.H,
|
||||||
|
}
|
||||||
|
|
||||||
common_source = jinja_env.get_template('common.cl').render(
|
common_source = jinja_env.get_template('common.cl').render(
|
||||||
ftype=ctype,
|
ftype=ctype,
|
||||||
@ -216,8 +226,8 @@ class Simulation:
|
|||||||
if bloch_boundaries:
|
if bloch_boundaries:
|
||||||
bloch_args = jinja_args.copy()
|
bloch_args = jinja_args.copy()
|
||||||
bloch_args['do_poynting'] = False
|
bloch_args['do_poynting'] = False
|
||||||
bloch_args['bloch'] = [
|
bloch_args['bloch'] = [{
|
||||||
{'axis': b['axis'],
|
'axis': b['axis'],
|
||||||
'real': b['imag'],
|
'real': b['imag'],
|
||||||
'imag': b['real'],
|
'imag': b['real'],
|
||||||
}
|
}
|
||||||
@ -227,7 +237,7 @@ class Simulation:
|
|||||||
self.sources['F'] = F_source
|
self.sources['F'] = F_source
|
||||||
self.sources['G'] = G_source
|
self.sources['G'] = G_source
|
||||||
|
|
||||||
S_fields = OrderedDict()
|
S_fields = {}
|
||||||
if do_poynting:
|
if do_poynting:
|
||||||
self.S = pyopencl.array.zeros_like(self.E)
|
self.S = pyopencl.array.zeros_like(self.E)
|
||||||
S_fields[ptr('S')] = self.S
|
S_fields[ptr('S')] = self.S
|
||||||
@ -237,7 +247,7 @@ class Simulation:
|
|||||||
S_fields[ptr('S0')] = self.S0
|
S_fields[ptr('S0')] = self.S0
|
||||||
S_fields[ptr('S1')] = self.S1
|
S_fields[ptr('S1')] = self.S1
|
||||||
|
|
||||||
J_fields = OrderedDict()
|
J_fields = {}
|
||||||
if do_fieldsrc:
|
if do_fieldsrc:
|
||||||
J_source = jinja_env.get_template('update_j.cl').render(**jinja_args)
|
J_source = jinja_env.get_template('update_j.cl').render(**jinja_args)
|
||||||
self.sources['J'] = J_source
|
self.sources['J'] = J_source
|
||||||
@ -247,37 +257,36 @@ class Simulation:
|
|||||||
J_fields[ptr('Jr')] = self.Jr
|
J_fields[ptr('Jr')] = self.Jr
|
||||||
J_fields[ptr('Ji')] = self.Ji
|
J_fields[ptr('Ji')] = self.Ji
|
||||||
|
|
||||||
'''
|
#
|
||||||
PML
|
# PML
|
||||||
'''
|
#
|
||||||
pml_e_fields, pml_h_fields = self._create_pmls(pmls)
|
pml_e_fields, pml_h_fields = self._create_pmls(pmls)
|
||||||
if bloch_boundaries:
|
if bloch_boundaries:
|
||||||
pml_f_fields, pml_g_fields = self._create_pmls(pmls)
|
pml_f_fields, pml_g_fields = self._create_pmls(pmls)
|
||||||
|
|
||||||
'''
|
#
|
||||||
Create operations
|
# Create operations
|
||||||
'''
|
#
|
||||||
self.update_E = self._create_operation(E_source, (base_fields, eps_field, pml_e_fields))
|
self.update_E = self._create_operation(E_source, (base_fields, eps_field, pml_e_fields))
|
||||||
self.update_H = self._create_operation(H_source, (base_fields, pml_h_fields, S_fields))
|
self.update_H = self._create_operation(H_source, (base_fields, pml_h_fields, S_fields))
|
||||||
if bloch_boundaries:
|
if bloch_boundaries:
|
||||||
self.update_F = self._create_operation(F_source, (bloch_fields, eps_field, pml_f_fields))
|
self.update_F = self._create_operation(F_source, (bloch_fields, eps_field, pml_f_fields))
|
||||||
self.update_G = self._create_operation(G_source, (bloch_fields, pml_g_fields))
|
self.update_G = self._create_operation(G_source, (bloch_fields, pml_g_fields))
|
||||||
if do_fieldsrc:
|
if do_fieldsrc:
|
||||||
args = OrderedDict()
|
args = base_fields | J_fields
|
||||||
[args.update(d) for d in (base_fields, J_fields)]
|
|
||||||
var_args = [ctype + ' ' + v for v in 'cs'] + ['uint ' + r + m for r in 'xyz' for m in ('min', 'max')]
|
var_args = [ctype + ' ' + v for v in 'cs'] + ['uint ' + r + m for r in 'xyz' for m in ('min', 'max')]
|
||||||
update = ElementwiseKernel(self.context, operation=J_source,
|
update = ElementwiseKernel(self.context, operation=J_source,
|
||||||
arguments=', '.join(list(args.keys()) + var_args))
|
arguments=', '.join(list(args.keys()) + var_args))
|
||||||
self.update_J = lambda e, *a: update(*args.values(), *a, wait_for=e)
|
self.update_J = lambda e, *a: update(*args.values(), *a, wait_for=e)
|
||||||
|
|
||||||
def _create_pmls(self, pmls):
|
def _create_pmls(self, pmls: Sequence[dict[str, float]]) -> tuple[dict[str, pyopencl.array.Array], ...]:
|
||||||
ctype = type_to_C(self.arg_type)
|
ctype = type_to_C(self.arg_type)
|
||||||
|
|
||||||
def ptr(arg: str) -> str:
|
def ptr(arg: str) -> str:
|
||||||
return ctype + ' *' + arg
|
return ctype + ' *' + arg
|
||||||
|
|
||||||
pml_e_fields = OrderedDict()
|
pml_e_fields = {}
|
||||||
pml_h_fields = OrderedDict()
|
pml_h_fields = {}
|
||||||
for pml in pmls:
|
for pml in pmls:
|
||||||
a = 'xyz'.find(pml['axis'])
|
a = 'xyz'.find(pml['axis'])
|
||||||
|
|
||||||
@ -285,7 +294,9 @@ class Simulation:
|
|||||||
kappa_max = numpy.sqrt(pml['mu_eff'] * pml['epsilon_eff'])
|
kappa_max = numpy.sqrt(pml['mu_eff'] * pml['epsilon_eff'])
|
||||||
alpha_max = pml['cfs_alpha']
|
alpha_max = pml['cfs_alpha']
|
||||||
|
|
||||||
def par(x):
|
print(sigma_max, kappa_max, alpha_max, pml['thickness'], self.dt)
|
||||||
|
|
||||||
|
def par(x, pml=pml, sigma_max=sigma_max, kappa_max=kappa_max, alpha_max=alpha_max): # noqa: ANN001, ANN202
|
||||||
scaling = (x / pml['thickness']) ** pml['m']
|
scaling = (x / pml['thickness']) ** pml['m']
|
||||||
sigma = scaling * sigma_max
|
sigma = scaling * sigma_max
|
||||||
kappa = 1 + scaling * (kappa_max - 1)
|
kappa = 1 + scaling * (kappa_max - 1)
|
||||||
@ -301,8 +312,13 @@ class Simulation:
|
|||||||
elif pml['polarity'] == 'n':
|
elif pml['polarity'] == 'n':
|
||||||
xh -= 0.5
|
xh -= 0.5
|
||||||
|
|
||||||
|
logger.debug(f'{pml=}')
|
||||||
|
logger.debug(f'{xe=}')
|
||||||
|
logger.debug(f'{xh=}')
|
||||||
|
logger.debug(f'{par(xe)=}')
|
||||||
|
logger.debug(f'{par(xh)=}')
|
||||||
pml_p_names = [['p' + pml['axis'] + i + eh + pml['polarity'] for i in '012'] for eh in 'eh']
|
pml_p_names = [['p' + pml['axis'] + i + eh + pml['polarity'] for i in '012'] for eh in 'eh']
|
||||||
for name_e, name_h, pe, ph in zip(pml_p_names[0], pml_p_names[1], par(xe), par(xh)):
|
for name_e, name_h, pe, ph in zip(pml_p_names[0], pml_p_names[1], par(xe), par(xh), strict=True):
|
||||||
pml_e_fields[ptr(name_e)] = pyopencl.array.to_device(self.queue, pe)
|
pml_e_fields[ptr(name_e)] = pyopencl.array.to_device(self.queue, pe)
|
||||||
pml_h_fields[ptr(name_h)] = pyopencl.array.to_device(self.queue, ph)
|
pml_h_fields[ptr(name_h)] = pyopencl.array.to_device(self.queue, ph)
|
||||||
|
|
||||||
@ -313,13 +329,13 @@ class Simulation:
|
|||||||
psi_shape = list(self.shape)
|
psi_shape = list(self.shape)
|
||||||
psi_shape[a] = pml['thickness']
|
psi_shape[a] = pml['thickness']
|
||||||
|
|
||||||
for ne, nh in zip(*psi_names):
|
for ne, nh in zip(*psi_names, strict=True):
|
||||||
pml_e_fields[ptr(ne)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type)
|
pml_e_fields[ptr(ne)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type)
|
||||||
pml_h_fields[ptr(nh)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type)
|
pml_h_fields[ptr(nh)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type)
|
||||||
return pml_e_fields, pml_h_fields
|
return pml_e_fields, pml_h_fields
|
||||||
|
|
||||||
def _create_operation(self, source, args_fields) -> Callable[..., pyopencl.Event]:
|
def _create_operation(self, source: str, args_fields: Sequence[dict[str, pyopencl.array.Array]]) -> Callable[..., pyopencl.Event]:
|
||||||
args = OrderedDict()
|
args = {}
|
||||||
for d in args_fields:
|
for d in args_fields:
|
||||||
args.update(d)
|
args.update(d)
|
||||||
update = ElementwiseKernel(
|
update = ElementwiseKernel(
|
||||||
@ -334,38 +350,30 @@ class Simulation:
|
|||||||
context: pyopencl.Context | None = None,
|
context: pyopencl.Context | None = None,
|
||||||
queue: pyopencl.CommandQueue | None = None,
|
queue: pyopencl.CommandQueue | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if context is None:
|
self.context = context or pyopencl.create_some_context()
|
||||||
self.context = pyopencl.create_some_context()
|
self.queue = queue or pyopencl.CommandQueue(self.context)
|
||||||
else:
|
|
||||||
self.context = context
|
|
||||||
|
|
||||||
if queue is None:
|
|
||||||
self.queue = pyopencl.CommandQueue(self.context)
|
|
||||||
else:
|
|
||||||
self.queue = queue
|
|
||||||
|
|
||||||
def _create_eps(self, epsilon: NDArray) -> pyopencl.array.Array:
|
def _create_eps(self, epsilon: NDArray) -> pyopencl.array.Array:
|
||||||
if len(epsilon) != 3:
|
if len(epsilon) != 3:
|
||||||
raise Exception('Epsilon must be a list with length of 3')
|
raise FDTDError('Epsilon must be a list with length of 3')
|
||||||
if not all((e.shape == epsilon[0].shape for e in epsilon[1:])):
|
if not all(e.shape == epsilon[0].shape for e in epsilon[1:]):
|
||||||
raise Exception('All epsilon grids must have the same shape. Shapes are {}', [e.shape for e in epsilon])
|
raise FDTDError('All epsilon grids must have the same shape. Shapes are {}', [e.shape for e in epsilon])
|
||||||
if not epsilon[0].shape == self.shape:
|
if not epsilon[0].shape == self.shape:
|
||||||
raise Exception(f'Epsilon shape mismatch. Expected {self.shape}, got {epsilon[0].shape}')
|
raise FDTDError(f'Epsilon shape mismatch. Expected {self.shape}, got {epsilon[0].shape}')
|
||||||
self.eps = pyopencl.array.to_device(self.queue, vec(epsilon).astype(self.arg_type))
|
self.eps = pyopencl.array.to_device(self.queue, vec(epsilon).astype(self.arg_type))
|
||||||
|
|
||||||
def _create_field(self, initial_value: NDArray | None = None) -> pyopencl.array.Array:
|
def _create_field(self, initial_value: NDArray | None = None) -> pyopencl.array.Array:
|
||||||
if initial_value is None:
|
if initial_value is None:
|
||||||
return pyopencl.array.zeros_like(self.eps)
|
return pyopencl.array.zeros_like(self.eps)
|
||||||
else:
|
|
||||||
if len(initial_value) != 3:
|
if len(initial_value) != 3:
|
||||||
Exception('Initial field value must be a list of length 3')
|
raise FDTDError('Initial field value must be a list of length 3')
|
||||||
if not all((f.shape == self.shape for f in initial_value)):
|
if not all(f.shape == self.shape for f in initial_value):
|
||||||
Exception('Initial field list elements must have same shape as epsilon elements')
|
raise FDTDError('Initial field list elements must have same shape as epsilon elements')
|
||||||
return pyopencl.array.to_device(self.queue, vec(initial_value).astype(self.arg_type))
|
return pyopencl.array.to_device(self.queue, vec(initial_value).astype(self.arg_type))
|
||||||
|
|
||||||
|
|
||||||
def type_to_C(
|
def type_to_C(
|
||||||
float_type: Type,
|
float_type: type,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Returns a string corresponding to the C equivalent of a numpy type.
|
Returns a string corresponding to the C equivalent of a numpy type.
|
||||||
@ -385,7 +393,7 @@ def type_to_C(
|
|||||||
numpy.complex128: 'cdouble_t',
|
numpy.complex128: 'cdouble_t',
|
||||||
}
|
}
|
||||||
if float_type not in types:
|
if float_type not in types:
|
||||||
raise Exception(f'Unsupported type: {float_type}')
|
raise FDTDError(f'Unsupported type: {float_type}')
|
||||||
|
|
||||||
return types[float_type]
|
return types[float_type]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user