style and type fixes

This commit is contained in:
jan 2022-11-24 16:12:40 -08:00
parent 77c10feead
commit 689b3176cc
3 changed files with 157 additions and 118 deletions

74
fdtd.py
View File

@ -30,55 +30,64 @@ def perturbed_l3(a: float, radius: float, **kwargs) -> Pattern:
""" """
Generate a masque.Pattern object containing a perturbed L3 cavity. Generate a masque.Pattern object containing a perturbed L3 cavity.
:param a: Lattice constant. Args:
:param radius: Hole radius, in units of a (lattice constant). a: Lattice constant.
:param kwargs: Keyword arguments: radius: Hole radius, in units of a (lattice constant).
hole_dose, trench_dose, hole_layer, trench_layer: Shape properties for Pattern. hole_dose: Dose for all holes. Default 1.
Defaults *_dose=1, hole_layer=0, trench_layer=1. trench_dose: Dose for undercut trenches. Default 1.
shifts_a, shifts_r: passed to pcgen.l3_shift; specifies lattice constant (1 - hole_layer: Layer for holes. Default (0, 0).
multiplicative factor) and radius (multiplicative factor) for shifting trench_layer: Layer for undercut trenches. Default (1, 0).
holes adjacent to the defect (same row). Defaults are 0.15 shift for shifts_a: passed to pcgen.l3_shift
first hole, 0.075 shift for third hole, and no radius change. shifts_r: passed to pcgen.l3_shift
xy_size: [x, y] number of mirror periods in each direction; total size is xy_size: [x, y] number of mirror periods in each direction; total size is
2 * n + 1 holes in each direction. Default [10, 10]. 2 * n + 1 holes in each direction. Default (10, 10).
perturbed_radius: radius of holes perturbed to form an upwards-driected beam perturbed_radius: radius of holes perturbed to form an upwards-driected beam
(multiplicative factor). Default 1.1. (multiplicative factor). Default 1.1.
trench width: Width of the undercut trenches. Default 1.2e3. trench width: Width of the undercut trenches. Default 1.2e3.
:return: masque.Pattern object containing the L3 design
Returns:
`masque.Pattern` object containing the L3 design
""" """
default_args = {'hole_dose': 1, default_args = {
'trench_dose': 1, 'hole_dose': 1,
'hole_layer': 0, 'trench_dose': 1,
'trench_layer': 1, 'hole_layer': 0,
'shifts_a': (0.15, 0, 0.075), 'trench_layer': 1,
'shifts_r': (1.0, 1.0, 1.0), 'shifts_a': (0.15, 0, 0.075),
'xy_size': (10, 10), 'shifts_r': (1.0, 1.0, 1.0),
'perturbed_radius': 1.1, 'xy_size': (10, 10),
'trench_width': 1.2e3, 'perturbed_radius': 1.1,
} 'trench_width': 1.2e3,
}
kwargs = {**default_args, **kwargs} kwargs = {**default_args, **kwargs}
xyr = pcgen.l3_shift_perturbed_defect(mirror_dims=kwargs['xy_size'], xyr = pcgen.l3_shift_perturbed_defect(
perturbed_radius=kwargs['perturbed_radius'], mirror_dims=kwargs['xy_size'],
shifts_a=kwargs['shifts_a'], perturbed_radius=kwargs['perturbed_radius'],
shifts_r=kwargs['shifts_r']) shifts_a=kwargs['shifts_a'],
shifts_r=kwargs['shifts_r'],
)
xyr *= a xyr *= a
xyr[:, 2] *= radius xyr[:, 2] *= radius
pat = Pattern() pat = Pattern()
pat.name = 'L3p-a{:g}r{:g}rp{:g}'.format(a, radius, kwargs['perturbed_radius']) pat.name = f'L3p-a{a:g}r{radius:g}rp{kwargs["perturbed_radius"]:g}'
pat.shapes += [shapes.Circle(radius=r, offset=(x, y), pat.shapes += [shapes.Circle(radius=r, offset=(x, y),
dose=kwargs['hole_dose'], dose=kwargs['hole_dose'],
layer=kwargs['hole_layer']) layer=kwargs['hole_layer'])
for x, y, r in xyr] for x, y, r in xyr]
maxes = numpy.max(numpy.fabs(xyr), axis=0) maxes = numpy.max(numpy.fabs(xyr), axis=0)
pat.shapes += [shapes.Polygon.rectangle( pat.shapes += [
lx=(2 * maxes[0]), ly=kwargs['trench_width'], shapes.Polygon.rectangle(
offset=(0, s * (maxes[1] + a + kwargs['trench_width'] / 2)), lx=(2 * maxes[0]),
dose=kwargs['trench_dose'], layer=kwargs['trench_layer']) ly=kwargs['trench_width'],
for s in (-1, 1)] offset=(0, s * (maxes[1] + a + kwargs['trench_width'] / 2)),
dose=kwargs['trench_dose'],
layer=kwargs['trench_layer'],
)
for s in (-1, 1)]
return pat return pat
@ -226,7 +235,8 @@ def main():
# pml_thickness+m:-pml_thickness-m, :].sum() * dx * dx * dx # pml_thickness+m:-pml_thickness-m, :].sum() * dx * dx * dx
if t % 100 == 0: if t % 100 == 0:
logger.info('iteration {}: average {} iterations per sec'.format(t, (t+1)/(time.perf_counter()-start))) avg = (t + 1) / (time.perf_counter() - start)
logger.info(f'iteration {t}: average {avg} iterations per sec')
sys.stdout.flush() sys.stdout.flush()
with lzma.open('saved_simulation', 'wb') as f: with lzma.open('saved_simulation', 'wb') as f:

