improve type annotations
This commit is contained in:
parent
36bea6a593
commit
ee51c7db49
@ -7,6 +7,7 @@ its parameters into 2D equivalents and expands the results back into 3D.
|
||||
from typing import Sequence, Any
|
||||
import numpy
|
||||
from numpy.typing import NDArray
|
||||
from numpy import complexfloating
|
||||
|
||||
from ..fdmath import vec, unvec, dx_lists_t, fdfield_t, cfdfield_t
|
||||
from . import operators, waveguide_2d
|
||||
@ -21,7 +22,7 @@ def solve_mode(
|
||||
slices: Sequence[slice],
|
||||
epsilon: fdfield_t,
|
||||
mu: fdfield_t | None = None,
|
||||
) -> dict[str, complex | NDArray[numpy.float_]]:
|
||||
) -> dict[str, complex | NDArray[complexfloating]]:
|
||||
"""
|
||||
Given a 3D grid, selects a slice from the grid and attempts to
|
||||
solve for an eigenmode propagating through that slice.
|
||||
@ -40,8 +41,8 @@ def solve_mode(
|
||||
Returns:
|
||||
```
|
||||
{
|
||||
'E': list[NDArray[numpy.float_]],
|
||||
'H': list[NDArray[numpy.float_]],
|
||||
'E': NDArray[complexfloating],
|
||||
'H': NDArray[complexfloating],
|
||||
'wavenumber': complex,
|
||||
}
|
||||
```
|
||||
|
@ -11,7 +11,7 @@ As the z-dependence is known, all the functions in this file assume a 2D grid
|
||||
import numpy
|
||||
import scipy.sparse as sparse # type: ignore
|
||||
|
||||
from ..fdmath import vec, unvec, dx_lists_t, fdfield_t, vfdfield_t, cfdfield_t
|
||||
from ..fdmath import vec, unvec, dx_lists_t, vfdfield_t, cfdfield_t
|
||||
from ..fdmath.operators import deriv_forward, deriv_back
|
||||
from ..eigensolvers import signed_eigensolve, rayleigh_quotient_iteration
|
||||
|
||||
|
@ -7,12 +7,13 @@ from typing import Sequence, Callable
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray
|
||||
from numpy import floating
|
||||
|
||||
from .types import fdfield_t, fdfield_updater_t
|
||||
|
||||
|
||||
def deriv_forward(
|
||||
dx_e: Sequence[NDArray[numpy.float_]] | None = None,
|
||||
dx_e: Sequence[NDArray[floating]] | None = None,
|
||||
) -> tuple[fdfield_updater_t, fdfield_updater_t, fdfield_updater_t]:
|
||||
"""
|
||||
Utility operators for taking discretized derivatives (backward variant).
|
||||
@ -36,7 +37,7 @@ def deriv_forward(
|
||||
|
||||
|
||||
def deriv_back(
|
||||
dx_h: Sequence[NDArray[numpy.float_]] | None = None,
|
||||
dx_h: Sequence[NDArray[floating]] | None = None,
|
||||
) -> tuple[fdfield_updater_t, fdfield_updater_t, fdfield_updater_t]:
|
||||
"""
|
||||
Utility operators for taking discretized derivatives (forward variant).
|
||||
@ -60,7 +61,7 @@ def deriv_back(
|
||||
|
||||
|
||||
def curl_forward(
|
||||
dx_e: Sequence[NDArray[numpy.float_]] | None = None,
|
||||
dx_e: Sequence[NDArray[floating]] | None = None,
|
||||
) -> fdfield_updater_t:
|
||||
r"""
|
||||
Curl operator for use with the E field.
|
||||
@ -89,7 +90,7 @@ def curl_forward(
|
||||
|
||||
|
||||
def curl_back(
|
||||
dx_h: Sequence[NDArray[numpy.float_]] | None = None,
|
||||
dx_h: Sequence[NDArray[floating]] | None = None,
|
||||
) -> fdfield_updater_t:
|
||||
r"""
|
||||
Create a function which takes the backward curl of a field.
|
||||
@ -118,7 +119,7 @@ def curl_back(
|
||||
|
||||
|
||||
def curl_forward_parts(
|
||||
dx_e: Sequence[NDArray[numpy.float_]] | None = None,
|
||||
dx_e: Sequence[NDArray[floating]] | None = None,
|
||||
) -> Callable:
|
||||
Dx, Dy, Dz = deriv_forward(dx_e)
|
||||
|
||||
@ -131,7 +132,7 @@ def curl_forward_parts(
|
||||
|
||||
|
||||
def curl_back_parts(
|
||||
dx_h: Sequence[NDArray[numpy.float_]] | None = None,
|
||||
dx_h: Sequence[NDArray[floating]] | None = None,
|
||||
) -> Callable:
|
||||
Dx, Dy, Dz = deriv_back(dx_h)
|
||||
|
||||
|
@ -6,6 +6,7 @@ Basic discrete calculus etc.
|
||||
from typing import Sequence
|
||||
import numpy
|
||||
from numpy.typing import NDArray
|
||||
from numpy import floating
|
||||
import scipy.sparse as sparse # type: ignore
|
||||
|
||||
from .types import vfdfield_t
|
||||
@ -96,7 +97,7 @@ def shift_with_mirror(
|
||||
|
||||
|
||||
def deriv_forward(
|
||||
dx_e: Sequence[NDArray[numpy.float_]],
|
||||
dx_e: Sequence[NDArray[floating]],
|
||||
) -> list[sparse.spmatrix]:
|
||||
"""
|
||||
Utility operators for taking discretized derivatives (forward variant).
|
||||
@ -123,7 +124,7 @@ def deriv_forward(
|
||||
|
||||
|
||||
def deriv_back(
|
||||
dx_h: Sequence[NDArray[numpy.float_]],
|
||||
dx_h: Sequence[NDArray[floating]],
|
||||
) -> list[sparse.spmatrix]:
|
||||
"""
|
||||
Utility operators for taking discretized derivatives (backward variant).
|
||||
@ -218,7 +219,7 @@ def avg_back(axis: int, shape: Sequence[int]) -> sparse.spmatrix:
|
||||
|
||||
|
||||
def curl_forward(
|
||||
dx_e: Sequence[NDArray[numpy.float_]],
|
||||
dx_e: Sequence[NDArray[floating]],
|
||||
) -> sparse.spmatrix:
|
||||
"""
|
||||
Curl operator for use with the E field.
|
||||
@ -234,7 +235,7 @@ def curl_forward(
|
||||
|
||||
|
||||
def curl_back(
|
||||
dx_h: Sequence[NDArray[numpy.float_]],
|
||||
dx_h: Sequence[NDArray[floating]],
|
||||
) -> sparse.spmatrix:
|
||||
"""
|
||||
Curl operator for use with the H field.
|
||||
|
@ -2,25 +2,25 @@
|
||||
Types shared across multiple submodules
|
||||
"""
|
||||
from typing import Sequence, Callable, MutableSequence
|
||||
import numpy
|
||||
from numpy.typing import NDArray
|
||||
from numpy import floating, complexfloating
|
||||
|
||||
|
||||
# Field types
|
||||
fdfield_t = NDArray[numpy.float_]
|
||||
fdfield_t = NDArray[floating]
|
||||
"""Vector field with shape (3, X, Y, Z) (e.g. `[E_x, E_y, E_z]`)"""
|
||||
|
||||
vfdfield_t = NDArray[numpy.float_]
|
||||
vfdfield_t = NDArray[floating]
|
||||
"""Linearized vector field (single vector of length 3*X*Y*Z)"""
|
||||
|
||||
cfdfield_t = NDArray[numpy.complex_]
|
||||
cfdfield_t = NDArray[complexfloating]
|
||||
"""Complex vector field with shape (3, X, Y, Z) (e.g. `[E_x, E_y, E_z]`)"""
|
||||
|
||||
vcfdfield_t = NDArray[numpy.complex_]
|
||||
vcfdfield_t = NDArray[complexfloating]
|
||||
"""Linearized complex vector field (single vector of length 3*X*Y*Z)"""
|
||||
|
||||
|
||||
dx_lists_t = Sequence[Sequence[NDArray[numpy.float_]]]
|
||||
dx_lists_t = Sequence[Sequence[NDArray[floating]]]
|
||||
"""
|
||||
'dxes' datastructure which contains grid cell width information in the following format:
|
||||
|
||||
@ -31,7 +31,7 @@ dx_lists_t = Sequence[Sequence[NDArray[numpy.float_]]]
|
||||
and `dy_h[0]` is the y-width of the `y=0` cells, as used when calculating dH/dy, etc.
|
||||
"""
|
||||
|
||||
dx_lists_mut = MutableSequence[MutableSequence[NDArray[numpy.float_]]]
|
||||
dx_lists_mut = MutableSequence[MutableSequence[NDArray[floating]]]
|
||||
"""Mutable version of `dx_lists_t`"""
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user