refactor to avoid class-scoped imports

This commit is contained in:
Jan Petykiewicz 2024-07-29 01:37:58 -07:00
parent f84a75f35a
commit 8c33a39c02
5 changed files with 801 additions and 774 deletions

198
gridlock/base.py Normal file
View File

@ -0,0 +1,198 @@
from typing import ClassVar, Self, Protocol
from collections.abc import Callable, Sequence
import numpy
from numpy.typing import NDArray, ArrayLike
from . import GridError
class GridBase(Protocol):
exyz: list[NDArray]
"""Cell edges. Monotonically increasing without duplicates."""
periodic: list[bool]
"""For each axis, determines how far the rightmost boundary gets shifted. """
shifts: NDArray
"""Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`"""
@property
def dxyz(self) -> list[NDArray]:
"""
Cell sizes for each axis, no shifts applied
Returns:
List of 3 ndarrays of cell sizes
"""
return [numpy.diff(ee) for ee in self.exyz]
@property
def xyz(self) -> list[NDArray]:
"""
Cell centers for each axis, no shifts applied
Returns:
List of 3 ndarrays of cell edges
"""
return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
@property
def shape(self) -> NDArray[numpy.int_]:
"""
The number of cells in x, y, and z
Returns:
ndarray of [x_centers.size, y_centers.size, z_centers.size]
"""
return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int)
@property
def num_grids(self) -> int:
"""
The number of grids (number of shifts)
"""
return self.shifts.shape[0]
@property
def cell_data_shape(self):
"""
The shape of the cell_data ndarray (num_grids, *self.shape).
"""
return numpy.hstack((self.num_grids, self.shape))
@property
def dxyz_with_ghost(self) -> list[NDArray]:
"""
Gives dxyz with an additional 'ghost' cell at the end, whose value depends
on whether or not the axis has periodic boundary conditions. See main description
above to learn why this is necessary.
If periodic, final edge shifts same amount as first
Otherwise, final edge shifts same amount as second-to-last
Returns:
list of [dxs, dys, dzs] with each element same length as elements of `self.xyz`
"""
el = [0 if p else -1 for p in self.periodic]
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
@property
def center(self) -> NDArray[numpy.float64]:
"""
Center position of the entire grid, no shifts applied
Returns:
ndarray of [x_center, y_center, z_center]
"""
# center is just average of first and last xyz, which is just the average of the
# first two and last two exyz
centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)]
return numpy.array(centers, dtype=float)
@property
def dxyz_limits(self) -> tuple[NDArray, NDArray]:
"""
Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element
ndarrays. No shifts are applied, so these are extreme bounds on these values (as a
weighted average is performed when shifting).
Returns:
Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]`
"""
d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float)
d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float)
return d_min, d_max
def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]:
"""
Returns edges for which_shifts.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell edges
"""
if which_shifts is None:
return self.exyz
dxyz = self.dxyz_with_ghost
shifts = self.shifts[which_shifts, :]
# If shift is negative, use left cell's dx to determine shift
for a in range(3):
if shifts[a] < 0:
dxyz[a] = numpy.roll(dxyz[a], 1)
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
"""
Returns cell sizes for `which_shifts`.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell sizes
"""
if which_shifts is None:
return self.dxyz
shifts = self.shifts[which_shifts, :]
dxyz = self.dxyz_with_ghost
# If shift is negative, use left cell's dx to determine size
sdxyz = []
for a in range(3):
if shifts[a] < 0:
roll_dxyz = numpy.roll(dxyz[a], 1)
abs_shift = numpy.abs(shifts[a])
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
else:
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
return sdxyz
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
"""
Returns cell centers for `which_shifts`.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell centers
"""
if which_shifts is None:
return self.xyz
exyz = self.shifted_exyz(which_shifts)
dxyz = self.shifted_dxyz(which_shifts)
return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)]
def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]:
"""
Return cell widths, with each dimension shifted by the corresponding shifts.
Returns:
`[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]`
"""
if self.num_grids != 3:
raise GridError('Autoshifting requires exactly 3 grids')
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray:
"""
Allocate an ndarray for storing grid data.
Args:
fill_value: Value to initialize the grid to. If None, an
uninitialized array is returned.
dtype: Numpy dtype for the array. Default is `numpy.float32`.
Returns:
The allocated array
"""
if fill_value is None:
return numpy.empty(self.cell_data_shape, dtype=dtype)
else:
return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)

