[fdfd.waveguide_3d] improve handling of out-of-bounds overlap_e windows
This commit is contained in:
parent
593098bf8f
commit
f35b334100
2 changed files with 218 additions and 19 deletions
|
|
@ -5,6 +5,8 @@ This module relies heavily on `waveguide_2d` and mostly just transforms
|
|||
its parameters into 2D equivalents and expands the results back into 3D.
|
||||
"""
|
||||
from typing import Any, cast
|
||||
import warnings
|
||||
from typing import Any
|
||||
from collections.abc import Sequence
|
||||
import numpy
|
||||
from numpy.typing import NDArray
|
||||
|
|
@ -200,17 +202,33 @@ def compute_overlap_e(
|
|||
Ee = expand_e(E=E, wavenumber=wavenumber, dxes=dxes,
|
||||
axis=axis, polarity=polarity, slices=slices)
|
||||
|
||||
start, stop = sorted((slices[axis].start, slices[axis].start - 2 * polarity))
|
||||
axis_size = E.shape[axis + 1]
|
||||
if polarity > 0:
|
||||
start = slices[axis].start - 2
|
||||
stop = slices[axis].start
|
||||
else:
|
||||
start = slices[axis].stop
|
||||
stop = slices[axis].stop + 2
|
||||
|
||||
clipped_start = max(0, start)
|
||||
clipped_stop = min(axis_size, stop)
|
||||
if clipped_start >= clipped_stop:
|
||||
raise ValueError('Requested overlap window lies outside the domain')
|
||||
if clipped_start != start or clipped_stop != stop:
|
||||
warnings.warn('Requested overlap window was clipped to fit within the domain', RuntimeWarning)
|
||||
|
||||
slices2_l = list(slices)
|
||||
slices2_l[axis] = slice(start, stop)
|
||||
slices2_l[axis] = slice(clipped_start, clipped_stop)
|
||||
slices2 = (slice(None), *slices2_l)
|
||||
|
||||
Etgt = numpy.zeros_like(Ee)
|
||||
Etgt[slices2] = Ee[slices2]
|
||||
|
||||
# Note: We normalize so that (Etgt @ E.conj()) == 1, so (Etgt @ Etgt.conj) != 1
|
||||
Etgt /= (Etgt.conj() * Etgt).sum()
|
||||
norm = (Etgt.conj() * Etgt).sum()
|
||||
if norm == 0:
|
||||
raise ValueError('Requested overlap window contains no overlap field support')
|
||||
Etgt /= norm
|
||||
return cfdfield_t(Etgt)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,29 +1,56 @@
|
|||
import contextlib
|
||||
import io
|
||||
import numpy
|
||||
from numpy.linalg import norm
|
||||
import pytest
|
||||
import warnings
|
||||
|
||||
from ..fdmath import vec
|
||||
from ..fdfd import waveguide_3d, waveguide_cyl
|
||||
from ..fdmath import vec, unvec
|
||||
from ..fdfd import waveguide_2d, waveguide_3d, waveguide_cyl
|
||||
|
||||
|
||||
OMEGA = 1 / 1500
|
||||
|
||||
|
||||
def test_waveguide_3d_solve_mode_and_expand_e_are_phase_consistent() -> None:
|
||||
def build_waveguide_3d_mode(
|
||||
*,
|
||||
slice_start: int,
|
||||
polarity: int,
|
||||
) -> tuple[numpy.ndarray, list[list[numpy.ndarray]], tuple[slice, slice, slice], dict[str, complex | numpy.ndarray]]:
|
||||
epsilon = numpy.ones((3, 5, 5, 1), dtype=float)
|
||||
dxes = [[numpy.ones(5), numpy.ones(5), numpy.ones(1)] for _ in range(2)]
|
||||
axis = 0
|
||||
polarity = 1
|
||||
slices = (slice(0, 1), slice(None), slice(None))
|
||||
|
||||
slices = (slice(slice_start, slice_start + 1), slice(None), slice(None))
|
||||
result = waveguide_3d.solve_mode(
|
||||
0,
|
||||
omega=OMEGA,
|
||||
dxes=dxes,
|
||||
axis=axis,
|
||||
axis=0,
|
||||
polarity=polarity,
|
||||
slices=slices,
|
||||
epsilon=epsilon,
|
||||
)
|
||||
return epsilon, dxes, slices, result
|
||||
|
||||
|
||||
def build_waveguide_cyl_fixture(
|
||||
*,
|
||||
nonuniform: bool = False,
|
||||
) -> tuple[list[list[numpy.ndarray]], numpy.ndarray, float]:
|
||||
if nonuniform:
|
||||
dxes = [
|
||||
[numpy.array([1.0, 1.5, 1.2, 0.8, 1.1]), numpy.ones(5)],
|
||||
[numpy.array([0.9, 1.4, 1.0, 0.7, 1.2]), numpy.ones(5)],
|
||||
]
|
||||
else:
|
||||
dxes = [[numpy.ones(5), numpy.ones(5)] for _ in range(2)]
|
||||
epsilon = vec(numpy.ones((3, 5, 5), dtype=float))
|
||||
return dxes, epsilon, 10.0
|
||||
|
||||
|
||||
def test_waveguide_3d_solve_mode_and_expand_e_are_phase_consistent() -> None:
|
||||
epsilon, dxes, slices, result = build_waveguide_3d_mode(slice_start=0, polarity=1)
|
||||
axis = 0
|
||||
polarity = 1
|
||||
expanded = waveguide_3d.expand_e(
|
||||
E=result['E'],
|
||||
wavenumber=result['wavenumber'],
|
||||
|
|
@ -55,11 +82,88 @@ def test_waveguide_3d_solve_mode_and_expand_e_are_phase_consistent() -> None:
|
|||
numpy.testing.assert_allclose(ratios, expected_ratio, rtol=1e-6, atol=1e-9)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('polarity', 'expected_range'),
|
||||
[(1, (0, 1)), (-1, (3, 4))],
|
||||
)
|
||||
def test_waveguide_3d_compute_overlap_e_uses_adjacent_window(
|
||||
polarity: int,
|
||||
expected_range: tuple[int, int],
|
||||
) -> None:
|
||||
_epsilon, dxes, slices, result = build_waveguide_3d_mode(slice_start=2, polarity=polarity)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
overlap = waveguide_3d.compute_overlap_e(
|
||||
E=result['E'],
|
||||
wavenumber=result['wavenumber'],
|
||||
dxes=dxes,
|
||||
axis=0,
|
||||
polarity=polarity,
|
||||
slices=slices,
|
||||
omega=OMEGA,
|
||||
)
|
||||
|
||||
nonzero = numpy.argwhere(numpy.abs(overlap) > 0)
|
||||
|
||||
assert not caught
|
||||
assert numpy.isfinite(overlap).all()
|
||||
assert nonzero[:, 1].min() == expected_range[0]
|
||||
assert nonzero[:, 1].max() == expected_range[1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('polarity', 'slice_start', 'expected_index'),
|
||||
[(1, 1, 0), (-1, 3, 4)],
|
||||
)
|
||||
def test_waveguide_3d_compute_overlap_e_warns_when_window_is_clipped(
|
||||
polarity: int,
|
||||
slice_start: int,
|
||||
expected_index: int,
|
||||
) -> None:
|
||||
_epsilon, dxes, slices, result = build_waveguide_3d_mode(slice_start=slice_start, polarity=polarity)
|
||||
|
||||
with pytest.warns(RuntimeWarning, match='clipped'):
|
||||
overlap = waveguide_3d.compute_overlap_e(
|
||||
E=result['E'],
|
||||
wavenumber=result['wavenumber'],
|
||||
dxes=dxes,
|
||||
axis=0,
|
||||
polarity=polarity,
|
||||
slices=slices,
|
||||
omega=OMEGA,
|
||||
)
|
||||
|
||||
nonzero = numpy.argwhere(numpy.abs(overlap) > 0)
|
||||
|
||||
assert numpy.isfinite(overlap).all()
|
||||
assert nonzero[:, 1].min() == expected_index
|
||||
assert nonzero[:, 1].max() == expected_index
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('polarity', 'slice_start'),
|
||||
[(1, 0), (-1, 4)],
|
||||
)
|
||||
def test_waveguide_3d_compute_overlap_e_rejects_empty_overlap_window(
|
||||
polarity: int,
|
||||
slice_start: int,
|
||||
) -> None:
|
||||
_epsilon, dxes, slices, result = build_waveguide_3d_mode(slice_start=slice_start, polarity=polarity)
|
||||
|
||||
with pytest.raises(ValueError, match='outside the domain'):
|
||||
waveguide_3d.compute_overlap_e(
|
||||
E=result['E'],
|
||||
wavenumber=result['wavenumber'],
|
||||
dxes=dxes,
|
||||
axis=0,
|
||||
polarity=polarity,
|
||||
slices=slices,
|
||||
omega=OMEGA,
|
||||
)
|
||||
|
||||
|
||||
def test_waveguide_cyl_solved_modes_are_ordered_and_low_residual() -> None:
|
||||
shape = (5, 5)
|
||||
dxes = [[numpy.ones(shape[0]), numpy.ones(shape[1])] for _ in range(2)]
|
||||
epsilon = vec(numpy.ones((3, *shape), dtype=float))
|
||||
rmin = 10.0
|
||||
dxes, epsilon, rmin = build_waveguide_cyl_fixture()
|
||||
|
||||
e_xys, angular_wavenumbers = waveguide_cyl.solve_modes(
|
||||
[0, 1],
|
||||
|
|
@ -79,9 +183,7 @@ def test_waveguide_cyl_solved_modes_are_ordered_and_low_residual() -> None:
|
|||
|
||||
|
||||
def test_waveguide_cyl_linear_wavenumbers_are_finite_and_ordered() -> None:
|
||||
shape = (5, 5)
|
||||
dxes = [[numpy.ones(shape[0]), numpy.ones(shape[1])] for _ in range(2)]
|
||||
epsilon = vec(numpy.ones((3, *shape), dtype=float))
|
||||
dxes, epsilon, rmin = build_waveguide_cyl_fixture()
|
||||
|
||||
e_xys, angular_wavenumbers = waveguide_cyl.solve_modes(
|
||||
[0, 1],
|
||||
|
|
@ -95,9 +197,88 @@ def test_waveguide_cyl_linear_wavenumbers_are_finite_and_ordered() -> None:
|
|||
angular_wavenumbers,
|
||||
epsilon=epsilon,
|
||||
dxes=dxes,
|
||||
rmin=10.0,
|
||||
rmin=rmin,
|
||||
)
|
||||
|
||||
assert numpy.isfinite(linear_wavenumbers).all()
|
||||
assert numpy.all(numpy.real(linear_wavenumbers) > 0)
|
||||
assert numpy.all(numpy.diff(numpy.real(linear_wavenumbers)) <= 0)
|
||||
|
||||
|
||||
def test_waveguide_cyl_dxes2t_matches_expected_radius_scaling() -> None:
|
||||
dxes, _epsilon, rmin = build_waveguide_cyl_fixture(nonuniform=True)
|
||||
Ta, Tb = waveguide_cyl.dxes2T(dxes, rmin)
|
||||
|
||||
ta = (rmin + numpy.cumsum(dxes[0][0])) / rmin
|
||||
tb = (rmin + dxes[0][0] / 2 + numpy.cumsum(dxes[1][0])) / rmin
|
||||
|
||||
numpy.testing.assert_allclose(Ta.diagonal(), numpy.repeat(ta, dxes[0][1].size))
|
||||
numpy.testing.assert_allclose(Tb.diagonal(), numpy.repeat(tb, dxes[1][1].size))
|
||||
|
||||
|
||||
def test_waveguide_cyl_exy2e_and_exy2h_return_finite_full_fields() -> None:
|
||||
dxes, epsilon, rmin = build_waveguide_cyl_fixture()
|
||||
mu = vec(2 * numpy.ones((3, 5, 5), dtype=float))
|
||||
e_xy, angular_wavenumber = waveguide_cyl.solve_mode(
|
||||
0,
|
||||
omega=OMEGA,
|
||||
dxes=dxes,
|
||||
epsilon=epsilon,
|
||||
rmin=rmin,
|
||||
)
|
||||
|
||||
e_field = waveguide_cyl.exy2e(
|
||||
angular_wavenumber=angular_wavenumber,
|
||||
omega=OMEGA,
|
||||
dxes=dxes,
|
||||
rmin=rmin,
|
||||
epsilon=epsilon,
|
||||
) @ e_xy
|
||||
h_field = waveguide_cyl.exy2h(
|
||||
angular_wavenumber=angular_wavenumber,
|
||||
omega=OMEGA,
|
||||
dxes=dxes,
|
||||
rmin=rmin,
|
||||
epsilon=epsilon,
|
||||
mu=mu,
|
||||
) @ e_xy
|
||||
|
||||
assert e_field.shape == (3 * 25,)
|
||||
assert h_field.shape == (3 * 25,)
|
||||
assert numpy.isfinite(e_field).all()
|
||||
assert numpy.isfinite(h_field).all()
|
||||
assert unvec(e_field, (5, 5)).shape == (3, 5, 5)
|
||||
assert unvec(h_field, (5, 5)).shape == (3, 5, 5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_mu', [False, True])
|
||||
def test_waveguide_cyl_normalized_fields_are_unit_norm_and_silent(use_mu: bool) -> None:
|
||||
dxes, epsilon, rmin = build_waveguide_cyl_fixture()
|
||||
mu = vec(2 * numpy.ones((3, 5, 5), dtype=float)) if use_mu else None
|
||||
e_xy, angular_wavenumber = waveguide_cyl.solve_mode(
|
||||
0,
|
||||
omega=OMEGA,
|
||||
dxes=dxes,
|
||||
epsilon=epsilon,
|
||||
rmin=rmin,
|
||||
)
|
||||
|
||||
output = io.StringIO()
|
||||
with contextlib.redirect_stdout(output):
|
||||
e_field, h_field = waveguide_cyl.normalized_fields_e(
|
||||
e_xy,
|
||||
angular_wavenumber=angular_wavenumber,
|
||||
omega=OMEGA,
|
||||
dxes=dxes,
|
||||
rmin=rmin,
|
||||
epsilon=epsilon,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
overlap = waveguide_2d.inner_product(e_field, h_field, dxes, conj_h=True)
|
||||
|
||||
assert output.getvalue() == ''
|
||||
assert numpy.isfinite(e_field).all()
|
||||
assert numpy.isfinite(h_field).all()
|
||||
assert abs(overlap.real - 1.0) < 1e-10
|
||||
assert abs(overlap.imag) < 1e-10
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue