improve type annotations

This commit is contained in:
Jan Petykiewicz 2024-07-28 23:23:47 -07:00
parent 36bea6a593
commit ee51c7db49
5 changed files with 24 additions and 21 deletions

View File

@ -7,6 +7,7 @@ its parameters into 2D equivalents and expands the results back into 3D.
from typing import Sequence, Any from typing import Sequence, Any
import numpy import numpy
from numpy.typing import NDArray from numpy.typing import NDArray
from numpy import complexfloating
from ..fdmath import vec, unvec, dx_lists_t, fdfield_t, cfdfield_t from ..fdmath import vec, unvec, dx_lists_t, fdfield_t, cfdfield_t
from . import operators, waveguide_2d from . import operators, waveguide_2d
@ -21,7 +22,7 @@ def solve_mode(
slices: Sequence[slice], slices: Sequence[slice],
epsilon: fdfield_t, epsilon: fdfield_t,
mu: fdfield_t | None = None, mu: fdfield_t | None = None,
) -> dict[str, complex | NDArray[numpy.float_]]: ) -> dict[str, complex | NDArray[complexfloating]]:
""" """
Given a 3D grid, selects a slice from the grid and attempts to Given a 3D grid, selects a slice from the grid and attempts to
solve for an eigenmode propagating through that slice. solve for an eigenmode propagating through that slice.
@ -40,8 +41,8 @@ def solve_mode(
Returns: Returns:
``` ```
{ {
'E': list[NDArray[numpy.float_]], 'E': NDArray[complexfloating],
'H': list[NDArray[numpy.float_]], 'H': NDArray[complexfloating],
'wavenumber': complex, 'wavenumber': complex,
} }
``` ```

View File

@ -11,7 +11,7 @@ As the z-dependence is known, all the functions in this file assume a 2D grid
import numpy import numpy
import scipy.sparse as sparse # type: ignore import scipy.sparse as sparse # type: ignore
from ..fdmath import vec, unvec, dx_lists_t, fdfield_t, vfdfield_t, cfdfield_t from ..fdmath import vec, unvec, dx_lists_t, vfdfield_t, cfdfield_t
from ..fdmath.operators import deriv_forward, deriv_back from ..fdmath.operators import deriv_forward, deriv_back
from ..eigensolvers import signed_eigensolve, rayleigh_quotient_iteration from ..eigensolvers import signed_eigensolve, rayleigh_quotient_iteration

View File

@ -7,12 +7,13 @@ from typing import Sequence, Callable
import numpy import numpy
from numpy.typing import NDArray from numpy.typing import NDArray
from numpy import floating
from .types import fdfield_t, fdfield_updater_t from .types import fdfield_t, fdfield_updater_t
def deriv_forward( def deriv_forward(
dx_e: Sequence[NDArray[numpy.float_]] | None = None, dx_e: Sequence[NDArray[floating]] | None = None,
) -> tuple[fdfield_updater_t, fdfield_updater_t, fdfield_updater_t]: ) -> tuple[fdfield_updater_t, fdfield_updater_t, fdfield_updater_t]:
""" """
Utility operators for taking discretized derivatives (backward variant). Utility operators for taking discretized derivatives (backward variant).
@ -36,7 +37,7 @@ def deriv_forward(
def deriv_back( def deriv_back(
dx_h: Sequence[NDArray[numpy.float_]] | None = None, dx_h: Sequence[NDArray[floating]] | None = None,
) -> tuple[fdfield_updater_t, fdfield_updater_t, fdfield_updater_t]: ) -> tuple[fdfield_updater_t, fdfield_updater_t, fdfield_updater_t]:
""" """
Utility operators for taking discretized derivatives (forward variant). Utility operators for taking discretized derivatives (forward variant).
@ -60,7 +61,7 @@ def deriv_back(
def curl_forward( def curl_forward(
dx_e: Sequence[NDArray[numpy.float_]] | None = None, dx_e: Sequence[NDArray[floating]] | None = None,
) -> fdfield_updater_t: ) -> fdfield_updater_t:
r""" r"""
Curl operator for use with the E field. Curl operator for use with the E field.
@ -89,7 +90,7 @@ def curl_forward(
def curl_back( def curl_back(
dx_h: Sequence[NDArray[numpy.float_]] | None = None, dx_h: Sequence[NDArray[floating]] | None = None,
) -> fdfield_updater_t: ) -> fdfield_updater_t:
r""" r"""
Create a function which takes the backward curl of a field. Create a function which takes the backward curl of a field.
@ -118,7 +119,7 @@ def curl_back(
def curl_forward_parts( def curl_forward_parts(
dx_e: Sequence[NDArray[numpy.float_]] | None = None, dx_e: Sequence[NDArray[floating]] | None = None,
) -> Callable: ) -> Callable:
Dx, Dy, Dz = deriv_forward(dx_e) Dx, Dy, Dz = deriv_forward(dx_e)
@ -131,7 +132,7 @@ def curl_forward_parts(
def curl_back_parts( def curl_back_parts(
dx_h: Sequence[NDArray[numpy.float_]] | None = None, dx_h: Sequence[NDArray[floating]] | None = None,
) -> Callable: ) -> Callable:
Dx, Dy, Dz = deriv_back(dx_h) Dx, Dy, Dz = deriv_back(dx_h)

View File

@ -6,6 +6,7 @@ Basic discrete calculus etc.
from typing import Sequence from typing import Sequence
import numpy import numpy
from numpy.typing import NDArray from numpy.typing import NDArray
from numpy import floating
import scipy.sparse as sparse # type: ignore import scipy.sparse as sparse # type: ignore
from .types import vfdfield_t from .types import vfdfield_t
@ -96,7 +97,7 @@ def shift_with_mirror(
def deriv_forward( def deriv_forward(
dx_e: Sequence[NDArray[numpy.float_]], dx_e: Sequence[NDArray[floating]],
) -> list[sparse.spmatrix]: ) -> list[sparse.spmatrix]:
""" """
Utility operators for taking discretized derivatives (forward variant). Utility operators for taking discretized derivatives (forward variant).
@ -123,7 +124,7 @@ def deriv_forward(
def deriv_back( def deriv_back(
dx_h: Sequence[NDArray[numpy.float_]], dx_h: Sequence[NDArray[floating]],
) -> list[sparse.spmatrix]: ) -> list[sparse.spmatrix]:
""" """
Utility operators for taking discretized derivatives (backward variant). Utility operators for taking discretized derivatives (backward variant).
@ -218,7 +219,7 @@ def avg_back(axis: int, shape: Sequence[int]) -> sparse.spmatrix:
def curl_forward( def curl_forward(
dx_e: Sequence[NDArray[numpy.float_]], dx_e: Sequence[NDArray[floating]],
) -> sparse.spmatrix: ) -> sparse.spmatrix:
""" """
Curl operator for use with the E field. Curl operator for use with the E field.
@ -234,7 +235,7 @@ def curl_forward(
def curl_back( def curl_back(
dx_h: Sequence[NDArray[numpy.float_]], dx_h: Sequence[NDArray[floating]],
) -> sparse.spmatrix: ) -> sparse.spmatrix:
""" """
Curl operator for use with the H field. Curl operator for use with the H field.

View File

@ -2,25 +2,25 @@
Types shared across multiple submodules Types shared across multiple submodules
""" """
from typing import Sequence, Callable, MutableSequence from typing import Sequence, Callable, MutableSequence
import numpy
from numpy.typing import NDArray from numpy.typing import NDArray
from numpy import floating, complexfloating
# Field types # Field types
fdfield_t = NDArray[numpy.float_] fdfield_t = NDArray[floating]
"""Vector field with shape (3, X, Y, Z) (e.g. `[E_x, E_y, E_z]`)""" """Vector field with shape (3, X, Y, Z) (e.g. `[E_x, E_y, E_z]`)"""
vfdfield_t = NDArray[numpy.float_] vfdfield_t = NDArray[floating]
"""Linearized vector field (single vector of length 3*X*Y*Z)""" """Linearized vector field (single vector of length 3*X*Y*Z)"""
cfdfield_t = NDArray[numpy.complex_] cfdfield_t = NDArray[complexfloating]
"""Complex vector field with shape (3, X, Y, Z) (e.g. `[E_x, E_y, E_z]`)""" """Complex vector field with shape (3, X, Y, Z) (e.g. `[E_x, E_y, E_z]`)"""
vcfdfield_t = NDArray[numpy.complex_] vcfdfield_t = NDArray[complexfloating]
"""Linearized complex vector field (single vector of length 3*X*Y*Z)""" """Linearized complex vector field (single vector of length 3*X*Y*Z)"""
dx_lists_t = Sequence[Sequence[NDArray[numpy.float_]]] dx_lists_t = Sequence[Sequence[NDArray[floating]]]
""" """
'dxes' datastructure which contains grid cell width information in the following format: 'dxes' datastructure which contains grid cell width information in the following format:
@ -31,7 +31,7 @@ dx_lists_t = Sequence[Sequence[NDArray[numpy.float_]]]
and `dy_h[0]` is the y-width of the `y=0` cells, as used when calculating dH/dy, etc. and `dy_h[0]` is the y-width of the `y=0` cells, as used when calculating dH/dy, etc.
""" """
dx_lists_mut = MutableSequence[MutableSequence[NDArray[numpy.float_]]] dx_lists_mut = MutableSequence[MutableSequence[NDArray[floating]]]
"""Mutable version of `dx_lists_t`""" """Mutable version of `dx_lists_t`"""