[EME] add more docs and tests
This commit is contained in:
parent
8b67696d7f
commit
1a2c6ab524
2 changed files with 201 additions and 1 deletions
|
|
@ -1,3 +1,24 @@
|
|||
"""
|
||||
Low-level mode-matching helpers for waveguide / EME workflows.
|
||||
|
||||
These helpers operate on already-solved and already-normalized port fields.
|
||||
They do not build geometries or solve modes themselves; downstream users are
|
||||
expected to supply compatible `(E, H)` modal field pairs from
|
||||
`waveguide_2d`, `waveguide_3d`, or `waveguide_cyl`.
|
||||
|
||||
The returned matrices follow the usual port ordering:
|
||||
|
||||
- `get_tr(...)` returns `(T, R)` for left-incident modes.
|
||||
- `get_abcd(...)` returns the 2-port block transfer matrix built from the two
|
||||
directional `T/R` solves.
|
||||
- `get_s(...)` returns the full block scattering matrix
|
||||
`[[R12, T12], [T21, R21]]`.
|
||||
|
||||
This module is intentionally a thin library layer rather than an integrated
|
||||
simulation suite. It provides the overlap algebra that downstream users can
|
||||
compose into larger workflows.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import numpy
|
||||
from numpy.typing import NDArray
|
||||
|
|
@ -7,6 +28,37 @@ from ..fdmath import dx_lists2_t, vcfdfield2
|
|||
from .waveguide_2d import inner_product
|
||||
|
||||
|
||||
def _validate_port_modes(
|
||||
name: str,
|
||||
ehs: Sequence[Sequence[vcfdfield2]],
|
||||
wavenumbers: Sequence[complex],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
if len(ehs) != len(wavenumbers):
|
||||
raise ValueError(f'{name} mode list and wavenumber list must have the same length')
|
||||
if not ehs:
|
||||
raise ValueError(f'{name} must contain at least one mode')
|
||||
|
||||
e_shape: tuple[int, ...] | None = None
|
||||
h_shape: tuple[int, ...] | None = None
|
||||
for index, mode in enumerate(ehs):
|
||||
if len(mode) != 2:
|
||||
raise ValueError(f'{name}[{index}] must be a 2-tuple of (E, H) modal fields')
|
||||
e_field, h_field = mode
|
||||
mode_e_shape = numpy.shape(e_field)
|
||||
mode_h_shape = numpy.shape(h_field)
|
||||
if mode_e_shape != mode_h_shape:
|
||||
raise ValueError(f'{name}[{index}] has mismatched E/H field shapes')
|
||||
if e_shape is None:
|
||||
e_shape = mode_e_shape
|
||||
h_shape = mode_h_shape
|
||||
elif mode_e_shape != e_shape or mode_h_shape != h_shape:
|
||||
raise ValueError(f'{name} modal fields must all share the same shape')
|
||||
|
||||
assert e_shape is not None
|
||||
assert h_shape is not None
|
||||
return e_shape, h_shape
|
||||
|
||||
|
||||
def get_tr(
|
||||
ehLs: Sequence[Sequence[vcfdfield2]],
|
||||
wavenumbers_L: Sequence[complex],
|
||||
|
|
@ -14,6 +66,29 @@ def get_tr(
|
|||
wavenumbers_R: Sequence[complex],
|
||||
dxes: dx_lists2_t,
|
||||
) -> tuple[NDArray[numpy.complex128], NDArray[numpy.complex128]]:
|
||||
"""
|
||||
Compute left-incident transmission and reflection matrices.
|
||||
|
||||
Args:
|
||||
ehLs: Left-port modes as `(E, H)` field pairs.
|
||||
wavenumbers_L: Propagation constants for `ehLs`.
|
||||
ehRs: Right-port modes as `(E, H)` field pairs.
|
||||
wavenumbers_R: Propagation constants for `ehRs`.
|
||||
dxes: Two-dimensional Yee-cell edge lengths for the shared port plane.
|
||||
|
||||
Returns:
|
||||
`(T12, R12)` where columns index left-incident modes and rows index
|
||||
outgoing right-going / left-going modes respectively.
|
||||
|
||||
Raises:
|
||||
ValueError: If the port mode lists are empty, malformed, or defined on
|
||||
incompatible field shapes.
|
||||
"""
|
||||
left_e_shape, left_h_shape = _validate_port_modes('ehLs', ehLs, wavenumbers_L)
|
||||
right_e_shape, right_h_shape = _validate_port_modes('ehRs', ehRs, wavenumbers_R)
|
||||
if left_e_shape != right_e_shape or left_h_shape != right_h_shape:
|
||||
raise ValueError('left and right modal fields must share the same E/H shapes')
|
||||
|
||||
nL = len(wavenumbers_L)
|
||||
nR = len(wavenumbers_R)
|
||||
A12 = numpy.zeros((nL, nR), dtype=complex)
|
||||
|
|
@ -48,6 +123,16 @@ def get_abcd(
|
|||
wavenumbers_R: Sequence[complex],
|
||||
**kwargs,
|
||||
) -> sparse.sparray:
|
||||
"""
|
||||
Build the 2-port block transfer matrix for an interface.
|
||||
|
||||
The blocks are assembled from the forward and reverse `get_tr(...)`
|
||||
solutions using the standard
|
||||
|
||||
`[[A, B], [C, D]] = [[T12 - R21 T21^-1 R12, R21 T21^-1], [-T21^-1 R12, T21^-1]]`
|
||||
|
||||
convention.
|
||||
"""
|
||||
t12, r12 = get_tr(ehLs, wavenumbers_L, ehRs, wavenumbers_R, **kwargs)
|
||||
t21, r21 = get_tr(ehRs, wavenumbers_R, ehLs, wavenumbers_L, **kwargs)
|
||||
t21i = numpy.linalg.pinv(t21)
|
||||
|
|
@ -73,6 +158,19 @@ def get_s(
|
|||
force_reciprocal: bool = False,
|
||||
**kwargs,
|
||||
) -> NDArray[numpy.complex128]:
|
||||
"""
|
||||
Build the full block scattering matrix for a two-sided interface.
|
||||
|
||||
The returned matrix is ordered as `[[R12, T12], [T21, R21]]`, where the
|
||||
first block-row/column corresponds to the left port and the second to the
|
||||
right port.
|
||||
|
||||
Args:
|
||||
force_nogain: If `True`, clamp singular values of the assembled
|
||||
scattering matrix to at most one.
|
||||
force_reciprocal: If `True`, symmetrize the assembled matrix as
|
||||
`0.5 * (S + S.T)`.
|
||||
"""
|
||||
t12, r12 = get_tr(ehLs, wavenumbers_L, ehRs, wavenumbers_R, **kwargs)
|
||||
t21, r21 = get_tr(ehRs, wavenumbers_R, ehLs, wavenumbers_L, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import numpy
|
||||
import pytest
|
||||
from scipy import sparse
|
||||
|
||||
from ..fdmath import vec
|
||||
from ..fdfd import eme
|
||||
from ..fdfd import eme, waveguide_2d, waveguide_cyl
|
||||
from ._test_builders import complex_ramp, unit_dxes
|
||||
from .utils import assert_close
|
||||
|
||||
|
|
@ -11,6 +12,8 @@ SHAPE = (3, 2, 2)
|
|||
DXES = unit_dxes((2, 2))
|
||||
WAVENUMBERS_L = numpy.array([1.0, 0.8])
|
||||
WAVENUMBERS_R = numpy.array([0.9, 0.7])
|
||||
OMEGA = 1 / 1500
|
||||
REAL_DXES = unit_dxes((5, 5))
|
||||
|
||||
|
||||
def _mode(scale: float) -> tuple[numpy.ndarray, numpy.ndarray]:
|
||||
|
|
@ -130,3 +133,102 @@ def test_get_s_force_nogain_and_reciprocal_returns_finite_output(monkeypatch) ->
|
|||
assert numpy.isfinite(ss).all()
|
||||
assert_close(ss, ss.T)
|
||||
assert (numpy.linalg.svd(ss, compute_uv=False) <= 1.0 + 1e-12).all()
|
||||
|
||||
|
||||
def test_get_tr_rejects_length_mismatches() -> None:
|
||||
left_modes, right_modes = _mode_sets()
|
||||
|
||||
with pytest.raises(ValueError, match='same length'):
|
||||
eme.get_tr(left_modes[:1], WAVENUMBERS_L, right_modes, WAVENUMBERS_R, dxes=DXES)
|
||||
|
||||
|
||||
def test_get_tr_rejects_malformed_mode_tuples() -> None:
|
||||
bad_modes = [(numpy.ones(4),)]
|
||||
|
||||
with pytest.raises(ValueError, match='2-tuple'):
|
||||
eme.get_tr(bad_modes, [1.0], bad_modes, [1.0], dxes=DXES)
|
||||
|
||||
|
||||
def test_get_tr_rejects_incompatible_field_shapes() -> None:
|
||||
left_modes = [(numpy.ones(4), numpy.ones(4))]
|
||||
right_modes = [(numpy.ones(6), numpy.ones(6))]
|
||||
|
||||
with pytest.raises(ValueError, match='same E/H shapes'):
|
||||
eme.get_tr(left_modes, [1.0], right_modes, [1.0], dxes=DXES)
|
||||
|
||||
|
||||
def _build_real_epsilon() -> numpy.ndarray:
|
||||
epsilon = numpy.ones((3, 5, 5), dtype=float)
|
||||
epsilon[:, 2, 1] = 2.0
|
||||
return vec(epsilon)
|
||||
|
||||
|
||||
def _build_straight_mode() -> tuple[tuple[numpy.ndarray, numpy.ndarray], complex, numpy.ndarray]:
|
||||
epsilon = _build_real_epsilon()
|
||||
e_xy, wavenumber = waveguide_2d.solve_mode(
|
||||
0,
|
||||
omega=OMEGA,
|
||||
dxes=REAL_DXES,
|
||||
epsilon=epsilon,
|
||||
)
|
||||
e_field, h_field = waveguide_2d.normalized_fields_e(
|
||||
e_xy,
|
||||
wavenumber=wavenumber,
|
||||
omega=OMEGA,
|
||||
dxes=REAL_DXES,
|
||||
epsilon=epsilon,
|
||||
)
|
||||
return (e_field, h_field), wavenumber, epsilon
|
||||
|
||||
|
||||
def _build_bend_mode() -> tuple[tuple[numpy.ndarray, numpy.ndarray], complex]:
|
||||
epsilon = vec(numpy.ones((3, 5, 5), dtype=float))
|
||||
rmin = 10.0
|
||||
e_xy, angular_wavenumber = waveguide_cyl.solve_mode(
|
||||
0,
|
||||
omega=OMEGA,
|
||||
dxes=REAL_DXES,
|
||||
epsilon=epsilon,
|
||||
rmin=rmin,
|
||||
)
|
||||
linear_wavenumber = waveguide_cyl.linear_wavenumbers(
|
||||
[e_xy],
|
||||
[angular_wavenumber],
|
||||
epsilon=epsilon,
|
||||
dxes=REAL_DXES,
|
||||
rmin=rmin,
|
||||
)[0]
|
||||
e_field, h_field = waveguide_cyl.normalized_fields_e(
|
||||
e_xy,
|
||||
angular_wavenumber=angular_wavenumber,
|
||||
omega=OMEGA,
|
||||
dxes=REAL_DXES,
|
||||
epsilon=epsilon,
|
||||
rmin=rmin,
|
||||
)
|
||||
return (e_field, h_field), linear_wavenumber
|
||||
|
||||
|
||||
def test_get_s_is_near_identity_for_identical_solved_straight_modes() -> None:
|
||||
mode, wavenumber, _epsilon = _build_straight_mode()
|
||||
|
||||
ss = eme.get_s([mode], [wavenumber], [mode], [wavenumber], dxes=REAL_DXES)
|
||||
|
||||
assert ss.shape == (2, 2)
|
||||
assert numpy.isfinite(ss).all()
|
||||
assert abs(ss[0, 0]) < 1e-6
|
||||
assert abs(ss[1, 1]) < 1e-6
|
||||
assert abs(abs(ss[0, 1]) - 1.0) < 1e-6
|
||||
assert abs(abs(ss[1, 0]) - 1.0) < 1e-6
|
||||
assert numpy.linalg.svd(ss, compute_uv=False).max() <= 1.0 + 1e-10
|
||||
|
||||
|
||||
def test_get_s_returns_finite_passive_output_for_small_straight_to_bend_fixture() -> None:
|
||||
straight_mode, straight_wavenumber, _epsilon = _build_straight_mode()
|
||||
bend_mode, bend_wavenumber = _build_bend_mode()
|
||||
|
||||
ss = eme.get_s([straight_mode], [straight_wavenumber], [bend_mode], [bend_wavenumber], dxes=REAL_DXES)
|
||||
|
||||
assert ss.shape == (2, 2)
|
||||
assert numpy.isfinite(ss).all()
|
||||
assert numpy.linalg.svd(ss, compute_uv=False).max() <= 1.0 + 1e-10
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue