diff --git a/opencl_fdfd/csr.py b/opencl_fdfd/csr.py index ec49c09..b84eec0 100644 --- a/opencl_fdfd/csr.py +++ b/opencl_fdfd/csr.py @@ -31,7 +31,6 @@ from . import ops if TYPE_CHECKING: import scipy -__author__ = 'Jan Petykiewicz' logger = logging.getLogger(__name__) @@ -99,18 +98,18 @@ def cg( m = CSRMatrix(queue, A) - ''' - Generate OpenCL kernels - ''' + # + # Generate OpenCL kernels + # a_step = ops.create_a_csr(context) xr_step = ops.create_xr_step(context) rhoerr_step = ops.create_rhoerr_step(context) p_step = ops.create_p_step(context) dot = ops.create_dot(context) - ''' - Start the solve - ''' + # + # Start the solve + # start_time2 = time.perf_counter() _, err2 = rhoerr_step(r, []) @@ -140,9 +139,9 @@ def cg( if k % 1000 == 0: logger.info(f'iteration {k}') - ''' - Done solving - ''' + # + # Done solving + # time_elapsed = time.perf_counter() - start_time x = x.get() diff --git a/opencl_fdfd/main.py b/opencl_fdfd/main.py index 769aee3..9e57395 100644 --- a/opencl_fdfd/main.py +++ b/opencl_fdfd/main.py @@ -20,8 +20,6 @@ import meanas.fdfd.operators from . import ops -__author__ = 'Jan Petykiewicz' - logger = logging.getLogger(__name__) @@ -113,9 +111,9 @@ def cg_solver( L, R = meanas.fdfd.operators.e_full_preconditioners(dxes) b_preconditioned = (R if adjoint else L) @ b - ''' - Allocate GPU memory and load in data - ''' + # + # Allocate GPU memory and load in data + # if context is None: context = pyopencl.create_some_context(interactive=True) @@ -155,10 +153,10 @@ def cg_solver( else: gpmc = load_field(numpy.asarray(pmc, dtype=bool), dtype=numpy.int8) - ''' - Generate OpenCL kernels - ''' - has_mu, has_pec, has_pmc = [q is not None for q in (mu, pec, pmc)] + # + # Generate OpenCL kernels + # + has_mu, has_pec, has_pmc = (qq is not None for qq in (mu, pec, pmc)) a_step_full = ops.create_a(context, shape, has_mu, has_pec, has_pmc) xr_step = ops.create_xr_step(context) @@ -174,9 +172,9 @@ def cg_solver( ) -> list[pyopencl.Event]: return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events) - ''' - Start the solve - ''' + # + # Start the solve + # start_time2 = time.perf_counter() _, err2 = rhoerr_step(r, []) @@ -209,16 +207,13 @@ def cg_solver( if k % 1000 == 0: logger.info(f'iteration {k}') - ''' - Done solving - ''' + # + # Done solving + # time_elapsed = time.perf_counter() - start_time # Undo preconditioners - if adjoint: - x = (Pl * x).get() - else: - x = (Pr * x).get() + x = ((Pl if adjoint else Pr) * x).get() if success: logger.info('Solve success') diff --git a/opencl_fdfd/ops.py b/opencl_fdfd/ops.py index 80bf836..c2d73ed 100644 --- a/opencl_fdfd/ops.py +++ b/opencl_fdfd/ops.py @@ -56,6 +56,7 @@ def type_to_C( return types[float_type] + # Type names ctype = type_to_C(numpy.complex128) ctype_bare = 'cdouble' @@ -123,9 +124,9 @@ def create_a( des = [ctype + ' *inv_de' + a for a in 'xyz'] dhs = [ctype + ' *inv_dh' + a for a in 'xyz'] - ''' - Convert p to initial E (ie, apply right preconditioner and PEC) - ''' + # + # Convert p to initial E (ie, apply right preconditioner and PEC) + # p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec) P2E_kernel = ElementwiseKernel( context, @@ -135,9 +136,9 @@ def create_a( arguments=', '.join(ptrs('E', 'p', 'Pr') + pec_arg), ) - ''' - Calculate intermediate H from intermediate E - ''' + # + # Calculate intermediate H from intermediate E + # e2h_source = jinja_env.get_template('e2h.cl').render( mu=mu, pmc=pmc, @@ -151,9 +152,9 @@ def create_a( arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des), ) - ''' - Calculate final E (including left preconditioner) - ''' + # + # Calculate final E (including left preconditioner) + # h2e_source = jinja_env.get_template('h2e.cl').render( pec=pec, common_cl=common_source, @@ -277,7 +278,7 @@ def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., tuple[complex def ri_update(r: pyopencl.array.Array, e: list[pyopencl.Event]) -> tuple[complex, complex]: g = ri_kernel(r, wait_for=e).astype(ri_dtype).get() - rr, ri, ii = [g[q] for q in 'xyz'] + rr, ri, ii = (g[qq] for qq in 'xyz') rho = rr + 2j * ri - ii err = rr + ii return rho, err