improve type annotations, formatting, comment styles
This commit is contained in:
parent
81bb1dd2c0
commit
efeb29479b
@ -14,14 +14,16 @@ satisfy the constraints for the 'conjugate gradient' algorithm
|
|||||||
(positive definite, symmetric) and some that don't.
|
(positive definite, symmetric) and some that don't.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any, Optional
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
from numpy.typing import NDArray, ArrayLike
|
||||||
from numpy.linalg import norm
|
from numpy.linalg import norm
|
||||||
import pyopencl
|
import pyopencl
|
||||||
import pyopencl.array
|
import pyopencl.array
|
||||||
|
import scipy
|
||||||
|
|
||||||
import meanas.fdfd.solvers
|
import meanas.fdfd.solvers
|
||||||
|
|
||||||
@ -33,39 +35,45 @@ __author__ = 'Jan Petykiewicz'
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CSRMatrix(object):
|
class CSRMatrix:
|
||||||
"""
|
"""
|
||||||
Matrix stored in Compressed Sparse Row format, in GPU RAM.
|
Matrix stored in Compressed Sparse Row format, in GPU RAM.
|
||||||
"""
|
"""
|
||||||
row_ptr = None # type: pyopencl.array.Array
|
row_ptr: pyopencl.array.Array
|
||||||
col_ind = None # type: pyopencl.array.Array
|
col_ind: pyopencl.array.Array
|
||||||
data = None # type: pyopencl.array.Array
|
data: pyopencl.array.Array
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
queue: pyopencl.CommandQueue,
|
self,
|
||||||
m: 'scipy.sparse.csr_matrix'):
|
queue: pyopencl.CommandQueue,
|
||||||
|
m: 'scipy.sparse.csr_matrix',
|
||||||
|
) -> None:
|
||||||
self.row_ptr = pyopencl.array.to_device(queue, m.indptr)
|
self.row_ptr = pyopencl.array.to_device(queue, m.indptr)
|
||||||
self.col_ind = pyopencl.array.to_device(queue, m.indices)
|
self.col_ind = pyopencl.array.to_device(queue, m.indices)
|
||||||
self.data = pyopencl.array.to_device(queue, m.data.astype(numpy.complex128))
|
self.data = pyopencl.array.to_device(queue, m.data.astype(numpy.complex128))
|
||||||
|
|
||||||
|
|
||||||
def cg(A: 'scipy.sparse.csr_matrix',
|
def cg(
|
||||||
b: numpy.ndarray,
|
A: 'scipy.sparse.csr_matrix',
|
||||||
max_iters: int = 10000,
|
b: ArrayLike,
|
||||||
err_threshold: float = 1e-6,
|
max_iters: int = 10000,
|
||||||
context: pyopencl.Context = None,
|
err_threshold: float = 1e-6,
|
||||||
queue: pyopencl.CommandQueue = None,
|
context: Optional[pyopencl.Context] = None,
|
||||||
) -> numpy.ndarray:
|
queue: Optional[pyopencl.CommandQueue] = None,
|
||||||
|
) -> NDArray:
|
||||||
"""
|
"""
|
||||||
General conjugate-gradient solver for sparse matrices, where A @ x = b.
|
General conjugate-gradient solver for sparse matrices, where A @ x = b.
|
||||||
|
|
||||||
:param A: Matrix to solve (CSR format)
|
Args:
|
||||||
:param b: Right-hand side vector (dense ndarray)
|
A: Matrix to solve (CSR format)
|
||||||
:param max_iters: Maximum number of iterations
|
b: Right-hand side vector (dense ndarray)
|
||||||
:param err_threshold: Error threshold for successful solve, relative to norm(b)
|
max_iters: Maximum number of iterations
|
||||||
:param context: PyOpenCL context. Will be created if not given.
|
err_threshold: Error threshold for successful solve, relative to norm(b)
|
||||||
:param queue: PyOpenCL command queue. Will be created if not given.
|
context: PyOpenCL context. Will be created if not given.
|
||||||
:return: Solution vector x; returned even if solve doesn't converge.
|
queue: PyOpenCL command queue. Will be created if not given.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Solution vector x; returned even if solve doesn't converge.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
@ -151,29 +159,37 @@ def cg(A: 'scipy.sparse.csr_matrix',
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def fdfd_cg_solver(solver_opts: Dict[str, Any] = None,
|
def fdfd_cg_solver(
|
||||||
**fdfd_args
|
solver_opts: Optional[Dict[str, Any]] = None,
|
||||||
) -> numpy.ndarray:
|
**fdfd_args
|
||||||
|
) -> NDArray:
|
||||||
"""
|
"""
|
||||||
Conjugate gradient FDFD solver using CSR sparse matrices, mainly for
|
Conjugate gradient FDFD solver using CSR sparse matrices, mainly for
|
||||||
testing and development since it's much slower than the solver in main.py.
|
testing and development since it's much slower than the solver in main.py.
|
||||||
|
|
||||||
Calls meanas.fdfd.solvers.generic(**fdfd_args,
|
Calls meanas.fdfd.solvers.generic(
|
||||||
matrix_solver=opencl_fdfd.csr.cg,
|
**fdfd_args,
|
||||||
matrix_solver_opts=solver_opts)
|
matrix_solver=opencl_fdfd.csr.cg,
|
||||||
|
matrix_solver_opts=solver_opts,
|
||||||
|
)
|
||||||
|
|
||||||
:param solver_opts: Passed as matrix_solver_opts to fdfd_tools.solver.generic(...).
|
Args:
|
||||||
Default {}.
|
solver_opts: Passed as matrix_solver_opts to fdfd_tools.solver.generic(...).
|
||||||
:param fdfd_args: Passed as **fdfd_args to fdfd_tools.solver.generic(...).
|
Default {}.
|
||||||
Should include all of the arguments **except** matrix_solver and matrix_solver_opts
|
fdfd_args: Passed as **fdfd_args to fdfd_tools.solver.generic(...).
|
||||||
:return: E-field which solves the system.
|
Should include all of the arguments **except** matrix_solver and matrix_solver_opts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
E-field which solves the system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if solver_opts is None:
|
if solver_opts is None:
|
||||||
solver_opts = dict()
|
solver_opts = dict()
|
||||||
|
|
||||||
x = meanas.fdfd.solvers.generic(matrix_solver=cg,
|
x = meanas.fdfd.solvers.generic(
|
||||||
matrix_solver_opts=solver_opts,
|
matrix_solver=cg,
|
||||||
**fdfd_args)
|
matrix_solver_opts=solver_opts,
|
||||||
|
**fdfd_args,
|
||||||
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
@ -6,11 +6,12 @@ operator implemented directly as OpenCL arithmetic (rather than as
|
|||||||
a matrix).
|
a matrix).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Optional, cast
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
from numpy.typing import NDArray, ArrayLike
|
||||||
from numpy.linalg import norm
|
from numpy.linalg import norm
|
||||||
import pyopencl
|
import pyopencl
|
||||||
import pyopencl.array
|
import pyopencl.array
|
||||||
@ -25,18 +26,19 @@ __author__ = 'Jan Petykiewicz'
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def cg_solver(omega: complex,
|
def cg_solver(
|
||||||
dxes: List[List[numpy.ndarray]],
|
omega: complex,
|
||||||
J: numpy.ndarray,
|
dxes: List[List[NDArray]],
|
||||||
epsilon: numpy.ndarray,
|
J: ArrayLike,
|
||||||
mu: numpy.ndarray = None,
|
epsilon: ArrayLike,
|
||||||
pec: numpy.ndarray = None,
|
mu: Optional[ArrayLike] = None,
|
||||||
pmc: numpy.ndarray = None,
|
pec: Optional[ArrayLike] = None,
|
||||||
adjoint: bool = False,
|
pmc: Optional[ArrayLike] = None,
|
||||||
max_iters: int = 40000,
|
adjoint: bool = False,
|
||||||
err_threshold: float = 1e-6,
|
max_iters: int = 40000,
|
||||||
context: pyopencl.Context = None,
|
err_threshold: float = 1e-6,
|
||||||
) -> numpy.ndarray:
|
context: Optional[pyopencl.Context] = None,
|
||||||
|
) -> NDArray:
|
||||||
"""
|
"""
|
||||||
OpenCL FDFD solver using the iterative conjugate gradient (cg) method
|
OpenCL FDFD solver using the iterative conjugate gradient (cg) method
|
||||||
and implementing the diagonalized E-field wave operator directly in
|
and implementing the diagonalized E-field wave operator directly in
|
||||||
@ -46,28 +48,30 @@ def cg_solver(omega: complex,
|
|||||||
either use meanas.fdmath.vec() or numpy:
|
either use meanas.fdmath.vec() or numpy:
|
||||||
f_1D = numpy.hstack(tuple((fi.flatten(order='F') for fi in [f_x, f_y, f_z])))
|
f_1D = numpy.hstack(tuple((fi.flatten(order='F') for fi in [f_x, f_y, f_z])))
|
||||||
|
|
||||||
:param omega: Complex frequency to solve at.
|
Args:
|
||||||
:param dxes: [[dx_e, dy_e, dz_e], [dx_h, dy_h, dz_h]] (complex cell sizes)
|
omega: Complex frequency to solve at.
|
||||||
:param J: Electric current distribution (at E-field locations)
|
dxes: [[dx_e, dy_e, dz_e], [dx_h, dy_h, dz_h]] (complex cell sizes)
|
||||||
:param epsilon: Dielectric constant distribution (at E-field locations)
|
J: Electric current distribution (at E-field locations)
|
||||||
:param mu: Magnetic permeability distribution (at H-field locations)
|
epsilon: Dielectric constant distribution (at E-field locations)
|
||||||
:param pec: Perfect electric conductor distribution
|
mu: Magnetic permeability distribution (at H-field locations)
|
||||||
(at E-field locations; non-zero value indicates PEC is present)
|
pec: Perfect electric conductor distribution
|
||||||
:param pmc: Perfect magnetic conductor distribution
|
(at E-field locations; non-zero value indicates PEC is present)
|
||||||
(at H-field locations; non-zero value indicates PMC is present)
|
pmc: Perfect magnetic conductor distribution
|
||||||
:param adjoint: If true, solves the adjoint problem.
|
(at H-field locations; non-zero value indicates PMC is present)
|
||||||
:param max_iters: Maximum number of iterations. Default 40,000.
|
adjoint: If true, solves the adjoint problem.
|
||||||
:param err_threshold: If (r @ r.conj()) / norm(1j * omega * J) < err_threshold, success.
|
max_iters: Maximum number of iterations. Default 40,000.
|
||||||
Default 1e-6.
|
err_threshold: If (r @ r.conj()) / norm(1j * omega * J) < err_threshold, success.
|
||||||
:param context: PyOpenCL context to run in. If not given, construct a new context.
|
Default 1e-6.
|
||||||
:return: E-field which solves the system. Returned even if we did not converge.
|
context: PyOpenCL context to run in. If not given, construct a new context.
|
||||||
"""
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
E-field which solves the system. Returned even if we did not converge.
|
||||||
|
"""
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
b = -1j * omega * J
|
shape = [dd.size for dd in dxes[0]]
|
||||||
|
|
||||||
shape = [d.size for d in dxes[0]]
|
b = -1j * omega * numpy.array(J, copy=False)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
** In this comment, I use the following notation:
|
** In this comment, I use the following notation:
|
||||||
@ -96,9 +100,10 @@ def cg_solver(omega: complex,
|
|||||||
We can accomplish all this simply by conjugating everything (except J) and
|
We can accomplish all this simply by conjugating everything (except J) and
|
||||||
reversing the order of L and R
|
reversing the order of L and R
|
||||||
'''
|
'''
|
||||||
|
epsilon = numpy.array(epsilon, copy=False)
|
||||||
if adjoint:
|
if adjoint:
|
||||||
# Conjugate everything
|
# Conjugate everything
|
||||||
dxes = [[numpy.conj(d) for d in dd] for dd in dxes]
|
dxes = [[numpy.conj(dd) for dd in dds] for dds in dxes]
|
||||||
omega = numpy.conj(omega)
|
omega = numpy.conj(omega)
|
||||||
epsilon = numpy.conj(epsilon)
|
epsilon = numpy.conj(epsilon)
|
||||||
if mu is not None:
|
if mu is not None:
|
||||||
@ -132,7 +137,7 @@ def cg_solver(omega: complex,
|
|||||||
rho = 1.0 + 0j
|
rho = 1.0 + 0j
|
||||||
errs = []
|
errs = []
|
||||||
|
|
||||||
inv_dxes = [[load_field(1 / d) for d in dd] for dd in dxes]
|
inv_dxes = [[load_field(1 / numpy.array(dd, copy=False)) for dd in dds] for dds in dxes]
|
||||||
oeps = load_field(-omega ** 2 * epsilon)
|
oeps = load_field(-omega ** 2 * epsilon)
|
||||||
Pl = load_field(L.diagonal())
|
Pl = load_field(L.diagonal())
|
||||||
Pr = load_field(R.diagonal())
|
Pr = load_field(R.diagonal())
|
||||||
@ -140,17 +145,18 @@ def cg_solver(omega: complex,
|
|||||||
if mu is None:
|
if mu is None:
|
||||||
invm = load_field(numpy.array([]))
|
invm = load_field(numpy.array([]))
|
||||||
else:
|
else:
|
||||||
invm = load_field(1 / mu)
|
invm = load_field(1 / numpy.array(mu, copy=False))
|
||||||
|
mu = numpy.array(mu, copy=False)
|
||||||
|
|
||||||
if pec is None:
|
if pec is None:
|
||||||
gpec = load_field(numpy.array([]), dtype=numpy.int8)
|
gpec = load_field(numpy.array([]), dtype=numpy.int8)
|
||||||
else:
|
else:
|
||||||
gpec = load_field(pec.astype(bool), dtype=numpy.int8)
|
gpec = load_field(numpy.array(pec, dtype=bool, copy=False), dtype=numpy.int8)
|
||||||
|
|
||||||
if pmc is None:
|
if pmc is None:
|
||||||
gpmc = load_field(numpy.array([]), dtype=numpy.int8)
|
gpmc = load_field(numpy.array([]), dtype=numpy.int8)
|
||||||
else:
|
else:
|
||||||
gpmc = load_field(pmc.astype(bool), dtype=numpy.int8)
|
gpmc = load_field(numpy.array(pmc, dtype=bool, copy=False), dtype=numpy.int8)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Generate OpenCL kernels
|
Generate OpenCL kernels
|
||||||
|
@ -7,10 +7,11 @@ kernels for use by the other solvers.
|
|||||||
See kernels/ for any of the .cl files loaded in this file.
|
See kernels/ for any of the .cl files loaded in this file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Callable
|
from typing import List, Callable, Union, Type, Sequence, Optional, Tuple
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
from numpy.typing import NDArray, ArrayLike
|
||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
import pyopencl
|
import pyopencl
|
||||||
@ -28,12 +29,17 @@ jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__, 'kernels'))
|
|||||||
operation = Callable[..., List[pyopencl.Event]]
|
operation = Callable[..., List[pyopencl.Event]]
|
||||||
|
|
||||||
|
|
||||||
def type_to_C(float_type: numpy.float32 or numpy.float64) -> str:
|
def type_to_C(
|
||||||
|
float_type: Type,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Returns a string corresponding to the C equivalent of a numpy type.
|
Returns a string corresponding to the C equivalent of a numpy type.
|
||||||
|
|
||||||
:param float_type: numpy type: float32, float64, complex64, complex128
|
Args:
|
||||||
:return: string containing the corresponding C type (eg. 'double')
|
float_type: numpy type: float32, float64, complex64, complex128
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
string containing the corresponding C type (eg. 'double')
|
||||||
"""
|
"""
|
||||||
types = {
|
types = {
|
||||||
numpy.float32: 'float',
|
numpy.float32: 'float',
|
||||||
@ -68,12 +74,13 @@ def ptrs(*args: str) -> List[str]:
|
|||||||
return [ctype + ' *' + s for s in args]
|
return [ctype + ' *' + s for s in args]
|
||||||
|
|
||||||
|
|
||||||
def create_a(context: pyopencl.Context,
|
def create_a(
|
||||||
shape: numpy.ndarray,
|
context: pyopencl.Context,
|
||||||
mu: bool = False,
|
shape: ArrayLike,
|
||||||
pec: bool = False,
|
mu: bool = False,
|
||||||
pmc: bool = False,
|
pec: bool = False,
|
||||||
) -> operation:
|
pmc: bool = False,
|
||||||
|
) -> operation:
|
||||||
"""
|
"""
|
||||||
Return a function which performs (A @ p), where A is the FDFD wave equation for E-field.
|
Return a function which performs (A @ p), where A is the FDFD wave equation for E-field.
|
||||||
|
|
||||||
@ -94,12 +101,15 @@ def create_a(context: pyopencl.Context,
|
|||||||
|
|
||||||
and returns a list of pyopencl.Event.
|
and returns a list of pyopencl.Event.
|
||||||
|
|
||||||
:param context: PyOpenCL context
|
Args:
|
||||||
:param shape: Dimensions of the E-field
|
context: PyOpenCL context
|
||||||
:param mu: False iff (mu == 1) everywhere
|
shape: Dimensions of the E-field
|
||||||
:param pec: False iff no PEC anywhere
|
mu: False iff (mu == 1) everywhere
|
||||||
:param pmc: False iff no PMC anywhere
|
pec: False iff no PEC anywhere
|
||||||
:return: Function for computing (A @ p)
|
pmc: False iff no PMC anywhere
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function for computing (A @ p)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
common_source = jinja_env.get_template('common.cl').render(shape=shape)
|
common_source = jinja_env.get_template('common.cl').render(shape=shape)
|
||||||
@ -113,45 +123,67 @@ def create_a(context: pyopencl.Context,
|
|||||||
Convert p to initial E (ie, apply right preconditioner and PEC)
|
Convert p to initial E (ie, apply right preconditioner and PEC)
|
||||||
'''
|
'''
|
||||||
p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec)
|
p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec)
|
||||||
P2E_kernel = ElementwiseKernel(context,
|
P2E_kernel = ElementwiseKernel(
|
||||||
name='P2E',
|
context,
|
||||||
preamble=preamble,
|
name='P2E',
|
||||||
operation=p2e_source,
|
preamble=preamble,
|
||||||
arguments=', '.join(ptrs('E', 'p', 'Pr') + pec_arg))
|
operation=p2e_source,
|
||||||
|
arguments=', '.join(ptrs('E', 'p', 'Pr') + pec_arg),
|
||||||
|
)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Calculate intermediate H from intermediate E
|
Calculate intermediate H from intermediate E
|
||||||
'''
|
'''
|
||||||
e2h_source = jinja_env.get_template('e2h.cl').render(mu=mu,
|
e2h_source = jinja_env.get_template('e2h.cl').render(
|
||||||
pmc=pmc,
|
mu=mu,
|
||||||
common_cl=common_source)
|
pmc=pmc,
|
||||||
E2H_kernel = ElementwiseKernel(context,
|
common_cl=common_source,
|
||||||
name='E2H',
|
)
|
||||||
preamble=preamble,
|
E2H_kernel = ElementwiseKernel(
|
||||||
operation=e2h_source,
|
context,
|
||||||
arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des))
|
name='E2H',
|
||||||
|
preamble=preamble,
|
||||||
|
operation=e2h_source,
|
||||||
|
arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des),
|
||||||
|
)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Calculate final E (including left preconditioner)
|
Calculate final E (including left preconditioner)
|
||||||
'''
|
'''
|
||||||
h2e_source = jinja_env.get_template('h2e.cl').render(pec=pec,
|
h2e_source = jinja_env.get_template('h2e.cl').render(
|
||||||
common_cl=common_source)
|
pec=pec,
|
||||||
H2E_kernel = ElementwiseKernel(context,
|
common_cl=common_source,
|
||||||
name='H2E',
|
)
|
||||||
preamble=preamble,
|
H2E_kernel = ElementwiseKernel(
|
||||||
operation=h2e_source,
|
context,
|
||||||
arguments=', '.join(ptrs('E', 'H', 'oeps', 'Pl') + pec_arg + dhs))
|
name='H2E',
|
||||||
|
preamble=preamble,
|
||||||
|
operation=h2e_source,
|
||||||
|
arguments=', '.join(ptrs('E', 'H', 'oeps', 'Pl') + pec_arg + dhs),
|
||||||
|
)
|
||||||
|
|
||||||
def spmv(E, H, p, idxes, oeps, inv_mu, pec, pmc, Pl, Pr, e):
|
def spmv(
|
||||||
|
E: pyopencl.array.Array,
|
||||||
|
H: pyopencl.array.Array,
|
||||||
|
p: pyopencl.array.Array,
|
||||||
|
idxes: Sequence[Sequence[pyopencl.array.Array]],
|
||||||
|
oeps: pyopencl.array.Array,
|
||||||
|
inv_mu: Optional[pyopencl.array.Array],
|
||||||
|
pec: Optional[pyopencl.array.Array],
|
||||||
|
pmc: Optional[pyopencl.array.Array],
|
||||||
|
Pl: pyopencl.array.Array,
|
||||||
|
Pr: pyopencl.array.Array,
|
||||||
|
e: List[pyopencl.Event],
|
||||||
|
) -> List[pyopencl.Event]:
|
||||||
e2 = P2E_kernel(E, p, Pr, pec, wait_for=e)
|
e2 = P2E_kernel(E, p, Pr, pec, wait_for=e)
|
||||||
e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2])
|
e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2])
|
||||||
e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2])
|
e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2])
|
||||||
return [e2]
|
return [e2]
|
||||||
|
|
||||||
logger.debug('Preamble: \n{}'.format(preamble))
|
logger.debug(f'Preamble: \n{preamble}')
|
||||||
logger.debug('p2e: \n{}'.format(p2e_source))
|
logger.debug(f'p2e: \n{p2e_source}')
|
||||||
logger.debug('e2h: \n{}'.format(e2h_source))
|
logger.debug(f'e2h: \n{e2h_source}')
|
||||||
logger.debug('h2e: \n{}'.format(h2e_source))
|
logger.debug(f'h2e: \n{h2e_source}')
|
||||||
|
|
||||||
return spmv
|
return spmv
|
||||||
|
|
||||||
@ -167,8 +199,11 @@ def create_xr_step(context: pyopencl.Context) -> operation:
|
|||||||
after waiting for all in the list e
|
after waiting for all in the list e
|
||||||
and returns a list of pyopencl.Event
|
and returns a list of pyopencl.Event
|
||||||
|
|
||||||
:param context: PyOpenCL context
|
Args:
|
||||||
:return: Function for performing x and r updates
|
context: PyOpenCL context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function for performing x and r updates
|
||||||
"""
|
"""
|
||||||
update_xr_source = '''
|
update_xr_source = '''
|
||||||
x[i] = add(x[i], mul(alpha, p[i]));
|
x[i] = add(x[i], mul(alpha, p[i]));
|
||||||
@ -177,19 +212,28 @@ def create_xr_step(context: pyopencl.Context) -> operation:
|
|||||||
|
|
||||||
xr_args = ', '.join(ptrs('x', 'p', 'r', 'v') + [ctype + ' alpha'])
|
xr_args = ', '.join(ptrs('x', 'p', 'r', 'v') + [ctype + ' alpha'])
|
||||||
|
|
||||||
xr_kernel = ElementwiseKernel(context,
|
xr_kernel = ElementwiseKernel(
|
||||||
name='XR',
|
context,
|
||||||
preamble=preamble,
|
name='XR',
|
||||||
operation=update_xr_source,
|
preamble=preamble,
|
||||||
arguments=xr_args)
|
operation=update_xr_source,
|
||||||
|
arguments=xr_args,
|
||||||
|
)
|
||||||
|
|
||||||
def xr_update(x, p, r, v, alpha, e):
|
def xr_update(
|
||||||
|
x: pyopencl.array.Array,
|
||||||
|
p: pyopencl.array.Array,
|
||||||
|
r: pyopencl.array.Array,
|
||||||
|
v: pyopencl.array.Array,
|
||||||
|
alpha: complex,
|
||||||
|
e: List[pyopencl.Event],
|
||||||
|
) -> List[pyopencl.Event]:
|
||||||
return [xr_kernel(x, p, r, v, alpha, wait_for=e)]
|
return [xr_kernel(x, p, r, v, alpha, wait_for=e)]
|
||||||
|
|
||||||
return xr_update
|
return xr_update
|
||||||
|
|
||||||
|
|
||||||
def create_rhoerr_step(context: pyopencl.Context) -> operation:
|
def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., Tuple[complex, complex]]:
|
||||||
"""
|
"""
|
||||||
Return a function
|
Return a function
|
||||||
ri_update(r, e)
|
ri_update(r, e)
|
||||||
@ -200,8 +244,11 @@ def create_rhoerr_step(context: pyopencl.Context) -> operation:
|
|||||||
after waiting for all pyopencl.Event in the list e
|
after waiting for all pyopencl.Event in the list e
|
||||||
and returns a list of pyopencl.Event
|
and returns a list of pyopencl.Event
|
||||||
|
|
||||||
:param context: PyOpenCL context
|
Args:
|
||||||
:return: Function for performing x and r updates
|
context: PyOpenCL context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function for performing x and r updates
|
||||||
"""
|
"""
|
||||||
|
|
||||||
update_ri_source = '''
|
update_ri_source = '''
|
||||||
@ -213,16 +260,18 @@ def create_rhoerr_step(context: pyopencl.Context) -> operation:
|
|||||||
# Use a vector type (double3) to make the reduction simpler
|
# Use a vector type (double3) to make the reduction simpler
|
||||||
ri_dtype = pyopencl.array.vec.double3
|
ri_dtype = pyopencl.array.vec.double3
|
||||||
|
|
||||||
ri_kernel = ReductionKernel(context,
|
ri_kernel = ReductionKernel(
|
||||||
name='RHOERR',
|
context,
|
||||||
preamble=preamble,
|
name='RHOERR',
|
||||||
dtype_out=ri_dtype,
|
preamble=preamble,
|
||||||
neutral='(double3)(0.0, 0.0, 0.0)',
|
dtype_out=ri_dtype,
|
||||||
map_expr=update_ri_source,
|
neutral='(double3)(0.0, 0.0, 0.0)',
|
||||||
reduce_expr='a+b',
|
map_expr=update_ri_source,
|
||||||
arguments=ctype + ' *r')
|
reduce_expr='a+b',
|
||||||
|
arguments=ctype + ' *r',
|
||||||
|
)
|
||||||
|
|
||||||
def ri_update(r, e):
|
def ri_update(r: pyopencl.array.Array, e: List[pyopencl.Event]) -> Tuple[complex, complex]:
|
||||||
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
|
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
|
||||||
rr, ri, ii = [g[q] for q in 'xyz']
|
rr, ri, ii = [g[q] for q in 'xyz']
|
||||||
rho = rr + 2j * ri - ii
|
rho = rr + 2j * ri - ii
|
||||||
@ -242,48 +291,66 @@ def create_p_step(context: pyopencl.Context) -> operation:
|
|||||||
after waiting for all pyopencl.Event in the list e
|
after waiting for all pyopencl.Event in the list e
|
||||||
and returns a list of pyopencl.Event
|
and returns a list of pyopencl.Event
|
||||||
|
|
||||||
:param context: PyOpenCL context
|
Args:
|
||||||
:return: Function for performing the p update
|
context: PyOpenCL context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function for performing the p update
|
||||||
"""
|
"""
|
||||||
update_p_source = '''
|
update_p_source = '''
|
||||||
p[i] = add(r[i], mul(beta, p[i]));
|
p[i] = add(r[i], mul(beta, p[i]));
|
||||||
'''
|
'''
|
||||||
p_args = ptrs('p', 'r') + [ctype + ' beta']
|
p_args = ptrs('p', 'r') + [ctype + ' beta']
|
||||||
|
|
||||||
p_kernel = ElementwiseKernel(context,
|
p_kernel = ElementwiseKernel(
|
||||||
name='P',
|
context,
|
||||||
preamble=preamble,
|
name='P',
|
||||||
operation=update_p_source,
|
preamble=preamble,
|
||||||
arguments=', '.join(p_args))
|
operation=update_p_source,
|
||||||
|
arguments=', '.join(p_args),
|
||||||
|
)
|
||||||
|
|
||||||
def p_update(p, r, beta, e):
|
def p_update(
|
||||||
|
p: pyopencl.array.Array,
|
||||||
|
r: pyopencl.array.Array,
|
||||||
|
beta: complex,
|
||||||
|
e: List[pyopencl.Event]) -> List[pyopencl.Event]:
|
||||||
return [p_kernel(p, r, beta, wait_for=e)]
|
return [p_kernel(p, r, beta, wait_for=e)]
|
||||||
|
|
||||||
return p_update
|
return p_update
|
||||||
|
|
||||||
|
|
||||||
def create_dot(context: pyopencl.Context) -> operation:
|
def create_dot(context: pyopencl.Context) -> Callable[..., complex]:
|
||||||
"""
|
"""
|
||||||
Return a function for performing the dot product
|
Return a function for performing the dot product
|
||||||
p @ v
|
p @ v
|
||||||
with the signature
|
with the signature
|
||||||
dot(p, v, e) -> float
|
dot(p, v, e) -> complex
|
||||||
|
|
||||||
:param context: PyOpenCL context
|
Args:
|
||||||
:return: Function for performing the dot product
|
context: PyOpenCL context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function for performing the dot product
|
||||||
"""
|
"""
|
||||||
dot_dtype = numpy.complex128
|
dot_dtype = numpy.complex128
|
||||||
|
|
||||||
dot_kernel = ReductionKernel(context,
|
dot_kernel = ReductionKernel(
|
||||||
name='dot',
|
context,
|
||||||
preamble=preamble,
|
name='dot',
|
||||||
dtype_out=dot_dtype,
|
preamble=preamble,
|
||||||
neutral='zero',
|
dtype_out=dot_dtype,
|
||||||
map_expr='mul(p[i], v[i])',
|
neutral='zero',
|
||||||
reduce_expr='add(a, b)',
|
map_expr='mul(p[i], v[i])',
|
||||||
arguments=ptrs('p', 'v'))
|
reduce_expr='add(a, b)',
|
||||||
|
arguments=ptrs('p', 'v'),
|
||||||
|
)
|
||||||
|
|
||||||
def dot(p, v, e):
|
def dot(
|
||||||
|
p: pyopencl.array.Array,
|
||||||
|
v: pyopencl.array.Array,
|
||||||
|
e: List[pyopencl.Event],
|
||||||
|
) -> complex:
|
||||||
g = dot_kernel(p, v, wait_for=e)
|
g = dot_kernel(p, v, wait_for=e)
|
||||||
return g.get()
|
return g.get()
|
||||||
|
|
||||||
@ -304,8 +371,11 @@ def create_a_csr(context: pyopencl.Context) -> operation:
|
|||||||
The function waits on all the pyopencl.Event in e before running, and returns
|
The function waits on all the pyopencl.Event in e before running, and returns
|
||||||
a list of pyopencl.Event.
|
a list of pyopencl.Event.
|
||||||
|
|
||||||
:param context: PyOpenCL context
|
Args:
|
||||||
:return: Function for sparse (M @ v) operation where M is in CSR format
|
context: PyOpenCL context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function for sparse (M @ v) operation where M is in CSR format
|
||||||
"""
|
"""
|
||||||
spmv_source = '''
|
spmv_source = '''
|
||||||
int start = m_row_ptr[i];
|
int start = m_row_ptr[i];
|
||||||
@ -326,13 +396,20 @@ def create_a_csr(context: pyopencl.Context) -> operation:
|
|||||||
m_args = 'int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data'
|
m_args = 'int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data'
|
||||||
v_in_args = ctype + ' *v_in'
|
v_in_args = ctype + ' *v_in'
|
||||||
|
|
||||||
spmv_kernel = ElementwiseKernel(context,
|
spmv_kernel = ElementwiseKernel(
|
||||||
name='csr_spmv',
|
context,
|
||||||
preamble=preamble,
|
name='csr_spmv',
|
||||||
operation=spmv_source,
|
preamble=preamble,
|
||||||
arguments=', '.join((v_out_args, m_args, v_in_args)))
|
operation=spmv_source,
|
||||||
|
arguments=', '.join((v_out_args, m_args, v_in_args)),
|
||||||
|
)
|
||||||
|
|
||||||
def spmv(v_out, m, v_in, e):
|
def spmv(
|
||||||
|
v_out,
|
||||||
|
m,
|
||||||
|
v_in,
|
||||||
|
e: List[pyopencl.Event],
|
||||||
|
) -> List[pyopencl.Event]:
|
||||||
return [spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)]
|
return [spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)]
|
||||||
|
|
||||||
return spmv
|
return spmv
|
||||||
|
Loading…
Reference in New Issue
Block a user