refactor to avoid class-scoped imports
This commit is contained in:
parent
f84a75f35a
commit
8c33a39c02
198
gridlock/base.py
Normal file
198
gridlock/base.py
Normal file
@ -0,0 +1,198 @@
|
||||
from typing import ClassVar, Self, Protocol
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray, ArrayLike
|
||||
|
||||
from . import GridError
|
||||
|
||||
|
||||
class GridBase(Protocol):
|
||||
exyz: list[NDArray]
|
||||
"""Cell edges. Monotonically increasing without duplicates."""
|
||||
|
||||
periodic: list[bool]
|
||||
"""For each axis, determines how far the rightmost boundary gets shifted. """
|
||||
|
||||
shifts: NDArray
|
||||
"""Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`"""
|
||||
|
||||
@property
|
||||
def dxyz(self) -> list[NDArray]:
|
||||
"""
|
||||
Cell sizes for each axis, no shifts applied
|
||||
|
||||
Returns:
|
||||
List of 3 ndarrays of cell sizes
|
||||
"""
|
||||
return [numpy.diff(ee) for ee in self.exyz]
|
||||
|
||||
@property
|
||||
def xyz(self) -> list[NDArray]:
|
||||
"""
|
||||
Cell centers for each axis, no shifts applied
|
||||
|
||||
Returns:
|
||||
List of 3 ndarrays of cell edges
|
||||
"""
|
||||
return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
|
||||
|
||||
@property
|
||||
def shape(self) -> NDArray[numpy.int_]:
|
||||
"""
|
||||
The number of cells in x, y, and z
|
||||
|
||||
Returns:
|
||||
ndarray of [x_centers.size, y_centers.size, z_centers.size]
|
||||
"""
|
||||
return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int)
|
||||
|
||||
@property
|
||||
def num_grids(self) -> int:
|
||||
"""
|
||||
The number of grids (number of shifts)
|
||||
"""
|
||||
return self.shifts.shape[0]
|
||||
|
||||
@property
|
||||
def cell_data_shape(self):
|
||||
"""
|
||||
The shape of the cell_data ndarray (num_grids, *self.shape).
|
||||
"""
|
||||
return numpy.hstack((self.num_grids, self.shape))
|
||||
|
||||
@property
|
||||
def dxyz_with_ghost(self) -> list[NDArray]:
|
||||
"""
|
||||
Gives dxyz with an additional 'ghost' cell at the end, whose value depends
|
||||
on whether or not the axis has periodic boundary conditions. See main description
|
||||
above to learn why this is necessary.
|
||||
|
||||
If periodic, final edge shifts same amount as first
|
||||
Otherwise, final edge shifts same amount as second-to-last
|
||||
|
||||
Returns:
|
||||
list of [dxs, dys, dzs] with each element same length as elements of `self.xyz`
|
||||
"""
|
||||
el = [0 if p else -1 for p in self.periodic]
|
||||
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
|
||||
|
||||
@property
|
||||
def center(self) -> NDArray[numpy.float64]:
|
||||
"""
|
||||
Center position of the entire grid, no shifts applied
|
||||
|
||||
Returns:
|
||||
ndarray of [x_center, y_center, z_center]
|
||||
"""
|
||||
# center is just average of first and last xyz, which is just the average of the
|
||||
# first two and last two exyz
|
||||
centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)]
|
||||
return numpy.array(centers, dtype=float)
|
||||
|
||||
@property
|
||||
def dxyz_limits(self) -> tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element
|
||||
ndarrays. No shifts are applied, so these are extreme bounds on these values (as a
|
||||
weighted average is performed when shifting).
|
||||
|
||||
Returns:
|
||||
Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]`
|
||||
"""
|
||||
d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float)
|
||||
d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float)
|
||||
return d_min, d_max
|
||||
|
||||
def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]:
|
||||
"""
|
||||
Returns edges for which_shifts.
|
||||
|
||||
Args:
|
||||
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
|
||||
|
||||
Returns:
|
||||
List of 3 ndarrays of cell edges
|
||||
"""
|
||||
if which_shifts is None:
|
||||
return self.exyz
|
||||
dxyz = self.dxyz_with_ghost
|
||||
shifts = self.shifts[which_shifts, :]
|
||||
|
||||
# If shift is negative, use left cell's dx to determine shift
|
||||
for a in range(3):
|
||||
if shifts[a] < 0:
|
||||
dxyz[a] = numpy.roll(dxyz[a], 1)
|
||||
|
||||
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
|
||||
|
||||
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
|
||||
"""
|
||||
Returns cell sizes for `which_shifts`.
|
||||
|
||||
Args:
|
||||
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
|
||||
|
||||
Returns:
|
||||
List of 3 ndarrays of cell sizes
|
||||
"""
|
||||
if which_shifts is None:
|
||||
return self.dxyz
|
||||
shifts = self.shifts[which_shifts, :]
|
||||
dxyz = self.dxyz_with_ghost
|
||||
|
||||
# If shift is negative, use left cell's dx to determine size
|
||||
sdxyz = []
|
||||
for a in range(3):
|
||||
if shifts[a] < 0:
|
||||
roll_dxyz = numpy.roll(dxyz[a], 1)
|
||||
abs_shift = numpy.abs(shifts[a])
|
||||
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
|
||||
else:
|
||||
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
|
||||
|
||||
return sdxyz
|
||||
|
||||
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
||||
"""
|
||||
Returns cell centers for `which_shifts`.
|
||||
|
||||
Args:
|
||||
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
|
||||
|
||||
Returns:
|
||||
List of 3 ndarrays of cell centers
|
||||
"""
|
||||
if which_shifts is None:
|
||||
return self.xyz
|
||||
exyz = self.shifted_exyz(which_shifts)
|
||||
dxyz = self.shifted_dxyz(which_shifts)
|
||||
return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)]
|
||||
|
||||
def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]:
|
||||
"""
|
||||
Return cell widths, with each dimension shifted by the corresponding shifts.
|
||||
|
||||
Returns:
|
||||
`[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]`
|
||||
"""
|
||||
if self.num_grids != 3:
|
||||
raise GridError('Autoshifting requires exactly 3 grids')
|
||||
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
|
||||
|
||||
def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray:
|
||||
"""
|
||||
Allocate an ndarray for storing grid data.
|
||||
|
||||
Args:
|
||||
fill_value: Value to initialize the grid to. If None, an
|
||||
uninitialized array is returned.
|
||||
dtype: Numpy dtype for the array. Default is `numpy.float32`.
|
||||
|
||||
Returns:
|
||||
The allocated array
|
||||
"""
|
||||
if fill_value is None:
|
||||
return numpy.empty(self.cell_data_shape, dtype=dtype)
|
||||
else:
|
||||
return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)
|
@ -1,13 +1,16 @@
|
||||
"""
|
||||
Drawing-related methods for Grid class
|
||||
"""
|
||||
from typing import Union, Sequence, Callable
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Callable
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray, ArrayLike
|
||||
from float_raster import raster
|
||||
|
||||
from . import GridError
|
||||
from .base import GridBase
|
||||
from .position import GridPosMixin
|
||||
|
||||
|
||||
# NOTE: Maybe it would make sense to create a GridDrawer class
|
||||
@ -20,12 +23,13 @@ foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray]
|
||||
foreground_t = Union[float, foreground_callable_t]
|
||||
|
||||
|
||||
class GridDrawMixin(GridPosMixin):
|
||||
def draw_polygons(
|
||||
self,
|
||||
cell_data: NDArray,
|
||||
surface_normal: int,
|
||||
center: ArrayLike,
|
||||
polygons: Sequence[NDArray],
|
||||
polygons: Sequence[ArrayLike],
|
||||
thickness: float,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
) -> None:
|
||||
@ -50,13 +54,13 @@ def draw_polygons(
|
||||
"""
|
||||
if surface_normal not in range(3):
|
||||
raise GridError('Invalid surface_normal direction')
|
||||
|
||||
center = numpy.squeeze(center)
|
||||
poly_list = [numpy.array(poly, copy=False) for poly in polygons]
|
||||
|
||||
# Check polygons, and remove redundant coordinates
|
||||
surface = numpy.delete(range(3), surface_normal)
|
||||
|
||||
for i, polygon in enumerate(polygons):
|
||||
for i, polygon in enumerate(poly_list):
|
||||
malformed = f'Malformed polygon: ({i})'
|
||||
if polygon.shape[1] not in (2, 3):
|
||||
raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray')
|
||||
@ -82,7 +86,7 @@ def draw_polygons(
|
||||
# 1) Compute outer bounds (bd) of polygons
|
||||
bd_2d_min = numpy.array([0, 0])
|
||||
bd_2d_max = numpy.array([0, 0])
|
||||
for polygon in polygons:
|
||||
for polygon in poly_list:
|
||||
bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0))
|
||||
bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0))
|
||||
bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center
|
||||
@ -100,7 +104,7 @@ def draw_polygons(
|
||||
bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int)
|
||||
|
||||
# 3) Adjust polygons for center
|
||||
polygons = [poly + center[surface] for poly in polygons]
|
||||
poly_list = [poly + center[surface] for poly in poly_list]
|
||||
|
||||
# ## Generate weighing function
|
||||
def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]:
|
||||
@ -108,6 +112,7 @@ def draw_polygons(
|
||||
return numpy.insert(v_2d, surface_normal, (val,))
|
||||
|
||||
# iterate over grids
|
||||
foreground_val: NDArray | float
|
||||
for i, _ in enumerate(cell_data):
|
||||
# ## Evaluate or expand foregrounds[i]
|
||||
foregrounds_i = foregrounds[i]
|
||||
@ -129,7 +134,7 @@ def draw_polygons(
|
||||
w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int))
|
||||
|
||||
# Draw each polygon separately
|
||||
for polygon in polygons:
|
||||
for polygon in poly_list:
|
||||
|
||||
# Get the boundaries of the polygon
|
||||
pbd_min = polygon.min(axis=0)
|
||||
|
201
gridlock/grid.py
201
gridlock/grid.py
@ -1,4 +1,5 @@
|
||||
from typing import Callable, Sequence, ClassVar, Self
|
||||
from typing import ClassVar, Self
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray, ArrayLike
|
||||
@ -8,12 +9,16 @@ import warnings
|
||||
import copy
|
||||
|
||||
from . import GridError
|
||||
from .base import GridBase
|
||||
from .draw import GridDrawMixin
|
||||
from .read import GridReadMixin
|
||||
from .position import GridPosMixin
|
||||
|
||||
|
||||
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
|
||||
|
||||
|
||||
class Grid:
|
||||
class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
||||
"""
|
||||
Simulation grid metadata for finite-difference simulations.
|
||||
|
||||
@ -70,193 +75,6 @@ class Grid:
|
||||
], dtype=float)
|
||||
"""Default shifts for Yee grid H-field"""
|
||||
|
||||
from .draw import (
|
||||
draw_polygons, draw_polygon, draw_slab, draw_cuboid,
|
||||
draw_cylinder, draw_extrude_rectangle,
|
||||
)
|
||||
from .read import get_slice, visualize_slice, visualize_isosurface
|
||||
from .position import ind2pos, pos2ind
|
||||
|
||||
@property
|
||||
def dxyz(self) -> list[NDArray]:
|
||||
"""
|
||||
Cell sizes for each axis, no shifts applied
|
||||
|
||||
Returns:
|
||||
List of 3 ndarrays of cell sizes
|
||||
"""
|
||||
return [numpy.diff(ee) for ee in self.exyz]
|
||||
|
||||
@property
|
||||
def xyz(self) -> list[NDArray]:
|
||||
"""
|
||||
Cell centers for each axis, no shifts applied
|
||||
|
||||
Returns:
|
||||
List of 3 ndarrays of cell edges
|
||||
"""
|
||||
return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
|
||||
|
||||
@property
|
||||
def shape(self) -> NDArray[numpy.int_]:
|
||||
"""
|
||||
The number of cells in x, y, and z
|
||||
|
||||
Returns:
|
||||
ndarray of [x_centers.size, y_centers.size, z_centers.size]
|
||||
"""
|
||||
return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int)
|
||||
|
||||
@property
|
||||
def num_grids(self) -> int:
|
||||
"""
|
||||
The number of grids (number of shifts)
|
||||
"""
|
||||
return self.shifts.shape[0]
|
||||
|
||||
@property
|
||||
def cell_data_shape(self):
|
||||
"""
|
||||
The shape of the cell_data ndarray (num_grids, *self.shape).
|
||||
"""
|
||||
return numpy.hstack((self.num_grids, self.shape))
|
||||
|
||||
@property
|
||||
def dxyz_with_ghost(self) -> list[NDArray]:
|
||||
"""
|
||||
Gives dxyz with an additional 'ghost' cell at the end, whose value depends
|
||||
on whether or not the axis has periodic boundary conditions. See main description
|
||||
above to learn why this is necessary.
|
||||
|
||||
If periodic, final edge shifts same amount as first
|
||||
Otherwise, final edge shifts same amount as second-to-last
|
||||
|
||||
Returns:
|
||||
list of [dxs, dys, dzs] with each element same length as elements of `self.xyz`
|
||||
"""
|
||||
el = [0 if p else -1 for p in self.periodic]
|
||||
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
|
||||
|
||||
@property
|
||||
def center(self) -> NDArray[numpy.float64]:
|
||||
"""
|
||||
Center position of the entire grid, no shifts applied
|
||||
|
||||
Returns:
|
||||
ndarray of [x_center, y_center, z_center]
|
||||
"""
|
||||
# center is just average of first and last xyz, which is just the average of the
|
||||
# first two and last two exyz
|
||||
centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)]
|
||||
return numpy.array(centers, dtype=float)
|
||||
|
||||
@property
|
||||
def dxyz_limits(self) -> tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element
|
||||
ndarrays. No shifts are applied, so these are extreme bounds on these values (as a
|
||||
weighted average is performed when shifting).
|
||||
|
||||
Returns:
|
||||
Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]`
|
||||
"""
|
||||
d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float)
|
||||
d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float)
|
||||
return d_min, d_max
|
||||
|
||||
def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]:
|
||||
"""
|
||||
Returns edges for which_shifts.
|
||||
|
||||
Args:
|
||||
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
|
||||
|
||||
Returns:
|
||||
List of 3 ndarrays of cell edges
|
||||
"""
|
||||
if which_shifts is None:
|
||||
return self.exyz
|
||||
dxyz = self.dxyz_with_ghost
|
||||
shifts = self.shifts[which_shifts, :]
|
||||
|
||||
# If shift is negative, use left cell's dx to determine shift
|
||||
for a in range(3):
|
||||
if shifts[a] < 0:
|
||||
dxyz[a] = numpy.roll(dxyz[a], 1)
|
||||
|
||||
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
|
||||
|
||||
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
|
||||
"""
|
||||
Returns cell sizes for `which_shifts`.
|
||||
|
||||
Args:
|
||||
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
|
||||
|
||||
Returns:
|
||||
List of 3 ndarrays of cell sizes
|
||||
"""
|
||||
if which_shifts is None:
|
||||
return self.dxyz
|
||||
shifts = self.shifts[which_shifts, :]
|
||||
dxyz = self.dxyz_with_ghost
|
||||
|
||||
# If shift is negative, use left cell's dx to determine size
|
||||
sdxyz = []
|
||||
for a in range(3):
|
||||
if shifts[a] < 0:
|
||||
roll_dxyz = numpy.roll(dxyz[a], 1)
|
||||
abs_shift = numpy.abs(shifts[a])
|
||||
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
|
||||
else:
|
||||
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
|
||||
|
||||
return sdxyz
|
||||
|
||||
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
||||
"""
|
||||
Returns cell centers for `which_shifts`.
|
||||
|
||||
Args:
|
||||
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
|
||||
|
||||
Returns:
|
||||
List of 3 ndarrays of cell centers
|
||||
"""
|
||||
if which_shifts is None:
|
||||
return self.xyz
|
||||
exyz = self.shifted_exyz(which_shifts)
|
||||
dxyz = self.shifted_dxyz(which_shifts)
|
||||
return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)]
|
||||
|
||||
def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]:
|
||||
"""
|
||||
Return cell widths, with each dimension shifted by the corresponding shifts.
|
||||
|
||||
Returns:
|
||||
`[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]`
|
||||
"""
|
||||
if self.num_grids != 3:
|
||||
raise GridError('Autoshifting requires exactly 3 grids')
|
||||
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
|
||||
|
||||
def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray:
|
||||
"""
|
||||
Allocate an ndarray for storing grid data.
|
||||
|
||||
Args:
|
||||
fill_value: Value to initialize the grid to. If None, an
|
||||
uninitialized array is returned.
|
||||
dtype: Numpy dtype for the array. Default is `numpy.float32`.
|
||||
|
||||
Returns:
|
||||
The allocated array
|
||||
"""
|
||||
if fill_value is None:
|
||||
return numpy.empty(self.cell_data_shape, dtype=dtype)
|
||||
else:
|
||||
return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pixel_edge_coordinates: Sequence[ArrayLike],
|
||||
@ -277,11 +95,12 @@ class Grid:
|
||||
Raises:
|
||||
`GridError` on invalid input
|
||||
"""
|
||||
self.exyz = [numpy.unique(pixel_edge_coordinates[i]) for i in range(3)]
|
||||
edge_arrs = [numpy.array(cc, copy=False) for cc in pixel_edge_coordinates]
|
||||
self.exyz = [numpy.unique(edges) for edges in edge_arrs]
|
||||
self.shifts = numpy.array(shifts, dtype=float)
|
||||
|
||||
for i in range(3):
|
||||
if len(self.exyz[i]) != len(pixel_edge_coordinates[i]):
|
||||
if self.exyz[i].size != edge_arrs[i].size:
|
||||
warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2)
|
||||
|
||||
if isinstance(periodic, bool):
|
||||
|
@ -5,8 +5,10 @@ import numpy
|
||||
from numpy.typing import NDArray, ArrayLike
|
||||
|
||||
from . import GridError
|
||||
from .base import GridBase
|
||||
|
||||
|
||||
class GridPosMixin(GridBase):
|
||||
def ind2pos(
|
||||
self,
|
||||
ind: NDArray,
|
||||
|
@ -7,6 +7,8 @@ import numpy
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from . import GridError
|
||||
from .base import GridBase
|
||||
from .position import GridPosMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import matplotlib.axes
|
||||
@ -18,6 +20,7 @@ if TYPE_CHECKING:
|
||||
# .visualize_isosurface uses mpl_toolkits.mplot3d
|
||||
|
||||
|
||||
class GridReadMixin(GridPosMixin):
|
||||
def get_slice(
|
||||
self,
|
||||
cell_data: NDArray,
|
||||
|
Loading…
Reference in New Issue
Block a user