improve some type annotations

This commit is contained in:
Jan Petykiewicz 2024-07-29 00:32:52 -07:00
parent 1021768e30
commit 5dd9994e76
2 changed files with 9 additions and 6 deletions

View File

@ -8,7 +8,7 @@ from collections.abc import Sequence, Callable
import numpy import numpy
from numpy.typing import NDArray from numpy.typing import NDArray
from numpy import floating from numpy import floating, complexfloating
from .types import fdfield_t, fdfield_updater_t from .types import fdfield_t, fdfield_updater_t
@ -61,9 +61,12 @@ def deriv_back(
return derivs return derivs
TT = TypeVar('TT', bound='NDArray[floating | complexfloating]')
def curl_forward( def curl_forward(
dx_e: Sequence[NDArray[floating]] | None = None, dx_e: Sequence[NDArray[floating]] | None = None,
) -> fdfield_updater_t: ) -> Callable[[TT], TT]:
r""" r"""
Curl operator for use with the E field. Curl operator for use with the E field.
@ -77,7 +80,7 @@ def curl_forward(
""" """
Dx, Dy, Dz = deriv_forward(dx_e) Dx, Dy, Dz = deriv_forward(dx_e)
def ce_fun(e: fdfield_t) -> fdfield_t: def ce_fun(e: TT) -> TT:
output = numpy.empty_like(e) output = numpy.empty_like(e)
output[0] = Dy(e[2]) output[0] = Dy(e[2])
output[1] = Dz(e[0]) output[1] = Dz(e[0])
@ -92,7 +95,7 @@ def curl_forward(
def curl_back( def curl_back(
dx_h: Sequence[NDArray[floating]] | None = None, dx_h: Sequence[NDArray[floating]] | None = None,
) -> fdfield_updater_t: ) -> Callable[[TT], TT]:
r""" r"""
Create a function which takes the backward curl of a field. Create a function which takes the backward curl of a field.
@ -106,7 +109,7 @@ def curl_back(
""" """
Dx, Dy, Dz = deriv_back(dx_h) Dx, Dy, Dz = deriv_back(dx_h)
def ch_fun(h: fdfield_t) -> fdfield_t: def ch_fun(h: TT) -> TT:
output = numpy.empty_like(h) output = numpy.empty_like(h)
output[0] = Dy(h[2]) output[0] = Dy(h[2])
output[1] = Dz(h[0]) output[1] = Dz(h[0])

View File

@ -185,7 +185,7 @@ def updates_with_cpml(
def update_H( def update_H(
e: fdfield_t, e: fdfield_t,
h: fdfield_t, h: fdfield_t,
mu: fdfield_t = numpy.ones(3), mu: fdfield_t | tuple[int, int, int] = (1, 1, 1),
) -> None: ) -> None:
dyEx = Dfy(e[0]) dyEx = Dfy(e[0])
dzEx = Dfz(e[0]) dzEx = Dfz(e[0])