View File

@ -2,9 +2,10 @@
Class for constructing and holding the basic FDTD operations and fields Class for constructing and holding the basic FDTD operations and fields
""" """
from typing import List, Dict, Callable from typing import List, Dict, Callable, Type, Union, Optional, Sequence
from collections import OrderedDict from collections import OrderedDict
import numpy import numpy
from numpy.typing import NDArray
import jinja2 import jinja2
import warnings import warnings
@ -31,8 +32,8 @@ class Simulation(object):
pmls = [{'axis': a, 'polarity': p} for a in 'xyz' for p in 'np'] pmls = [{'axis': a, 'polarity': p} for a in 'xyz' for p in 'np']
sim = Simulation(grid.grids, do_poynting=True, pmls=pmls) sim = Simulation(grid.grids, do_poynting=True, pmls=pmls)
with open('sources.c', 'wt') as f: with open('sources.c', 'w') as f:
f.write(repr(sim.sources)) f.write(f'{sim.sources}')
for t in range(max_t): for t in range(max_t):
sim.update_E([]).wait() sim.update_E([]).wait()
@ -56,38 +57,38 @@ class Simulation(object):
event0 and event1 to occur (i.e. previous operations to finish) before starting execution. event0 and event1 to occur (i.e. previous operations to finish) before starting execution.
event2 can then be used to prepare further operations to be run after update_H. event2 can then be used to prepare further operations to be run after update_H.
""" """
E = None # type: pyopencl.array.Array E: pyopencl.array.Array
H = None # type: pyopencl.array.Array H: pyopencl.array.Array
S = None # type: pyopencl.array.Array S: pyopencl.array.Array
eps = None # type: pyopencl.array.Array eps: pyopencl.array.Array
dt = None # type: float dt: float
inv_dxes = None # type: List[pyopencl.array.Array] inv_dxes: List[pyopencl.array.Array]
arg_type = None # type: numpy.float32 or numpy.float64 arg_type: Type
context = None # type: pyopencl.Context context: pyopencl.Context
queue = None # type: pyopencl.CommandQueue queue: pyopencl.CommandQueue
update_E = None # type: Callable[[List[pyopencl.Event]], pyopencl.Event] update_E: Callable[[List[pyopencl.Event]], pyopencl.Event]
update_H = None # type: Callable[[List[pyopencl.Event]], pyopencl.Event] update_H: Callable[[List[pyopencl.Event]], pyopencl.Event]
update_S = None # type: Callable[[List[pyopencl.Event]], pyopencl.Event] update_S: Callable[[List[pyopencl.Event]], pyopencl.Event]
update_J = None # type: Callable[[List[pyopencl.Event]], pyopencl.Event] update_J: Callable[[List[pyopencl.Event]], pyopencl.Event]
sources = None # type: Dict[str, str] sources: Dict[str, str]
def __init__(self, def __init__(
epsilon: List[numpy.ndarray], self,
pmls: List[Dict[str, int or float]], epsilon: NDArray,
bloch_boundaries: List[Dict[str, int or float]] = (), pmls: Sequence[Dict[str, float]],
dxes: List[List[numpy.ndarray]] or float = None, bloch_boundaries: Sequence[Dict[str, float]] = (),
dt: float = None, dxes: Union[List[List[NDArray]], float, None] = None,
initial_fields: Dict[str, List[numpy.ndarray]] = None, dt: Optional[float] = None,
context: pyopencl.Context = None, initial_fields: Optional[Dict[str, NDArray]] = None,
queue: pyopencl.CommandQueue = None, context: Optional[pyopencl.Context] = None,
float_type: numpy.float32 or numpy.float64 = numpy.float32, queue: Optional[pyopencl.CommandQueue] = None,
do_poynting: bool = True, float_type: Type = numpy.float32,
do_poynting_halves: bool = False, do_poynting: bool = True,
do_fieldsrc: bool = False, do_fieldsrc: bool = False,
) -> None: ) -> None:
""" """
Initialize the simulation. Initialize the simulation.
@ -113,14 +114,13 @@ class Simulation(object):
context: pyOpenCL context. If not given, pyopencl.create_some_context(False) is called. context: pyOpenCL context. If not given, pyopencl.create_some_context(False) is called.
queue: pyOpenCL command queue. If not given, pyopencl.CommandQueue(context) is called. queue: pyOpenCL command queue. If not given, pyopencl.CommandQueue(context) is called.
float_type: numpy.float32 or numpy.float64. Default numpy.float32. float_type: numpy.float32 or numpy.float64. Default numpy.float32.
do_poynting: If true, enables calculation of the poynting vector, S. do_poynting: If True, enables calculation of the poynting vector, S.
Poynting vector calculation adds the following computational burdens: Poynting vector calculation adds the following computational burdens:
* During update_H, ~6 extra additions/cell are performed in order to temporally * During update_H, 12 extra additions/cell are performed in order to temporally
sum H. The results are then multiplied by E (6 multiplications/cell) and sum E and H. The results are then multiplied by E (6 multiplications/cell) and
then stored (6 writes/cell, cache-friendly). The E-field components are then stored (6 writes/cell, cache-friendly). The E-field components are
reused from the H-field update and do not require additional H reused from the H-field update and do not require additional H
* GPU memory requirements increase by 50% (for storing S) * GPU memory requirements increase by 50% (for storing S)
do_poynting_halves: TODO DOCUMENT
""" """
if initial_fields is None: if initial_fields is None:
initial_fields = {} initial_fields = {}
@ -147,9 +147,9 @@ class Simulation(object):
if dt is None: if dt is None:
self.dt = max_dt self.dt = max_dt
elif dt > max_dt: elif dt > max_dt:
warnings.warn('Warning: unstable dt: {}'.format(dt)) warnings.warn(f'Warning: unstable dt: {dt}')
elif dt <= 0: elif dt <= 0:
raise Exception('Invalid dt: {}'.format(dt)) raise Exception(f'Invalid dt: {dt}')
else: else:
self.dt = dt self.dt = dt
@ -216,10 +216,12 @@ class Simulation(object):
if bloch_boundaries: if bloch_boundaries:
bloch_args = jinja_args.copy() bloch_args = jinja_args.copy()
bloch_args['do_poynting'] = False bloch_args['do_poynting'] = False
bloch_args['bloch'] = [{'axis': b['axis'], bloch_args['bloch'] = [
'real': b['imag'], {'axis': b['axis'],
'imag': b['real']} 'real': b['imag'],
for b in bloch_boundaries] 'imag': b['real'],
}
for b in bloch_boundaries]
F_source = jinja_env.get_template('update_e.cl').render(**bloch_args) F_source = jinja_env.get_template('update_e.cl').render(**bloch_args)
G_source = jinja_env.get_template('update_h.cl').render(**bloch_args) G_source = jinja_env.get_template('update_h.cl').render(**bloch_args)
self.sources['F'] = F_source self.sources['F'] = F_source
@ -316,15 +318,22 @@ class Simulation(object):
pml_h_fields[ptr(nh)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type) pml_h_fields[ptr(nh)] = pyopencl.array.zeros(self.queue, tuple(psi_shape), dtype=self.arg_type)
return pml_e_fields, pml_h_fields return pml_e_fields, pml_h_fields
def _create_operation(self, source, args_fields): def _create_operation(self, source, args_fields) -> Callable[..., pyopencl.Event]:
args = OrderedDict() args = OrderedDict()
[args.update(d) for d in args_fields] for d in args_fields:
update = ElementwiseKernel(self.context, operation=source, args.update(d)
arguments=', '.join(args.keys())) update = ElementwiseKernel(
self.context,
operation=source,
arguments=', '.join(args.keys()),
)
return lambda e: update(*args.values(), wait_for=e) return lambda e: update(*args.values(), wait_for=e)
def _create_context(self, context: pyopencl.Context = None, def _create_context(
queue: pyopencl.CommandQueue = None): self,
context: Optional[pyopencl.Context] = None,
queue: Optional[pyopencl.CommandQueue] = None,
) -> None:
if context is None: if context is None:
self.context = pyopencl.create_some_context() self.context = pyopencl.create_some_context()
else: else:
@ -335,16 +344,16 @@ class Simulation(object):
else: else:
self.queue = queue self.queue = queue
def _create_eps(self, epsilon: List[numpy.ndarray]): def _create_eps(self, epsilon: NDArray) -> pyopencl.array.Array:
if len(epsilon) != 3: if len(epsilon) != 3:
raise Exception('Epsilon must be a list with length of 3') raise Exception('Epsilon must be a list with length of 3')
if not all((e.shape == epsilon[0].shape for e in epsilon[1:])): if not all((e.shape == epsilon[0].shape for e in epsilon[1:])):
raise Exception('All epsilon grids must have the same shape. Shapes are {}', [e.shape for e in epsilon]) raise Exception('All epsilon grids must have the same shape. Shapes are {}', [e.shape for e in epsilon])
if not epsilon[0].shape == self.shape: if not epsilon[0].shape == self.shape:
raise Exception('Epsilon shape mismatch. Expected {}, got {}'.format(self.shape, epsilon[0].shape)) raise Exception(f'Epsilon shape mismatch. Expected {self.shape}, got {epsilon[0].shape}')
self.eps = pyopencl.array.to_device(self.queue, vec(epsilon).astype(self.arg_type)) self.eps = pyopencl.array.to_device(self.queue, vec(epsilon).astype(self.arg_type))
def _create_field(self, initial_value: List[numpy.ndarray] = None): def _create_field(self, initial_value: Optional[NDArray] = None) -> pyopencl.array.Array:
if initial_value is None: if initial_value is None:
return pyopencl.array.zeros_like(self.eps) return pyopencl.array.zeros_like(self.eps)
else: else:
@ -355,23 +364,30 @@ class Simulation(object):
return pyopencl.array.to_device(self.queue, vec(initial_value).astype(self.arg_type)) return pyopencl.array.to_device(self.queue, vec(initial_value).astype(self.arg_type))
def type_to_C(float_type: numpy.dtype) -> str: def type_to_C(
float_type: Type,
) -> str:
""" """
Returns a string corresponding to the C equivalent of a numpy type. Returns a string corresponding to the C equivalent of a numpy type.
Only works for float16, float32, float64. Only works for float16, float32, float64.
:param float_type: e.g. numpy.float32 Args:
:return: string containing the corresponding C type (eg. 'double') float_type: e.g. numpy.float32
Returns:
string containing the corresponding C type (eg. 'double')
""" """
if float_type == numpy.float16: types = {
arg_type = 'half' numpy.float16: 'half',
elif float_type == numpy.float32: numpy.float32: 'float',
arg_type = 'float' numpy.float64: 'double',
elif float_type == numpy.float64: numpy.complex64: 'cfloat_t',
arg_type = 'double' numpy.complex128: 'cdouble_t',
else: }
raise Exception('Unsupported type') if float_type not in types:
return arg_type raise Exception(f'Unsupported type: {float_type}')
return types[float_type]
# def par(x): # def par(x):
# scaling = ((x / (pml['thickness'])) ** pml['m']) # scaling = ((x / (pml['thickness'])) ** pml['m'])

