[fdfd.waveguide_3d] improve handling of out-of-bounds overlap_e windows

This commit is contained in:
Jan Petykiewicz 2026-04-17 20:44:36 -07:00
commit f35b334100
2 changed files with 218 additions and 19 deletions

View file

@ -5,6 +5,8 @@ This module relies heavily on `waveguide_2d` and mostly just transforms
its parameters into 2D equivalents and expands the results back into 3D. its parameters into 2D equivalents and expands the results back into 3D.
""" """
from typing import Any, cast from typing import Any, cast
import warnings
from typing import Any
from collections.abc import Sequence from collections.abc import Sequence
import numpy import numpy
from numpy.typing import NDArray from numpy.typing import NDArray
@ -200,17 +202,33 @@ def compute_overlap_e(
Ee = expand_e(E=E, wavenumber=wavenumber, dxes=dxes, Ee = expand_e(E=E, wavenumber=wavenumber, dxes=dxes,
axis=axis, polarity=polarity, slices=slices) axis=axis, polarity=polarity, slices=slices)
start, stop = sorted((slices[axis].start, slices[axis].start - 2 * polarity)) axis_size = E.shape[axis + 1]
if polarity > 0:
start = slices[axis].start - 2
stop = slices[axis].start
else:
start = slices[axis].stop
stop = slices[axis].stop + 2
clipped_start = max(0, start)
clipped_stop = min(axis_size, stop)
if clipped_start >= clipped_stop:
raise ValueError('Requested overlap window lies outside the domain')
if clipped_start != start or clipped_stop != stop:
warnings.warn('Requested overlap window was clipped to fit within the domain', RuntimeWarning)
slices2_l = list(slices) slices2_l = list(slices)
slices2_l[axis] = slice(start, stop) slices2_l[axis] = slice(clipped_start, clipped_stop)
slices2 = (slice(None), *slices2_l) slices2 = (slice(None), *slices2_l)
Etgt = numpy.zeros_like(Ee) Etgt = numpy.zeros_like(Ee)
Etgt[slices2] = Ee[slices2] Etgt[slices2] = Ee[slices2]
# Note: We normalize so that (Etgt @ E.conj()) == 1, so (Etgt @ Etgt.conj) != 1 # Note: We normalize so that (Etgt @ E.conj()) == 1, so (Etgt @ Etgt.conj) != 1
Etgt /= (Etgt.conj() * Etgt).sum() norm = (Etgt.conj() * Etgt).sum()
if norm == 0:
raise ValueError('Requested overlap window contains no overlap field support')
Etgt /= norm
return cfdfield_t(Etgt) return cfdfield_t(Etgt)

View file

@ -1,29 +1,56 @@
import contextlib
import io
import numpy import numpy
from numpy.linalg import norm from numpy.linalg import norm
import pytest
import warnings
from ..fdmath import vec from ..fdmath import vec, unvec
from ..fdfd import waveguide_3d, waveguide_cyl from ..fdfd import waveguide_2d, waveguide_3d, waveguide_cyl
OMEGA = 1 / 1500 OMEGA = 1 / 1500
def test_waveguide_3d_solve_mode_and_expand_e_are_phase_consistent() -> None: def build_waveguide_3d_mode(
*,
slice_start: int,
polarity: int,
) -> tuple[numpy.ndarray, list[list[numpy.ndarray]], tuple[slice, slice, slice], dict[str, complex | numpy.ndarray]]:
epsilon = numpy.ones((3, 5, 5, 1), dtype=float) epsilon = numpy.ones((3, 5, 5, 1), dtype=float)
dxes = [[numpy.ones(5), numpy.ones(5), numpy.ones(1)] for _ in range(2)] dxes = [[numpy.ones(5), numpy.ones(5), numpy.ones(1)] for _ in range(2)]
axis = 0 slices = (slice(slice_start, slice_start + 1), slice(None), slice(None))
polarity = 1
slices = (slice(0, 1), slice(None), slice(None))
result = waveguide_3d.solve_mode( result = waveguide_3d.solve_mode(
0, 0,
omega=OMEGA, omega=OMEGA,
dxes=dxes, dxes=dxes,
axis=axis, axis=0,
polarity=polarity, polarity=polarity,
slices=slices, slices=slices,
epsilon=epsilon, epsilon=epsilon,
) )
return epsilon, dxes, slices, result
def build_waveguide_cyl_fixture(
*,
nonuniform: bool = False,
) -> tuple[list[list[numpy.ndarray]], numpy.ndarray, float]:
if nonuniform:
dxes = [
[numpy.array([1.0, 1.5, 1.2, 0.8, 1.1]), numpy.ones(5)],
[numpy.array([0.9, 1.4, 1.0, 0.7, 1.2]), numpy.ones(5)],
]
else:
dxes = [[numpy.ones(5), numpy.ones(5)] for _ in range(2)]
epsilon = vec(numpy.ones((3, 5, 5), dtype=float))
return dxes, epsilon, 10.0
def test_waveguide_3d_solve_mode_and_expand_e_are_phase_consistent() -> None:
epsilon, dxes, slices, result = build_waveguide_3d_mode(slice_start=0, polarity=1)
axis = 0
polarity = 1
expanded = waveguide_3d.expand_e( expanded = waveguide_3d.expand_e(
E=result['E'], E=result['E'],
wavenumber=result['wavenumber'], wavenumber=result['wavenumber'],
@ -55,11 +82,88 @@ def test_waveguide_3d_solve_mode_and_expand_e_are_phase_consistent() -> None:
numpy.testing.assert_allclose(ratios, expected_ratio, rtol=1e-6, atol=1e-9) numpy.testing.assert_allclose(ratios, expected_ratio, rtol=1e-6, atol=1e-9)
@pytest.mark.parametrize(
('polarity', 'expected_range'),
[(1, (0, 1)), (-1, (3, 4))],
)
def test_waveguide_3d_compute_overlap_e_uses_adjacent_window(
polarity: int,
expected_range: tuple[int, int],
) -> None:
_epsilon, dxes, slices, result = build_waveguide_3d_mode(slice_start=2, polarity=polarity)
with warnings.catch_warnings(record=True) as caught:
overlap = waveguide_3d.compute_overlap_e(
E=result['E'],
wavenumber=result['wavenumber'],
dxes=dxes,
axis=0,
polarity=polarity,
slices=slices,
omega=OMEGA,
)
nonzero = numpy.argwhere(numpy.abs(overlap) > 0)
assert not caught
assert numpy.isfinite(overlap).all()
assert nonzero[:, 1].min() == expected_range[0]
assert nonzero[:, 1].max() == expected_range[1]
@pytest.mark.parametrize(
('polarity', 'slice_start', 'expected_index'),
[(1, 1, 0), (-1, 3, 4)],
)
def test_waveguide_3d_compute_overlap_e_warns_when_window_is_clipped(
polarity: int,
slice_start: int,
expected_index: int,
) -> None:
_epsilon, dxes, slices, result = build_waveguide_3d_mode(slice_start=slice_start, polarity=polarity)
with pytest.warns(RuntimeWarning, match='clipped'):
overlap = waveguide_3d.compute_overlap_e(
E=result['E'],
wavenumber=result['wavenumber'],
dxes=dxes,
axis=0,
polarity=polarity,
slices=slices,
omega=OMEGA,
)
nonzero = numpy.argwhere(numpy.abs(overlap) > 0)
assert numpy.isfinite(overlap).all()
assert nonzero[:, 1].min() == expected_index
assert nonzero[:, 1].max() == expected_index
@pytest.mark.parametrize(
('polarity', 'slice_start'),
[(1, 0), (-1, 4)],
)
def test_waveguide_3d_compute_overlap_e_rejects_empty_overlap_window(
polarity: int,
slice_start: int,
) -> None:
_epsilon, dxes, slices, result = build_waveguide_3d_mode(slice_start=slice_start, polarity=polarity)
with pytest.raises(ValueError, match='outside the domain'):
waveguide_3d.compute_overlap_e(
E=result['E'],
wavenumber=result['wavenumber'],
dxes=dxes,
axis=0,
polarity=polarity,
slices=slices,
omega=OMEGA,
)
def test_waveguide_cyl_solved_modes_are_ordered_and_low_residual() -> None: def test_waveguide_cyl_solved_modes_are_ordered_and_low_residual() -> None:
shape = (5, 5) dxes, epsilon, rmin = build_waveguide_cyl_fixture()
dxes = [[numpy.ones(shape[0]), numpy.ones(shape[1])] for _ in range(2)]
epsilon = vec(numpy.ones((3, *shape), dtype=float))
rmin = 10.0
e_xys, angular_wavenumbers = waveguide_cyl.solve_modes( e_xys, angular_wavenumbers = waveguide_cyl.solve_modes(
[0, 1], [0, 1],
@ -79,9 +183,7 @@ def test_waveguide_cyl_solved_modes_are_ordered_and_low_residual() -> None:
def test_waveguide_cyl_linear_wavenumbers_are_finite_and_ordered() -> None: def test_waveguide_cyl_linear_wavenumbers_are_finite_and_ordered() -> None:
shape = (5, 5) dxes, epsilon, rmin = build_waveguide_cyl_fixture()
dxes = [[numpy.ones(shape[0]), numpy.ones(shape[1])] for _ in range(2)]
epsilon = vec(numpy.ones((3, *shape), dtype=float))
e_xys, angular_wavenumbers = waveguide_cyl.solve_modes( e_xys, angular_wavenumbers = waveguide_cyl.solve_modes(
[0, 1], [0, 1],
@ -95,9 +197,88 @@ def test_waveguide_cyl_linear_wavenumbers_are_finite_and_ordered() -> None:
angular_wavenumbers, angular_wavenumbers,
epsilon=epsilon, epsilon=epsilon,
dxes=dxes, dxes=dxes,
rmin=10.0, rmin=rmin,
) )
assert numpy.isfinite(linear_wavenumbers).all() assert numpy.isfinite(linear_wavenumbers).all()
assert numpy.all(numpy.real(linear_wavenumbers) > 0) assert numpy.all(numpy.real(linear_wavenumbers) > 0)
assert numpy.all(numpy.diff(numpy.real(linear_wavenumbers)) <= 0) assert numpy.all(numpy.diff(numpy.real(linear_wavenumbers)) <= 0)
def test_waveguide_cyl_dxes2t_matches_expected_radius_scaling() -> None:
dxes, _epsilon, rmin = build_waveguide_cyl_fixture(nonuniform=True)
Ta, Tb = waveguide_cyl.dxes2T(dxes, rmin)
ta = (rmin + numpy.cumsum(dxes[0][0])) / rmin
tb = (rmin + dxes[0][0] / 2 + numpy.cumsum(dxes[1][0])) / rmin
numpy.testing.assert_allclose(Ta.diagonal(), numpy.repeat(ta, dxes[0][1].size))
numpy.testing.assert_allclose(Tb.diagonal(), numpy.repeat(tb, dxes[1][1].size))
def test_waveguide_cyl_exy2e_and_exy2h_return_finite_full_fields() -> None:
dxes, epsilon, rmin = build_waveguide_cyl_fixture()
mu = vec(2 * numpy.ones((3, 5, 5), dtype=float))
e_xy, angular_wavenumber = waveguide_cyl.solve_mode(
0,
omega=OMEGA,
dxes=dxes,
epsilon=epsilon,
rmin=rmin,
)
e_field = waveguide_cyl.exy2e(
angular_wavenumber=angular_wavenumber,
omega=OMEGA,
dxes=dxes,
rmin=rmin,
epsilon=epsilon,
) @ e_xy
h_field = waveguide_cyl.exy2h(
angular_wavenumber=angular_wavenumber,
omega=OMEGA,
dxes=dxes,
rmin=rmin,
epsilon=epsilon,
mu=mu,
) @ e_xy
assert e_field.shape == (3 * 25,)
assert h_field.shape == (3 * 25,)
assert numpy.isfinite(e_field).all()
assert numpy.isfinite(h_field).all()
assert unvec(e_field, (5, 5)).shape == (3, 5, 5)
assert unvec(h_field, (5, 5)).shape == (3, 5, 5)
@pytest.mark.parametrize('use_mu', [False, True])
def test_waveguide_cyl_normalized_fields_are_unit_norm_and_silent(use_mu: bool) -> None:
dxes, epsilon, rmin = build_waveguide_cyl_fixture()
mu = vec(2 * numpy.ones((3, 5, 5), dtype=float)) if use_mu else None
e_xy, angular_wavenumber = waveguide_cyl.solve_mode(
0,
omega=OMEGA,
dxes=dxes,
epsilon=epsilon,
rmin=rmin,
)
output = io.StringIO()
with contextlib.redirect_stdout(output):
e_field, h_field = waveguide_cyl.normalized_fields_e(
e_xy,
angular_wavenumber=angular_wavenumber,
omega=OMEGA,
dxes=dxes,
rmin=rmin,
epsilon=epsilon,
mu=mu,
)
overlap = waveguide_2d.inner_product(e_field, h_field, dxes, conj_h=True)
assert output.getvalue() == ''
assert numpy.isfinite(e_field).all()
assert numpy.isfinite(h_field).all()
assert abs(overlap.real - 1.0) < 1e-10
assert abs(overlap.imag) < 1e-10