From bc428f5e8ea3da6cc42353ee91a6d7bd014c0b3e Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Fri, 16 Oct 2020 21:46:04 -0700 Subject: [PATCH] add more type hints --- meanas/eigensolvers.py | 3 ++- meanas/fdfd/bloch.py | 28 ++++++++++----------- meanas/fdfd/functional.py | 36 +++++++++++++------------- meanas/fdfd/operators.py | 13 +++++----- meanas/fdfd/scpml.py | 16 ++++++------ meanas/fdfd/solvers.py | 6 ++--- meanas/fdfd/waveguide_2d.py | 6 ++--- meanas/fdmath/operators.py | 6 ++--- meanas/fdtd/__init__.py | 2 +- meanas/fdtd/boundaries.py | 8 +++--- meanas/fdtd/energy.py | 3 +++ meanas/fdtd/pml.py | 2 +- meanas/test/conftest.py | 22 ++++++++++------ meanas/test/test_fdfd.py | 29 ++++++++++++++------- meanas/test/test_fdfd_pml.py | 49 ++++++++++++++++++++++++++---------- meanas/test/test_fdtd.py | 34 ++++++++++++++++--------- meanas/test/utils.py | 13 ++++++++-- 17 files changed, 170 insertions(+), 106 deletions(-) diff --git a/meanas/eigensolvers.py b/meanas/eigensolvers.py index a15a06d..8c1739e 100644 --- a/meanas/eigensolvers.py +++ b/meanas/eigensolvers.py @@ -73,8 +73,9 @@ def rayleigh_quotient_iteration(operator: Union[sparse.spmatrix, spalg.LinearOpe dtype=operator.dtype, matvec=lambda v: eigval * v) if solver is None: - def solver(A, b): + def solver(A: spalg.LinearOperator, b: numpy.ndarray) -> numpy.ndarray: return spalg.bicgstab(A, b)[0] + assert(solver is not None) v = numpy.squeeze(guess_vector) v /= norm(v) diff --git a/meanas/fdfd/bloch.py b/meanas/fdfd/bloch.py index 33c5c13..ed6ca39 100644 --- a/meanas/fdfd/bloch.py +++ b/meanas/fdfd/bloch.py @@ -80,7 +80,7 @@ This module contains functions for generating and solving the ''' -from typing import Tuple, Callable +from typing import Tuple, Callable, Any, List, Optional, cast import logging import numpy # type: ignore from numpy import pi, real, trace # type: ignore @@ -109,10 +109,10 @@ try: 'planner_effort': 'FFTW_EXHAUSTIVE', } - def fftn(*args, **kwargs): + def fftn(*args: Any, **kwargs: Any) -> numpy.ndarray: return pyfftw.interfaces.numpy_fft.fftn(*args, **kwargs, **fftw_args) - def ifftn(*args, **kwargs): + def ifftn(*args: Any, **kwargs: Any) -> numpy.ndarray: return pyfftw.interfaces.numpy_fft.ifftn(*args, **kwargs, **fftw_args) except ImportError: @@ -199,7 +199,7 @@ def maxwell_operator(k0: numpy.ndarray, if mu is not None: mu = numpy.stack(mu, 3) - def operator(h: numpy.ndarray): + def operator(h: numpy.ndarray) -> numpy.ndarray: """ Maxwell operator for Bloch eigenmode simulation. @@ -309,11 +309,11 @@ def hmn_2_hxyz(k0: numpy.ndarray, shape = epsilon[0].shape + (1,) _k_mag, m, n = generate_kmn(k0, G_matrix, shape) - def operator(h: numpy.ndarray): + def operator(h: numpy.ndarray) -> fdfield_t: hin_m, hin_n = [hi.reshape(shape) for hi in numpy.split(h, 2)] h_xyz = (m * hin_m + n * hin_n) - return [ifftn(hi) for hi in numpy.rollaxis(h_xyz, 3)] + return numpy.array([ifftn(hi) for hi in numpy.rollaxis(h_xyz, 3)]) return operator @@ -351,7 +351,7 @@ def inverse_maxwell_operator_approx(k0: numpy.ndarray, if mu is not None: mu = numpy.stack(mu, 3) - def operator(h: numpy.ndarray): + def operator(h: numpy.ndarray) -> numpy.ndarray: """ Approximate inverse Maxwell operator for Bloch eigenmode simulation. @@ -429,7 +429,7 @@ def find_k(frequency: float, direction = numpy.array(direction) / norm(direction) - def get_f(k0_mag: float, band: int = 0): + def get_f(k0_mag: float, band: int = 0) -> numpy.ndarray: k0 = direction * k0_mag n, v = eigsolve(band + 1, k0, G_matrix=G_matrix, epsilon=epsilon, mu=mu) f = numpy.sqrt(numpy.abs(numpy.real(n[band]))) @@ -552,12 +552,12 @@ def eigsolve(num_modes: int, symZtD = _symmetrize(Z.conj().T @ D) symZtAD = _symmetrize(Z.conj().T @ AD) - Qi_memo = [None, None] + Qi_memo: List[Optional[float]] = [None, None] - def Qi_func(theta): + def Qi_func(theta: float) -> float: nonlocal Qi_memo if Qi_memo[0] == theta: - return Qi_memo[1] + return cast(float, Qi_memo[1]) c = numpy.cos(theta) s = numpy.sin(theta) @@ -579,7 +579,7 @@ def eigsolve(num_modes: int, Qi_memo[1] = Qi return Qi - def trace_func(theta): + def trace_func(theta: float) -> float: c = numpy.cos(theta) s = numpy.sin(theta) Qi = Qi_func(theta) @@ -685,9 +685,9 @@ def linmin(x_guess, f0, df0, x_max, f_tol=0.1, df_tol=min(tolerance, 1e-6), x_to return x, fx, dfx ''' -def _rtrace_AtB(A, B): +def _rtrace_AtB(A: numpy.ndarray, B: numpy.ndarray) -> numpy.ndarray: return real(numpy.sum(A.conj() * B)) -def _symmetrize(A): +def _symmetrize(A: numpy.ndarray) -> numpy.ndarray: return (A + A.conj().T) * 0.5 diff --git a/meanas/fdfd/functional.py b/meanas/fdfd/functional.py index 488d58e..92ec8e9 100644 --- a/meanas/fdfd/functional.py +++ b/meanas/fdfd/functional.py @@ -36,13 +36,13 @@ def e_full(omega: complex, ch = curl_back(dxes[1]) ce = curl_forward(dxes[0]) - def op_1(e): + def op_1(e: fdfield_t) -> fdfield_t: curls = ch(ce(e)) - return curls - omega ** 2 * epsilon * e + return curls - omega ** 2 * epsilon * e # type: ignore # issues with numpy/mypy - def op_mu(e): + def op_mu(e: fdfield_t) -> fdfield_t: curls = ch(mu * ce(e)) - return curls - omega ** 2 * epsilon * e + return curls - omega ** 2 * epsilon * e # type: ignore # issues with numpy/mypy if numpy.any(numpy.equal(mu, None)): return op_1 @@ -72,13 +72,13 @@ def eh_full(omega: complex, ch = curl_back(dxes[1]) ce = curl_forward(dxes[0]) - def op_1(e, h): + def op_1(e: fdfield_t, h: fdfield_t) -> Tuple[fdfield_t, fdfield_t]: return (ch(h) - 1j * omega * epsilon * e, - ce(e) + 1j * omega * h) + ce(e) + 1j * omega * h) # type: ignore # issues with numpy/mypy - def op_mu(e, h): + def op_mu(e: fdfield_t, h: fdfield_t) -> Tuple[fdfield_t, fdfield_t]: return (ch(h) - 1j * omega * epsilon * e, - ce(e) + 1j * omega * mu * h) + ce(e) + 1j * omega * mu * h) # type: ignore # issues with numpy/mypy if numpy.any(numpy.equal(mu, None)): return op_1 @@ -105,11 +105,11 @@ def e2h(omega: complex, """ ce = curl_forward(dxes[0]) - def e2h_1_1(e): - return ce(e) / (-1j * omega) + def e2h_1_1(e: fdfield_t) -> fdfield_t: + return ce(e) / (-1j * omega) # type: ignore # issues with numpy/mypy - def e2h_mu(e): - return ce(e) / (-1j * omega * mu) + def e2h_mu(e: fdfield_t) -> fdfield_t: + return ce(e) / (-1j * omega * mu) # type: ignore # issues with numpy/mypy if numpy.any(numpy.equal(mu, None)): return e2h_1_1 @@ -137,13 +137,13 @@ def m2j(omega: complex, """ ch = curl_back(dxes[1]) - def m2j_mu(m): + def m2j_mu(m: fdfield_t) -> fdfield_t: J = ch(m / mu) / (-1j * omega) - return J + return J # type: ignore # issues with numpy/mypy - def m2j_1(m): + def m2j_1(m: fdfield_t) -> fdfield_t: J = ch(m) / (-1j * omega) - return J + return J # type: ignore # issues with numpy/mypy if numpy.any(numpy.equal(mu, None)): return m2j_1 @@ -177,7 +177,7 @@ def e_tfsf_source(TF_region: fdfield_t, # TODO documentation A = e_full(omega, dxes, epsilon, mu) - def op(e): + def op(e: fdfield_t) -> fdfield_t: neg_iwj = A(TF_region * e) - TF_region * A(e) return neg_iwj / (-1j * omega) return op @@ -205,7 +205,7 @@ def poynting_e_cross_h(dxes: dx_lists_t) -> Callable[[fdfield_t, fdfield_t], fdf Returns: Function `f` that returns E x H as required for the poynting vector. """ - def exh(e: fdfield_t, h: fdfield_t): + def exh(e: fdfield_t, h: fdfield_t) -> fdfield_t: s = numpy.empty_like(e) ex = e[0] * dxes[0][0][:, None, None] ey = e[1] * dxes[0][1][None, :, None] diff --git a/meanas/fdfd/operators.py b/meanas/fdfd/operators.py index ef2fd57..370b7a2 100644 --- a/meanas/fdfd/operators.py +++ b/meanas/fdfd/operators.py @@ -416,12 +416,13 @@ def e_boundary_source(mask: vfdfield_t, shape = [len(dxe) for dxe in dxes[0]] jmask = numpy.zeros_like(mask, dtype=bool) - if periodic_mask_edges: - def shift(axis, polarity): - return rotation(axis=axis, shape=shape, shift_distance=polarity) - else: - def shift(axis, polarity): - return shift_with_mirror(axis=axis, shape=shape, shift_distance=polarity) + def shift_rot(axis: int, polarity: int) -> sparse.spmatrix: + return rotation(axis=axis, shape=shape, shift_distance=polarity) + + def shift_mir(axis: int, polarity: int) -> sparse.spmatrix: + return shift_with_mirror(axis=axis, shape=shape, shift_distance=polarity) + + shift = shift_rot if periodic_mask_edges else shift_mir for axis in (0, 1, 2): if shape[axis] == 1: diff --git a/meanas/fdfd/scpml.py b/meanas/fdfd/scpml.py index 00587f7..67d58ca 100644 --- a/meanas/fdfd/scpml.py +++ b/meanas/fdfd/scpml.py @@ -2,11 +2,9 @@ Functions for creating stretched coordinate perfectly matched layer (PML) absorbers. """ -from typing import Sequence, Union, Callable, Optional +from typing import Sequence, Union, Callable, Optional, List import numpy # type: ignore -from ..fdmath import dx_lists_t, dx_lists_mut - __author__ = 'Jan Petykiewicz' @@ -42,7 +40,7 @@ def uniform_grid_scpml(shape: Union[numpy.ndarray, Sequence[int]], omega: float, epsilon_effective: float = 1.0, s_function: Optional[s_function_t] = None, - ) -> dx_lists_mut: + ) -> List[List[numpy.ndarray]]: """ Create dx arrays for a uniform grid with a cell width of 1 and a pml. @@ -69,7 +67,7 @@ def uniform_grid_scpml(shape: Union[numpy.ndarray, Sequence[int]], s_function = prepare_s_function() # Normalized distance to nearest boundary - def ll(u, n, t): + def ll(u: numpy.ndarray, n: numpy.ndarray, t: numpy.ndarray) -> numpy.ndarray: return ((t - u).clip(0) + (u - (n - t)).clip(0)) / t dx_a = [numpy.array(numpy.inf)] * 3 @@ -90,14 +88,14 @@ def uniform_grid_scpml(shape: Union[numpy.ndarray, Sequence[int]], return [dx_a, dx_b] -def stretch_with_scpml(dxes: dx_lists_mut, +def stretch_with_scpml(dxes: List[List[numpy.ndarray]], axis: int, polarity: int, omega: float, epsilon_effective: float = 1.0, thickness: int = 10, s_function: Optional[s_function_t] = None, - ) -> dx_lists_t: + ) -> List[List[numpy.ndarray]]: """ Stretch dxes to contain a stretched-coordinate PML (SCPML) in one direction along one axis. @@ -134,7 +132,7 @@ def stretch_with_scpml(dxes: dx_lists_mut, bound = pos[thickness] d = bound - pos[0] - def l_d(x): + def l_d(x: numpy.ndarray) -> numpy.ndarray: return (bound - x) / (bound - pos[0]) slc = slice(thickness) @@ -144,7 +142,7 @@ def stretch_with_scpml(dxes: dx_lists_mut, bound = pos[-thickness - 1] d = pos[-1] - bound - def l_d(x): + def l_d(x: numpy.ndarray) -> numpy.ndarray: return (x - bound) / (pos[-1] - bound) if thickness == 0: diff --git a/meanas/fdfd/solvers.py b/meanas/fdfd/solvers.py index 73548ca..a8a423a 100644 --- a/meanas/fdfd/solvers.py +++ b/meanas/fdfd/solvers.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) def _scipy_qmr(A: scipy.sparse.csr_matrix, b: numpy.ndarray, - **kwargs + **kwargs: Any, ) -> numpy.ndarray: """ Wrapper for scipy.sparse.linalg.qmr @@ -37,14 +37,14 @@ def _scipy_qmr(A: scipy.sparse.csr_matrix, ''' ii = 0 - def log_residual(xk): + def log_residual(xk: numpy.ndarray) -> None: nonlocal ii ii += 1 if ii % 100 == 0: logger.info('Solver residual at iteration {} : {}'.format(ii, norm(A @ xk - b))) if 'callback' in kwargs: - def augmented_callback(xk): + def augmented_callback(xk: numpy.ndarray) -> None: log_residual(xk) kwargs['callback'](xk) diff --git a/meanas/fdfd/waveguide_2d.py b/meanas/fdfd/waveguide_2d.py index e5c3775..c7d8148 100644 --- a/meanas/fdfd/waveguide_2d.py +++ b/meanas/fdfd/waveguide_2d.py @@ -146,7 +146,7 @@ to account for numerical dispersion if the result is introduced into a space wit """ # TODO update module docs -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Any import numpy # type: ignore from numpy.linalg import norm # type: ignore import scipy.sparse as sparse # type: ignore @@ -721,8 +721,8 @@ def solve_modes(mode_numbers: List[int], def solve_mode(mode_number: int, - *args, - **kwargs + *args: Any, + **kwargs: Any, ) -> Tuple[vfdfield_t, complex]: """ Wrapper around `solve_modes()` that solves for a single mode. diff --git a/meanas/fdmath/operators.py b/meanas/fdmath/operators.py index 72fbf5d..2ac4e7a 100644 --- a/meanas/fdmath/operators.py +++ b/meanas/fdmath/operators.py @@ -67,7 +67,7 @@ def shift_with_mirror(axis: int, shape: Sequence[int], shift_distance: int = 1) raise Exception('Shift ({}) is too large for axis {} of size {}'.format( shift_distance, axis, shape[axis])) - def mirrored_range(n, s): + def mirrored_range(n: int, s: int) -> numpy.ndarray: v = numpy.arange(n) + s v = numpy.where(v >= n, 2 * n - v - 1, v) v = numpy.where(v < 0, - 1 - v, v) @@ -103,7 +103,7 @@ def deriv_forward(dx_e: Sequence[numpy.ndarray]) -> List[sparse.spmatrix]: dx_e_expanded = numpy.meshgrid(*dx_e, indexing='ij') - def deriv(axis): + def deriv(axis: int) -> sparse.spmatrix: return rotation(axis, shape, 1) - sparse.eye(n) Ds = [sparse.diags(+1 / dx.ravel(order='C')) @ deriv(a) @@ -128,7 +128,7 @@ def deriv_back(dx_h: Sequence[numpy.ndarray]) -> List[sparse.spmatrix]: dx_h_expanded = numpy.meshgrid(*dx_h, indexing='ij') - def deriv(axis): + def deriv(axis: int) -> sparse.spmatrix: return rotation(axis, shape, -1) - sparse.eye(n) Ds = [sparse.diags(-1 / dx.ravel(order='C')) @ deriv(a) diff --git a/meanas/fdtd/__init__.py b/meanas/fdtd/__init__.py index 1a7e1bd..92e215f 100644 --- a/meanas/fdtd/__init__.py +++ b/meanas/fdtd/__init__.py @@ -130,7 +130,7 @@ $$ \\end{aligned} $$ -This result is exact an should practically hold to within numerical precision. No time- +This result is exact and should practically hold to within numerical precision. No time- or spatial-averaging is necessary. Note that each value of $J$ contributes to the energy twice (i.e. once per field update) diff --git a/meanas/fdtd/boundaries.py b/meanas/fdtd/boundaries.py index 8cf0a25..d03a976 100644 --- a/meanas/fdtd/boundaries.py +++ b/meanas/fdtd/boundaries.py @@ -24,13 +24,13 @@ def conducting_boundary(direction: int, boundary_slice[direction] = 0 shifted1_slice[direction] = 1 - def en(e: fdfield_t): + def en(e: fdfield_t) -> fdfield_t: e[direction][boundary_slice] = 0 e[u][boundary_slice] = e[u][shifted1_slice] e[v][boundary_slice] = e[v][shifted1_slice] return e - def hn(h: fdfield_t): + def hn(h: fdfield_t) -> fdfield_t: h[direction][boundary_slice] = h[direction][shifted1_slice] h[u][boundary_slice] = 0 h[v][boundary_slice] = 0 @@ -46,14 +46,14 @@ def conducting_boundary(direction: int, shifted1_slice[direction] = -2 shifted2_slice[direction] = -3 - def ep(e: fdfield_t): + def ep(e: fdfield_t) -> fdfield_t: e[direction][boundary_slice] = -e[direction][shifted2_slice] e[direction][shifted1_slice] = 0 e[u][boundary_slice] = e[u][shifted1_slice] e[v][boundary_slice] = e[v][shifted1_slice] return e - def hp(h: fdfield_t): + def hp(h: fdfield_t) -> fdfield_t: h[direction][boundary_slice] = h[direction][shifted1_slice] h[u][boundary_slice] = -h[u][shifted2_slice] h[u][shifted1_slice] = 0 diff --git a/meanas/fdtd/energy.py b/meanas/fdtd/energy.py index b8aa8dc..121c4f6 100644 --- a/meanas/fdtd/energy.py +++ b/meanas/fdtd/energy.py @@ -5,6 +5,9 @@ from ..fdmath import dx_lists_t, fdfield_t from ..fdmath.functional import deriv_back +# TODO documentation + + def poynting(e: fdfield_t, h: fdfield_t, dxes: Optional[dx_lists_t] = None, diff --git a/meanas/fdtd/pml.py b/meanas/fdtd/pml.py index 91e7f12..e1c9668 100644 --- a/meanas/fdtd/pml.py +++ b/meanas/fdtd/pml.py @@ -63,7 +63,7 @@ def cpml(direction: int, expand_slice_l[direction] = slice(None) expand_slice = tuple(expand_slice_l) - def par(x): + def par(x: numpy.ndarray) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: scaling = (x / thickness) ** m sigma = scaling * sigma_max kappa = 1 + scaling * (kappa_max - 1) diff --git a/meanas/test/conftest.py b/meanas/test/conftest.py index 3514087..932a62c 100644 --- a/meanas/test/conftest.py +++ b/meanas/test/conftest.py @@ -3,6 +3,7 @@ Test fixtures """ +from typing import Tuple, Iterable, List import numpy # type: ignore import pytest # type: ignore @@ -14,22 +15,26 @@ from .utils import PRNG (5, 5, 5), # (7, 7, 7), ]) -def shape(request): +def shape(request: pytest.FixtureRequest) -> Iterable[Tuple[int, ...]]: yield (3, *request.param) @pytest.fixture(scope='module', params=[1.0, 1.5]) -def epsilon_bg(request): +def epsilon_bg(request: pytest.FixtureRequest) -> Iterable[float]: yield request.param @pytest.fixture(scope='module', params=[1.0, 2.5]) -def epsilon_fg(request): +def epsilon_fg(request: pytest.FixtureRequest) -> Iterable[float]: yield request.param @pytest.fixture(scope='module', params=['center', '000', 'random']) -def epsilon(request, shape, epsilon_bg, epsilon_fg): +def epsilon(request: pytest.FixtureRequest, + shape: Tuple[int, ...], + epsilon_bg: float, + epsilon_fg: float, + ) -> Iterable[numpy.ndarray]: is3d = (numpy.array(shape) == 1).sum() == 0 if is3d: if request.param == '000': @@ -53,17 +58,20 @@ def epsilon(request, shape, epsilon_bg, epsilon_fg): @pytest.fixture(scope='module', params=[1.0]) # 1.5 -def j_mag(request): +def j_mag(request: pytest.FixtureRequest) -> Iterable[float]: yield request.param @pytest.fixture(scope='module', params=[1.0, 1.5]) -def dx(request): +def dx(request: pytest.FixtureRequest) -> Iterable[float]: yield request.param @pytest.fixture(scope='module', params=['uniform', 'centerbig']) -def dxes(request, shape, dx): +def dxes(request: pytest.FixtureRequest, + shape: Tuple[int, ...], + dx: float, + ) -> Iterable[List[List[numpy.ndarray]]]: if request.param == 'uniform': dxes = [[numpy.full(s, dx) for s in shape[1:]] for _ in range(2)] elif request.param == 'centerbig': diff --git a/meanas/test/test_fdfd.py b/meanas/test/test_fdfd.py index c6b3c02..ac84213 100644 --- a/meanas/test/test_fdfd.py +++ b/meanas/test/test_fdfd.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Tuple, Iterable, Optional import dataclasses import pytest # type: ignore import numpy # type: ignore @@ -9,14 +9,14 @@ from ..fdmath import vec, unvec from .utils import assert_close # , assert_fields_close -def test_residual(sim): +def test_residual(sim: 'FDResult') -> None: A = fdfd.operators.e_full(sim.omega, sim.dxes, vec(sim.epsilon)).tocsr() b = -1j * sim.omega * vec(sim.j) residual = A @ vec(sim.e) - b assert numpy.linalg.norm(residual) < 1e-10 -def test_poynting_planes(sim): +def test_poynting_planes(sim: 'FDResult') -> None: mask = (sim.j != 0).any(axis=0) if mask.sum() != 2: pytest.skip(f'test_poynting_planes will only test 2-point sources, got {mask.sum()}') @@ -53,17 +53,17 @@ def test_poynting_planes(sim): # Also see conftest.py @pytest.fixture(params=[1 / 1500]) -def omega(request): +def omega(request: pytest.FixtureRequest) -> Iterable[float]: yield request.param @pytest.fixture(params=[None]) -def pec(request): +def pec(request: pytest.FixtureRequest) -> Iterable[Optional[numpy.ndarray]]: yield request.param @pytest.fixture(params=[None]) -def pmc(request): +def pmc(request: pytest.FixtureRequest) -> Iterable[Optional[numpy.ndarray]]: yield request.param @@ -74,7 +74,10 @@ def pmc(request): @pytest.fixture(params=['diag']) # 'center' -def j_distribution(request, shape, j_mag): +def j_distribution(request: pytest.FixtureRequest, + shape: Tuple[int, ...], + j_mag: float, + ) -> Iterable[numpy.ndarray]: j = numpy.zeros(shape, dtype=complex) center_mask = numpy.zeros(shape, dtype=bool) center_mask[:, shape[1] // 2, shape[2] // 2, shape[3] // 2] = True @@ -89,7 +92,7 @@ def j_distribution(request, shape, j_mag): @dataclasses.dataclass() class FDResult: - shape: Tuple[int] + shape: Tuple[int, ...] dxes: List[List[numpy.ndarray]] epsilon: numpy.ndarray omega: complex @@ -100,7 +103,15 @@ class FDResult: @pytest.fixture() -def sim(request, shape, epsilon, dxes, j_distribution, omega, pec, pmc): +def sim(request: pytest.FixtureRequest, + shape: Tuple[int, ...], + epsilon: numpy.ndarray, + dxes: List[List[numpy.ndarray]], + j_distribution: numpy.ndarray, + omega: float, + pec: Optional[numpy.ndarray], + pmc: Optional[numpy.ndarray], + ) -> FDResult: """ Build simulation from parts """ diff --git a/meanas/test/test_fdfd_pml.py b/meanas/test/test_fdfd_pml.py index 436aa39..ac57750 100644 --- a/meanas/test/test_fdfd_pml.py +++ b/meanas/test/test_fdfd_pml.py @@ -1,15 +1,15 @@ -##################################### +from typing import Optional, Tuple, Iterable, List import pytest # type: ignore import numpy # type: ignore from numpy.testing import assert_allclose # type: ignore from .. import fdfd -from ..fdmath import vec, unvec +from ..fdmath import vec, unvec, dx_lists_mut #from .utils import assert_close, assert_fields_close from .test_fdfd import FDResult -def test_pml(sim, src_polarity): +def test_pml(sim: FDResult, src_polarity: int) -> None: e_sqr = numpy.squeeze((sim.e.conj() * sim.e).sum(axis=0)) # from matplotlib import pyplot @@ -42,34 +42,40 @@ def test_pml(sim, src_polarity): # Also see conftest.py @pytest.fixture(params=[1 / 1500]) -def omega(request): +def omega(request: pytest.FixtureRequest) -> Iterable[float]: yield request.param @pytest.fixture(params=[None]) -def pec(request): +def pec(request: pytest.FixtureRequest) -> Iterable[Optional[numpy.ndarray]]: yield request.param @pytest.fixture(params=[None]) -def pmc(request): +def pmc(request: pytest.FixtureRequest) -> Iterable[Optional[numpy.ndarray]]: yield request.param @pytest.fixture(params=[(30, 1, 1), (1, 30, 1), (1, 1, 30)]) -def shape(request): +def shape(request: pytest.FixtureRequest) -> Iterable[Tuple[int, ...]]: yield (3, *request.param) @pytest.fixture(params=[+1, -1]) -def src_polarity(request): +def src_polarity(request: pytest.FixtureRequest) -> Iterable[int]: yield request.param @pytest.fixture() -def j_distribution(request, shape, epsilon, dxes, omega, src_polarity): +def j_distribution(request: pytest.FixtureRequest, + shape: Tuple[int, ...], + epsilon: numpy.ndarray, + dxes: dx_lists_mut, + omega: float, + src_polarity: int, + ) -> Iterable[numpy.ndarray]: j = numpy.zeros(shape, dtype=complex) dim = numpy.where(numpy.array(shape[1:]) > 1)[0][0] # Propagation axis @@ -101,13 +107,22 @@ def j_distribution(request, shape, epsilon, dxes, omega, src_polarity): @pytest.fixture() -def epsilon(request, shape, epsilon_bg, epsilon_fg): +def epsilon(request: pytest.FixtureRequest, + shape: Tuple[int, ...], + epsilon_bg: float, + epsilon_fg: float, + ) -> Iterable[numpy.ndarray]: epsilon = numpy.full(shape, epsilon_fg, dtype=float) yield epsilon @pytest.fixture(params=['uniform']) -def dxes(request, shape, dx, omega, epsilon_fg): +def dxes(request: pytest.FixtureRequest, + shape: Tuple[int, ...], + dx: float, + omega: float, + epsilon_fg: float, + ) -> Iterable[List[List[numpy.ndarray]]]: if request.param == 'uniform': dxes = [[numpy.full(s, dx) for s in shape[1:]] for _ in range(2)] dim = numpy.where(numpy.array(shape[1:]) > 1)[0][0] # Propagation axis @@ -120,7 +135,15 @@ def dxes(request, shape, dx, omega, epsilon_fg): @pytest.fixture() -def sim(request, shape, epsilon, dxes, j_distribution, omega, pec, pmc): +def sim(request: pytest.FixtureRequest, + shape: Tuple[int, ...], + epsilon: numpy.ndarray, + dxes: dx_lists_mut, + j_distribution: numpy.ndarray, + omega: float, + pec: Optional[numpy.ndarray], + pmc: Optional[numpy.ndarray], + ) -> FDResult: j_vec = vec(j_distribution) eps_vec = vec(epsilon) e_vec = fdfd.solvers.generic(J=j_vec, omega=omega, dxes=dxes, epsilon=eps_vec, @@ -129,7 +152,7 @@ def sim(request, shape, epsilon, dxes, j_distribution, omega, pec, pmc): sim = FDResult( shape=shape, - dxes=dxes, + dxes=[list(d) for d in dxes], epsilon=epsilon, j=j_distribution, e=e, diff --git a/meanas/test/test_fdtd.py b/meanas/test/test_fdtd.py index efeb3e9..56fa553 100644 --- a/meanas/test/test_fdtd.py +++ b/meanas/test/test_fdtd.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Tuple, Iterable import dataclasses import pytest # type: ignore import numpy # type: ignore @@ -8,7 +8,7 @@ from .. import fdtd from .utils import assert_close, assert_fields_close, PRNG -def test_initial_fields(sim): +def test_initial_fields(sim: 'TDResult') -> None: # Make sure initial fields didn't change e0 = sim.es[0] h0 = sim.hs[0] @@ -20,7 +20,7 @@ def test_initial_fields(sim): assert not h0.any() -def test_initial_energy(sim): +def test_initial_energy(sim: 'TDResult') -> None: """ Assumes fields start at 0 before J0 is added """ @@ -41,7 +41,7 @@ def test_initial_energy(sim): assert_fields_close(e0_dot_j0, u0) -def test_energy_conservation(sim): +def test_energy_conservation(sim: 'TDResult') -> None: """ Assumes fields start at 0 before J0 is added """ @@ -63,7 +63,7 @@ def test_energy_conservation(sim): assert_close(u_estep.sum(), u) -def test_poynting_divergence(sim): +def test_poynting_divergence(sim: 'TDResult') -> None: args = {'dxes': sim.dxes, 'epsilon': sim.epsilon} @@ -90,7 +90,7 @@ def test_poynting_divergence(sim): u_eprev = u_estep -def test_poynting_planes(sim): +def test_poynting_planes(sim: 'TDResult') -> None: mask = (sim.js[0] != 0).any(axis=0) if mask.sum() > 1: pytest.skip('test_poynting_planes can only test single point sources, got {}'.format(mask.sum())) @@ -140,30 +140,33 @@ def test_poynting_planes(sim): @pytest.fixture(params=[0.3]) -def dt(request): +def dt(request: pytest.FixtureRequest) -> Iterable[float]: yield request.param @dataclasses.dataclass() class TDResult: - shape: Tuple[int] + shape: Tuple[int, ...] dt: float dxes: List[List[numpy.ndarray]] epsilon: numpy.ndarray j_distribution: numpy.ndarray - j_steps: Tuple[int] + j_steps: Tuple[int, ...] es: List[numpy.ndarray] = dataclasses.field(default_factory=list) hs: List[numpy.ndarray] = dataclasses.field(default_factory=list) js: List[numpy.ndarray] = dataclasses.field(default_factory=list) @pytest.fixture(params=[(0, 4, 8)]) # (0,) -def j_steps(request): +def j_steps(request: pytest.fixtureRequest) -> Iterable[Tuple[int, ...]]: yield request.param @pytest.fixture(params=['center', 'random']) -def j_distribution(request, shape, j_mag): +def j_distribution(request: pytest.FixtureRequest, + shape: Tuple[int, ...], + j_mag: float, + ) -> Iterable[numpy.ndarray]: j = numpy.zeros(shape) if request.param == 'center': j[:, shape[1] // 2, shape[2] // 2, shape[3] // 2] = j_mag @@ -175,7 +178,14 @@ def j_distribution(request, shape, j_mag): @pytest.fixture() -def sim(request, shape, epsilon, dxes, dt, j_distribution, j_steps): +def sim(request: pytest.FixtureRequest, + shape: Tuple[int, ...], + epsilon: numpy.ndarray, + dxes: List[List[numpy.ndarray]], + dt: float, + j_distribution: numpy.ndarray, + j_steps: Tuple[int, ...], + ) -> TDResult: is3d = (numpy.array(shape) == 1).sum() == 0 if is3d: if dt != 0.3: diff --git a/meanas/test/utils.py b/meanas/test/utils.py index a49bc04..7c8c372 100644 --- a/meanas/test/utils.py +++ b/meanas/test/utils.py @@ -1,13 +1,22 @@ +from typing import Any import numpy # type: ignore PRNG = numpy.random.RandomState(12345) -def assert_fields_close(x, y, *args, **kwargs): +def assert_fields_close(x: numpy.ndarray, + y: numpy.ndarray, + *args: Any, + **kwargs: Any, + ) -> None: numpy.testing.assert_allclose( x, y, verbose=False, err_msg='Fields did not match:\n{}\n{}'.format(numpy.rollaxis(x, -1), numpy.rollaxis(y, -1)), *args, **kwargs) -def assert_close(x, y, *args, **kwargs): +def assert_close(x: numpy.ndarray, + y: numpy.ndarray, + *args: Any, + **kwargs: Any, + ) -> None: numpy.testing.assert_allclose(x, y, *args, **kwargs)