From f3d13e148653c855218de611e6a367d80c344582 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Fri, 17 Apr 2026 22:24:53 -0700 Subject: [PATCH] [fdfd.eme] do a better job of enforcing no gain --- meanas/fdfd/eme.py | 4 +- meanas/test/test_eme_numerics.py | 125 +++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 2 deletions(-) create mode 100644 meanas/test/test_eme_numerics.py diff --git a/meanas/fdfd/eme.py b/meanas/fdfd/eme.py index cb1b99e..5165ef1 100644 --- a/meanas/fdfd/eme.py +++ b/meanas/fdfd/eme.py @@ -81,8 +81,8 @@ def get_s( if force_nogain: # force S @ S.H diagonal - U, sing, V = numpy.linalg.svd(ss) - ss = numpy.diag(sing) @ U @ V + U, sing, Vh = numpy.linalg.svd(ss) + ss = U @ numpy.diag(numpy.minimum(sing, 1.0)) @ Vh if force_reciprocal: ss = 0.5 * (ss + ss.T) diff --git a/meanas/test/test_eme_numerics.py b/meanas/test/test_eme_numerics.py new file mode 100644 index 0000000..40ca5ed --- /dev/null +++ b/meanas/test/test_eme_numerics.py @@ -0,0 +1,125 @@ +import numpy +from numpy.testing import assert_allclose +from scipy import sparse + +from ..fdmath import vec +from ..fdfd import eme + + +SHAPE = (3, 2, 2) +DXES = [[numpy.ones(2), numpy.ones(2)] for _ in range(2)] +WAVENUMBERS_L = numpy.array([1.0, 0.8]) +WAVENUMBERS_R = numpy.array([0.9, 0.7]) + + +def _mode(scale: float) -> tuple[numpy.ndarray, numpy.ndarray]: + e_field = (numpy.arange(12).reshape(SHAPE) + 1.0 + scale).astype(complex) + h_field = (numpy.arange(12).reshape(SHAPE) * 0.2 + 2.0 + 0.05j * scale).astype(complex) + return vec(e_field), vec(h_field) + + +def _mode_sets() -> tuple[list[tuple[numpy.ndarray, numpy.ndarray]], list[tuple[numpy.ndarray, numpy.ndarray]]]: + left_modes = [_mode(0.0), _mode(0.7)] + right_modes = [_mode(1.4), _mode(2.1)] + return left_modes, right_modes + + +def test_get_tr_returns_finite_bounded_transfer_matrices() -> None: + left_modes, right_modes = _mode_sets() + + transmission, reflection = eme.get_tr( + left_modes, + WAVENUMBERS_L, + right_modes, + WAVENUMBERS_R, + dxes=DXES, + ) + + singular_values = numpy.linalg.svd(transmission, compute_uv=False) + + assert transmission.shape == (2, 2) + assert reflection.shape == (2, 2) + assert numpy.isfinite(transmission).all() + assert numpy.isfinite(reflection).all() + assert (singular_values <= 1.0 + 1e-12).all() + + +def test_get_abcd_matches_explicit_block_formula() -> None: + left_modes, right_modes = _mode_sets() + t12, r12 = eme.get_tr(left_modes, WAVENUMBERS_L, right_modes, WAVENUMBERS_R, dxes=DXES) + t21, r21 = eme.get_tr(right_modes, WAVENUMBERS_R, left_modes, WAVENUMBERS_L, dxes=DXES) + t21_inv = numpy.linalg.pinv(t21) + + expected = numpy.block([ + [t12 - r21 @ t21_inv @ r12, r21 @ t21_inv], + [-t21_inv @ r12, t21_inv], + ]) + abcd = eme.get_abcd(left_modes, WAVENUMBERS_L, right_modes, WAVENUMBERS_R, dxes=DXES) + + assert sparse.issparse(abcd) + assert abcd.shape == (4, 4) + assert_allclose(abcd.toarray(), expected) + + +def test_get_s_plain_matches_block_assembly_from_get_tr() -> None: + left_modes, right_modes = _mode_sets() + t12, r12 = eme.get_tr(left_modes, WAVENUMBERS_L, right_modes, WAVENUMBERS_R, dxes=DXES) + t21, r21 = eme.get_tr(right_modes, WAVENUMBERS_R, left_modes, WAVENUMBERS_L, dxes=DXES) + expected = numpy.block([[r12, t12], [t21, r21]]) + + ss = eme.get_s(left_modes, WAVENUMBERS_L, right_modes, WAVENUMBERS_R, dxes=DXES) + + assert ss.shape == (4, 4) + assert numpy.isfinite(ss).all() + assert_allclose(ss, expected) + + +def test_get_s_force_nogain_caps_singular_values(monkeypatch) -> None: + def fake_get_tr(*args, **kwargs): + return numpy.array([[2.0, 0.0], [0.0, 0.5]]), numpy.zeros((2, 2)) + + monkeypatch.setattr(eme, 'get_tr', fake_get_tr) + + plain_s = eme.get_s(None, None, None, None) + clipped_s = eme.get_s(None, None, None, None, force_nogain=True) + + plain_singular_values = numpy.linalg.svd(plain_s, compute_uv=False) + clipped_singular_values = numpy.linalg.svd(clipped_s, compute_uv=False) + + assert plain_singular_values.max() > 1.0 + assert (clipped_singular_values <= 1.0 + 1e-12).all() + assert numpy.isfinite(clipped_s).all() + + +def test_get_s_force_reciprocal_symmetrizes_output(monkeypatch) -> None: + left = object() + right = object() + + def fake_get_tr(_eh_left, wavenumbers_left, _eh_right, _wavenumbers_right, **kwargs): + if wavenumbers_left is left: + return ( + numpy.array([[1.0, 2.0], [0.5, 1.0]]), + numpy.array([[0.0, 1.0], [2.0, 0.0]]), + ) + return ( + numpy.array([[1.0, -1.0], [0.0, 1.0]]), + numpy.array([[0.0, 0.5], [1.5, 0.0]]), + ) + + monkeypatch.setattr(eme, 'get_tr', fake_get_tr) + ss = eme.get_s(None, left, None, right, force_reciprocal=True) + + assert_allclose(ss, ss.T) + + +def test_get_s_force_nogain_and_reciprocal_returns_finite_output(monkeypatch) -> None: + def fake_get_tr(*args, **kwargs): + return numpy.array([[2.0, 0.0], [0.0, 0.5]]), numpy.array([[0.0, 1.0], [2.0, 0.0]]) + + monkeypatch.setattr(eme, 'get_tr', fake_get_tr) + ss = eme.get_s(None, None, None, None, force_nogain=True, force_reciprocal=True) + + assert ss.shape == (4, 4) + assert numpy.isfinite(ss).all() + assert_allclose(ss, ss.T) + assert (numpy.linalg.svd(ss, compute_uv=False) <= 1.0 + 1e-12).all()