View File

@ -2,7 +2,7 @@ import unittest
import numpy import numpy
from opencl_fdtd import Simulation from opencl_fdtd import Simulation
from fdfd_tools import fdtd from meanas import fdtd
class BasicTests(): class BasicTests():
@ -25,16 +25,18 @@ class BasicTests():
dxes = self.dxes if self.dxes is not None else tuple(tuple(numpy.ones(s) for s in e0.shape[1:]) for _ in range(2)) dxes = self.dxes if self.dxes is not None else tuple(tuple(numpy.ones(s) for s in e0.shape[1:]) for _ in range(2))
dV = numpy.prod(numpy.meshgrid(*dxes[0], indexing='ij'), axis=0) dV = numpy.prod(numpy.meshgrid(*dxes[0], indexing='ij'), axis=0)
u0 = self.j_mag * self.j_mag / self.epsilon[self.src_mask] * dV[mask] u0 = self.j_mag * self.j_mag / self.epsilon[self.src_mask] * dV[mask]
args = {'dxes': self.dxes, args = {
'epsilon': self.epsilon} 'dxes': self.dxes,
'epsilon': self.epsilon,
}
# Make sure initial energy and E dot J are correct # Make sure initial energy and E dot J are correct
energy0 = fdtd.energy_estep(h0=h0, e1=e0, h2=self.hs[1], **args) energy0 = fdtd.energy_estep(h0=h0, e1=e0, h2=self.hs[1], **args)
e_dot_j_0 = fdtd.delta_energy_j(j0=(e0 - 0) * self.epsilon, e1=e0, dxes=self.dxes) e_dot_j_0 = fdtd.delta_energy_j(j0=(e0 - 0) * self.epsilon, e1=e0, dxes=self.dxes)
self.assertTrue(numpy.allclose(energy0[mask], u0)) self.assertTrue(numpy.allclose(energy0[mask], u0))
self.assertFalse(energy0[~mask].any(), msg='energy0: {}'.format(energy0)) self.assertFalse(energy0[~mask].any(), msg=f'{energy0=}')
self.assertTrue(numpy.allclose(e_dot_j_0[mask], u0)) self.assertTrue(numpy.allclose(e_dot_j_0[mask], u0))
self.assertFalse(e_dot_j_0[~mask].any(), msg='e_dot_j_0: {}'.format(e_dot_j_0)) self.assertFalse(e_dot_j_0[~mask].any(), msg=f'{e_dot_j_0=}')
def test_energy_conservation(self): def test_energy_conservation(self):
@ -47,22 +49,25 @@ class BasicTests():
with self.subTest(i=ii): with self.subTest(i=ii):
u_hstep = fdtd.energy_hstep(e0=self.es[ii-1], h1=self.hs[ii], e2=self.es[ii], **args) u_hstep = fdtd.energy_hstep(e0=self.es[ii-1], h1=self.hs[ii], e2=self.es[ii], **args)
u_estep = fdtd.energy_estep(h0=self.hs[ii], e1=self.es[ii], h2=self.hs[ii + 1], **args) u_estep = fdtd.energy_estep(h0=self.hs[ii], e1=self.es[ii], h2=self.hs[ii + 1], **args)
self.assertTrue(numpy.allclose(u_hstep.sum(), u0), msg='u_hstep: {}\n{}'.format(u_hstep.sum(), numpy.rollaxis(u_hstep, -1))) self.assertTrue(numpy.allclose(u_hstep.sum(), u0), msg=f'u_hstep: {u_hstep.sum()}\n{numpy.moveaxis(u_hstep, -1, 0)}')
self.assertTrue(numpy.allclose(u_estep.sum(), u0), msg='u_estep: {}\n{}'.format(u_estep.sum(), numpy.rollaxis(u_estep, -1))) self.assertTrue(numpy.allclose(u_estep.sum(), u0), msg=f'u_estep: {u_estep.sum()}\n{numpy.moveaxis(u_estep, -1, 0)}')
def test_poynting(self): def test_poynting(self):
for ii in range(1, 3): for ii in range(1, 3):
with self.subTest(i=ii): with self.subTest(i=ii):
s = fdtd.poynting(e=self.es[ii], h=self.hs[ii+1] + self.hs[ii]) s = fdtd.poynting(e=self.es[ii], h=self.hs[ii+1] + self.hs[ii])
sf = numpy.moveaxis(s, -1, 0)
ss = numpy.moveaxis(self.ss[ii], -1, 0)
self.assertTrue(numpy.allclose(s, self.ss[ii], rtol=1e-4), self.assertTrue(numpy.allclose(s, self.ss[ii], rtol=1e-4),
msg='From ExH:\n{}\nFrom sim.S:\n{}'.format(numpy.rollaxis(s, -1), msg=f'From ExH:\n{sf}\nFrom sim.S:\n{ss}')
numpy.rollaxis(self.ss[ii], -1)))
def test_poynting_divergence(self): def test_poynting_divergence(self):
args = {'dxes': self.dxes, args = {
'epsilon': self.epsilon} 'dxes': self.dxes,
'epsilon': self.epsilon,
}
dxes = self.dxes if self.dxes is not None else tuple(tuple(numpy.ones(s) for s in self.epsilon.shape[1:]) for _ in range(2)) dxes = self.dxes if self.dxes is not None else tuple(tuple(numpy.ones(s) for s in self.epsilon.shape[1:]) for _ in range(2))
dV = numpy.prod(numpy.meshgrid(*dxes[0], indexing='ij'), axis=0) dV = numpy.prod(numpy.meshgrid(*dxes[0], indexing='ij'), axis=0)
@ -75,9 +80,11 @@ class BasicTests():
du_half_h2e = u_estep - u_hstep du_half_h2e = u_estep - u_hstep
div_s_h2e = self.dt * fdtd.poynting_divergence(e=self.es[ii], h=self.hs[ii], dxes=self.dxes) * dV div_s_h2e = self.dt * fdtd.poynting_divergence(e=self.es[ii], h=self.hs[ii], dxes=self.dxes) * dV
du_half_h2e_f = numpy.moveaxis(du_half_h2e, -1, 0)
div_s_h2e_f = -numpy.moveaxis(div_s_h2e, -1, 0)
self.assertTrue(numpy.allclose(du_half_h2e, -div_s_h2e, rtol=1e-4), self.assertTrue(numpy.allclose(du_half_h2e, -div_s_h2e, rtol=1e-4),
msg='du_half_h2e\n{}\ndiv_s_h2e\n{}'.format(numpy.rollaxis(du_half_h2e, -1), msg=f'du_half_h2e\n{du_half_h2e_f}\ndiv_s_h2e\n{div_s_h2e_f}')
-numpy.rollaxis(div_s_h2e, -1)))
if u_eprev is None: if u_eprev is None:
u_eprev = u_estep u_eprev = u_estep
@ -86,15 +93,18 @@ class BasicTests():
# previous half-step # previous half-step
du_half_e2h = u_hstep - u_eprev du_half_e2h = u_hstep - u_eprev
div_s_e2h = self.dt * fdtd.poynting_divergence(e=self.es[ii-1], h=self.hs[ii], dxes=self.dxes) * dV div_s_e2h = self.dt * fdtd.poynting_divergence(e=self.es[ii-1], h=self.hs[ii], dxes=self.dxes) * dV
du_half_e2h_f = numpy.moveaxis(du_half_e2h, -1, 0)
div_s_e2h_f = -numpy.moveaxis(div_s_e2h, -1, 0)
self.assertTrue(numpy.allclose(du_half_e2h, -div_s_e2h, rtol=1e-4), self.assertTrue(numpy.allclose(du_half_e2h, -div_s_e2h, rtol=1e-4),
msg='du_half_e2h\n{}\ndiv_s_e2h\n{}'.format(numpy.rollaxis(du_half_e2h, -1), msg=f'du_half_e2h\n{du_half_e2h_f}\ndiv_s_e2h\n{div_s_e2h_f}')
-numpy.rollaxis(div_s_e2h, -1)))
u_eprev = u_estep u_eprev = u_estep
def test_poynting_planes(self): def test_poynting_planes(self):
args = {'dxes': self.dxes, args = {
'epsilon': self.epsilon} 'dxes': self.dxes,
'epsilon': self.epsilon,
}
dxes = self.dxes if self.dxes is not None else tuple(tuple(numpy.ones(s) for s in self.epsilon.shape[1:]) for _ in range(2)) dxes = self.dxes if self.dxes is not None else tuple(tuple(numpy.ones(s) for s in self.epsilon.shape[1:]) for _ in range(2))
dV = numpy.prod(numpy.meshgrid(*dxes[0], indexing='ij'), axis=0) dV = numpy.prod(numpy.meshgrid(*dxes[0], indexing='ij'), axis=0)
@ -118,8 +128,9 @@ class BasicTests():
planes = [s_h2e[px].sum(), -s_h2e[mx].sum(), planes = [s_h2e[px].sum(), -s_h2e[mx].sum(),
s_h2e[py].sum(), -s_h2e[my].sum(), s_h2e[py].sum(), -s_h2e[my].sum(),
s_h2e[pz].sum(), -s_h2e[mz].sum()] s_h2e[pz].sum(), -s_h2e[mz].sum()]
du = (u_estep - u_hstep)[self.src_mask[1]]
self.assertTrue(numpy.allclose(sum(planes), (u_estep - u_hstep)[self.src_mask[1]]), self.assertTrue(numpy.allclose(sum(planes), (u_estep - u_hstep)[self.src_mask[1]]),
msg='planes: {} (sum: {})\n du:\n {}'.format(planes, sum(planes), (u_estep - u_hstep)[self.src_mask[1]])) msg=f'planes: {planes} (sum: {sum(planes)})\n du:\n {du}')
if u_eprev is None: if u_eprev is None:
u_eprev = u_estep u_eprev = u_estep
@ -132,15 +143,14 @@ class BasicTests():
planes = [s_e2h[px].sum(), -s_e2h[mx].sum(), planes = [s_e2h[px].sum(), -s_e2h[mx].sum(),
s_e2h[py].sum(), -s_e2h[my].sum(), s_e2h[py].sum(), -s_e2h[my].sum(),
s_e2h[pz].sum(), -s_e2h[mz].sum()] s_e2h[pz].sum(), -s_e2h[mz].sum()]
du = (u_hstep - u_eprev)[self.src_mask[1]]
self.assertTrue(numpy.allclose(sum(planes), (u_hstep - u_eprev)[self.src_mask[1]]), self.assertTrue(numpy.allclose(sum(planes), (u_hstep - u_eprev)[self.src_mask[1]]),
msg='planes: {} (sum: {})\n du:\n {}'.format(planes, sum(planes), (u_hstep - u_eprev)[self.src_mask[1]])) msg=f'planes: {du} (sum: {sum(planes)})\n du:\n {du}')
# previous half-step # previous half-step
u_eprev = u_estep u_eprev = u_estep
class Basic2DNoDXOnlyVacuum(unittest.TestCase, BasicTests): class Basic2DNoDXOnlyVacuum(unittest.TestCase, BasicTests):
def setUp(self): def setUp(self):
shape = [3, 5, 5, 1] shape = [3, 5, 5, 1]
@ -348,8 +358,10 @@ class JdotE_3DUniformDX(unittest.TestCase):
e1 = self.es[2] e1 = self.es[2]
j0 = numpy.zeros_like(e0) j0 = numpy.zeros_like(e0)
j0[self.src_mask] = self.j_mag j0[self.src_mask] = self.j_mag
args = {'dxes': self.dxes, args = {
'epsilon': self.epsilon} 'dxes': self.dxes,
'epsilon': self.epsilon,
}
e2h = fdtd.maxwell_h(dt=self.dt, dxes=self.dxes) e2h = fdtd.maxwell_h(dt=self.dt, dxes=self.dxes)
#ee = j0 * (2 * e0 - j0) #ee = j0 * (2 * e0 - j0)
@ -365,4 +377,5 @@ class JdotE_3DUniformDX(unittest.TestCase):
u_hstep = fdtd.energy_hstep(e0=self.es[0], h1=self.hs[1], e2=self.es[1], **args) u_hstep = fdtd.energy_hstep(e0=self.es[0], h1=self.hs[1], e2=self.es[1], **args)
u_estep = fdtd.energy_estep(h0=self.hs[-2], e1=self.es[-2], h2=self.hs[-1], **args) u_estep = fdtd.energy_estep(h0=self.hs[-2], e1=self.es[-2], h2=self.hs[-1], **args)
#breakpoint() #breakpoint()
self.assertTrue(numpy.allclose(u0.sum(), (u_estep - u_hstep).sum()), msg='{} != {}'.format(u0.sum(), (u_estep - u_hstep).sum())) du = (u_estep - u_hstep).sum()
self.assertTrue(numpy.allclose(u0.sum(), (u_estep - u_hstep).sum()), msg=f'{u0.sum()} != {du}')