From d09eff990f9160e70bc837174990305a7be791db Mon Sep 17 00:00:00 2001 From: jan Date: Sun, 17 Dec 2017 20:51:34 -0800 Subject: [PATCH] Update Rayleigh quotient iteration to allow arbitrary linear operators --- fdfd_tools/eigensolvers.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/fdfd_tools/eigensolvers.py b/fdfd_tools/eigensolvers.py index 24d7339..e348cd7 100644 --- a/fdfd_tools/eigensolvers.py +++ b/fdfd_tools/eigensolvers.py @@ -33,10 +33,11 @@ def power_iteration(operator: sparse.spmatrix, return lm_eigval, v -def rayleigh_quotient_iteration(operator: sparse.spmatrix, +def rayleigh_quotient_iteration(operator: sparse.spmatrix or spalg.LinearOperator, guess_vector: numpy.ndarray, iterations: int = 40, tolerance: float = 1e-13, + solver=None, ) -> Tuple[complex, numpy.ndarray]: """ Use Rayleigh quotient iteration to refine an eigenvector guess. @@ -46,16 +47,33 @@ def rayleigh_quotient_iteration(operator: sparse.spmatrix, :param iterations: Maximum number of iterations to perform. Default 40. :param tolerance: Stop iteration if (A - I*eigenvalue) @ v < tolerance. Default 1e-13. + :param solver: Solver function of the form x = solver(A, b). + By default, use scipy.sparse.spsolve for sparse matrices and + scipy.sparse.bicgstab for general LinearOperator instances. :return: (eigenvalue, eigenvector) """ + try: + _test = operator - sparse.eye(operator.shape) + shift = lambda eigval: eigval * sparse.eye(operator.shape[0]) + if solver is None: + solver = spalg.spsolve + except TypeError: + shift = lambda eigval: spalg.LinearOperator(shape=operator.shape, + dtype=operator.dtype, + matvec=lambda v: eigval * v) + if solver is None: + solver = lambda A, b: spalg.bicgstab(A, b)[0] + v = guess_vector + v /= norm(v) for _ in range(iterations): - eigval = v.conj() @ operator @ v + eigval = v.conj() @ (operator @ v) if norm(operator @ v - eigval * v) < tolerance: break - v = spalg.spsolve(operator - eigval * sparse.eye(operator.shape[0]), v) - v /= norm(v) + shifted_operator = operator - shift(eigval) + v = solver(shifted_operator, v) + v /= norm(v) return eigval, v