improve type annotations, formatting, comment styles

This commit is contained in:
Jan Petykiewicz 2022-11-20 21:57:43 -08:00
parent 81bb1dd2c0
commit efeb29479b
3 changed files with 261 additions and 162 deletions

View File

@ -14,14 +14,16 @@ satisfy the constraints for the 'conjugate gradient' algorithm
(positive definite, symmetric) and some that don't. (positive definite, symmetric) and some that don't.
""" """
from typing import Dict, Any from typing import Dict, Any, Optional
import time import time
import logging import logging
import numpy import numpy
from numpy.typing import NDArray, ArrayLike
from numpy.linalg import norm from numpy.linalg import norm
import pyopencl import pyopencl
import pyopencl.array import pyopencl.array
import scipy
import meanas.fdfd.solvers import meanas.fdfd.solvers
@ -33,39 +35,45 @@ __author__ = 'Jan Petykiewicz'
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CSRMatrix(object): class CSRMatrix:
""" """
Matrix stored in Compressed Sparse Row format, in GPU RAM. Matrix stored in Compressed Sparse Row format, in GPU RAM.
""" """
row_ptr = None # type: pyopencl.array.Array row_ptr: pyopencl.array.Array
col_ind = None # type: pyopencl.array.Array col_ind: pyopencl.array.Array
data = None # type: pyopencl.array.Array data: pyopencl.array.Array
def __init__(self, def __init__(
self,
queue: pyopencl.CommandQueue, queue: pyopencl.CommandQueue,
m: 'scipy.sparse.csr_matrix'): m: 'scipy.sparse.csr_matrix',
) -> None:
self.row_ptr = pyopencl.array.to_device(queue, m.indptr) self.row_ptr = pyopencl.array.to_device(queue, m.indptr)
self.col_ind = pyopencl.array.to_device(queue, m.indices) self.col_ind = pyopencl.array.to_device(queue, m.indices)
self.data = pyopencl.array.to_device(queue, m.data.astype(numpy.complex128)) self.data = pyopencl.array.to_device(queue, m.data.astype(numpy.complex128))
def cg(A: 'scipy.sparse.csr_matrix', def cg(
b: numpy.ndarray, A: 'scipy.sparse.csr_matrix',
b: ArrayLike,
max_iters: int = 10000, max_iters: int = 10000,
err_threshold: float = 1e-6, err_threshold: float = 1e-6,
context: pyopencl.Context = None, context: Optional[pyopencl.Context] = None,
queue: pyopencl.CommandQueue = None, queue: Optional[pyopencl.CommandQueue] = None,
) -> numpy.ndarray: ) -> NDArray:
""" """
General conjugate-gradient solver for sparse matrices, where A @ x = b. General conjugate-gradient solver for sparse matrices, where A @ x = b.
:param A: Matrix to solve (CSR format) Args:
:param b: Right-hand side vector (dense ndarray) A: Matrix to solve (CSR format)
:param max_iters: Maximum number of iterations b: Right-hand side vector (dense ndarray)
:param err_threshold: Error threshold for successful solve, relative to norm(b) max_iters: Maximum number of iterations
:param context: PyOpenCL context. Will be created if not given. err_threshold: Error threshold for successful solve, relative to norm(b)
:param queue: PyOpenCL command queue. Will be created if not given. context: PyOpenCL context. Will be created if not given.
:return: Solution vector x; returned even if solve doesn't converge. queue: PyOpenCL command queue. Will be created if not given.
Returns:
Solution vector x; returned even if solve doesn't converge.
""" """
start_time = time.perf_counter() start_time = time.perf_counter()
@ -151,29 +159,37 @@ def cg(A: 'scipy.sparse.csr_matrix',
return x return x
def fdfd_cg_solver(solver_opts: Dict[str, Any] = None, def fdfd_cg_solver(
solver_opts: Optional[Dict[str, Any]] = None,
**fdfd_args **fdfd_args
) -> numpy.ndarray: ) -> NDArray:
""" """
Conjugate gradient FDFD solver using CSR sparse matrices, mainly for Conjugate gradient FDFD solver using CSR sparse matrices, mainly for
testing and development since it's much slower than the solver in main.py. testing and development since it's much slower than the solver in main.py.
Calls meanas.fdfd.solvers.generic(**fdfd_args, Calls meanas.fdfd.solvers.generic(
**fdfd_args,
matrix_solver=opencl_fdfd.csr.cg, matrix_solver=opencl_fdfd.csr.cg,
matrix_solver_opts=solver_opts) matrix_solver_opts=solver_opts,
)
:param solver_opts: Passed as matrix_solver_opts to fdfd_tools.solver.generic(...). Args:
solver_opts: Passed as matrix_solver_opts to fdfd_tools.solver.generic(...).
Default {}. Default {}.
:param fdfd_args: Passed as **fdfd_args to fdfd_tools.solver.generic(...). fdfd_args: Passed as **fdfd_args to fdfd_tools.solver.generic(...).
Should include all of the arguments **except** matrix_solver and matrix_solver_opts Should include all of the arguments **except** matrix_solver and matrix_solver_opts
:return: E-field which solves the system.
Returns:
E-field which solves the system.
""" """
if solver_opts is None: if solver_opts is None:
solver_opts = dict() solver_opts = dict()
x = meanas.fdfd.solvers.generic(matrix_solver=cg, x = meanas.fdfd.solvers.generic(
matrix_solver=cg,
matrix_solver_opts=solver_opts, matrix_solver_opts=solver_opts,
**fdfd_args) **fdfd_args,
)
return x return x

View File

@ -6,11 +6,12 @@ operator implemented directly as OpenCL arithmetic (rather than as
a matrix). a matrix).
""" """
from typing import List from typing import List, Optional, cast
import time import time
import logging import logging
import numpy import numpy
from numpy.typing import NDArray, ArrayLike
from numpy.linalg import norm from numpy.linalg import norm
import pyopencl import pyopencl
import pyopencl.array import pyopencl.array
@ -25,18 +26,19 @@ __author__ = 'Jan Petykiewicz'
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def cg_solver(omega: complex, def cg_solver(
dxes: List[List[numpy.ndarray]], omega: complex,
J: numpy.ndarray, dxes: List[List[NDArray]],
epsilon: numpy.ndarray, J: ArrayLike,
mu: numpy.ndarray = None, epsilon: ArrayLike,
pec: numpy.ndarray = None, mu: Optional[ArrayLike] = None,
pmc: numpy.ndarray = None, pec: Optional[ArrayLike] = None,
pmc: Optional[ArrayLike] = None,
adjoint: bool = False, adjoint: bool = False,
max_iters: int = 40000, max_iters: int = 40000,
err_threshold: float = 1e-6, err_threshold: float = 1e-6,
context: pyopencl.Context = None, context: Optional[pyopencl.Context] = None,
) -> numpy.ndarray: ) -> NDArray:
""" """
OpenCL FDFD solver using the iterative conjugate gradient (cg) method OpenCL FDFD solver using the iterative conjugate gradient (cg) method
and implementing the diagonalized E-field wave operator directly in and implementing the diagonalized E-field wave operator directly in
@ -46,28 +48,30 @@ def cg_solver(omega: complex,
either use meanas.fdmath.vec() or numpy: either use meanas.fdmath.vec() or numpy:
f_1D = numpy.hstack(tuple((fi.flatten(order='F') for fi in [f_x, f_y, f_z]))) f_1D = numpy.hstack(tuple((fi.flatten(order='F') for fi in [f_x, f_y, f_z])))
:param omega: Complex frequency to solve at. Args:
:param dxes: [[dx_e, dy_e, dz_e], [dx_h, dy_h, dz_h]] (complex cell sizes) omega: Complex frequency to solve at.
:param J: Electric current distribution (at E-field locations) dxes: [[dx_e, dy_e, dz_e], [dx_h, dy_h, dz_h]] (complex cell sizes)
:param epsilon: Dielectric constant distribution (at E-field locations) J: Electric current distribution (at E-field locations)
:param mu: Magnetic permeability distribution (at H-field locations) epsilon: Dielectric constant distribution (at E-field locations)
:param pec: Perfect electric conductor distribution mu: Magnetic permeability distribution (at H-field locations)
pec: Perfect electric conductor distribution
(at E-field locations; non-zero value indicates PEC is present) (at E-field locations; non-zero value indicates PEC is present)
:param pmc: Perfect magnetic conductor distribution pmc: Perfect magnetic conductor distribution
(at H-field locations; non-zero value indicates PMC is present) (at H-field locations; non-zero value indicates PMC is present)
:param adjoint: If true, solves the adjoint problem. adjoint: If true, solves the adjoint problem.
:param max_iters: Maximum number of iterations. Default 40,000. max_iters: Maximum number of iterations. Default 40,000.
:param err_threshold: If (r @ r.conj()) / norm(1j * omega * J) < err_threshold, success. err_threshold: If (r @ r.conj()) / norm(1j * omega * J) < err_threshold, success.
Default 1e-6. Default 1e-6.
:param context: PyOpenCL context to run in. If not given, construct a new context. context: PyOpenCL context to run in. If not given, construct a new context.
:return: E-field which solves the system. Returned even if we did not converge.
"""
Returns:
E-field which solves the system. Returned even if we did not converge.
"""
start_time = time.perf_counter() start_time = time.perf_counter()
b = -1j * omega * J shape = [dd.size for dd in dxes[0]]
shape = [d.size for d in dxes[0]] b = -1j * omega * numpy.array(J, copy=False)
''' '''
** In this comment, I use the following notation: ** In this comment, I use the following notation:
@ -96,9 +100,10 @@ def cg_solver(omega: complex,
We can accomplish all this simply by conjugating everything (except J) and We can accomplish all this simply by conjugating everything (except J) and
reversing the order of L and R reversing the order of L and R
''' '''
epsilon = numpy.array(epsilon, copy=False)
if adjoint: if adjoint:
# Conjugate everything # Conjugate everything
dxes = [[numpy.conj(d) for d in dd] for dd in dxes] dxes = [[numpy.conj(dd) for dd in dds] for dds in dxes]
omega = numpy.conj(omega) omega = numpy.conj(omega)
epsilon = numpy.conj(epsilon) epsilon = numpy.conj(epsilon)
if mu is not None: if mu is not None:
@ -132,7 +137,7 @@ def cg_solver(omega: complex,
rho = 1.0 + 0j rho = 1.0 + 0j
errs = [] errs = []
inv_dxes = [[load_field(1 / d) for d in dd] for dd in dxes] inv_dxes = [[load_field(1 / numpy.array(dd, copy=False)) for dd in dds] for dds in dxes]
oeps = load_field(-omega ** 2 * epsilon) oeps = load_field(-omega ** 2 * epsilon)
Pl = load_field(L.diagonal()) Pl = load_field(L.diagonal())
Pr = load_field(R.diagonal()) Pr = load_field(R.diagonal())
@ -140,17 +145,18 @@ def cg_solver(omega: complex,
if mu is None: if mu is None:
invm = load_field(numpy.array([])) invm = load_field(numpy.array([]))
else: else:
invm = load_field(1 / mu) invm = load_field(1 / numpy.array(mu, copy=False))
mu = numpy.array(mu, copy=False)
if pec is None: if pec is None:
gpec = load_field(numpy.array([]), dtype=numpy.int8) gpec = load_field(numpy.array([]), dtype=numpy.int8)
else: else:
gpec = load_field(pec.astype(bool), dtype=numpy.int8) gpec = load_field(numpy.array(pec, dtype=bool, copy=False), dtype=numpy.int8)
if pmc is None: if pmc is None:
gpmc = load_field(numpy.array([]), dtype=numpy.int8) gpmc = load_field(numpy.array([]), dtype=numpy.int8)
else: else:
gpmc = load_field(pmc.astype(bool), dtype=numpy.int8) gpmc = load_field(numpy.array(pmc, dtype=bool, copy=False), dtype=numpy.int8)
''' '''
Generate OpenCL kernels Generate OpenCL kernels

View File

@ -7,10 +7,11 @@ kernels for use by the other solvers.
See kernels/ for any of the .cl files loaded in this file. See kernels/ for any of the .cl files loaded in this file.
""" """
from typing import List, Callable from typing import List, Callable, Union, Type, Sequence, Optional, Tuple
import logging import logging
import numpy import numpy
from numpy.typing import NDArray, ArrayLike
import jinja2 import jinja2
import pyopencl import pyopencl
@ -28,12 +29,17 @@ jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__, 'kernels'))
operation = Callable[..., List[pyopencl.Event]] operation = Callable[..., List[pyopencl.Event]]
def type_to_C(float_type: numpy.float32 or numpy.float64) -> str: def type_to_C(
float_type: Type,
) -> str:
""" """
Returns a string corresponding to the C equivalent of a numpy type. Returns a string corresponding to the C equivalent of a numpy type.
:param float_type: numpy type: float32, float64, complex64, complex128 Args:
:return: string containing the corresponding C type (eg. 'double') float_type: numpy type: float32, float64, complex64, complex128
Returns:
string containing the corresponding C type (eg. 'double')
""" """
types = { types = {
numpy.float32: 'float', numpy.float32: 'float',
@ -68,8 +74,9 @@ def ptrs(*args: str) -> List[str]:
return [ctype + ' *' + s for s in args] return [ctype + ' *' + s for s in args]
def create_a(context: pyopencl.Context, def create_a(
shape: numpy.ndarray, context: pyopencl.Context,
shape: ArrayLike,
mu: bool = False, mu: bool = False,
pec: bool = False, pec: bool = False,
pmc: bool = False, pmc: bool = False,
@ -94,12 +101,15 @@ def create_a(context: pyopencl.Context,
and returns a list of pyopencl.Event. and returns a list of pyopencl.Event.
:param context: PyOpenCL context Args:
:param shape: Dimensions of the E-field context: PyOpenCL context
:param mu: False iff (mu == 1) everywhere shape: Dimensions of the E-field
:param pec: False iff no PEC anywhere mu: False iff (mu == 1) everywhere
:param pmc: False iff no PMC anywhere pec: False iff no PEC anywhere
:return: Function for computing (A @ p) pmc: False iff no PMC anywhere
Returns:
Function for computing (A @ p)
""" """
common_source = jinja_env.get_template('common.cl').render(shape=shape) common_source = jinja_env.get_template('common.cl').render(shape=shape)
@ -113,45 +123,67 @@ def create_a(context: pyopencl.Context,
Convert p to initial E (ie, apply right preconditioner and PEC) Convert p to initial E (ie, apply right preconditioner and PEC)
''' '''
p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec) p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec)
P2E_kernel = ElementwiseKernel(context, P2E_kernel = ElementwiseKernel(
context,
name='P2E', name='P2E',
preamble=preamble, preamble=preamble,
operation=p2e_source, operation=p2e_source,
arguments=', '.join(ptrs('E', 'p', 'Pr') + pec_arg)) arguments=', '.join(ptrs('E', 'p', 'Pr') + pec_arg),
)
''' '''
Calculate intermediate H from intermediate E Calculate intermediate H from intermediate E
''' '''
e2h_source = jinja_env.get_template('e2h.cl').render(mu=mu, e2h_source = jinja_env.get_template('e2h.cl').render(
mu=mu,
pmc=pmc, pmc=pmc,
common_cl=common_source) common_cl=common_source,
E2H_kernel = ElementwiseKernel(context, )
E2H_kernel = ElementwiseKernel(
context,
name='E2H', name='E2H',
preamble=preamble, preamble=preamble,
operation=e2h_source, operation=e2h_source,
arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des)) arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des),
)
''' '''
Calculate final E (including left preconditioner) Calculate final E (including left preconditioner)
''' '''
h2e_source = jinja_env.get_template('h2e.cl').render(pec=pec, h2e_source = jinja_env.get_template('h2e.cl').render(
common_cl=common_source) pec=pec,
H2E_kernel = ElementwiseKernel(context, common_cl=common_source,
)
H2E_kernel = ElementwiseKernel(
context,
name='H2E', name='H2E',
preamble=preamble, preamble=preamble,
operation=h2e_source, operation=h2e_source,
arguments=', '.join(ptrs('E', 'H', 'oeps', 'Pl') + pec_arg + dhs)) arguments=', '.join(ptrs('E', 'H', 'oeps', 'Pl') + pec_arg + dhs),
)
def spmv(E, H, p, idxes, oeps, inv_mu, pec, pmc, Pl, Pr, e): def spmv(
E: pyopencl.array.Array,
H: pyopencl.array.Array,
p: pyopencl.array.Array,
idxes: Sequence[Sequence[pyopencl.array.Array]],
oeps: pyopencl.array.Array,
inv_mu: Optional[pyopencl.array.Array],
pec: Optional[pyopencl.array.Array],
pmc: Optional[pyopencl.array.Array],
Pl: pyopencl.array.Array,
Pr: pyopencl.array.Array,
e: List[pyopencl.Event],
) -> List[pyopencl.Event]:
e2 = P2E_kernel(E, p, Pr, pec, wait_for=e) e2 = P2E_kernel(E, p, Pr, pec, wait_for=e)
e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2]) e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2])
e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2]) e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2])
return [e2] return [e2]
logger.debug('Preamble: \n{}'.format(preamble)) logger.debug(f'Preamble: \n{preamble}')
logger.debug('p2e: \n{}'.format(p2e_source)) logger.debug(f'p2e: \n{p2e_source}')
logger.debug('e2h: \n{}'.format(e2h_source)) logger.debug(f'e2h: \n{e2h_source}')
logger.debug('h2e: \n{}'.format(h2e_source)) logger.debug(f'h2e: \n{h2e_source}')
return spmv return spmv
@ -167,8 +199,11 @@ def create_xr_step(context: pyopencl.Context) -> operation:
after waiting for all in the list e after waiting for all in the list e
and returns a list of pyopencl.Event and returns a list of pyopencl.Event
:param context: PyOpenCL context Args:
:return: Function for performing x and r updates context: PyOpenCL context
Returns:
Function for performing x and r updates
""" """
update_xr_source = ''' update_xr_source = '''
x[i] = add(x[i], mul(alpha, p[i])); x[i] = add(x[i], mul(alpha, p[i]));
@ -177,19 +212,28 @@ def create_xr_step(context: pyopencl.Context) -> operation:
xr_args = ', '.join(ptrs('x', 'p', 'r', 'v') + [ctype + ' alpha']) xr_args = ', '.join(ptrs('x', 'p', 'r', 'v') + [ctype + ' alpha'])
xr_kernel = ElementwiseKernel(context, xr_kernel = ElementwiseKernel(
context,
name='XR', name='XR',
preamble=preamble, preamble=preamble,
operation=update_xr_source, operation=update_xr_source,
arguments=xr_args) arguments=xr_args,
)
def xr_update(x, p, r, v, alpha, e): def xr_update(
x: pyopencl.array.Array,
p: pyopencl.array.Array,
r: pyopencl.array.Array,
v: pyopencl.array.Array,
alpha: complex,
e: List[pyopencl.Event],
) -> List[pyopencl.Event]:
return [xr_kernel(x, p, r, v, alpha, wait_for=e)] return [xr_kernel(x, p, r, v, alpha, wait_for=e)]
return xr_update return xr_update
def create_rhoerr_step(context: pyopencl.Context) -> operation: def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., Tuple[complex, complex]]:
""" """
Return a function Return a function
ri_update(r, e) ri_update(r, e)
@ -200,8 +244,11 @@ def create_rhoerr_step(context: pyopencl.Context) -> operation:
after waiting for all pyopencl.Event in the list e after waiting for all pyopencl.Event in the list e
and returns a list of pyopencl.Event and returns a list of pyopencl.Event
:param context: PyOpenCL context Args:
:return: Function for performing x and r updates context: PyOpenCL context
Returns:
Function for performing x and r updates
""" """
update_ri_source = ''' update_ri_source = '''
@ -213,16 +260,18 @@ def create_rhoerr_step(context: pyopencl.Context) -> operation:
# Use a vector type (double3) to make the reduction simpler # Use a vector type (double3) to make the reduction simpler
ri_dtype = pyopencl.array.vec.double3 ri_dtype = pyopencl.array.vec.double3
ri_kernel = ReductionKernel(context, ri_kernel = ReductionKernel(
context,
name='RHOERR', name='RHOERR',
preamble=preamble, preamble=preamble,
dtype_out=ri_dtype, dtype_out=ri_dtype,
neutral='(double3)(0.0, 0.0, 0.0)', neutral='(double3)(0.0, 0.0, 0.0)',
map_expr=update_ri_source, map_expr=update_ri_source,
reduce_expr='a+b', reduce_expr='a+b',
arguments=ctype + ' *r') arguments=ctype + ' *r',
)
def ri_update(r, e): def ri_update(r: pyopencl.array.Array, e: List[pyopencl.Event]) -> Tuple[complex, complex]:
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get() g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
rr, ri, ii = [g[q] for q in 'xyz'] rr, ri, ii = [g[q] for q in 'xyz']
rho = rr + 2j * ri - ii rho = rr + 2j * ri - ii
@ -242,48 +291,66 @@ def create_p_step(context: pyopencl.Context) -> operation:
after waiting for all pyopencl.Event in the list e after waiting for all pyopencl.Event in the list e
and returns a list of pyopencl.Event and returns a list of pyopencl.Event
:param context: PyOpenCL context Args:
:return: Function for performing the p update context: PyOpenCL context
Returns:
Function for performing the p update
""" """
update_p_source = ''' update_p_source = '''
p[i] = add(r[i], mul(beta, p[i])); p[i] = add(r[i], mul(beta, p[i]));
''' '''
p_args = ptrs('p', 'r') + [ctype + ' beta'] p_args = ptrs('p', 'r') + [ctype + ' beta']
p_kernel = ElementwiseKernel(context, p_kernel = ElementwiseKernel(
context,
name='P', name='P',
preamble=preamble, preamble=preamble,
operation=update_p_source, operation=update_p_source,
arguments=', '.join(p_args)) arguments=', '.join(p_args),
)
def p_update(p, r, beta, e): def p_update(
p: pyopencl.array.Array,
r: pyopencl.array.Array,
beta: complex,
e: List[pyopencl.Event]) -> List[pyopencl.Event]:
return [p_kernel(p, r, beta, wait_for=e)] return [p_kernel(p, r, beta, wait_for=e)]
return p_update return p_update
def create_dot(context: pyopencl.Context) -> operation: def create_dot(context: pyopencl.Context) -> Callable[..., complex]:
""" """
Return a function for performing the dot product Return a function for performing the dot product
p @ v p @ v
with the signature with the signature
dot(p, v, e) -> float dot(p, v, e) -> complex
:param context: PyOpenCL context Args:
:return: Function for performing the dot product context: PyOpenCL context
Returns:
Function for performing the dot product
""" """
dot_dtype = numpy.complex128 dot_dtype = numpy.complex128
dot_kernel = ReductionKernel(context, dot_kernel = ReductionKernel(
context,
name='dot', name='dot',
preamble=preamble, preamble=preamble,
dtype_out=dot_dtype, dtype_out=dot_dtype,
neutral='zero', neutral='zero',
map_expr='mul(p[i], v[i])', map_expr='mul(p[i], v[i])',
reduce_expr='add(a, b)', reduce_expr='add(a, b)',
arguments=ptrs('p', 'v')) arguments=ptrs('p', 'v'),
)
def dot(p, v, e): def dot(
p: pyopencl.array.Array,
v: pyopencl.array.Array,
e: List[pyopencl.Event],
) -> complex:
g = dot_kernel(p, v, wait_for=e) g = dot_kernel(p, v, wait_for=e)
return g.get() return g.get()
@ -304,8 +371,11 @@ def create_a_csr(context: pyopencl.Context) -> operation:
The function waits on all the pyopencl.Event in e before running, and returns The function waits on all the pyopencl.Event in e before running, and returns
a list of pyopencl.Event. a list of pyopencl.Event.
:param context: PyOpenCL context Args:
:return: Function for sparse (M @ v) operation where M is in CSR format context: PyOpenCL context
Returns:
Function for sparse (M @ v) operation where M is in CSR format
""" """
spmv_source = ''' spmv_source = '''
int start = m_row_ptr[i]; int start = m_row_ptr[i];
@ -326,13 +396,20 @@ def create_a_csr(context: pyopencl.Context) -> operation:
m_args = 'int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data' m_args = 'int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data'
v_in_args = ctype + ' *v_in' v_in_args = ctype + ' *v_in'
spmv_kernel = ElementwiseKernel(context, spmv_kernel = ElementwiseKernel(
context,
name='csr_spmv', name='csr_spmv',
preamble=preamble, preamble=preamble,
operation=spmv_source, operation=spmv_source,
arguments=', '.join((v_out_args, m_args, v_in_args))) arguments=', '.join((v_out_args, m_args, v_in_args)),
)
def spmv(v_out, m, v_in, e): def spmv(
v_out,
m,
v_in,
e: List[pyopencl.Event],
) -> List[pyopencl.Event]:
return [spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)] return [spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)]
return spmv return spmv