View File

@ -1,13 +1,16 @@
""" """
Drawing-related methods for Grid class Drawing-related methods for Grid class
""" """
from typing import Union, Sequence, Callable from typing import Union
from collections.abc import Sequence, Callable
import numpy import numpy
from numpy.typing import NDArray, ArrayLike from numpy.typing import NDArray, ArrayLike
from float_raster import raster from float_raster import raster
from . import GridError from . import GridError
from .base import GridBase
from .position import GridPosMixin
# NOTE: Maybe it would make sense to create a GridDrawer class # NOTE: Maybe it would make sense to create a GridDrawer class
@ -20,12 +23,13 @@ foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray]
foreground_t = Union[float, foreground_callable_t] foreground_t = Union[float, foreground_callable_t]
def draw_polygons( class GridDrawMixin(GridPosMixin):
def draw_polygons(
self, self,
cell_data: NDArray, cell_data: NDArray,
surface_normal: int, surface_normal: int,
center: ArrayLike, center: ArrayLike,
polygons: Sequence[NDArray], polygons: Sequence[ArrayLike],
thickness: float, thickness: float,
foreground: Sequence[foreground_t] | foreground_t, foreground: Sequence[foreground_t] | foreground_t,
) -> None: ) -> None:
@ -50,13 +54,13 @@ def draw_polygons(
""" """
if surface_normal not in range(3): if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction') raise GridError('Invalid surface_normal direction')
center = numpy.squeeze(center) center = numpy.squeeze(center)
poly_list = [numpy.array(poly, copy=False) for poly in polygons]
# Check polygons, and remove redundant coordinates # Check polygons, and remove redundant coordinates
surface = numpy.delete(range(3), surface_normal) surface = numpy.delete(range(3), surface_normal)
for i, polygon in enumerate(polygons): for i, polygon in enumerate(poly_list):
malformed = f'Malformed polygon: ({i})' malformed = f'Malformed polygon: ({i})'
if polygon.shape[1] not in (2, 3): if polygon.shape[1] not in (2, 3):
raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray')
@ -82,7 +86,7 @@ def draw_polygons(
# 1) Compute outer bounds (bd) of polygons # 1) Compute outer bounds (bd) of polygons
bd_2d_min = numpy.array([0, 0]) bd_2d_min = numpy.array([0, 0])
bd_2d_max = numpy.array([0, 0]) bd_2d_max = numpy.array([0, 0])
for polygon in polygons: for polygon in poly_list:
bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0))
bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0))
bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center
@ -100,7 +104,7 @@ def draw_polygons(
bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int) bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int)
# 3) Adjust polygons for center # 3) Adjust polygons for center
polygons = [poly + center[surface] for poly in polygons] poly_list = [poly + center[surface] for poly in poly_list]
# ## Generate weighing function # ## Generate weighing function
def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]: def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]:
@ -108,6 +112,7 @@ def draw_polygons(
return numpy.insert(v_2d, surface_normal, (val,)) return numpy.insert(v_2d, surface_normal, (val,))
# iterate over grids # iterate over grids
foreground_val: NDArray | float
for i, _ in enumerate(cell_data): for i, _ in enumerate(cell_data):
# ## Evaluate or expand foregrounds[i] # ## Evaluate or expand foregrounds[i]
foregrounds_i = foregrounds[i] foregrounds_i = foregrounds[i]
@ -129,7 +134,7 @@ def draw_polygons(
w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int)) w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int))
# Draw each polygon separately # Draw each polygon separately
for polygon in polygons: for polygon in poly_list:
# Get the boundaries of the polygon # Get the boundaries of the polygon
pbd_min = polygon.min(axis=0) pbd_min = polygon.min(axis=0)
@ -195,7 +200,7 @@ def draw_polygons(
cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val
def draw_polygon( def draw_polygon(
self, self,
cell_data: NDArray, cell_data: NDArray,
surface_normal: int, surface_normal: int,
@ -220,7 +225,7 @@ def draw_polygon(
self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground) self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground)
def draw_slab( def draw_slab(
self, self,
cell_data: NDArray, cell_data: NDArray,
surface_normal: int, surface_normal: int,
@ -271,7 +276,7 @@ def draw_slab(
self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground) self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground)
def draw_cuboid( def draw_cuboid(
self, self,
cell_data: NDArray, cell_data: NDArray,
center: ArrayLike, center: ArrayLike,
@ -297,7 +302,7 @@ def draw_cuboid(
self.draw_polygon(cell_data, 2, center, p, thickness, foreground) self.draw_polygon(cell_data, 2, center, p, thickness, foreground)
def draw_cylinder( def draw_cylinder(
self, self,
cell_data: NDArray, cell_data: NDArray,
surface_normal: int, surface_normal: int,
@ -326,7 +331,7 @@ def draw_cylinder(
self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground) self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground)
def draw_extrude_rectangle( def draw_extrude_rectangle(
self, self,
cell_data: NDArray, cell_data: NDArray,
rectangle: ArrayLike, rectangle: ArrayLike,

View File

@ -1,4 +1,5 @@
from typing import Callable, Sequence, ClassVar, Self from typing import ClassVar, Self
from collections.abc import Callable, Sequence
import numpy import numpy
from numpy.typing import NDArray, ArrayLike from numpy.typing import NDArray, ArrayLike
@ -8,12 +9,16 @@ import warnings
import copy import copy
from . import GridError from . import GridError
from .base import GridBase
from .draw import GridDrawMixin
from .read import GridReadMixin
from .position import GridPosMixin
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray] foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
class Grid: class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
""" """
Simulation grid metadata for finite-difference simulations. Simulation grid metadata for finite-difference simulations.
@ -70,193 +75,6 @@ class Grid:
], dtype=float) ], dtype=float)
"""Default shifts for Yee grid H-field""" """Default shifts for Yee grid H-field"""
from .draw import (
draw_polygons, draw_polygon, draw_slab, draw_cuboid,
draw_cylinder, draw_extrude_rectangle,
)
from .read import get_slice, visualize_slice, visualize_isosurface
from .position import ind2pos, pos2ind
@property
def dxyz(self) -> list[NDArray]:
"""
Cell sizes for each axis, no shifts applied
Returns:
List of 3 ndarrays of cell sizes
"""
return [numpy.diff(ee) for ee in self.exyz]
@property
def xyz(self) -> list[NDArray]:
"""
Cell centers for each axis, no shifts applied
Returns:
List of 3 ndarrays of cell edges
"""
return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
@property
def shape(self) -> NDArray[numpy.int_]:
"""
The number of cells in x, y, and z
Returns:
ndarray of [x_centers.size, y_centers.size, z_centers.size]
"""
return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int)
@property
def num_grids(self) -> int:
"""
The number of grids (number of shifts)
"""
return self.shifts.shape[0]
@property
def cell_data_shape(self):
"""
The shape of the cell_data ndarray (num_grids, *self.shape).
"""
return numpy.hstack((self.num_grids, self.shape))
@property
def dxyz_with_ghost(self) -> list[NDArray]:
"""
Gives dxyz with an additional 'ghost' cell at the end, whose value depends
on whether or not the axis has periodic boundary conditions. See main description
above to learn why this is necessary.
If periodic, final edge shifts same amount as first
Otherwise, final edge shifts same amount as second-to-last
Returns:
list of [dxs, dys, dzs] with each element same length as elements of `self.xyz`
"""
el = [0 if p else -1 for p in self.periodic]
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
@property
def center(self) -> NDArray[numpy.float64]:
"""
Center position of the entire grid, no shifts applied
Returns:
ndarray of [x_center, y_center, z_center]
"""
# center is just average of first and last xyz, which is just the average of the
# first two and last two exyz
centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)]
return numpy.array(centers, dtype=float)
@property
def dxyz_limits(self) -> tuple[NDArray, NDArray]:
"""
Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element
ndarrays. No shifts are applied, so these are extreme bounds on these values (as a
weighted average is performed when shifting).
Returns:
Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]`
"""
d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float)
d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float)
return d_min, d_max
def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]:
"""
Returns edges for which_shifts.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell edges
"""
if which_shifts is None:
return self.exyz
dxyz = self.dxyz_with_ghost
shifts = self.shifts[which_shifts, :]
# If shift is negative, use left cell's dx to determine shift
for a in range(3):
if shifts[a] < 0:
dxyz[a] = numpy.roll(dxyz[a], 1)
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
"""
Returns cell sizes for `which_shifts`.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell sizes
"""
if which_shifts is None:
return self.dxyz
shifts = self.shifts[which_shifts, :]
dxyz = self.dxyz_with_ghost
# If shift is negative, use left cell's dx to determine size
sdxyz = []
for a in range(3):
if shifts[a] < 0:
roll_dxyz = numpy.roll(dxyz[a], 1)
abs_shift = numpy.abs(shifts[a])
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
else:
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
return sdxyz
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
"""
Returns cell centers for `which_shifts`.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell centers
"""
if which_shifts is None:
return self.xyz
exyz = self.shifted_exyz(which_shifts)
dxyz = self.shifted_dxyz(which_shifts)
return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)]
def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]:
"""
Return cell widths, with each dimension shifted by the corresponding shifts.
Returns:
`[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]`
"""
if self.num_grids != 3:
raise GridError('Autoshifting requires exactly 3 grids')
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray:
"""
Allocate an ndarray for storing grid data.
Args:
fill_value: Value to initialize the grid to. If None, an
uninitialized array is returned.
dtype: Numpy dtype for the array. Default is `numpy.float32`.
Returns:
The allocated array
"""
if fill_value is None:
return numpy.empty(self.cell_data_shape, dtype=dtype)
else:
return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)
def __init__( def __init__(
self, self,
pixel_edge_coordinates: Sequence[ArrayLike], pixel_edge_coordinates: Sequence[ArrayLike],
@ -277,11 +95,12 @@ class Grid:
Raises: Raises:
`GridError` on invalid input `GridError` on invalid input
""" """
self.exyz = [numpy.unique(pixel_edge_coordinates[i]) for i in range(3)] edge_arrs = [numpy.array(cc, copy=False) for cc in pixel_edge_coordinates]
self.exyz = [numpy.unique(edges) for edges in edge_arrs]
self.shifts = numpy.array(shifts, dtype=float) self.shifts = numpy.array(shifts, dtype=float)
for i in range(3): for i in range(3):
if len(self.exyz[i]) != len(pixel_edge_coordinates[i]): if self.exyz[i].size != edge_arrs[i].size:
warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2) warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2)
if isinstance(periodic, bool): if isinstance(periodic, bool):

