improve type annotations

This commit is contained in:
Jan Petykiewicz 2024-07-30 22:41:27 -07:00
parent 2f7a46ff71
commit 6193a9c256
3 changed files with 50 additions and 45 deletions

View File

@ -14,21 +14,22 @@ satisfy the constraints for the 'conjugate gradient' algorithm
(positive definite, symmetric) and some that don't. (positive definite, symmetric) and some that don't.
""" """
from typing import Dict, Any, Optional from typing import Any, TYPE_CHECKING
import time import time
import logging import logging
import numpy import numpy
from numpy.typing import NDArray, ArrayLike from numpy.typing import NDArray, ArrayLike
from numpy.linalg import norm from numpy.linalg import norm
from numpy import complexfloating
import pyopencl import pyopencl
import pyopencl.array import pyopencl.array
import scipy
import meanas.fdfd.solvers import meanas.fdfd.solvers
from . import ops from . import ops
if TYPE_CHECKING:
import scipy
__author__ = 'Jan Petykiewicz' __author__ = 'Jan Petykiewicz'
@ -58,9 +59,9 @@ def cg(
b: ArrayLike, b: ArrayLike,
max_iters: int = 10000, max_iters: int = 10000,
err_threshold: float = 1e-6, err_threshold: float = 1e-6,
context: Optional[pyopencl.Context] = None, context: pyopencl.Context | None = None,
queue: Optional[pyopencl.CommandQueue] = None, queue: pyopencl.CommandQueue | None = None,
) -> NDArray: ) -> NDArray[complexfloating]:
""" """
General conjugate-gradient solver for sparse matrices, where A @ x = b. General conjugate-gradient solver for sparse matrices, where A @ x = b.
@ -84,7 +85,7 @@ def cg(
if queue is None: if queue is None:
queue = pyopencl.CommandQueue(context) queue = pyopencl.CommandQueue(context)
def load_field(v, dtype=numpy.complex128): def load_field(v: NDArray[numpy.complexfloating], dtype: type = numpy.complex128) -> pyopencl.array.Array:
return pyopencl.array.to_device(queue, v.astype(dtype)) return pyopencl.array.to_device(queue, v.astype(dtype))
r = load_field(b) r = load_field(b)
@ -160,9 +161,9 @@ def cg(
def fdfd_cg_solver( def fdfd_cg_solver(
solver_opts: Optional[Dict[str, Any]] = None, solver_opts: dict[str, Any] | None = None,
**fdfd_args **fdfd_args,
) -> NDArray: ) -> NDArray[complexfloating]:
""" """
Conjugate gradient FDFD solver using CSR sparse matrices, mainly for Conjugate gradient FDFD solver using CSR sparse matrices, mainly for
testing and development since it's much slower than the solver in main.py. testing and development since it's much slower than the solver in main.py.

View File

@ -5,14 +5,13 @@ This file holds the default FDFD solver, which uses an E-field wave
operator implemented directly as OpenCL arithmetic (rather than as operator implemented directly as OpenCL arithmetic (rather than as
a matrix). a matrix).
""" """
from typing import List, Optional, cast
import time import time
import logging import logging
import numpy import numpy
from numpy.typing import NDArray, ArrayLike from numpy.typing import NDArray, ArrayLike
from numpy.linalg import norm from numpy.linalg import norm
from numpy import floating, complexfloating
import pyopencl import pyopencl
import pyopencl.array import pyopencl.array
@ -28,16 +27,16 @@ logger = logging.getLogger(__name__)
def cg_solver( def cg_solver(
omega: complex, omega: complex,
dxes: List[List[NDArray]], dxes: list[list[NDArray[floating | complexfloating]]],
J: ArrayLike, J: ArrayLike,
epsilon: ArrayLike, epsilon: ArrayLike,
mu: Optional[ArrayLike] = None, mu: ArrayLike | None = None,
pec: Optional[ArrayLike] = None, pec: ArrayLike | None = None,
pmc: Optional[ArrayLike] = None, pmc: ArrayLike | None = None,
adjoint: bool = False, adjoint: bool = False,
max_iters: int = 40000, max_iters: int = 40000,
err_threshold: float = 1e-6, err_threshold: float = 1e-6,
context: Optional[pyopencl.Context] = None, context: pyopencl.Context | None = None,
) -> NDArray: ) -> NDArray:
""" """
OpenCL FDFD solver using the iterative conjugate gradient (cg) method OpenCL FDFD solver using the iterative conjugate gradient (cg) method
@ -108,13 +107,10 @@ def cg_solver(
epsilon = numpy.conj(epsilon) epsilon = numpy.conj(epsilon)
if mu is not None: if mu is not None:
mu = numpy.conj(mu) mu = numpy.conj(mu)
assert isinstance(epsilon, NDArray[floating] | NDArray[complexfloating])
L, R = meanas.fdfd.operators.e_full_preconditioners(dxes) L, R = meanas.fdfd.operators.e_full_preconditioners(dxes)
b_preconditioned = (R if adjoint else L) @ b
if adjoint:
b_preconditioned = R @ b
else:
b_preconditioned = L @ b
''' '''
Allocate GPU memory and load in data Allocate GPU memory and load in data
@ -124,7 +120,7 @@ def cg_solver(
queue = pyopencl.CommandQueue(context) queue = pyopencl.CommandQueue(context)
def load_field(v, dtype=numpy.complex128): def load_field(v: NDArray[complexfloating | floating], dtype: type = numpy.complex128) -> pyopencl.array.Array:
return pyopencl.array.to_device(queue, v.astype(dtype)) return pyopencl.array.to_device(queue, v.astype(dtype))
r = load_field(b_preconditioned) # load preconditioned b into r r = load_field(b_preconditioned) # load preconditioned b into r
@ -169,7 +165,12 @@ def cg_solver(
p_step = ops.create_p_step(context) p_step = ops.create_p_step(context)
dot = ops.create_dot(context) dot = ops.create_dot(context)
def a_step(E, H, p, events): def a_step(
E: pyopencl.array.Array,
H: pyopencl.array.Array,
p: pyopencl.array.Array,
events: list[pyopencl.Event],
) -> list[pyopencl.Event]:
return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events) return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events)
''' '''

View File

@ -7,11 +7,11 @@ kernels for use by the other solvers.
See kernels/ for any of the .cl files loaded in this file. See kernels/ for any of the .cl files loaded in this file.
""" """
from typing import List, Callable, Union, Type, Sequence, Optional, Tuple from collections.abc import Callable, Sequence
import logging import logging
import numpy import numpy
from numpy.typing import NDArray, ArrayLike from numpy.typing import ArrayLike
import jinja2 import jinja2
import pyopencl import pyopencl
@ -20,17 +20,20 @@ from pyopencl.elementwise import ElementwiseKernel
from pyopencl.reduction import ReductionKernel from pyopencl.reduction import ReductionKernel
from .csr import CSRMatrix
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Create jinja2 env on module load # Create jinja2 env on module load
jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__, 'kernels')) jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__, 'kernels'))
# Return type for the create_opname(...) functions # Return type for the create_opname(...) functions
operation = Callable[..., List[pyopencl.Event]] operation = Callable[..., list[pyopencl.Event]]
def type_to_C( def type_to_C(
float_type: Type, float_type: type[numpy.floating | numpy.complexfloating],
) -> str: ) -> str:
""" """
Returns a string corresponding to the C equivalent of a numpy type. Returns a string corresponding to the C equivalent of a numpy type.
@ -71,7 +74,7 @@ preamble = '''
'''.format(ctype=ctype_bare) '''.format(ctype=ctype_bare)
def ptrs(*args: str) -> List[str]: def ptrs(*args: str) -> list[str]:
return [ctype + ' *' + s for s in args] return [ctype + ' *' + s for s in args]
@ -169,13 +172,13 @@ def create_a(
p: pyopencl.array.Array, p: pyopencl.array.Array,
idxes: Sequence[Sequence[pyopencl.array.Array]], idxes: Sequence[Sequence[pyopencl.array.Array]],
oeps: pyopencl.array.Array, oeps: pyopencl.array.Array,
inv_mu: Optional[pyopencl.array.Array], inv_mu: pyopencl.array.Array | None,
pec: Optional[pyopencl.array.Array], pec: pyopencl.array.Array | None,
pmc: Optional[pyopencl.array.Array], pmc: pyopencl.array.Array | None,
Pl: pyopencl.array.Array, Pl: pyopencl.array.Array,
Pr: pyopencl.array.Array, Pr: pyopencl.array.Array,
e: List[pyopencl.Event], e: list[pyopencl.Event],
) -> List[pyopencl.Event]: ) -> list[pyopencl.Event]:
e2 = P2E_kernel(E, p, Pr, pec, wait_for=e) e2 = P2E_kernel(E, p, Pr, pec, wait_for=e)
e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2]) e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2])
e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2]) e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2])
@ -227,14 +230,14 @@ def create_xr_step(context: pyopencl.Context) -> operation:
r: pyopencl.array.Array, r: pyopencl.array.Array,
v: pyopencl.array.Array, v: pyopencl.array.Array,
alpha: complex, alpha: complex,
e: List[pyopencl.Event], e: list[pyopencl.Event],
) -> List[pyopencl.Event]: ) -> list[pyopencl.Event]:
return [xr_kernel(x, p, r, v, alpha, wait_for=e)] return [xr_kernel(x, p, r, v, alpha, wait_for=e)]
return xr_update return xr_update
def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., Tuple[complex, complex]]: def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., tuple[complex, complex]]:
""" """
Return a function Return a function
ri_update(r, e) ri_update(r, e)
@ -272,7 +275,7 @@ def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., Tuple[complex
arguments=ctype + ' *r', arguments=ctype + ' *r',
) )
def ri_update(r: pyopencl.array.Array, e: List[pyopencl.Event]) -> Tuple[complex, complex]: def ri_update(r: pyopencl.array.Array, e: list[pyopencl.Event]) -> tuple[complex, complex]:
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get() g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
rr, ri, ii = [g[q] for q in 'xyz'] rr, ri, ii = [g[q] for q in 'xyz']
rho = rr + 2j * ri - ii rho = rr + 2j * ri - ii
@ -315,7 +318,7 @@ def create_p_step(context: pyopencl.Context) -> operation:
p: pyopencl.array.Array, p: pyopencl.array.Array,
r: pyopencl.array.Array, r: pyopencl.array.Array,
beta: complex, beta: complex,
e: List[pyopencl.Event]) -> List[pyopencl.Event]: e: list[pyopencl.Event]) -> list[pyopencl.Event]:
return [p_kernel(p, r, beta, wait_for=e)] return [p_kernel(p, r, beta, wait_for=e)]
return p_update return p_update
@ -350,7 +353,7 @@ def create_dot(context: pyopencl.Context) -> Callable[..., complex]:
def dot( def dot(
p: pyopencl.array.Array, p: pyopencl.array.Array,
v: pyopencl.array.Array, v: pyopencl.array.Array,
e: List[pyopencl.Event], e: list[pyopencl.Event],
) -> complex: ) -> complex:
g = dot_kernel(p, v, wait_for=e) g = dot_kernel(p, v, wait_for=e)
return g.get() return g.get()
@ -406,11 +409,11 @@ def create_a_csr(context: pyopencl.Context) -> operation:
) )
def spmv( def spmv(
v_out, v_out: pyopencl.array.Array,
m, m: CSRMatrix,
v_in, v_in: pyopencl.array.Array,
e: List[pyopencl.Event], e: list[pyopencl.Event],
) -> List[pyopencl.Event]: ) -> list[pyopencl.Event]:
return [spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)] return [spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)]
return spmv return spmv