forked from jan/opencl_fdfd
improve type annotations
This commit is contained in:
parent
2f7a46ff71
commit
6193a9c256
@ -14,21 +14,22 @@ satisfy the constraints for the 'conjugate gradient' algorithm
|
||||
(positive definite, symmetric) and some that don't.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any, TYPE_CHECKING
|
||||
import time
|
||||
import logging
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray, ArrayLike
|
||||
from numpy.linalg import norm
|
||||
from numpy import complexfloating
|
||||
import pyopencl
|
||||
import pyopencl.array
|
||||
import scipy
|
||||
|
||||
import meanas.fdfd.solvers
|
||||
|
||||
from . import ops
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import scipy
|
||||
|
||||
__author__ = 'Jan Petykiewicz'
|
||||
|
||||
@ -58,9 +59,9 @@ def cg(
|
||||
b: ArrayLike,
|
||||
max_iters: int = 10000,
|
||||
err_threshold: float = 1e-6,
|
||||
context: Optional[pyopencl.Context] = None,
|
||||
queue: Optional[pyopencl.CommandQueue] = None,
|
||||
) -> NDArray:
|
||||
context: pyopencl.Context | None = None,
|
||||
queue: pyopencl.CommandQueue | None = None,
|
||||
) -> NDArray[complexfloating]:
|
||||
"""
|
||||
General conjugate-gradient solver for sparse matrices, where A @ x = b.
|
||||
|
||||
@ -84,7 +85,7 @@ def cg(
|
||||
if queue is None:
|
||||
queue = pyopencl.CommandQueue(context)
|
||||
|
||||
def load_field(v, dtype=numpy.complex128):
|
||||
def load_field(v: NDArray[numpy.complexfloating], dtype: type = numpy.complex128) -> pyopencl.array.Array:
|
||||
return pyopencl.array.to_device(queue, v.astype(dtype))
|
||||
|
||||
r = load_field(b)
|
||||
@ -160,9 +161,9 @@ def cg(
|
||||
|
||||
|
||||
def fdfd_cg_solver(
|
||||
solver_opts: Optional[Dict[str, Any]] = None,
|
||||
**fdfd_args
|
||||
) -> NDArray:
|
||||
solver_opts: dict[str, Any] | None = None,
|
||||
**fdfd_args,
|
||||
) -> NDArray[complexfloating]:
|
||||
"""
|
||||
Conjugate gradient FDFD solver using CSR sparse matrices, mainly for
|
||||
testing and development since it's much slower than the solver in main.py.
|
||||
|
@ -5,14 +5,13 @@ This file holds the default FDFD solver, which uses an E-field wave
|
||||
operator implemented directly as OpenCL arithmetic (rather than as
|
||||
a matrix).
|
||||
"""
|
||||
|
||||
from typing import List, Optional, cast
|
||||
import time
|
||||
import logging
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray, ArrayLike
|
||||
from numpy.linalg import norm
|
||||
from numpy import floating, complexfloating
|
||||
import pyopencl
|
||||
import pyopencl.array
|
||||
|
||||
@ -28,16 +27,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def cg_solver(
|
||||
omega: complex,
|
||||
dxes: List[List[NDArray]],
|
||||
dxes: list[list[NDArray[floating | complexfloating]]],
|
||||
J: ArrayLike,
|
||||
epsilon: ArrayLike,
|
||||
mu: Optional[ArrayLike] = None,
|
||||
pec: Optional[ArrayLike] = None,
|
||||
pmc: Optional[ArrayLike] = None,
|
||||
mu: ArrayLike | None = None,
|
||||
pec: ArrayLike | None = None,
|
||||
pmc: ArrayLike | None = None,
|
||||
adjoint: bool = False,
|
||||
max_iters: int = 40000,
|
||||
err_threshold: float = 1e-6,
|
||||
context: Optional[pyopencl.Context] = None,
|
||||
context: pyopencl.Context | None = None,
|
||||
) -> NDArray:
|
||||
"""
|
||||
OpenCL FDFD solver using the iterative conjugate gradient (cg) method
|
||||
@ -108,13 +107,10 @@ def cg_solver(
|
||||
epsilon = numpy.conj(epsilon)
|
||||
if mu is not None:
|
||||
mu = numpy.conj(mu)
|
||||
assert isinstance(epsilon, NDArray[floating] | NDArray[complexfloating])
|
||||
|
||||
L, R = meanas.fdfd.operators.e_full_preconditioners(dxes)
|
||||
|
||||
if adjoint:
|
||||
b_preconditioned = R @ b
|
||||
else:
|
||||
b_preconditioned = L @ b
|
||||
b_preconditioned = (R if adjoint else L) @ b
|
||||
|
||||
'''
|
||||
Allocate GPU memory and load in data
|
||||
@ -124,7 +120,7 @@ def cg_solver(
|
||||
|
||||
queue = pyopencl.CommandQueue(context)
|
||||
|
||||
def load_field(v, dtype=numpy.complex128):
|
||||
def load_field(v: NDArray[complexfloating | floating], dtype: type = numpy.complex128) -> pyopencl.array.Array:
|
||||
return pyopencl.array.to_device(queue, v.astype(dtype))
|
||||
|
||||
r = load_field(b_preconditioned) # load preconditioned b into r
|
||||
@ -169,7 +165,12 @@ def cg_solver(
|
||||
p_step = ops.create_p_step(context)
|
||||
dot = ops.create_dot(context)
|
||||
|
||||
def a_step(E, H, p, events):
|
||||
def a_step(
|
||||
E: pyopencl.array.Array,
|
||||
H: pyopencl.array.Array,
|
||||
p: pyopencl.array.Array,
|
||||
events: list[pyopencl.Event],
|
||||
) -> list[pyopencl.Event]:
|
||||
return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events)
|
||||
|
||||
'''
|
||||
|
@ -7,11 +7,11 @@ kernels for use by the other solvers.
|
||||
See kernels/ for any of the .cl files loaded in this file.
|
||||
"""
|
||||
|
||||
from typing import List, Callable, Union, Type, Sequence, Optional, Tuple
|
||||
from collections.abc import Callable, Sequence
|
||||
import logging
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray, ArrayLike
|
||||
from numpy.typing import ArrayLike
|
||||
import jinja2
|
||||
|
||||
import pyopencl
|
||||
@ -20,17 +20,20 @@ from pyopencl.elementwise import ElementwiseKernel
|
||||
from pyopencl.reduction import ReductionKernel
|
||||
|
||||
|
||||
from .csr import CSRMatrix
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create jinja2 env on module load
|
||||
jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__, 'kernels'))
|
||||
|
||||
# Return type for the create_opname(...) functions
|
||||
operation = Callable[..., List[pyopencl.Event]]
|
||||
operation = Callable[..., list[pyopencl.Event]]
|
||||
|
||||
|
||||
def type_to_C(
|
||||
float_type: Type,
|
||||
float_type: type[numpy.floating | numpy.complexfloating],
|
||||
) -> str:
|
||||
"""
|
||||
Returns a string corresponding to the C equivalent of a numpy type.
|
||||
@ -71,7 +74,7 @@ preamble = '''
|
||||
'''.format(ctype=ctype_bare)
|
||||
|
||||
|
||||
def ptrs(*args: str) -> List[str]:
|
||||
def ptrs(*args: str) -> list[str]:
|
||||
return [ctype + ' *' + s for s in args]
|
||||
|
||||
|
||||
@ -169,13 +172,13 @@ def create_a(
|
||||
p: pyopencl.array.Array,
|
||||
idxes: Sequence[Sequence[pyopencl.array.Array]],
|
||||
oeps: pyopencl.array.Array,
|
||||
inv_mu: Optional[pyopencl.array.Array],
|
||||
pec: Optional[pyopencl.array.Array],
|
||||
pmc: Optional[pyopencl.array.Array],
|
||||
inv_mu: pyopencl.array.Array | None,
|
||||
pec: pyopencl.array.Array | None,
|
||||
pmc: pyopencl.array.Array | None,
|
||||
Pl: pyopencl.array.Array,
|
||||
Pr: pyopencl.array.Array,
|
||||
e: List[pyopencl.Event],
|
||||
) -> List[pyopencl.Event]:
|
||||
e: list[pyopencl.Event],
|
||||
) -> list[pyopencl.Event]:
|
||||
e2 = P2E_kernel(E, p, Pr, pec, wait_for=e)
|
||||
e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2])
|
||||
e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2])
|
||||
@ -227,14 +230,14 @@ def create_xr_step(context: pyopencl.Context) -> operation:
|
||||
r: pyopencl.array.Array,
|
||||
v: pyopencl.array.Array,
|
||||
alpha: complex,
|
||||
e: List[pyopencl.Event],
|
||||
) -> List[pyopencl.Event]:
|
||||
e: list[pyopencl.Event],
|
||||
) -> list[pyopencl.Event]:
|
||||
return [xr_kernel(x, p, r, v, alpha, wait_for=e)]
|
||||
|
||||
return xr_update
|
||||
|
||||
|
||||
def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., Tuple[complex, complex]]:
|
||||
def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., tuple[complex, complex]]:
|
||||
"""
|
||||
Return a function
|
||||
ri_update(r, e)
|
||||
@ -272,7 +275,7 @@ def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., Tuple[complex
|
||||
arguments=ctype + ' *r',
|
||||
)
|
||||
|
||||
def ri_update(r: pyopencl.array.Array, e: List[pyopencl.Event]) -> Tuple[complex, complex]:
|
||||
def ri_update(r: pyopencl.array.Array, e: list[pyopencl.Event]) -> tuple[complex, complex]:
|
||||
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
|
||||
rr, ri, ii = [g[q] for q in 'xyz']
|
||||
rho = rr + 2j * ri - ii
|
||||
@ -315,7 +318,7 @@ def create_p_step(context: pyopencl.Context) -> operation:
|
||||
p: pyopencl.array.Array,
|
||||
r: pyopencl.array.Array,
|
||||
beta: complex,
|
||||
e: List[pyopencl.Event]) -> List[pyopencl.Event]:
|
||||
e: list[pyopencl.Event]) -> list[pyopencl.Event]:
|
||||
return [p_kernel(p, r, beta, wait_for=e)]
|
||||
|
||||
return p_update
|
||||
@ -350,7 +353,7 @@ def create_dot(context: pyopencl.Context) -> Callable[..., complex]:
|
||||
def dot(
|
||||
p: pyopencl.array.Array,
|
||||
v: pyopencl.array.Array,
|
||||
e: List[pyopencl.Event],
|
||||
e: list[pyopencl.Event],
|
||||
) -> complex:
|
||||
g = dot_kernel(p, v, wait_for=e)
|
||||
return g.get()
|
||||
@ -406,11 +409,11 @@ def create_a_csr(context: pyopencl.Context) -> operation:
|
||||
)
|
||||
|
||||
def spmv(
|
||||
v_out,
|
||||
m,
|
||||
v_in,
|
||||
e: List[pyopencl.Event],
|
||||
) -> List[pyopencl.Event]:
|
||||
v_out: pyopencl.array.Array,
|
||||
m: CSRMatrix,
|
||||
v_in: pyopencl.array.Array,
|
||||
e: list[pyopencl.Event],
|
||||
) -> list[pyopencl.Event]:
|
||||
return [spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)]
|
||||
|
||||
return spmv
|
||||
|
Loading…
Reference in New Issue
Block a user