improve type annotations

This commit is contained in:
Jan Petykiewicz 2024-07-30 22:41:27 -07:00
parent 2f7a46ff71
commit 6193a9c256
3 changed files with 50 additions and 45 deletions

View File

@ -14,21 +14,22 @@ satisfy the constraints for the 'conjugate gradient' algorithm
(positive definite, symmetric) and some that don't.
"""
from typing import Dict, Any, Optional
from typing import Any, TYPE_CHECKING
import time
import logging
import numpy
from numpy.typing import NDArray, ArrayLike
from numpy.linalg import norm
from numpy import complexfloating
import pyopencl
import pyopencl.array
import scipy
import meanas.fdfd.solvers
from . import ops
if TYPE_CHECKING:
import scipy
__author__ = 'Jan Petykiewicz'
@ -58,9 +59,9 @@ def cg(
b: ArrayLike,
max_iters: int = 10000,
err_threshold: float = 1e-6,
context: Optional[pyopencl.Context] = None,
queue: Optional[pyopencl.CommandQueue] = None,
) -> NDArray:
context: pyopencl.Context | None = None,
queue: pyopencl.CommandQueue | None = None,
) -> NDArray[complexfloating]:
"""
General conjugate-gradient solver for sparse matrices, where A @ x = b.
@ -84,7 +85,7 @@ def cg(
if queue is None:
queue = pyopencl.CommandQueue(context)
def load_field(v, dtype=numpy.complex128):
def load_field(v: NDArray[numpy.complexfloating], dtype: type = numpy.complex128) -> pyopencl.array.Array:
return pyopencl.array.to_device(queue, v.astype(dtype))
r = load_field(b)
@ -160,9 +161,9 @@ def cg(
def fdfd_cg_solver(
solver_opts: Optional[Dict[str, Any]] = None,
**fdfd_args
) -> NDArray:
solver_opts: dict[str, Any] | None = None,
**fdfd_args,
) -> NDArray[complexfloating]:
"""
Conjugate gradient FDFD solver using CSR sparse matrices, mainly for
testing and development since it's much slower than the solver in main.py.

View File

@ -5,14 +5,13 @@ This file holds the default FDFD solver, which uses an E-field wave
operator implemented directly as OpenCL arithmetic (rather than as
a matrix).
"""
from typing import List, Optional, cast
import time
import logging
import numpy
from numpy.typing import NDArray, ArrayLike
from numpy.linalg import norm
from numpy import floating, complexfloating
import pyopencl
import pyopencl.array
@ -28,16 +27,16 @@ logger = logging.getLogger(__name__)
def cg_solver(
omega: complex,
dxes: List[List[NDArray]],
dxes: list[list[NDArray[floating | complexfloating]]],
J: ArrayLike,
epsilon: ArrayLike,
mu: Optional[ArrayLike] = None,
pec: Optional[ArrayLike] = None,
pmc: Optional[ArrayLike] = None,
mu: ArrayLike | None = None,
pec: ArrayLike | None = None,
pmc: ArrayLike | None = None,
adjoint: bool = False,
max_iters: int = 40000,
err_threshold: float = 1e-6,
context: Optional[pyopencl.Context] = None,
context: pyopencl.Context | None = None,
) -> NDArray:
"""
OpenCL FDFD solver using the iterative conjugate gradient (cg) method
@ -108,13 +107,10 @@ def cg_solver(
epsilon = numpy.conj(epsilon)
if mu is not None:
mu = numpy.conj(mu)
assert isinstance(epsilon, NDArray[floating] | NDArray[complexfloating])
L, R = meanas.fdfd.operators.e_full_preconditioners(dxes)
if adjoint:
b_preconditioned = R @ b
else:
b_preconditioned = L @ b
b_preconditioned = (R if adjoint else L) @ b
'''
Allocate GPU memory and load in data
@ -124,7 +120,7 @@ def cg_solver(
queue = pyopencl.CommandQueue(context)
def load_field(v, dtype=numpy.complex128):
def load_field(v: NDArray[complexfloating | floating], dtype: type = numpy.complex128) -> pyopencl.array.Array:
return pyopencl.array.to_device(queue, v.astype(dtype))
r = load_field(b_preconditioned) # load preconditioned b into r
@ -169,7 +165,12 @@ def cg_solver(
p_step = ops.create_p_step(context)
dot = ops.create_dot(context)
def a_step(E, H, p, events):
def a_step(
E: pyopencl.array.Array,
H: pyopencl.array.Array,
p: pyopencl.array.Array,
events: list[pyopencl.Event],
) -> list[pyopencl.Event]:
return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events)
'''

View File

@ -7,11 +7,11 @@ kernels for use by the other solvers.
See kernels/ for any of the .cl files loaded in this file.
"""
from typing import List, Callable, Union, Type, Sequence, Optional, Tuple
from collections.abc import Callable, Sequence
import logging
import numpy
from numpy.typing import NDArray, ArrayLike
from numpy.typing import ArrayLike
import jinja2
import pyopencl
@ -20,17 +20,20 @@ from pyopencl.elementwise import ElementwiseKernel
from pyopencl.reduction import ReductionKernel
from .csr import CSRMatrix
logger = logging.getLogger(__name__)
# Create jinja2 env on module load
jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__, 'kernels'))
# Return type for the create_opname(...) functions
operation = Callable[..., List[pyopencl.Event]]
operation = Callable[..., list[pyopencl.Event]]
def type_to_C(
float_type: Type,
float_type: type[numpy.floating | numpy.complexfloating],
) -> str:
"""
Returns a string corresponding to the C equivalent of a numpy type.
@ -71,7 +74,7 @@ preamble = '''
'''.format(ctype=ctype_bare)
def ptrs(*args: str) -> List[str]:
def ptrs(*args: str) -> list[str]:
return [ctype + ' *' + s for s in args]
@ -169,13 +172,13 @@ def create_a(
p: pyopencl.array.Array,
idxes: Sequence[Sequence[pyopencl.array.Array]],
oeps: pyopencl.array.Array,
inv_mu: Optional[pyopencl.array.Array],
pec: Optional[pyopencl.array.Array],
pmc: Optional[pyopencl.array.Array],
inv_mu: pyopencl.array.Array | None,
pec: pyopencl.array.Array | None,
pmc: pyopencl.array.Array | None,
Pl: pyopencl.array.Array,
Pr: pyopencl.array.Array,
e: List[pyopencl.Event],
) -> List[pyopencl.Event]:
e: list[pyopencl.Event],
) -> list[pyopencl.Event]:
e2 = P2E_kernel(E, p, Pr, pec, wait_for=e)
e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2])
e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2])
@ -227,14 +230,14 @@ def create_xr_step(context: pyopencl.Context) -> operation:
r: pyopencl.array.Array,
v: pyopencl.array.Array,
alpha: complex,
e: List[pyopencl.Event],
) -> List[pyopencl.Event]:
e: list[pyopencl.Event],
) -> list[pyopencl.Event]:
return [xr_kernel(x, p, r, v, alpha, wait_for=e)]
return xr_update
def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., Tuple[complex, complex]]:
def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., tuple[complex, complex]]:
"""
Return a function
ri_update(r, e)
@ -272,7 +275,7 @@ def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., Tuple[complex
arguments=ctype + ' *r',
)
def ri_update(r: pyopencl.array.Array, e: List[pyopencl.Event]) -> Tuple[complex, complex]:
def ri_update(r: pyopencl.array.Array, e: list[pyopencl.Event]) -> tuple[complex, complex]:
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
rr, ri, ii = [g[q] for q in 'xyz']
rho = rr + 2j * ri - ii
@ -315,7 +318,7 @@ def create_p_step(context: pyopencl.Context) -> operation:
p: pyopencl.array.Array,
r: pyopencl.array.Array,
beta: complex,
e: List[pyopencl.Event]) -> List[pyopencl.Event]:
e: list[pyopencl.Event]) -> list[pyopencl.Event]:
return [p_kernel(p, r, beta, wait_for=e)]
return p_update
@ -350,7 +353,7 @@ def create_dot(context: pyopencl.Context) -> Callable[..., complex]:
def dot(
p: pyopencl.array.Array,
v: pyopencl.array.Array,
e: List[pyopencl.Event],
e: list[pyopencl.Event],
) -> complex:
g = dot_kernel(p, v, wait_for=e)
return g.get()
@ -406,11 +409,11 @@ def create_a_csr(context: pyopencl.Context) -> operation:
)
def spmv(
v_out,
m,
v_in,
e: List[pyopencl.Event],
) -> List[pyopencl.Event]:
v_out: pyopencl.array.Array,
m: CSRMatrix,
v_in: pyopencl.array.Array,
e: list[pyopencl.Event],
) -> list[pyopencl.Event]:
return [spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)]
return spmv