[tests] add a waveguide scattering test

This commit is contained in:
Jan Petykiewicz 2026-04-18 14:07:15 -07:00
commit 0568e1ba50

View file

@ -18,6 +18,13 @@ SHAPE = (3, 25, 13, 13)
SOURCE_SLICES = (slice(4, 5), slice(None), slice(None)) SOURCE_SLICES = (slice(4, 5), slice(None), slice(None))
MONITOR_SLICES = (slice(18, 19), slice(None), slice(None)) MONITOR_SLICES = (slice(18, 19), slice(None), slice(None))
CHOSEN_VARIANT = 'base' CHOSEN_VARIANT = 'base'
SCATTERING_SHAPE = (3, 35, 15, 15)
SCATTERING_SOURCE_SLICES = (slice(4, 5), slice(None), slice(None))
SCATTERING_REFLECT_SLICES = (slice(10, 11), slice(None), slice(None))
SCATTERING_TRANSMIT_SLICES = (slice(29, 30), slice(None), slice(None))
SCATTERING_STEP_X = 18
SCATTERING_WARMUP_PERIODS = 10
SCATTERING_ACCUMULATION_PERIODS = 10
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
@ -54,8 +61,45 @@ class WaveguideCalibrationResult:
return self.overlap_mag_rel_err + self.flux_rel_err return self.overlap_mag_rel_err + self.flux_rel_err
@dataclasses.dataclass(frozen=True)
class WaveguideScatteringResult:
e_ph: numpy.ndarray
h_ph: numpy.ndarray
j_ph: numpy.ndarray
e_fdfd: numpy.ndarray
h_fdfd: numpy.ndarray
reflected_td: complex
reflected_fd: complex
transmitted_td: complex
transmitted_fd: complex
reflected_flux_td: float
reflected_flux_fd: float
transmitted_flux_td: float
transmitted_flux_fd: float
@property
def reflected_overlap_mag_rel_err(self) -> float:
return float(abs(abs(self.reflected_td) - abs(self.reflected_fd)) / abs(self.reflected_fd))
@property
def transmitted_overlap_mag_rel_err(self) -> float:
return float(abs(abs(self.transmitted_td) - abs(self.transmitted_fd)) / abs(self.transmitted_fd))
@property
def reflected_flux_rel_err(self) -> float:
return float(abs(self.reflected_flux_td - self.reflected_flux_fd) / abs(self.reflected_flux_fd))
@property
def transmitted_flux_rel_err(self) -> float:
return float(abs(self.transmitted_flux_td - self.transmitted_flux_fd) / abs(self.transmitted_flux_fd))
def _build_uniform_dxes(shape: tuple[int, int, int, int]) -> list[list[numpy.ndarray]]:
return [[numpy.ones(shape[axis + 1]) for axis in range(3)] for _ in range(2)]
def _build_base_dxes() -> list[list[numpy.ndarray]]: def _build_base_dxes() -> list[list[numpy.ndarray]]:
return [[numpy.ones(SHAPE[axis + 1]) for axis in range(3)] for _ in range(2)] return _build_uniform_dxes(SHAPE)
def _build_stretched_dxes(base_dxes: list[list[numpy.ndarray]]) -> list[list[numpy.ndarray]]: def _build_stretched_dxes(base_dxes: list[list[numpy.ndarray]]) -> list[list[numpy.ndarray]]:
@ -81,6 +125,23 @@ def _build_epsilon() -> numpy.ndarray:
return epsilon return epsilon
def _build_scattering_epsilon() -> numpy.ndarray:
epsilon = numpy.ones(SCATTERING_SHAPE, dtype=float)
y0 = SCATTERING_SHAPE[2] // 2
z0 = SCATTERING_SHAPE[3] // 2
epsilon[:, :SCATTERING_STEP_X, y0 - 1:y0 + 2, z0 - 1:z0 + 2] = 12.0
epsilon[:, SCATTERING_STEP_X:, y0 - 2:y0 + 3, z0 - 2:z0 + 3] = 12.0
return epsilon
def _build_cpml_params() -> list[list[dict[str, numpy.ndarray | float]]]:
return [
[fdtd.cpml_params(axis=axis, polarity=polarity, dt=DT, thickness=CPML_THICKNESS, epsilon_eff=1.0)
for polarity in (-1, 1)]
for axis in range(3)
]
@lru_cache(maxsize=2) @lru_cache(maxsize=2)
def _run_straight_waveguide_case(variant: str) -> WaveguideCalibrationResult: def _run_straight_waveguide_case(variant: str) -> WaveguideCalibrationResult:
assert variant in ('stretched', 'base') assert variant in ('stretched', 'base')
@ -128,12 +189,7 @@ def _run_straight_waveguide_case(variant: str) -> WaveguideCalibrationResult:
omega=OMEGA, omega=OMEGA,
) )
pml_params = [ update_e, update_h = fdtd.updates_with_cpml(cpml_params=_build_cpml_params(), dt=DT, dxes=base_dxes, epsilon=epsilon)
[fdtd.cpml_params(axis=axis, polarity=polarity, dt=DT, thickness=CPML_THICKNESS, epsilon_eff=1.0)
for polarity in (-1, 1)]
for axis in range(3)
]
update_e, update_h = fdtd.updates_with_cpml(cpml_params=pml_params, dt=DT, dxes=base_dxes, epsilon=epsilon)
e_field = numpy.zeros_like(epsilon) e_field = numpy.zeros_like(epsilon)
h_field = numpy.zeros_like(epsilon) h_field = numpy.zeros_like(epsilon)
@ -197,6 +253,139 @@ def _run_straight_waveguide_case(variant: str) -> WaveguideCalibrationResult:
) )
@lru_cache(maxsize=1)
def _run_width_step_scattering_case() -> WaveguideScatteringResult:
epsilon = _build_scattering_epsilon()
base_dxes = _build_uniform_dxes(SCATTERING_SHAPE)
stretched_dxes = _build_stretched_dxes(base_dxes)
source_mode = waveguide_3d.solve_mode(
0,
omega=OMEGA,
dxes=base_dxes,
axis=0,
polarity=1,
slices=SCATTERING_SOURCE_SLICES,
epsilon=epsilon,
)
j_mode = waveguide_3d.compute_source(
E=source_mode['E'],
wavenumber=source_mode['wavenumber'],
omega=OMEGA,
dxes=base_dxes,
axis=0,
polarity=1,
slices=SCATTERING_SOURCE_SLICES,
epsilon=epsilon,
)
reflected_mode = waveguide_3d.solve_mode(
0,
omega=OMEGA,
dxes=base_dxes,
axis=0,
polarity=-1,
slices=SCATTERING_REFLECT_SLICES,
epsilon=epsilon,
)
reflected_overlap = waveguide_3d.compute_overlap_e(
E=reflected_mode['E'],
wavenumber=reflected_mode['wavenumber'],
dxes=base_dxes,
axis=0,
polarity=-1,
slices=SCATTERING_REFLECT_SLICES,
omega=OMEGA,
)
transmitted_mode = waveguide_3d.solve_mode(
0,
omega=OMEGA,
dxes=base_dxes,
axis=0,
polarity=1,
slices=SCATTERING_TRANSMIT_SLICES,
epsilon=epsilon,
)
transmitted_overlap = waveguide_3d.compute_overlap_e(
E=transmitted_mode['E'],
wavenumber=transmitted_mode['wavenumber'],
dxes=base_dxes,
axis=0,
polarity=1,
slices=SCATTERING_TRANSMIT_SLICES,
omega=OMEGA,
)
update_e, update_h = fdtd.updates_with_cpml(cpml_params=_build_cpml_params(), dt=DT, dxes=base_dxes, epsilon=epsilon)
e_field = numpy.zeros_like(epsilon)
h_field = numpy.zeros_like(epsilon)
e_accumulator = numpy.zeros((1, *SCATTERING_SHAPE), dtype=complex)
h_accumulator = numpy.zeros((1, *SCATTERING_SHAPE), dtype=complex)
j_accumulator = numpy.zeros((1, *SCATTERING_SHAPE), dtype=complex)
warmup_steps = SCATTERING_WARMUP_PERIODS * PERIOD_STEPS
accumulation_steps = SCATTERING_ACCUMULATION_PERIODS * PERIOD_STEPS
for step in range(warmup_steps + accumulation_steps):
update_e(e_field, h_field, epsilon)
t_half = (step + 0.5) * DT
j_real = (j_mode.real * numpy.cos(OMEGA * t_half) - j_mode.imag * numpy.sin(OMEGA * t_half)).real
e_field -= DT * j_real / epsilon
if step >= warmup_steps:
fdtd.accumulate_phasor_j(j_accumulator, OMEGA, DT, j_real, step)
fdtd.accumulate_phasor_e(e_accumulator, OMEGA, DT, e_field, step + 1)
update_h(e_field, h_field)
if step >= warmup_steps:
fdtd.accumulate_phasor_h(h_accumulator, OMEGA, DT, h_field, step + 1)
e_ph = e_accumulator[0]
h_ph = h_accumulator[0]
j_ph = j_accumulator[0]
e_fdfd = unvec(
fdfd.solvers.generic(
J=vec(j_ph),
omega=OMEGA,
dxes=stretched_dxes,
epsilon=vec(epsilon),
matrix_solver_opts={'atol': 1e-10, 'rtol': 1e-7},
),
SCATTERING_SHAPE[1:],
)
h_fdfd = functional.e2h(OMEGA, stretched_dxes)(e_fdfd)
reflected_td = vec(e_ph) @ vec(reflected_overlap).conj()
reflected_fd = vec(e_fdfd) @ vec(reflected_overlap).conj()
transmitted_td = vec(e_ph) @ vec(transmitted_overlap).conj()
transmitted_fd = vec(e_fdfd) @ vec(transmitted_overlap).conj()
poynting_td = functional.poynting_e_cross_h(stretched_dxes)(e_ph, h_ph.conj())
poynting_fd = functional.poynting_e_cross_h(stretched_dxes)(e_fdfd, h_fdfd.conj())
reflected_flux_td = float(0.5 * poynting_td[0, SCATTERING_REFLECT_SLICES[0], :, :].real.sum())
reflected_flux_fd = float(0.5 * poynting_fd[0, SCATTERING_REFLECT_SLICES[0], :, :].real.sum())
transmitted_flux_td = float(0.5 * poynting_td[0, SCATTERING_TRANSMIT_SLICES[0], :, :].real.sum())
transmitted_flux_fd = float(0.5 * poynting_fd[0, SCATTERING_TRANSMIT_SLICES[0], :, :].real.sum())
return WaveguideScatteringResult(
e_ph=e_ph,
h_ph=h_ph,
j_ph=j_ph,
e_fdfd=e_fdfd,
h_fdfd=h_fdfd,
reflected_td=reflected_td,
reflected_fd=reflected_fd,
transmitted_td=transmitted_td,
transmitted_fd=transmitted_fd,
reflected_flux_td=reflected_flux_td,
reflected_flux_fd=reflected_flux_fd,
transmitted_flux_td=transmitted_flux_td,
transmitted_flux_fd=transmitted_flux_fd,
)
def test_straight_waveguide_base_variant_outperforms_stretched_variant() -> None: def test_straight_waveguide_base_variant_outperforms_stretched_variant() -> None:
base_result = _run_straight_waveguide_case('base') base_result = _run_straight_waveguide_case('base')
stretched_result = _run_straight_waveguide_case('stretched') stretched_result = _run_straight_waveguide_case('stretched')
@ -222,3 +411,26 @@ def test_straight_waveguide_fdtd_fdfd_overlap_and_flux_agree() -> None:
assert result.flux_rel_err < 0.01 assert result.flux_rel_err < 0.01
assert result.overlap_rel_err < 0.01 assert result.overlap_rel_err < 0.01
assert result.overlap_phase_deg < 0.5 assert result.overlap_phase_deg < 0.5
def test_width_step_waveguide_fdtd_fdfd_modal_powers_and_flux_agree() -> None:
result = _run_width_step_scattering_case()
assert numpy.isfinite(result.e_ph).all()
assert numpy.isfinite(result.h_ph).all()
assert numpy.isfinite(result.j_ph).all()
assert numpy.isfinite(result.e_fdfd).all()
assert numpy.isfinite(result.h_fdfd).all()
assert abs(result.reflected_td) > 0
assert abs(result.reflected_fd) > 0
assert abs(result.transmitted_td) > 0
assert abs(result.transmitted_fd) > 0
assert abs(result.reflected_flux_td) > 0
assert abs(result.reflected_flux_fd) > 0
assert abs(result.transmitted_flux_td) > 0
assert abs(result.transmitted_flux_fd) > 0
assert result.transmitted_overlap_mag_rel_err < 0.03
assert result.reflected_overlap_mag_rel_err < 0.03
assert result.transmitted_flux_rel_err < 0.01
assert result.reflected_flux_rel_err < 0.01