From f35b334100c28e2fd4b0e8787517f41360f7430a Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Fri, 17 Apr 2026 20:44:36 -0700 Subject: [PATCH] [fdfd.waveguide_3d] improve handling of out-of-bounds overlap_e windows --- meanas/fdfd/waveguide_3d.py | 24 ++- meanas/test/test_waveguide_mode_helpers.py | 213 +++++++++++++++++++-- 2 files changed, 218 insertions(+), 19 deletions(-) diff --git a/meanas/fdfd/waveguide_3d.py b/meanas/fdfd/waveguide_3d.py index 61ed38e..19975db 100644 --- a/meanas/fdfd/waveguide_3d.py +++ b/meanas/fdfd/waveguide_3d.py @@ -5,6 +5,8 @@ This module relies heavily on `waveguide_2d` and mostly just transforms its parameters into 2D equivalents and expands the results back into 3D. """ from typing import Any, cast +import warnings +from typing import Any from collections.abc import Sequence import numpy from numpy.typing import NDArray @@ -200,17 +202,33 @@ def compute_overlap_e( Ee = expand_e(E=E, wavenumber=wavenumber, dxes=dxes, axis=axis, polarity=polarity, slices=slices) - start, stop = sorted((slices[axis].start, slices[axis].start - 2 * polarity)) + axis_size = E.shape[axis + 1] + if polarity > 0: + start = slices[axis].start - 2 + stop = slices[axis].start + else: + start = slices[axis].stop + stop = slices[axis].stop + 2 + + clipped_start = max(0, start) + clipped_stop = min(axis_size, stop) + if clipped_start >= clipped_stop: + raise ValueError('Requested overlap window lies outside the domain') + if clipped_start != start or clipped_stop != stop: + warnings.warn('Requested overlap window was clipped to fit within the domain', RuntimeWarning) slices2_l = list(slices) - slices2_l[axis] = slice(start, stop) + slices2_l[axis] = slice(clipped_start, clipped_stop) slices2 = (slice(None), *slices2_l) Etgt = numpy.zeros_like(Ee) Etgt[slices2] = Ee[slices2] # Note: We normalize so that (Etgt @ E.conj()) == 1, so (Etgt @ Etgt.conj) != 1 - Etgt /= (Etgt.conj() * Etgt).sum() + norm = (Etgt.conj() * Etgt).sum() + if norm == 0: + raise ValueError('Requested overlap window contains no overlap field support') + Etgt /= norm return cfdfield_t(Etgt) diff --git a/meanas/test/test_waveguide_mode_helpers.py b/meanas/test/test_waveguide_mode_helpers.py index 2bf77f2..7bbcd88 100644 --- a/meanas/test/test_waveguide_mode_helpers.py +++ b/meanas/test/test_waveguide_mode_helpers.py @@ -1,29 +1,56 @@ +import contextlib +import io import numpy from numpy.linalg import norm +import pytest +import warnings -from ..fdmath import vec -from ..fdfd import waveguide_3d, waveguide_cyl +from ..fdmath import vec, unvec +from ..fdfd import waveguide_2d, waveguide_3d, waveguide_cyl OMEGA = 1 / 1500 -def test_waveguide_3d_solve_mode_and_expand_e_are_phase_consistent() -> None: +def build_waveguide_3d_mode( + *, + slice_start: int, + polarity: int, + ) -> tuple[numpy.ndarray, list[list[numpy.ndarray]], tuple[slice, slice, slice], dict[str, complex | numpy.ndarray]]: epsilon = numpy.ones((3, 5, 5, 1), dtype=float) dxes = [[numpy.ones(5), numpy.ones(5), numpy.ones(1)] for _ in range(2)] - axis = 0 - polarity = 1 - slices = (slice(0, 1), slice(None), slice(None)) - + slices = (slice(slice_start, slice_start + 1), slice(None), slice(None)) result = waveguide_3d.solve_mode( 0, omega=OMEGA, dxes=dxes, - axis=axis, + axis=0, polarity=polarity, slices=slices, epsilon=epsilon, ) + return epsilon, dxes, slices, result + + +def build_waveguide_cyl_fixture( + *, + nonuniform: bool = False, + ) -> tuple[list[list[numpy.ndarray]], numpy.ndarray, float]: + if nonuniform: + dxes = [ + [numpy.array([1.0, 1.5, 1.2, 0.8, 1.1]), numpy.ones(5)], + [numpy.array([0.9, 1.4, 1.0, 0.7, 1.2]), numpy.ones(5)], + ] + else: + dxes = [[numpy.ones(5), numpy.ones(5)] for _ in range(2)] + epsilon = vec(numpy.ones((3, 5, 5), dtype=float)) + return dxes, epsilon, 10.0 + + +def test_waveguide_3d_solve_mode_and_expand_e_are_phase_consistent() -> None: + epsilon, dxes, slices, result = build_waveguide_3d_mode(slice_start=0, polarity=1) + axis = 0 + polarity = 1 expanded = waveguide_3d.expand_e( E=result['E'], wavenumber=result['wavenumber'], @@ -55,11 +82,88 @@ def test_waveguide_3d_solve_mode_and_expand_e_are_phase_consistent() -> None: numpy.testing.assert_allclose(ratios, expected_ratio, rtol=1e-6, atol=1e-9) +@pytest.mark.parametrize( + ('polarity', 'expected_range'), + [(1, (0, 1)), (-1, (3, 4))], + ) +def test_waveguide_3d_compute_overlap_e_uses_adjacent_window( + polarity: int, + expected_range: tuple[int, int], + ) -> None: + _epsilon, dxes, slices, result = build_waveguide_3d_mode(slice_start=2, polarity=polarity) + + with warnings.catch_warnings(record=True) as caught: + overlap = waveguide_3d.compute_overlap_e( + E=result['E'], + wavenumber=result['wavenumber'], + dxes=dxes, + axis=0, + polarity=polarity, + slices=slices, + omega=OMEGA, + ) + + nonzero = numpy.argwhere(numpy.abs(overlap) > 0) + + assert not caught + assert numpy.isfinite(overlap).all() + assert nonzero[:, 1].min() == expected_range[0] + assert nonzero[:, 1].max() == expected_range[1] + + +@pytest.mark.parametrize( + ('polarity', 'slice_start', 'expected_index'), + [(1, 1, 0), (-1, 3, 4)], + ) +def test_waveguide_3d_compute_overlap_e_warns_when_window_is_clipped( + polarity: int, + slice_start: int, + expected_index: int, + ) -> None: + _epsilon, dxes, slices, result = build_waveguide_3d_mode(slice_start=slice_start, polarity=polarity) + + with pytest.warns(RuntimeWarning, match='clipped'): + overlap = waveguide_3d.compute_overlap_e( + E=result['E'], + wavenumber=result['wavenumber'], + dxes=dxes, + axis=0, + polarity=polarity, + slices=slices, + omega=OMEGA, + ) + + nonzero = numpy.argwhere(numpy.abs(overlap) > 0) + + assert numpy.isfinite(overlap).all() + assert nonzero[:, 1].min() == expected_index + assert nonzero[:, 1].max() == expected_index + + +@pytest.mark.parametrize( + ('polarity', 'slice_start'), + [(1, 0), (-1, 4)], + ) +def test_waveguide_3d_compute_overlap_e_rejects_empty_overlap_window( + polarity: int, + slice_start: int, + ) -> None: + _epsilon, dxes, slices, result = build_waveguide_3d_mode(slice_start=slice_start, polarity=polarity) + + with pytest.raises(ValueError, match='outside the domain'): + waveguide_3d.compute_overlap_e( + E=result['E'], + wavenumber=result['wavenumber'], + dxes=dxes, + axis=0, + polarity=polarity, + slices=slices, + omega=OMEGA, + ) + + def test_waveguide_cyl_solved_modes_are_ordered_and_low_residual() -> None: - shape = (5, 5) - dxes = [[numpy.ones(shape[0]), numpy.ones(shape[1])] for _ in range(2)] - epsilon = vec(numpy.ones((3, *shape), dtype=float)) - rmin = 10.0 + dxes, epsilon, rmin = build_waveguide_cyl_fixture() e_xys, angular_wavenumbers = waveguide_cyl.solve_modes( [0, 1], @@ -79,9 +183,7 @@ def test_waveguide_cyl_solved_modes_are_ordered_and_low_residual() -> None: def test_waveguide_cyl_linear_wavenumbers_are_finite_and_ordered() -> None: - shape = (5, 5) - dxes = [[numpy.ones(shape[0]), numpy.ones(shape[1])] for _ in range(2)] - epsilon = vec(numpy.ones((3, *shape), dtype=float)) + dxes, epsilon, rmin = build_waveguide_cyl_fixture() e_xys, angular_wavenumbers = waveguide_cyl.solve_modes( [0, 1], @@ -95,9 +197,88 @@ def test_waveguide_cyl_linear_wavenumbers_are_finite_and_ordered() -> None: angular_wavenumbers, epsilon=epsilon, dxes=dxes, - rmin=10.0, + rmin=rmin, ) assert numpy.isfinite(linear_wavenumbers).all() assert numpy.all(numpy.real(linear_wavenumbers) > 0) assert numpy.all(numpy.diff(numpy.real(linear_wavenumbers)) <= 0) + + +def test_waveguide_cyl_dxes2t_matches_expected_radius_scaling() -> None: + dxes, _epsilon, rmin = build_waveguide_cyl_fixture(nonuniform=True) + Ta, Tb = waveguide_cyl.dxes2T(dxes, rmin) + + ta = (rmin + numpy.cumsum(dxes[0][0])) / rmin + tb = (rmin + dxes[0][0] / 2 + numpy.cumsum(dxes[1][0])) / rmin + + numpy.testing.assert_allclose(Ta.diagonal(), numpy.repeat(ta, dxes[0][1].size)) + numpy.testing.assert_allclose(Tb.diagonal(), numpy.repeat(tb, dxes[1][1].size)) + + +def test_waveguide_cyl_exy2e_and_exy2h_return_finite_full_fields() -> None: + dxes, epsilon, rmin = build_waveguide_cyl_fixture() + mu = vec(2 * numpy.ones((3, 5, 5), dtype=float)) + e_xy, angular_wavenumber = waveguide_cyl.solve_mode( + 0, + omega=OMEGA, + dxes=dxes, + epsilon=epsilon, + rmin=rmin, + ) + + e_field = waveguide_cyl.exy2e( + angular_wavenumber=angular_wavenumber, + omega=OMEGA, + dxes=dxes, + rmin=rmin, + epsilon=epsilon, + ) @ e_xy + h_field = waveguide_cyl.exy2h( + angular_wavenumber=angular_wavenumber, + omega=OMEGA, + dxes=dxes, + rmin=rmin, + epsilon=epsilon, + mu=mu, + ) @ e_xy + + assert e_field.shape == (3 * 25,) + assert h_field.shape == (3 * 25,) + assert numpy.isfinite(e_field).all() + assert numpy.isfinite(h_field).all() + assert unvec(e_field, (5, 5)).shape == (3, 5, 5) + assert unvec(h_field, (5, 5)).shape == (3, 5, 5) + + +@pytest.mark.parametrize('use_mu', [False, True]) +def test_waveguide_cyl_normalized_fields_are_unit_norm_and_silent(use_mu: bool) -> None: + dxes, epsilon, rmin = build_waveguide_cyl_fixture() + mu = vec(2 * numpy.ones((3, 5, 5), dtype=float)) if use_mu else None + e_xy, angular_wavenumber = waveguide_cyl.solve_mode( + 0, + omega=OMEGA, + dxes=dxes, + epsilon=epsilon, + rmin=rmin, + ) + + output = io.StringIO() + with contextlib.redirect_stdout(output): + e_field, h_field = waveguide_cyl.normalized_fields_e( + e_xy, + angular_wavenumber=angular_wavenumber, + omega=OMEGA, + dxes=dxes, + rmin=rmin, + epsilon=epsilon, + mu=mu, + ) + + overlap = waveguide_2d.inner_product(e_field, h_field, dxes, conj_h=True) + + assert output.getvalue() == '' + assert numpy.isfinite(e_field).all() + assert numpy.isfinite(h_field).all() + assert abs(overlap.real - 1.0) < 1e-10 + assert abs(overlap.imag) < 1e-10