From 67cafd870e2326236e01df18719606eeb9c2eb27 Mon Sep 17 00:00:00 2001 From: jan Date: Mon, 4 Jul 2016 23:30:06 -0700 Subject: [PATCH] add csr.cg_solve --- opencl_fdfd/csr.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/opencl_fdfd/csr.py b/opencl_fdfd/csr.py index e42ad8f..e24e780 100644 --- a/opencl_fdfd/csr.py +++ b/opencl_fdfd/csr.py @@ -5,6 +5,7 @@ import pyopencl.array import time +import fdfd_tools.operators from . import ops @@ -106,6 +107,28 @@ def cg(a, b, max_iters=10000, err_thresh=1e-6, context=None, queue=None, verbose print('final error', errs[-1]) print('overhead {} sec'.format(start_time2 - start_time)) - print('Post-everything residual:', norm(a @ x - b) / norm(b)) + print('Final residual:', norm(a @ x - b) / norm(b)) return x + +def cg_solver(omega, dxes, J, epsilon, mu=None, pec=None, pmc=None, adjoint=False, solver_opts=None): + b0 = -1j * omega * J + A0 = fdfd_tools.operators.e_full(omega, dxes, epsilon=epsilon, mu=mu, pec=pec, pmc=pmc) + + Pl, Pr = fdfd_tools.operators.e_full_preconditioners(dxes) + + if adjoint: + A = (Pl @ A0 @ Pr).H + b = Pr.H @ b0 + else: + A = Pl @ A0 @ Pr + b = Pl @ b0 + + x = cg(A.tocsr(), b, **solver_opts) + + if adjoint: + x0 = Pl.H @ x + else: + x0 = Pr @ x + + return x0