type hints and lint

This commit is contained in:
Forgejo Actions 2026-04-21 21:13:34 -07:00
commit c6c9159b13
30 changed files with 198 additions and 136 deletions

View file

@ -14,6 +14,7 @@ simple straight interface:
from __future__ import annotations
import importlib
from typing import TYPE_CHECKING
import numpy
from numpy import pi
@ -24,6 +25,9 @@ from gridlock import Extent
from meanas.fdfd import eme, waveguide_2d
from meanas.fdmath import unvec
if TYPE_CHECKING:
from types import ModuleType
WL = 1310.0
DX = 40.0
@ -35,7 +39,7 @@ EPS_OX = 1.453 ** 2
MODE_NUMBERS = numpy.array([0])
def require_optional(name: str, package_name: str | None = None):
def require_optional(name: str, package_name: str | None = None) -> ModuleType:
package_name = package_name or name
try:
return importlib.import_module(name)
@ -159,7 +163,7 @@ def print_summary(ss: numpy.ndarray, wavenumbers_left: numpy.ndarray, wavenumber
def plot_results(
*,
pyplot,
pyplot: ModuleType,
ss: numpy.ndarray,
left_mode: tuple[numpy.ndarray, numpy.ndarray],
right_mode: tuple[numpy.ndarray, numpy.ndarray],

View file

@ -15,6 +15,7 @@ This example demonstrates a cylindrical-waveguide EME workflow:
from __future__ import annotations
import importlib
from typing import TYPE_CHECKING
import numpy
from numpy import pi
@ -26,6 +27,9 @@ from gridlock import Extent
from meanas.fdfd import eme, waveguide_2d, waveguide_cyl
from meanas.fdmath import unvec
if TYPE_CHECKING:
from types import ModuleType
WL = 1310.0
DX = 40.0
@ -40,7 +44,7 @@ STRAIGHT_SECTION_LENGTH = 12e3
BEND_ANGLE = pi / 2
def require_optional(name: str, package_name: str | None = None):
def require_optional(name: str, package_name: str | None = None) -> ModuleType:
package_name = package_name or name
try:
return importlib.import_module(name)
@ -163,7 +167,7 @@ def solve_bend_modes(
def build_cascaded_network(
skrf,
skrf: ModuleType,
*,
interface_s: numpy.ndarray,
straight_wavenumbers: numpy.ndarray,
@ -216,7 +220,7 @@ def print_summary(
def plot_results(
*,
pyplot,
pyplot: ModuleType,
interface_s: numpy.ndarray,
cascaded_s: numpy.ndarray,
straight_mode: tuple[numpy.ndarray, numpy.ndarray],

View file

@ -89,7 +89,7 @@ def perturbed_l3(a: float, radius: float, **kwargs) -> Pattern:
return pat
def main():
def main() -> None:
dtype = numpy.float32
max_t = 3600 # number of timesteps
@ -97,7 +97,6 @@ def main():
pml_thickness = 8 # (number of cells)
wl = 1550 # Excitation wavelength and fwhm
dwl = 100
# Device design parameters
xy_size = numpy.array([10, 10])

View file

@ -683,7 +683,16 @@ def eigsolve(
return numpy.abs(trace)
if False:
def trace_deriv(theta, sgn: int = sgn, ZtAZ=ZtAZ, DtAD=DtAD, symZtD=symZtD, symZtAD=symZtAD, ZtZ=ZtZ, DtD=DtD): # noqa: ANN001
def trace_deriv(
theta: float,
sgn: int = sgn,
ZtAZ=ZtAZ, # noqa: ANN001
DtAD=DtAD, # noqa: ANN001
symZtD=symZtD, # noqa: ANN001
symZtAD=symZtAD, # noqa: ANN001
ZtZ=ZtZ, # noqa: ANN001
DtD=DtD, # noqa: ANN001
) -> float:
Qi = Qi_func(theta)
c2 = numpy.cos(2 * theta)
s2 = numpy.sin(2 * theta)

View file

@ -27,11 +27,13 @@ from scipy import sparse
from ..fdmath import dx_lists2_t, vcfdfield2
from .waveguide_2d import inner_product
type wavenumber_seq = Sequence[complex] | NDArray[numpy.complexfloating] | NDArray[numpy.floating]
def _validate_port_modes(
name: str,
ehs: Sequence[Sequence[vcfdfield2]],
wavenumbers: Sequence[complex],
wavenumbers: wavenumber_seq,
) -> tuple[tuple[int, ...], tuple[int, ...]]:
if len(ehs) != len(wavenumbers):
raise ValueError(f'{name} mode list and wavenumber list must have the same length')
@ -61,9 +63,9 @@ def _validate_port_modes(
def get_tr(
ehLs: Sequence[Sequence[vcfdfield2]],
wavenumbers_L: Sequence[complex],
wavenumbers_L: wavenumber_seq,
ehRs: Sequence[Sequence[vcfdfield2]],
wavenumbers_R: Sequence[complex],
wavenumbers_R: wavenumber_seq,
dxes: dx_lists2_t,
) -> tuple[NDArray[numpy.complex128], NDArray[numpy.complex128]]:
"""
@ -118,9 +120,9 @@ def get_tr(
def get_abcd(
ehLs: Sequence[Sequence[vcfdfield2]],
wavenumbers_L: Sequence[complex],
wavenumbers_L: wavenumber_seq,
ehRs: Sequence[Sequence[vcfdfield2]],
wavenumbers_R: Sequence[complex],
wavenumbers_R: wavenumber_seq,
**kwargs,
) -> sparse.sparray:
"""
@ -151,9 +153,9 @@ def get_abcd(
def get_s(
ehLs: Sequence[Sequence[vcfdfield2]],
wavenumbers_L: Sequence[complex],
wavenumbers_L: wavenumber_seq,
ehRs: Sequence[Sequence[vcfdfield2]],
wavenumbers_R: Sequence[complex],
wavenumbers_R: wavenumber_seq,
force_nogain: bool = False,
force_reciprocal: bool = False,
**kwargs,

View file

@ -1,23 +1,24 @@
"""
Functions for performing near-to-farfield transformation (and the reverse).
"""
from typing import Any, cast, TYPE_CHECKING
from typing import Any, cast
from collections.abc import Sequence
import numpy
from numpy.fft import fft2, fftshift, fftfreq, ifft2, ifftshift
from numpy import pi
from numpy.typing import NDArray
from numpy import complexfloating
from ..fdmath import cfdfield_t
if TYPE_CHECKING:
from collections.abc import Sequence
type farfield_slice = NDArray[complexfloating]
type transverse_slice_pair = Sequence[farfield_slice]
def near_to_farfield(
E_near: cfdfield_t,
H_near: cfdfield_t,
E_near: transverse_slice_pair,
H_near: transverse_slice_pair,
dx: float,
dy: float,
padded_size: list[int] | int | None = None
padded_size: Sequence[int] | int | None = None
) -> dict[str, Any]:
"""
Compute the farfield, i.e. the distribution of the fields after propagation
@ -58,7 +59,7 @@ def near_to_farfield(
raise Exception('H_near must be a length-2 list of ndarrays')
s = E_near[0].shape
if not all(s == f.shape for f in E_near + H_near):
if not all(s == f.shape for f in [*E_near, *H_near]):
raise Exception('All fields must be the same shape!')
if padded_size is None:
@ -123,11 +124,11 @@ def near_to_farfield(
def far_to_nearfield(
E_far: cfdfield_t,
H_far: cfdfield_t,
E_far: transverse_slice_pair,
H_far: transverse_slice_pair,
dkx: float,
dky: float,
padded_size: list[int] | int | None = None
padded_size: Sequence[int] | int | None = None
) -> dict[str, Any]:
"""
Compute the farfield, i.e. the distribution of the fields after propagation
@ -164,7 +165,7 @@ def far_to_nearfield(
raise Exception('H_far must be a length-2 list of ndarrays')
s = E_far[0].shape
if not all(s == f.shape for f in E_far + H_far):
if not all(s == f.shape for f in [*E_far, *H_far]):
raise Exception('All fields must be the same shape!')
if padded_size is None:

View file

@ -423,10 +423,10 @@ def normalized_fields_h(
def _normalized_fields(
e: vcfdslice,
h: vcfdslice,
omega: complex,
_omega: complex,
dxes: dx_lists2_t,
epsilon: vfdslice,
mu: vfdslice | None = None,
_mu: vfdslice | None = None,
prop_phase: float = 0,
) -> tuple[vcfdslice_t, vcfdslice_t]:
r"""

View file

@ -19,9 +19,8 @@ The intended workflow is:
That same convention controls which side of the selected slice is used for the
overlap window and how the expanded field is phased.
"""
from typing import Any, cast
from typing import Any, TypedDict, cast
import warnings
from typing import Any
from collections.abc import Sequence
import numpy
from numpy.typing import NDArray
@ -31,6 +30,13 @@ from ..fdmath import vec, unvec, dx_lists_t, cfdfield_t, fdfield, cfdfield
from . import operators, waveguide_2d
class Waveguide3DMode(TypedDict):
wavenumber: complex
wavenumber_2d: complex
H: NDArray[complexfloating]
E: NDArray[complexfloating]
def solve_mode(
mode_number: int,
omega: complex,
@ -40,7 +46,7 @@ def solve_mode(
slices: Sequence[slice],
epsilon: fdfield,
mu: fdfield | None = None,
) -> dict[str, complex | NDArray[complexfloating]]:
) -> Waveguide3DMode:
r"""
Given a 3D grid, selects a slice from the grid and attempts to
solve for an eigenmode propagating through that slice.
@ -121,7 +127,7 @@ def solve_mode(
E[iii] = e[oo][:, :, None].transpose(reverse_order)
H[iii] = h[oo][:, :, None].transpose(reverse_order)
results = {
results: Waveguide3DMode = {
'wavenumber': wavenumber,
'wavenumber_2d': wavenumber_2d,
'H': H,
@ -184,13 +190,13 @@ def compute_source(
def compute_overlap_e(
E: cfdfield_t,
E: cfdfield,
wavenumber: complex,
dxes: dx_lists_t,
axis: int,
polarity: int,
slices: Sequence[slice],
omega: float,
_omega: float,
) -> cfdfield_t:
r"""
Build an overlap field for projecting another 3D electric field onto a mode.
@ -262,7 +268,7 @@ def compute_overlap_e(
if clipped_start >= clipped_stop:
raise ValueError('Requested overlap window lies outside the domain')
if clipped_start != start or clipped_stop != stop:
warnings.warn('Requested overlap window was clipped to fit within the domain', RuntimeWarning)
warnings.warn('Requested overlap window was clipped to fit within the domain', RuntimeWarning, stacklevel=2)
slices2_l = list(slices)
slices2_l[axis] = slice(clipped_start, clipped_stop)
@ -275,7 +281,7 @@ def compute_overlap_e(
norm = (Etgt.conj() * Etgt).sum()
if norm == 0:
raise ValueError('Requested overlap window contains no overlap field support')
Etgt /= norm
Etgt = Etgt / norm
return cfdfield_t(Etgt)

View file

@ -130,7 +130,7 @@ import numpy
from numpy.typing import NDArray, ArrayLike
from scipy import sparse
from ..fdmath import vec, unvec, dx_lists2_t, vcfdslice_t, vcfdfield2_t, vfdslice, vcfdslice, vcfdfield2
from ..fdmath import vec, unvec, dx_lists2_t, vcfdslice_t, vfdslice, vcfdslice, vcfdfield2
from ..fdmath.operators import deriv_forward, deriv_back
from ..eigensolvers import signed_eigensolve, rayleigh_quotient_iteration
from . import waveguide_2d
@ -267,7 +267,7 @@ def solve_mode(
mode_number: int,
*args: Any,
**kwargs: Any,
) -> tuple[vcfdslice, complex]:
) -> tuple[vcfdfield2, complex]:
"""
Wrapper around `solve_modes()` that solves for a single mode.
@ -285,7 +285,7 @@ def solve_mode(
def linear_wavenumbers(
e_xys: list[vcfdfield2_t],
e_xys: Sequence[vcfdfield2] | NDArray[numpy.complex128],
angular_wavenumbers: ArrayLike,
epsilon: vfdslice,
dxes: dx_lists2_t,
@ -537,11 +537,11 @@ def normalized_fields_e(
def _normalized_fields(
e: vcfdslice,
h: vcfdslice,
omega: complex,
_omega: complex,
dxes: dx_lists2_t,
rmin: float, # Currently unused, but may want to use cylindrical poynting
_rmin: float, # Currently unused, but may want to use cylindrical poynting
epsilon: vfdslice,
mu: vfdslice | None = None,
_mu: vfdslice | None = None,
prop_phase: float = 0,
) -> tuple[vcfdslice_t, vcfdslice_t]:
r"""

View file

@ -10,7 +10,7 @@ import numpy
from numpy.typing import NDArray
from numpy import floating, complexfloating
from .types import fdfield_t, fdfield_updater_t
from .types import fdfield, fdfield_updater_t
def deriv_forward(
@ -127,7 +127,7 @@ def curl_forward_parts(
) -> Callable:
Dx, Dy, Dz = deriv_forward(dx_e)
def mkparts_fwd(e: fdfield_t) -> tuple[tuple[fdfield_t, fdfield_t], ...]:
def mkparts_fwd(e: fdfield) -> tuple[tuple[fdfield, fdfield], ...]:
return ((-Dz(e[1]), Dy(e[2])),
( Dz(e[0]), -Dx(e[2])),
(-Dy(e[0]), Dx(e[1])))
@ -140,7 +140,7 @@ def curl_back_parts(
) -> Callable:
Dx, Dy, Dz = deriv_back(dx_h)
def mkparts_back(h: fdfield_t) -> tuple[tuple[fdfield_t, fdfield_t], ...]:
def mkparts_back(h: fdfield) -> tuple[tuple[fdfield, fdfield], ...]:
return ((-Dz(h[1]), Dy(h[2])),
( Dz(h[0]), -Dx(h[2])),
(-Dy(h[0]), Dx(h[1])))

View file

@ -9,7 +9,7 @@ from numpy.typing import NDArray
from numpy import floating, complexfloating
from scipy import sparse
from .types import vfdfield_t
from .types import vfdfield
def shift_circ(
@ -171,7 +171,7 @@ def cross(
[-B[1], B[0], zero]])
def vec_cross(b: vfdfield_t) -> sparse.sparray:
def vec_cross(b: vfdfield) -> sparse.sparray:
"""
Vector cross product operator

View file

@ -88,8 +88,8 @@ dx_lists2_mut = MutableSequence[MutableSequence[NDArray[floating | complexfloati
"""Mutable version of `dx_lists2_t`"""
fdfield_updater_t = Callable[..., fdfield_t]
"""Convenience type for functions which take and return an fdfield_t"""
fdfield_updater_t = Callable[..., fdfield]
"""Convenience type for functions which take and return a real `fdfield`"""
cfdfield_updater_t = Callable[..., cfdfield_t]
"""Convenience type for functions which take and return an cfdfield_t"""
cfdfield_updater_t = Callable[..., cfdfield]
"""Convenience type for functions which take and return a complex `cfdfield`"""

View file

@ -3,7 +3,7 @@ Basic FDTD field updates
"""
from ..fdmath import dx_lists_t, fdfield_t, fdfield_updater_t
from ..fdmath import dx_lists_t, fdfield, fdfield_updater_t
from ..fdmath.functional import curl_forward, curl_back
@ -47,7 +47,7 @@ def maxwell_e(
else:
curl_h_fun = curl_back()
def me_fun(e: fdfield_t, h: fdfield_t, epsilon: fdfield_t | float) -> fdfield_t:
def me_fun(e: fdfield, h: fdfield, epsilon: fdfield | float) -> fdfield:
"""
Update the E-field.
@ -103,7 +103,7 @@ def maxwell_h(
else:
curl_e_fun = curl_forward()
def mh_fun(e: fdfield_t, h: fdfield_t, mu: fdfield_t | float | None = None) -> fdfield_t:
def mh_fun(e: fdfield, h: fdfield, mu: fdfield | float | None = None) -> fdfield:
"""
Update the H-field.

View file

@ -6,7 +6,7 @@ Boundary conditions
from typing import Any
from ..fdmath import fdfield_t, fdfield_updater_t
from ..fdmath import fdfield, fdfield_updater_t
def conducting_boundary(
@ -15,7 +15,7 @@ def conducting_boundary(
) -> tuple[fdfield_updater_t, fdfield_updater_t]:
dirs = [0, 1, 2]
if direction not in dirs:
raise Exception(f'Invalid direction: {direction}')
raise ValueError(f'Invalid direction: {direction}')
dirs.remove(direction)
u, v = dirs
@ -31,13 +31,13 @@ def conducting_boundary(
boundary = tuple(boundary_slice)
shifted1 = tuple(shifted1_slice)
def en(e: fdfield_t) -> fdfield_t:
def en(e: fdfield) -> fdfield:
e[direction][boundary] = 0
e[u][boundary] = e[u][shifted1]
e[v][boundary] = e[v][shifted1]
return e
def hn(h: fdfield_t) -> fdfield_t:
def hn(h: fdfield) -> fdfield:
h[direction][boundary] = h[direction][shifted1]
h[u][boundary] = 0
h[v][boundary] = 0
@ -56,14 +56,14 @@ def conducting_boundary(
shifted1 = tuple(shifted1_slice)
shifted2 = tuple(shifted2_slice)
def ep(e: fdfield_t) -> fdfield_t:
def ep(e: fdfield) -> fdfield:
e[direction][boundary] = -e[direction][shifted2]
e[direction][shifted1] = 0
e[u][boundary] = e[u][shifted1]
e[v][boundary] = e[v][shifted1]
return e
def hp(h: fdfield_t) -> fdfield_t:
def hp(h: fdfield) -> fdfield:
h[direction][boundary] = h[direction][shifted1]
h[u][boundary] = -h[u][shifted2]
h[u][shifted1] = 0
@ -73,4 +73,4 @@ def conducting_boundary(
return ep, hp
raise Exception(f'Bad polarity: {polarity}')
raise ValueError(f'Bad polarity: {polarity}')

View file

@ -1,5 +1,6 @@
from collections.abc import Callable
import logging
from typing import cast
import numpy
from numpy.typing import NDArray, ArrayLike
@ -9,7 +10,14 @@ from numpy import pi
logger = logging.getLogger(__name__)
pulse_fn_t = Callable[[int | NDArray], tuple[float, float, float]]
type pulse_scalar_t = float | NDArray[numpy.floating]
pulse_fn_t = Callable[[ArrayLike], tuple[pulse_scalar_t, pulse_scalar_t, pulse_scalar_t]]
def _scalar_or_array(values: NDArray[numpy.floating]) -> pulse_scalar_t:
if values.ndim == 0:
return float(values)
return cast('NDArray[numpy.floating]', values)
def gaussian_packet(
@ -49,8 +57,9 @@ def gaussian_packet(
delay = numpy.ceil(delay * freq) / freq # force delay to integer number of periods to maintain phase
logger.info(f'src_time {2 * delay / dt}')
def source_phasor(ii: int | NDArray) -> tuple[float, float, float]:
t0 = ii * dt - delay
def source_phasor(ii: ArrayLike) -> tuple[pulse_scalar_t, pulse_scalar_t, pulse_scalar_t]:
ii_array = numpy.asarray(ii, dtype=float)
t0 = ii_array * dt - delay
envelope = numpy.sqrt(numpy.sqrt(2 * alpha / pi)) * numpy.exp(-alpha * t0 * t0)
if one_sided:
@ -59,7 +68,7 @@ def gaussian_packet(
cc = numpy.cos(omega * t0)
ss = numpy.sin(omega * t0)
return envelope, cc, ss
return _scalar_or_array(envelope), _scalar_or_array(cc), _scalar_or_array(ss)
# nrm = numpy.exp(-omega * omega / alpha) / 2
@ -105,15 +114,16 @@ def ricker_pulse(
delay = delay_results.root
delay = numpy.ceil(delay * freq) / freq # force delay to integer number of periods to maintain phase
def source_phasor(ii: int | NDArray) -> tuple[float, float, float]:
t0 = ii * dt - delay
def source_phasor(ii: ArrayLike) -> tuple[pulse_scalar_t, pulse_scalar_t, pulse_scalar_t]:
ii_array = numpy.asarray(ii, dtype=float)
t0 = ii_array * dt - delay
rr = omega * t0 / 2
ff = (1 - 2 * rr * rr) * numpy.exp(-rr * rr)
cc = numpy.cos(omega * t0)
ss = numpy.sin(omega * t0)
return ff, cc, ss
return _scalar_or_array(ff), _scalar_or_array(cc), _scalar_or_array(ss)
return source_phasor, delay

View file

@ -23,7 +23,7 @@ from copy import deepcopy
import numpy
from numpy.typing import NDArray, DTypeLike
from ..fdmath import fdfield, fdfield_t, dx_lists_t
from ..fdmath import fdfield, dx_lists_t
from ..fdmath.functional import deriv_forward, deriv_back
@ -67,16 +67,16 @@ def cpml_params(
"""
if axis not in range(3):
raise Exception(f'Invalid axis: {axis}')
raise ValueError(f'Invalid axis: {axis}')
if polarity not in (-1, 1):
raise Exception(f'Invalid polarity: {polarity}')
raise ValueError(f'Invalid polarity: {polarity}')
if thickness <= 2:
raise Exception('It would be wise to have a pml with 4+ cells of thickness')
raise ValueError('It would be wise to have a pml with 4+ cells of thickness')
if epsilon_eff <= 0:
raise Exception('epsilon_eff must be positive')
raise ValueError('epsilon_eff must be positive')
sigma_max = -ln_R_per_layer / 2 * (m + 1)
kappa_max = numpy.sqrt(epsilon_eff * mu_eff)
@ -129,8 +129,7 @@ def updates_with_cpml(
epsilon: fdfield,
*,
dtype: DTypeLike = numpy.float32,
) -> tuple[Callable[[fdfield_t, fdfield_t, fdfield_t], None],
Callable[[fdfield_t, fdfield_t, fdfield_t], None]]:
) -> tuple[Callable[..., None], Callable[..., None]]:
"""
Build Yee-step update closures augmented with CPML terms.
@ -187,9 +186,9 @@ def updates_with_cpml(
pH = numpy.empty_like(epsilon, dtype=dtype)
def update_E(
e: fdfield_t,
h: fdfield_t,
epsilon: fdfield_t,
e: fdfield,
h: fdfield,
epsilon: fdfield,
) -> None:
dyHx = Dby(h[0])
dzHx = Dbz(h[0])
@ -233,9 +232,9 @@ def updates_with_cpml(
e[2] += dt / epsilon[2] * (dxHy - dyHx + pE[2])
def update_H(
e: fdfield_t,
h: fdfield_t,
mu: fdfield_t | tuple[int, int, int] = (1, 1, 1),
e: fdfield,
h: fdfield,
mu: fdfield | tuple[int, int, int] = (1, 1, 1),
) -> None:
dyEx = Dfy(e[0])
dzEx = Dfz(e[0])

View file

@ -4,7 +4,7 @@ from numpy.testing import assert_allclose
from types import SimpleNamespace
from ..fdfd import bloch
from ._bloch_case import EPSILON, G_MATRIX, H_SIZE, K0_X, SHAPE, Y0, Y0_TWO_MODE, build_overlap_fixture
from ._bloch_case import EPSILON, G_MATRIX, H_SIZE, K0_X, Y0, Y0_TWO_MODE, build_overlap_fixture
from .utils import assert_close

View file

@ -1,3 +1,5 @@
from typing import cast
import numpy
import pytest
from scipy import sparse
@ -51,6 +53,10 @@ def _nonsymmetric_tr(left_marker: object):
return fake_get_tr
def _dummy_modes() -> tuple[list[tuple[numpy.ndarray, numpy.ndarray]], numpy.ndarray]:
return [_mode(0.0), _mode(0.7)], numpy.array([1.0, 0.5])
def test_get_tr_returns_finite_bounded_transfer_matrices() -> None:
left_modes, right_modes = _mode_sets()
@ -103,9 +109,10 @@ def test_get_s_plain_matches_block_assembly_from_get_tr() -> None:
def test_get_s_force_nogain_caps_singular_values(monkeypatch) -> None:
monkeypatch.setattr(eme, 'get_tr', _gain_only_tr)
modes, wavenumbers = _dummy_modes()
plain_s = eme.get_s(None, None, None, None)
clipped_s = eme.get_s(None, None, None, None, force_nogain=True)
plain_s = eme.get_s(modes, wavenumbers, modes, wavenumbers)
clipped_s = eme.get_s(modes, wavenumbers, modes, wavenumbers, force_nogain=True)
plain_singular_values = numpy.linalg.svd(plain_s, compute_uv=False)
clipped_singular_values = numpy.linalg.svd(clipped_s, compute_uv=False)
@ -116,18 +123,20 @@ def test_get_s_force_nogain_caps_singular_values(monkeypatch) -> None:
def test_get_s_force_reciprocal_symmetrizes_output(monkeypatch) -> None:
left = object()
right = object()
left = numpy.array([1.0, 0.5])
right = numpy.array([0.9, 0.4])
modes, _wavenumbers = _dummy_modes()
monkeypatch.setattr(eme, 'get_tr', _nonsymmetric_tr(left))
ss = eme.get_s(None, left, None, right, force_reciprocal=True)
ss = eme.get_s(modes, left, modes, right, force_reciprocal=True)
assert_close(ss, ss.T)
def test_get_s_force_nogain_and_reciprocal_returns_finite_output(monkeypatch) -> None:
monkeypatch.setattr(eme, 'get_tr', _gain_and_reflection_tr)
ss = eme.get_s(None, None, None, None, force_nogain=True, force_reciprocal=True)
modes, wavenumbers = _dummy_modes()
ss = eme.get_s(modes, wavenumbers, modes, wavenumbers, force_nogain=True, force_reciprocal=True)
assert ss.shape == (4, 4)
assert numpy.isfinite(ss).all()
@ -143,15 +152,15 @@ def test_get_tr_rejects_length_mismatches() -> None:
def test_get_tr_rejects_malformed_mode_tuples() -> None:
bad_modes = [(numpy.ones(4),)]
bad_modes = cast(list[tuple[numpy.ndarray, numpy.ndarray]], [(numpy.ones(4, dtype=complex),)])
with pytest.raises(ValueError, match='2-tuple'):
eme.get_tr(bad_modes, [1.0], bad_modes, [1.0], dxes=DXES)
def test_get_tr_rejects_incompatible_field_shapes() -> None:
left_modes = [(numpy.ones(4), numpy.ones(4))]
right_modes = [(numpy.ones(6), numpy.ones(6))]
left_modes = [(numpy.ones(4, dtype=complex), numpy.ones(4, dtype=complex))]
right_modes = [(numpy.ones(6, dtype=complex), numpy.ones(6, dtype=complex))]
with pytest.raises(ValueError, match='same E/H shapes'):
eme.get_tr(left_modes, [1.0], right_modes, [1.0], dxes=DXES)

View file

@ -85,8 +85,10 @@ def j_distribution(
other_dims = [0, 1, 2]
other_dims.remove(dim)
dx_prop = (dxes[0][dim][shape[dim + 1] // 2]
+ dxes[1][dim][shape[dim + 1] // 2]) / 2 # noqa: E128 # TODO is this right for nonuniform dxes?
dx_prop = (
dxes[0][dim][shape[dim + 1] // 2]
+ dxes[1][dim][shape[dim + 1] // 2]
) / 2 # TODO is this right for nonuniform dxes?
# Mask only contains components orthogonal to propagation direction
center_mask = numpy.zeros(shape, dtype=bool)

View file

@ -1,3 +1,5 @@
from typing import cast
import numpy
from ..fdfd import solvers
@ -41,7 +43,7 @@ def test_scipy_qmr_installs_logging_callback_when_missing(monkeypatch) -> None:
def test_generic_forward_preconditions_system_and_guess(monkeypatch) -> None:
case = solver_plumbing_case()
captured: dict[str, object] = {}
captured: dict[str, numpy.ndarray | float | object] = {}
monkeypatch.setattr(solvers.operators, 'e_full', lambda *args, **kwargs: case.a0)
monkeypatch.setattr(solvers.operators, 'e_full_preconditioners', lambda dxes: (case.pl, case.pr))
@ -63,16 +65,16 @@ def test_generic_forward_preconditions_system_and_guess(monkeypatch) -> None:
E_guess=case.guess,
)
assert_close(captured['a'].toarray(), (case.pl @ case.a0 @ case.pr).toarray())
assert_close(captured['b'], case.pl @ (-1j * case.omega * case.j))
assert_close(captured['x0'], case.pl @ case.guess)
assert_close(cast(object, captured['a']).toarray(), (case.pl @ case.a0 @ case.pr).toarray()) # type: ignore[attr-defined]
assert_close(cast(numpy.ndarray, captured['b']), case.pl @ (-1j * case.omega * case.j))
assert_close(cast(numpy.ndarray, captured['x0']), case.pl @ case.guess)
assert captured['atol'] == 1e-12
assert_close(result, case.pr @ case.solver_result)
def test_generic_adjoint_preconditions_system_and_guess(monkeypatch) -> None:
case = solver_plumbing_case()
captured: dict[str, object] = {}
captured: dict[str, numpy.ndarray | float | object] = {}
monkeypatch.setattr(solvers.operators, 'e_full', lambda *args, **kwargs: case.a0)
monkeypatch.setattr(solvers.operators, 'e_full_preconditioners', lambda dxes: (case.pl, case.pr))
@ -96,9 +98,9 @@ def test_generic_adjoint_preconditions_system_and_guess(monkeypatch) -> None:
)
expected_matrix = (case.pl @ case.a0 @ case.pr).T.conjugate()
assert_close(captured['a'].toarray(), expected_matrix.toarray())
assert_close(captured['b'], case.pr.T.conjugate() @ (-1j * case.omega * case.j))
assert_close(captured['x0'], case.pr.T.conjugate() @ case.guess)
assert_close(cast(object, captured['a']).toarray(), expected_matrix.toarray()) # type: ignore[attr-defined]
assert_close(cast(numpy.ndarray, captured['b']), case.pr.T.conjugate() @ (-1j * case.omega * case.j))
assert_close(cast(numpy.ndarray, captured['x0']), case.pr.T.conjugate() @ case.guess)
assert captured['rtol'] == 1e-9
assert_close(result, case.pl.T.conjugate() @ case.solver_result)
@ -122,5 +124,5 @@ def test_generic_without_guess_does_not_inject_x0(monkeypatch) -> None:
matrix_solver=fake_solver,
)
assert 'x0' not in captured['kwargs']
assert 'x0' not in cast(dict[str, object], captured['kwargs'])
assert_close(result, case.pr @ numpy.array([1.0, -1.0]))

View file

@ -1,5 +1,3 @@
import numpy
from ..fdmath import functional as fd_functional
from ..fdtd import base
from ._test_builders import real_ramp

View file

@ -58,5 +58,5 @@ def test_conducting_boundary_updates_expected_faces(direction: int, polarity: in
[(-1, 1), (3, 1), (0, 0)],
)
def test_conducting_boundary_rejects_invalid_arguments(direction: int, polarity: int) -> None:
with pytest.raises(Exception):
with pytest.raises(ValueError, match='Invalid direction|Bad polarity'):
conducting_boundary(direction, polarity)

View file

@ -10,6 +10,9 @@ def test_gaussian_packet_accepts_array_input(one_sided: bool) -> None:
source, delay = gaussian_packet(1.55, 0.1, dt, one_sided=one_sided)
steps = numpy.array([0, int(numpy.ceil(delay / dt)) + 5])
envelope, cc, ss = source(steps)
assert isinstance(envelope, numpy.ndarray)
assert isinstance(cc, numpy.ndarray)
assert isinstance(ss, numpy.ndarray)
assert envelope.shape == (2,)
assert numpy.isfinite(envelope).all()

View file

@ -371,7 +371,7 @@ def _real_pulse_case() -> RealPulseCase:
source_phasor, _delay = gaussian_packet(wl=wavelength, dwl=1.0, dt=dt, turn_on=1e-5)
aa, cc, ss = source_phasor(numpy.arange(total_steps) + 0.5)
waveform = aa * (cc + 1j * ss)
waveform = numpy.asarray(aa * (cc + 1j * ss), dtype=complex)
scale = fdtd.real_injection_scale(waveform, omega, dt, offset_steps=0.5)[0]
j_accumulator = numpy.zeros((1, *full_shape), dtype=complex)

View file

@ -1,3 +1,5 @@
from typing import Any
import numpy
import pytest
@ -12,7 +14,7 @@ from .utils import assert_close
[(3, 1, 4, 1.0), (0, 0, 4, 1.0), (0, 1, 2, 1.0), (0, 1, 4, 0.0)],
)
def test_cpml_params_reject_invalid_arguments(axis: int, polarity: int, thickness: int, epsilon_eff: float) -> None:
with pytest.raises(Exception):
with pytest.raises(ValueError, match='Invalid axis|Invalid polarity|wise to have a pml|epsilon_eff must be positive'):
cpml_params(axis=axis, polarity=polarity, dt=0.1, thickness=thickness, epsilon_eff=epsilon_eff)
@ -36,7 +38,7 @@ def test_updates_with_cpml_keeps_zero_fields_zero() -> None:
e = numpy.zeros(shape, dtype=float)
h = numpy.zeros(shape, dtype=float)
dxes = [[numpy.ones(4), numpy.ones(4), numpy.ones(4)] for _ in range(2)]
params = [[None, None] for _ in range(3)]
params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)]
params[0][0] = cpml_params(axis=0, polarity=-1, dt=0.1, thickness=3)
update_e, update_h = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon)
@ -69,7 +71,7 @@ def test_updates_with_cpml_matches_base_updates_when_all_faces_disabled() -> Non
e = _real_field(shape, 10.0)
h = _real_field(shape, 100.0)
dxes = _unit_dxes(shape)
params = [[None, None] for _ in range(3)]
params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)]
update_e_cpml, update_h_cpml = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon)
update_e_base = maxwell_e(dt=0.1, dxes=dxes)
@ -96,7 +98,7 @@ def test_updates_with_cpml_matches_base_updates_with_complex_dtype_when_all_face
e = _complex_field(shape, 10.0)
h = _complex_field(shape, 100.0)
dxes = _unit_dxes(shape)
params = [[None, None] for _ in range(3)]
params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)]
update_e_cpml, update_h_cpml = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon, dtype=complex)
update_e_base = maxwell_e(dt=0.1, dxes=dxes)
@ -125,7 +127,7 @@ def test_updates_with_cpml_only_changes_the_configured_face_region() -> None:
dxes = _unit_dxes(shape)
thickness = 3
params = [[None, None] for _ in range(3)]
params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)]
params[0][0] = cpml_params(axis=0, polarity=-1, dt=0.1, thickness=thickness)
update_e_cpml, update_h_cpml = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon)
@ -166,7 +168,7 @@ def test_cpml_plane_wave_phasor_decays_monotonically_through_outgoing_pml() -> N
epsilon = numpy.ones(shape, dtype=float)
dxes = _unit_dxes(shape)
params = [[None, None] for _ in range(3)]
params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)]
for polarity_index, polarity in enumerate((-1, 1)):
params[0][polarity_index] = cpml_params(axis=0, polarity=polarity, dt=dt, thickness=thickness)
@ -212,7 +214,7 @@ def test_cpml_point_source_total_energy_reaches_late_time_plateau() -> None:
epsilon = numpy.ones(shape, dtype=float)
dxes = _unit_dxes(shape)
params = [[None, None] for _ in range(3)]
params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)]
for axis in range(3):
for polarity_index, polarity in enumerate((-1, 1)):
params[axis][polarity_index] = cpml_params(axis=axis, polarity=polarity, dt=dt, thickness=thickness)

View file

@ -1,26 +1,28 @@
import builtins
import importlib
import pathlib
from types import ModuleType
from typing import Any
import pytest
import meanas
from ..fdfd import bloch
from .utils import assert_close
def _reload(module):
def _reload(module: ModuleType) -> ModuleType:
return importlib.reload(module)
def _restore_reloaded(monkeypatch, module):
def _restore_reloaded(monkeypatch: pytest.MonkeyPatch, module: ModuleType) -> ModuleType:
monkeypatch.undo()
return _reload(module)
def test_meanas_import_survives_readme_open_failure(monkeypatch) -> None: # type: ignore[no-untyped-def]
def test_meanas_import_survives_readme_open_failure(monkeypatch: pytest.MonkeyPatch) -> None:
expected_version = meanas.__version__
original_open = pathlib.Path.open
def failing_open(self: pathlib.Path, *args, **kwargs): # type: ignore[no-untyped-def]
def failing_open(self: pathlib.Path, *args: Any, **kwargs: Any) -> Any:
if self.name == 'README.md':
raise FileNotFoundError('forced README failure')
return original_open(self, *args, **kwargs)
@ -35,10 +37,16 @@ def test_meanas_import_survives_readme_open_failure(monkeypatch) -> None: # typ
_restore_reloaded(monkeypatch, meanas)
def test_bloch_reloads_with_numpy_fft_when_pyfftw_is_unavailable(monkeypatch) -> None: # type: ignore[no-untyped-def]
def test_bloch_reloads_with_numpy_fft_when_pyfftw_is_unavailable(monkeypatch: pytest.MonkeyPatch) -> None:
original_import = builtins.__import__
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0): # type: ignore[no-untyped-def]
def fake_import(
name: str,
globals: dict[str, Any] | None = None,
locals: dict[str, Any] | None = None,
fromlist: tuple[str, ...] = (),
level: int = 0,
) -> Any:
if name.startswith('pyfftw'):
raise ImportError('forced pyfftw failure')
return original_import(name, globals, locals, fromlist, level)

View file

@ -224,7 +224,7 @@ def _build_cpml_params() -> list[list[dict[str, numpy.ndarray | float]]]:
def _build_complex_pulse_waveform(total_steps: int) -> tuple[numpy.ndarray, complex]:
source_phasor, _delay = gaussian_packet(wl=WAVELENGTH, dwl=PULSE_DWL, dt=DT, turn_on=1e-5)
aa, cc, ss = source_phasor(numpy.arange(total_steps) + 0.5)
waveform = aa * (cc + 1j * ss)
waveform = numpy.asarray(aa * (cc + 1j * ss), dtype=complex)
scale = fdtd.temporal_phasor_scale(waveform, OMEGA, DT, offset_steps=0.5)[0]
return waveform, scale
@ -272,7 +272,7 @@ def _run_real_field_straight_waveguide_case() -> RealFieldWaveguideResult:
slices=REAL_FIELD_SOURCE_SLICES,
epsilon=epsilon,
)
j_mode *= numpy.exp(1j * REAL_FIELD_SOURCE_PHASE)
j_mode = j_mode * numpy.exp(1j * REAL_FIELD_SOURCE_PHASE)
monitor_mode = waveguide_3d.solve_mode(
0,
omega=OMEGA,
@ -425,8 +425,8 @@ def _run_straight_waveguide_case(variant: str) -> WaveguideCalibrationResult:
)
h_fdfd = functional.e2h(OMEGA, stretched_dxes)(e_fdfd)
overlap_td = vec(e_ph) @ vec(overlap_e).conj()
overlap_fd = vec(e_fdfd) @ vec(overlap_e).conj()
overlap_td = complex(vec(e_ph) @ vec(overlap_e).conj())
overlap_fd = complex(vec(e_fdfd) @ vec(overlap_e).conj())
poynting_td = functional.poynting_e_cross_h(stretched_dxes)(e_ph, h_ph.conj())
poynting_fd = functional.poynting_e_cross_h(stretched_dxes)(e_fdfd, h_fdfd.conj())
@ -551,10 +551,10 @@ def _run_width_step_scattering_case() -> WaveguideScatteringResult:
)
h_fdfd = functional.e2h(OMEGA, stretched_dxes)(e_fdfd)
reflected_td = vec(e_ph) @ vec(reflected_overlap).conj()
reflected_fd = vec(e_fdfd) @ vec(reflected_overlap).conj()
transmitted_td = vec(e_ph) @ vec(transmitted_overlap).conj()
transmitted_fd = vec(e_fdfd) @ vec(transmitted_overlap).conj()
reflected_td = complex(vec(e_ph) @ vec(reflected_overlap).conj())
reflected_fd = complex(vec(e_fdfd) @ vec(reflected_overlap).conj())
transmitted_td = complex(vec(e_ph) @ vec(transmitted_overlap).conj())
transmitted_fd = complex(vec(e_fdfd) @ vec(transmitted_overlap).conj())
poynting_td = functional.poynting_e_cross_h(stretched_dxes)(e_ph, h_ph.conj())
poynting_fd = functional.poynting_e_cross_h(stretched_dxes)(e_fdfd, h_fdfd.conj())
@ -664,8 +664,8 @@ def _run_pulsed_straight_waveguide_case() -> PulsedWaveguideCalibrationResult:
)
h_fdfd = functional.e2h(OMEGA, stretched_dxes)(e_fdfd)
overlap_td = vec(e_ph) @ vec(overlap_e).conj()
overlap_fd = vec(e_fdfd) @ vec(overlap_e).conj()
overlap_td = complex(vec(e_ph) @ vec(overlap_e).conj())
overlap_fd = complex(vec(e_fdfd) @ vec(overlap_e).conj())
poynting_td = functional.poynting_e_cross_h(stretched_dxes)(e_ph, h_ph.conj())
poynting_fd = functional.poynting_e_cross_h(stretched_dxes)(e_fdfd, h_fdfd.conj())

View file

@ -16,7 +16,7 @@ def build_waveguide_3d_mode(
*,
slice_start: int,
polarity: int,
) -> tuple[numpy.ndarray, list[list[numpy.ndarray]], tuple[slice, slice, slice], dict[str, complex | numpy.ndarray]]:
) -> tuple[numpy.ndarray, list[list[numpy.ndarray]], tuple[slice, slice, slice], waveguide_3d.Waveguide3DMode]:
epsilon = numpy.ones((3, 5, 5, 1), dtype=float)
dxes = [[numpy.ones(5), numpy.ones(5), numpy.ones(1)] for _ in range(2)]
slices = (slice(slice_start, slice_start + 1), slice(None), slice(None))

View file

@ -1,5 +1,6 @@
import numpy
from numpy.typing import NDArray
from numpy.typing import ArrayLike
def make_prng(seed: int = 12345) -> numpy.random.RandomState:
@ -24,9 +25,9 @@ def assert_fields_close(
)
def assert_close(
x: NDArray,
y: NDArray,
x: ArrayLike,
y: ArrayLike,
*args,
**kwargs,
) -> None:
numpy.testing.assert_allclose(x, y, *args, **kwargs)
numpy.testing.assert_allclose(numpy.asarray(x), numpy.asarray(y), *args, **kwargs)

View file

@ -104,6 +104,9 @@ lint.ignore = [
"TRY002", # Exception()
]
[tool.ruff.lint.per-file-ignores]
"meanas/test/**/*.py" = ["ANN", "ARG", "TC006"]
[[tool.mypy.overrides]]
module = [