diff --git a/meanas/test/test_waveguide_fdtd_fdfd.py b/meanas/test/test_waveguide_fdtd_fdfd.py new file mode 100644 index 0000000..396dfda --- /dev/null +++ b/meanas/test/test_waveguide_fdtd_fdfd.py @@ -0,0 +1,224 @@ +import dataclasses +from functools import lru_cache + +import numpy + +from .. import fdfd, fdtd +from ..fdmath import vec, unvec +from ..fdfd import functional, scpml, waveguide_3d + + +DT = 0.25 +PERIOD_STEPS = 64 +OMEGA = 2 * numpy.pi / (PERIOD_STEPS * DT) +CPML_THICKNESS = 3 +WARMUP_PERIODS = 9 +ACCUMULATION_PERIODS = 9 +SHAPE = (3, 25, 13, 13) +SOURCE_SLICES = (slice(4, 5), slice(None), slice(None)) +MONITOR_SLICES = (slice(18, 19), slice(None), slice(None)) +CHOSEN_VARIANT = 'base' + + +@dataclasses.dataclass(frozen=True) +class WaveguideCalibrationResult: + variant: str + e_ph: numpy.ndarray + h_ph: numpy.ndarray + j_ph: numpy.ndarray + e_fdfd: numpy.ndarray + h_fdfd: numpy.ndarray + overlap_td: complex + overlap_fd: complex + flux_td: float + flux_fd: float + + @property + def overlap_rel_err(self) -> float: + return float(abs(self.overlap_td - self.overlap_fd) / abs(self.overlap_fd)) + + @property + def overlap_mag_rel_err(self) -> float: + return float(abs(abs(self.overlap_td) - abs(self.overlap_fd)) / abs(self.overlap_fd)) + + @property + def overlap_phase_deg(self) -> float: + return float(abs(numpy.degrees(numpy.angle(self.overlap_td / self.overlap_fd)))) + + @property + def flux_rel_err(self) -> float: + return float(abs(self.flux_td - self.flux_fd) / abs(self.flux_fd)) + + @property + def combined_error(self) -> float: + return self.overlap_mag_rel_err + self.flux_rel_err + + +def _build_base_dxes() -> list[list[numpy.ndarray]]: + return [[numpy.ones(SHAPE[axis + 1]) for axis in range(3)] for _ in range(2)] + + +def _build_stretched_dxes(base_dxes: list[list[numpy.ndarray]]) -> list[list[numpy.ndarray]]: + stretched_dxes = [[dx.copy() for dx in group] for group in base_dxes] + for axis in (0, 1, 2): + for polarity in (-1, 1): + stretched_dxes = scpml.stretch_with_scpml( + stretched_dxes, + axis=axis, + polarity=polarity, + omega=OMEGA, + epsilon_effective=1.0, + thickness=CPML_THICKNESS, + ) + return stretched_dxes + + +def _build_epsilon() -> numpy.ndarray: + epsilon = numpy.ones(SHAPE, dtype=float) + y0 = (SHAPE[2] - 3) // 2 + z0 = (SHAPE[3] - 3) // 2 + epsilon[:, :, y0:y0 + 3, z0:z0 + 3] = 12.0 + return epsilon + + +@lru_cache(maxsize=2) +def _run_straight_waveguide_case(variant: str) -> WaveguideCalibrationResult: + assert variant in ('stretched', 'base') + + epsilon = _build_epsilon() + base_dxes = _build_base_dxes() + stretched_dxes = _build_stretched_dxes(base_dxes) + mode_dxes = stretched_dxes if variant == 'stretched' else base_dxes + + source_mode = waveguide_3d.solve_mode( + 0, + omega=OMEGA, + dxes=mode_dxes, + axis=0, + polarity=1, + slices=SOURCE_SLICES, + epsilon=epsilon, + ) + j_mode = waveguide_3d.compute_source( + E=source_mode['E'], + wavenumber=source_mode['wavenumber'], + omega=OMEGA, + dxes=mode_dxes, + axis=0, + polarity=1, + slices=SOURCE_SLICES, + epsilon=epsilon, + ) + monitor_mode = waveguide_3d.solve_mode( + 0, + omega=OMEGA, + dxes=mode_dxes, + axis=0, + polarity=1, + slices=MONITOR_SLICES, + epsilon=epsilon, + ) + overlap_e = waveguide_3d.compute_overlap_e( + E=monitor_mode['E'], + wavenumber=monitor_mode['wavenumber'], + dxes=mode_dxes, + axis=0, + polarity=1, + slices=MONITOR_SLICES, + omega=OMEGA, + ) + + pml_params = [ + [fdtd.cpml_params(axis=axis, polarity=polarity, dt=DT, thickness=CPML_THICKNESS, epsilon_eff=1.0) + for polarity in (-1, 1)] + for axis in range(3) + ] + update_e, update_h = fdtd.updates_with_cpml(cpml_params=pml_params, dt=DT, dxes=base_dxes, epsilon=epsilon) + + e_field = numpy.zeros_like(epsilon) + h_field = numpy.zeros_like(epsilon) + e_accumulator = numpy.zeros((1, *SHAPE), dtype=complex) + h_accumulator = numpy.zeros((1, *SHAPE), dtype=complex) + j_accumulator = numpy.zeros((1, *SHAPE), dtype=complex) + + warmup_steps = WARMUP_PERIODS * PERIOD_STEPS + accumulation_steps = ACCUMULATION_PERIODS * PERIOD_STEPS + for step in range(warmup_steps + accumulation_steps): + update_e(e_field, h_field, epsilon) + + t_half = (step + 0.5) * DT + j_real = (j_mode.real * numpy.cos(OMEGA * t_half) - j_mode.imag * numpy.sin(OMEGA * t_half)).real + e_field -= DT * j_real / epsilon + + if step >= warmup_steps: + fdtd.accumulate_phasor_j(j_accumulator, OMEGA, DT, j_real, step) + fdtd.accumulate_phasor_e(e_accumulator, OMEGA, DT, e_field, step + 1) + + update_h(e_field, h_field) + + if step >= warmup_steps: + fdtd.accumulate_phasor_h(h_accumulator, OMEGA, DT, h_field, step + 1) + + e_ph = e_accumulator[0] + h_ph = h_accumulator[0] + j_ph = j_accumulator[0] + + e_fdfd = unvec( + fdfd.solvers.generic( + J=vec(j_ph), + omega=OMEGA, + dxes=stretched_dxes, + epsilon=vec(epsilon), + matrix_solver_opts={'atol': 1e-10, 'rtol': 1e-7}, + ), + SHAPE[1:], + ) + h_fdfd = functional.e2h(OMEGA, stretched_dxes)(e_fdfd) + + overlap_td = vec(e_ph) @ vec(overlap_e).conj() + overlap_fd = vec(e_fdfd) @ vec(overlap_e).conj() + + poynting_td = functional.poynting_e_cross_h(stretched_dxes)(e_ph, h_ph.conj()) + poynting_fd = functional.poynting_e_cross_h(stretched_dxes)(e_fdfd, h_fdfd.conj()) + flux_td = float(0.5 * poynting_td[0, MONITOR_SLICES[0], :, :].real.sum()) + flux_fd = float(0.5 * poynting_fd[0, MONITOR_SLICES[0], :, :].real.sum()) + + return WaveguideCalibrationResult( + variant=variant, + e_ph=e_ph, + h_ph=h_ph, + j_ph=j_ph, + e_fdfd=e_fdfd, + h_fdfd=h_fdfd, + overlap_td=overlap_td, + overlap_fd=overlap_fd, + flux_td=flux_td, + flux_fd=flux_fd, + ) + + +def test_straight_waveguide_base_variant_outperforms_stretched_variant() -> None: + base_result = _run_straight_waveguide_case('base') + stretched_result = _run_straight_waveguide_case('stretched') + + assert base_result.variant == CHOSEN_VARIANT + assert base_result.combined_error < stretched_result.combined_error + + +def test_straight_waveguide_fdtd_fdfd_overlap_and_flux_agree() -> None: + result = _run_straight_waveguide_case(CHOSEN_VARIANT) + + assert numpy.isfinite(result.e_ph).all() + assert numpy.isfinite(result.h_ph).all() + assert numpy.isfinite(result.j_ph).all() + assert numpy.isfinite(result.e_fdfd).all() + assert numpy.isfinite(result.h_fdfd).all() + assert abs(result.overlap_td) > 0 + assert abs(result.overlap_fd) > 0 + assert abs(result.flux_td) > 0 + assert abs(result.flux_fd) > 0 + + assert result.overlap_mag_rel_err < 0.01 + assert result.flux_rel_err < 0.01 + assert result.overlap_rel_err < 0.01 + assert result.overlap_phase_deg < 0.5