add per-iteration callback
This commit is contained in:
parent
a82d8dfc7e
commit
c7c71a3a82
@ -411,9 +411,10 @@ def find_k(
|
|||||||
epsilon: fdfield_t,
|
epsilon: fdfield_t,
|
||||||
mu: Optional[fdfield_t] = None,
|
mu: Optional[fdfield_t] = None,
|
||||||
band: int = 0,
|
band: int = 0,
|
||||||
solve_callback: Optional[Callable] = None
|
|
||||||
k_bounds: Tuple[float, float] = (0, 0.5),
|
k_bounds: Tuple[float, float] = (0, 0.5),
|
||||||
k_guess: Optional[float] = None,
|
k_guess: Optional[float] = None,
|
||||||
|
solve_callback: Optional[Callable[[...], None]] = None,
|
||||||
|
iter_callback: Optional[Callable[[...], None]] = None,
|
||||||
) -> Tuple[float, float, NDArray[numpy.complex128], NDArray[numpy.complex128]]:
|
) -> Tuple[float, float, NDArray[numpy.complex128], NDArray[numpy.complex128]]:
|
||||||
"""
|
"""
|
||||||
Search for a bloch vector that has a given frequency.
|
Search for a bloch vector that has a given frequency.
|
||||||
@ -430,6 +431,8 @@ def find_k(
|
|||||||
band: Which band to search in. Default 0 (lowest frequency).
|
band: Which band to search in. Default 0 (lowest frequency).
|
||||||
k_bounds: Minimum and maximum values for k. Default (0, 0.5).
|
k_bounds: Minimum and maximum values for k. Default (0, 0.5).
|
||||||
k_guess: Initial value for k.
|
k_guess: Initial value for k.
|
||||||
|
solve_callback: TODO
|
||||||
|
iter_callback: TODO
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`(k, actual_frequency, eigenvalues, eigenvectors)`
|
`(k, actual_frequency, eigenvalues, eigenvectors)`
|
||||||
@ -449,7 +452,7 @@ def find_k(
|
|||||||
def get_f(k0_mag: float, band: int = 0) -> float:
|
def get_f(k0_mag: float, band: int = 0) -> float:
|
||||||
nonlocal n, v
|
nonlocal n, v
|
||||||
k0 = direction * k0_mag # type: ignore
|
k0 = direction * k0_mag # type: ignore
|
||||||
n, v = eigsolve(band + 1, k0, G_matrix=G_matrix, epsilon=epsilon, mu=mu)
|
n, v = eigsolve(band + 1, k0, G_matrix=G_matrix, epsilon=epsilon, mu=mu, callback=iter_callback)
|
||||||
f = numpy.sqrt(numpy.abs(numpy.real(n[band])))
|
f = numpy.sqrt(numpy.abs(numpy.real(n[band])))
|
||||||
if solve_callback:
|
if solve_callback:
|
||||||
solve_callback(k0_mag, n, v, f)
|
solve_callback(k0_mag, n, v, f)
|
||||||
@ -477,6 +480,7 @@ def eigsolve(
|
|||||||
tolerance: float = 1e-20,
|
tolerance: float = 1e-20,
|
||||||
max_iters: int = 10000,
|
max_iters: int = 10000,
|
||||||
reset_iters: int = 100,
|
reset_iters: int = 100,
|
||||||
|
callback: Optional[Callable[[...], None]] = None,
|
||||||
) -> Tuple[NDArray[numpy.complex128], NDArray[numpy.complex128]]:
|
) -> Tuple[NDArray[numpy.complex128], NDArray[numpy.complex128]]:
|
||||||
"""
|
"""
|
||||||
Find the first (lowest-frequency) num_modes eigenmodes with Bloch wavevector
|
Find the first (lowest-frequency) num_modes eigenmodes with Bloch wavevector
|
||||||
@ -491,6 +495,9 @@ def eigsolve(
|
|||||||
Default `None` (1 everywhere).
|
Default `None` (1 everywhere).
|
||||||
tolerance: Solver stops when fractional change in the objective
|
tolerance: Solver stops when fractional change in the objective
|
||||||
`trace(Z.H @ A @ Z @ inv(Z Z.H))` is smaller than the tolerance
|
`trace(Z.H @ A @ Z @ inv(Z Z.H))` is smaller than the tolerance
|
||||||
|
max_iters: TODO
|
||||||
|
reset_iters: TODO
|
||||||
|
callback: TODO
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`(eigenvalues, eigenvectors)` where `eigenvalues[i]` corresponds to the
|
`(eigenvalues, eigenvectors)` where `eigenvalues[i]` corresponds to the
|
||||||
@ -659,6 +666,9 @@ def eigsolve(
|
|||||||
#prev_theta = theta
|
#prev_theta = theta
|
||||||
prev_E = E
|
prev_E = E
|
||||||
|
|
||||||
|
if callback:
|
||||||
|
callback()
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Recover eigenvectors from Z
|
Recover eigenvectors from Z
|
||||||
'''
|
'''
|
||||||
|
Loading…
Reference in New Issue
Block a user