improve type annotations, formatting, comment styles

This commit is contained in:
Jan Petykiewicz 2022-11-20 21:57:43 -08:00
commit efeb29479b
3 changed files with 268 additions and 169 deletions

View file

@ -7,10 +7,11 @@ kernels for use by the other solvers.
See kernels/ for any of the .cl files loaded in this file.
"""
from typing import List, Callable
from typing import List, Callable, Union, Type, Sequence, Optional, Tuple
import logging
import numpy
from numpy.typing import NDArray, ArrayLike
import jinja2
import pyopencl
@ -28,12 +29,17 @@ jinja_env = jinja2.Environment(loader=jinja2.PackageLoader(__name__, 'kernels'))
operation = Callable[..., List[pyopencl.Event]]
def type_to_C(float_type: numpy.float32 or numpy.float64) -> str:
def type_to_C(
float_type: Type,
) -> str:
"""
Returns a string corresponding to the C equivalent of a numpy type.
:param float_type: numpy type: float32, float64, complex64, complex128
:return: string containing the corresponding C type (eg. 'double')
Args:
float_type: numpy type: float32, float64, complex64, complex128
Returns:
string containing the corresponding C type (eg. 'double')
"""
types = {
numpy.float32: 'float',
@ -68,12 +74,13 @@ def ptrs(*args: str) -> List[str]:
return [ctype + ' *' + s for s in args]
def create_a(context: pyopencl.Context,
shape: numpy.ndarray,
mu: bool = False,
pec: bool = False,
pmc: bool = False,
) -> operation:
def create_a(
context: pyopencl.Context,
shape: ArrayLike,
mu: bool = False,
pec: bool = False,
pmc: bool = False,
) -> operation:
"""
Return a function which performs (A @ p), where A is the FDFD wave equation for E-field.
@ -94,12 +101,15 @@ def create_a(context: pyopencl.Context,
and returns a list of pyopencl.Event.
:param context: PyOpenCL context
:param shape: Dimensions of the E-field
:param mu: False iff (mu == 1) everywhere
:param pec: False iff no PEC anywhere
:param pmc: False iff no PMC anywhere
:return: Function for computing (A @ p)
Args:
context: PyOpenCL context
shape: Dimensions of the E-field
mu: False iff (mu == 1) everywhere
pec: False iff no PEC anywhere
pmc: False iff no PMC anywhere
Returns:
Function for computing (A @ p)
"""
common_source = jinja_env.get_template('common.cl').render(shape=shape)
@ -113,45 +123,67 @@ def create_a(context: pyopencl.Context,
Convert p to initial E (ie, apply right preconditioner and PEC)
'''
p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec)
P2E_kernel = ElementwiseKernel(context,
name='P2E',
preamble=preamble,
operation=p2e_source,
arguments=', '.join(ptrs('E', 'p', 'Pr') + pec_arg))
P2E_kernel = ElementwiseKernel(
context,
name='P2E',
preamble=preamble,
operation=p2e_source,
arguments=', '.join(ptrs('E', 'p', 'Pr') + pec_arg),
)
'''
Calculate intermediate H from intermediate E
'''
e2h_source = jinja_env.get_template('e2h.cl').render(mu=mu,
pmc=pmc,
common_cl=common_source)
E2H_kernel = ElementwiseKernel(context,
name='E2H',
preamble=preamble,
operation=e2h_source,
arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des))
e2h_source = jinja_env.get_template('e2h.cl').render(
mu=mu,
pmc=pmc,
common_cl=common_source,
)
E2H_kernel = ElementwiseKernel(
context,
name='E2H',
preamble=preamble,
operation=e2h_source,
arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des),
)
'''
Calculate final E (including left preconditioner)
'''
h2e_source = jinja_env.get_template('h2e.cl').render(pec=pec,
common_cl=common_source)
H2E_kernel = ElementwiseKernel(context,
name='H2E',
preamble=preamble,
operation=h2e_source,
arguments=', '.join(ptrs('E', 'H', 'oeps', 'Pl') + pec_arg + dhs))
h2e_source = jinja_env.get_template('h2e.cl').render(
pec=pec,
common_cl=common_source,
)
H2E_kernel = ElementwiseKernel(
context,
name='H2E',
preamble=preamble,
operation=h2e_source,
arguments=', '.join(ptrs('E', 'H', 'oeps', 'Pl') + pec_arg + dhs),
)
def spmv(E, H, p, idxes, oeps, inv_mu, pec, pmc, Pl, Pr, e):
def spmv(
E: pyopencl.array.Array,
H: pyopencl.array.Array,
p: pyopencl.array.Array,
idxes: Sequence[Sequence[pyopencl.array.Array]],
oeps: pyopencl.array.Array,
inv_mu: Optional[pyopencl.array.Array],
pec: Optional[pyopencl.array.Array],
pmc: Optional[pyopencl.array.Array],
Pl: pyopencl.array.Array,
Pr: pyopencl.array.Array,
e: List[pyopencl.Event],
) -> List[pyopencl.Event]:
e2 = P2E_kernel(E, p, Pr, pec, wait_for=e)
e2 = E2H_kernel(E, H, inv_mu, pmc, *idxes[0], wait_for=[e2])
e2 = H2E_kernel(E, H, oeps, Pl, pec, *idxes[1], wait_for=[e2])
return [e2]
logger.debug('Preamble: \n{}'.format(preamble))
logger.debug('p2e: \n{}'.format(p2e_source))
logger.debug('e2h: \n{}'.format(e2h_source))
logger.debug('h2e: \n{}'.format(h2e_source))
logger.debug(f'Preamble: \n{preamble}')
logger.debug(f'p2e: \n{p2e_source}')
logger.debug(f'e2h: \n{e2h_source}')
logger.debug(f'h2e: \n{h2e_source}')
return spmv
@ -167,8 +199,11 @@ def create_xr_step(context: pyopencl.Context) -> operation:
after waiting for all in the list e
and returns a list of pyopencl.Event
:param context: PyOpenCL context
:return: Function for performing x and r updates
Args:
context: PyOpenCL context
Returns:
Function for performing x and r updates
"""
update_xr_source = '''
x[i] = add(x[i], mul(alpha, p[i]));
@ -177,19 +212,28 @@ def create_xr_step(context: pyopencl.Context) -> operation:
xr_args = ', '.join(ptrs('x', 'p', 'r', 'v') + [ctype + ' alpha'])
xr_kernel = ElementwiseKernel(context,
name='XR',
preamble=preamble,
operation=update_xr_source,
arguments=xr_args)
xr_kernel = ElementwiseKernel(
context,
name='XR',
preamble=preamble,
operation=update_xr_source,
arguments=xr_args,
)
def xr_update(x, p, r, v, alpha, e):
def xr_update(
x: pyopencl.array.Array,
p: pyopencl.array.Array,
r: pyopencl.array.Array,
v: pyopencl.array.Array,
alpha: complex,
e: List[pyopencl.Event],
) -> List[pyopencl.Event]:
return [xr_kernel(x, p, r, v, alpha, wait_for=e)]
return xr_update
def create_rhoerr_step(context: pyopencl.Context) -> operation:
def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., Tuple[complex, complex]]:
"""
Return a function
ri_update(r, e)
@ -200,8 +244,11 @@ def create_rhoerr_step(context: pyopencl.Context) -> operation:
after waiting for all pyopencl.Event in the list e
and returns a list of pyopencl.Event
:param context: PyOpenCL context
:return: Function for performing x and r updates
Args:
context: PyOpenCL context
Returns:
Function for performing x and r updates
"""
update_ri_source = '''
@ -213,16 +260,18 @@ def create_rhoerr_step(context: pyopencl.Context) -> operation:
# Use a vector type (double3) to make the reduction simpler
ri_dtype = pyopencl.array.vec.double3
ri_kernel = ReductionKernel(context,
name='RHOERR',
preamble=preamble,
dtype_out=ri_dtype,
neutral='(double3)(0.0, 0.0, 0.0)',
map_expr=update_ri_source,
reduce_expr='a+b',
arguments=ctype + ' *r')
ri_kernel = ReductionKernel(
context,
name='RHOERR',
preamble=preamble,
dtype_out=ri_dtype,
neutral='(double3)(0.0, 0.0, 0.0)',
map_expr=update_ri_source,
reduce_expr='a+b',
arguments=ctype + ' *r',
)
def ri_update(r, e):
def ri_update(r: pyopencl.array.Array, e: List[pyopencl.Event]) -> Tuple[complex, complex]:
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
rr, ri, ii = [g[q] for q in 'xyz']
rho = rr + 2j * ri - ii
@ -242,48 +291,66 @@ def create_p_step(context: pyopencl.Context) -> operation:
after waiting for all pyopencl.Event in the list e
and returns a list of pyopencl.Event
:param context: PyOpenCL context
:return: Function for performing the p update
Args:
context: PyOpenCL context
Returns:
Function for performing the p update
"""
update_p_source = '''
p[i] = add(r[i], mul(beta, p[i]));
'''
p_args = ptrs('p', 'r') + [ctype + ' beta']
p_kernel = ElementwiseKernel(context,
name='P',
preamble=preamble,
operation=update_p_source,
arguments=', '.join(p_args))
p_kernel = ElementwiseKernel(
context,
name='P',
preamble=preamble,
operation=update_p_source,
arguments=', '.join(p_args),
)
def p_update(p, r, beta, e):
def p_update(
p: pyopencl.array.Array,
r: pyopencl.array.Array,
beta: complex,
e: List[pyopencl.Event]) -> List[pyopencl.Event]:
return [p_kernel(p, r, beta, wait_for=e)]
return p_update
def create_dot(context: pyopencl.Context) -> operation:
def create_dot(context: pyopencl.Context) -> Callable[..., complex]:
"""
Return a function for performing the dot product
p @ v
with the signature
dot(p, v, e) -> float
dot(p, v, e) -> complex
:param context: PyOpenCL context
:return: Function for performing the dot product
Args:
context: PyOpenCL context
Returns:
Function for performing the dot product
"""
dot_dtype = numpy.complex128
dot_kernel = ReductionKernel(context,
name='dot',
preamble=preamble,
dtype_out=dot_dtype,
neutral='zero',
map_expr='mul(p[i], v[i])',
reduce_expr='add(a, b)',
arguments=ptrs('p', 'v'))
dot_kernel = ReductionKernel(
context,
name='dot',
preamble=preamble,
dtype_out=dot_dtype,
neutral='zero',
map_expr='mul(p[i], v[i])',
reduce_expr='add(a, b)',
arguments=ptrs('p', 'v'),
)
def dot(p, v, e):
def dot(
p: pyopencl.array.Array,
v: pyopencl.array.Array,
e: List[pyopencl.Event],
) -> complex:
g = dot_kernel(p, v, wait_for=e)
return g.get()
@ -304,8 +371,11 @@ def create_a_csr(context: pyopencl.Context) -> operation:
The function waits on all the pyopencl.Event in e before running, and returns
a list of pyopencl.Event.
:param context: PyOpenCL context
:return: Function for sparse (M @ v) operation where M is in CSR format
Args:
context: PyOpenCL context
Returns:
Function for sparse (M @ v) operation where M is in CSR format
"""
spmv_source = '''
int start = m_row_ptr[i];
@ -326,13 +396,20 @@ def create_a_csr(context: pyopencl.Context) -> operation:
m_args = 'int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data'
v_in_args = ctype + ' *v_in'
spmv_kernel = ElementwiseKernel(context,
name='csr_spmv',
preamble=preamble,
operation=spmv_source,
arguments=', '.join((v_out_args, m_args, v_in_args)))
spmv_kernel = ElementwiseKernel(
context,
name='csr_spmv',
preamble=preamble,
operation=spmv_source,
arguments=', '.join((v_out_args, m_args, v_in_args)),
)
def spmv(v_out, m, v_in, e):
def spmv(
v_out,
m,
v_in,
e: List[pyopencl.Event],
) -> List[pyopencl.Event]:
return [spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)]
return spmv