improve type annotations

This commit is contained in:
Jan Petykiewicz 2025-01-28 19:54:04 -08:00
parent 4afc6cf62e
commit 1987ee473a
5 changed files with 17 additions and 17 deletions

View File

@ -931,7 +931,7 @@ def inner_product( # TODO documentation
prop_phase: float = 0,
conj_h: bool = False,
trapezoid: bool = False,
) -> tuple[vcfdfield_t, vcfdfield_t]:
) -> complex:
shape = [s.size for s in dxes[0]]

View File

@ -8,7 +8,7 @@ As the z-dependence is known, all the functions in this file assume a 2D grid
"""
# TODO update module docs
from typing import Any
from typing import Any, cast
from collections.abc import Sequence
import logging
@ -142,7 +142,7 @@ def solve_modes(
# Wavenumbers assume the mode is at rmin, which is unlikely
# Instead, return the wavenumber in inverse radians
angular_wavenumbers = wavenumbers * rmin
angular_wavenumbers = wavenumbers * cast(complex, rmin)
order = angular_wavenumbers.argsort()[::-1]
e_xys = e_xys[order]

View File

@ -14,7 +14,7 @@ from .types import fdfield_t, fdfield_updater_t
def deriv_forward(
dx_e: Sequence[NDArray[floating]] | None = None,
dx_e: Sequence[NDArray[floating | complexfloating]] | None = None,
) -> tuple[fdfield_updater_t, fdfield_updater_t, fdfield_updater_t]:
"""
Utility operators for taking discretized derivatives (backward variant).
@ -38,7 +38,7 @@ def deriv_forward(
def deriv_back(
dx_h: Sequence[NDArray[floating]] | None = None,
dx_h: Sequence[NDArray[floating | complexfloating]] | None = None,
) -> tuple[fdfield_updater_t, fdfield_updater_t, fdfield_updater_t]:
"""
Utility operators for taking discretized derivatives (forward variant).
@ -65,7 +65,7 @@ TT = TypeVar('TT', bound='NDArray[floating | complexfloating]')
def curl_forward(
dx_e: Sequence[NDArray[floating]] | None = None,
dx_e: Sequence[NDArray[floating | complexfloating]] | None = None,
) -> Callable[[TT], TT]:
r"""
Curl operator for use with the E field.
@ -94,7 +94,7 @@ def curl_forward(
def curl_back(
dx_h: Sequence[NDArray[floating]] | None = None,
dx_h: Sequence[NDArray[floating | complexfloating]] | None = None,
) -> Callable[[TT], TT]:
r"""
Create a function which takes the backward curl of a field.
@ -123,7 +123,7 @@ def curl_back(
def curl_forward_parts(
dx_e: Sequence[NDArray[floating]] | None = None,
dx_e: Sequence[NDArray[floating | complexfloating]] | None = None,
) -> Callable:
Dx, Dy, Dz = deriv_forward(dx_e)
@ -136,7 +136,7 @@ def curl_forward_parts(
def curl_back_parts(
dx_h: Sequence[NDArray[floating]] | None = None,
dx_h: Sequence[NDArray[floating | complexfloating]] | None = None,
) -> Callable:
Dx, Dy, Dz = deriv_back(dx_h)

View File

@ -6,7 +6,7 @@ Basic discrete calculus etc.
from collections.abc import Sequence
import numpy
from numpy.typing import NDArray
from numpy import floating
from numpy import floating, complexfloating
from scipy import sparse
from .types import vfdfield_t
@ -97,7 +97,7 @@ def shift_with_mirror(
def deriv_forward(
dx_e: Sequence[NDArray[floating]],
dx_e: Sequence[NDArray[floating | complexfloating]],
) -> list[sparse.spmatrix]:
"""
Utility operators for taking discretized derivatives (forward variant).
@ -124,7 +124,7 @@ def deriv_forward(
def deriv_back(
dx_h: Sequence[NDArray[floating]],
dx_h: Sequence[NDArray[floating | complexfloating]],
) -> list[sparse.spmatrix]:
"""
Utility operators for taking discretized derivatives (backward variant).
@ -219,7 +219,7 @@ def avg_back(axis: int, shape: Sequence[int]) -> sparse.spmatrix:
def curl_forward(
dx_e: Sequence[NDArray[floating]],
dx_e: Sequence[NDArray[floating | complexfloating]],
) -> sparse.spmatrix:
"""
Curl operator for use with the E field.
@ -235,7 +235,7 @@ def curl_forward(
def curl_back(
dx_h: Sequence[NDArray[floating]],
dx_h: Sequence[NDArray[floating | complexfloating]],
) -> sparse.spmatrix:
"""
Curl operator for use with the H field.

View File

@ -49,15 +49,15 @@ def vec(
@overload
def unvec(v: None, shape: Sequence[int], nvdim: int) -> None:
def unvec(v: None, shape: Sequence[int], nvdim: int = 3) -> None:
pass
@overload
def unvec(v: vfdfield_t, shape: Sequence[int], nvdim: int) -> fdfield_t:
def unvec(v: vfdfield_t, shape: Sequence[int], nvdim: int = 3) -> fdfield_t:
pass
@overload
def unvec(v: vcfdfield_t, shape: Sequence[int], nvdim: int) -> cfdfield_t:
def unvec(v: vcfdfield_t, shape: Sequence[int], nvdim: int = 3) -> cfdfield_t:
pass
def unvec(