View File

@ -5,9 +5,11 @@ import numpy
from numpy.typing import NDArray, ArrayLike from numpy.typing import NDArray, ArrayLike
from . import GridError from . import GridError
from .base import GridBase
def ind2pos( class GridPosMixin(GridBase):
def ind2pos(
self, self,
ind: NDArray, ind: NDArray,
which_shifts: int | None = None, which_shifts: int | None = None,
@ -59,7 +61,7 @@ def ind2pos(
return numpy.array(position, dtype=float) return numpy.array(position, dtype=float)
def pos2ind( def pos2ind(
self, self,
r: ArrayLike, r: ArrayLike,
which_shifts: int | None, which_shifts: int | None,

View File

@ -7,6 +7,8 @@ import numpy
from numpy.typing import NDArray from numpy.typing import NDArray
from . import GridError from . import GridError
from .base import GridBase
from .position import GridPosMixin
if TYPE_CHECKING: if TYPE_CHECKING:
import matplotlib.axes import matplotlib.axes
@ -18,7 +20,8 @@ if TYPE_CHECKING:
# .visualize_isosurface uses mpl_toolkits.mplot3d # .visualize_isosurface uses mpl_toolkits.mplot3d
def get_slice( class GridReadMixin(GridPosMixin):
def get_slice(
self, self,
cell_data: NDArray, cell_data: NDArray,
surface_normal: int, surface_normal: int,
@ -82,7 +85,7 @@ def get_slice(
return sliced_grid return sliced_grid
def visualize_slice( def visualize_slice(
self, self,
cell_data: NDArray, cell_data: NDArray,
surface_normal: int, surface_normal: int,
@ -135,7 +138,7 @@ def visualize_slice(
return fig, ax return fig, ax
def visualize_isosurface( def visualize_isosurface(
self, self,
cell_data: NDArray, cell_data: NDArray,
level: float | None = None, level: float | None = None,