type hints and lint

This commit is contained in:
Forgejo Actions 2026-04-21 21:13:34 -07:00
commit c6c9159b13
30 changed files with 198 additions and 136 deletions

View file

@ -14,6 +14,7 @@ simple straight interface:
from __future__ import annotations from __future__ import annotations
import importlib import importlib
from typing import TYPE_CHECKING
import numpy import numpy
from numpy import pi from numpy import pi
@ -24,6 +25,9 @@ from gridlock import Extent
from meanas.fdfd import eme, waveguide_2d from meanas.fdfd import eme, waveguide_2d
from meanas.fdmath import unvec from meanas.fdmath import unvec
if TYPE_CHECKING:
from types import ModuleType
WL = 1310.0 WL = 1310.0
DX = 40.0 DX = 40.0
@ -35,7 +39,7 @@ EPS_OX = 1.453 ** 2
MODE_NUMBERS = numpy.array([0]) MODE_NUMBERS = numpy.array([0])
def require_optional(name: str, package_name: str | None = None): def require_optional(name: str, package_name: str | None = None) -> ModuleType:
package_name = package_name or name package_name = package_name or name
try: try:
return importlib.import_module(name) return importlib.import_module(name)
@ -159,7 +163,7 @@ def print_summary(ss: numpy.ndarray, wavenumbers_left: numpy.ndarray, wavenumber
def plot_results( def plot_results(
*, *,
pyplot, pyplot: ModuleType,
ss: numpy.ndarray, ss: numpy.ndarray,
left_mode: tuple[numpy.ndarray, numpy.ndarray], left_mode: tuple[numpy.ndarray, numpy.ndarray],
right_mode: tuple[numpy.ndarray, numpy.ndarray], right_mode: tuple[numpy.ndarray, numpy.ndarray],

View file

@ -15,6 +15,7 @@ This example demonstrates a cylindrical-waveguide EME workflow:
from __future__ import annotations from __future__ import annotations
import importlib import importlib
from typing import TYPE_CHECKING
import numpy import numpy
from numpy import pi from numpy import pi
@ -26,6 +27,9 @@ from gridlock import Extent
from meanas.fdfd import eme, waveguide_2d, waveguide_cyl from meanas.fdfd import eme, waveguide_2d, waveguide_cyl
from meanas.fdmath import unvec from meanas.fdmath import unvec
if TYPE_CHECKING:
from types import ModuleType
WL = 1310.0 WL = 1310.0
DX = 40.0 DX = 40.0
@ -40,7 +44,7 @@ STRAIGHT_SECTION_LENGTH = 12e3
BEND_ANGLE = pi / 2 BEND_ANGLE = pi / 2
def require_optional(name: str, package_name: str | None = None): def require_optional(name: str, package_name: str | None = None) -> ModuleType:
package_name = package_name or name package_name = package_name or name
try: try:
return importlib.import_module(name) return importlib.import_module(name)
@ -163,7 +167,7 @@ def solve_bend_modes(
def build_cascaded_network( def build_cascaded_network(
skrf, skrf: ModuleType,
*, *,
interface_s: numpy.ndarray, interface_s: numpy.ndarray,
straight_wavenumbers: numpy.ndarray, straight_wavenumbers: numpy.ndarray,
@ -216,7 +220,7 @@ def print_summary(
def plot_results( def plot_results(
*, *,
pyplot, pyplot: ModuleType,
interface_s: numpy.ndarray, interface_s: numpy.ndarray,
cascaded_s: numpy.ndarray, cascaded_s: numpy.ndarray,
straight_mode: tuple[numpy.ndarray, numpy.ndarray], straight_mode: tuple[numpy.ndarray, numpy.ndarray],

View file

@ -89,7 +89,7 @@ def perturbed_l3(a: float, radius: float, **kwargs) -> Pattern:
return pat return pat
def main(): def main() -> None:
dtype = numpy.float32 dtype = numpy.float32
max_t = 3600 # number of timesteps max_t = 3600 # number of timesteps
@ -97,7 +97,6 @@ def main():
pml_thickness = 8 # (number of cells) pml_thickness = 8 # (number of cells)
wl = 1550 # Excitation wavelength and fwhm wl = 1550 # Excitation wavelength and fwhm
dwl = 100
# Device design parameters # Device design parameters
xy_size = numpy.array([10, 10]) xy_size = numpy.array([10, 10])

View file

@ -683,7 +683,16 @@ def eigsolve(
return numpy.abs(trace) return numpy.abs(trace)
if False: if False:
def trace_deriv(theta, sgn: int = sgn, ZtAZ=ZtAZ, DtAD=DtAD, symZtD=symZtD, symZtAD=symZtAD, ZtZ=ZtZ, DtD=DtD): # noqa: ANN001 def trace_deriv(
theta: float,
sgn: int = sgn,
ZtAZ=ZtAZ, # noqa: ANN001
DtAD=DtAD, # noqa: ANN001
symZtD=symZtD, # noqa: ANN001
symZtAD=symZtAD, # noqa: ANN001
ZtZ=ZtZ, # noqa: ANN001
DtD=DtD, # noqa: ANN001
) -> float:
Qi = Qi_func(theta) Qi = Qi_func(theta)
c2 = numpy.cos(2 * theta) c2 = numpy.cos(2 * theta)
s2 = numpy.sin(2 * theta) s2 = numpy.sin(2 * theta)

View file

@ -27,11 +27,13 @@ from scipy import sparse
from ..fdmath import dx_lists2_t, vcfdfield2 from ..fdmath import dx_lists2_t, vcfdfield2
from .waveguide_2d import inner_product from .waveguide_2d import inner_product
type wavenumber_seq = Sequence[complex] | NDArray[numpy.complexfloating] | NDArray[numpy.floating]
def _validate_port_modes( def _validate_port_modes(
name: str, name: str,
ehs: Sequence[Sequence[vcfdfield2]], ehs: Sequence[Sequence[vcfdfield2]],
wavenumbers: Sequence[complex], wavenumbers: wavenumber_seq,
) -> tuple[tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...]]:
if len(ehs) != len(wavenumbers): if len(ehs) != len(wavenumbers):
raise ValueError(f'{name} mode list and wavenumber list must have the same length') raise ValueError(f'{name} mode list and wavenumber list must have the same length')
@ -61,9 +63,9 @@ def _validate_port_modes(
def get_tr( def get_tr(
ehLs: Sequence[Sequence[vcfdfield2]], ehLs: Sequence[Sequence[vcfdfield2]],
wavenumbers_L: Sequence[complex], wavenumbers_L: wavenumber_seq,
ehRs: Sequence[Sequence[vcfdfield2]], ehRs: Sequence[Sequence[vcfdfield2]],
wavenumbers_R: Sequence[complex], wavenumbers_R: wavenumber_seq,
dxes: dx_lists2_t, dxes: dx_lists2_t,
) -> tuple[NDArray[numpy.complex128], NDArray[numpy.complex128]]: ) -> tuple[NDArray[numpy.complex128], NDArray[numpy.complex128]]:
""" """
@ -118,9 +120,9 @@ def get_tr(
def get_abcd( def get_abcd(
ehLs: Sequence[Sequence[vcfdfield2]], ehLs: Sequence[Sequence[vcfdfield2]],
wavenumbers_L: Sequence[complex], wavenumbers_L: wavenumber_seq,
ehRs: Sequence[Sequence[vcfdfield2]], ehRs: Sequence[Sequence[vcfdfield2]],
wavenumbers_R: Sequence[complex], wavenumbers_R: wavenumber_seq,
**kwargs, **kwargs,
) -> sparse.sparray: ) -> sparse.sparray:
""" """
@ -151,9 +153,9 @@ def get_abcd(
def get_s( def get_s(
ehLs: Sequence[Sequence[vcfdfield2]], ehLs: Sequence[Sequence[vcfdfield2]],
wavenumbers_L: Sequence[complex], wavenumbers_L: wavenumber_seq,
ehRs: Sequence[Sequence[vcfdfield2]], ehRs: Sequence[Sequence[vcfdfield2]],
wavenumbers_R: Sequence[complex], wavenumbers_R: wavenumber_seq,
force_nogain: bool = False, force_nogain: bool = False,
force_reciprocal: bool = False, force_reciprocal: bool = False,
**kwargs, **kwargs,

View file

@ -1,23 +1,24 @@
""" """
Functions for performing near-to-farfield transformation (and the reverse). Functions for performing near-to-farfield transformation (and the reverse).
""" """
from typing import Any, cast, TYPE_CHECKING from typing import Any, cast
from collections.abc import Sequence
import numpy import numpy
from numpy.fft import fft2, fftshift, fftfreq, ifft2, ifftshift from numpy.fft import fft2, fftshift, fftfreq, ifft2, ifftshift
from numpy import pi from numpy import pi
from numpy.typing import NDArray
from numpy import complexfloating
from ..fdmath import cfdfield_t type farfield_slice = NDArray[complexfloating]
type transverse_slice_pair = Sequence[farfield_slice]
if TYPE_CHECKING:
from collections.abc import Sequence
def near_to_farfield( def near_to_farfield(
E_near: cfdfield_t, E_near: transverse_slice_pair,
H_near: cfdfield_t, H_near: transverse_slice_pair,
dx: float, dx: float,
dy: float, dy: float,
padded_size: list[int] | int | None = None padded_size: Sequence[int] | int | None = None
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Compute the farfield, i.e. the distribution of the fields after propagation Compute the farfield, i.e. the distribution of the fields after propagation
@ -58,7 +59,7 @@ def near_to_farfield(
raise Exception('H_near must be a length-2 list of ndarrays') raise Exception('H_near must be a length-2 list of ndarrays')
s = E_near[0].shape s = E_near[0].shape
if not all(s == f.shape for f in E_near + H_near): if not all(s == f.shape for f in [*E_near, *H_near]):
raise Exception('All fields must be the same shape!') raise Exception('All fields must be the same shape!')
if padded_size is None: if padded_size is None:
@ -123,11 +124,11 @@ def near_to_farfield(
def far_to_nearfield( def far_to_nearfield(
E_far: cfdfield_t, E_far: transverse_slice_pair,
H_far: cfdfield_t, H_far: transverse_slice_pair,
dkx: float, dkx: float,
dky: float, dky: float,
padded_size: list[int] | int | None = None padded_size: Sequence[int] | int | None = None
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Compute the farfield, i.e. the distribution of the fields after propagation Compute the farfield, i.e. the distribution of the fields after propagation
@ -164,7 +165,7 @@ def far_to_nearfield(
raise Exception('H_far must be a length-2 list of ndarrays') raise Exception('H_far must be a length-2 list of ndarrays')
s = E_far[0].shape s = E_far[0].shape
if not all(s == f.shape for f in E_far + H_far): if not all(s == f.shape for f in [*E_far, *H_far]):
raise Exception('All fields must be the same shape!') raise Exception('All fields must be the same shape!')
if padded_size is None: if padded_size is None:

View file

@ -423,10 +423,10 @@ def normalized_fields_h(
def _normalized_fields( def _normalized_fields(
e: vcfdslice, e: vcfdslice,
h: vcfdslice, h: vcfdslice,
omega: complex, _omega: complex,
dxes: dx_lists2_t, dxes: dx_lists2_t,
epsilon: vfdslice, epsilon: vfdslice,
mu: vfdslice | None = None, _mu: vfdslice | None = None,
prop_phase: float = 0, prop_phase: float = 0,
) -> tuple[vcfdslice_t, vcfdslice_t]: ) -> tuple[vcfdslice_t, vcfdslice_t]:
r""" r"""

View file

@ -19,9 +19,8 @@ The intended workflow is:
That same convention controls which side of the selected slice is used for the That same convention controls which side of the selected slice is used for the
overlap window and how the expanded field is phased. overlap window and how the expanded field is phased.
""" """
from typing import Any, cast from typing import Any, TypedDict, cast
import warnings import warnings
from typing import Any
from collections.abc import Sequence from collections.abc import Sequence
import numpy import numpy
from numpy.typing import NDArray from numpy.typing import NDArray
@ -31,6 +30,13 @@ from ..fdmath import vec, unvec, dx_lists_t, cfdfield_t, fdfield, cfdfield
from . import operators, waveguide_2d from . import operators, waveguide_2d
class Waveguide3DMode(TypedDict):
wavenumber: complex
wavenumber_2d: complex
H: NDArray[complexfloating]
E: NDArray[complexfloating]
def solve_mode( def solve_mode(
mode_number: int, mode_number: int,
omega: complex, omega: complex,
@ -40,7 +46,7 @@ def solve_mode(
slices: Sequence[slice], slices: Sequence[slice],
epsilon: fdfield, epsilon: fdfield,
mu: fdfield | None = None, mu: fdfield | None = None,
) -> dict[str, complex | NDArray[complexfloating]]: ) -> Waveguide3DMode:
r""" r"""
Given a 3D grid, selects a slice from the grid and attempts to Given a 3D grid, selects a slice from the grid and attempts to
solve for an eigenmode propagating through that slice. solve for an eigenmode propagating through that slice.
@ -121,7 +127,7 @@ def solve_mode(
E[iii] = e[oo][:, :, None].transpose(reverse_order) E[iii] = e[oo][:, :, None].transpose(reverse_order)
H[iii] = h[oo][:, :, None].transpose(reverse_order) H[iii] = h[oo][:, :, None].transpose(reverse_order)
results = { results: Waveguide3DMode = {
'wavenumber': wavenumber, 'wavenumber': wavenumber,
'wavenumber_2d': wavenumber_2d, 'wavenumber_2d': wavenumber_2d,
'H': H, 'H': H,
@ -184,13 +190,13 @@ def compute_source(
def compute_overlap_e( def compute_overlap_e(
E: cfdfield_t, E: cfdfield,
wavenumber: complex, wavenumber: complex,
dxes: dx_lists_t, dxes: dx_lists_t,
axis: int, axis: int,
polarity: int, polarity: int,
slices: Sequence[slice], slices: Sequence[slice],
omega: float, _omega: float,
) -> cfdfield_t: ) -> cfdfield_t:
r""" r"""
Build an overlap field for projecting another 3D electric field onto a mode. Build an overlap field for projecting another 3D electric field onto a mode.
@ -262,7 +268,7 @@ def compute_overlap_e(
if clipped_start >= clipped_stop: if clipped_start >= clipped_stop:
raise ValueError('Requested overlap window lies outside the domain') raise ValueError('Requested overlap window lies outside the domain')
if clipped_start != start or clipped_stop != stop: if clipped_start != start or clipped_stop != stop:
warnings.warn('Requested overlap window was clipped to fit within the domain', RuntimeWarning) warnings.warn('Requested overlap window was clipped to fit within the domain', RuntimeWarning, stacklevel=2)
slices2_l = list(slices) slices2_l = list(slices)
slices2_l[axis] = slice(clipped_start, clipped_stop) slices2_l[axis] = slice(clipped_start, clipped_stop)
@ -275,7 +281,7 @@ def compute_overlap_e(
norm = (Etgt.conj() * Etgt).sum() norm = (Etgt.conj() * Etgt).sum()
if norm == 0: if norm == 0:
raise ValueError('Requested overlap window contains no overlap field support') raise ValueError('Requested overlap window contains no overlap field support')
Etgt /= norm Etgt = Etgt / norm
return cfdfield_t(Etgt) return cfdfield_t(Etgt)

View file

@ -130,7 +130,7 @@ import numpy
from numpy.typing import NDArray, ArrayLike from numpy.typing import NDArray, ArrayLike
from scipy import sparse from scipy import sparse
from ..fdmath import vec, unvec, dx_lists2_t, vcfdslice_t, vcfdfield2_t, vfdslice, vcfdslice, vcfdfield2 from ..fdmath import vec, unvec, dx_lists2_t, vcfdslice_t, vfdslice, vcfdslice, vcfdfield2
from ..fdmath.operators import deriv_forward, deriv_back from ..fdmath.operators import deriv_forward, deriv_back
from ..eigensolvers import signed_eigensolve, rayleigh_quotient_iteration from ..eigensolvers import signed_eigensolve, rayleigh_quotient_iteration
from . import waveguide_2d from . import waveguide_2d
@ -267,7 +267,7 @@ def solve_mode(
mode_number: int, mode_number: int,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> tuple[vcfdslice, complex]: ) -> tuple[vcfdfield2, complex]:
""" """
Wrapper around `solve_modes()` that solves for a single mode. Wrapper around `solve_modes()` that solves for a single mode.
@ -285,7 +285,7 @@ def solve_mode(
def linear_wavenumbers( def linear_wavenumbers(
e_xys: list[vcfdfield2_t], e_xys: Sequence[vcfdfield2] | NDArray[numpy.complex128],
angular_wavenumbers: ArrayLike, angular_wavenumbers: ArrayLike,
epsilon: vfdslice, epsilon: vfdslice,
dxes: dx_lists2_t, dxes: dx_lists2_t,
@ -537,11 +537,11 @@ def normalized_fields_e(
def _normalized_fields( def _normalized_fields(
e: vcfdslice, e: vcfdslice,
h: vcfdslice, h: vcfdslice,
omega: complex, _omega: complex,
dxes: dx_lists2_t, dxes: dx_lists2_t,
rmin: float, # Currently unused, but may want to use cylindrical poynting _rmin: float, # Currently unused, but may want to use cylindrical poynting
epsilon: vfdslice, epsilon: vfdslice,
mu: vfdslice | None = None, _mu: vfdslice | None = None,
prop_phase: float = 0, prop_phase: float = 0,
) -> tuple[vcfdslice_t, vcfdslice_t]: ) -> tuple[vcfdslice_t, vcfdslice_t]:
r""" r"""

View file

@ -10,7 +10,7 @@ import numpy
from numpy.typing import NDArray from numpy.typing import NDArray
from numpy import floating, complexfloating from numpy import floating, complexfloating
from .types import fdfield_t, fdfield_updater_t from .types import fdfield, fdfield_updater_t
def deriv_forward( def deriv_forward(
@ -127,7 +127,7 @@ def curl_forward_parts(
) -> Callable: ) -> Callable:
Dx, Dy, Dz = deriv_forward(dx_e) Dx, Dy, Dz = deriv_forward(dx_e)
def mkparts_fwd(e: fdfield_t) -> tuple[tuple[fdfield_t, fdfield_t], ...]: def mkparts_fwd(e: fdfield) -> tuple[tuple[fdfield, fdfield], ...]:
return ((-Dz(e[1]), Dy(e[2])), return ((-Dz(e[1]), Dy(e[2])),
( Dz(e[0]), -Dx(e[2])), ( Dz(e[0]), -Dx(e[2])),
(-Dy(e[0]), Dx(e[1]))) (-Dy(e[0]), Dx(e[1])))
@ -140,7 +140,7 @@ def curl_back_parts(
) -> Callable: ) -> Callable:
Dx, Dy, Dz = deriv_back(dx_h) Dx, Dy, Dz = deriv_back(dx_h)
def mkparts_back(h: fdfield_t) -> tuple[tuple[fdfield_t, fdfield_t], ...]: def mkparts_back(h: fdfield) -> tuple[tuple[fdfield, fdfield], ...]:
return ((-Dz(h[1]), Dy(h[2])), return ((-Dz(h[1]), Dy(h[2])),
( Dz(h[0]), -Dx(h[2])), ( Dz(h[0]), -Dx(h[2])),
(-Dy(h[0]), Dx(h[1]))) (-Dy(h[0]), Dx(h[1])))

View file

@ -9,7 +9,7 @@ from numpy.typing import NDArray
from numpy import floating, complexfloating from numpy import floating, complexfloating
from scipy import sparse from scipy import sparse
from .types import vfdfield_t from .types import vfdfield
def shift_circ( def shift_circ(
@ -171,7 +171,7 @@ def cross(
[-B[1], B[0], zero]]) [-B[1], B[0], zero]])
def vec_cross(b: vfdfield_t) -> sparse.sparray: def vec_cross(b: vfdfield) -> sparse.sparray:
""" """
Vector cross product operator Vector cross product operator

View file

@ -88,8 +88,8 @@ dx_lists2_mut = MutableSequence[MutableSequence[NDArray[floating | complexfloati
"""Mutable version of `dx_lists2_t`""" """Mutable version of `dx_lists2_t`"""
fdfield_updater_t = Callable[..., fdfield_t] fdfield_updater_t = Callable[..., fdfield]
"""Convenience type for functions which take and return an fdfield_t""" """Convenience type for functions which take and return a real `fdfield`"""
cfdfield_updater_t = Callable[..., cfdfield_t] cfdfield_updater_t = Callable[..., cfdfield]
"""Convenience type for functions which take and return an cfdfield_t""" """Convenience type for functions which take and return a complex `cfdfield`"""

View file

@ -3,7 +3,7 @@ Basic FDTD field updates
""" """
from ..fdmath import dx_lists_t, fdfield_t, fdfield_updater_t from ..fdmath import dx_lists_t, fdfield, fdfield_updater_t
from ..fdmath.functional import curl_forward, curl_back from ..fdmath.functional import curl_forward, curl_back
@ -47,7 +47,7 @@ def maxwell_e(
else: else:
curl_h_fun = curl_back() curl_h_fun = curl_back()
def me_fun(e: fdfield_t, h: fdfield_t, epsilon: fdfield_t | float) -> fdfield_t: def me_fun(e: fdfield, h: fdfield, epsilon: fdfield | float) -> fdfield:
""" """
Update the E-field. Update the E-field.
@ -103,7 +103,7 @@ def maxwell_h(
else: else:
curl_e_fun = curl_forward() curl_e_fun = curl_forward()
def mh_fun(e: fdfield_t, h: fdfield_t, mu: fdfield_t | float | None = None) -> fdfield_t: def mh_fun(e: fdfield, h: fdfield, mu: fdfield | float | None = None) -> fdfield:
""" """
Update the H-field. Update the H-field.

View file

@ -6,7 +6,7 @@ Boundary conditions
from typing import Any from typing import Any
from ..fdmath import fdfield_t, fdfield_updater_t from ..fdmath import fdfield, fdfield_updater_t
def conducting_boundary( def conducting_boundary(
@ -15,7 +15,7 @@ def conducting_boundary(
) -> tuple[fdfield_updater_t, fdfield_updater_t]: ) -> tuple[fdfield_updater_t, fdfield_updater_t]:
dirs = [0, 1, 2] dirs = [0, 1, 2]
if direction not in dirs: if direction not in dirs:
raise Exception(f'Invalid direction: {direction}') raise ValueError(f'Invalid direction: {direction}')
dirs.remove(direction) dirs.remove(direction)
u, v = dirs u, v = dirs
@ -31,13 +31,13 @@ def conducting_boundary(
boundary = tuple(boundary_slice) boundary = tuple(boundary_slice)
shifted1 = tuple(shifted1_slice) shifted1 = tuple(shifted1_slice)
def en(e: fdfield_t) -> fdfield_t: def en(e: fdfield) -> fdfield:
e[direction][boundary] = 0 e[direction][boundary] = 0
e[u][boundary] = e[u][shifted1] e[u][boundary] = e[u][shifted1]
e[v][boundary] = e[v][shifted1] e[v][boundary] = e[v][shifted1]
return e return e
def hn(h: fdfield_t) -> fdfield_t: def hn(h: fdfield) -> fdfield:
h[direction][boundary] = h[direction][shifted1] h[direction][boundary] = h[direction][shifted1]
h[u][boundary] = 0 h[u][boundary] = 0
h[v][boundary] = 0 h[v][boundary] = 0
@ -56,14 +56,14 @@ def conducting_boundary(
shifted1 = tuple(shifted1_slice) shifted1 = tuple(shifted1_slice)
shifted2 = tuple(shifted2_slice) shifted2 = tuple(shifted2_slice)
def ep(e: fdfield_t) -> fdfield_t: def ep(e: fdfield) -> fdfield:
e[direction][boundary] = -e[direction][shifted2] e[direction][boundary] = -e[direction][shifted2]
e[direction][shifted1] = 0 e[direction][shifted1] = 0
e[u][boundary] = e[u][shifted1] e[u][boundary] = e[u][shifted1]
e[v][boundary] = e[v][shifted1] e[v][boundary] = e[v][shifted1]
return e return e
def hp(h: fdfield_t) -> fdfield_t: def hp(h: fdfield) -> fdfield:
h[direction][boundary] = h[direction][shifted1] h[direction][boundary] = h[direction][shifted1]
h[u][boundary] = -h[u][shifted2] h[u][boundary] = -h[u][shifted2]
h[u][shifted1] = 0 h[u][shifted1] = 0
@ -73,4 +73,4 @@ def conducting_boundary(
return ep, hp return ep, hp
raise Exception(f'Bad polarity: {polarity}') raise ValueError(f'Bad polarity: {polarity}')

View file

@ -1,5 +1,6 @@
from collections.abc import Callable from collections.abc import Callable
import logging import logging
from typing import cast
import numpy import numpy
from numpy.typing import NDArray, ArrayLike from numpy.typing import NDArray, ArrayLike
@ -9,7 +10,14 @@ from numpy import pi
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
pulse_fn_t = Callable[[int | NDArray], tuple[float, float, float]] type pulse_scalar_t = float | NDArray[numpy.floating]
pulse_fn_t = Callable[[ArrayLike], tuple[pulse_scalar_t, pulse_scalar_t, pulse_scalar_t]]
def _scalar_or_array(values: NDArray[numpy.floating]) -> pulse_scalar_t:
if values.ndim == 0:
return float(values)
return cast('NDArray[numpy.floating]', values)
def gaussian_packet( def gaussian_packet(
@ -49,8 +57,9 @@ def gaussian_packet(
delay = numpy.ceil(delay * freq) / freq # force delay to integer number of periods to maintain phase delay = numpy.ceil(delay * freq) / freq # force delay to integer number of periods to maintain phase
logger.info(f'src_time {2 * delay / dt}') logger.info(f'src_time {2 * delay / dt}')
def source_phasor(ii: int | NDArray) -> tuple[float, float, float]: def source_phasor(ii: ArrayLike) -> tuple[pulse_scalar_t, pulse_scalar_t, pulse_scalar_t]:
t0 = ii * dt - delay ii_array = numpy.asarray(ii, dtype=float)
t0 = ii_array * dt - delay
envelope = numpy.sqrt(numpy.sqrt(2 * alpha / pi)) * numpy.exp(-alpha * t0 * t0) envelope = numpy.sqrt(numpy.sqrt(2 * alpha / pi)) * numpy.exp(-alpha * t0 * t0)
if one_sided: if one_sided:
@ -59,7 +68,7 @@ def gaussian_packet(
cc = numpy.cos(omega * t0) cc = numpy.cos(omega * t0)
ss = numpy.sin(omega * t0) ss = numpy.sin(omega * t0)
return envelope, cc, ss return _scalar_or_array(envelope), _scalar_or_array(cc), _scalar_or_array(ss)
# nrm = numpy.exp(-omega * omega / alpha) / 2 # nrm = numpy.exp(-omega * omega / alpha) / 2
@ -105,15 +114,16 @@ def ricker_pulse(
delay = delay_results.root delay = delay_results.root
delay = numpy.ceil(delay * freq) / freq # force delay to integer number of periods to maintain phase delay = numpy.ceil(delay * freq) / freq # force delay to integer number of periods to maintain phase
def source_phasor(ii: int | NDArray) -> tuple[float, float, float]: def source_phasor(ii: ArrayLike) -> tuple[pulse_scalar_t, pulse_scalar_t, pulse_scalar_t]:
t0 = ii * dt - delay ii_array = numpy.asarray(ii, dtype=float)
t0 = ii_array * dt - delay
rr = omega * t0 / 2 rr = omega * t0 / 2
ff = (1 - 2 * rr * rr) * numpy.exp(-rr * rr) ff = (1 - 2 * rr * rr) * numpy.exp(-rr * rr)
cc = numpy.cos(omega * t0) cc = numpy.cos(omega * t0)
ss = numpy.sin(omega * t0) ss = numpy.sin(omega * t0)
return ff, cc, ss return _scalar_or_array(ff), _scalar_or_array(cc), _scalar_or_array(ss)
return source_phasor, delay return source_phasor, delay

View file

@ -23,7 +23,7 @@ from copy import deepcopy
import numpy import numpy
from numpy.typing import NDArray, DTypeLike from numpy.typing import NDArray, DTypeLike
from ..fdmath import fdfield, fdfield_t, dx_lists_t from ..fdmath import fdfield, dx_lists_t
from ..fdmath.functional import deriv_forward, deriv_back from ..fdmath.functional import deriv_forward, deriv_back
@ -67,16 +67,16 @@ def cpml_params(
""" """
if axis not in range(3): if axis not in range(3):
raise Exception(f'Invalid axis: {axis}') raise ValueError(f'Invalid axis: {axis}')
if polarity not in (-1, 1): if polarity not in (-1, 1):
raise Exception(f'Invalid polarity: {polarity}') raise ValueError(f'Invalid polarity: {polarity}')
if thickness <= 2: if thickness <= 2:
raise Exception('It would be wise to have a pml with 4+ cells of thickness') raise ValueError('It would be wise to have a pml with 4+ cells of thickness')
if epsilon_eff <= 0: if epsilon_eff <= 0:
raise Exception('epsilon_eff must be positive') raise ValueError('epsilon_eff must be positive')
sigma_max = -ln_R_per_layer / 2 * (m + 1) sigma_max = -ln_R_per_layer / 2 * (m + 1)
kappa_max = numpy.sqrt(epsilon_eff * mu_eff) kappa_max = numpy.sqrt(epsilon_eff * mu_eff)
@ -129,8 +129,7 @@ def updates_with_cpml(
epsilon: fdfield, epsilon: fdfield,
*, *,
dtype: DTypeLike = numpy.float32, dtype: DTypeLike = numpy.float32,
) -> tuple[Callable[[fdfield_t, fdfield_t, fdfield_t], None], ) -> tuple[Callable[..., None], Callable[..., None]]:
Callable[[fdfield_t, fdfield_t, fdfield_t], None]]:
""" """
Build Yee-step update closures augmented with CPML terms. Build Yee-step update closures augmented with CPML terms.
@ -187,9 +186,9 @@ def updates_with_cpml(
pH = numpy.empty_like(epsilon, dtype=dtype) pH = numpy.empty_like(epsilon, dtype=dtype)
def update_E( def update_E(
e: fdfield_t, e: fdfield,
h: fdfield_t, h: fdfield,
epsilon: fdfield_t, epsilon: fdfield,
) -> None: ) -> None:
dyHx = Dby(h[0]) dyHx = Dby(h[0])
dzHx = Dbz(h[0]) dzHx = Dbz(h[0])
@ -233,9 +232,9 @@ def updates_with_cpml(
e[2] += dt / epsilon[2] * (dxHy - dyHx + pE[2]) e[2] += dt / epsilon[2] * (dxHy - dyHx + pE[2])
def update_H( def update_H(
e: fdfield_t, e: fdfield,
h: fdfield_t, h: fdfield,
mu: fdfield_t | tuple[int, int, int] = (1, 1, 1), mu: fdfield | tuple[int, int, int] = (1, 1, 1),
) -> None: ) -> None:
dyEx = Dfy(e[0]) dyEx = Dfy(e[0])
dzEx = Dfz(e[0]) dzEx = Dfz(e[0])

View file

@ -4,7 +4,7 @@ from numpy.testing import assert_allclose
from types import SimpleNamespace from types import SimpleNamespace
from ..fdfd import bloch from ..fdfd import bloch
from ._bloch_case import EPSILON, G_MATRIX, H_SIZE, K0_X, SHAPE, Y0, Y0_TWO_MODE, build_overlap_fixture from ._bloch_case import EPSILON, G_MATRIX, H_SIZE, K0_X, Y0, Y0_TWO_MODE, build_overlap_fixture
from .utils import assert_close from .utils import assert_close

View file

@ -1,3 +1,5 @@
from typing import cast
import numpy import numpy
import pytest import pytest
from scipy import sparse from scipy import sparse
@ -51,6 +53,10 @@ def _nonsymmetric_tr(left_marker: object):
return fake_get_tr return fake_get_tr
def _dummy_modes() -> tuple[list[tuple[numpy.ndarray, numpy.ndarray]], numpy.ndarray]:
return [_mode(0.0), _mode(0.7)], numpy.array([1.0, 0.5])
def test_get_tr_returns_finite_bounded_transfer_matrices() -> None: def test_get_tr_returns_finite_bounded_transfer_matrices() -> None:
left_modes, right_modes = _mode_sets() left_modes, right_modes = _mode_sets()
@ -103,9 +109,10 @@ def test_get_s_plain_matches_block_assembly_from_get_tr() -> None:
def test_get_s_force_nogain_caps_singular_values(monkeypatch) -> None: def test_get_s_force_nogain_caps_singular_values(monkeypatch) -> None:
monkeypatch.setattr(eme, 'get_tr', _gain_only_tr) monkeypatch.setattr(eme, 'get_tr', _gain_only_tr)
modes, wavenumbers = _dummy_modes()
plain_s = eme.get_s(None, None, None, None) plain_s = eme.get_s(modes, wavenumbers, modes, wavenumbers)
clipped_s = eme.get_s(None, None, None, None, force_nogain=True) clipped_s = eme.get_s(modes, wavenumbers, modes, wavenumbers, force_nogain=True)
plain_singular_values = numpy.linalg.svd(plain_s, compute_uv=False) plain_singular_values = numpy.linalg.svd(plain_s, compute_uv=False)
clipped_singular_values = numpy.linalg.svd(clipped_s, compute_uv=False) clipped_singular_values = numpy.linalg.svd(clipped_s, compute_uv=False)
@ -116,18 +123,20 @@ def test_get_s_force_nogain_caps_singular_values(monkeypatch) -> None:
def test_get_s_force_reciprocal_symmetrizes_output(monkeypatch) -> None: def test_get_s_force_reciprocal_symmetrizes_output(monkeypatch) -> None:
left = object() left = numpy.array([1.0, 0.5])
right = object() right = numpy.array([0.9, 0.4])
modes, _wavenumbers = _dummy_modes()
monkeypatch.setattr(eme, 'get_tr', _nonsymmetric_tr(left)) monkeypatch.setattr(eme, 'get_tr', _nonsymmetric_tr(left))
ss = eme.get_s(None, left, None, right, force_reciprocal=True) ss = eme.get_s(modes, left, modes, right, force_reciprocal=True)
assert_close(ss, ss.T) assert_close(ss, ss.T)
def test_get_s_force_nogain_and_reciprocal_returns_finite_output(monkeypatch) -> None: def test_get_s_force_nogain_and_reciprocal_returns_finite_output(monkeypatch) -> None:
monkeypatch.setattr(eme, 'get_tr', _gain_and_reflection_tr) monkeypatch.setattr(eme, 'get_tr', _gain_and_reflection_tr)
ss = eme.get_s(None, None, None, None, force_nogain=True, force_reciprocal=True) modes, wavenumbers = _dummy_modes()
ss = eme.get_s(modes, wavenumbers, modes, wavenumbers, force_nogain=True, force_reciprocal=True)
assert ss.shape == (4, 4) assert ss.shape == (4, 4)
assert numpy.isfinite(ss).all() assert numpy.isfinite(ss).all()
@ -143,15 +152,15 @@ def test_get_tr_rejects_length_mismatches() -> None:
def test_get_tr_rejects_malformed_mode_tuples() -> None: def test_get_tr_rejects_malformed_mode_tuples() -> None:
bad_modes = [(numpy.ones(4),)] bad_modes = cast(list[tuple[numpy.ndarray, numpy.ndarray]], [(numpy.ones(4, dtype=complex),)])
with pytest.raises(ValueError, match='2-tuple'): with pytest.raises(ValueError, match='2-tuple'):
eme.get_tr(bad_modes, [1.0], bad_modes, [1.0], dxes=DXES) eme.get_tr(bad_modes, [1.0], bad_modes, [1.0], dxes=DXES)
def test_get_tr_rejects_incompatible_field_shapes() -> None: def test_get_tr_rejects_incompatible_field_shapes() -> None:
left_modes = [(numpy.ones(4), numpy.ones(4))] left_modes = [(numpy.ones(4, dtype=complex), numpy.ones(4, dtype=complex))]
right_modes = [(numpy.ones(6), numpy.ones(6))] right_modes = [(numpy.ones(6, dtype=complex), numpy.ones(6, dtype=complex))]
with pytest.raises(ValueError, match='same E/H shapes'): with pytest.raises(ValueError, match='same E/H shapes'):
eme.get_tr(left_modes, [1.0], right_modes, [1.0], dxes=DXES) eme.get_tr(left_modes, [1.0], right_modes, [1.0], dxes=DXES)

View file

@ -85,8 +85,10 @@ def j_distribution(
other_dims = [0, 1, 2] other_dims = [0, 1, 2]
other_dims.remove(dim) other_dims.remove(dim)
dx_prop = (dxes[0][dim][shape[dim + 1] // 2] dx_prop = (
+ dxes[1][dim][shape[dim + 1] // 2]) / 2 # noqa: E128 # TODO is this right for nonuniform dxes? dxes[0][dim][shape[dim + 1] // 2]
+ dxes[1][dim][shape[dim + 1] // 2]
) / 2 # TODO is this right for nonuniform dxes?
# Mask only contains components orthogonal to propagation direction # Mask only contains components orthogonal to propagation direction
center_mask = numpy.zeros(shape, dtype=bool) center_mask = numpy.zeros(shape, dtype=bool)

View file

@ -1,3 +1,5 @@
from typing import cast
import numpy import numpy
from ..fdfd import solvers from ..fdfd import solvers
@ -41,7 +43,7 @@ def test_scipy_qmr_installs_logging_callback_when_missing(monkeypatch) -> None:
def test_generic_forward_preconditions_system_and_guess(monkeypatch) -> None: def test_generic_forward_preconditions_system_and_guess(monkeypatch) -> None:
case = solver_plumbing_case() case = solver_plumbing_case()
captured: dict[str, object] = {} captured: dict[str, numpy.ndarray | float | object] = {}
monkeypatch.setattr(solvers.operators, 'e_full', lambda *args, **kwargs: case.a0) monkeypatch.setattr(solvers.operators, 'e_full', lambda *args, **kwargs: case.a0)
monkeypatch.setattr(solvers.operators, 'e_full_preconditioners', lambda dxes: (case.pl, case.pr)) monkeypatch.setattr(solvers.operators, 'e_full_preconditioners', lambda dxes: (case.pl, case.pr))
@ -63,16 +65,16 @@ def test_generic_forward_preconditions_system_and_guess(monkeypatch) -> None:
E_guess=case.guess, E_guess=case.guess,
) )
assert_close(captured['a'].toarray(), (case.pl @ case.a0 @ case.pr).toarray()) assert_close(cast(object, captured['a']).toarray(), (case.pl @ case.a0 @ case.pr).toarray()) # type: ignore[attr-defined]
assert_close(captured['b'], case.pl @ (-1j * case.omega * case.j)) assert_close(cast(numpy.ndarray, captured['b']), case.pl @ (-1j * case.omega * case.j))
assert_close(captured['x0'], case.pl @ case.guess) assert_close(cast(numpy.ndarray, captured['x0']), case.pl @ case.guess)
assert captured['atol'] == 1e-12 assert captured['atol'] == 1e-12
assert_close(result, case.pr @ case.solver_result) assert_close(result, case.pr @ case.solver_result)
def test_generic_adjoint_preconditions_system_and_guess(monkeypatch) -> None: def test_generic_adjoint_preconditions_system_and_guess(monkeypatch) -> None:
case = solver_plumbing_case() case = solver_plumbing_case()
captured: dict[str, object] = {} captured: dict[str, numpy.ndarray | float | object] = {}
monkeypatch.setattr(solvers.operators, 'e_full', lambda *args, **kwargs: case.a0) monkeypatch.setattr(solvers.operators, 'e_full', lambda *args, **kwargs: case.a0)
monkeypatch.setattr(solvers.operators, 'e_full_preconditioners', lambda dxes: (case.pl, case.pr)) monkeypatch.setattr(solvers.operators, 'e_full_preconditioners', lambda dxes: (case.pl, case.pr))
@ -96,9 +98,9 @@ def test_generic_adjoint_preconditions_system_and_guess(monkeypatch) -> None:
) )
expected_matrix = (case.pl @ case.a0 @ case.pr).T.conjugate() expected_matrix = (case.pl @ case.a0 @ case.pr).T.conjugate()
assert_close(captured['a'].toarray(), expected_matrix.toarray()) assert_close(cast(object, captured['a']).toarray(), expected_matrix.toarray()) # type: ignore[attr-defined]
assert_close(captured['b'], case.pr.T.conjugate() @ (-1j * case.omega * case.j)) assert_close(cast(numpy.ndarray, captured['b']), case.pr.T.conjugate() @ (-1j * case.omega * case.j))
assert_close(captured['x0'], case.pr.T.conjugate() @ case.guess) assert_close(cast(numpy.ndarray, captured['x0']), case.pr.T.conjugate() @ case.guess)
assert captured['rtol'] == 1e-9 assert captured['rtol'] == 1e-9
assert_close(result, case.pl.T.conjugate() @ case.solver_result) assert_close(result, case.pl.T.conjugate() @ case.solver_result)
@ -122,5 +124,5 @@ def test_generic_without_guess_does_not_inject_x0(monkeypatch) -> None:
matrix_solver=fake_solver, matrix_solver=fake_solver,
) )
assert 'x0' not in captured['kwargs'] assert 'x0' not in cast(dict[str, object], captured['kwargs'])
assert_close(result, case.pr @ numpy.array([1.0, -1.0])) assert_close(result, case.pr @ numpy.array([1.0, -1.0]))

View file

@ -1,5 +1,3 @@
import numpy
from ..fdmath import functional as fd_functional from ..fdmath import functional as fd_functional
from ..fdtd import base from ..fdtd import base
from ._test_builders import real_ramp from ._test_builders import real_ramp

View file

@ -58,5 +58,5 @@ def test_conducting_boundary_updates_expected_faces(direction: int, polarity: in
[(-1, 1), (3, 1), (0, 0)], [(-1, 1), (3, 1), (0, 0)],
) )
def test_conducting_boundary_rejects_invalid_arguments(direction: int, polarity: int) -> None: def test_conducting_boundary_rejects_invalid_arguments(direction: int, polarity: int) -> None:
with pytest.raises(Exception): with pytest.raises(ValueError, match='Invalid direction|Bad polarity'):
conducting_boundary(direction, polarity) conducting_boundary(direction, polarity)

View file

@ -10,6 +10,9 @@ def test_gaussian_packet_accepts_array_input(one_sided: bool) -> None:
source, delay = gaussian_packet(1.55, 0.1, dt, one_sided=one_sided) source, delay = gaussian_packet(1.55, 0.1, dt, one_sided=one_sided)
steps = numpy.array([0, int(numpy.ceil(delay / dt)) + 5]) steps = numpy.array([0, int(numpy.ceil(delay / dt)) + 5])
envelope, cc, ss = source(steps) envelope, cc, ss = source(steps)
assert isinstance(envelope, numpy.ndarray)
assert isinstance(cc, numpy.ndarray)
assert isinstance(ss, numpy.ndarray)
assert envelope.shape == (2,) assert envelope.shape == (2,)
assert numpy.isfinite(envelope).all() assert numpy.isfinite(envelope).all()

View file

@ -371,7 +371,7 @@ def _real_pulse_case() -> RealPulseCase:
source_phasor, _delay = gaussian_packet(wl=wavelength, dwl=1.0, dt=dt, turn_on=1e-5) source_phasor, _delay = gaussian_packet(wl=wavelength, dwl=1.0, dt=dt, turn_on=1e-5)
aa, cc, ss = source_phasor(numpy.arange(total_steps) + 0.5) aa, cc, ss = source_phasor(numpy.arange(total_steps) + 0.5)
waveform = aa * (cc + 1j * ss) waveform = numpy.asarray(aa * (cc + 1j * ss), dtype=complex)
scale = fdtd.real_injection_scale(waveform, omega, dt, offset_steps=0.5)[0] scale = fdtd.real_injection_scale(waveform, omega, dt, offset_steps=0.5)[0]
j_accumulator = numpy.zeros((1, *full_shape), dtype=complex) j_accumulator = numpy.zeros((1, *full_shape), dtype=complex)

View file

@ -1,3 +1,5 @@
from typing import Any
import numpy import numpy
import pytest import pytest
@ -12,7 +14,7 @@ from .utils import assert_close
[(3, 1, 4, 1.0), (0, 0, 4, 1.0), (0, 1, 2, 1.0), (0, 1, 4, 0.0)], [(3, 1, 4, 1.0), (0, 0, 4, 1.0), (0, 1, 2, 1.0), (0, 1, 4, 0.0)],
) )
def test_cpml_params_reject_invalid_arguments(axis: int, polarity: int, thickness: int, epsilon_eff: float) -> None: def test_cpml_params_reject_invalid_arguments(axis: int, polarity: int, thickness: int, epsilon_eff: float) -> None:
with pytest.raises(Exception): with pytest.raises(ValueError, match='Invalid axis|Invalid polarity|wise to have a pml|epsilon_eff must be positive'):
cpml_params(axis=axis, polarity=polarity, dt=0.1, thickness=thickness, epsilon_eff=epsilon_eff) cpml_params(axis=axis, polarity=polarity, dt=0.1, thickness=thickness, epsilon_eff=epsilon_eff)
@ -36,7 +38,7 @@ def test_updates_with_cpml_keeps_zero_fields_zero() -> None:
e = numpy.zeros(shape, dtype=float) e = numpy.zeros(shape, dtype=float)
h = numpy.zeros(shape, dtype=float) h = numpy.zeros(shape, dtype=float)
dxes = [[numpy.ones(4), numpy.ones(4), numpy.ones(4)] for _ in range(2)] dxes = [[numpy.ones(4), numpy.ones(4), numpy.ones(4)] for _ in range(2)]
params = [[None, None] for _ in range(3)] params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)]
params[0][0] = cpml_params(axis=0, polarity=-1, dt=0.1, thickness=3) params[0][0] = cpml_params(axis=0, polarity=-1, dt=0.1, thickness=3)
update_e, update_h = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon) update_e, update_h = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon)
@ -69,7 +71,7 @@ def test_updates_with_cpml_matches_base_updates_when_all_faces_disabled() -> Non
e = _real_field(shape, 10.0) e = _real_field(shape, 10.0)
h = _real_field(shape, 100.0) h = _real_field(shape, 100.0)
dxes = _unit_dxes(shape) dxes = _unit_dxes(shape)
params = [[None, None] for _ in range(3)] params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)]
update_e_cpml, update_h_cpml = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon) update_e_cpml, update_h_cpml = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon)
update_e_base = maxwell_e(dt=0.1, dxes=dxes) update_e_base = maxwell_e(dt=0.1, dxes=dxes)
@ -96,7 +98,7 @@ def test_updates_with_cpml_matches_base_updates_with_complex_dtype_when_all_face
e = _complex_field(shape, 10.0) e = _complex_field(shape, 10.0)
h = _complex_field(shape, 100.0) h = _complex_field(shape, 100.0)
dxes = _unit_dxes(shape) dxes = _unit_dxes(shape)
params = [[None, None] for _ in range(3)] params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)]
update_e_cpml, update_h_cpml = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon, dtype=complex) update_e_cpml, update_h_cpml = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon, dtype=complex)
update_e_base = maxwell_e(dt=0.1, dxes=dxes) update_e_base = maxwell_e(dt=0.1, dxes=dxes)
@ -125,7 +127,7 @@ def test_updates_with_cpml_only_changes_the_configured_face_region() -> None:
dxes = _unit_dxes(shape) dxes = _unit_dxes(shape)
thickness = 3 thickness = 3
params = [[None, None] for _ in range(3)] params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)]
params[0][0] = cpml_params(axis=0, polarity=-1, dt=0.1, thickness=thickness) params[0][0] = cpml_params(axis=0, polarity=-1, dt=0.1, thickness=thickness)
update_e_cpml, update_h_cpml = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon) update_e_cpml, update_h_cpml = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon)
@ -166,7 +168,7 @@ def test_cpml_plane_wave_phasor_decays_monotonically_through_outgoing_pml() -> N
epsilon = numpy.ones(shape, dtype=float) epsilon = numpy.ones(shape, dtype=float)
dxes = _unit_dxes(shape) dxes = _unit_dxes(shape)
params = [[None, None] for _ in range(3)] params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)]
for polarity_index, polarity in enumerate((-1, 1)): for polarity_index, polarity in enumerate((-1, 1)):
params[0][polarity_index] = cpml_params(axis=0, polarity=polarity, dt=dt, thickness=thickness) params[0][polarity_index] = cpml_params(axis=0, polarity=polarity, dt=dt, thickness=thickness)
@ -212,7 +214,7 @@ def test_cpml_point_source_total_energy_reaches_late_time_plateau() -> None:
epsilon = numpy.ones(shape, dtype=float) epsilon = numpy.ones(shape, dtype=float)
dxes = _unit_dxes(shape) dxes = _unit_dxes(shape)
params = [[None, None] for _ in range(3)] params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)]
for axis in range(3): for axis in range(3):
for polarity_index, polarity in enumerate((-1, 1)): for polarity_index, polarity in enumerate((-1, 1)):
params[axis][polarity_index] = cpml_params(axis=axis, polarity=polarity, dt=dt, thickness=thickness) params[axis][polarity_index] = cpml_params(axis=axis, polarity=polarity, dt=dt, thickness=thickness)

View file

@ -1,26 +1,28 @@
import builtins import builtins
import importlib import importlib
import pathlib import pathlib
from types import ModuleType
from typing import Any
import pytest
import meanas import meanas
from ..fdfd import bloch from ..fdfd import bloch
from .utils import assert_close
def _reload(module): def _reload(module: ModuleType) -> ModuleType:
return importlib.reload(module) return importlib.reload(module)
def _restore_reloaded(monkeypatch, module): def _restore_reloaded(monkeypatch: pytest.MonkeyPatch, module: ModuleType) -> ModuleType:
monkeypatch.undo() monkeypatch.undo()
return _reload(module) return _reload(module)
def test_meanas_import_survives_readme_open_failure(monkeypatch) -> None: # type: ignore[no-untyped-def] def test_meanas_import_survives_readme_open_failure(monkeypatch: pytest.MonkeyPatch) -> None:
expected_version = meanas.__version__ expected_version = meanas.__version__
original_open = pathlib.Path.open original_open = pathlib.Path.open
def failing_open(self: pathlib.Path, *args, **kwargs): # type: ignore[no-untyped-def] def failing_open(self: pathlib.Path, *args: Any, **kwargs: Any) -> Any:
if self.name == 'README.md': if self.name == 'README.md':
raise FileNotFoundError('forced README failure') raise FileNotFoundError('forced README failure')
return original_open(self, *args, **kwargs) return original_open(self, *args, **kwargs)
@ -35,10 +37,16 @@ def test_meanas_import_survives_readme_open_failure(monkeypatch) -> None: # typ
_restore_reloaded(monkeypatch, meanas) _restore_reloaded(monkeypatch, meanas)
def test_bloch_reloads_with_numpy_fft_when_pyfftw_is_unavailable(monkeypatch) -> None: # type: ignore[no-untyped-def] def test_bloch_reloads_with_numpy_fft_when_pyfftw_is_unavailable(monkeypatch: pytest.MonkeyPatch) -> None:
original_import = builtins.__import__ original_import = builtins.__import__
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0): # type: ignore[no-untyped-def] def fake_import(
name: str,
globals: dict[str, Any] | None = None,
locals: dict[str, Any] | None = None,
fromlist: tuple[str, ...] = (),
level: int = 0,
) -> Any:
if name.startswith('pyfftw'): if name.startswith('pyfftw'):
raise ImportError('forced pyfftw failure') raise ImportError('forced pyfftw failure')
return original_import(name, globals, locals, fromlist, level) return original_import(name, globals, locals, fromlist, level)

View file

@ -224,7 +224,7 @@ def _build_cpml_params() -> list[list[dict[str, numpy.ndarray | float]]]:
def _build_complex_pulse_waveform(total_steps: int) -> tuple[numpy.ndarray, complex]: def _build_complex_pulse_waveform(total_steps: int) -> tuple[numpy.ndarray, complex]:
source_phasor, _delay = gaussian_packet(wl=WAVELENGTH, dwl=PULSE_DWL, dt=DT, turn_on=1e-5) source_phasor, _delay = gaussian_packet(wl=WAVELENGTH, dwl=PULSE_DWL, dt=DT, turn_on=1e-5)
aa, cc, ss = source_phasor(numpy.arange(total_steps) + 0.5) aa, cc, ss = source_phasor(numpy.arange(total_steps) + 0.5)
waveform = aa * (cc + 1j * ss) waveform = numpy.asarray(aa * (cc + 1j * ss), dtype=complex)
scale = fdtd.temporal_phasor_scale(waveform, OMEGA, DT, offset_steps=0.5)[0] scale = fdtd.temporal_phasor_scale(waveform, OMEGA, DT, offset_steps=0.5)[0]
return waveform, scale return waveform, scale
@ -272,7 +272,7 @@ def _run_real_field_straight_waveguide_case() -> RealFieldWaveguideResult:
slices=REAL_FIELD_SOURCE_SLICES, slices=REAL_FIELD_SOURCE_SLICES,
epsilon=epsilon, epsilon=epsilon,
) )
j_mode *= numpy.exp(1j * REAL_FIELD_SOURCE_PHASE) j_mode = j_mode * numpy.exp(1j * REAL_FIELD_SOURCE_PHASE)
monitor_mode = waveguide_3d.solve_mode( monitor_mode = waveguide_3d.solve_mode(
0, 0,
omega=OMEGA, omega=OMEGA,
@ -425,8 +425,8 @@ def _run_straight_waveguide_case(variant: str) -> WaveguideCalibrationResult:
) )
h_fdfd = functional.e2h(OMEGA, stretched_dxes)(e_fdfd) h_fdfd = functional.e2h(OMEGA, stretched_dxes)(e_fdfd)
overlap_td = vec(e_ph) @ vec(overlap_e).conj() overlap_td = complex(vec(e_ph) @ vec(overlap_e).conj())
overlap_fd = vec(e_fdfd) @ vec(overlap_e).conj() overlap_fd = complex(vec(e_fdfd) @ vec(overlap_e).conj())
poynting_td = functional.poynting_e_cross_h(stretched_dxes)(e_ph, h_ph.conj()) poynting_td = functional.poynting_e_cross_h(stretched_dxes)(e_ph, h_ph.conj())
poynting_fd = functional.poynting_e_cross_h(stretched_dxes)(e_fdfd, h_fdfd.conj()) poynting_fd = functional.poynting_e_cross_h(stretched_dxes)(e_fdfd, h_fdfd.conj())
@ -551,10 +551,10 @@ def _run_width_step_scattering_case() -> WaveguideScatteringResult:
) )
h_fdfd = functional.e2h(OMEGA, stretched_dxes)(e_fdfd) h_fdfd = functional.e2h(OMEGA, stretched_dxes)(e_fdfd)
reflected_td = vec(e_ph) @ vec(reflected_overlap).conj() reflected_td = complex(vec(e_ph) @ vec(reflected_overlap).conj())
reflected_fd = vec(e_fdfd) @ vec(reflected_overlap).conj() reflected_fd = complex(vec(e_fdfd) @ vec(reflected_overlap).conj())
transmitted_td = vec(e_ph) @ vec(transmitted_overlap).conj() transmitted_td = complex(vec(e_ph) @ vec(transmitted_overlap).conj())
transmitted_fd = vec(e_fdfd) @ vec(transmitted_overlap).conj() transmitted_fd = complex(vec(e_fdfd) @ vec(transmitted_overlap).conj())
poynting_td = functional.poynting_e_cross_h(stretched_dxes)(e_ph, h_ph.conj()) poynting_td = functional.poynting_e_cross_h(stretched_dxes)(e_ph, h_ph.conj())
poynting_fd = functional.poynting_e_cross_h(stretched_dxes)(e_fdfd, h_fdfd.conj()) poynting_fd = functional.poynting_e_cross_h(stretched_dxes)(e_fdfd, h_fdfd.conj())
@ -664,8 +664,8 @@ def _run_pulsed_straight_waveguide_case() -> PulsedWaveguideCalibrationResult:
) )
h_fdfd = functional.e2h(OMEGA, stretched_dxes)(e_fdfd) h_fdfd = functional.e2h(OMEGA, stretched_dxes)(e_fdfd)
overlap_td = vec(e_ph) @ vec(overlap_e).conj() overlap_td = complex(vec(e_ph) @ vec(overlap_e).conj())
overlap_fd = vec(e_fdfd) @ vec(overlap_e).conj() overlap_fd = complex(vec(e_fdfd) @ vec(overlap_e).conj())
poynting_td = functional.poynting_e_cross_h(stretched_dxes)(e_ph, h_ph.conj()) poynting_td = functional.poynting_e_cross_h(stretched_dxes)(e_ph, h_ph.conj())
poynting_fd = functional.poynting_e_cross_h(stretched_dxes)(e_fdfd, h_fdfd.conj()) poynting_fd = functional.poynting_e_cross_h(stretched_dxes)(e_fdfd, h_fdfd.conj())

View file

@ -16,7 +16,7 @@ def build_waveguide_3d_mode(
*, *,
slice_start: int, slice_start: int,
polarity: int, polarity: int,
) -> tuple[numpy.ndarray, list[list[numpy.ndarray]], tuple[slice, slice, slice], dict[str, complex | numpy.ndarray]]: ) -> tuple[numpy.ndarray, list[list[numpy.ndarray]], tuple[slice, slice, slice], waveguide_3d.Waveguide3DMode]:
epsilon = numpy.ones((3, 5, 5, 1), dtype=float) epsilon = numpy.ones((3, 5, 5, 1), dtype=float)
dxes = [[numpy.ones(5), numpy.ones(5), numpy.ones(1)] for _ in range(2)] dxes = [[numpy.ones(5), numpy.ones(5), numpy.ones(1)] for _ in range(2)]
slices = (slice(slice_start, slice_start + 1), slice(None), slice(None)) slices = (slice(slice_start, slice_start + 1), slice(None), slice(None))

View file

@ -1,5 +1,6 @@
import numpy import numpy
from numpy.typing import NDArray from numpy.typing import NDArray
from numpy.typing import ArrayLike
def make_prng(seed: int = 12345) -> numpy.random.RandomState: def make_prng(seed: int = 12345) -> numpy.random.RandomState:
@ -24,9 +25,9 @@ def assert_fields_close(
) )
def assert_close( def assert_close(
x: NDArray, x: ArrayLike,
y: NDArray, y: ArrayLike,
*args, *args,
**kwargs, **kwargs,
) -> None: ) -> None:
numpy.testing.assert_allclose(x, y, *args, **kwargs) numpy.testing.assert_allclose(numpy.asarray(x), numpy.asarray(y), *args, **kwargs)

View file

@ -104,6 +104,9 @@ lint.ignore = [
"TRY002", # Exception() "TRY002", # Exception()
] ]
[tool.ruff.lint.per-file-ignores]
"meanas/test/**/*.py" = ["ANN", "ARG", "TC006"]
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = [ module = [