From ee51c7db496ec6db9cb82a724f7018926348d1b1 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Sun, 28 Jul 2024 23:23:47 -0700 Subject: [PATCH] improve type annotations --- meanas/fdfd/waveguide_3d.py | 7 ++++--- meanas/fdfd/waveguide_cyl.py | 2 +- meanas/fdmath/functional.py | 13 +++++++------ meanas/fdmath/operators.py | 9 +++++---- meanas/fdmath/types.py | 14 +++++++------- 5 files changed, 24 insertions(+), 21 deletions(-) diff --git a/meanas/fdfd/waveguide_3d.py b/meanas/fdfd/waveguide_3d.py index 7f994d3..2f499fa 100644 --- a/meanas/fdfd/waveguide_3d.py +++ b/meanas/fdfd/waveguide_3d.py @@ -7,6 +7,7 @@ its parameters into 2D equivalents and expands the results back into 3D. from typing import Sequence, Any import numpy from numpy.typing import NDArray +from numpy import complexfloating from ..fdmath import vec, unvec, dx_lists_t, fdfield_t, cfdfield_t from . import operators, waveguide_2d @@ -21,7 +22,7 @@ def solve_mode( slices: Sequence[slice], epsilon: fdfield_t, mu: fdfield_t | None = None, - ) -> dict[str, complex | NDArray[numpy.float_]]: + ) -> dict[str, complex | NDArray[complexfloating]]: """ Given a 3D grid, selects a slice from the grid and attempts to solve for an eigenmode propagating through that slice. @@ -40,8 +41,8 @@ def solve_mode( Returns: ``` { - 'E': list[NDArray[numpy.float_]], - 'H': list[NDArray[numpy.float_]], + 'E': NDArray[complexfloating], + 'H': NDArray[complexfloating], 'wavenumber': complex, } ``` diff --git a/meanas/fdfd/waveguide_cyl.py b/meanas/fdfd/waveguide_cyl.py index d476caa..596c6be 100644 --- a/meanas/fdfd/waveguide_cyl.py +++ b/meanas/fdfd/waveguide_cyl.py @@ -11,7 +11,7 @@ As the z-dependence is known, all the functions in this file assume a 2D grid import numpy import scipy.sparse as sparse # type: ignore -from ..fdmath import vec, unvec, dx_lists_t, fdfield_t, vfdfield_t, cfdfield_t +from ..fdmath import vec, unvec, dx_lists_t, vfdfield_t, cfdfield_t from ..fdmath.operators import deriv_forward, deriv_back from ..eigensolvers import signed_eigensolve, rayleigh_quotient_iteration diff --git a/meanas/fdmath/functional.py b/meanas/fdmath/functional.py index 3a10a00..91d8d29 100644 --- a/meanas/fdmath/functional.py +++ b/meanas/fdmath/functional.py @@ -7,12 +7,13 @@ from typing import Sequence, Callable import numpy from numpy.typing import NDArray +from numpy import floating from .types import fdfield_t, fdfield_updater_t def deriv_forward( - dx_e: Sequence[NDArray[numpy.float_]] | None = None, + dx_e: Sequence[NDArray[floating]] | None = None, ) -> tuple[fdfield_updater_t, fdfield_updater_t, fdfield_updater_t]: """ Utility operators for taking discretized derivatives (backward variant). @@ -36,7 +37,7 @@ def deriv_forward( def deriv_back( - dx_h: Sequence[NDArray[numpy.float_]] | None = None, + dx_h: Sequence[NDArray[floating]] | None = None, ) -> tuple[fdfield_updater_t, fdfield_updater_t, fdfield_updater_t]: """ Utility operators for taking discretized derivatives (forward variant). @@ -60,7 +61,7 @@ def deriv_back( def curl_forward( - dx_e: Sequence[NDArray[numpy.float_]] | None = None, + dx_e: Sequence[NDArray[floating]] | None = None, ) -> fdfield_updater_t: r""" Curl operator for use with the E field. @@ -89,7 +90,7 @@ def curl_forward( def curl_back( - dx_h: Sequence[NDArray[numpy.float_]] | None = None, + dx_h: Sequence[NDArray[floating]] | None = None, ) -> fdfield_updater_t: r""" Create a function which takes the backward curl of a field. @@ -118,7 +119,7 @@ def curl_back( def curl_forward_parts( - dx_e: Sequence[NDArray[numpy.float_]] | None = None, + dx_e: Sequence[NDArray[floating]] | None = None, ) -> Callable: Dx, Dy, Dz = deriv_forward(dx_e) @@ -131,7 +132,7 @@ def curl_forward_parts( def curl_back_parts( - dx_h: Sequence[NDArray[numpy.float_]] | None = None, + dx_h: Sequence[NDArray[floating]] | None = None, ) -> Callable: Dx, Dy, Dz = deriv_back(dx_h) diff --git a/meanas/fdmath/operators.py b/meanas/fdmath/operators.py index 9d5988d..c085808 100644 --- a/meanas/fdmath/operators.py +++ b/meanas/fdmath/operators.py @@ -6,6 +6,7 @@ Basic discrete calculus etc. from typing import Sequence import numpy from numpy.typing import NDArray +from numpy import floating import scipy.sparse as sparse # type: ignore from .types import vfdfield_t @@ -96,7 +97,7 @@ def shift_with_mirror( def deriv_forward( - dx_e: Sequence[NDArray[numpy.float_]], + dx_e: Sequence[NDArray[floating]], ) -> list[sparse.spmatrix]: """ Utility operators for taking discretized derivatives (forward variant). @@ -123,7 +124,7 @@ def deriv_forward( def deriv_back( - dx_h: Sequence[NDArray[numpy.float_]], + dx_h: Sequence[NDArray[floating]], ) -> list[sparse.spmatrix]: """ Utility operators for taking discretized derivatives (backward variant). @@ -218,7 +219,7 @@ def avg_back(axis: int, shape: Sequence[int]) -> sparse.spmatrix: def curl_forward( - dx_e: Sequence[NDArray[numpy.float_]], + dx_e: Sequence[NDArray[floating]], ) -> sparse.spmatrix: """ Curl operator for use with the E field. @@ -234,7 +235,7 @@ def curl_forward( def curl_back( - dx_h: Sequence[NDArray[numpy.float_]], + dx_h: Sequence[NDArray[floating]], ) -> sparse.spmatrix: """ Curl operator for use with the H field. diff --git a/meanas/fdmath/types.py b/meanas/fdmath/types.py index aae9594..b78e93f 100644 --- a/meanas/fdmath/types.py +++ b/meanas/fdmath/types.py @@ -2,25 +2,25 @@ Types shared across multiple submodules """ from typing import Sequence, Callable, MutableSequence -import numpy from numpy.typing import NDArray +from numpy import floating, complexfloating # Field types -fdfield_t = NDArray[numpy.float_] +fdfield_t = NDArray[floating] """Vector field with shape (3, X, Y, Z) (e.g. `[E_x, E_y, E_z]`)""" -vfdfield_t = NDArray[numpy.float_] +vfdfield_t = NDArray[floating] """Linearized vector field (single vector of length 3*X*Y*Z)""" -cfdfield_t = NDArray[numpy.complex_] +cfdfield_t = NDArray[complexfloating] """Complex vector field with shape (3, X, Y, Z) (e.g. `[E_x, E_y, E_z]`)""" -vcfdfield_t = NDArray[numpy.complex_] +vcfdfield_t = NDArray[complexfloating] """Linearized complex vector field (single vector of length 3*X*Y*Z)""" -dx_lists_t = Sequence[Sequence[NDArray[numpy.float_]]] +dx_lists_t = Sequence[Sequence[NDArray[floating]]] """ 'dxes' datastructure which contains grid cell width information in the following format: @@ -31,7 +31,7 @@ dx_lists_t = Sequence[Sequence[NDArray[numpy.float_]]] and `dy_h[0]` is the y-width of the `y=0` cells, as used when calculating dH/dy, etc. """ -dx_lists_mut = MutableSequence[MutableSequence[NDArray[numpy.float_]]] +dx_lists_mut = MutableSequence[MutableSequence[NDArray[floating]]] """Mutable version of `dx_lists_t`"""