add csr.cg_solve

This commit is contained in:
jan 2016-07-04 23:30:06 -07:00
parent e94a07db28
commit 67cafd870e

View File

@ -5,6 +5,7 @@ import pyopencl.array
import time import time
import fdfd_tools.operators
from . import ops from . import ops
@ -106,6 +107,28 @@ def cg(a, b, max_iters=10000, err_thresh=1e-6, context=None, queue=None, verbose
print('final error', errs[-1]) print('final error', errs[-1])
print('overhead {} sec'.format(start_time2 - start_time)) print('overhead {} sec'.format(start_time2 - start_time))
print('Post-everything residual:', norm(a @ x - b) / norm(b)) print('Final residual:', norm(a @ x - b) / norm(b))
return x return x
def cg_solver(omega, dxes, J, epsilon, mu=None, pec=None, pmc=None, adjoint=False, solver_opts=None):
b0 = -1j * omega * J
A0 = fdfd_tools.operators.e_full(omega, dxes, epsilon=epsilon, mu=mu, pec=pec, pmc=pmc)
Pl, Pr = fdfd_tools.operators.e_full_preconditioners(dxes)
if adjoint:
A = (Pl @ A0 @ Pr).H
b = Pr.H @ b0
else:
A = Pl @ A0 @ Pr
b = Pl @ b0
x = cg(A.tocsr(), b, **solver_opts)
if adjoint:
x0 = Pl.H @ x
else:
x0 = Pr @ x
return x0