improve type annotations, formatting, comment styles
This commit is contained in:
parent
81bb1dd2c0
commit
efeb29479b
3 changed files with 268 additions and 169 deletions
|
|
@ -7,10 +7,11 @@ kernels for use by the other solvers.
|
|||
See kernels/ for any of the .cl files loaded in this file.
|
||||
"""
|
||||
|
||||
from typing import List, Callable
|
||||
from typing import List, Callable, Union, Type, Sequence, Optional, Tuple
|
||||
import logging
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray, ArrayLike
|
||||
import jinja2
|
||||
|
||||
import pyopencl
|
||||
|
|
@ -28,12 +29,17 @@ jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__, 'kernels'))
|
|||
operation = Callable[..., List[pyopencl.Event]]
|
||||
|
||||
|
||||
def type_to_C(float_type: numpy.float32 or numpy.float64) -> str:
|
||||
def type_to_C(
|
||||
float_type: Type,
|
||||
) -> str:
|
||||
"""
|
||||
Returns a string corresponding to the C equivalent of a numpy type.
|
||||
|
||||
:param float_type: numpy type: float32, float64, complex64, complex128
|
||||
:return: string containing the corresponding C type (eg. 'double')
|
||||
Args:
|
||||
float_type: numpy type: float32, float64, complex64, complex128
|
||||
|
||||
Returns:
|
||||
string containing the corresponding C type (eg. 'double')
|
||||
"""
|
||||
types = {
|
||||
numpy.float32: 'float',
|
||||
|
|
@ -68,12 +74,13 @@ def ptrs(*args: str) -> List[str]:
|
|||
return [ctype + ' *' + s for s in args]
|
||||
|
||||
|
||||
def create_a(context: pyopencl.Context,
|
||||
shape: numpy.ndarray,
|
||||
mu: bool = False,
|
||||
pec: bool = False,
|
||||
pmc: bool = False,
|
||||
) -> operation:
|
||||
def create_a(
|
||||
context: pyopencl.Context,
|
||||
shape: ArrayLike,
|
||||
mu: bool = False,
|
||||
pec: bool = False,
|
||||
pmc: bool = False,
|
||||
) -> operation:
|
||||
"""
|
||||
Return a function which performs (A @ p), where A is the FDFD wave equation for E-field.
|
||||
|
||||
|
|
@ -94,12 +101,15 @@ def create_a(context: pyopencl.Context,
|
|||
|
||||
and returns a list of pyopencl.Event.
|
||||
|
||||
:param context: PyOpenCL context
|
||||
:param shape: Dimensions of the E-field
|
||||
:param mu: False iff (mu == 1) everywhere
|
||||
:param pec: False iff no PEC anywhere
|
||||
:param pmc: False iff no PMC anywhere
|
||||
:return: Function for computing (A @ p)
|
||||
Args:
|
||||
context: PyOpenCL context
|
||||
shape: Dimensions of the E-field
|
||||
mu: False iff (mu == 1) everywhere
|
||||
pec: False iff no PEC anywhere
|
||||
pmc: False iff no PMC anywhere
|
||||
|
||||
Returns:
|
||||
Function for computing (A @ p)
|
||||
"""
|
||||
|
||||
common_source = jinja_env.get_template('common.cl').render(shape=shape)
|
||||
|
|
@ -113,45 +123,67 @@ def create_a(context: pyopencl.Context,
|
|||
Convert p to initial E (ie, apply right preconditioner and PEC)
|
||||
'''
|
||||
p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec)
|
||||
P2E_kernel = ElementwiseKernel(context,
|
||||
name='P2E',
|
||||
preamble=preamble,
|
||||
operation=p2e_source,
|
||||
arguments=', '.join(ptrs('E', 'p', 'Pr') + pec_arg))
|
||||
P2E_kernel = ElementwiseKernel(
|
||||
context,
|
||||
name='P2E',
|
||||
preamble=preamble,
|
||||
operation=p2e_source,
|
||||
arguments=', '.join(ptrs('E', 'p', 'Pr') + pec_arg),
|
||||
)
|
||||
|
||||
'''
|
||||
Calculate intermediate H from intermediate E
|
||||
'''
|
||||
e2h_source = jinja_env.get_template('e2h.cl').render(mu=mu,
|
||||
pmc=pmc,
|
||||
common_cl=common_source)
|
||||
E2H_kernel = ElementwiseKernel(context,
|
||||
name='E2H',
|
||||
preamble=preamble,
|
||||
operation=e2h_source,
|
||||
arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des))
|
||||
e2h_source = jinja_env.get_template('e2h.cl').render(
|
||||
mu=mu,
|
||||
pmc=pmc,
|
||||
common_cl=common_source,
|
||||
)
|
||||
E2H_kernel = ElementwiseKernel(
|
||||
context,
|
||||
name='E2H',
|
||||
preamble=preamble,
|
||||
operation=e2h_source,
|
||||
arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des),
|
||||
)
|
||||
|
||||
'''
|
||||
Calculate final E (including left preconditioner)
|
||||
'''
|
||||
h2e_source = jinja_env.get_template('h2e.cl').render(pec=pec,
|
||||
common_cl=common_source)
|
||||
H2E_kernel = ElementwiseKernel(context,
|
||||
name='H2E',
|
||||
preamble=preamble,
|
||||
operation=h2e_source,
|
||||
arguments=', '.join(ptrs('E', 'H', 'oeps', 'Pl') + pec_arg + dhs))
|
||||
h2e_source = jinja_env.get_template('h2e.cl').render(
|
||||
pec=pec,
|
||||
common_cl=common_source,
|
||||
)
|
||||
H2E_kernel = ElementwiseKernel(
|
||||
context,
|
||||
name='H2E',
|
||||
preamble=preamble,
|
||||
operation=h2e_source,
|
||||
arguments=', '.join(ptrs('E', 'H', 'oeps', 'Pl') + pec_arg + dhs),
|
||||
)
|
||||
|
||||
def spmv(E, H, p, idxes, oeps, inv_mu, pec, pmc, Pl, Pr, e):
|
||||
def spmv(
|
||||
E: pyopencl.array.Array,
|
||||
H: pyopencl.array.Array,
|
||||
p: pyopencl.array.Array,
|
||||
idxes: Sequence[Sequence[pyopencl.array.Array]],
|
||||
oeps: pyopencl.array.Array,
|
||||
inv_mu: Optional[pyopencl.array.Array],
|
||||
pec: Optional[pyopencl.array.Array],
|
||||
pmc: Optional[pyopencl.array.Array],
|
||||
Pl: pyopencl.array.Array,
|
||||
Pr: pyopencl.array.Array,
|
||||
e: List[pyopencl.Event],
|
||||
) -> List[pyopencl.Event]:
|
||||
e2 = P2E_kernel(E, p, Pr, pec, wait_for=e)
|
||||
e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2])
|
||||
e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2])
|
||||
return [e2]
|
||||
|
||||
logger.debug('Preamble: \n{}'.format(preamble))
|
||||
logger.debug('p2e: \n{}'.format(p2e_source))
|
||||
logger.debug('e2h: \n{}'.format(e2h_source))
|
||||
logger.debug('h2e: \n{}'.format(h2e_source))
|
||||
logger.debug(f'Preamble: \n{preamble}')
|
||||
logger.debug(f'p2e: \n{p2e_source}')
|
||||
logger.debug(f'e2h: \n{e2h_source}')
|
||||
logger.debug(f'h2e: \n{h2e_source}')
|
||||
|
||||
return spmv
|
||||
|
||||
|
|
@ -167,8 +199,11 @@ def create_xr_step(context: pyopencl.Context) -> operation:
|
|||
after waiting for all in the list e
|
||||
and returns a list of pyopencl.Event
|
||||
|
||||
:param context: PyOpenCL context
|
||||
:return: Function for performing x and r updates
|
||||
Args:
|
||||
context: PyOpenCL context
|
||||
|
||||
Returns:
|
||||
Function for performing x and r updates
|
||||
"""
|
||||
update_xr_source = '''
|
||||
x[i] = add(x[i], mul(alpha, p[i]));
|
||||
|
|
@ -177,19 +212,28 @@ def create_xr_step(context: pyopencl.Context) -> operation:
|
|||
|
||||
xr_args = ', '.join(ptrs('x', 'p', 'r', 'v') + [ctype + ' alpha'])
|
||||
|
||||
xr_kernel = ElementwiseKernel(context,
|
||||
name='XR',
|
||||
preamble=preamble,
|
||||
operation=update_xr_source,
|
||||
arguments=xr_args)
|
||||
xr_kernel = ElementwiseKernel(
|
||||
context,
|
||||
name='XR',
|
||||
preamble=preamble,
|
||||
operation=update_xr_source,
|
||||
arguments=xr_args,
|
||||
)
|
||||
|
||||
def xr_update(x, p, r, v, alpha, e):
|
||||
def xr_update(
|
||||
x: pyopencl.array.Array,
|
||||
p: pyopencl.array.Array,
|
||||
r: pyopencl.array.Array,
|
||||
v: pyopencl.array.Array,
|
||||
alpha: complex,
|
||||
e: List[pyopencl.Event],
|
||||
) -> List[pyopencl.Event]:
|
||||
return [xr_kernel(x, p, r, v, alpha, wait_for=e)]
|
||||
|
||||
return xr_update
|
||||
|
||||
|
||||
def create_rhoerr_step(context: pyopencl.Context) -> operation:
|
||||
def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., Tuple[complex, complex]]:
|
||||
"""
|
||||
Return a function
|
||||
ri_update(r, e)
|
||||
|
|
@ -200,8 +244,11 @@ def create_rhoerr_step(context: pyopencl.Context) -> operation:
|
|||
after waiting for all pyopencl.Event in the list e
|
||||
and returns a list of pyopencl.Event
|
||||
|
||||
:param context: PyOpenCL context
|
||||
:return: Function for performing x and r updates
|
||||
Args:
|
||||
context: PyOpenCL context
|
||||
|
||||
Returns:
|
||||
Function for performing x and r updates
|
||||
"""
|
||||
|
||||
update_ri_source = '''
|
||||
|
|
@ -213,16 +260,18 @@ def create_rhoerr_step(context: pyopencl.Context) -> operation:
|
|||
# Use a vector type (double3) to make the reduction simpler
|
||||
ri_dtype = pyopencl.array.vec.double3
|
||||
|
||||
ri_kernel = ReductionKernel(context,
|
||||
name='RHOERR',
|
||||
preamble=preamble,
|
||||
dtype_out=ri_dtype,
|
||||
neutral='(double3)(0.0, 0.0, 0.0)',
|
||||
map_expr=update_ri_source,
|
||||
reduce_expr='a+b',
|
||||
arguments=ctype + ' *r')
|
||||
ri_kernel = ReductionKernel(
|
||||
context,
|
||||
name='RHOERR',
|
||||
preamble=preamble,
|
||||
dtype_out=ri_dtype,
|
||||
neutral='(double3)(0.0, 0.0, 0.0)',
|
||||
map_expr=update_ri_source,
|
||||
reduce_expr='a+b',
|
||||
arguments=ctype + ' *r',
|
||||
)
|
||||
|
||||
def ri_update(r, e):
|
||||
def ri_update(r: pyopencl.array.Array, e: List[pyopencl.Event]) -> Tuple[complex, complex]:
|
||||
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
|
||||
rr, ri, ii = [g[q] for q in 'xyz']
|
||||
rho = rr + 2j * ri - ii
|
||||
|
|
@ -242,48 +291,66 @@ def create_p_step(context: pyopencl.Context) -> operation:
|
|||
after waiting for all pyopencl.Event in the list e
|
||||
and returns a list of pyopencl.Event
|
||||
|
||||
:param context: PyOpenCL context
|
||||
:return: Function for performing the p update
|
||||
Args:
|
||||
context: PyOpenCL context
|
||||
|
||||
Returns:
|
||||
Function for performing the p update
|
||||
"""
|
||||
update_p_source = '''
|
||||
p[i] = add(r[i], mul(beta, p[i]));
|
||||
'''
|
||||
p_args = ptrs('p', 'r') + [ctype + ' beta']
|
||||
|
||||
p_kernel = ElementwiseKernel(context,
|
||||
name='P',
|
||||
preamble=preamble,
|
||||
operation=update_p_source,
|
||||
arguments=', '.join(p_args))
|
||||
p_kernel = ElementwiseKernel(
|
||||
context,
|
||||
name='P',
|
||||
preamble=preamble,
|
||||
operation=update_p_source,
|
||||
arguments=', '.join(p_args),
|
||||
)
|
||||
|
||||
def p_update(p, r, beta, e):
|
||||
def p_update(
|
||||
p: pyopencl.array.Array,
|
||||
r: pyopencl.array.Array,
|
||||
beta: complex,
|
||||
e: List[pyopencl.Event]) -> List[pyopencl.Event]:
|
||||
return [p_kernel(p, r, beta, wait_for=e)]
|
||||
|
||||
return p_update
|
||||
|
||||
|
||||
def create_dot(context: pyopencl.Context) -> operation:
|
||||
def create_dot(context: pyopencl.Context) -> Callable[..., complex]:
|
||||
"""
|
||||
Return a function for performing the dot product
|
||||
p @ v
|
||||
with the signature
|
||||
dot(p, v, e) -> float
|
||||
dot(p, v, e) -> complex
|
||||
|
||||
:param context: PyOpenCL context
|
||||
:return: Function for performing the dot product
|
||||
Args:
|
||||
context: PyOpenCL context
|
||||
|
||||
Returns:
|
||||
Function for performing the dot product
|
||||
"""
|
||||
dot_dtype = numpy.complex128
|
||||
|
||||
dot_kernel = ReductionKernel(context,
|
||||
name='dot',
|
||||
preamble=preamble,
|
||||
dtype_out=dot_dtype,
|
||||
neutral='zero',
|
||||
map_expr='mul(p[i], v[i])',
|
||||
reduce_expr='add(a, b)',
|
||||
arguments=ptrs('p', 'v'))
|
||||
dot_kernel = ReductionKernel(
|
||||
context,
|
||||
name='dot',
|
||||
preamble=preamble,
|
||||
dtype_out=dot_dtype,
|
||||
neutral='zero',
|
||||
map_expr='mul(p[i], v[i])',
|
||||
reduce_expr='add(a, b)',
|
||||
arguments=ptrs('p', 'v'),
|
||||
)
|
||||
|
||||
def dot(p, v, e):
|
||||
def dot(
|
||||
p: pyopencl.array.Array,
|
||||
v: pyopencl.array.Array,
|
||||
e: List[pyopencl.Event],
|
||||
) -> complex:
|
||||
g = dot_kernel(p, v, wait_for=e)
|
||||
return g.get()
|
||||
|
||||
|
|
@ -304,8 +371,11 @@ def create_a_csr(context: pyopencl.Context) -> operation:
|
|||
The function waits on all the pyopencl.Event in e before running, and returns
|
||||
a list of pyopencl.Event.
|
||||
|
||||
:param context: PyOpenCL context
|
||||
:return: Function for sparse (M @ v) operation where M is in CSR format
|
||||
Args:
|
||||
context: PyOpenCL context
|
||||
|
||||
Returns:
|
||||
Function for sparse (M @ v) operation where M is in CSR format
|
||||
"""
|
||||
spmv_source = '''
|
||||
int start = m_row_ptr[i];
|
||||
|
|
@ -326,13 +396,20 @@ def create_a_csr(context: pyopencl.Context) -> operation:
|
|||
m_args = 'int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data'
|
||||
v_in_args = ctype + ' *v_in'
|
||||
|
||||
spmv_kernel = ElementwiseKernel(context,
|
||||
name='csr_spmv',
|
||||
preamble=preamble,
|
||||
operation=spmv_source,
|
||||
arguments=', '.join((v_out_args, m_args, v_in_args)))
|
||||
spmv_kernel = ElementwiseKernel(
|
||||
context,
|
||||
name='csr_spmv',
|
||||
preamble=preamble,
|
||||
operation=spmv_source,
|
||||
arguments=', '.join((v_out_args, m_args, v_in_args)),
|
||||
)
|
||||
|
||||
def spmv(v_out, m, v_in, e):
|
||||
def spmv(
|
||||
v_out,
|
||||
m,
|
||||
v_in,
|
||||
e: List[pyopencl.Event],
|
||||
) -> List[pyopencl.Event]:
|
||||
return [spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)]
|
||||
|
||||
return spmv
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue