[bloch] add some more tests and clean up solves

This commit is contained in:
Jan Petykiewicz 2026-04-17 22:10:18 -07:00
commit e6756742be
2 changed files with 171 additions and 15 deletions

View file

@ -451,7 +451,7 @@ def find_k(
solve_callback: Callable[..., None] | None = None,
iter_callback: Callable[..., None] | None = None,
v0: NDArray[numpy.complex128] | None = None,
) -> tuple[float, float, NDArray[numpy.complex128], NDArray[numpy.complex128]]:
) -> tuple[NDArray[numpy.float64], float, NDArray[numpy.complex128], NDArray[numpy.complex128]]:
"""
Search for a bloch vector that has a given frequency.
@ -496,15 +496,15 @@ def find_k(
res = scipy.optimize.minimize_scalar(
lambda x: abs(get_f(x, band) - frequency),
k_guess,
method='Bounded',
method='bounded',
bounds=k_bounds,
options={'xatol': abs(tolerance)},
)
assert n is not None
assert v is not None
return float(res.x * direction), float(res.fun + frequency), n, v
actual_frequency = get_f(float(res.x), band)
return direction * float(res.x), float(actual_frequency), n, v
def eigsolve(
@ -725,7 +725,12 @@ def eigsolve(
amax=pi,
)
result = scipy.optimize.minimize_scalar(trace_func, bounds=(0, pi), tol=tolerance)
result = scipy.optimize.minimize_scalar(
trace_func,
method='bounded',
bounds=(0, pi),
options={'xatol': tolerance},
)
new_E = result.fun
theta = result.x
@ -754,7 +759,7 @@ def eigsolve(
v = eigvecs[:, i]
n = eigvals[i]
v /= norm(v)
Av = (scipy_op @ v.copy())[:, 0]
Av = numpy.asarray(scipy_op @ v.copy()).reshape(-1)
eigness = norm(Av - (v.conj() @ Av) * v)
f = numpy.sqrt(-numpy.real(n))
df = numpy.sqrt(-numpy.real(n) + eigness)
@ -823,18 +828,18 @@ def inner_product(
# eRxhR_x = numpy.cross(eR.reshape(3, -1), hR.reshape(3, -1), axis=0).reshape(eR.shape)[0] / normR
# logger.info(f'power {eRxhR_x.sum() / 2})
eR /= numpy.sqrt(norm2R)
hR /= numpy.sqrt(norm2R)
eL /= numpy.sqrt(norm2L)
hL /= numpy.sqrt(norm2L)
eR_norm = eR / numpy.sqrt(abs(norm2R))
hR_norm = hR / numpy.sqrt(abs(norm2R))
eL_norm = eL / numpy.sqrt(abs(norm2L))
hL_norm = hL / numpy.sqrt(abs(norm2L))
# (eR x hL)[0] and (eL x hR)[0]
eRxhL_x = eR[1] * hL[2] - eR[2] - hL[1]
eLxhR_x = eL[1] * hR[2] - eL[2] - hR[1]
eRxhL_x = eR_norm[1] * hL_norm[2] - eR_norm[2] * hL_norm[1]
eLxhR_x = eL_norm[1] * hR_norm[2] - eL_norm[2] * hR_norm[1]
#return 1j * (eRxhL_x - eLxhR_x).sum() / numpy.sqrt(norm2R * norm2L)
#return (eRxhL_x.sum() - eLxhR_x.sum()) / numpy.sqrt(norm2R * norm2L)
return eRxhL_x.sum() - eLxhR_x.sum()
return eLxhR_x.sum() - eRxhL_x.sum()
def trq(
@ -848,8 +853,8 @@ def trq(
np = inner_product(eO, -hO, eI, hI)
nn = inner_product(eO, -hO, eI, -hI)
assert pp == -nn
assert pn == -np
assert numpy.allclose(pp, -nn, atol=1e-12, rtol=1e-12)
assert numpy.allclose(pn, -np, atol=1e-12, rtol=1e-12)
logger.info(f'''
{pp=:4g} {pn=:4g}