From c6c9159b133e9c22b55c45ff2c2f719aef086b9c Mon Sep 17 00:00:00 2001 From: Forgejo Actions Date: Tue, 21 Apr 2026 21:13:34 -0700 Subject: [PATCH] type hints and lint --- examples/eme.py | 8 +++++-- examples/eme_bend.py | 10 +++++--- examples/fdtd.py | 3 +-- meanas/fdfd/bloch.py | 11 ++++++++- meanas/fdfd/eme.py | 16 +++++++------ meanas/fdfd/farfield.py | 27 +++++++++++----------- meanas/fdfd/waveguide_2d.py | 4 ++-- meanas/fdfd/waveguide_3d.py | 22 +++++++++++------- meanas/fdfd/waveguide_cyl.py | 12 +++++----- meanas/fdmath/functional.py | 6 ++--- meanas/fdmath/operators.py | 4 ++-- meanas/fdmath/types.py | 8 +++---- meanas/fdtd/base.py | 6 ++--- meanas/fdtd/boundaries.py | 14 +++++------ meanas/fdtd/misc.py | 24 +++++++++++++------ meanas/fdtd/pml.py | 25 ++++++++++---------- meanas/test/test_bloch_interactions.py | 2 +- meanas/test/test_eme_numerics.py | 27 ++++++++++++++-------- meanas/test/test_fdfd_pml.py | 6 +++-- meanas/test/test_fdfd_solvers.py | 20 ++++++++-------- meanas/test/test_fdtd_base.py | 2 -- meanas/test/test_fdtd_boundaries.py | 2 +- meanas/test/test_fdtd_misc.py | 3 +++ meanas/test/test_fdtd_phasor.py | 2 +- meanas/test/test_fdtd_pml.py | 16 +++++++------ meanas/test/test_import_fallbacks.py | 22 ++++++++++++------ meanas/test/test_waveguide_fdtd_fdfd.py | 20 ++++++++-------- meanas/test/test_waveguide_mode_helpers.py | 2 +- meanas/test/utils.py | 7 +++--- pyproject.toml | 3 +++ 30 files changed, 198 insertions(+), 136 deletions(-) diff --git a/examples/eme.py b/examples/eme.py index 6215cbc..3a26dc8 100644 --- a/examples/eme.py +++ b/examples/eme.py @@ -14,6 +14,7 @@ simple straight interface: from __future__ import annotations import importlib +from typing import TYPE_CHECKING import numpy from numpy import pi @@ -24,6 +25,9 @@ from gridlock import Extent from meanas.fdfd import eme, waveguide_2d from meanas.fdmath import unvec +if TYPE_CHECKING: + from types import ModuleType + WL = 1310.0 DX = 40.0 @@ -35,7 +39,7 @@ EPS_OX = 1.453 ** 2 MODE_NUMBERS = numpy.array([0]) -def require_optional(name: str, package_name: str | None = None): +def require_optional(name: str, package_name: str | None = None) -> ModuleType: package_name = package_name or name try: return importlib.import_module(name) @@ -159,7 +163,7 @@ def print_summary(ss: numpy.ndarray, wavenumbers_left: numpy.ndarray, wavenumber def plot_results( *, - pyplot, + pyplot: ModuleType, ss: numpy.ndarray, left_mode: tuple[numpy.ndarray, numpy.ndarray], right_mode: tuple[numpy.ndarray, numpy.ndarray], diff --git a/examples/eme_bend.py b/examples/eme_bend.py index caff4df..e5eaebd 100644 --- a/examples/eme_bend.py +++ b/examples/eme_bend.py @@ -15,6 +15,7 @@ This example demonstrates a cylindrical-waveguide EME workflow: from __future__ import annotations import importlib +from typing import TYPE_CHECKING import numpy from numpy import pi @@ -26,6 +27,9 @@ from gridlock import Extent from meanas.fdfd import eme, waveguide_2d, waveguide_cyl from meanas.fdmath import unvec +if TYPE_CHECKING: + from types import ModuleType + WL = 1310.0 DX = 40.0 @@ -40,7 +44,7 @@ STRAIGHT_SECTION_LENGTH = 12e3 BEND_ANGLE = pi / 2 -def require_optional(name: str, package_name: str | None = None): +def require_optional(name: str, package_name: str | None = None) -> ModuleType: package_name = package_name or name try: return importlib.import_module(name) @@ -163,7 +167,7 @@ def solve_bend_modes( def build_cascaded_network( - skrf, + skrf: ModuleType, *, interface_s: numpy.ndarray, straight_wavenumbers: numpy.ndarray, @@ -216,7 +220,7 @@ def print_summary( def plot_results( *, - pyplot, + pyplot: ModuleType, interface_s: numpy.ndarray, cascaded_s: numpy.ndarray, straight_mode: tuple[numpy.ndarray, numpy.ndarray], diff --git a/examples/fdtd.py b/examples/fdtd.py index d8cd101..fd6026d 100644 --- a/examples/fdtd.py +++ b/examples/fdtd.py @@ -89,7 +89,7 @@ def perturbed_l3(a: float, radius: float, **kwargs) -> Pattern: return pat -def main(): +def main() -> None: dtype = numpy.float32 max_t = 3600 # number of timesteps @@ -97,7 +97,6 @@ def main(): pml_thickness = 8 # (number of cells) wl = 1550 # Excitation wavelength and fwhm - dwl = 100 # Device design parameters xy_size = numpy.array([10, 10]) diff --git a/meanas/fdfd/bloch.py b/meanas/fdfd/bloch.py index 5701ed9..df04999 100644 --- a/meanas/fdfd/bloch.py +++ b/meanas/fdfd/bloch.py @@ -683,7 +683,16 @@ def eigsolve( return numpy.abs(trace) if False: - def trace_deriv(theta, sgn: int = sgn, ZtAZ=ZtAZ, DtAD=DtAD, symZtD=symZtD, symZtAD=symZtAD, ZtZ=ZtZ, DtD=DtD): # noqa: ANN001 + def trace_deriv( + theta: float, + sgn: int = sgn, + ZtAZ=ZtAZ, # noqa: ANN001 + DtAD=DtAD, # noqa: ANN001 + symZtD=symZtD, # noqa: ANN001 + symZtAD=symZtAD, # noqa: ANN001 + ZtZ=ZtZ, # noqa: ANN001 + DtD=DtD, # noqa: ANN001 + ) -> float: Qi = Qi_func(theta) c2 = numpy.cos(2 * theta) s2 = numpy.sin(2 * theta) diff --git a/meanas/fdfd/eme.py b/meanas/fdfd/eme.py index 366de8e..af745e8 100644 --- a/meanas/fdfd/eme.py +++ b/meanas/fdfd/eme.py @@ -27,11 +27,13 @@ from scipy import sparse from ..fdmath import dx_lists2_t, vcfdfield2 from .waveguide_2d import inner_product +type wavenumber_seq = Sequence[complex] | NDArray[numpy.complexfloating] | NDArray[numpy.floating] + def _validate_port_modes( name: str, ehs: Sequence[Sequence[vcfdfield2]], - wavenumbers: Sequence[complex], + wavenumbers: wavenumber_seq, ) -> tuple[tuple[int, ...], tuple[int, ...]]: if len(ehs) != len(wavenumbers): raise ValueError(f'{name} mode list and wavenumber list must have the same length') @@ -61,9 +63,9 @@ def _validate_port_modes( def get_tr( ehLs: Sequence[Sequence[vcfdfield2]], - wavenumbers_L: Sequence[complex], + wavenumbers_L: wavenumber_seq, ehRs: Sequence[Sequence[vcfdfield2]], - wavenumbers_R: Sequence[complex], + wavenumbers_R: wavenumber_seq, dxes: dx_lists2_t, ) -> tuple[NDArray[numpy.complex128], NDArray[numpy.complex128]]: """ @@ -118,9 +120,9 @@ def get_tr( def get_abcd( ehLs: Sequence[Sequence[vcfdfield2]], - wavenumbers_L: Sequence[complex], + wavenumbers_L: wavenumber_seq, ehRs: Sequence[Sequence[vcfdfield2]], - wavenumbers_R: Sequence[complex], + wavenumbers_R: wavenumber_seq, **kwargs, ) -> sparse.sparray: """ @@ -151,9 +153,9 @@ def get_abcd( def get_s( ehLs: Sequence[Sequence[vcfdfield2]], - wavenumbers_L: Sequence[complex], + wavenumbers_L: wavenumber_seq, ehRs: Sequence[Sequence[vcfdfield2]], - wavenumbers_R: Sequence[complex], + wavenumbers_R: wavenumber_seq, force_nogain: bool = False, force_reciprocal: bool = False, **kwargs, diff --git a/meanas/fdfd/farfield.py b/meanas/fdfd/farfield.py index 0051cd0..00e6989 100644 --- a/meanas/fdfd/farfield.py +++ b/meanas/fdfd/farfield.py @@ -1,23 +1,24 @@ """ Functions for performing near-to-farfield transformation (and the reverse). """ -from typing import Any, cast, TYPE_CHECKING +from typing import Any, cast +from collections.abc import Sequence import numpy from numpy.fft import fft2, fftshift, fftfreq, ifft2, ifftshift from numpy import pi +from numpy.typing import NDArray +from numpy import complexfloating -from ..fdmath import cfdfield_t - -if TYPE_CHECKING: - from collections.abc import Sequence +type farfield_slice = NDArray[complexfloating] +type transverse_slice_pair = Sequence[farfield_slice] def near_to_farfield( - E_near: cfdfield_t, - H_near: cfdfield_t, + E_near: transverse_slice_pair, + H_near: transverse_slice_pair, dx: float, dy: float, - padded_size: list[int] | int | None = None + padded_size: Sequence[int] | int | None = None ) -> dict[str, Any]: """ Compute the farfield, i.e. the distribution of the fields after propagation @@ -58,7 +59,7 @@ def near_to_farfield( raise Exception('H_near must be a length-2 list of ndarrays') s = E_near[0].shape - if not all(s == f.shape for f in E_near + H_near): + if not all(s == f.shape for f in [*E_near, *H_near]): raise Exception('All fields must be the same shape!') if padded_size is None: @@ -123,11 +124,11 @@ def near_to_farfield( def far_to_nearfield( - E_far: cfdfield_t, - H_far: cfdfield_t, + E_far: transverse_slice_pair, + H_far: transverse_slice_pair, dkx: float, dky: float, - padded_size: list[int] | int | None = None + padded_size: Sequence[int] | int | None = None ) -> dict[str, Any]: """ Compute the farfield, i.e. the distribution of the fields after propagation @@ -164,7 +165,7 @@ def far_to_nearfield( raise Exception('H_far must be a length-2 list of ndarrays') s = E_far[0].shape - if not all(s == f.shape for f in E_far + H_far): + if not all(s == f.shape for f in [*E_far, *H_far]): raise Exception('All fields must be the same shape!') if padded_size is None: diff --git a/meanas/fdfd/waveguide_2d.py b/meanas/fdfd/waveguide_2d.py index 1074e2b..fa2fe76 100644 --- a/meanas/fdfd/waveguide_2d.py +++ b/meanas/fdfd/waveguide_2d.py @@ -423,10 +423,10 @@ def normalized_fields_h( def _normalized_fields( e: vcfdslice, h: vcfdslice, - omega: complex, + _omega: complex, dxes: dx_lists2_t, epsilon: vfdslice, - mu: vfdslice | None = None, + _mu: vfdslice | None = None, prop_phase: float = 0, ) -> tuple[vcfdslice_t, vcfdslice_t]: r""" diff --git a/meanas/fdfd/waveguide_3d.py b/meanas/fdfd/waveguide_3d.py index e7dfd22..01db9b1 100644 --- a/meanas/fdfd/waveguide_3d.py +++ b/meanas/fdfd/waveguide_3d.py @@ -19,9 +19,8 @@ The intended workflow is: That same convention controls which side of the selected slice is used for the overlap window and how the expanded field is phased. """ -from typing import Any, cast +from typing import Any, TypedDict, cast import warnings -from typing import Any from collections.abc import Sequence import numpy from numpy.typing import NDArray @@ -31,6 +30,13 @@ from ..fdmath import vec, unvec, dx_lists_t, cfdfield_t, fdfield, cfdfield from . import operators, waveguide_2d +class Waveguide3DMode(TypedDict): + wavenumber: complex + wavenumber_2d: complex + H: NDArray[complexfloating] + E: NDArray[complexfloating] + + def solve_mode( mode_number: int, omega: complex, @@ -40,7 +46,7 @@ def solve_mode( slices: Sequence[slice], epsilon: fdfield, mu: fdfield | None = None, - ) -> dict[str, complex | NDArray[complexfloating]]: + ) -> Waveguide3DMode: r""" Given a 3D grid, selects a slice from the grid and attempts to solve for an eigenmode propagating through that slice. @@ -121,7 +127,7 @@ def solve_mode( E[iii] = e[oo][:, :, None].transpose(reverse_order) H[iii] = h[oo][:, :, None].transpose(reverse_order) - results = { + results: Waveguide3DMode = { 'wavenumber': wavenumber, 'wavenumber_2d': wavenumber_2d, 'H': H, @@ -184,13 +190,13 @@ def compute_source( def compute_overlap_e( - E: cfdfield_t, + E: cfdfield, wavenumber: complex, dxes: dx_lists_t, axis: int, polarity: int, slices: Sequence[slice], - omega: float, + _omega: float, ) -> cfdfield_t: r""" Build an overlap field for projecting another 3D electric field onto a mode. @@ -262,7 +268,7 @@ def compute_overlap_e( if clipped_start >= clipped_stop: raise ValueError('Requested overlap window lies outside the domain') if clipped_start != start or clipped_stop != stop: - warnings.warn('Requested overlap window was clipped to fit within the domain', RuntimeWarning) + warnings.warn('Requested overlap window was clipped to fit within the domain', RuntimeWarning, stacklevel=2) slices2_l = list(slices) slices2_l[axis] = slice(clipped_start, clipped_stop) @@ -275,7 +281,7 @@ def compute_overlap_e( norm = (Etgt.conj() * Etgt).sum() if norm == 0: raise ValueError('Requested overlap window contains no overlap field support') - Etgt /= norm + Etgt = Etgt / norm return cfdfield_t(Etgt) diff --git a/meanas/fdfd/waveguide_cyl.py b/meanas/fdfd/waveguide_cyl.py index 201f709..e4e2666 100644 --- a/meanas/fdfd/waveguide_cyl.py +++ b/meanas/fdfd/waveguide_cyl.py @@ -130,7 +130,7 @@ import numpy from numpy.typing import NDArray, ArrayLike from scipy import sparse -from ..fdmath import vec, unvec, dx_lists2_t, vcfdslice_t, vcfdfield2_t, vfdslice, vcfdslice, vcfdfield2 +from ..fdmath import vec, unvec, dx_lists2_t, vcfdslice_t, vfdslice, vcfdslice, vcfdfield2 from ..fdmath.operators import deriv_forward, deriv_back from ..eigensolvers import signed_eigensolve, rayleigh_quotient_iteration from . import waveguide_2d @@ -267,7 +267,7 @@ def solve_mode( mode_number: int, *args: Any, **kwargs: Any, - ) -> tuple[vcfdslice, complex]: + ) -> tuple[vcfdfield2, complex]: """ Wrapper around `solve_modes()` that solves for a single mode. @@ -285,7 +285,7 @@ def solve_mode( def linear_wavenumbers( - e_xys: list[vcfdfield2_t], + e_xys: Sequence[vcfdfield2] | NDArray[numpy.complex128], angular_wavenumbers: ArrayLike, epsilon: vfdslice, dxes: dx_lists2_t, @@ -537,11 +537,11 @@ def normalized_fields_e( def _normalized_fields( e: vcfdslice, h: vcfdslice, - omega: complex, + _omega: complex, dxes: dx_lists2_t, - rmin: float, # Currently unused, but may want to use cylindrical poynting + _rmin: float, # Currently unused, but may want to use cylindrical poynting epsilon: vfdslice, - mu: vfdslice | None = None, + _mu: vfdslice | None = None, prop_phase: float = 0, ) -> tuple[vcfdslice_t, vcfdslice_t]: r""" diff --git a/meanas/fdmath/functional.py b/meanas/fdmath/functional.py index 034d4ba..27d368a 100644 --- a/meanas/fdmath/functional.py +++ b/meanas/fdmath/functional.py @@ -10,7 +10,7 @@ import numpy from numpy.typing import NDArray from numpy import floating, complexfloating -from .types import fdfield_t, fdfield_updater_t +from .types import fdfield, fdfield_updater_t def deriv_forward( @@ -127,7 +127,7 @@ def curl_forward_parts( ) -> Callable: Dx, Dy, Dz = deriv_forward(dx_e) - def mkparts_fwd(e: fdfield_t) -> tuple[tuple[fdfield_t, fdfield_t], ...]: + def mkparts_fwd(e: fdfield) -> tuple[tuple[fdfield, fdfield], ...]: return ((-Dz(e[1]), Dy(e[2])), ( Dz(e[0]), -Dx(e[2])), (-Dy(e[0]), Dx(e[1]))) @@ -140,7 +140,7 @@ def curl_back_parts( ) -> Callable: Dx, Dy, Dz = deriv_back(dx_h) - def mkparts_back(h: fdfield_t) -> tuple[tuple[fdfield_t, fdfield_t], ...]: + def mkparts_back(h: fdfield) -> tuple[tuple[fdfield, fdfield], ...]: return ((-Dz(h[1]), Dy(h[2])), ( Dz(h[0]), -Dx(h[2])), (-Dy(h[0]), Dx(h[1]))) diff --git a/meanas/fdmath/operators.py b/meanas/fdmath/operators.py index 0c64ae7..8b7cabc 100644 --- a/meanas/fdmath/operators.py +++ b/meanas/fdmath/operators.py @@ -9,7 +9,7 @@ from numpy.typing import NDArray from numpy import floating, complexfloating from scipy import sparse -from .types import vfdfield_t +from .types import vfdfield def shift_circ( @@ -171,7 +171,7 @@ def cross( [-B[1], B[0], zero]]) -def vec_cross(b: vfdfield_t) -> sparse.sparray: +def vec_cross(b: vfdfield) -> sparse.sparray: """ Vector cross product operator diff --git a/meanas/fdmath/types.py b/meanas/fdmath/types.py index 222d18a..b82a5ae 100644 --- a/meanas/fdmath/types.py +++ b/meanas/fdmath/types.py @@ -88,8 +88,8 @@ dx_lists2_mut = MutableSequence[MutableSequence[NDArray[floating | complexfloati """Mutable version of `dx_lists2_t`""" -fdfield_updater_t = Callable[..., fdfield_t] -"""Convenience type for functions which take and return an fdfield_t""" +fdfield_updater_t = Callable[..., fdfield] +"""Convenience type for functions which take and return a real `fdfield`""" -cfdfield_updater_t = Callable[..., cfdfield_t] -"""Convenience type for functions which take and return an cfdfield_t""" +cfdfield_updater_t = Callable[..., cfdfield] +"""Convenience type for functions which take and return a complex `cfdfield`""" diff --git a/meanas/fdtd/base.py b/meanas/fdtd/base.py index 3891e28..480ed87 100644 --- a/meanas/fdtd/base.py +++ b/meanas/fdtd/base.py @@ -3,7 +3,7 @@ Basic FDTD field updates """ -from ..fdmath import dx_lists_t, fdfield_t, fdfield_updater_t +from ..fdmath import dx_lists_t, fdfield, fdfield_updater_t from ..fdmath.functional import curl_forward, curl_back @@ -47,7 +47,7 @@ def maxwell_e( else: curl_h_fun = curl_back() - def me_fun(e: fdfield_t, h: fdfield_t, epsilon: fdfield_t | float) -> fdfield_t: + def me_fun(e: fdfield, h: fdfield, epsilon: fdfield | float) -> fdfield: """ Update the E-field. @@ -103,7 +103,7 @@ def maxwell_h( else: curl_e_fun = curl_forward() - def mh_fun(e: fdfield_t, h: fdfield_t, mu: fdfield_t | float | None = None) -> fdfield_t: + def mh_fun(e: fdfield, h: fdfield, mu: fdfield | float | None = None) -> fdfield: """ Update the H-field. diff --git a/meanas/fdtd/boundaries.py b/meanas/fdtd/boundaries.py index aa0bff5..ca8940d 100644 --- a/meanas/fdtd/boundaries.py +++ b/meanas/fdtd/boundaries.py @@ -6,7 +6,7 @@ Boundary conditions from typing import Any -from ..fdmath import fdfield_t, fdfield_updater_t +from ..fdmath import fdfield, fdfield_updater_t def conducting_boundary( @@ -15,7 +15,7 @@ def conducting_boundary( ) -> tuple[fdfield_updater_t, fdfield_updater_t]: dirs = [0, 1, 2] if direction not in dirs: - raise Exception(f'Invalid direction: {direction}') + raise ValueError(f'Invalid direction: {direction}') dirs.remove(direction) u, v = dirs @@ -31,13 +31,13 @@ def conducting_boundary( boundary = tuple(boundary_slice) shifted1 = tuple(shifted1_slice) - def en(e: fdfield_t) -> fdfield_t: + def en(e: fdfield) -> fdfield: e[direction][boundary] = 0 e[u][boundary] = e[u][shifted1] e[v][boundary] = e[v][shifted1] return e - def hn(h: fdfield_t) -> fdfield_t: + def hn(h: fdfield) -> fdfield: h[direction][boundary] = h[direction][shifted1] h[u][boundary] = 0 h[v][boundary] = 0 @@ -56,14 +56,14 @@ def conducting_boundary( shifted1 = tuple(shifted1_slice) shifted2 = tuple(shifted2_slice) - def ep(e: fdfield_t) -> fdfield_t: + def ep(e: fdfield) -> fdfield: e[direction][boundary] = -e[direction][shifted2] e[direction][shifted1] = 0 e[u][boundary] = e[u][shifted1] e[v][boundary] = e[v][shifted1] return e - def hp(h: fdfield_t) -> fdfield_t: + def hp(h: fdfield) -> fdfield: h[direction][boundary] = h[direction][shifted1] h[u][boundary] = -h[u][shifted2] h[u][shifted1] = 0 @@ -73,4 +73,4 @@ def conducting_boundary( return ep, hp - raise Exception(f'Bad polarity: {polarity}') + raise ValueError(f'Bad polarity: {polarity}') diff --git a/meanas/fdtd/misc.py b/meanas/fdtd/misc.py index 89ccb3d..585c745 100644 --- a/meanas/fdtd/misc.py +++ b/meanas/fdtd/misc.py @@ -1,5 +1,6 @@ from collections.abc import Callable import logging +from typing import cast import numpy from numpy.typing import NDArray, ArrayLike @@ -9,7 +10,14 @@ from numpy import pi logger = logging.getLogger(__name__) -pulse_fn_t = Callable[[int | NDArray], tuple[float, float, float]] +type pulse_scalar_t = float | NDArray[numpy.floating] +pulse_fn_t = Callable[[ArrayLike], tuple[pulse_scalar_t, pulse_scalar_t, pulse_scalar_t]] + + +def _scalar_or_array(values: NDArray[numpy.floating]) -> pulse_scalar_t: + if values.ndim == 0: + return float(values) + return cast('NDArray[numpy.floating]', values) def gaussian_packet( @@ -49,8 +57,9 @@ def gaussian_packet( delay = numpy.ceil(delay * freq) / freq # force delay to integer number of periods to maintain phase logger.info(f'src_time {2 * delay / dt}') - def source_phasor(ii: int | NDArray) -> tuple[float, float, float]: - t0 = ii * dt - delay + def source_phasor(ii: ArrayLike) -> tuple[pulse_scalar_t, pulse_scalar_t, pulse_scalar_t]: + ii_array = numpy.asarray(ii, dtype=float) + t0 = ii_array * dt - delay envelope = numpy.sqrt(numpy.sqrt(2 * alpha / pi)) * numpy.exp(-alpha * t0 * t0) if one_sided: @@ -59,7 +68,7 @@ def gaussian_packet( cc = numpy.cos(omega * t0) ss = numpy.sin(omega * t0) - return envelope, cc, ss + return _scalar_or_array(envelope), _scalar_or_array(cc), _scalar_or_array(ss) # nrm = numpy.exp(-omega * omega / alpha) / 2 @@ -105,15 +114,16 @@ def ricker_pulse( delay = delay_results.root delay = numpy.ceil(delay * freq) / freq # force delay to integer number of periods to maintain phase - def source_phasor(ii: int | NDArray) -> tuple[float, float, float]: - t0 = ii * dt - delay + def source_phasor(ii: ArrayLike) -> tuple[pulse_scalar_t, pulse_scalar_t, pulse_scalar_t]: + ii_array = numpy.asarray(ii, dtype=float) + t0 = ii_array * dt - delay rr = omega * t0 / 2 ff = (1 - 2 * rr * rr) * numpy.exp(-rr * rr) cc = numpy.cos(omega * t0) ss = numpy.sin(omega * t0) - return ff, cc, ss + return _scalar_or_array(ff), _scalar_or_array(cc), _scalar_or_array(ss) return source_phasor, delay diff --git a/meanas/fdtd/pml.py b/meanas/fdtd/pml.py index bf61b4e..aba9cb7 100644 --- a/meanas/fdtd/pml.py +++ b/meanas/fdtd/pml.py @@ -23,7 +23,7 @@ from copy import deepcopy import numpy from numpy.typing import NDArray, DTypeLike -from ..fdmath import fdfield, fdfield_t, dx_lists_t +from ..fdmath import fdfield, dx_lists_t from ..fdmath.functional import deriv_forward, deriv_back @@ -67,16 +67,16 @@ def cpml_params( """ if axis not in range(3): - raise Exception(f'Invalid axis: {axis}') + raise ValueError(f'Invalid axis: {axis}') if polarity not in (-1, 1): - raise Exception(f'Invalid polarity: {polarity}') + raise ValueError(f'Invalid polarity: {polarity}') if thickness <= 2: - raise Exception('It would be wise to have a pml with 4+ cells of thickness') + raise ValueError('It would be wise to have a pml with 4+ cells of thickness') if epsilon_eff <= 0: - raise Exception('epsilon_eff must be positive') + raise ValueError('epsilon_eff must be positive') sigma_max = -ln_R_per_layer / 2 * (m + 1) kappa_max = numpy.sqrt(epsilon_eff * mu_eff) @@ -129,8 +129,7 @@ def updates_with_cpml( epsilon: fdfield, *, dtype: DTypeLike = numpy.float32, - ) -> tuple[Callable[[fdfield_t, fdfield_t, fdfield_t], None], - Callable[[fdfield_t, fdfield_t, fdfield_t], None]]: + ) -> tuple[Callable[..., None], Callable[..., None]]: """ Build Yee-step update closures augmented with CPML terms. @@ -187,9 +186,9 @@ def updates_with_cpml( pH = numpy.empty_like(epsilon, dtype=dtype) def update_E( - e: fdfield_t, - h: fdfield_t, - epsilon: fdfield_t, + e: fdfield, + h: fdfield, + epsilon: fdfield, ) -> None: dyHx = Dby(h[0]) dzHx = Dbz(h[0]) @@ -233,9 +232,9 @@ def updates_with_cpml( e[2] += dt / epsilon[2] * (dxHy - dyHx + pE[2]) def update_H( - e: fdfield_t, - h: fdfield_t, - mu: fdfield_t | tuple[int, int, int] = (1, 1, 1), + e: fdfield, + h: fdfield, + mu: fdfield | tuple[int, int, int] = (1, 1, 1), ) -> None: dyEx = Dfy(e[0]) dzEx = Dfz(e[0]) diff --git a/meanas/test/test_bloch_interactions.py b/meanas/test/test_bloch_interactions.py index b67d5ce..0628a55 100644 --- a/meanas/test/test_bloch_interactions.py +++ b/meanas/test/test_bloch_interactions.py @@ -4,7 +4,7 @@ from numpy.testing import assert_allclose from types import SimpleNamespace from ..fdfd import bloch -from ._bloch_case import EPSILON, G_MATRIX, H_SIZE, K0_X, SHAPE, Y0, Y0_TWO_MODE, build_overlap_fixture +from ._bloch_case import EPSILON, G_MATRIX, H_SIZE, K0_X, Y0, Y0_TWO_MODE, build_overlap_fixture from .utils import assert_close diff --git a/meanas/test/test_eme_numerics.py b/meanas/test/test_eme_numerics.py index 7486128..2949e4c 100644 --- a/meanas/test/test_eme_numerics.py +++ b/meanas/test/test_eme_numerics.py @@ -1,3 +1,5 @@ +from typing import cast + import numpy import pytest from scipy import sparse @@ -51,6 +53,10 @@ def _nonsymmetric_tr(left_marker: object): return fake_get_tr +def _dummy_modes() -> tuple[list[tuple[numpy.ndarray, numpy.ndarray]], numpy.ndarray]: + return [_mode(0.0), _mode(0.7)], numpy.array([1.0, 0.5]) + + def test_get_tr_returns_finite_bounded_transfer_matrices() -> None: left_modes, right_modes = _mode_sets() @@ -103,9 +109,10 @@ def test_get_s_plain_matches_block_assembly_from_get_tr() -> None: def test_get_s_force_nogain_caps_singular_values(monkeypatch) -> None: monkeypatch.setattr(eme, 'get_tr', _gain_only_tr) + modes, wavenumbers = _dummy_modes() - plain_s = eme.get_s(None, None, None, None) - clipped_s = eme.get_s(None, None, None, None, force_nogain=True) + plain_s = eme.get_s(modes, wavenumbers, modes, wavenumbers) + clipped_s = eme.get_s(modes, wavenumbers, modes, wavenumbers, force_nogain=True) plain_singular_values = numpy.linalg.svd(plain_s, compute_uv=False) clipped_singular_values = numpy.linalg.svd(clipped_s, compute_uv=False) @@ -116,18 +123,20 @@ def test_get_s_force_nogain_caps_singular_values(monkeypatch) -> None: def test_get_s_force_reciprocal_symmetrizes_output(monkeypatch) -> None: - left = object() - right = object() + left = numpy.array([1.0, 0.5]) + right = numpy.array([0.9, 0.4]) + modes, _wavenumbers = _dummy_modes() monkeypatch.setattr(eme, 'get_tr', _nonsymmetric_tr(left)) - ss = eme.get_s(None, left, None, right, force_reciprocal=True) + ss = eme.get_s(modes, left, modes, right, force_reciprocal=True) assert_close(ss, ss.T) def test_get_s_force_nogain_and_reciprocal_returns_finite_output(monkeypatch) -> None: monkeypatch.setattr(eme, 'get_tr', _gain_and_reflection_tr) - ss = eme.get_s(None, None, None, None, force_nogain=True, force_reciprocal=True) + modes, wavenumbers = _dummy_modes() + ss = eme.get_s(modes, wavenumbers, modes, wavenumbers, force_nogain=True, force_reciprocal=True) assert ss.shape == (4, 4) assert numpy.isfinite(ss).all() @@ -143,15 +152,15 @@ def test_get_tr_rejects_length_mismatches() -> None: def test_get_tr_rejects_malformed_mode_tuples() -> None: - bad_modes = [(numpy.ones(4),)] + bad_modes = cast(list[tuple[numpy.ndarray, numpy.ndarray]], [(numpy.ones(4, dtype=complex),)]) with pytest.raises(ValueError, match='2-tuple'): eme.get_tr(bad_modes, [1.0], bad_modes, [1.0], dxes=DXES) def test_get_tr_rejects_incompatible_field_shapes() -> None: - left_modes = [(numpy.ones(4), numpy.ones(4))] - right_modes = [(numpy.ones(6), numpy.ones(6))] + left_modes = [(numpy.ones(4, dtype=complex), numpy.ones(4, dtype=complex))] + right_modes = [(numpy.ones(6, dtype=complex), numpy.ones(6, dtype=complex))] with pytest.raises(ValueError, match='same E/H shapes'): eme.get_tr(left_modes, [1.0], right_modes, [1.0], dxes=DXES) diff --git a/meanas/test/test_fdfd_pml.py b/meanas/test/test_fdfd_pml.py index 540a3a0..1a8d66c 100644 --- a/meanas/test/test_fdfd_pml.py +++ b/meanas/test/test_fdfd_pml.py @@ -85,8 +85,10 @@ def j_distribution( other_dims = [0, 1, 2] other_dims.remove(dim) - dx_prop = (dxes[0][dim][shape[dim + 1] // 2] - + dxes[1][dim][shape[dim + 1] // 2]) / 2 # noqa: E128 # TODO is this right for nonuniform dxes? + dx_prop = ( + dxes[0][dim][shape[dim + 1] // 2] + + dxes[1][dim][shape[dim + 1] // 2] + ) / 2 # TODO is this right for nonuniform dxes? # Mask only contains components orthogonal to propagation direction center_mask = numpy.zeros(shape, dtype=bool) diff --git a/meanas/test/test_fdfd_solvers.py b/meanas/test/test_fdfd_solvers.py index b841dc9..de39d70 100644 --- a/meanas/test/test_fdfd_solvers.py +++ b/meanas/test/test_fdfd_solvers.py @@ -1,3 +1,5 @@ +from typing import cast + import numpy from ..fdfd import solvers @@ -41,7 +43,7 @@ def test_scipy_qmr_installs_logging_callback_when_missing(monkeypatch) -> None: def test_generic_forward_preconditions_system_and_guess(monkeypatch) -> None: case = solver_plumbing_case() - captured: dict[str, object] = {} + captured: dict[str, numpy.ndarray | float | object] = {} monkeypatch.setattr(solvers.operators, 'e_full', lambda *args, **kwargs: case.a0) monkeypatch.setattr(solvers.operators, 'e_full_preconditioners', lambda dxes: (case.pl, case.pr)) @@ -63,16 +65,16 @@ def test_generic_forward_preconditions_system_and_guess(monkeypatch) -> None: E_guess=case.guess, ) - assert_close(captured['a'].toarray(), (case.pl @ case.a0 @ case.pr).toarray()) - assert_close(captured['b'], case.pl @ (-1j * case.omega * case.j)) - assert_close(captured['x0'], case.pl @ case.guess) + assert_close(cast(object, captured['a']).toarray(), (case.pl @ case.a0 @ case.pr).toarray()) # type: ignore[attr-defined] + assert_close(cast(numpy.ndarray, captured['b']), case.pl @ (-1j * case.omega * case.j)) + assert_close(cast(numpy.ndarray, captured['x0']), case.pl @ case.guess) assert captured['atol'] == 1e-12 assert_close(result, case.pr @ case.solver_result) def test_generic_adjoint_preconditions_system_and_guess(monkeypatch) -> None: case = solver_plumbing_case() - captured: dict[str, object] = {} + captured: dict[str, numpy.ndarray | float | object] = {} monkeypatch.setattr(solvers.operators, 'e_full', lambda *args, **kwargs: case.a0) monkeypatch.setattr(solvers.operators, 'e_full_preconditioners', lambda dxes: (case.pl, case.pr)) @@ -96,9 +98,9 @@ def test_generic_adjoint_preconditions_system_and_guess(monkeypatch) -> None: ) expected_matrix = (case.pl @ case.a0 @ case.pr).T.conjugate() - assert_close(captured['a'].toarray(), expected_matrix.toarray()) - assert_close(captured['b'], case.pr.T.conjugate() @ (-1j * case.omega * case.j)) - assert_close(captured['x0'], case.pr.T.conjugate() @ case.guess) + assert_close(cast(object, captured['a']).toarray(), expected_matrix.toarray()) # type: ignore[attr-defined] + assert_close(cast(numpy.ndarray, captured['b']), case.pr.T.conjugate() @ (-1j * case.omega * case.j)) + assert_close(cast(numpy.ndarray, captured['x0']), case.pr.T.conjugate() @ case.guess) assert captured['rtol'] == 1e-9 assert_close(result, case.pl.T.conjugate() @ case.solver_result) @@ -122,5 +124,5 @@ def test_generic_without_guess_does_not_inject_x0(monkeypatch) -> None: matrix_solver=fake_solver, ) - assert 'x0' not in captured['kwargs'] + assert 'x0' not in cast(dict[str, object], captured['kwargs']) assert_close(result, case.pr @ numpy.array([1.0, -1.0])) diff --git a/meanas/test/test_fdtd_base.py b/meanas/test/test_fdtd_base.py index bc1f514..c8246d5 100644 --- a/meanas/test/test_fdtd_base.py +++ b/meanas/test/test_fdtd_base.py @@ -1,5 +1,3 @@ -import numpy - from ..fdmath import functional as fd_functional from ..fdtd import base from ._test_builders import real_ramp diff --git a/meanas/test/test_fdtd_boundaries.py b/meanas/test/test_fdtd_boundaries.py index d7ba186..d60ca7a 100644 --- a/meanas/test/test_fdtd_boundaries.py +++ b/meanas/test/test_fdtd_boundaries.py @@ -58,5 +58,5 @@ def test_conducting_boundary_updates_expected_faces(direction: int, polarity: in [(-1, 1), (3, 1), (0, 0)], ) def test_conducting_boundary_rejects_invalid_arguments(direction: int, polarity: int) -> None: - with pytest.raises(Exception): + with pytest.raises(ValueError, match='Invalid direction|Bad polarity'): conducting_boundary(direction, polarity) diff --git a/meanas/test/test_fdtd_misc.py b/meanas/test/test_fdtd_misc.py index 65dc713..3688c6c 100644 --- a/meanas/test/test_fdtd_misc.py +++ b/meanas/test/test_fdtd_misc.py @@ -10,6 +10,9 @@ def test_gaussian_packet_accepts_array_input(one_sided: bool) -> None: source, delay = gaussian_packet(1.55, 0.1, dt, one_sided=one_sided) steps = numpy.array([0, int(numpy.ceil(delay / dt)) + 5]) envelope, cc, ss = source(steps) + assert isinstance(envelope, numpy.ndarray) + assert isinstance(cc, numpy.ndarray) + assert isinstance(ss, numpy.ndarray) assert envelope.shape == (2,) assert numpy.isfinite(envelope).all() diff --git a/meanas/test/test_fdtd_phasor.py b/meanas/test/test_fdtd_phasor.py index 7e1126b..9d28ee3 100644 --- a/meanas/test/test_fdtd_phasor.py +++ b/meanas/test/test_fdtd_phasor.py @@ -371,7 +371,7 @@ def _real_pulse_case() -> RealPulseCase: source_phasor, _delay = gaussian_packet(wl=wavelength, dwl=1.0, dt=dt, turn_on=1e-5) aa, cc, ss = source_phasor(numpy.arange(total_steps) + 0.5) - waveform = aa * (cc + 1j * ss) + waveform = numpy.asarray(aa * (cc + 1j * ss), dtype=complex) scale = fdtd.real_injection_scale(waveform, omega, dt, offset_steps=0.5)[0] j_accumulator = numpy.zeros((1, *full_shape), dtype=complex) diff --git a/meanas/test/test_fdtd_pml.py b/meanas/test/test_fdtd_pml.py index 06c2588..319260f 100644 --- a/meanas/test/test_fdtd_pml.py +++ b/meanas/test/test_fdtd_pml.py @@ -1,3 +1,5 @@ +from typing import Any + import numpy import pytest @@ -12,7 +14,7 @@ from .utils import assert_close [(3, 1, 4, 1.0), (0, 0, 4, 1.0), (0, 1, 2, 1.0), (0, 1, 4, 0.0)], ) def test_cpml_params_reject_invalid_arguments(axis: int, polarity: int, thickness: int, epsilon_eff: float) -> None: - with pytest.raises(Exception): + with pytest.raises(ValueError, match='Invalid axis|Invalid polarity|wise to have a pml|epsilon_eff must be positive'): cpml_params(axis=axis, polarity=polarity, dt=0.1, thickness=thickness, epsilon_eff=epsilon_eff) @@ -36,7 +38,7 @@ def test_updates_with_cpml_keeps_zero_fields_zero() -> None: e = numpy.zeros(shape, dtype=float) h = numpy.zeros(shape, dtype=float) dxes = [[numpy.ones(4), numpy.ones(4), numpy.ones(4)] for _ in range(2)] - params = [[None, None] for _ in range(3)] + params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)] params[0][0] = cpml_params(axis=0, polarity=-1, dt=0.1, thickness=3) update_e, update_h = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon) @@ -69,7 +71,7 @@ def test_updates_with_cpml_matches_base_updates_when_all_faces_disabled() -> Non e = _real_field(shape, 10.0) h = _real_field(shape, 100.0) dxes = _unit_dxes(shape) - params = [[None, None] for _ in range(3)] + params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)] update_e_cpml, update_h_cpml = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon) update_e_base = maxwell_e(dt=0.1, dxes=dxes) @@ -96,7 +98,7 @@ def test_updates_with_cpml_matches_base_updates_with_complex_dtype_when_all_face e = _complex_field(shape, 10.0) h = _complex_field(shape, 100.0) dxes = _unit_dxes(shape) - params = [[None, None] for _ in range(3)] + params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)] update_e_cpml, update_h_cpml = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon, dtype=complex) update_e_base = maxwell_e(dt=0.1, dxes=dxes) @@ -125,7 +127,7 @@ def test_updates_with_cpml_only_changes_the_configured_face_region() -> None: dxes = _unit_dxes(shape) thickness = 3 - params = [[None, None] for _ in range(3)] + params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)] params[0][0] = cpml_params(axis=0, polarity=-1, dt=0.1, thickness=thickness) update_e_cpml, update_h_cpml = updates_with_cpml(params, dt=0.1, dxes=dxes, epsilon=epsilon) @@ -166,7 +168,7 @@ def test_cpml_plane_wave_phasor_decays_monotonically_through_outgoing_pml() -> N epsilon = numpy.ones(shape, dtype=float) dxes = _unit_dxes(shape) - params = [[None, None] for _ in range(3)] + params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)] for polarity_index, polarity in enumerate((-1, 1)): params[0][polarity_index] = cpml_params(axis=0, polarity=polarity, dt=dt, thickness=thickness) @@ -212,7 +214,7 @@ def test_cpml_point_source_total_energy_reaches_late_time_plateau() -> None: epsilon = numpy.ones(shape, dtype=float) dxes = _unit_dxes(shape) - params = [[None, None] for _ in range(3)] + params: list[list[dict[str, Any] | None]] = [[None, None] for _ in range(3)] for axis in range(3): for polarity_index, polarity in enumerate((-1, 1)): params[axis][polarity_index] = cpml_params(axis=axis, polarity=polarity, dt=dt, thickness=thickness) diff --git a/meanas/test/test_import_fallbacks.py b/meanas/test/test_import_fallbacks.py index 75005d0..e332d1b 100644 --- a/meanas/test/test_import_fallbacks.py +++ b/meanas/test/test_import_fallbacks.py @@ -1,26 +1,28 @@ import builtins import importlib import pathlib +from types import ModuleType +from typing import Any +import pytest import meanas from ..fdfd import bloch -from .utils import assert_close -def _reload(module): +def _reload(module: ModuleType) -> ModuleType: return importlib.reload(module) -def _restore_reloaded(monkeypatch, module): +def _restore_reloaded(monkeypatch: pytest.MonkeyPatch, module: ModuleType) -> ModuleType: monkeypatch.undo() return _reload(module) -def test_meanas_import_survives_readme_open_failure(monkeypatch) -> None: # type: ignore[no-untyped-def] +def test_meanas_import_survives_readme_open_failure(monkeypatch: pytest.MonkeyPatch) -> None: expected_version = meanas.__version__ original_open = pathlib.Path.open - def failing_open(self: pathlib.Path, *args, **kwargs): # type: ignore[no-untyped-def] + def failing_open(self: pathlib.Path, *args: Any, **kwargs: Any) -> Any: if self.name == 'README.md': raise FileNotFoundError('forced README failure') return original_open(self, *args, **kwargs) @@ -35,10 +37,16 @@ def test_meanas_import_survives_readme_open_failure(monkeypatch) -> None: # typ _restore_reloaded(monkeypatch, meanas) -def test_bloch_reloads_with_numpy_fft_when_pyfftw_is_unavailable(monkeypatch) -> None: # type: ignore[no-untyped-def] +def test_bloch_reloads_with_numpy_fft_when_pyfftw_is_unavailable(monkeypatch: pytest.MonkeyPatch) -> None: original_import = builtins.__import__ - def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0): # type: ignore[no-untyped-def] + def fake_import( + name: str, + globals: dict[str, Any] | None = None, + locals: dict[str, Any] | None = None, + fromlist: tuple[str, ...] = (), + level: int = 0, + ) -> Any: if name.startswith('pyfftw'): raise ImportError('forced pyfftw failure') return original_import(name, globals, locals, fromlist, level) diff --git a/meanas/test/test_waveguide_fdtd_fdfd.py b/meanas/test/test_waveguide_fdtd_fdfd.py index ae2078d..9d42bf5 100644 --- a/meanas/test/test_waveguide_fdtd_fdfd.py +++ b/meanas/test/test_waveguide_fdtd_fdfd.py @@ -224,7 +224,7 @@ def _build_cpml_params() -> list[list[dict[str, numpy.ndarray | float]]]: def _build_complex_pulse_waveform(total_steps: int) -> tuple[numpy.ndarray, complex]: source_phasor, _delay = gaussian_packet(wl=WAVELENGTH, dwl=PULSE_DWL, dt=DT, turn_on=1e-5) aa, cc, ss = source_phasor(numpy.arange(total_steps) + 0.5) - waveform = aa * (cc + 1j * ss) + waveform = numpy.asarray(aa * (cc + 1j * ss), dtype=complex) scale = fdtd.temporal_phasor_scale(waveform, OMEGA, DT, offset_steps=0.5)[0] return waveform, scale @@ -272,7 +272,7 @@ def _run_real_field_straight_waveguide_case() -> RealFieldWaveguideResult: slices=REAL_FIELD_SOURCE_SLICES, epsilon=epsilon, ) - j_mode *= numpy.exp(1j * REAL_FIELD_SOURCE_PHASE) + j_mode = j_mode * numpy.exp(1j * REAL_FIELD_SOURCE_PHASE) monitor_mode = waveguide_3d.solve_mode( 0, omega=OMEGA, @@ -425,8 +425,8 @@ def _run_straight_waveguide_case(variant: str) -> WaveguideCalibrationResult: ) h_fdfd = functional.e2h(OMEGA, stretched_dxes)(e_fdfd) - overlap_td = vec(e_ph) @ vec(overlap_e).conj() - overlap_fd = vec(e_fdfd) @ vec(overlap_e).conj() + overlap_td = complex(vec(e_ph) @ vec(overlap_e).conj()) + overlap_fd = complex(vec(e_fdfd) @ vec(overlap_e).conj()) poynting_td = functional.poynting_e_cross_h(stretched_dxes)(e_ph, h_ph.conj()) poynting_fd = functional.poynting_e_cross_h(stretched_dxes)(e_fdfd, h_fdfd.conj()) @@ -551,10 +551,10 @@ def _run_width_step_scattering_case() -> WaveguideScatteringResult: ) h_fdfd = functional.e2h(OMEGA, stretched_dxes)(e_fdfd) - reflected_td = vec(e_ph) @ vec(reflected_overlap).conj() - reflected_fd = vec(e_fdfd) @ vec(reflected_overlap).conj() - transmitted_td = vec(e_ph) @ vec(transmitted_overlap).conj() - transmitted_fd = vec(e_fdfd) @ vec(transmitted_overlap).conj() + reflected_td = complex(vec(e_ph) @ vec(reflected_overlap).conj()) + reflected_fd = complex(vec(e_fdfd) @ vec(reflected_overlap).conj()) + transmitted_td = complex(vec(e_ph) @ vec(transmitted_overlap).conj()) + transmitted_fd = complex(vec(e_fdfd) @ vec(transmitted_overlap).conj()) poynting_td = functional.poynting_e_cross_h(stretched_dxes)(e_ph, h_ph.conj()) poynting_fd = functional.poynting_e_cross_h(stretched_dxes)(e_fdfd, h_fdfd.conj()) @@ -664,8 +664,8 @@ def _run_pulsed_straight_waveguide_case() -> PulsedWaveguideCalibrationResult: ) h_fdfd = functional.e2h(OMEGA, stretched_dxes)(e_fdfd) - overlap_td = vec(e_ph) @ vec(overlap_e).conj() - overlap_fd = vec(e_fdfd) @ vec(overlap_e).conj() + overlap_td = complex(vec(e_ph) @ vec(overlap_e).conj()) + overlap_fd = complex(vec(e_fdfd) @ vec(overlap_e).conj()) poynting_td = functional.poynting_e_cross_h(stretched_dxes)(e_ph, h_ph.conj()) poynting_fd = functional.poynting_e_cross_h(stretched_dxes)(e_fdfd, h_fdfd.conj()) diff --git a/meanas/test/test_waveguide_mode_helpers.py b/meanas/test/test_waveguide_mode_helpers.py index d5d3abf..ca2d917 100644 --- a/meanas/test/test_waveguide_mode_helpers.py +++ b/meanas/test/test_waveguide_mode_helpers.py @@ -16,7 +16,7 @@ def build_waveguide_3d_mode( *, slice_start: int, polarity: int, - ) -> tuple[numpy.ndarray, list[list[numpy.ndarray]], tuple[slice, slice, slice], dict[str, complex | numpy.ndarray]]: + ) -> tuple[numpy.ndarray, list[list[numpy.ndarray]], tuple[slice, slice, slice], waveguide_3d.Waveguide3DMode]: epsilon = numpy.ones((3, 5, 5, 1), dtype=float) dxes = [[numpy.ones(5), numpy.ones(5), numpy.ones(1)] for _ in range(2)] slices = (slice(slice_start, slice_start + 1), slice(None), slice(None)) diff --git a/meanas/test/utils.py b/meanas/test/utils.py index 3bafd49..62afaf0 100644 --- a/meanas/test/utils.py +++ b/meanas/test/utils.py @@ -1,5 +1,6 @@ import numpy from numpy.typing import NDArray +from numpy.typing import ArrayLike def make_prng(seed: int = 12345) -> numpy.random.RandomState: @@ -24,9 +25,9 @@ def assert_fields_close( ) def assert_close( - x: NDArray, - y: NDArray, + x: ArrayLike, + y: ArrayLike, *args, **kwargs, ) -> None: - numpy.testing.assert_allclose(x, y, *args, **kwargs) + numpy.testing.assert_allclose(numpy.asarray(x), numpy.asarray(y), *args, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 55fcac1..7f1d6b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,9 @@ lint.ignore = [ "TRY002", # Exception() ] +[tool.ruff.lint.per-file-ignores] +"meanas/test/**/*.py" = ["ANN", "ARG", "TC006"] + [[tool.mypy.overrides]] module = [