misc cleanup
This commit is contained in:
parent
9282bfe8c0
commit
d72c5e254f
@ -31,7 +31,6 @@ from . import ops
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import scipy
|
import scipy
|
||||||
|
|
||||||
__author__ = 'Jan Petykiewicz'
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -99,18 +98,18 @@ def cg(
|
|||||||
|
|
||||||
m = CSRMatrix(queue, A)
|
m = CSRMatrix(queue, A)
|
||||||
|
|
||||||
'''
|
#
|
||||||
Generate OpenCL kernels
|
# Generate OpenCL kernels
|
||||||
'''
|
#
|
||||||
a_step = ops.create_a_csr(context)
|
a_step = ops.create_a_csr(context)
|
||||||
xr_step = ops.create_xr_step(context)
|
xr_step = ops.create_xr_step(context)
|
||||||
rhoerr_step = ops.create_rhoerr_step(context)
|
rhoerr_step = ops.create_rhoerr_step(context)
|
||||||
p_step = ops.create_p_step(context)
|
p_step = ops.create_p_step(context)
|
||||||
dot = ops.create_dot(context)
|
dot = ops.create_dot(context)
|
||||||
|
|
||||||
'''
|
#
|
||||||
Start the solve
|
# Start the solve
|
||||||
'''
|
#
|
||||||
start_time2 = time.perf_counter()
|
start_time2 = time.perf_counter()
|
||||||
|
|
||||||
_, err2 = rhoerr_step(r, [])
|
_, err2 = rhoerr_step(r, [])
|
||||||
@ -140,9 +139,9 @@ def cg(
|
|||||||
if k % 1000 == 0:
|
if k % 1000 == 0:
|
||||||
logger.info(f'iteration {k}')
|
logger.info(f'iteration {k}')
|
||||||
|
|
||||||
'''
|
#
|
||||||
Done solving
|
# Done solving
|
||||||
'''
|
#
|
||||||
time_elapsed = time.perf_counter() - start_time
|
time_elapsed = time.perf_counter() - start_time
|
||||||
|
|
||||||
x = x.get()
|
x = x.get()
|
||||||
|
@ -20,8 +20,6 @@ import meanas.fdfd.operators
|
|||||||
from . import ops
|
from . import ops
|
||||||
|
|
||||||
|
|
||||||
__author__ = 'Jan Petykiewicz'
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -113,9 +111,9 @@ def cg_solver(
|
|||||||
L, R = meanas.fdfd.operators.e_full_preconditioners(dxes)
|
L, R = meanas.fdfd.operators.e_full_preconditioners(dxes)
|
||||||
b_preconditioned = (R if adjoint else L) @ b
|
b_preconditioned = (R if adjoint else L) @ b
|
||||||
|
|
||||||
'''
|
#
|
||||||
Allocate GPU memory and load in data
|
# Allocate GPU memory and load in data
|
||||||
'''
|
#
|
||||||
if context is None:
|
if context is None:
|
||||||
context = pyopencl.create_some_context(interactive=True)
|
context = pyopencl.create_some_context(interactive=True)
|
||||||
|
|
||||||
@ -155,10 +153,10 @@ def cg_solver(
|
|||||||
else:
|
else:
|
||||||
gpmc = load_field(numpy.asarray(pmc, dtype=bool), dtype=numpy.int8)
|
gpmc = load_field(numpy.asarray(pmc, dtype=bool), dtype=numpy.int8)
|
||||||
|
|
||||||
'''
|
#
|
||||||
Generate OpenCL kernels
|
# Generate OpenCL kernels
|
||||||
'''
|
#
|
||||||
has_mu, has_pec, has_pmc = [q is not None for q in (mu, pec, pmc)]
|
has_mu, has_pec, has_pmc = (qq is not None for qq in (mu, pec, pmc))
|
||||||
|
|
||||||
a_step_full = ops.create_a(context, shape, has_mu, has_pec, has_pmc)
|
a_step_full = ops.create_a(context, shape, has_mu, has_pec, has_pmc)
|
||||||
xr_step = ops.create_xr_step(context)
|
xr_step = ops.create_xr_step(context)
|
||||||
@ -174,9 +172,9 @@ def cg_solver(
|
|||||||
) -> list[pyopencl.Event]:
|
) -> list[pyopencl.Event]:
|
||||||
return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events)
|
return a_step_full(E, H, p, inv_dxes, oeps, invm, gpec, gpmc, Pl, Pr, events)
|
||||||
|
|
||||||
'''
|
#
|
||||||
Start the solve
|
# Start the solve
|
||||||
'''
|
#
|
||||||
start_time2 = time.perf_counter()
|
start_time2 = time.perf_counter()
|
||||||
|
|
||||||
_, err2 = rhoerr_step(r, [])
|
_, err2 = rhoerr_step(r, [])
|
||||||
@ -209,16 +207,13 @@ def cg_solver(
|
|||||||
if k % 1000 == 0:
|
if k % 1000 == 0:
|
||||||
logger.info(f'iteration {k}')
|
logger.info(f'iteration {k}')
|
||||||
|
|
||||||
'''
|
#
|
||||||
Done solving
|
# Done solving
|
||||||
'''
|
#
|
||||||
time_elapsed = time.perf_counter() - start_time
|
time_elapsed = time.perf_counter() - start_time
|
||||||
|
|
||||||
# Undo preconditioners
|
# Undo preconditioners
|
||||||
if adjoint:
|
x = ((Pl if adjoint else Pr) * x).get()
|
||||||
x = (Pl * x).get()
|
|
||||||
else:
|
|
||||||
x = (Pr * x).get()
|
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.info('Solve success')
|
logger.info('Solve success')
|
||||||
|
@ -56,6 +56,7 @@ def type_to_C(
|
|||||||
|
|
||||||
return types[float_type]
|
return types[float_type]
|
||||||
|
|
||||||
|
|
||||||
# Type names
|
# Type names
|
||||||
ctype = type_to_C(numpy.complex128)
|
ctype = type_to_C(numpy.complex128)
|
||||||
ctype_bare = 'cdouble'
|
ctype_bare = 'cdouble'
|
||||||
@ -123,9 +124,9 @@ def create_a(
|
|||||||
des = [ctype + ' *inv_de' + a for a in 'xyz']
|
des = [ctype + ' *inv_de' + a for a in 'xyz']
|
||||||
dhs = [ctype + ' *inv_dh' + a for a in 'xyz']
|
dhs = [ctype + ' *inv_dh' + a for a in 'xyz']
|
||||||
|
|
||||||
'''
|
#
|
||||||
Convert p to initial E (ie, apply right preconditioner and PEC)
|
# Convert p to initial E (ie, apply right preconditioner and PEC)
|
||||||
'''
|
#
|
||||||
p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec)
|
p2e_source = jinja_env.get_template('p2e.cl').render(pec=pec)
|
||||||
P2E_kernel = ElementwiseKernel(
|
P2E_kernel = ElementwiseKernel(
|
||||||
context,
|
context,
|
||||||
@ -135,9 +136,9 @@ def create_a(
|
|||||||
arguments=', '.join(ptrs('E', 'p', 'Pr') + pec_arg),
|
arguments=', '.join(ptrs('E', 'p', 'Pr') + pec_arg),
|
||||||
)
|
)
|
||||||
|
|
||||||
'''
|
#
|
||||||
Calculate intermediate H from intermediate E
|
# Calculate intermediate H from intermediate E
|
||||||
'''
|
#
|
||||||
e2h_source = jinja_env.get_template('e2h.cl').render(
|
e2h_source = jinja_env.get_template('e2h.cl').render(
|
||||||
mu=mu,
|
mu=mu,
|
||||||
pmc=pmc,
|
pmc=pmc,
|
||||||
@ -151,9 +152,9 @@ def create_a(
|
|||||||
arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des),
|
arguments=', '.join(ptrs('E', 'H', 'inv_mu') + pmc_arg + des),
|
||||||
)
|
)
|
||||||
|
|
||||||
'''
|
#
|
||||||
Calculate final E (including left preconditioner)
|
# Calculate final E (including left preconditioner)
|
||||||
'''
|
#
|
||||||
h2e_source = jinja_env.get_template('h2e.cl').render(
|
h2e_source = jinja_env.get_template('h2e.cl').render(
|
||||||
pec=pec,
|
pec=pec,
|
||||||
common_cl=common_source,
|
common_cl=common_source,
|
||||||
@ -277,7 +278,7 @@ def create_rhoerr_step(context: pyopencl.Context) -> Callable[..., tuple[complex
|
|||||||
|
|
||||||
def ri_update(r: pyopencl.array.Array, e: list[pyopencl.Event]) -> tuple[complex, complex]:
|
def ri_update(r: pyopencl.array.Array, e: list[pyopencl.Event]) -> tuple[complex, complex]:
|
||||||
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
|
g = ri_kernel(r, wait_for=e).astype(ri_dtype).get()
|
||||||
rr, ri, ii = [g[q] for q in 'xyz']
|
rr, ri, ii = (g[qq] for qq in 'xyz')
|
||||||
rho = rr + 2j * ri - ii
|
rho = rr + 2j * ri - ii
|
||||||
err = rr + ii
|
err = rr + ii
|
||||||
return rho, err
|
return rho, err
|
||||||
|
Loading…
Reference in New Issue
Block a user