[bloch] fixup some vectorization and add tests

This commit is contained in:
Jan Petykiewicz 2026-04-17 20:59:24 -07:00
commit 07b16ad86a
2 changed files with 101 additions and 14 deletions

View file

@ -136,6 +136,14 @@ except ImportError:
logger.info('Using numpy fft')
def _assemble_hmn_vector(
h_m: NDArray[numpy.complex128],
h_n: NDArray[numpy.complex128],
) -> NDArray[numpy.complex128]:
stacked = numpy.concatenate((numpy.ravel(h_m), numpy.ravel(h_n)))
return stacked[:, None]
def generate_kmn(
k0: ArrayLike,
G_matrix: ArrayLike,
@ -253,8 +261,8 @@ def maxwell_operator(
h_m, h_n = b_m, b_n
else:
# transform from mn to xyz
b_xyz = (m * b_m[:, :, :, None]
+ n * b_n[:, :, :, None])
b_xyz = (m * b_m
+ n * b_n) # noqa: E128
# divide by mu
temp = ifftn(b_xyz, axes=range(3))
@ -265,10 +273,7 @@ def maxwell_operator(
h_m = numpy.sum(h_xyz * m, axis=3)
h_n = numpy.sum(h_xyz * n, axis=3)
h.shape = (h.size,)
h = numpy.concatenate((h_m.ravel(), h_n.ravel()), axis=None, out=h) # ravel and merge
h.shape = (h.size, 1)
return h
return _assemble_hmn_vector(h_m, h_n)
return operator
@ -403,8 +408,8 @@ def inverse_maxwell_operator_approx(
b_m, b_n = hin_m, hin_n
else:
# transform from mn to xyz
h_xyz = (m * hin_m[:, :, :, None]
+ n * hin_n[:, :, :, None])
h_xyz = (m * hin_m
+ n * hin_n) # noqa: E128
# multiply by mu
temp = ifftn(h_xyz, axes=range(3))
@ -412,8 +417,8 @@ def inverse_maxwell_operator_approx(
b_xyz = fftn(temp, axes=range(3))
# transform back to mn
b_m = numpy.sum(b_xyz * m, axis=3)
b_n = numpy.sum(b_xyz * n, axis=3)
b_m = numpy.sum(b_xyz * m, axis=3, keepdims=True)
b_n = numpy.sum(b_xyz * n, axis=3, keepdims=True)
# cross product and transform into xyz basis
e_xyz = (n * b_m
@ -428,10 +433,7 @@ def inverse_maxwell_operator_approx(
h_m = numpy.sum(d_xyz * n, axis=3, keepdims=True) / +k_mag
h_n = numpy.sum(d_xyz * m, axis=3, keepdims=True) / -k_mag
h.shape = (h.size,)
h = numpy.concatenate((h_m, h_n), axis=None, out=h)
h.shape = (h.size, 1)
return h
return _assemble_hmn_vector(h_m, h_n)
return operator