add per-iteration callback
This commit is contained in:
parent
a82d8dfc7e
commit
c7c71a3a82
@ -411,9 +411,10 @@ def find_k(
|
||||
epsilon: fdfield_t,
|
||||
mu: Optional[fdfield_t] = None,
|
||||
band: int = 0,
|
||||
solve_callback: Optional[Callable] = None
|
||||
k_bounds: Tuple[float, float] = (0, 0.5),
|
||||
k_guess: Optional[float] = None,
|
||||
solve_callback: Optional[Callable[[...], None]] = None,
|
||||
iter_callback: Optional[Callable[[...], None]] = None,
|
||||
) -> Tuple[float, float, NDArray[numpy.complex128], NDArray[numpy.complex128]]:
|
||||
"""
|
||||
Search for a bloch vector that has a given frequency.
|
||||
@ -430,6 +431,8 @@ def find_k(
|
||||
band: Which band to search in. Default 0 (lowest frequency).
|
||||
k_bounds: Minimum and maximum values for k. Default (0, 0.5).
|
||||
k_guess: Initial value for k.
|
||||
solve_callback: TODO
|
||||
iter_callback: TODO
|
||||
|
||||
Returns:
|
||||
`(k, actual_frequency, eigenvalues, eigenvectors)`
|
||||
@ -449,7 +452,7 @@ def find_k(
|
||||
def get_f(k0_mag: float, band: int = 0) -> float:
|
||||
nonlocal n, v
|
||||
k0 = direction * k0_mag # type: ignore
|
||||
n, v = eigsolve(band + 1, k0, G_matrix=G_matrix, epsilon=epsilon, mu=mu)
|
||||
n, v = eigsolve(band + 1, k0, G_matrix=G_matrix, epsilon=epsilon, mu=mu, callback=iter_callback)
|
||||
f = numpy.sqrt(numpy.abs(numpy.real(n[band])))
|
||||
if solve_callback:
|
||||
solve_callback(k0_mag, n, v, f)
|
||||
@ -477,6 +480,7 @@ def eigsolve(
|
||||
tolerance: float = 1e-20,
|
||||
max_iters: int = 10000,
|
||||
reset_iters: int = 100,
|
||||
callback: Optional[Callable[[...], None]] = None,
|
||||
) -> Tuple[NDArray[numpy.complex128], NDArray[numpy.complex128]]:
|
||||
"""
|
||||
Find the first (lowest-frequency) num_modes eigenmodes with Bloch wavevector
|
||||
@ -491,6 +495,9 @@ def eigsolve(
|
||||
Default `None` (1 everywhere).
|
||||
tolerance: Solver stops when fractional change in the objective
|
||||
`trace(Z.H @ A @ Z @ inv(Z Z.H))` is smaller than the tolerance
|
||||
max_iters: TODO
|
||||
reset_iters: TODO
|
||||
callback: TODO
|
||||
|
||||
Returns:
|
||||
`(eigenvalues, eigenvectors)` where `eigenvalues[i]` corresponds to the
|
||||
@ -659,6 +666,9 @@ def eigsolve(
|
||||
#prev_theta = theta
|
||||
prev_E = E
|
||||
|
||||
if callback:
|
||||
callback()
|
||||
|
||||
'''
|
||||
Recover eigenvectors from Z
|
||||
'''
|
||||
|
Loading…
Reference in New Issue
Block a user