[EME] add more docs and tests

This commit is contained in:
Forgejo Actions 2026-04-21 19:40:32 -07:00
commit 1a2c6ab524
2 changed files with 201 additions and 1 deletions

View file

@ -1,3 +1,24 @@
"""
Low-level mode-matching helpers for waveguide / EME workflows.
These helpers operate on already-solved and already-normalized port fields.
They do not build geometries or solve modes themselves; downstream users are
expected to supply compatible `(E, H)` modal field pairs from
`waveguide_2d`, `waveguide_3d`, or `waveguide_cyl`.
The returned matrices follow the usual port ordering:
- `get_tr(...)` returns `(T, R)` for left-incident modes.
- `get_abcd(...)` returns the 2-port block transfer matrix built from the two
directional `T/R` solves.
- `get_s(...)` returns the full block scattering matrix
`[[R12, T12], [T21, R21]]`.
This module is intentionally a thin library layer rather than an integrated
simulation suite. It provides the overlap algebra that downstream users can
compose into larger workflows.
"""
from collections.abc import Sequence
import numpy
from numpy.typing import NDArray
@ -7,6 +28,37 @@ from ..fdmath import dx_lists2_t, vcfdfield2
from .waveguide_2d import inner_product
def _validate_port_modes(
name: str,
ehs: Sequence[Sequence[vcfdfield2]],
wavenumbers: Sequence[complex],
) -> tuple[tuple[int, ...], tuple[int, ...]]:
if len(ehs) != len(wavenumbers):
raise ValueError(f'{name} mode list and wavenumber list must have the same length')
if not ehs:
raise ValueError(f'{name} must contain at least one mode')
e_shape: tuple[int, ...] | None = None
h_shape: tuple[int, ...] | None = None
for index, mode in enumerate(ehs):
if len(mode) != 2:
raise ValueError(f'{name}[{index}] must be a 2-tuple of (E, H) modal fields')
e_field, h_field = mode
mode_e_shape = numpy.shape(e_field)
mode_h_shape = numpy.shape(h_field)
if mode_e_shape != mode_h_shape:
raise ValueError(f'{name}[{index}] has mismatched E/H field shapes')
if e_shape is None:
e_shape = mode_e_shape
h_shape = mode_h_shape
elif mode_e_shape != e_shape or mode_h_shape != h_shape:
raise ValueError(f'{name} modal fields must all share the same shape')
assert e_shape is not None
assert h_shape is not None
return e_shape, h_shape
def get_tr(
ehLs: Sequence[Sequence[vcfdfield2]],
wavenumbers_L: Sequence[complex],
@ -14,6 +66,29 @@ def get_tr(
wavenumbers_R: Sequence[complex],
dxes: dx_lists2_t,
) -> tuple[NDArray[numpy.complex128], NDArray[numpy.complex128]]:
"""
Compute left-incident transmission and reflection matrices.
Args:
ehLs: Left-port modes as `(E, H)` field pairs.
wavenumbers_L: Propagation constants for `ehLs`.
ehRs: Right-port modes as `(E, H)` field pairs.
wavenumbers_R: Propagation constants for `ehRs`.
dxes: Two-dimensional Yee-cell edge lengths for the shared port plane.
Returns:
`(T12, R12)` where columns index left-incident modes and rows index
outgoing right-going / left-going modes respectively.
Raises:
ValueError: If the port mode lists are empty, malformed, or defined on
incompatible field shapes.
"""
left_e_shape, left_h_shape = _validate_port_modes('ehLs', ehLs, wavenumbers_L)
right_e_shape, right_h_shape = _validate_port_modes('ehRs', ehRs, wavenumbers_R)
if left_e_shape != right_e_shape or left_h_shape != right_h_shape:
raise ValueError('left and right modal fields must share the same E/H shapes')
nL = len(wavenumbers_L)
nR = len(wavenumbers_R)
A12 = numpy.zeros((nL, nR), dtype=complex)
@ -48,6 +123,16 @@ def get_abcd(
wavenumbers_R: Sequence[complex],
**kwargs,
) -> sparse.sparray:
"""
Build the 2-port block transfer matrix for an interface.
The blocks are assembled from the forward and reverse `get_tr(...)`
solutions using the standard
`[[A, B], [C, D]] = [[T12 - R21 T21^-1 R12, R21 T21^-1], [-T21^-1 R12, T21^-1]]`
convention.
"""
t12, r12 = get_tr(ehLs, wavenumbers_L, ehRs, wavenumbers_R, **kwargs)
t21, r21 = get_tr(ehRs, wavenumbers_R, ehLs, wavenumbers_L, **kwargs)
t21i = numpy.linalg.pinv(t21)
@ -73,6 +158,19 @@ def get_s(
force_reciprocal: bool = False,
**kwargs,
) -> NDArray[numpy.complex128]:
"""
Build the full block scattering matrix for a two-sided interface.
The returned matrix is ordered as `[[R12, T12], [T21, R21]]`, where the
first block-row/column corresponds to the left port and the second to the
right port.
Args:
force_nogain: If `True`, clamp singular values of the assembled
scattering matrix to at most one.
force_reciprocal: If `True`, symmetrize the assembled matrix as
`0.5 * (S + S.T)`.
"""
t12, r12 = get_tr(ehLs, wavenumbers_L, ehRs, wavenumbers_R, **kwargs)
t21, r21 = get_tr(ehRs, wavenumbers_R, ehLs, wavenumbers_L, **kwargs)

View file

@ -1,8 +1,9 @@
import numpy
import pytest
from scipy import sparse
from ..fdmath import vec
from ..fdfd import eme
from ..fdfd import eme, waveguide_2d, waveguide_cyl
from ._test_builders import complex_ramp, unit_dxes
from .utils import assert_close
@ -11,6 +12,8 @@ SHAPE = (3, 2, 2)
DXES = unit_dxes((2, 2))
WAVENUMBERS_L = numpy.array([1.0, 0.8])
WAVENUMBERS_R = numpy.array([0.9, 0.7])
OMEGA = 1 / 1500
REAL_DXES = unit_dxes((5, 5))
def _mode(scale: float) -> tuple[numpy.ndarray, numpy.ndarray]:
@ -130,3 +133,102 @@ def test_get_s_force_nogain_and_reciprocal_returns_finite_output(monkeypatch) ->
assert numpy.isfinite(ss).all()
assert_close(ss, ss.T)
assert (numpy.linalg.svd(ss, compute_uv=False) <= 1.0 + 1e-12).all()
def test_get_tr_rejects_length_mismatches() -> None:
left_modes, right_modes = _mode_sets()
with pytest.raises(ValueError, match='same length'):
eme.get_tr(left_modes[:1], WAVENUMBERS_L, right_modes, WAVENUMBERS_R, dxes=DXES)
def test_get_tr_rejects_malformed_mode_tuples() -> None:
bad_modes = [(numpy.ones(4),)]
with pytest.raises(ValueError, match='2-tuple'):
eme.get_tr(bad_modes, [1.0], bad_modes, [1.0], dxes=DXES)
def test_get_tr_rejects_incompatible_field_shapes() -> None:
left_modes = [(numpy.ones(4), numpy.ones(4))]
right_modes = [(numpy.ones(6), numpy.ones(6))]
with pytest.raises(ValueError, match='same E/H shapes'):
eme.get_tr(left_modes, [1.0], right_modes, [1.0], dxes=DXES)
def _build_real_epsilon() -> numpy.ndarray:
epsilon = numpy.ones((3, 5, 5), dtype=float)
epsilon[:, 2, 1] = 2.0
return vec(epsilon)
def _build_straight_mode() -> tuple[tuple[numpy.ndarray, numpy.ndarray], complex, numpy.ndarray]:
epsilon = _build_real_epsilon()
e_xy, wavenumber = waveguide_2d.solve_mode(
0,
omega=OMEGA,
dxes=REAL_DXES,
epsilon=epsilon,
)
e_field, h_field = waveguide_2d.normalized_fields_e(
e_xy,
wavenumber=wavenumber,
omega=OMEGA,
dxes=REAL_DXES,
epsilon=epsilon,
)
return (e_field, h_field), wavenumber, epsilon
def _build_bend_mode() -> tuple[tuple[numpy.ndarray, numpy.ndarray], complex]:
epsilon = vec(numpy.ones((3, 5, 5), dtype=float))
rmin = 10.0
e_xy, angular_wavenumber = waveguide_cyl.solve_mode(
0,
omega=OMEGA,
dxes=REAL_DXES,
epsilon=epsilon,
rmin=rmin,
)
linear_wavenumber = waveguide_cyl.linear_wavenumbers(
[e_xy],
[angular_wavenumber],
epsilon=epsilon,
dxes=REAL_DXES,
rmin=rmin,
)[0]
e_field, h_field = waveguide_cyl.normalized_fields_e(
e_xy,
angular_wavenumber=angular_wavenumber,
omega=OMEGA,
dxes=REAL_DXES,
epsilon=epsilon,
rmin=rmin,
)
return (e_field, h_field), linear_wavenumber
def test_get_s_is_near_identity_for_identical_solved_straight_modes() -> None:
mode, wavenumber, _epsilon = _build_straight_mode()
ss = eme.get_s([mode], [wavenumber], [mode], [wavenumber], dxes=REAL_DXES)
assert ss.shape == (2, 2)
assert numpy.isfinite(ss).all()
assert abs(ss[0, 0]) < 1e-6
assert abs(ss[1, 1]) < 1e-6
assert abs(abs(ss[0, 1]) - 1.0) < 1e-6
assert abs(abs(ss[1, 0]) - 1.0) < 1e-6
assert numpy.linalg.svd(ss, compute_uv=False).max() <= 1.0 + 1e-10
def test_get_s_returns_finite_passive_output_for_small_straight_to_bend_fixture() -> None:
straight_mode, straight_wavenumber, _epsilon = _build_straight_mode()
bend_mode, bend_wavenumber = _build_bend_mode()
ss = eme.get_s([straight_mode], [straight_wavenumber], [bend_mode], [bend_wavenumber], dxes=REAL_DXES)
assert ss.shape == (2, 2)
assert numpy.isfinite(ss).all()
assert numpy.linalg.svd(ss, compute_uv=False).max() <= 1.0 + 1e-10