[fdfd.eme] do a better job of enforcing no gain

This commit is contained in:
Jan Petykiewicz 2026-04-17 22:24:53 -07:00
commit f3d13e1486
2 changed files with 127 additions and 2 deletions

View file

@ -81,8 +81,8 @@ def get_s(
if force_nogain:
# force S @ S.H diagonal
U, sing, V = numpy.linalg.svd(ss)
ss = numpy.diag(sing) @ U @ V
U, sing, Vh = numpy.linalg.svd(ss)
ss = U @ numpy.diag(numpy.minimum(sing, 1.0)) @ Vh
if force_reciprocal:
ss = 0.5 * (ss + ss.T)

View file

@ -0,0 +1,125 @@
import numpy
from numpy.testing import assert_allclose
from scipy import sparse
from ..fdmath import vec
from ..fdfd import eme
SHAPE = (3, 2, 2)
DXES = [[numpy.ones(2), numpy.ones(2)] for _ in range(2)]
WAVENUMBERS_L = numpy.array([1.0, 0.8])
WAVENUMBERS_R = numpy.array([0.9, 0.7])
def _mode(scale: float) -> tuple[numpy.ndarray, numpy.ndarray]:
e_field = (numpy.arange(12).reshape(SHAPE) + 1.0 + scale).astype(complex)
h_field = (numpy.arange(12).reshape(SHAPE) * 0.2 + 2.0 + 0.05j * scale).astype(complex)
return vec(e_field), vec(h_field)
def _mode_sets() -> tuple[list[tuple[numpy.ndarray, numpy.ndarray]], list[tuple[numpy.ndarray, numpy.ndarray]]]:
left_modes = [_mode(0.0), _mode(0.7)]
right_modes = [_mode(1.4), _mode(2.1)]
return left_modes, right_modes
def test_get_tr_returns_finite_bounded_transfer_matrices() -> None:
left_modes, right_modes = _mode_sets()
transmission, reflection = eme.get_tr(
left_modes,
WAVENUMBERS_L,
right_modes,
WAVENUMBERS_R,
dxes=DXES,
)
singular_values = numpy.linalg.svd(transmission, compute_uv=False)
assert transmission.shape == (2, 2)
assert reflection.shape == (2, 2)
assert numpy.isfinite(transmission).all()
assert numpy.isfinite(reflection).all()
assert (singular_values <= 1.0 + 1e-12).all()
def test_get_abcd_matches_explicit_block_formula() -> None:
left_modes, right_modes = _mode_sets()
t12, r12 = eme.get_tr(left_modes, WAVENUMBERS_L, right_modes, WAVENUMBERS_R, dxes=DXES)
t21, r21 = eme.get_tr(right_modes, WAVENUMBERS_R, left_modes, WAVENUMBERS_L, dxes=DXES)
t21_inv = numpy.linalg.pinv(t21)
expected = numpy.block([
[t12 - r21 @ t21_inv @ r12, r21 @ t21_inv],
[-t21_inv @ r12, t21_inv],
])
abcd = eme.get_abcd(left_modes, WAVENUMBERS_L, right_modes, WAVENUMBERS_R, dxes=DXES)
assert sparse.issparse(abcd)
assert abcd.shape == (4, 4)
assert_allclose(abcd.toarray(), expected)
def test_get_s_plain_matches_block_assembly_from_get_tr() -> None:
left_modes, right_modes = _mode_sets()
t12, r12 = eme.get_tr(left_modes, WAVENUMBERS_L, right_modes, WAVENUMBERS_R, dxes=DXES)
t21, r21 = eme.get_tr(right_modes, WAVENUMBERS_R, left_modes, WAVENUMBERS_L, dxes=DXES)
expected = numpy.block([[r12, t12], [t21, r21]])
ss = eme.get_s(left_modes, WAVENUMBERS_L, right_modes, WAVENUMBERS_R, dxes=DXES)
assert ss.shape == (4, 4)
assert numpy.isfinite(ss).all()
assert_allclose(ss, expected)
def test_get_s_force_nogain_caps_singular_values(monkeypatch) -> None:
def fake_get_tr(*args, **kwargs):
return numpy.array([[2.0, 0.0], [0.0, 0.5]]), numpy.zeros((2, 2))
monkeypatch.setattr(eme, 'get_tr', fake_get_tr)
plain_s = eme.get_s(None, None, None, None)
clipped_s = eme.get_s(None, None, None, None, force_nogain=True)
plain_singular_values = numpy.linalg.svd(plain_s, compute_uv=False)
clipped_singular_values = numpy.linalg.svd(clipped_s, compute_uv=False)
assert plain_singular_values.max() > 1.0
assert (clipped_singular_values <= 1.0 + 1e-12).all()
assert numpy.isfinite(clipped_s).all()
def test_get_s_force_reciprocal_symmetrizes_output(monkeypatch) -> None:
left = object()
right = object()
def fake_get_tr(_eh_left, wavenumbers_left, _eh_right, _wavenumbers_right, **kwargs):
if wavenumbers_left is left:
return (
numpy.array([[1.0, 2.0], [0.5, 1.0]]),
numpy.array([[0.0, 1.0], [2.0, 0.0]]),
)
return (
numpy.array([[1.0, -1.0], [0.0, 1.0]]),
numpy.array([[0.0, 0.5], [1.5, 0.0]]),
)
monkeypatch.setattr(eme, 'get_tr', fake_get_tr)
ss = eme.get_s(None, left, None, right, force_reciprocal=True)
assert_allclose(ss, ss.T)
def test_get_s_force_nogain_and_reciprocal_returns_finite_output(monkeypatch) -> None:
def fake_get_tr(*args, **kwargs):
return numpy.array([[2.0, 0.0], [0.0, 0.5]]), numpy.array([[0.0, 1.0], [2.0, 0.0]])
monkeypatch.setattr(eme, 'get_tr', fake_get_tr)
ss = eme.get_s(None, None, None, None, force_nogain=True, force_reciprocal=True)
assert ss.shape == (4, 4)
assert numpy.isfinite(ss).all()
assert_allclose(ss, ss.T)
assert (numpy.linalg.svd(ss, compute_uv=False) <= 1.0 + 1e-12).all()