diff --git a/opencl_fdfd/csr.py b/opencl_fdfd/csr.py index e42ad8f..e24e780 100644 --- a/opencl_fdfd/csr.py +++ b/opencl_fdfd/csr.py @@ -5,6 +5,7 @@ import pyopencl.array import time +import fdfd_tools.operators from . import ops @@ -106,6 +107,28 @@ def cg(a, b, max_iters=10000, err_thresh=1e-6, context=None, queue=None, verbose print('final error', errs[-1]) print('overhead {} sec'.format(start_time2 - start_time)) - print('Post-everything residual:', norm(a @ x - b) / norm(b)) + print('Final residual:', norm(a @ x - b) / norm(b)) return x + +def cg_solver(omega, dxes, J, epsilon, mu=None, pec=None, pmc=None, adjoint=False, solver_opts=None): + b0 = -1j * omega * J + A0 = fdfd_tools.operators.e_full(omega, dxes, epsilon=epsilon, mu=mu, pec=pec, pmc=pmc) + + Pl, Pr = fdfd_tools.operators.e_full_preconditioners(dxes) + + if adjoint: + A = (Pl @ A0 @ Pr).H + b = Pr.H @ b0 + else: + A = Pl @ A0 @ Pr + b = Pl @ b0 + + x = cg(A.tocsr(), b, **solver_opts) + + if adjoint: + x0 = Pl.H @ x + else: + x0 = Pr @ x + + return x0