diff --git a/gridlock/base.py b/gridlock/base.py new file mode 100644 index 0000000..6bd5fb8 --- /dev/null +++ b/gridlock/base.py @@ -0,0 +1,198 @@ +from typing import ClassVar, Self, Protocol +from collections.abc import Callable, Sequence + +import numpy +from numpy.typing import NDArray, ArrayLike + +from . import GridError + + +class GridBase(Protocol): + exyz: list[NDArray] + """Cell edges. Monotonically increasing without duplicates.""" + + periodic: list[bool] + """For each axis, determines how far the rightmost boundary gets shifted. """ + + shifts: NDArray + """Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`""" + + @property + def dxyz(self) -> list[NDArray]: + """ + Cell sizes for each axis, no shifts applied + + Returns: + List of 3 ndarrays of cell sizes + """ + return [numpy.diff(ee) for ee in self.exyz] + + @property + def xyz(self) -> list[NDArray]: + """ + Cell centers for each axis, no shifts applied + + Returns: + List of 3 ndarrays of cell edges + """ + return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)] + + @property + def shape(self) -> NDArray[numpy.int_]: + """ + The number of cells in x, y, and z + + Returns: + ndarray of [x_centers.size, y_centers.size, z_centers.size] + """ + return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int) + + @property + def num_grids(self) -> int: + """ + The number of grids (number of shifts) + """ + return self.shifts.shape[0] + + @property + def cell_data_shape(self): + """ + The shape of the cell_data ndarray (num_grids, *self.shape). + """ + return numpy.hstack((self.num_grids, self.shape)) + + @property + def dxyz_with_ghost(self) -> list[NDArray]: + """ + Gives dxyz with an additional 'ghost' cell at the end, whose value depends + on whether or not the axis has periodic boundary conditions. See main description + above to learn why this is necessary. + + If periodic, final edge shifts same amount as first + Otherwise, final edge shifts same amount as second-to-last + + Returns: + list of [dxs, dys, dzs] with each element same length as elements of `self.xyz` + """ + el = [0 if p else -1 for p in self.periodic] + return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)] + + @property + def center(self) -> NDArray[numpy.float64]: + """ + Center position of the entire grid, no shifts applied + + Returns: + ndarray of [x_center, y_center, z_center] + """ + # center is just average of first and last xyz, which is just the average of the + # first two and last two exyz + centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)] + return numpy.array(centers, dtype=float) + + @property + def dxyz_limits(self) -> tuple[NDArray, NDArray]: + """ + Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element + ndarrays. No shifts are applied, so these are extreme bounds on these values (as a + weighted average is performed when shifting). + + Returns: + Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]` + """ + d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float) + d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float) + return d_min, d_max + + def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]: + """ + Returns edges for which_shifts. + + Args: + which_shifts: Which grid (which shifts) to use, or `None` for unshifted + + Returns: + List of 3 ndarrays of cell edges + """ + if which_shifts is None: + return self.exyz + dxyz = self.dxyz_with_ghost + shifts = self.shifts[which_shifts, :] + + # If shift is negative, use left cell's dx to determine shift + for a in range(3): + if shifts[a] < 0: + dxyz[a] = numpy.roll(dxyz[a], 1) + + return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] + + def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]: + """ + Returns cell sizes for `which_shifts`. + + Args: + which_shifts: Which grid (which shifts) to use, or `None` for unshifted + + Returns: + List of 3 ndarrays of cell sizes + """ + if which_shifts is None: + return self.dxyz + shifts = self.shifts[which_shifts, :] + dxyz = self.dxyz_with_ghost + + # If shift is negative, use left cell's dx to determine size + sdxyz = [] + for a in range(3): + if shifts[a] < 0: + roll_dxyz = numpy.roll(dxyz[a], 1) + abs_shift = numpy.abs(shifts[a]) + sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift)) + else: + sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a]) + + return sdxyz + + def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: + """ + Returns cell centers for `which_shifts`. + + Args: + which_shifts: Which grid (which shifts) to use, or `None` for unshifted + + Returns: + List of 3 ndarrays of cell centers + """ + if which_shifts is None: + return self.xyz + exyz = self.shifted_exyz(which_shifts) + dxyz = self.shifted_dxyz(which_shifts) + return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)] + + def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]: + """ + Return cell widths, with each dimension shifted by the corresponding shifts. + + Returns: + `[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]` + """ + if self.num_grids != 3: + raise GridError('Autoshifting requires exactly 3 grids') + return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] + + def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray: + """ + Allocate an ndarray for storing grid data. + + Args: + fill_value: Value to initialize the grid to. If None, an + uninitialized array is returned. + dtype: Numpy dtype for the array. Default is `numpy.float32`. + + Returns: + The allocated array + """ + if fill_value is None: + return numpy.empty(self.cell_data_shape, dtype=dtype) + else: + return numpy.full(self.cell_data_shape, fill_value, dtype=dtype) diff --git a/gridlock/draw.py b/gridlock/draw.py index b4b2176..7146e15 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -1,13 +1,16 @@ """ Drawing-related methods for Grid class """ -from typing import Union, Sequence, Callable +from typing import Union +from collections.abc import Sequence, Callable import numpy from numpy.typing import NDArray, ArrayLike from float_raster import raster from . import GridError +from .base import GridBase +from .position import GridPosMixin # NOTE: Maybe it would make sense to create a GridDrawer class @@ -20,372 +23,374 @@ foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray] foreground_t = Union[float, foreground_callable_t] -def draw_polygons( - self, - cell_data: NDArray, - surface_normal: int, - center: ArrayLike, - polygons: Sequence[NDArray], - thickness: float, - foreground: Sequence[foreground_t] | foreground_t, - ) -> None: - """ - Draw polygons on an axis-aligned plane. +class GridDrawMixin(GridPosMixin): + def draw_polygons( + self, + cell_data: NDArray, + surface_normal: int, + center: ArrayLike, + polygons: Sequence[ArrayLike], + thickness: float, + foreground: Sequence[foreground_t] | foreground_t, + ) -> None: + """ + Draw polygons on an axis-aligned plane. - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: 3-element ndarray or list specifying an offset applied to all the polygons - polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon - (non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each - polygon must have at least 3 vertices. - thickness: Thickness of the layer to draw - foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list - of any of these (1 per grid). Callable values should take an ndarray the shape of the - grid and return an ndarray of equal shape containing the foreground value at the given x, y, - and z (natural, not grid coordinates). + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. + center: 3-element ndarray or list specifying an offset applied to all the polygons + polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon + (non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each + polygon must have at least 3 vertices. + thickness: Thickness of the layer to draw + foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list + of any of these (1 per grid). Callable values should take an ndarray the shape of the + grid and return an ndarray of equal shape containing the foreground value at the given x, y, + and z (natural, not grid coordinates). - Raises: - GridError - """ - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') - - center = numpy.squeeze(center) - - # Check polygons, and remove redundant coordinates - surface = numpy.delete(range(3), surface_normal) - - for i, polygon in enumerate(polygons): - malformed = f'Malformed polygon: ({i})' - if polygon.shape[1] not in (2, 3): - raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') - if polygon.shape[1] == 3: - polygon = polygon[surface, :] - - if not polygon.shape[0] > 2: - raise GridError(malformed + 'must consist of more than 2 points') - if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1: - raise GridError(malformed + 'must be in plane with surface normal ' - + 'xyz'[surface_normal]) - - # Broadcast foreground where necessary - foregrounds: Sequence[foreground_callable_t] | Sequence[float] - if numpy.size(foreground) == 1: # type: ignore - foregrounds = [foreground] * len(cell_data) # type: ignore - elif isinstance(foreground, numpy.ndarray): - raise GridError('ndarray not supported for foreground') - else: - foregrounds = foreground # type: ignore - - # ## Compute sub-domain of the grid occupied by polygons - # 1) Compute outer bounds (bd) of polygons - bd_2d_min = numpy.array([0, 0]) - bd_2d_max = numpy.array([0, 0]) - for polygon in polygons: - bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) - bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) - bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center - bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center - - # 2) Find indices (bdi) just outside bd elements - buf = 2 # size of safety buffer - # Use s_min and s_max with unshifted pos2ind to get absolute limits on - # the indices the polygons might affect - s_min = self.shifts.min(axis=0) - s_max = self.shifts.max(axis=0) - bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf - bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf - bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int) - bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int) - - # 3) Adjust polygons for center - polygons = [poly + center[surface] for poly in polygons] - - # ## Generate weighing function - def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]: - v_2d = numpy.array(vector, dtype=float) - return numpy.insert(v_2d, surface_normal, (val,)) - - # iterate over grids - for i, _ in enumerate(cell_data): - # ## Evaluate or expand foregrounds[i] - foregrounds_i = foregrounds[i] - if callable(foregrounds_i): - # meshgrid over the (shifted) domain - domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k] + 1] for k in range(3)] - (x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij') - - # evaluate on the meshgrid - foreground_val = foregrounds_i(x0, y0, z0) - if not numpy.isfinite(foreground_val).all(): - raise GridError(f'Non-finite values in foreground[{i}]') - elif numpy.size(foregrounds_i) != 1: - raise GridError(f'Unsupported foreground[{i}]: {type(foregrounds_i)}') - else: - # foreground[i] is scalar non-callable - foreground_val = foregrounds_i - - w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int)) - - # Draw each polygon separately - for polygon in polygons: - - # Get the boundaries of the polygon - pbd_min = polygon.min(axis=0) - pbd_max = polygon.max(axis=0) - - # Find indices in w_xy just outside polygon - # using per-grid xy-weights (self.shifted_xyz()) - corner_min = self.pos2ind(to_3d(pbd_min), i, - check_bounds=False)[surface].astype(int) - corner_max = self.pos2ind(to_3d(pbd_max), i, - check_bounds=False)[surface].astype(int) - - # Find indices in w_xy which are modified by polygon - # First for the edge coordinates (+1 since we're indexing edges) - edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max, strict=True)] - # Then for the pixel centers (-bdi_min since we're - # calculating weights within a subspace) - centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface], - corner_max - bdi_min[surface], strict=True)) - - aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices, strict=True)) - w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y) - - # Clamp overlapping polygons to 1 - w_xy = numpy.minimum(w_xy, 1.0) - - # 2) Generate weights in z-direction - w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], )) - - def get_zi(offset, i=i, w_z=w_z): - edges = self.shifted_exyz(i)[surface_normal] - point = center[surface_normal] + offset - grid_coord = numpy.digitize(point, edges) - 1 - w_coord = grid_coord - bdi_min[surface_normal] - - if w_coord < 0: - w_coord = 0 - f = 0 - elif w_coord >= w_z.size: - w_coord = w_z.size - 1 - f = 1 - else: - dz = self.shifted_dxyz(i)[surface_normal][grid_coord] - f = (point - edges[grid_coord]) / dz - return f, w_coord - - zi_top_f, zi_top = get_zi(+thickness / 2.0) - zi_bot_f, zi_bot = get_zi(-thickness / 2.0) - - w_z[zi_bot + 1:zi_top] = 1 - - if zi_bot < zi_top: - w_z[zi_top] = zi_top_f - w_z[zi_bot] = 1 - zi_bot_f - else: - w_z[zi_bot] = zi_top_f - zi_bot_f - - # 3) Generate total weight function - w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,))) - - # ## Modify the grid - g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3)) - cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val - - -def draw_polygon( - self, - cell_data: NDArray, - surface_normal: int, - center: ArrayLike, - polygon: ArrayLike, - thickness: float, - foreground: Sequence[foreground_t] | foreground_t, - ) -> None: - """ - Draw a polygon on an axis-aligned plane. - - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: 3-element ndarray or list specifying an offset applied to the polygon - polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, - clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at - least 3 vertices. - thickness: Thickness of the layer to draw - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - """ - self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground) - - -def draw_slab( - self, - cell_data: NDArray, - surface_normal: int, - center: ArrayLike, - thickness: float, - foreground: Sequence[foreground_t] | foreground_t, - ) -> None: - """ - Draw an axis-aligned infinite slab. - - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: `surface_normal` coordinate value at the center of the slab - thickness: Thickness of the layer to draw - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - """ - # Turn surface_normal into its integer representation - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') - - if numpy.size(center) != 1: + Raises: + GridError + """ + if surface_normal not in range(3): + raise GridError('Invalid surface_normal direction') center = numpy.squeeze(center) - if len(center) == 3: - center = center[surface_normal] + poly_list = [numpy.array(poly, copy=False) for poly in polygons] + + # Check polygons, and remove redundant coordinates + surface = numpy.delete(range(3), surface_normal) + + for i, polygon in enumerate(poly_list): + malformed = f'Malformed polygon: ({i})' + if polygon.shape[1] not in (2, 3): + raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') + if polygon.shape[1] == 3: + polygon = polygon[surface, :] + + if not polygon.shape[0] > 2: + raise GridError(malformed + 'must consist of more than 2 points') + if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1: + raise GridError(malformed + 'must be in plane with surface normal ' + + 'xyz'[surface_normal]) + + # Broadcast foreground where necessary + foregrounds: Sequence[foreground_callable_t] | Sequence[float] + if numpy.size(foreground) == 1: # type: ignore + foregrounds = [foreground] * len(cell_data) # type: ignore + elif isinstance(foreground, numpy.ndarray): + raise GridError('ndarray not supported for foreground') else: - raise GridError(f'Bad center: {center}') + foregrounds = foreground # type: ignore - # Find center of slab - center_shift = self.center - center_shift[surface_normal] = center + # ## Compute sub-domain of the grid occupied by polygons + # 1) Compute outer bounds (bd) of polygons + bd_2d_min = numpy.array([0, 0]) + bd_2d_max = numpy.array([0, 0]) + for polygon in poly_list: + bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) + bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) + bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center + bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center - surface = numpy.delete(range(3), surface_normal) + # 2) Find indices (bdi) just outside bd elements + buf = 2 # size of safety buffer + # Use s_min and s_max with unshifted pos2ind to get absolute limits on + # the indices the polygons might affect + s_min = self.shifts.min(axis=0) + s_max = self.shifts.max(axis=0) + bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf + bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf + bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int) + bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int) - xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface] - xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface] + # 3) Adjust polygons for center + poly_list = [poly + center[surface] for poly in poly_list] - dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float) + # ## Generate weighing function + def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]: + v_2d = numpy.array(vector, dtype=float) + return numpy.insert(v_2d, surface_normal, (val,)) - xyz_min -= 4 * dxyz - xyz_max += 4 * dxyz + # iterate over grids + foreground_val: NDArray | float + for i, _ in enumerate(cell_data): + # ## Evaluate or expand foregrounds[i] + foregrounds_i = foregrounds[i] + if callable(foregrounds_i): + # meshgrid over the (shifted) domain + domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k] + 1] for k in range(3)] + (x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij') - p = numpy.array([[xyz_min[0], xyz_max[1]], - [xyz_max[0], xyz_max[1]], - [xyz_max[0], xyz_min[1]], - [xyz_min[0], xyz_min[1]]], dtype=float) + # evaluate on the meshgrid + foreground_val = foregrounds_i(x0, y0, z0) + if not numpy.isfinite(foreground_val).all(): + raise GridError(f'Non-finite values in foreground[{i}]') + elif numpy.size(foregrounds_i) != 1: + raise GridError(f'Unsupported foreground[{i}]: {type(foregrounds_i)}') + else: + # foreground[i] is scalar non-callable + foreground_val = foregrounds_i - self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground) + w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int)) + + # Draw each polygon separately + for polygon in poly_list: + + # Get the boundaries of the polygon + pbd_min = polygon.min(axis=0) + pbd_max = polygon.max(axis=0) + + # Find indices in w_xy just outside polygon + # using per-grid xy-weights (self.shifted_xyz()) + corner_min = self.pos2ind(to_3d(pbd_min), i, + check_bounds=False)[surface].astype(int) + corner_max = self.pos2ind(to_3d(pbd_max), i, + check_bounds=False)[surface].astype(int) + + # Find indices in w_xy which are modified by polygon + # First for the edge coordinates (+1 since we're indexing edges) + edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max, strict=True)] + # Then for the pixel centers (-bdi_min since we're + # calculating weights within a subspace) + centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface], + corner_max - bdi_min[surface], strict=True)) + + aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices, strict=True)) + w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y) + + # Clamp overlapping polygons to 1 + w_xy = numpy.minimum(w_xy, 1.0) + + # 2) Generate weights in z-direction + w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], )) + + def get_zi(offset, i=i, w_z=w_z): + edges = self.shifted_exyz(i)[surface_normal] + point = center[surface_normal] + offset + grid_coord = numpy.digitize(point, edges) - 1 + w_coord = grid_coord - bdi_min[surface_normal] + + if w_coord < 0: + w_coord = 0 + f = 0 + elif w_coord >= w_z.size: + w_coord = w_z.size - 1 + f = 1 + else: + dz = self.shifted_dxyz(i)[surface_normal][grid_coord] + f = (point - edges[grid_coord]) / dz + return f, w_coord + + zi_top_f, zi_top = get_zi(+thickness / 2.0) + zi_bot_f, zi_bot = get_zi(-thickness / 2.0) + + w_z[zi_bot + 1:zi_top] = 1 + + if zi_bot < zi_top: + w_z[zi_top] = zi_top_f + w_z[zi_bot] = 1 - zi_bot_f + else: + w_z[zi_bot] = zi_top_f - zi_bot_f + + # 3) Generate total weight function + w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,))) + + # ## Modify the grid + g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3)) + cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val -def draw_cuboid( - self, - cell_data: NDArray, - center: ArrayLike, - dimensions: ArrayLike, - foreground: Sequence[foreground_t] | foreground_t, - ) -> None: - """ - Draw an axis-aligned cuboid + def draw_polygon( + self, + cell_data: NDArray, + surface_normal: int, + center: ArrayLike, + polygon: ArrayLike, + thickness: float, + foreground: Sequence[foreground_t] | foreground_t, + ) -> None: + """ + Draw a polygon on an axis-aligned plane. - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - center: 3-element ndarray or list specifying the cuboid's center - dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge - sizes of the cuboid - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - """ - dimensions = numpy.array(dimensions, copy=False) - p = numpy.array([[-dimensions[0], +dimensions[1]], - [+dimensions[0], +dimensions[1]], - [+dimensions[0], -dimensions[1]], - [-dimensions[0], -dimensions[1]]], dtype=float) / 2.0 - thickness = dimensions[2] - self.draw_polygon(cell_data, 2, center, p, thickness, foreground) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. + center: 3-element ndarray or list specifying an offset applied to the polygon + polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, + clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at + least 3 vertices. + thickness: Thickness of the layer to draw + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + """ + self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground) -def draw_cylinder( - self, - cell_data: NDArray, - surface_normal: int, - center: ArrayLike, - radius: float, - thickness: float, - num_points: int, - foreground: Sequence[foreground_t] | foreground_t, - ) -> None: - """ - Draw an axis-aligned cylinder. Approximated by a num_points-gon + def draw_slab( + self, + cell_data: NDArray, + surface_normal: int, + center: ArrayLike, + thickness: float, + foreground: Sequence[foreground_t] | foreground_t, + ) -> None: + """ + Draw an axis-aligned infinite slab. - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: 3-element ndarray or list specifying the cylinder's center - radius: cylinder radius - thickness: Thickness of the layer to draw - num_points: The circle is approximated by a polygon with `num_points` vertices - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - """ - theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False) - x = radius * numpy.sin(theta) - y = radius * numpy.cos(theta) - polygon = numpy.hstack((x[:, None], y[:, None])) - self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. + center: `surface_normal` coordinate value at the center of the slab + thickness: Thickness of the layer to draw + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + """ + # Turn surface_normal into its integer representation + if surface_normal not in range(3): + raise GridError('Invalid surface_normal direction') + + if numpy.size(center) != 1: + center = numpy.squeeze(center) + if len(center) == 3: + center = center[surface_normal] + else: + raise GridError(f'Bad center: {center}') + + # Find center of slab + center_shift = self.center + center_shift[surface_normal] = center + + surface = numpy.delete(range(3), surface_normal) + + xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface] + xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface] + + dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float) + + xyz_min -= 4 * dxyz + xyz_max += 4 * dxyz + + p = numpy.array([[xyz_min[0], xyz_max[1]], + [xyz_max[0], xyz_max[1]], + [xyz_max[0], xyz_min[1]], + [xyz_min[0], xyz_min[1]]], dtype=float) + + self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground) -def draw_extrude_rectangle( - self, - cell_data: NDArray, - rectangle: ArrayLike, - direction: int, - polarity: int, - distance: float, - ) -> None: - """ - Extrude a rectangle of a previously-drawn structure along an axis. + def draw_cuboid( + self, + cell_data: NDArray, + center: ArrayLike, + dimensions: ArrayLike, + foreground: Sequence[foreground_t] | foreground_t, + ) -> None: + """ + Draw an axis-aligned cuboid - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - rectangle: 2x3 ndarray or list specifying the rectangle's corners - direction: Direction to extrude in. Integer in `range(3)`. - polarity: +1 or -1, direction along axis to extrude in - distance: How far to extrude - """ - s = numpy.sign(polarity) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + center: 3-element ndarray or list specifying the cuboid's center + dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge + sizes of the cuboid + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + """ + dimensions = numpy.array(dimensions, copy=False) + p = numpy.array([[-dimensions[0], +dimensions[1]], + [+dimensions[0], +dimensions[1]], + [+dimensions[0], -dimensions[1]], + [-dimensions[0], -dimensions[1]]], dtype=float) / 2.0 + thickness = dimensions[2] + self.draw_polygon(cell_data, 2, center, p, thickness, foreground) - rectangle = numpy.array(rectangle, dtype=float) - if s == 0: - raise GridError('0 is not a valid polarity') - if direction not in range(3): - raise GridError(f'Invalid direction: {direction}') - if rectangle[0, direction] != rectangle[1, direction]: - raise GridError('Rectangle entries along extrusion direction do not match.') - center = rectangle.sum(axis=0) / 2.0 - center[direction] += s * distance / 2.0 + def draw_cylinder( + self, + cell_data: NDArray, + surface_normal: int, + center: ArrayLike, + radius: float, + thickness: float, + num_points: int, + foreground: Sequence[foreground_t] | foreground_t, + ) -> None: + """ + Draw an axis-aligned cylinder. Approximated by a num_points-gon - surface = numpy.delete(range(3), direction) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. + center: 3-element ndarray or list specifying the cylinder's center + radius: cylinder radius + thickness: Thickness of the layer to draw + num_points: The circle is approximated by a polygon with `num_points` vertices + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + """ + theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False) + x = radius * numpy.sin(theta) + y = radius * numpy.cos(theta) + polygon = numpy.hstack((x[:, None], y[:, None])) + self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground) - dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface] - p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5, - numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T - thickness = distance - foreground_func = [] - for i, grid in enumerate(cell_data): - z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction] + def draw_extrude_rectangle( + self, + cell_data: NDArray, + rectangle: ArrayLike, + direction: int, + polarity: int, + distance: float, + ) -> None: + """ + Extrude a rectangle of a previously-drawn structure along an axis. - ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)] + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + rectangle: 2x3 ndarray or list specifying the rectangle's corners + direction: Direction to extrude in. Integer in `range(3)`. + polarity: +1 or -1, direction along axis to extrude in + distance: How far to extrude + """ + s = numpy.sign(polarity) - fpart = z - numpy.floor(z) - mult = [1 - fpart, fpart][::s] # reverses if s negative + rectangle = numpy.array(rectangle, dtype=float) + if s == 0: + raise GridError('0 is not a valid polarity') + if direction not in range(3): + raise GridError(f'Invalid direction: {direction}') + if rectangle[0, direction] != rectangle[1, direction]: + raise GridError('Rectangle entries along extrusion direction do not match.') - foreground = mult[0] * grid[tuple(ind)] - ind[direction] += 1 # type: ignore #(known safe) - foreground += mult[1] * grid[tuple(ind)] + center = rectangle.sum(axis=0) / 2.0 + center[direction] += s * distance / 2.0 - def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int_]: - # transform from natural position to index - xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i) - for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=int) - # reshape to original shape and keep only in-plane components - qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface) - return foreground[qi, ri] + surface = numpy.delete(range(3), direction) - foreground_func.append(f_foreground) + dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface] + p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5, + numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T + thickness = distance - self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func) + foreground_func = [] + for i, grid in enumerate(cell_data): + z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction] + + ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)] + + fpart = z - numpy.floor(z) + mult = [1 - fpart, fpart][::s] # reverses if s negative + + foreground = mult[0] * grid[tuple(ind)] + ind[direction] += 1 # type: ignore #(known safe) + foreground += mult[1] * grid[tuple(ind)] + + def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int_]: + # transform from natural position to index + xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i) + for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=int) + # reshape to original shape and keep only in-plane components + qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface) + return foreground[qi, ri] + + foreground_func.append(f_foreground) + + self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func) diff --git a/gridlock/grid.py b/gridlock/grid.py index 2fb721b..55b1abb 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -1,4 +1,5 @@ -from typing import Callable, Sequence, ClassVar, Self +from typing import ClassVar, Self +from collections.abc import Callable, Sequence import numpy from numpy.typing import NDArray, ArrayLike @@ -8,12 +9,16 @@ import warnings import copy from . import GridError +from .base import GridBase +from .draw import GridDrawMixin +from .read import GridReadMixin +from .position import GridPosMixin foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray] -class Grid: +class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): """ Simulation grid metadata for finite-difference simulations. @@ -70,193 +75,6 @@ class Grid: ], dtype=float) """Default shifts for Yee grid H-field""" - from .draw import ( - draw_polygons, draw_polygon, draw_slab, draw_cuboid, - draw_cylinder, draw_extrude_rectangle, - ) - from .read import get_slice, visualize_slice, visualize_isosurface - from .position import ind2pos, pos2ind - - @property - def dxyz(self) -> list[NDArray]: - """ - Cell sizes for each axis, no shifts applied - - Returns: - List of 3 ndarrays of cell sizes - """ - return [numpy.diff(ee) for ee in self.exyz] - - @property - def xyz(self) -> list[NDArray]: - """ - Cell centers for each axis, no shifts applied - - Returns: - List of 3 ndarrays of cell edges - """ - return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)] - - @property - def shape(self) -> NDArray[numpy.int_]: - """ - The number of cells in x, y, and z - - Returns: - ndarray of [x_centers.size, y_centers.size, z_centers.size] - """ - return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int) - - @property - def num_grids(self) -> int: - """ - The number of grids (number of shifts) - """ - return self.shifts.shape[0] - - @property - def cell_data_shape(self): - """ - The shape of the cell_data ndarray (num_grids, *self.shape). - """ - return numpy.hstack((self.num_grids, self.shape)) - - @property - def dxyz_with_ghost(self) -> list[NDArray]: - """ - Gives dxyz with an additional 'ghost' cell at the end, whose value depends - on whether or not the axis has periodic boundary conditions. See main description - above to learn why this is necessary. - - If periodic, final edge shifts same amount as first - Otherwise, final edge shifts same amount as second-to-last - - Returns: - list of [dxs, dys, dzs] with each element same length as elements of `self.xyz` - """ - el = [0 if p else -1 for p in self.periodic] - return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)] - - @property - def center(self) -> NDArray[numpy.float64]: - """ - Center position of the entire grid, no shifts applied - - Returns: - ndarray of [x_center, y_center, z_center] - """ - # center is just average of first and last xyz, which is just the average of the - # first two and last two exyz - centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)] - return numpy.array(centers, dtype=float) - - @property - def dxyz_limits(self) -> tuple[NDArray, NDArray]: - """ - Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element - ndarrays. No shifts are applied, so these are extreme bounds on these values (as a - weighted average is performed when shifting). - - Returns: - Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]` - """ - d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float) - d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float) - return d_min, d_max - - def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]: - """ - Returns edges for which_shifts. - - Args: - which_shifts: Which grid (which shifts) to use, or `None` for unshifted - - Returns: - List of 3 ndarrays of cell edges - """ - if which_shifts is None: - return self.exyz - dxyz = self.dxyz_with_ghost - shifts = self.shifts[which_shifts, :] - - # If shift is negative, use left cell's dx to determine shift - for a in range(3): - if shifts[a] < 0: - dxyz[a] = numpy.roll(dxyz[a], 1) - - return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] - - def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]: - """ - Returns cell sizes for `which_shifts`. - - Args: - which_shifts: Which grid (which shifts) to use, or `None` for unshifted - - Returns: - List of 3 ndarrays of cell sizes - """ - if which_shifts is None: - return self.dxyz - shifts = self.shifts[which_shifts, :] - dxyz = self.dxyz_with_ghost - - # If shift is negative, use left cell's dx to determine size - sdxyz = [] - for a in range(3): - if shifts[a] < 0: - roll_dxyz = numpy.roll(dxyz[a], 1) - abs_shift = numpy.abs(shifts[a]) - sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift)) - else: - sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a]) - - return sdxyz - - def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: - """ - Returns cell centers for `which_shifts`. - - Args: - which_shifts: Which grid (which shifts) to use, or `None` for unshifted - - Returns: - List of 3 ndarrays of cell centers - """ - if which_shifts is None: - return self.xyz - exyz = self.shifted_exyz(which_shifts) - dxyz = self.shifted_dxyz(which_shifts) - return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)] - - def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]: - """ - Return cell widths, with each dimension shifted by the corresponding shifts. - - Returns: - `[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]` - """ - if self.num_grids != 3: - raise GridError('Autoshifting requires exactly 3 grids') - return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] - - def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray: - """ - Allocate an ndarray for storing grid data. - - Args: - fill_value: Value to initialize the grid to. If None, an - uninitialized array is returned. - dtype: Numpy dtype for the array. Default is `numpy.float32`. - - Returns: - The allocated array - """ - if fill_value is None: - return numpy.empty(self.cell_data_shape, dtype=dtype) - else: - return numpy.full(self.cell_data_shape, fill_value, dtype=dtype) - def __init__( self, pixel_edge_coordinates: Sequence[ArrayLike], @@ -277,11 +95,12 @@ class Grid: Raises: `GridError` on invalid input """ - self.exyz = [numpy.unique(pixel_edge_coordinates[i]) for i in range(3)] + edge_arrs = [numpy.array(cc, copy=False) for cc in pixel_edge_coordinates] + self.exyz = [numpy.unique(edges) for edges in edge_arrs] self.shifts = numpy.array(shifts, dtype=float) for i in range(3): - if len(self.exyz[i]) != len(pixel_edge_coordinates[i]): + if self.exyz[i].size != edge_arrs[i].size: warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2) if isinstance(periodic, bool): diff --git a/gridlock/position.py b/gridlock/position.py index 5928174..b705b99 100644 --- a/gridlock/position.py +++ b/gridlock/position.py @@ -5,112 +5,114 @@ import numpy from numpy.typing import NDArray, ArrayLike from . import GridError +from .base import GridBase -def ind2pos( - self, - ind: NDArray, - which_shifts: int | None = None, - round_ind: bool = True, - check_bounds: bool = True - ) -> NDArray[numpy.float64]: - """ - Returns the natural position corresponding to the specified cell center indices. - The resulting position is clipped to the bounds of the grid - (to cell centers if `round_ind=True`, or cell outer edges if `round_ind=False`) +class GridPosMixin(GridBase): + def ind2pos( + self, + ind: NDArray, + which_shifts: int | None = None, + round_ind: bool = True, + check_bounds: bool = True + ) -> NDArray[numpy.float64]: + """ + Returns the natural position corresponding to the specified cell center indices. + The resulting position is clipped to the bounds of the grid + (to cell centers if `round_ind=True`, or cell outer edges if `round_ind=False`) - Args: - ind: Indices of the position. Can be fractional. (3-element ndarray or list) - which_shifts: which grid number (`shifts`) to use - round_ind: Whether to round ind to the nearest integer position before indexing - (default `True`) - check_bounds: Whether to raise an `GridError` if the provided ind is outside of - the grid, as defined above (centers if `round_ind`, else edges) (default `True`) + Args: + ind: Indices of the position. Can be fractional. (3-element ndarray or list) + which_shifts: which grid number (`shifts`) to use + round_ind: Whether to round ind to the nearest integer position before indexing + (default `True`) + check_bounds: Whether to raise an `GridError` if the provided ind is outside of + the grid, as defined above (centers if `round_ind`, else edges) (default `True`) - Returns: - 3-element ndarray specifying the natural position + Returns: + 3-element ndarray specifying the natural position - Raises: - `GridError` if invalid `which_shifts` - `GridError` if `check_bounds` and out of bounds - """ - if which_shifts is not None and which_shifts >= self.shifts.shape[0]: - raise GridError('Invalid shifts') - ind = numpy.array(ind, dtype=float) + Raises: + `GridError` if invalid `which_shifts` + `GridError` if `check_bounds` and out of bounds + """ + if which_shifts is not None and which_shifts >= self.shifts.shape[0]: + raise GridError('Invalid shifts') + ind = numpy.array(ind, dtype=float) + + if check_bounds: + if round_ind: + low_bound = 0.0 + high_bound = -1.0 + else: + low_bound = -0.5 + high_bound = -0.5 + if (ind < low_bound).any() or (ind > self.shape - high_bound).any(): + raise GridError(f'Position outside of grid: {ind}') - if check_bounds: if round_ind: - low_bound = 0.0 - high_bound = -1.0 + rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) + sxyz = self.shifted_xyz(which_shifts) + position = [sxyz[a][rind[a]].astype(int) for a in range(3)] else: - low_bound = -0.5 - high_bound = -0.5 - if (ind < low_bound).any() or (ind > self.shape - high_bound).any(): - raise GridError(f'Position outside of grid: {ind}') + sexyz = self.shifted_exyz(which_shifts) + position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a]) + for a in range(3)] + return numpy.array(position, dtype=float) + + + def pos2ind( + self, + r: ArrayLike, + which_shifts: int | None, + round_ind: bool = True, + check_bounds: bool = True + ) -> NDArray[numpy.float64]: + """ + Returns the cell-center indices corresponding to the specified natural position. + The resulting position is clipped to within the outer centers of the grid. + + Args: + r: Natural position that we will convert into indices (3-element ndarray or list) + which_shifts: which grid number (`shifts`) to use + round_ind: Whether to round the returned indices to the nearest integers. + check_bounds: Whether to throw an `GridError` if `r` is outside the grid edges + + Returns: + 3-element ndarray specifying the indices + + Raises: + `GridError` if invalid `which_shifts` + `GridError` if `check_bounds` and out of bounds + """ + r = numpy.squeeze(r) + if r.size != 3: + raise GridError(f'r must be 3-element vector: {r}') + + if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]): + raise GridError(f'Invalid which_shifts: {which_shifts}') - if round_ind: - rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) - sxyz = self.shifted_xyz(which_shifts) - position = [sxyz[a][rind[a]].astype(int) for a in range(3)] - else: sexyz = self.shifted_exyz(which_shifts) - position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a]) - for a in range(3)] - return numpy.array(position, dtype=float) + if check_bounds: + for a in range(3): + if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]): + raise GridError(f'Position[{a}] outside of grid!') -def pos2ind( - self, - r: ArrayLike, - which_shifts: int | None, - round_ind: bool = True, - check_bounds: bool = True - ) -> NDArray[numpy.float64]: - """ - Returns the cell-center indices corresponding to the specified natural position. - The resulting position is clipped to within the outer centers of the grid. - - Args: - r: Natural position that we will convert into indices (3-element ndarray or list) - which_shifts: which grid number (`shifts`) to use - round_ind: Whether to round the returned indices to the nearest integers. - check_bounds: Whether to throw an `GridError` if `r` is outside the grid edges - - Returns: - 3-element ndarray specifying the indices - - Raises: - `GridError` if invalid `which_shifts` - `GridError` if `check_bounds` and out of bounds - """ - r = numpy.squeeze(r) - if r.size != 3: - raise GridError(f'r must be 3-element vector: {r}') - - if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]): - raise GridError(f'Invalid which_shifts: {which_shifts}') - - sexyz = self.shifted_exyz(which_shifts) - - if check_bounds: + grid_pos = numpy.zeros((3,)) for a in range(3): - if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]): - raise GridError(f'Position[{a}] outside of grid!') + xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in + xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds - grid_pos = numpy.zeros((3,)) - for a in range(3): - xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in - xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds + # No need to interpolate if round_ind is true or we were outside the grid + if round_ind or xi != xi_clipped: + grid_pos[a] = xi_clipped + else: + # Interpolate + x = self.shifted_xyz(which_shifts)[a][xi] + dx = self.shifted_dxyz(which_shifts)[a][xi] + f = (r[a] - x) / dx - # No need to interpolate if round_ind is true or we were outside the grid - if round_ind or xi != xi_clipped: - grid_pos[a] = xi_clipped - else: - # Interpolate - x = self.shifted_xyz(which_shifts)[a][xi] - dx = self.shifted_dxyz(which_shifts)[a][xi] - f = (r[a] - x) / dx - - # Clip to centers - grid_pos[a] = numpy.clip(xi + f, 0, self.shape[a] - 1) - return grid_pos + # Clip to centers + grid_pos[a] = numpy.clip(xi + f, 0, self.shape[a] - 1) + return grid_pos diff --git a/gridlock/read.py b/gridlock/read.py index 31b583d..e82ffcc 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -7,6 +7,8 @@ import numpy from numpy.typing import NDArray from . import GridError +from .base import GridBase +from .position import GridPosMixin if TYPE_CHECKING: import matplotlib.axes @@ -18,186 +20,187 @@ if TYPE_CHECKING: # .visualize_isosurface uses mpl_toolkits.mplot3d -def get_slice( - self, - cell_data: NDArray, - surface_normal: int, - center: float, - which_shifts: int = 0, - sample_period: int = 1 - ) -> NDArray: - """ - Retrieve a slice of a grid. - Interpolates if given a position between two planes. +class GridReadMixin(GridPosMixin): + def get_slice( + self, + cell_data: NDArray, + surface_normal: int, + center: float, + which_shifts: int = 0, + sample_period: int = 1 + ) -> NDArray: + """ + Retrieve a slice of a grid. + Interpolates if given a position between two planes. - Args: - cell_data: Cell data to slice - surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. - center: Scalar specifying position along surface_normal axis. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) + Args: + cell_data: Cell data to slice + surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. + center: Scalar specifying position along surface_normal axis. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) - Returns: - Array containing the portion of the grid. - """ - if numpy.size(center) != 1 or not numpy.isreal(center): - raise GridError('center must be a real scalar') + Returns: + Array containing the portion of the grid. + """ + if numpy.size(center) != 1 or not numpy.isreal(center): + raise GridError('center must be a real scalar') - sp = round(sample_period) - if sp <= 0: - raise GridError('sample_period must be positive') + sp = round(sample_period) + if sp <= 0: + raise GridError('sample_period must be positive') - if numpy.size(which_shifts) != 1 or which_shifts < 0: - raise GridError('Invalid which_shifts') + if numpy.size(which_shifts) != 1 or which_shifts < 0: + raise GridError('Invalid which_shifts') - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') + if surface_normal not in range(3): + raise GridError('Invalid surface_normal direction') - surface = numpy.delete(range(3), surface_normal) + surface = numpy.delete(range(3), surface_normal) - # Extract indices and weights of planes - center3 = numpy.insert([0, 0], surface_normal, (center,)) - center_index = self.pos2ind(center3, which_shifts, - round_ind=False, check_bounds=False)[surface_normal] - centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) - if len(centers) == 2: - fpart = center_index - numpy.floor(center_index) - w = [1 - fpart, fpart] # longer distance -> less weight - else: - w = [1] + # Extract indices and weights of planes + center3 = numpy.insert([0, 0], surface_normal, (center,)) + center_index = self.pos2ind(center3, which_shifts, + round_ind=False, check_bounds=False)[surface_normal] + centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) + if len(centers) == 2: + fpart = center_index - numpy.floor(center_index) + w = [1 - fpart, fpart] # longer distance -> less weight + else: + w = [1] - c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1]) - if center < c_min or center > c_max: - raise GridError('Coordinate of selected plane must be within simulation domain') + c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1]) + if center < c_min or center > c_max: + raise GridError('Coordinate of selected plane must be within simulation domain') - # Extract grid values from planes above and below visualized slice - sliced_grid = numpy.zeros(self.shape[surface]) - for ci, weight in zip(centers, w, strict=True): - s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3)) - sliced_grid += weight * cell_data[which_shifts][tuple(s)] + # Extract grid values from planes above and below visualized slice + sliced_grid = numpy.zeros(self.shape[surface]) + for ci, weight in zip(centers, w, strict=True): + s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3)) + sliced_grid += weight * cell_data[which_shifts][tuple(s)] - # Remove extra dimensions - sliced_grid = numpy.squeeze(sliced_grid) + # Remove extra dimensions + sliced_grid = numpy.squeeze(sliced_grid) - return sliced_grid + return sliced_grid -def visualize_slice( - self, - cell_data: NDArray, - surface_normal: int, - center: float, - which_shifts: int = 0, - sample_period: int = 1, - finalize: bool = True, - pcolormesh_args: dict[str, Any] | None = None, - ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: - """ - Visualize a slice of a grid. - Interpolates if given a position between two planes. + def visualize_slice( + self, + cell_data: NDArray, + surface_normal: int, + center: float, + which_shifts: int = 0, + sample_period: int = 1, + finalize: bool = True, + pcolormesh_args: dict[str, Any] | None = None, + ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: + """ + Visualize a slice of a grid. + Interpolates if given a position between two planes. - Args: - surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. - center: Scalar specifying position along surface_normal axis. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) - finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + Args: + surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. + center: Scalar specifying position along surface_normal axis. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) + finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - Returns: - (Figure, Axes) - """ - from matplotlib import pyplot + Returns: + (Figure, Axes) + """ + from matplotlib import pyplot - if pcolormesh_args is None: - pcolormesh_args = {} + if pcolormesh_args is None: + pcolormesh_args = {} - grid_slice = self.get_slice(cell_data=cell_data, - surface_normal=surface_normal, - center=center, - which_shifts=which_shifts, - sample_period=sample_period) + grid_slice = self.get_slice(cell_data=cell_data, + surface_normal=surface_normal, + center=center, + which_shifts=which_shifts, + sample_period=sample_period) - surface = numpy.delete(range(3), surface_normal) + surface = numpy.delete(range(3), surface_normal) - x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) - xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') - x_label, y_label = ('xyz'[a] for a in surface) + x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) + xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') + x_label, y_label = ('xyz'[a] for a in surface) - fig, ax = pyplot.subplots() - mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) - fig.colorbar(mappable) - ax.set_aspect('equal', adjustable='box') - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - if finalize: - pyplot.show() + fig, ax = pyplot.subplots() + mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) + fig.colorbar(mappable) + ax.set_aspect('equal', adjustable='box') + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + if finalize: + pyplot.show() - return fig, ax + return fig, ax -def visualize_isosurface( - self, - cell_data: NDArray, - level: float | None = None, - which_shifts: int = 0, - sample_period: int = 1, - show_edges: bool = True, - finalize: bool = True, - ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: - """ - Draw an isosurface plot of the device. + def visualize_isosurface( + self, + cell_data: NDArray, + level: float | None = None, + which_shifts: int = 0, + sample_period: int = 1, + show_edges: bool = True, + finalize: bool = True, + ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: + """ + Draw an isosurface plot of the device. - Args: - cell_data: Cell data to visualize - level: Value at which to find isosurface. Default (None) uses mean value in grid. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) - show_edges: Whether to draw triangle edges. Default `True` - finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + Args: + cell_data: Cell data to visualize + level: Value at which to find isosurface. Default (None) uses mean value in grid. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) + show_edges: Whether to draw triangle edges. Default `True` + finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - Returns: - (Figure, Axes) - """ - from matplotlib import pyplot - import skimage.measure - # Claims to be unused, but needed for subplot(projection='3d') - from mpl_toolkits.mplot3d import Axes3D - del Axes3D # imported for side effects only + Returns: + (Figure, Axes) + """ + from matplotlib import pyplot + import skimage.measure + # Claims to be unused, but needed for subplot(projection='3d') + from mpl_toolkits.mplot3d import Axes3D + del Axes3D # imported for side effects only - # Get data from cell_data - grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period] - if level is None: - level = grid.mean() + # Get data from cell_data + grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period] + if level is None: + level = grid.mean() - # Find isosurface with marching cubes - verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) + # Find isosurface with marching cubes + verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) - # Convert vertices from index to position - pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False) - for i in range(verts.shape[0])], dtype=float) - xs, ys, zs = (pos_verts[:, a] for a in range(3)) + # Convert vertices from index to position + pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False) + for i in range(verts.shape[0])], dtype=float) + xs, ys, zs = (pos_verts[:, a] for a in range(3)) - # Draw the plot - fig = pyplot.figure() - ax = fig.add_subplot(111, projection='3d') - if show_edges: - ax.plot_trisurf(xs, ys, faces, zs) - else: - ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') + # Draw the plot + fig = pyplot.figure() + ax = fig.add_subplot(111, projection='3d') + if show_edges: + ax.plot_trisurf(xs, ys, faces, zs) + else: + ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') - # Add a fake plot of a cube to force the axes to be equal lengths - max_range = numpy.array([xs.max() - xs.min(), - ys.max() - ys.min(), - zs.max() - zs.min()], dtype=float).max() - mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2] - xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min()) - ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min()) - zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min()) - # Comment or uncomment following both lines to test the fake bounding box: - for xb, yb, zb in zip(xbs, ybs, zbs, strict=True): - ax.plot([xb], [yb], [zb], 'w') + # Add a fake plot of a cube to force the axes to be equal lengths + max_range = numpy.array([xs.max() - xs.min(), + ys.max() - ys.min(), + zs.max() - zs.min()], dtype=float).max() + mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2] + xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min()) + ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min()) + zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min()) + # Comment or uncomment following both lines to test the fake bounding box: + for xb, yb, zb in zip(xbs, ybs, zbs, strict=True): + ax.plot([xb], [yb], [zb], 'w') - if finalize: - pyplot.show() + if finalize: + pyplot.show() - return fig, ax + return fig, ax