improve type annotations
This commit is contained in:
parent
2f7a46ff71
commit
6193a9c256
@ -14,21 +14,22 @@ satisfy the constraints for the 'conjugate gradient' algorithm
|
|||||||
(positive definite, symmetric) and some that don't.
|
(positive definite, symmetric) and some that don't.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Any, Optional
|
from typing import Any, TYPE_CHECKING
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.typing import NDArray, ArrayLike
|
from numpy.typing import NDArray, ArrayLike
|
||||||
from numpy.linalg import norm
|
from numpy.linalg import norm
|
||||||
|
from numpy import complexfloating
|
||||||
import pyopencl
|
import pyopencl
|
||||||
import pyopencl.array
|
import pyopencl.array
|
||||||
import scipy
|
|
||||||
|
|
||||||
import meanas.fdfd.solvers
|
import meanas.fdfd.solvers
|
||||||
|
|
||||||
from . import ops
|
from . import ops
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import scipy
|
||||||
|
|
||||||
__author__ = 'Jan Petykiewicz'
|
__author__ = 'Jan Petykiewicz'
|
||||||
|
|
||||||
@ -58,9 +59,9 @@ def cg(
|
|||||||
b: ArrayLike,
|
b: ArrayLike,
|
||||||
max_iters: int = 10000,
|
max_iters: int = 10000,
|
||||||
err_threshold: float = 1e-6,
|
err_threshold: float = 1e-6,
|
||||||
context: Optional[pyopencl.Context] = None,
|
context: pyopencl.Context | None = None,
|
||||||
queue: Optional[pyopencl.CommandQueue] = None,
|
queue: pyopencl.CommandQueue | None = None,
|
||||||
) -> NDArray:
|
) -> NDArray[complexfloating]:
|
||||||
"""
|
"""
|
||||||
General conjugate-gradient solver for sparse matrices, where A @ x = b.
|
General conjugate-gradient solver for sparse matrices, where A @ x = b.
|
||||||
|
|
||||||
@ -84,7 +85,7 @@ def cg(
|
|||||||
if queue is None:
|
if queue is None:
|
||||||
queue = pyopencl.CommandQueue(context)
|
queue = pyopencl.CommandQueue(context)
|
||||||
|
|
||||||
def load_field(v, dtype=numpy.complex128):
|
def load_field(v: NDArray[numpy.complexfloating], dtype: type = numpy.complex128) -> pyopencl.array.Array:
|
||||||
return pyopencl.array.to_device(queue, v.astype(dtype))
|
return pyopencl.array.to_device(queue, v.astype(dtype))
|
||||||
|
|
||||||
r = load_field(b)
|
r = load_field(b)
|
||||||
@ -160,9 +161,9 @@ def cg(
|
|||||||
|
|
||||||
|
|
||||||
def fdfd_cg_solver(
|
def fdfd_cg_solver(
|
||||||
solver_opts: Optional[Dict[str, Any]] = None,
|
solver_opts: dict[str, Any] | None = None,
|
||||||
**fdfd_args
|
**fdfd_args,
|
||||||
) -> NDArray:
|
) -> NDArray[complexfloating]:
|
||||||
"""
|
"""
|
||||||
Conjugate gradient FDFD solver using CSR sparse matrices, mainly for
|
Conjugate gradient FDFD solver using CSR sparse matrices, mainly for
|
||||||
testing and development since it's much slower than the solver in main.py.
|
testing and development since it's much slower than the solver in main.py.
|
||||||
|
@ -5,14 +5,13 @@ This file holds the default FDFD solver, which uses an E-field wave
|
|||||||
operator implemented directly as OpenCL arithmetic (rather than as
|
operator implemented directly as OpenCL arithmetic (rather than as
|
||||||
a matrix).
|
a matrix).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, cast
|
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.typing import NDArray, ArrayLike
|
from numpy.typing import NDArray, ArrayLike
|
||||||
from numpy.linalg import norm
|
from numpy.linalg import norm
|
||||||
|
from numpy import floating, complexfloating
|
||||||
import pyopencl
|
import pyopencl
|
||||||
import pyopencl.array
|
import pyopencl.array
|
||||||
|
|
||||||
@ -28,16 +27,16 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def cg_solver(
|
def cg_solver(
|
||||||
omega: complex,
|
omega: complex,
|
||||||
dxes: List[List[NDArray]],
|
dxes: list[list[NDArray[floating | complexfloating]]],
|
||||||
J: ArrayLike,
|
J: ArrayLike,
|
||||||
epsilon: ArrayLike,
|
epsilon: ArrayLike,
|
||||||
mu: Optional[ArrayLike] = None,
|
mu: ArrayLike | None = None,
|
||||||
pec: Optional[ArrayLike] = None,
|
pec: ArrayLike | None = None,
|
||||||
pmc: Optional[ArrayLike] = None,
|
pmc: ArrayLike | None = None,
|
||||||
adjoint: bool = False,
|
adjoint: bool = False,
|
||||||
max_iters: int = 40000,
|
max_iters: int = 40000,
|
||||||
err_threshold: float = 1e-6,
|
err_threshold: float = 1e-6,
|
||||||
context: Optional[pyopencl.Context] = None,
|
context: pyopencl.Context | None = None,
|
||||||
) -> NDArray:
|
) -> NDArray:
|
||||||
"""
|
"""
|
||||||
OpenCL FDFD solver using the iterative conjugate gradient (cg) method
|
OpenCL FDFD solver using the iterative conjugate gradient (cg) method
|
||||||
@ -108,13 +107,10 @@ def cg_solver(
|
|||||||
epsilon = numpy.conj(epsilon)
|
epsilon = numpy.conj(epsilon)
|
||||||
if mu is not None:
|
if mu is not None:
|
||||||
mu = numpy.conj(mu)
|
mu = numpy.conj(mu)
|
||||||
|
assert isinstance(epsilon, NDArray[floating] | NDArray[complexfloating])
|
||||||
|
|
||||||
L, R = meanas.fdfd.operators.e_full_preconditioners(dxes)
|
L, R = meanas.fdfd.operators.e_full_preconditioners(dxes)
|
||||||
|
b_preconditioned = (R if adjoint else L) @ b
|
||||||
if adjoint:
|
|
||||||
b_preconditioned = R @ b
|
|
||||||
else:
|
|
||||||
b_preconditioned = L @ b
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Allocate GPU memory and load in data
|
Allocate GPU memory and load in data
|
||||||
@ -124,7 +120,7 @@ def cg_solver(
|
|||||||
|
|
||||||
queue = pyopencl.CommandQueue(context)
|
queue = pyopencl.CommandQueue(context)
|
||||||
|
|
||||||
def load_field(v, dtype=numpy.complex128):
|
def load_field(v: NDArray[complexfloating | floating], dtype: type = numpy.complex128) -> pyopencl.array.Array:
|
||||||
return pyopencl.array.to_device(queue, v.astype(dtype))
|
return pyopencl.array.to_device(queue, v.astype(dtype))
|
||||||
|
|
||||||
r = load_field(b_preconditioned) # load preconditioned b into r
|
r = load_field(b_preconditioned) # load preconditioned b into r
|
||||||
@ -169,7 +165,12 @@ def cg_solver(
|
|||||||
p_step = ops.create_p_step(context)
|
p_step = ops.create_p_step(context)
|
||||||
dot = ops.create_dot(context)
|
dot = ops.create_dot(context)
|
||||||
|
|
||||||
def a_step(E, H, p, events):
|
def a_step(
|
||||||
|
E: pyopencl.array.Array,
|
||||||
|
H: pyopencl.array.Array,
|
||||||
|
p: pyopencl.array.Array,
|
||||||
|
events: list[pyopencl.Event],
|
||||||
|
) -> list[pyopencl.Event]:
|
||||||
return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events)
|
return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
@ -7,11 +7,11 @@ kernels for use by the other solvers.
|
|||||||
See kernels/ for any of the .cl files loaded in this file.
|
See kernels/ for any of the .cl files loaded in this file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Callable, Union, Type, Sequence, Optional, Tuple
|
from collections.abc import Callable, Sequence
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.typing import NDArray, ArrayLike
|
from numpy.typing import ArrayLike
|
||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
import pyopencl
|
import pyopencl
|
||||||
@ -20,17 +20,20 @@ from pyopencl.elementwise import ElementwiseKernel
|
|||||||
from pyopencl.reduction import ReductionKernel
|
from pyopencl.reduction import ReductionKernel
|
||||||
|
|
||||||
|
|
||||||
|
from .csr import CSRMatrix
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Create jinja2 env on module load
|
# Create jinja2 env on module load
|
||||||
jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__, 'kernels'))
|
jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__, 'kernels'))
|
||||||
|
|
||||||
# Return type for the create_opname(...) functions
|
# Return type for the create_opname(...) functions
|
||||||
operation = Callable[..., List[pyopencl.Event]]
|
operation = Callable[..., list[pyopencl.Event]]
|
||||||
|
|
||||||
|
|
||||||
def type_to_C(
|
def type_to_C(
|
||||||
float_type: Type,
|
float_type: type[numpy.floating | numpy.complexfloating],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Returns a string corresponding to the C equivalent of a numpy type.
|
Returns a string corresponding to the C equivalent of a numpy type.
|
||||||
@ -71,7 +74,7 @@ preamble = '''
|
|||||||
'''.format(ctype=ctype_bare)
|
'''.format(ctype=ctype_bare)
|
||||||
|
|
||||||
|
|
||||||
def ptrs(*args: str) -> List[str]:
|
def ptrs(*args: str) -> list[str]:
|
||||||
return [ctype + ' *' + s for s in args]
|
return [ctype + ' *' + s for s in args]
|
||||||
|
|
||||||
|
|
||||||
@ -169,13 +172,13 @@ def create_a(
|
|||||||
p: pyopencl.array.Array,
|
p: pyopencl.array.Array,
|
||||||
idxes: Sequence[Sequence[pyopencl.array.Array]],
|
idxes: Sequence[Sequence[pyopencl.array.Array]],
|
||||||
oeps: pyopencl.array.Array,
|
oeps: pyopencl.array.Array,
|
||||||
inv_mu: Optional[pyopencl.array.Array],
|
inv_mu: pyopencl.array.Array | None,
|
||||||
pec: Optional[pyopencl.array.Array],
|
pec: pyopencl.array.Array | None,
|
||||||
pmc: Optional[pyopencl.array.Array],
|
pmc: pyopencl.array.Array | None,
|
||||||
Pl: pyopencl.array.Array,
|
Pl: pyopencl.array.Array,
|
||||||
Pr: pyopencl.array.Array,
|
Pr: pyopencl.array.Array,
|
||||||
e: List[pyopencl.Event],
|
e: list[pyopencl.Event],
|
||||||
) -> List[pyopencl.Event]:
|
) -> list[pyopencl.Event]:
|
||||||
e2 = P2E_kernel(E, p, Pr, pec, wait_for=e)
|
e2 = P2E_kernel(E, p, Pr, pec, wait_for=e)
|
||||||
e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2])
|
e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2])
|
||||||
e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2])
|
e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2])
|
||||||
@ -227,14 +230,14 @@ def create_xr_step(context: pyopencl.Context) -> operation:
|
|||||||
r: pyopencl.array.Array,
|
r: pyopencl.array.Array,
|
||||||
v: pyopencl.array.Array,
|
v: pyopencl.array.Array,
|
||||||
alpha: complex,
|
alpha: complex,
|
||||||
e: List[pyopencl.Event],
|
e: list[pyopencl.Event],
|
||||||
) -> List[pyopencl.Event]:
|
) -> list[pyopencl.Event]:
|
||||||
return [xr_kernel(x, p, r, v, alpha, wait_for=e)]
|
return [xr_kernel(x, p, r, v, alpha, wait_for=e)]
|
||||||
|
|
||||||
return xr_update
|
return xr_update
|
||||||
|
|
||||||
|
|
||||||
def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., Tuple[complex, complex]]:
|
def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., tuple[complex, complex]]:
|
||||||
"""
|
"""
|
||||||
Return a function
|
Return a function
|
||||||
ri_update(r, e)
|
ri_update(r, e)
|
||||||
@ -272,7 +275,7 @@ def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., Tuple[complex
|
|||||||
arguments=ctype + ' *r',
|
arguments=ctype + ' *r',
|
||||||
)
|
)
|
||||||
|
|
||||||
def ri_update(r: pyopencl.array.Array, e: List[pyopencl.Event]) -> Tuple[complex, complex]:
|
def ri_update(r: pyopencl.array.Array, e: list[pyopencl.Event]) -> tuple[complex, complex]:
|
||||||
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
|
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
|
||||||
rr, ri, ii = [g[q] for q in 'xyz']
|
rr, ri, ii = [g[q] for q in 'xyz']
|
||||||
rho = rr + 2j * ri - ii
|
rho = rr + 2j * ri - ii
|
||||||
@ -315,7 +318,7 @@ def create_p_step(context: pyopencl.Context) -> operation:
|
|||||||
p: pyopencl.array.Array,
|
p: pyopencl.array.Array,
|
||||||
r: pyopencl.array.Array,
|
r: pyopencl.array.Array,
|
||||||
beta: complex,
|
beta: complex,
|
||||||
e: List[pyopencl.Event]) -> List[pyopencl.Event]:
|
e: list[pyopencl.Event]) -> list[pyopencl.Event]:
|
||||||
return [p_kernel(p, r, beta, wait_for=e)]
|
return [p_kernel(p, r, beta, wait_for=e)]
|
||||||
|
|
||||||
return p_update
|
return p_update
|
||||||
@ -350,7 +353,7 @@ def create_dot(context: pyopencl.Context) -> Callable[..., complex]:
|
|||||||
def dot(
|
def dot(
|
||||||
p: pyopencl.array.Array,
|
p: pyopencl.array.Array,
|
||||||
v: pyopencl.array.Array,
|
v: pyopencl.array.Array,
|
||||||
e: List[pyopencl.Event],
|
e: list[pyopencl.Event],
|
||||||
) -> complex:
|
) -> complex:
|
||||||
g = dot_kernel(p, v, wait_for=e)
|
g = dot_kernel(p, v, wait_for=e)
|
||||||
return g.get()
|
return g.get()
|
||||||
@ -406,11 +409,11 @@ def create_a_csr(context: pyopencl.Context) -> operation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def spmv(
|
def spmv(
|
||||||
v_out,
|
v_out: pyopencl.array.Array,
|
||||||
m,
|
m: CSRMatrix,
|
||||||
v_in,
|
v_in: pyopencl.array.Array,
|
||||||
e: List[pyopencl.Event],
|
e: list[pyopencl.Event],
|
||||||
) -> List[pyopencl.Event]:
|
) -> list[pyopencl.Event]:
|
||||||
return [spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)]
|
return [spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)]
|
||||||
|
|
||||||
return spmv
|
return spmv
|
||||||
|
Loading…
Reference in New Issue
Block a user