From 6193a9c25691d9fdd0d835add1788149116e5a9c Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Tue, 30 Jul 2024 22:41:27 -0700 Subject: [PATCH] improve type annotations --- opencl_fdfd/csr.py | 21 +++++++++++---------- opencl_fdfd/main.py | 29 +++++++++++++++-------------- opencl_fdfd/ops.py | 45 ++++++++++++++++++++++++--------------------- 3 files changed, 50 insertions(+), 45 deletions(-) diff --git a/opencl_fdfd/csr.py b/opencl_fdfd/csr.py index 0f5a837..e2c5677 100644 --- a/opencl_fdfd/csr.py +++ b/opencl_fdfd/csr.py @@ -14,21 +14,22 @@ satisfy the constraints for the 'conjugate gradient' algorithm (positive definite, symmetric) and some that don't. """ -from typing import Dict, Any, Optional +from typing import Any, TYPE_CHECKING import time import logging import numpy from numpy.typing import NDArray, ArrayLike from numpy.linalg import norm +from numpy import complexfloating import pyopencl import pyopencl.array -import scipy - import meanas.fdfd.solvers from . import ops +if TYPE_CHECKING: + import scipy __author__ = 'Jan Petykiewicz' @@ -58,9 +59,9 @@ def cg( b: ArrayLike, max_iters: int = 10000, err_threshold: float = 1e-6, - context: Optional[pyopencl.Context] = None, - queue: Optional[pyopencl.CommandQueue] = None, - ) -> NDArray: + context: pyopencl.Context | None = None, + queue: pyopencl.CommandQueue | None = None, + ) -> NDArray[complexfloating]: """ General conjugate-gradient solver for sparse matrices, where A @ x = b. @@ -84,7 +85,7 @@ def cg( if queue is None: queue = pyopencl.CommandQueue(context) - def load_field(v, dtype=numpy.complex128): + def load_field(v: NDArray[numpy.complexfloating], dtype: type = numpy.complex128) -> pyopencl.array.Array: return pyopencl.array.to_device(queue, v.astype(dtype)) r = load_field(b) @@ -160,9 +161,9 @@ def cg( def fdfd_cg_solver( - solver_opts: Optional[Dict[str, Any]] = None, - **fdfd_args - ) -> NDArray: + solver_opts: dict[str, Any] | None = None, + **fdfd_args, + ) -> NDArray[complexfloating]: """ Conjugate gradient FDFD solver using CSR sparse matrices, mainly for testing and development since it's much slower than the solver in main.py. diff --git a/opencl_fdfd/main.py b/opencl_fdfd/main.py index 337b4e0..f4bd139 100644 --- a/opencl_fdfd/main.py +++ b/opencl_fdfd/main.py @@ -5,14 +5,13 @@ This file holds the default FDFD solver, which uses an E-field wave operator implemented directly as OpenCL arithmetic (rather than as a matrix). """ - -from typing import List, Optional, cast import time import logging import numpy from numpy.typing import NDArray, ArrayLike from numpy.linalg import norm +from numpy import floating, complexfloating import pyopencl import pyopencl.array @@ -28,16 +27,16 @@ logger = logging.getLogger(__name__) def cg_solver( omega: complex, - dxes: List[List[NDArray]], + dxes: list[list[NDArray[floating | complexfloating]]], J: ArrayLike, epsilon: ArrayLike, - mu: Optional[ArrayLike] = None, - pec: Optional[ArrayLike] = None, - pmc: Optional[ArrayLike] = None, + mu: ArrayLike | None = None, + pec: ArrayLike | None = None, + pmc: ArrayLike | None = None, adjoint: bool = False, max_iters: int = 40000, err_threshold: float = 1e-6, - context: Optional[pyopencl.Context] = None, + context: pyopencl.Context | None = None, ) -> NDArray: """ OpenCL FDFD solver using the iterative conjugate gradient (cg) method @@ -108,13 +107,10 @@ def cg_solver( epsilon = numpy.conj(epsilon) if mu is not None: mu = numpy.conj(mu) + assert isinstance(epsilon, NDArray[floating] | NDArray[complexfloating]) L, R = meanas.fdfd.operators.e_full_preconditioners(dxes) - - if adjoint: - b_preconditioned = R @ b - else: - b_preconditioned = L @ b + b_preconditioned = (R if adjoint else L) @ b ''' Allocate GPU memory and load in data @@ -124,7 +120,7 @@ def cg_solver( queue = pyopencl.CommandQueue(context) - def load_field(v, dtype=numpy.complex128): + def load_field(v: NDArray[complexfloating | floating], dtype: type = numpy.complex128) -> pyopencl.array.Array: return pyopencl.array.to_device(queue, v.astype(dtype)) r = load_field(b_preconditioned) # load preconditioned b into r @@ -169,7 +165,12 @@ def cg_solver( p_step = ops.create_p_step(context) dot = ops.create_dot(context) - def a_step(E, H, p, events): + def a_step( + E: pyopencl.array.Array, + H: pyopencl.array.Array, + p: pyopencl.array.Array, + events: list[pyopencl.Event], + ) -> list[pyopencl.Event]: return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events) ''' diff --git a/opencl_fdfd/ops.py b/opencl_fdfd/ops.py index b0e4108..16d0d6b 100644 --- a/opencl_fdfd/ops.py +++ b/opencl_fdfd/ops.py @@ -7,11 +7,11 @@ kernels for use by the other solvers. See kernels/ for any of the .cl files loaded in this file. """ -from typing import List, Callable, Union, Type, Sequence, Optional, Tuple +from collections.abc import Callable, Sequence import logging import numpy -from numpy.typing import NDArray, ArrayLike +from numpy.typing import ArrayLike import jinja2 import pyopencl @@ -20,17 +20,20 @@ from pyopencl.elementwise import ElementwiseKernel from pyopencl.reduction import ReductionKernel +from .csr import CSRMatrix + + logger = logging.getLogger(__name__) # Create jinja2 env on module load jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__, 'kernels')) # Return type for the create_opname(...) functions -operation = Callable[..., List[pyopencl.Event]] +operation = Callable[..., list[pyopencl.Event]] def type_to_C( - float_type: Type, + float_type: type[numpy.floating | numpy.complexfloating], ) -> str: """ Returns a string corresponding to the C equivalent of a numpy type. @@ -71,7 +74,7 @@ preamble = ''' '''.format(ctype=ctype_bare) -def ptrs(*args: str) -> List[str]: +def ptrs(*args: str) -> list[str]: return [ctype + ' *' + s for s in args] @@ -169,13 +172,13 @@ def create_a( p: pyopencl.array.Array, idxes: Sequence[Sequence[pyopencl.array.Array]], oeps: pyopencl.array.Array, - inv_mu: Optional[pyopencl.array.Array], - pec: Optional[pyopencl.array.Array], - pmc: Optional[pyopencl.array.Array], + inv_mu: pyopencl.array.Array | None, + pec: pyopencl.array.Array | None, + pmc: pyopencl.array.Array | None, Pl: pyopencl.array.Array, Pr: pyopencl.array.Array, - e: List[pyopencl.Event], - ) -> List[pyopencl.Event]: + e: list[pyopencl.Event], + ) -> list[pyopencl.Event]: e2 = P2E_kernel(E, p, Pr, pec, wait_for=e) e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2]) e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2]) @@ -227,14 +230,14 @@ def create_xr_step(context: pyopencl.Context) -> operation: r: pyopencl.array.Array, v: pyopencl.array.Array, alpha: complex, - e: List[pyopencl.Event], - ) -> List[pyopencl.Event]: + e: list[pyopencl.Event], + ) -> list[pyopencl.Event]: return [xr_kernel(x, p, r, v, alpha, wait_for=e)] return xr_update -def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., Tuple[complex, complex]]: +def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., tuple[complex, complex]]: """ Return a function ri_update(r, e) @@ -272,7 +275,7 @@ def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., Tuple[complex arguments=ctype + ' *r', ) - def ri_update(r: pyopencl.array.Array, e: List[pyopencl.Event]) -> Tuple[complex, complex]: + def ri_update(r: pyopencl.array.Array, e: list[pyopencl.Event]) -> tuple[complex, complex]: g = ri_kernel(r, wait_for=e).astype(ri_dtype).get() rr, ri, ii = [g[q] for q in 'xyz'] rho = rr + 2j * ri - ii @@ -315,7 +318,7 @@ def create_p_step(context: pyopencl.Context) -> operation: p: pyopencl.array.Array, r: pyopencl.array.Array, beta: complex, - e: List[pyopencl.Event]) -> List[pyopencl.Event]: + e: list[pyopencl.Event]) -> list[pyopencl.Event]: return [p_kernel(p, r, beta, wait_for=e)] return p_update @@ -350,7 +353,7 @@ def create_dot(context: pyopencl.Context) -> Callable[..., complex]: def dot( p: pyopencl.array.Array, v: pyopencl.array.Array, - e: List[pyopencl.Event], + e: list[pyopencl.Event], ) -> complex: g = dot_kernel(p, v, wait_for=e) return g.get() @@ -406,11 +409,11 @@ def create_a_csr(context: pyopencl.Context) -> operation: ) def spmv( - v_out, - m, - v_in, - e: List[pyopencl.Event], - ) -> List[pyopencl.Event]: + v_out: pyopencl.array.Array, + m: CSRMatrix, + v_in: pyopencl.array.Array, + e: list[pyopencl.Event], + ) -> list[pyopencl.Event]: return [spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)] return spmv