160 lines
5.6 KiB
Python
160 lines
5.6 KiB
Python
from typing import ClassVar, Self
|
|
from collections.abc import Callable, Sequence
|
|
|
|
import numpy
|
|
from numpy.typing import NDArray, ArrayLike
|
|
|
|
import pickle
|
|
import warnings
|
|
import copy
|
|
|
|
from . import GridError
|
|
from .base import GridBase
|
|
from .draw import GridDrawMixin
|
|
from .read import GridReadMixin
|
|
from .position import GridPosMixin
|
|
|
|
|
|
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
|
|
|
|
|
|
class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|
"""
|
|
Simulation grid metadata for finite-difference simulations.
|
|
|
|
Can be used to generate non-uniform rectangular grids (the entire grid
|
|
is generated based on the coordinates of the boundary points). Also does
|
|
straightforward natural <-> grid unit conversion.
|
|
|
|
This class handles data describing the grid, and should be paired with a
|
|
(separate) ndarray that contains the actual data in each cell. The `allocate()`
|
|
method can be used to create this ndarray.
|
|
|
|
The resulting `cell_data[i, a, b, c]` should correspond to the value in the
|
|
`i`-th grid, in the cell centered around
|
|
```
|
|
(xyz[0][a] + dxyz[0][a] * shifts[i, 0],
|
|
xyz[1][b] + dxyz[1][b] * shifts[i, 1],
|
|
xyz[2][c] + dxyz[2][c] * shifts[i, 2]).
|
|
```
|
|
You can get raw edge coordinates (`exyz`),
|
|
center coordinates (`xyz`),
|
|
cell sizes (`dxyz`),
|
|
from the properties named as above, or get them for a given grid by using the
|
|
`self.shifted_*xyz(which_shifts)` functions.
|
|
|
|
The sizes of adjacent cells are taken into account when applying shifts. The
|
|
total shift for each edge is chosen using `(shift * dx_of_cell_being_moved_through)`.
|
|
|
|
It is tricky to determine the size of the right-most cell after shifting,
|
|
since its right boundary should shift by `shifts[i][a] * dxyz[a][dxyz[a].size]`,
|
|
where the dxyz element refers to a cell that does not exist.
|
|
Because of this, we either assume this 'ghost' cell is the same size as the last
|
|
real cell, or, if `self.periodic[a]` is set to `True`, the same size as the first cell.
|
|
"""
|
|
exyz: list[NDArray]
|
|
"""Cell edges. Monotonically increasing without duplicates."""
|
|
|
|
periodic: list[bool]
|
|
"""For each axis, determines how far the rightmost boundary gets shifted. """
|
|
|
|
shifts: NDArray
|
|
"""Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`"""
|
|
|
|
Yee_Shifts_E: ClassVar[NDArray] = 0.5 * numpy.array([
|
|
[1, 0, 0],
|
|
[0, 1, 0],
|
|
[0, 0, 1],
|
|
], dtype=float)
|
|
"""Default shifts for Yee grid E-field"""
|
|
|
|
Yee_Shifts_H: ClassVar[NDArray] = 0.5 * numpy.array([
|
|
[0, 1, 1],
|
|
[1, 0, 1],
|
|
[1, 1, 0],
|
|
], dtype=float)
|
|
"""Default shifts for Yee grid H-field"""
|
|
|
|
def __init__(
|
|
self,
|
|
pixel_edge_coordinates: Sequence[ArrayLike],
|
|
shifts: ArrayLike = Yee_Shifts_E,
|
|
periodic: bool | Sequence[bool] = False,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
pixel_edge_coordinates: 3-element list of (ndarrays or lists) specifying the
|
|
coordinates of the pixel edges in each dimensions
|
|
(ie, `[[x0, x1, x2,...], [y0,...], [z0,...]]` where the first pixel has x-edges x=`x0` and
|
|
x=`x1`, the second has edges x=`x1` and x=`x2`, etc.)
|
|
shifts: Nx3 array containing `[x, y, z]` offsets for each of N grids.
|
|
E-field Yee shifts are used by default.
|
|
periodic: Specifies how the sizes of edge cells are calculated; see main class
|
|
documentation. List of 3 bool, or a single bool that gets broadcast. Default `False`.
|
|
|
|
Raises:
|
|
`GridError` on invalid input
|
|
"""
|
|
edge_arrs = [numpy.array(cc, copy=False) for cc in pixel_edge_coordinates]
|
|
self.exyz = [numpy.unique(edges) for edges in edge_arrs]
|
|
self.shifts = numpy.array(shifts, dtype=float)
|
|
|
|
for i in range(3):
|
|
if self.exyz[i].size != edge_arrs[i].size:
|
|
warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2)
|
|
|
|
if isinstance(periodic, bool):
|
|
self.periodic = [periodic] * 3
|
|
else:
|
|
self.periodic = list(periodic)
|
|
|
|
if len(self.shifts.shape) != 2:
|
|
raise GridError('Misshapen shifts: shifts must have two axes! '
|
|
f' The given shifts has shape {self.shifts.shape}')
|
|
if self.shifts.shape[1] != 3:
|
|
raise GridError('Misshapen shifts; second axis size should be 3,'
|
|
f' shape is {self.shifts.shape}')
|
|
|
|
if (numpy.abs(self.shifts) > 1).any():
|
|
raise GridError('Only shifts in the range [-1, 1] are currently supported')
|
|
|
|
if (self.shifts < 0).any():
|
|
# TODO: Test negative shifts
|
|
warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2)
|
|
|
|
@staticmethod
|
|
def load(filename: str) -> 'Grid':
|
|
"""
|
|
Load a grid from a file
|
|
|
|
Args:
|
|
filename: Filename to load from.
|
|
"""
|
|
with open(filename, 'rb') as f:
|
|
tmp_dict = pickle.load(f)
|
|
|
|
g = Grid([[-1, 1]] * 3)
|
|
g.__dict__.update(tmp_dict)
|
|
return g
|
|
|
|
def save(self, filename: str) -> Self:
|
|
"""
|
|
Save to file.
|
|
|
|
Args:
|
|
filename: Filename to save to.
|
|
|
|
Returns:
|
|
self
|
|
"""
|
|
with open(filename, 'wb') as f:
|
|
pickle.dump(self.__dict__, f, protocol=2)
|
|
return self
|
|
|
|
def copy(self) -> Self:
|
|
"""
|
|
Returns:
|
|
Deep copy of the grid.
|
|
"""
|
|
return copy.deepcopy(self)
|