misc cleanup

This commit is contained in:
Jan Petykiewicz 2024-07-30 22:43:29 -07:00
parent 9282bfe8c0
commit d72c5e254f
3 changed files with 34 additions and 39 deletions

View File

@ -31,7 +31,6 @@ from . import ops
if TYPE_CHECKING: if TYPE_CHECKING:
import scipy import scipy
__author__ = 'Jan Petykiewicz'
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -99,18 +98,18 @@ def cg(
m = CSRMatrix(queue, A) m = CSRMatrix(queue, A)
''' #
Generate OpenCL kernels # Generate OpenCL kernels
''' #
a_step = ops.create_a_csr(context) a_step = ops.create_a_csr(context)
xr_step = ops.create_xr_step(context) xr_step = ops.create_xr_step(context)
rhoerr_step = ops.create_rhoerr_step(context) rhoerr_step = ops.create_rhoerr_step(context)
p_step = ops.create_p_step(context) p_step = ops.create_p_step(context)
dot = ops.create_dot(context) dot = ops.create_dot(context)
''' #
Start the solve # Start the solve
''' #
start_time2 = time.perf_counter() start_time2 = time.perf_counter()
_, err2 = rhoerr_step(r, []) _, err2 = rhoerr_step(r, [])
@ -140,9 +139,9 @@ def cg(
if k % 1000 == 0: if k % 1000 == 0:
logger.info(f'iteration {k}') logger.info(f'iteration {k}')
''' #
Done solving # Done solving
''' #
time_elapsed = time.perf_counter() - start_time time_elapsed = time.perf_counter() - start_time
x = x.get() x = x.get()

View File

@ -20,8 +20,6 @@ import meanas.fdfd.operators
from . import ops from . import ops
__author__ = 'Jan Petykiewicz'
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -113,9 +111,9 @@ def cg_solver(
L, R = meanas.fdfd.operators.e_full_preconditioners(dxes) L, R = meanas.fdfd.operators.e_full_preconditioners(dxes)
b_preconditioned = (R if adjoint else L) @ b b_preconditioned = (R if adjoint else L) @ b
''' #
Allocate GPU memory and load in data # Allocate GPU memory and load in data
''' #
if context is None: if context is None:
context = pyopencl.create_some_context(interactive=True) context = pyopencl.create_some_context(interactive=True)
@ -155,10 +153,10 @@ def cg_solver(
else: else:
gpmc = load_field(numpy.asarray(pmc, dtype=bool), dtype=numpy.int8) gpmc = load_field(numpy.asarray(pmc, dtype=bool), dtype=numpy.int8)
''' #
Generate OpenCL kernels # Generate OpenCL kernels
''' #
has_mu, has_pec, has_pmc = [q is not None for q in (mu, pec, pmc)] has_mu, has_pec, has_pmc = (qq is not None for qq in (mu, pec, pmc))
a_step_full = ops.create_a(context, shape, has_mu, has_pec, has_pmc) a_step_full = ops.create_a(context, shape, has_mu, has_pec, has_pmc)
xr_step = ops.create_xr_step(context) xr_step = ops.create_xr_step(context)
@ -174,9 +172,9 @@ def cg_solver(
) -> list[pyopencl.Event]: ) -> list[pyopencl.Event]:
return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events) return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events)
''' #
Start the solve # Start the solve
''' #
start_time2 = time.perf_counter() start_time2 = time.perf_counter()
_, err2 = rhoerr_step(r, []) _, err2 = rhoerr_step(r, [])
@ -209,16 +207,13 @@ def cg_solver(
if k % 1000 == 0: if k % 1000 == 0:
logger.info(f'iteration {k}') logger.info(f'iteration {k}')
''' #
Done solving # Done solving
''' #
time_elapsed = time.perf_counter() - start_time time_elapsed = time.perf_counter() - start_time
# Undo preconditioners # Undo preconditioners
if adjoint: x = ((Pl if adjoint else Pr) * x).get()
x = (Pl * x).get()
else:
x = (Pr * x).get()
if success: if success:
logger.info('Solve success') logger.info('Solve success')

View File

@ -56,6 +56,7 @@ def type_to_C(
return types[float_type] return types[float_type]
# Type names # Type names
ctype = type_to_C(numpy.complex128) ctype = type_to_C(numpy.complex128)
ctype_bare = 'cdouble' ctype_bare = 'cdouble'
@ -123,9 +124,9 @@ def create_a(
des = [ctype + ' *inv_de' + a for a in 'xyz'] des = [ctype + ' *inv_de' + a for a in 'xyz']
dhs = [ctype + ' *inv_dh' + a for a in 'xyz'] dhs = [ctype + ' *inv_dh' + a for a in 'xyz']
''' #
Convert p to initial E (ie, apply right preconditioner and PEC) # Convert p to initial E (ie, apply right preconditioner and PEC)
''' #
p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec) p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec)
P2E_kernel = ElementwiseKernel( P2E_kernel = ElementwiseKernel(
context, context,
@ -135,9 +136,9 @@ def create_a(
arguments=', '.join(ptrs('E', 'p', 'Pr') + pec_arg), arguments=', '.join(ptrs('E', 'p', 'Pr') + pec_arg),
) )
''' #
Calculate intermediate H from intermediate E # Calculate intermediate H from intermediate E
''' #
e2h_source = jinja_env.get_template('e2h.cl').render( e2h_source = jinja_env.get_template('e2h.cl').render(
mu=mu, mu=mu,
pmc=pmc, pmc=pmc,
@ -151,9 +152,9 @@ def create_a(
arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des), arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des),
) )
''' #
Calculate final E (including left preconditioner) # Calculate final E (including left preconditioner)
''' #
h2e_source = jinja_env.get_template('h2e.cl').render( h2e_source = jinja_env.get_template('h2e.cl').render(
pec=pec, pec=pec,
common_cl=common_source, common_cl=common_source,
@ -277,7 +278,7 @@ def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., tuple[complex
def ri_update(r: pyopencl.array.Array, e: list[pyopencl.Event]) -> tuple[complex, complex]: def ri_update(r: pyopencl.array.Array, e: list[pyopencl.Event]) -> tuple[complex, complex]:
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get() g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
rr, ri, ii = [g[q] for q in 'xyz'] rr, ri, ii = (g[qq] for qq in 'xyz')
rho = rr + 2j * ri - ii rho = rr + 2j * ri - ii
err = rr + ii err = rr + ii
return rho, err return rho, err