add per-iteration callback

This commit is contained in:
Jan Petykiewicz 2022-11-14 12:20:50 -08:00
parent a82d8dfc7e
commit c7c71a3a82

View File

@ -411,9 +411,10 @@ def find_k(
epsilon: fdfield_t, epsilon: fdfield_t,
mu: Optional[fdfield_t] = None, mu: Optional[fdfield_t] = None,
band: int = 0, band: int = 0,
solve_callback: Optional[Callable] = None
k_bounds: Tuple[float, float] = (0, 0.5), k_bounds: Tuple[float, float] = (0, 0.5),
k_guess: Optional[float] = None, k_guess: Optional[float] = None,
solve_callback: Optional[Callable[[...], None]] = None,
iter_callback: Optional[Callable[[...], None]] = None,
) -> Tuple[float, float, NDArray[numpy.complex128], NDArray[numpy.complex128]]: ) -> Tuple[float, float, NDArray[numpy.complex128], NDArray[numpy.complex128]]:
""" """
Search for a bloch vector that has a given frequency. Search for a bloch vector that has a given frequency.
@ -430,6 +431,8 @@ def find_k(
band: Which band to search in. Default 0 (lowest frequency). band: Which band to search in. Default 0 (lowest frequency).
k_bounds: Minimum and maximum values for k. Default (0, 0.5). k_bounds: Minimum and maximum values for k. Default (0, 0.5).
k_guess: Initial value for k. k_guess: Initial value for k.
solve_callback: TODO
iter_callback: TODO
Returns: Returns:
`(k, actual_frequency, eigenvalues, eigenvectors)` `(k, actual_frequency, eigenvalues, eigenvectors)`
@ -449,7 +452,7 @@ def find_k(
def get_f(k0_mag: float, band: int = 0) -> float: def get_f(k0_mag: float, band: int = 0) -> float:
nonlocal n, v nonlocal n, v
k0 = direction * k0_mag # type: ignore k0 = direction * k0_mag # type: ignore
n, v = eigsolve(band + 1, k0, G_matrix=G_matrix, epsilon=epsilon, mu=mu) n, v = eigsolve(band + 1, k0, G_matrix=G_matrix, epsilon=epsilon, mu=mu, callback=iter_callback)
f = numpy.sqrt(numpy.abs(numpy.real(n[band]))) f = numpy.sqrt(numpy.abs(numpy.real(n[band])))
if solve_callback: if solve_callback:
solve_callback(k0_mag, n, v, f) solve_callback(k0_mag, n, v, f)
@ -477,6 +480,7 @@ def eigsolve(
tolerance: float = 1e-20, tolerance: float = 1e-20,
max_iters: int = 10000, max_iters: int = 10000,
reset_iters: int = 100, reset_iters: int = 100,
callback: Optional[Callable[[...], None]] = None,
) -> Tuple[NDArray[numpy.complex128], NDArray[numpy.complex128]]: ) -> Tuple[NDArray[numpy.complex128], NDArray[numpy.complex128]]:
""" """
Find the first (lowest-frequency) num_modes eigenmodes with Bloch wavevector Find the first (lowest-frequency) num_modes eigenmodes with Bloch wavevector
@ -491,6 +495,9 @@ def eigsolve(
Default `None` (1 everywhere). Default `None` (1 everywhere).
tolerance: Solver stops when fractional change in the objective tolerance: Solver stops when fractional change in the objective
`trace(Z.H @ A @ Z @ inv(Z Z.H))` is smaller than the tolerance `trace(Z.H @ A @ Z @ inv(Z Z.H))` is smaller than the tolerance
max_iters: TODO
reset_iters: TODO
callback: TODO
Returns: Returns:
`(eigenvalues, eigenvectors)` where `eigenvalues[i]` corresponds to the `(eigenvalues, eigenvectors)` where `eigenvalues[i]` corresponds to the
@ -659,6 +666,9 @@ def eigsolve(
#prev_theta = theta #prev_theta = theta
prev_E = E prev_E = E
if callback:
callback()
''' '''
Recover eigenvectors from Z Recover eigenvectors from Z
''' '''