misc cleanup
This commit is contained in:
parent
9282bfe8c0
commit
d72c5e254f
@ -31,7 +31,6 @@ from . import ops
|
||||
if TYPE_CHECKING:
|
||||
import scipy
|
||||
|
||||
__author__ = 'Jan Petykiewicz'
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -99,18 +98,18 @@ def cg(
|
||||
|
||||
m = CSRMatrix(queue, A)
|
||||
|
||||
'''
|
||||
Generate OpenCL kernels
|
||||
'''
|
||||
#
|
||||
# Generate OpenCL kernels
|
||||
#
|
||||
a_step = ops.create_a_csr(context)
|
||||
xr_step = ops.create_xr_step(context)
|
||||
rhoerr_step = ops.create_rhoerr_step(context)
|
||||
p_step = ops.create_p_step(context)
|
||||
dot = ops.create_dot(context)
|
||||
|
||||
'''
|
||||
Start the solve
|
||||
'''
|
||||
#
|
||||
# Start the solve
|
||||
#
|
||||
start_time2 = time.perf_counter()
|
||||
|
||||
_, err2 = rhoerr_step(r, [])
|
||||
@ -140,9 +139,9 @@ def cg(
|
||||
if k % 1000 == 0:
|
||||
logger.info(f'iteration {k}')
|
||||
|
||||
'''
|
||||
Done solving
|
||||
'''
|
||||
#
|
||||
# Done solving
|
||||
#
|
||||
time_elapsed = time.perf_counter() - start_time
|
||||
|
||||
x = x.get()
|
||||
|
@ -20,8 +20,6 @@ import meanas.fdfd.operators
|
||||
from . import ops
|
||||
|
||||
|
||||
__author__ = 'Jan Petykiewicz'
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -113,9 +111,9 @@ def cg_solver(
|
||||
L, R = meanas.fdfd.operators.e_full_preconditioners(dxes)
|
||||
b_preconditioned = (R if adjoint else L) @ b
|
||||
|
||||
'''
|
||||
Allocate GPU memory and load in data
|
||||
'''
|
||||
#
|
||||
# Allocate GPU memory and load in data
|
||||
#
|
||||
if context is None:
|
||||
context = pyopencl.create_some_context(interactive=True)
|
||||
|
||||
@ -155,10 +153,10 @@ def cg_solver(
|
||||
else:
|
||||
gpmc = load_field(numpy.asarray(pmc, dtype=bool), dtype=numpy.int8)
|
||||
|
||||
'''
|
||||
Generate OpenCL kernels
|
||||
'''
|
||||
has_mu, has_pec, has_pmc = [q is not None for q in (mu, pec, pmc)]
|
||||
#
|
||||
# Generate OpenCL kernels
|
||||
#
|
||||
has_mu, has_pec, has_pmc = (qq is not None for qq in (mu, pec, pmc))
|
||||
|
||||
a_step_full = ops.create_a(context, shape, has_mu, has_pec, has_pmc)
|
||||
xr_step = ops.create_xr_step(context)
|
||||
@ -174,9 +172,9 @@ def cg_solver(
|
||||
) -> list[pyopencl.Event]:
|
||||
return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events)
|
||||
|
||||
'''
|
||||
Start the solve
|
||||
'''
|
||||
#
|
||||
# Start the solve
|
||||
#
|
||||
start_time2 = time.perf_counter()
|
||||
|
||||
_, err2 = rhoerr_step(r, [])
|
||||
@ -209,16 +207,13 @@ def cg_solver(
|
||||
if k % 1000 == 0:
|
||||
logger.info(f'iteration {k}')
|
||||
|
||||
'''
|
||||
Done solving
|
||||
'''
|
||||
#
|
||||
# Done solving
|
||||
#
|
||||
time_elapsed = time.perf_counter() - start_time
|
||||
|
||||
# Undo preconditioners
|
||||
if adjoint:
|
||||
x = (Pl * x).get()
|
||||
else:
|
||||
x = (Pr * x).get()
|
||||
x = ((Pl if adjoint else Pr) * x).get()
|
||||
|
||||
if success:
|
||||
logger.info('Solve success')
|
||||
|
@ -56,6 +56,7 @@ def type_to_C(
|
||||
|
||||
return types[float_type]
|
||||
|
||||
|
||||
# Type names
|
||||
ctype = type_to_C(numpy.complex128)
|
||||
ctype_bare = 'cdouble'
|
||||
@ -123,9 +124,9 @@ def create_a(
|
||||
des = [ctype + ' *inv_de' + a for a in 'xyz']
|
||||
dhs = [ctype + ' *inv_dh' + a for a in 'xyz']
|
||||
|
||||
'''
|
||||
Convert p to initial E (ie, apply right preconditioner and PEC)
|
||||
'''
|
||||
#
|
||||
# Convert p to initial E (ie, apply right preconditioner and PEC)
|
||||
#
|
||||
p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec)
|
||||
P2E_kernel = ElementwiseKernel(
|
||||
context,
|
||||
@ -135,9 +136,9 @@ def create_a(
|
||||
arguments=', '.join(ptrs('E', 'p', 'Pr') + pec_arg),
|
||||
)
|
||||
|
||||
'''
|
||||
Calculate intermediate H from intermediate E
|
||||
'''
|
||||
#
|
||||
# Calculate intermediate H from intermediate E
|
||||
#
|
||||
e2h_source = jinja_env.get_template('e2h.cl').render(
|
||||
mu=mu,
|
||||
pmc=pmc,
|
||||
@ -151,9 +152,9 @@ def create_a(
|
||||
arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des),
|
||||
)
|
||||
|
||||
'''
|
||||
Calculate final E (including left preconditioner)
|
||||
'''
|
||||
#
|
||||
# Calculate final E (including left preconditioner)
|
||||
#
|
||||
h2e_source = jinja_env.get_template('h2e.cl').render(
|
||||
pec=pec,
|
||||
common_cl=common_source,
|
||||
@ -277,7 +278,7 @@ def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., tuple[complex
|
||||
|
||||
def ri_update(r: pyopencl.array.Array, e: list[pyopencl.Event]) -> tuple[complex, complex]:
|
||||
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
|
||||
rr, ri, ii = [g[q] for q in 'xyz']
|
||||
rr, ri, ii = (g[qq] for qq in 'xyz')
|
||||
rho = rr + 2j * ri - ii
|
||||
err = rr + ii
|
||||
return rho, err
|
||||
|
Loading…
Reference in New Issue
Block a user