diff --git a/meanas/fdfd/eme.py b/meanas/fdfd/eme.py index 5165ef1..366de8e 100644 --- a/meanas/fdfd/eme.py +++ b/meanas/fdfd/eme.py @@ -1,3 +1,24 @@ +""" +Low-level mode-matching helpers for waveguide / EME workflows. + +These helpers operate on already-solved and already-normalized port fields. +They do not build geometries or solve modes themselves; downstream users are +expected to supply compatible `(E, H)` modal field pairs from +`waveguide_2d`, `waveguide_3d`, or `waveguide_cyl`. + +The returned matrices follow the usual port ordering: + +- `get_tr(...)` returns `(T, R)` for left-incident modes. +- `get_abcd(...)` returns the 2-port block transfer matrix built from the two + directional `T/R` solves. +- `get_s(...)` returns the full block scattering matrix + `[[R12, T12], [T21, R21]]`. + +This module is intentionally a thin library layer rather than an integrated +simulation suite. It provides the overlap algebra that downstream users can +compose into larger workflows. +""" + from collections.abc import Sequence import numpy from numpy.typing import NDArray @@ -7,6 +28,37 @@ from ..fdmath import dx_lists2_t, vcfdfield2 from .waveguide_2d import inner_product +def _validate_port_modes( + name: str, + ehs: Sequence[Sequence[vcfdfield2]], + wavenumbers: Sequence[complex], + ) -> tuple[tuple[int, ...], tuple[int, ...]]: + if len(ehs) != len(wavenumbers): + raise ValueError(f'{name} mode list and wavenumber list must have the same length') + if not ehs: + raise ValueError(f'{name} must contain at least one mode') + + e_shape: tuple[int, ...] | None = None + h_shape: tuple[int, ...] | None = None + for index, mode in enumerate(ehs): + if len(mode) != 2: + raise ValueError(f'{name}[{index}] must be a 2-tuple of (E, H) modal fields') + e_field, h_field = mode + mode_e_shape = numpy.shape(e_field) + mode_h_shape = numpy.shape(h_field) + if mode_e_shape != mode_h_shape: + raise ValueError(f'{name}[{index}] has mismatched E/H field shapes') + if e_shape is None: + e_shape = mode_e_shape + h_shape = mode_h_shape + elif mode_e_shape != e_shape or mode_h_shape != h_shape: + raise ValueError(f'{name} modal fields must all share the same shape') + + assert e_shape is not None + assert h_shape is not None + return e_shape, h_shape + + def get_tr( ehLs: Sequence[Sequence[vcfdfield2]], wavenumbers_L: Sequence[complex], @@ -14,6 +66,29 @@ def get_tr( wavenumbers_R: Sequence[complex], dxes: dx_lists2_t, ) -> tuple[NDArray[numpy.complex128], NDArray[numpy.complex128]]: + """ + Compute left-incident transmission and reflection matrices. + + Args: + ehLs: Left-port modes as `(E, H)` field pairs. + wavenumbers_L: Propagation constants for `ehLs`. + ehRs: Right-port modes as `(E, H)` field pairs. + wavenumbers_R: Propagation constants for `ehRs`. + dxes: Two-dimensional Yee-cell edge lengths for the shared port plane. + + Returns: + `(T12, R12)` where columns index left-incident modes and rows index + outgoing right-going / left-going modes respectively. + + Raises: + ValueError: If the port mode lists are empty, malformed, or defined on + incompatible field shapes. + """ + left_e_shape, left_h_shape = _validate_port_modes('ehLs', ehLs, wavenumbers_L) + right_e_shape, right_h_shape = _validate_port_modes('ehRs', ehRs, wavenumbers_R) + if left_e_shape != right_e_shape or left_h_shape != right_h_shape: + raise ValueError('left and right modal fields must share the same E/H shapes') + nL = len(wavenumbers_L) nR = len(wavenumbers_R) A12 = numpy.zeros((nL, nR), dtype=complex) @@ -48,6 +123,16 @@ def get_abcd( wavenumbers_R: Sequence[complex], **kwargs, ) -> sparse.sparray: + """ + Build the 2-port block transfer matrix for an interface. + + The blocks are assembled from the forward and reverse `get_tr(...)` + solutions using the standard + + `[[A, B], [C, D]] = [[T12 - R21 T21^-1 R12, R21 T21^-1], [-T21^-1 R12, T21^-1]]` + + convention. + """ t12, r12 = get_tr(ehLs, wavenumbers_L, ehRs, wavenumbers_R, **kwargs) t21, r21 = get_tr(ehRs, wavenumbers_R, ehLs, wavenumbers_L, **kwargs) t21i = numpy.linalg.pinv(t21) @@ -73,6 +158,19 @@ def get_s( force_reciprocal: bool = False, **kwargs, ) -> NDArray[numpy.complex128]: + """ + Build the full block scattering matrix for a two-sided interface. + + The returned matrix is ordered as `[[R12, T12], [T21, R21]]`, where the + first block-row/column corresponds to the left port and the second to the + right port. + + Args: + force_nogain: If `True`, clamp singular values of the assembled + scattering matrix to at most one. + force_reciprocal: If `True`, symmetrize the assembled matrix as + `0.5 * (S + S.T)`. + """ t12, r12 = get_tr(ehLs, wavenumbers_L, ehRs, wavenumbers_R, **kwargs) t21, r21 = get_tr(ehRs, wavenumbers_R, ehLs, wavenumbers_L, **kwargs) diff --git a/meanas/test/test_eme_numerics.py b/meanas/test/test_eme_numerics.py index 8798e0d..7486128 100644 --- a/meanas/test/test_eme_numerics.py +++ b/meanas/test/test_eme_numerics.py @@ -1,8 +1,9 @@ import numpy +import pytest from scipy import sparse from ..fdmath import vec -from ..fdfd import eme +from ..fdfd import eme, waveguide_2d, waveguide_cyl from ._test_builders import complex_ramp, unit_dxes from .utils import assert_close @@ -11,6 +12,8 @@ SHAPE = (3, 2, 2) DXES = unit_dxes((2, 2)) WAVENUMBERS_L = numpy.array([1.0, 0.8]) WAVENUMBERS_R = numpy.array([0.9, 0.7]) +OMEGA = 1 / 1500 +REAL_DXES = unit_dxes((5, 5)) def _mode(scale: float) -> tuple[numpy.ndarray, numpy.ndarray]: @@ -130,3 +133,102 @@ def test_get_s_force_nogain_and_reciprocal_returns_finite_output(monkeypatch) -> assert numpy.isfinite(ss).all() assert_close(ss, ss.T) assert (numpy.linalg.svd(ss, compute_uv=False) <= 1.0 + 1e-12).all() + + +def test_get_tr_rejects_length_mismatches() -> None: + left_modes, right_modes = _mode_sets() + + with pytest.raises(ValueError, match='same length'): + eme.get_tr(left_modes[:1], WAVENUMBERS_L, right_modes, WAVENUMBERS_R, dxes=DXES) + + +def test_get_tr_rejects_malformed_mode_tuples() -> None: + bad_modes = [(numpy.ones(4),)] + + with pytest.raises(ValueError, match='2-tuple'): + eme.get_tr(bad_modes, [1.0], bad_modes, [1.0], dxes=DXES) + + +def test_get_tr_rejects_incompatible_field_shapes() -> None: + left_modes = [(numpy.ones(4), numpy.ones(4))] + right_modes = [(numpy.ones(6), numpy.ones(6))] + + with pytest.raises(ValueError, match='same E/H shapes'): + eme.get_tr(left_modes, [1.0], right_modes, [1.0], dxes=DXES) + + +def _build_real_epsilon() -> numpy.ndarray: + epsilon = numpy.ones((3, 5, 5), dtype=float) + epsilon[:, 2, 1] = 2.0 + return vec(epsilon) + + +def _build_straight_mode() -> tuple[tuple[numpy.ndarray, numpy.ndarray], complex, numpy.ndarray]: + epsilon = _build_real_epsilon() + e_xy, wavenumber = waveguide_2d.solve_mode( + 0, + omega=OMEGA, + dxes=REAL_DXES, + epsilon=epsilon, + ) + e_field, h_field = waveguide_2d.normalized_fields_e( + e_xy, + wavenumber=wavenumber, + omega=OMEGA, + dxes=REAL_DXES, + epsilon=epsilon, + ) + return (e_field, h_field), wavenumber, epsilon + + +def _build_bend_mode() -> tuple[tuple[numpy.ndarray, numpy.ndarray], complex]: + epsilon = vec(numpy.ones((3, 5, 5), dtype=float)) + rmin = 10.0 + e_xy, angular_wavenumber = waveguide_cyl.solve_mode( + 0, + omega=OMEGA, + dxes=REAL_DXES, + epsilon=epsilon, + rmin=rmin, + ) + linear_wavenumber = waveguide_cyl.linear_wavenumbers( + [e_xy], + [angular_wavenumber], + epsilon=epsilon, + dxes=REAL_DXES, + rmin=rmin, + )[0] + e_field, h_field = waveguide_cyl.normalized_fields_e( + e_xy, + angular_wavenumber=angular_wavenumber, + omega=OMEGA, + dxes=REAL_DXES, + epsilon=epsilon, + rmin=rmin, + ) + return (e_field, h_field), linear_wavenumber + + +def test_get_s_is_near_identity_for_identical_solved_straight_modes() -> None: + mode, wavenumber, _epsilon = _build_straight_mode() + + ss = eme.get_s([mode], [wavenumber], [mode], [wavenumber], dxes=REAL_DXES) + + assert ss.shape == (2, 2) + assert numpy.isfinite(ss).all() + assert abs(ss[0, 0]) < 1e-6 + assert abs(ss[1, 1]) < 1e-6 + assert abs(abs(ss[0, 1]) - 1.0) < 1e-6 + assert abs(abs(ss[1, 0]) - 1.0) < 1e-6 + assert numpy.linalg.svd(ss, compute_uv=False).max() <= 1.0 + 1e-10 + + +def test_get_s_returns_finite_passive_output_for_small_straight_to_bend_fixture() -> None: + straight_mode, straight_wavenumber, _epsilon = _build_straight_mode() + bend_mode, bend_wavenumber = _build_bend_mode() + + ss = eme.get_s([straight_mode], [straight_wavenumber], [bend_mode], [bend_wavenumber], dxes=REAL_DXES) + + assert ss.shape == (2, 2) + assert numpy.isfinite(ss).all() + assert numpy.linalg.svd(ss, compute_uv=False).max() <= 1.0 + 1e-10