diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 0042015..0000000 --- a/.flake8 +++ /dev/null @@ -1,29 +0,0 @@ -[flake8] -ignore = - # E501 line too long - E501, - # W391 newlines at EOF - W391, - # E241 multiple spaces after comma - E241, - # E302 expected 2 newlines - E302, - # W503 line break before binary operator (to be deprecated) - W503, - # E265 block comment should start with '# ' - E265, - # E123 closing bracket does not match indentation of opening bracket's line - E123, - # E124 closing bracket does not match visual indentation - E124, - # E221 multiple spaces before operator - E221, - # E201 whitespace after '[' - E201, - # E741 ambiguous variable name 'I' - E741, - - -per-file-ignores = - # F401 import without use - */__init__.py: F401, diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..c28ab72 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include README.md +include LICENSE.md diff --git a/README.md b/README.md index e4f43d5..e882b2e 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ the coordinates of the boundary points along each axis). ## Installation Requirements: -* python >3.11 (written and tested with 3.12) +* python 3 (written and tested with 3.9) * numpy * [float_raster](https://mpxd.net/code/jan/float_raster) * matplotlib (optional, used for visualization functions) diff --git a/gridlock/LICENSE.md b/gridlock/LICENSE.md deleted file mode 120000 index 7eabdb1..0000000 --- a/gridlock/LICENSE.md +++ /dev/null @@ -1 +0,0 @@ -../LICENSE.md \ No newline at end of file diff --git a/gridlock/README.md b/gridlock/README.md deleted file mode 120000 index 32d46ee..0000000 --- a/gridlock/README.md +++ /dev/null @@ -1 +0,0 @@ -../README.md \ No newline at end of file diff --git a/gridlock/VERSION.py b/gridlock/VERSION.py new file mode 100644 index 0000000..8e6abf6 --- /dev/null +++ b/gridlock/VERSION.py @@ -0,0 +1,4 @@ +""" VERSION defintion. THIS FILE IS MANUALLY PARSED BY setup.py and REQUIRES A SPECIFIC FORMAT """ +__version__ = ''' +1.0 +'''.strip() diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 759d1c1..d547794 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -15,25 +15,10 @@ Dependencies: - mpl_toolkits.mplot3d [Grid.visualize_isosurface()] - skimage [Grid.visualize_isosurface()] """ -from .utils import ( - GridError as GridError, - - Extent as Extent, - ExtentProtocol as ExtentProtocol, - ExtentDict as ExtentDict, - - Slab as Slab, - SlabProtocol as SlabProtocol, - SlabDict as SlabDict, - - Plane as Plane, - PlaneProtocol as PlaneProtocol, - PlaneDict as PlaneDict, - ) -from .grid import Grid as Grid -from .data import GridData as GridData - +from .error import GridError +from .grid import Grid __author__ = 'Jan Petykiewicz' -__version__ = '2.2' + +from .VERSION import __version__ version = __version__ diff --git a/gridlock/base.py b/gridlock/base.py deleted file mode 100644 index e68d955..0000000 --- a/gridlock/base.py +++ /dev/null @@ -1,192 +0,0 @@ -from typing import Protocol - -import numpy -from numpy.typing import NDArray - -from . import GridError - - -class GridBase(Protocol): - exyz: list[NDArray] - """Cell edges. Monotonically increasing without duplicates.""" - - periodic: list[bool] - """For each axis, determines how far the rightmost boundary gets shifted. """ - - shifts: NDArray - """Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`""" - - @property - def dxyz(self) -> list[NDArray]: - """ - Cell sizes for each axis, no shifts applied - - Returns: - List of 3 ndarrays of cell sizes - """ - return [numpy.diff(ee) for ee in self.exyz] - - @property - def xyz(self) -> list[NDArray]: - """ - Cell centers for each axis, no shifts applied - - Returns: - List of 3 ndarrays of cell edges - """ - return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)] - - @property - def shape(self) -> NDArray[numpy.intp]: - """ - The number of cells in x, y, and z - - Returns: - ndarray of [x_centers.size, y_centers.size, z_centers.size] - """ - return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int) - - @property - def num_grids(self) -> int: - """ - The number of grids (number of shifts) - """ - return self.shifts.shape[0] - - @property - def cell_data_shape(self) -> NDArray[numpy.intp]: - """ - The shape of the cell_data ndarray (num_grids, *self.shape). - """ - return numpy.hstack((self.num_grids, self.shape)) - - @property - def dxyz_with_ghost(self) -> list[NDArray]: - """ - Gives dxyz with an additional 'ghost' cell at the end, whose value depends - on whether or not the axis has periodic boundary conditions. See main description - above to learn why this is necessary. - - If periodic, final edge shifts same amount as first - Otherwise, final edge shifts same amount as second-to-last - - Returns: - list of [dxs, dys, dzs] with each element same length as elements of `self.xyz` - """ - el = [0 if p else -1 for p in self.periodic] - return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)] - - def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: - if which_shifts is None: - return self.dxyz_with_ghost - - shifts = self.shifts[which_shifts, :] - edge_dxyz = [] - for a in range(3): - if shifts[a] < 0: - ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0] - edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a]))) - else: - ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1] - edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost))) - return edge_dxyz - - @property - def center(self) -> NDArray[numpy.float64]: - """ - Center position of the entire grid, no shifts applied - - Returns: - ndarray of [x_center, y_center, z_center] - """ - # center is just average of first and last xyz, which is just the average of the - # first two and last two exyz - centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)] - return numpy.array(centers, dtype=float) - - @property - def dxyz_limits(self) -> tuple[NDArray, NDArray]: - """ - Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element - ndarrays. No shifts are applied, so these are extreme bounds on these values (as a - weighted average is performed when shifting). - - Returns: - Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]` - """ - d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float) - d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float) - return d_min, d_max - - def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]: - """ - Returns edges for which_shifts. - - Args: - which_shifts: Which grid (which shifts) to use, or `None` for unshifted - - Returns: - List of 3 ndarrays of cell edges - """ - if which_shifts is None: - return self.exyz - edge_dxyz = self._shifted_edge_dxyz(which_shifts) - shifts = self.shifts[which_shifts, :] - return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)] - - def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]: - """ - Returns cell sizes for `which_shifts`. - - Args: - which_shifts: Which grid (which shifts) to use, or `None` for unshifted - - Returns: - List of 3 ndarrays of cell sizes - """ - if which_shifts is None: - return self.dxyz - return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)] - - def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: - """ - Returns cell centers for `which_shifts`. - - Args: - which_shifts: Which grid (which shifts) to use, or `None` for unshifted - - Returns: - List of 3 ndarrays of cell centers - """ - if which_shifts is None: - return self.xyz - exyz = self.shifted_exyz(which_shifts) - dxyz = self.shifted_dxyz(which_shifts) - return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)] - - def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]: - """ - Return cell widths, with each dimension shifted by the corresponding shifts. - - Returns: - `[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]` - """ - if self.num_grids != 3: - raise GridError('Autoshifting requires exactly 3 grids') - return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] - - def allocate(self, fill_value: float | None = 1.0, dtype: type[numpy.number] = numpy.float32) -> NDArray: - """ - Allocate an ndarray for storing grid data. - - Args: - fill_value: Value to initialize the grid to. If None, an - uninitialized array is returned. - dtype: Numpy dtype for the array. Default is `numpy.float32`. - - Returns: - The allocated array - """ - if fill_value is None: - return numpy.empty(self.cell_data_shape, dtype=dtype) - return numpy.full(self.cell_data_shape, fill_value, dtype=dtype) diff --git a/gridlock/data.py b/gridlock/data.py deleted file mode 100644 index 5e6faa5..0000000 --- a/gridlock/data.py +++ /dev/null @@ -1,176 +0,0 @@ -from dataclasses import dataclass -from typing import Self -from collections.abc import Sequence - -import numpy -from numpy.typing import NDArray, ArrayLike - -from .draw import foreground_t -from .grid import Grid, _grid_from_payload, _load_payload, _payload_scalar_str, _save_npz_payload -from .utils import ( - ExtentDict, - ExtentProtocol, - GridError, - PlaneDict, - PlaneProtocol, - SlabDict, - SlabProtocol, -) - - -@dataclass(slots=True) -class GridData: - grid: Grid - cell_data: NDArray - - def __post_init__(self) -> None: - if tuple(self.cell_data.shape) != tuple(self.grid.cell_data_shape): - raise GridError( - f'cell_data has shape {self.cell_data.shape}, expected {tuple(self.grid.cell_data_shape)}' - ) - - @staticmethod - def load(filename: str) -> 'GridData': - payload = _load_payload(filename) - if _payload_scalar_str(payload, 'kind') != 'grid_data': - raise GridError('Serialized payload does not contain GridData') - if 'cell_data' not in payload: - raise GridError('Serialized GridData payload is missing cell_data') - - return GridData(_grid_from_payload(payload), numpy.array(payload['cell_data'])) - - def save(self, filename: str) -> Self: - payload = self.grid._serialization_payload(kind='grid_data') - payload['cell_data'] = self.cell_data - _save_npz_payload(filename, payload) - return self - - def copy(self) -> Self: - return GridData(self.grid.copy(), self.cell_data.copy()) - - def draw_polygons( - self, - foreground: Sequence[foreground_t] | foreground_t, - slab: SlabProtocol | SlabDict, - polygons: Sequence[ArrayLike], - *, - offset2d: ArrayLike = (0, 0), - ) -> Self: - self.grid.draw_polygons(self.cell_data, foreground, slab, polygons, offset2d=offset2d) - return self - - def draw_polygon( - self, - foreground: Sequence[foreground_t] | foreground_t, - slab: SlabProtocol | SlabDict, - polygon: ArrayLike, - *, - offset2d: ArrayLike = (0, 0), - ) -> Self: - self.grid.draw_polygon(self.cell_data, foreground, slab, polygon, offset2d=offset2d) - return self - - def draw_slab( - self, - foreground: Sequence[foreground_t] | foreground_t, - slab: SlabProtocol | SlabDict, - ) -> Self: - self.grid.draw_slab(self.cell_data, foreground, slab) - return self - - def draw_cuboid( - self, - foreground: Sequence[foreground_t] | foreground_t, - *, - x: ExtentProtocol | ExtentDict, - y: ExtentProtocol | ExtentDict, - z: ExtentProtocol | ExtentDict, - ) -> Self: - self.grid.draw_cuboid(self.cell_data, foreground, x=x, y=y, z=z) - return self - - def draw_cylinder( - self, - h: SlabProtocol | SlabDict, - radius: float, - num_points: int, - center2d: ArrayLike, - foreground: Sequence[foreground_t] | foreground_t, - ) -> Self: - self.grid.draw_cylinder(self.cell_data, h, radius, num_points, center2d, foreground) - return self - - def draw_extrude_rectangle( - self, - rectangle: ArrayLike, - direction: int, - polarity: int, - distance: float, - ) -> Self: - self.grid.draw_extrude_rectangle(self.cell_data, rectangle, direction, polarity, distance) - return self - - def get_slice( - self, - plane: PlaneProtocol | PlaneDict, - which_shifts: int = 0, - sample_period: int = 1, - ) -> NDArray: - return self.grid.get_slice(self.cell_data, plane, which_shifts=which_shifts, sample_period=sample_period) - - def visualize_slice( - self, - plane: PlaneProtocol | PlaneDict, - which_shifts: int = 0, - sample_period: int = 1, - finalize: bool = True, - pcolormesh_args: dict[str, object] | None = None, - ax: object | None = None, - ) -> tuple[object, object]: - return self.grid.visualize_slice( - self.cell_data, - plane, - which_shifts=which_shifts, - sample_period=sample_period, - finalize=finalize, - pcolormesh_args=pcolormesh_args, - ax=ax, - ) - - def visualize_edges( - self, - plane: PlaneProtocol | PlaneDict, - which_shifts: int = 0, - sample_period: int = 1, - finalize: bool = True, - contour_args: dict[str, object] | None = None, - ax: object | None = None, - level_fraction: float = 0.7, - ) -> tuple[object, object]: - return self.grid.visualize_edges( - self.cell_data, - plane, - which_shifts=which_shifts, - sample_period=sample_period, - finalize=finalize, - contour_args=contour_args, - ax=ax, - level_fraction=level_fraction, - ) - - def visualize_isosurface( - self, - level: float | None = None, - which_shifts: int = 0, - sample_period: int = 1, - show_edges: bool = True, - finalize: bool = True, - ) -> tuple[object, object]: - return self.grid.visualize_isosurface( - self.cell_data, - level=level, - which_shifts=which_shifts, - sample_period=sample_period, - show_edges=show_edges, - finalize=finalize, - ) diff --git a/gridlock/direction.py b/gridlock/direction.py new file mode 100644 index 0000000..b93b122 --- /dev/null +++ b/gridlock/direction.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class Direction(Enum): + """ + Enum for axis->integer mapping + """ + x = 0 + y = 1 + z = 2 diff --git a/gridlock/draw.py b/gridlock/draw.py index 321ec15..6385213 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -1,14 +1,12 @@ """ Drawing-related methods for Grid class """ -from collections.abc import Sequence, Callable +from typing import List, Optional, Union, Sequence, Callable -import numpy -from numpy.typing import NDArray, ArrayLike +import numpy # type: ignore from float_raster import raster -from .utils import GridError, Slab, SlabDict, SlabProtocol, Extent, ExtentDict, ExtentProtocol -from .position import GridPosMixin +from . import GridError # NOTE: Maybe it would make sense to create a GridDrawer class @@ -17,387 +15,364 @@ from .position import GridPosMixin # without having to pass `cell_data` again each time? -foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray] -foreground_t = float | foreground_callable_t +foreground_callable_t = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray] -class GridDrawMixin(GridPosMixin): - def draw_polygons( - self, - cell_data: NDArray, - foreground: Sequence[foreground_t] | foreground_t, - slab: SlabProtocol | SlabDict, - polygons: Sequence[ArrayLike], - *, - offset2d: ArrayLike = (0, 0), - ) -> None: - """ - Draw polygons on an axis-aligned slab. +def draw_polygons(self, + cell_data: numpy.ndarray, + surface_normal: int, + center: numpy.ndarray, + polygons: Sequence[numpy.ndarray], + thickness: float, + foreground: Union[Sequence[Union[float, foreground_callable_t]], float, foreground_callable_t], + ) -> None: + """ + Draw polygons on an axis-aligned plane. - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list - of any of these (1 per grid). Callable values should take an ndarray the shape of the - grid and return an ndarray of equal shape containing the foreground value at the given x, y, - and z (natural, not grid coordinates). - slab: `Slab` or slab-like dict specifying the slab in which the polygons will be drawn. - polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon - (non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each - polygon must have at least 3 vertices. - offset2d: 2D offset to apply to polygon coordinates -- this offset is added directly - to the given polygon vertex coordinates. Default (0, 0). + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. + center: 3-element ndarray or list specifying an offset applied to all the polygons + polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon + (non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each + polygon must have at least 3 vertices. + thickness: Thickness of the layer to draw + foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list + of any of these (1 per grid). Callable values should take an ndarray the shape of the + grid and return an ndarray of equal shape containing the foreground value at the given x, y, + and z (natural, not grid coordinates). - Raises: - GridError - """ - if isinstance(slab, dict): - slab = Slab(**slab) + Raises: + GridError + """ + if surface_normal not in range(3): + raise GridError('Invalid surface_normal direction') - poly_list = [numpy.asarray(poly) for poly in polygons] + center = numpy.squeeze(center) - # Check polygons, and remove redundant coordinates - surface = numpy.delete(range(3), slab.axis) + # Check polygons, and remove redundant coordinates + surface = numpy.delete(range(3), surface_normal) - for ii in range(len(poly_list)): - polygon = poly_list[ii] - malformed = f'Malformed polygon: ({ii})' - if polygon.ndim != 2: - raise GridError(malformed + 'must be a 2-dimensional ndarray') - if polygon.shape[1] not in (2, 3): + for i, polygon in enumerate(polygons): + malformed = f'Malformed polygon: ({i})' + if polygon.shape[1] not in (2, 3): raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') - if polygon.shape[1] == 3: - if numpy.unique(polygon[:, slab.axis]).size != 1: - raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis]) - polygon = polygon[:, surface] - poly_list[ii] = polygon + if polygon.shape[1] == 3: + polygon = polygon[surface, :] - if not polygon.shape[0] > 2: - raise GridError(malformed + 'must consist of more than 2 points') + if not polygon.shape[0] > 2: + raise GridError(malformed + 'must consist of more than 2 points') + if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1: + raise GridError(malformed + 'must be in plane with surface normal ' + + 'xyz'[surface_normal]) - # Broadcast foreground where necessary - foregrounds: Sequence[foreground_callable_t] | Sequence[float] - if isinstance(foreground, numpy.ndarray): - raise GridError('ndarray not supported for foreground') - if callable(foreground) or numpy.isscalar(foreground): - foregrounds = [foreground] * len(cell_data) # type: ignore[list-item] + # Broadcast foreground where necessary + if numpy.size(foreground) == 1: + foreground = [foreground] * len(cell_data) + elif isinstance(foreground, numpy.ndarray): + raise GridError('ndarray not supported for foreground') + + # ## Compute sub-domain of the grid occupied by polygons + # 1) Compute outer bounds (bd) of polygons + bd_2d_min = [0, 0] + bd_2d_max = [0, 0] + for polygon in polygons: + bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) + bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) + bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center + bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center + + # 2) Find indices (bdi) just outside bd elements + buf = 2 # size of safety buffer + # Use s_min and s_max with unshifted pos2ind to get absolute limits on + # the indices the polygons might affect + s_min = self.shifts.min(axis=0) + s_max = self.shifts.max(axis=0) + bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf + bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf + bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int) + bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int) + + # 3) Adjust polygons for center + polygons = [poly + center[surface] for poly in polygons] + + # ## Generate weighing function + def to_3d(vector: numpy.ndarray, val: float = 0.0) -> numpy.ndarray: + v_2d = numpy.array(vector, dtype=float) + return numpy.insert(v_2d, surface_normal, (val,)) + + # iterate over grids + for i, grid in enumerate(cell_data): + # ## Evaluate or expand foreground[i] + if callable(foreground[i]): + # meshgrid over the (shifted) domain + domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k]+1] for k in range(3)] + (x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij') + + # evaluate on the meshgrid + foreground_i = foreground[i](x0, y0, z0) + if not numpy.isfinite(foreground_i).all(): + raise GridError(f'Non-finite values in foreground[{i}]') + elif numpy.size(foreground[i]) != 1: + raise GridError(f'Unsupported foreground[{i}]: {type(foreground[i])}') else: - foregrounds = foreground # type: ignore + # foreground[i] is scalar non-callable + foreground_i = foreground[i] - # ## Compute sub-domain of the grid occupied by polygons - # 1) Compute outer bounds (bd) of polygons - bd_2d_min = numpy.array([0, 0]) - bd_2d_max = numpy.array([0, 0]) - for polygon in poly_list: - bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) + offset2d - bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) + offset2d - bd_min = numpy.insert(bd_2d_min, slab.axis, slab.min) - bd_max = numpy.insert(bd_2d_max, slab.axis, slab.max) + w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int)) - # 2) Find indices (bdi) just outside bd elements - buf = 2 # size of safety buffer - # Use s_min and s_max with unshifted pos2ind to get absolute limits on - # the indices the polygons might affect - s_min = self.shifts.min(axis=0) - s_max = self.shifts.max(axis=0) - bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf - bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf - bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int) - bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int) + # Draw each polygon separately + for polygon in polygons: - # 3) Adjust polygons for offset2d - poly_list = [poly + offset2d for poly in poly_list] + # Get the boundaries of the polygon + pbd_min = polygon.min(axis=0) + pbd_max = polygon.max(axis=0) - # ## Generate weighing function - def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]: - v_2d = numpy.array(vector, dtype=float) - return numpy.insert(v_2d, slab.axis, (val,)) + # Find indices in w_xy just outside polygon + # using per-grid xy-weights (self.shifted_xyz()) + corner_min = self.pos2ind(to_3d(pbd_min), i, + check_bounds=False)[surface].astype(int) + corner_max = self.pos2ind(to_3d(pbd_max), i, + check_bounds=False)[surface].astype(int) - # iterate over grids - foreground_val: NDArray | float - for i, _ in enumerate(cell_data): - # ## Evaluate or expand foregrounds[i] - foregrounds_i = foregrounds[i] - if callable(foregrounds_i): - # meshgrid over the (shifted) domain - domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k] + 1] for k in range(3)] - (x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij') + # Find indices in w_xy which are modified by polygon + # First for the edge coordinates (+1 since we're indexing edges) + edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max)] + # Then for the pixel centers (-bdi_min since we're + # calculating weights within a subspace) + centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface], + corner_max - bdi_min[surface])) - # evaluate on the meshgrid - foreground_val = foregrounds_i(x0, y0, z0) - if not numpy.isfinite(foreground_val).all(): - raise GridError(f'Non-finite values in foreground[{i}]') - elif numpy.size(foregrounds_i) != 1: - raise GridError(f'Unsupported foreground[{i}]: {type(foregrounds_i)}') + aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices)) + w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y) + + # Clamp overlapping polygons to 1 + w_xy = numpy.minimum(w_xy, 1.0) + + # 2) Generate weights in z-direction + w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], )) + + def get_zi(offset, i=i, w_z=w_z): + edges = self.shifted_exyz(i)[surface_normal] + point = center[surface_normal] + offset + grid_coord = numpy.digitize(point, edges) - 1 + w_coord = grid_coord - bdi_min[surface_normal] + + if w_coord < 0: + w_coord = 0 + f = 0 + elif w_coord >= w_z.size: + w_coord = w_z.size - 1 + f = 1 else: - # foreground[i] is scalar non-callable - foreground_val = foregrounds_i + dz = self.shifted_dxyz(i)[surface_normal][grid_coord] + f = (point - edges[grid_coord]) / dz + return f, w_coord - w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int)) + zi_top_f, zi_top = get_zi(+thickness / 2.0) + zi_bot_f, zi_bot = get_zi(-thickness / 2.0) - # Draw each polygon separately - for polygon in poly_list: + w_z[zi_bot + 1:zi_top] = 1 - # Get the boundaries of the polygon - pbd_min = polygon.min(axis=0) - pbd_max = polygon.max(axis=0) + if zi_bot < zi_top: + w_z[zi_top] = zi_top_f + w_z[zi_bot] = 1 - zi_bot_f + else: + w_z[zi_bot] = zi_top_f - zi_bot_f - # Find indices in w_xy just outside polygon - # using per-grid xy-weights (self.shifted_xyz()) - corner_min = self.pos2ind(to_3d(pbd_min), i, - check_bounds=False)[surface].astype(int) - corner_max = self.pos2ind(to_3d(pbd_max), i, - check_bounds=False)[surface].astype(int) + # 3) Generate total weight function + w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,))) - # Find indices in w_xy which are modified by polygon - # First for the edge coordinates (+1 since we're indexing edges) - edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max, strict=True)] - # Then for the pixel centers (-bdi_min since we're - # calculating weights within a subspace) - centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface], - corner_max - bdi_min[surface], strict=True)) - - aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices, strict=True)) - w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y) - - # Clamp overlapping polygons to 1 - w_xy = numpy.minimum(w_xy, 1.0) - - # 2) Generate weights in z-direction - w_z = numpy.zeros(((bdi_max - bdi_min + 1)[slab.axis], )) - - def get_zi(point: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001 - edges = self.shifted_exyz(i)[slab.axis] - grid_coord = numpy.digitize(point, edges) - 1 - w_coord = grid_coord - bdi_min[slab.axis] - - if w_coord < 0: - w_coord = 0 - f = 0 - elif w_coord >= w_z.size: - w_coord = w_z.size - 1 - f = 1 - else: - dz = self.shifted_dxyz(i)[slab.axis][grid_coord] - f = (point - edges[grid_coord]) / dz - return f, w_coord - - zi_top_f, zi_top = get_zi(slab.max) - zi_bot_f, zi_bot = get_zi(slab.min) - - w_z[zi_bot + 1:zi_top] = 1 - - if zi_bot < zi_top: - w_z[zi_top] = zi_top_f - w_z[zi_bot] = 1 - zi_bot_f - else: - w_z[zi_bot] = zi_top_f - zi_bot_f - - # 3) Generate total weight function - w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], slab.axis, (2,))) - - # ## Modify the grid - g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3)) - cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val + # ## Modify the grid + g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3)) + cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_i - def draw_polygon( - self, - cell_data: NDArray, - foreground: Sequence[foreground_t] | foreground_t, - slab: SlabProtocol | SlabDict, - polygon: ArrayLike, - *, - offset2d: ArrayLike = (0, 0), - ) -> None: - """ - Draw a polygon on an axis-aligned plane. +def draw_polygon(self, + cell_data: numpy.ndarray, + surface_normal: int, + center: numpy.ndarray, + polygon: numpy.ndarray, + thickness: float, + foreground: Union[Sequence[Union[float, foreground_callable_t]], float, foreground_callable_t], + ) -> None: + """ + Draw a polygon on an axis-aligned plane. - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - slab: `Slab` or slab-like dict specifying the slab in which the polygon will be drawn. - polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, - clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Must have at - least 3 vertices. - offset2d: 2D offset to apply to polygon coordinates -- this offset is added directly - to the given polygon vertex coordinates. Default (0, 0). - """ - self.draw_polygons( - cell_data = cell_data, - slab = slab, - polygons = [polygon], - foreground = foreground, - offset2d = offset2d, - ) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. + center: 3-element ndarray or list specifying an offset applied to the polygon + polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, + clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at + least 3 vertices. + thickness: Thickness of the layer to draw + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + """ + self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground) - def draw_slab( - self, - cell_data: NDArray, - foreground: Sequence[foreground_t] | foreground_t, - slab: SlabProtocol | SlabDict, - ) -> None: - """ - Draw an axis-aligned infinite slab. +def draw_slab(self, + cell_data: numpy.ndarray, + surface_normal: int, + center: numpy.ndarray, + thickness: float, + foreground: Union[List[Union[float, foreground_callable_t]], float, foreground_callable_t], + ) -> None: + """ + Draw an axis-aligned infinite slab. - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - slab: `Slab` or slab-like dict (geometrical slab specification) - """ - if isinstance(slab, dict): - slab = Slab(**slab) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. + center: `surface_normal` coordinate value at the center of the slab + thickness: Thickness of the layer to draw + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + """ + # Turn surface_normal into its integer representation + if surface_normal not in range(3): + raise GridError('Invalid surface_normal direction') - # Find center of slab - center_shift = self.center - center_shift[slab.axis] = slab.center + if numpy.size(center) != 1: + center = numpy.squeeze(center) + if len(center) == 3: + center = center[surface_normal] + else: + raise GridError(f'Bad center: {center}') - surface = numpy.delete(range(3), slab.axis) - u_min, u_max = self.exyz[surface[0]][[0, -1]] - v_min, v_max = self.exyz[surface[1]][[0, -1]] + # Find center of slab + center_shift = self.center + center_shift[surface_normal] = center - margin = 4 * numpy.max([self.dxyz[surface[0]].max(), - self.dxyz[surface[1]].max()]) + surface = numpy.delete(range(3), surface_normal) - p = numpy.array([[u_min - margin, v_max + margin], - [u_max + margin, v_max + margin], - [u_max + margin, v_min - margin], - [u_min - margin, v_min - margin]], dtype=float) + xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface] + xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface] - self.draw_polygon( - cell_data = cell_data, - slab = slab, - polygon = p, - foreground = foreground, - ) + dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float) + + xyz_min -= 4 * dxyz + xyz_max += 4 * dxyz + + p = numpy.array([[xyz_min[0], xyz_max[1]], + [xyz_max[0], xyz_max[1]], + [xyz_max[0], xyz_min[1]], + [xyz_min[0], xyz_min[1]]], dtype=float) + + self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground) - def draw_cuboid( - self, - cell_data: NDArray, - foreground: Sequence[foreground_t] | foreground_t, - *, - x: ExtentProtocol | ExtentDict, - y: ExtentProtocol | ExtentDict, - z: ExtentProtocol | ExtentDict, - ) -> None: - """ - Draw an axis-aligned cuboid +def draw_cuboid(self, + cell_data: numpy.ndarray, + center: numpy.ndarray, + dimensions: numpy.ndarray, + foreground: Union[List[Union[float, foreground_callable_t]], float, foreground_callable_t], + ) -> None: + """ + Draw an axis-aligned cuboid - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - x: `Extent` or extent-like dict specifying the x-extent of the cuboid. - y: `Extent` or extent-like dict specifying the y-extent of the cuboid. - z: `Extent` or extent-like dict specifying the z-extent of the cuboid. - """ - if isinstance(x, dict): - x = Extent(**x) - if isinstance(y, dict): - y = Extent(**y) - if isinstance(z, dict): - z = Extent(**z) - - p = numpy.array([[x.min, y.max], - [x.max, y.max], - [x.max, y.min], - [x.min, y.min]], dtype=float) - slab = Slab(axis=2, center=z.center, span=z.span) - self.draw_polygon(cell_data=cell_data, slab=slab, polygon=p, foreground=foreground) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + center: 3-element ndarray or list specifying the cuboid's center + dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge + sizes of the cuboid + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + """ + p = numpy.array([[-dimensions[0], +dimensions[1]], + [+dimensions[0], +dimensions[1]], + [+dimensions[0], -dimensions[1]], + [-dimensions[0], -dimensions[1]]], dtype=float) / 2.0 + thickness = dimensions[2] + self.draw_polygon(cell_data, 2, center, p, thickness, foreground) - def draw_cylinder( - self, - cell_data: NDArray, - h: SlabProtocol | SlabDict, - radius: float, - num_points: int, - center2d: ArrayLike, - foreground: Sequence[foreground_t] | foreground_t, - ) -> None: - """ - Draw an axis-aligned cylinder. Approximated by a num_points-gon +def draw_cylinder(self, + cell_data: numpy.ndarray, + surface_normal: int, + center: numpy.ndarray, + radius: float, + thickness: float, + num_points: int, + foreground: Union[List[Union[float, foreground_callable_t]], float, foreground_callable_t], + ) -> None: + """ + Draw an axis-aligned cylinder. Approximated by a num_points-gon - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - h: - radius: - num_points: The circle is approximated by a polygon with `num_points` vertices - center2d: - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - """ - if isinstance(h, dict): - h = Slab(**h) - - theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False)[:, None] - xy0 = numpy.hstack((numpy.sin(theta), numpy.cos(theta))) - polygon = radius * xy0 - self.draw_polygon(cell_data=cell_data, slab=h, polygon=polygon, foreground=foreground, offset2d=center2d) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. + center: 3-element ndarray or list specifying the cylinder's center + radius: cylinder radius + thickness: Thickness of the layer to draw + num_points: The circle is approximated by a polygon with `num_points` vertices + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + """ + theta = numpy.linspace(0, 2*numpy.pi, num_points, endpoint=False) + x = radius * numpy.sin(theta) + y = radius * numpy.cos(theta) + polygon = numpy.hstack((x[:, None], y[:, None])) + self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground) - def draw_extrude_rectangle( - self, - cell_data: NDArray, - rectangle: ArrayLike, - direction: int, - polarity: int, - distance: float, - ) -> None: - """ - Extrude a rectangle of a previously-drawn structure along an axis. +def draw_extrude_rectangle(self, + cell_data: numpy.ndarray, + rectangle: numpy.ndarray, + direction: int, + polarity: int, + distance: float, + ) -> None: + """ + Extrude a rectangle of a previously-drawn structure along an axis. - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - rectangle: 2x3 ndarray or list specifying the rectangle's corners - direction: Direction to extrude in. Integer in `range(3)`. - polarity: +1 or -1, direction along axis to extrude in - distance: How far to extrude - """ - sgn = numpy.sign(polarity) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + rectangle: 2x3 ndarray or list specifying the rectangle's corners + direction: Direction to extrude in. Integer in `range(3)`. + polarity: +1 or -1, direction along axis to extrude in + distance: How far to extrude + """ + s = numpy.sign(polarity) - rectangle = numpy.asarray(rectangle, dtype=float) - if sgn == 0: - raise GridError('0 is not a valid polarity') - if direction not in range(3): - raise GridError(f'Invalid direction: {direction}') - if rectangle[0, direction] != rectangle[1, direction]: - raise GridError('Rectangle entries along extrusion direction do not match.') + rectangle = numpy.array(rectangle, dtype=float) + if s == 0: + raise GridError('0 is not a valid polarity') + if direction not in range(3): + raise GridError(f'Invalid direction: {direction}') + if rectangle[0, direction] != rectangle[1, direction]: + raise GridError('Rectangle entries along extrusion direction do not match.') - center = rectangle.sum(axis=0) / 2.0 - center[direction] += sgn * distance / 2.0 + center = rectangle.sum(axis=0) / 2.0 + center[direction] += s * distance / 2.0 - surface = numpy.delete(range(3), direction) + surface = numpy.delete(range(3), direction) - dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface] - poly = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5, - numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T - thickness = distance + dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface] + p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0]/2.0, + numpy.array([-1, 1, 1, -1], dtype=float) * dim[1]/2.0)).T + thickness = distance - foreground_func = [] - for ii, grid in enumerate(cell_data): - zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction] - fpart = zz - numpy.floor(zz) - low = int(numpy.clip(numpy.floor(zz), 0, grid.shape[direction] - 1)) - high = int(numpy.clip(numpy.floor(zz) + 1, 0, grid.shape[direction] - 1)) + foreground_func = [] + for i, grid in enumerate(cell_data): + z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction] - low_ind = [low if dd == direction else slice(None) for dd in range(3)] - high_ind = [high if dd == direction else slice(None) for dd in range(3)] + ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)] - if low == high: - foreground = grid[tuple(low_ind)] - else: - mult = [1 - fpart, fpart][::sgn] # reverses if s negative - foreground = mult[0] * grid[tuple(low_ind)] + mult[1] * grid[tuple(high_ind)] + fpart = z - numpy.floor(z) + mult = [1-fpart, fpart][::s] # reverses if s negative - def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 - # transform from natural position to index - xyzi = numpy.array([self.pos2ind(qrs, which_shifts=ii) - for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64) - # reshape to original shape and keep only in-plane components - qi, ri = (numpy.reshape(xyzi[:, kk], xs.shape) for kk in surface) - return foreground[qi, ri] + foreground = mult[0] * grid[tuple(ind)] + ind[direction] += 1 + foreground += mult[1] * grid[tuple(ind)] - foreground_func.append(f_foreground) + def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> numpy.ndarray: + # transform from natural position to index + xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i) + for qrs in zip(xs.flat, ys.flat, zs.flat)], dtype=int) + # reshape to original shape and keep only in-plane components + qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface) + return foreground[qi, ri] + + foreground_func.append(f_foreground) + + self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func) - slab = Slab(axis=direction, center=center[direction], span=thickness) - self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface]) diff --git a/gridlock/error.py b/gridlock/error.py new file mode 100644 index 0000000..3974e9c --- /dev/null +++ b/gridlock/error.py @@ -0,0 +1,2 @@ +class GridError(Exception): + pass diff --git a/gridlock/examples/ex0.py b/gridlock/examples/ex0.py index 4ff2fb9..ca2ef55 100644 --- a/gridlock/examples/ex0.py +++ b/gridlock/examples/ex0.py @@ -1,4 +1,4 @@ -import numpy +import numpy # type: ignore from gridlock import Grid @@ -6,18 +6,18 @@ if __name__ == '__main__': # xyz = [numpy.arange(-5.0, 6.0), numpy.arange(-4.0, 5.0), [-1.0, 1.0]] # eg = Grid(xyz) # egc = Grid.allocate(0.0) - # # eg.draw_slab(egc, slab=dict(axis=2, center=0, span=10), foreground=2) - # eg.draw_cylinder(egc, h=slab(axis=2, center=0, span=10), - # center2d=[0, 0], radius=4, thickness=10, num_points=1000, foreground=1) - # eg.visualize_slice(egc, plane=dict(z=0), which_shifts=2) + # # eg.draw_slab(egc, surface_normal=2, center=0, thickness=10, foreground=2) + # eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=4, + # thickness=10, num_points=1000, foreground=1) + # eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2) # xyz2 = [numpy.arange(-5.0, 6.0), [-1.0, 1.0], numpy.arange(-4.0, 5.0)] # eg2 = Grid(xyz2) # eg2c = Grid.allocate(0.0) - # # eg2.draw_slab(eg2c, slab=dict(axis=2, center=0, span=10), foreground=2) - # eg2.draw_cylinder(eg2c, h=slab(axis=1, center=0, span=10), center2d=[0, 0], - # radius=4, num_points=1000, foreground=1.0) - # eg2.visualize_slice(eg2c, plane=dict(y=0), which_shifts=1) + # # eg2.draw_slab(eg2c, surface_normal=2, center=0, thickness=10, foreground=2) + # eg2.draw_cylinder(eg2c, surface_normal=1, center=[0, 0, 0], + # radius=4, thickness=10, num_points=1000, foreground=1.0) + # eg2.visualize_slice(eg2c, surface_normal=1, center=0, which_shifts=1) # n = 20 # m = 3 @@ -29,27 +29,16 @@ if __name__ == '__main__': # numpy.linspace(-5.5, 5.5, 10)] half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5] - xyz3 = [numpy.array([-x for x in half_x[::-1]] + [0] + half_x, dtype=float), - numpy.linspace(-5.5, 5.5, 10, dtype=float), - numpy.linspace(-5.5, 5.5, 10, dtype=float)] + xyz3 = [[-x for x in half_x[::-1]] + [0] + half_x, + numpy.linspace(-5.5, 5.5, 10), + numpy.linspace(-5.5, 5.5, 10)] eg = Grid(xyz3) egc = eg.allocate(0) # eg.draw_slab(Direction.z, 0, 10, 2) eg.save('/home/jan/Desktop/test.pickle') - eg.draw_cylinder( - egc, - h=dict(axis='z', center=0, span=10), - center2d=[0, 0], - radius=2.0, - num_points=1000, - foreground=1, - ) - eg.draw_extrude_rectangle( - egc, - rectangle=[[-2, 1, -1], [0, 1, 1]], - direction=1, - polarity=+1, - distance=5, - ) - eg.visualize_slice(egc, plane=dict(z=0), which_shifts=2) + eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0, + thickness=10, num_poitns=1000, foreground=1) + eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]], + direction=1, poalarity=+1, distance=5) + eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2) eg.visualize_isosurface(egc, which_shifts=2) diff --git a/gridlock/grid.py b/gridlock/grid.py index eeb9708..e320854 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -1,93 +1,20 @@ -from typing import TYPE_CHECKING, Any, ClassVar, Self -from collections.abc import Callable, Sequence +from typing import List, Tuple, Callable, Dict, Optional, Union, Sequence, ClassVar, TypeVar -import numpy -from numpy.typing import NDArray, ArrayLike +import numpy # type: ignore +from numpy import diff, floor, ceil, zeros, hstack, newaxis import pickle import warnings import copy from . import GridError -from .draw import GridDrawMixin -from .read import GridReadMixin -from .position import GridPosMixin - -if TYPE_CHECKING: - from .data import GridData -foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray] -_FORMAT_VERSION = 1 +foreground_callable_type = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray] +T = TypeVar('T', bound='Grid') -def _is_npz_file(filename: str) -> bool: - with open(filename, 'rb') as f: - return f.read(2) == b'PK' - - -def _save_npz_payload(filename: str, payload: dict[str, Any]) -> None: - with open(filename, 'wb') as f: - numpy.savez_compressed(f, **payload) - - -def _load_payload(filename: str) -> dict[str, Any]: - if _is_npz_file(filename): - with numpy.load(filename, allow_pickle=False) as payload: - return {key: payload[key] for key in payload.files} - - with open(filename, 'rb') as f: - legacy = pickle.load(f) - - if isinstance(legacy, Grid): - return legacy._serialization_payload(kind='grid') - if isinstance(legacy, dict): - grid = Grid([[-1, 1]] * 3) - grid.__dict__.update(legacy) - return grid._serialization_payload(kind='grid') - raise GridError('Unsupported serialized Grid payload') - - -def _payload_scalar_str(payload: dict[str, Any], key: str) -> str: - if key not in payload: - raise GridError(f'Missing serialized key: {key}') - - value = numpy.asarray(payload[key]) - if value.size != 1: - raise GridError(f'Serialized key {key} must be scalar') - return str(value.reshape(())) - - -def _payload_scalar_int(payload: dict[str, Any], key: str) -> int: - if key not in payload: - raise GridError(f'Missing serialized key: {key}') - - value = numpy.asarray(payload[key]) - if value.size != 1: - raise GridError(f'Serialized key {key} must be scalar') - return int(value.reshape(())) - - -def _grid_from_payload(payload: dict[str, Any]) -> 'Grid': - if _payload_scalar_int(payload, 'format_version') != _FORMAT_VERSION: - raise GridError('Unsupported serialized Grid format version') - - exyz = [] - for axis in range(3): - key = f'exyz_{axis}' - if key not in payload: - raise GridError(f'Missing serialized key: {key}') - exyz.append(numpy.array(payload[key], dtype=float)) - - if 'shifts' not in payload or 'periodic' not in payload: - raise GridError('Serialized Grid payload is missing shifts or periodic data') - - shifts = numpy.array(payload['shifts'], dtype=float) - periodic = numpy.array(payload['periodic'], dtype=bool).tolist() - return Grid(exyz, shifts=shifts, periodic=periodic) - - -class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): +class Grid: """ Simulation grid metadata for finite-difference simulations. @@ -121,35 +48,217 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Because of this, we either assume this 'ghost' cell is the same size as the last real cell, or, if `self.periodic[a]` is set to `True`, the same size as the first cell. """ - exyz: list[NDArray] + exyz: List[numpy.ndarray] """Cell edges. Monotonically increasing without duplicates.""" - periodic: list[bool] + periodic: List[bool] """For each axis, determines how far the rightmost boundary gets shifted. """ - shifts: NDArray + shifts: numpy.ndarray """Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`""" - Yee_Shifts_E: ClassVar[NDArray] = 0.5 * numpy.array([ - [1, 0, 0], - [0, 1, 0], - [0, 0, 1], - ], dtype=float) + Yee_Shifts_E: ClassVar[numpy.ndarray] = 0.5 * numpy.array([[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], dtype=float) """Default shifts for Yee grid E-field""" - Yee_Shifts_H: ClassVar[NDArray] = 0.5 * numpy.array([ - [0, 1, 1], - [1, 0, 1], - [1, 1, 0], - ], dtype=float) + Yee_Shifts_H: ClassVar[numpy.ndarray] = 0.5 * numpy.array([[0, 1, 1], + [1, 0, 1], + [1, 1, 0]], dtype=float) """Default shifts for Yee grid H-field""" - def __init__( - self, - pixel_edge_coordinates: Sequence[ArrayLike], - shifts: ArrayLike = Yee_Shifts_E, - periodic: bool | Sequence[bool] = False, - ) -> None: + from .draw import ( + draw_polygons, draw_polygon, draw_slab, draw_cuboid, + draw_cylinder, draw_extrude_rectangle, + ) + from .read import get_slice, visualize_slice, visualize_isosurface + from .position import ind2pos, pos2ind + + @property + def dxyz(self) -> List[numpy.ndarray]: + """ + Cell sizes for each axis, no shifts applied + + Returns: + List of 3 ndarrays of cell sizes + """ + return [numpy.diff(ee) for ee in self.exyz] + + @property + def xyz(self) -> List[numpy.ndarray]: + """ + Cell centers for each axis, no shifts applied + + Returns: + List of 3 ndarrays of cell edges + """ + return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)] + + @property + def shape(self) -> numpy.ndarray: + """ + The number of cells in x, y, and z + + Returns: + ndarray of [x_centers.size, y_centers.size, z_centers.size] + """ + return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int) + + @property + def num_grids(self) -> int: + """ + The number of grids (number of shifts) + """ + return self.shifts.shape[0] + + @property + def cell_data_shape(self): + """ + The shape of the cell_data ndarray (num_grids, *self.shape). + """ + return numpy.hstack((self.num_grids, self.shape)) + + @property + def dxyz_with_ghost(self) -> List[numpy.ndarray]: + """ + Gives dxyz with an additional 'ghost' cell at the end, whose value depends + on whether or not the axis has periodic boundary conditions. See main description + above to learn why this is necessary. + + If periodic, final edge shifts same amount as first + Otherwise, final edge shifts same amount as second-to-last + + Returns: + list of [dxs, dys, dzs] with each element same length as elements of `self.xyz` + """ + el = [0 if p else -1 for p in self.periodic] + return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el)] + + @property + def center(self) -> numpy.ndarray: + """ + Center position of the entire grid, no shifts applied + + Returns: + ndarray of [x_center, y_center, z_center] + """ + # center is just average of first and last xyz, which is just the average of the + # first two and last two exyz + centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)] + return numpy.array(centers, dtype=float) + + @property + def dxyz_limits(self) -> Tuple[numpy.ndarray, numpy.ndarray]: + """ + Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element + ndarrays. No shifts are applied, so these are extreme bounds on these values (as a + weighted average is performed when shifting). + + Returns: + Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]` + """ + d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float) + d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float) + return d_min, d_max + + def shifted_exyz(self, which_shifts: Optional[int]) -> List[numpy.ndarray]: + """ + Returns edges for which_shifts. + + Args: + which_shifts: Which grid (which shifts) to use, or `None` for unshifted + + Returns: + List of 3 ndarrays of cell edges + """ + if which_shifts is None: + return self.exyz + dxyz = self.dxyz_with_ghost + shifts = self.shifts[which_shifts, :] + + # If shift is negative, use left cell's dx to determine shift + for a in range(3): + if shifts[a] < 0: + dxyz[a] = numpy.roll(dxyz[a], 1) + + return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] + + def shifted_dxyz(self, which_shifts: Optional[int]) -> List[numpy.ndarray]: + """ + Returns cell sizes for `which_shifts`. + + Args: + which_shifts: Which grid (which shifts) to use, or `None` for unshifted + + Returns: + List of 3 ndarrays of cell sizes + """ + if which_shifts is None: + return self.dxyz + shifts = self.shifts[which_shifts, :] + dxyz = self.dxyz_with_ghost + + # If shift is negative, use left cell's dx to determine size + sdxyz = [] + for a in range(3): + if shifts[a] < 0: + roll_dxyz = numpy.roll(dxyz[a], 1) + abs_shift = numpy.abs(shifts[a]) + sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift)) + else: + sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a]) + + return sdxyz + + def shifted_xyz(self, which_shifts: Optional[int]) -> List[numpy.ndarray]: + """ + Returns cell centers for `which_shifts`. + + Args: + which_shifts: Which grid (which shifts) to use, or `None` for unshifted + + Returns: + List of 3 ndarrays of cell centers + """ + if which_shifts is None: + return self.xyz + exyz = self.shifted_exyz(which_shifts) + dxyz = self.shifted_dxyz(which_shifts) + return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)] + + def autoshifted_dxyz(self) -> List[numpy.ndarray]: + """ + Return cell widths, with each dimension shifted by the corresponding shifts. + + Returns: + `[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]` + """ + if self.num_grids != 3: + raise GridError('Autoshifting requires exactly 3 grids') + return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] + + def allocate(self, fill_value: Optional[float] = 1.0, dtype=numpy.float32) -> numpy.ndarray: + """ + Allocate an ndarray for storing grid data. + + Args: + fill_value: Value to initialize the grid to. If None, an + uninitialized array is returned. + dtype: Numpy dtype for the array. Default is `numpy.float32`. + + Returns: + The allocated array + """ + if fill_value is None: + return numpy.empty(self.cell_data_shape, dtype=dtype) + else: + return numpy.full(self.cell_data_shape, fill_value, dtype=dtype) + + def __init__(self, + pixel_edge_coordinates: Sequence[numpy.ndarray], + shifts: numpy.ndarray = Yee_Shifts_E, + periodic: Union[bool, Sequence[bool]] = False, + ) -> None: """ Args: pixel_edge_coordinates: 3-element list of (ndarrays or lists) specifying the @@ -164,24 +273,17 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Raises: `GridError` on invalid input """ - edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates] - if len(edge_arrs) != 3: - raise GridError('pixel_edge_coordinates must contain exactly 3 coordinate arrays') - self.exyz = [numpy.unique(edges) for edges in edge_arrs] + self.exyz = [numpy.unique(pixel_edge_coordinates[i]) for i in range(3)] self.shifts = numpy.array(shifts, dtype=float) for i in range(3): - if self.exyz[i].size != edge_arrs[i].size: + if len(self.exyz[i]) != len(pixel_edge_coordinates[i]): warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2) if isinstance(periodic, bool): self.periodic = [periodic] * 3 else: self.periodic = list(periodic) - if len(self.periodic) != 3: - raise GridError('periodic must be a bool or a sequence of length 3') - if not all(isinstance(pp, bool | numpy.bool_) for pp in self.periodic): - raise GridError('periodic sequence entries must be bool values') if len(self.shifts.shape) != 2: raise GridError('Misshapen shifts: shifts must have two axes! ' @@ -193,16 +295,9 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): if (numpy.abs(self.shifts) > 1).any(): raise GridError('Only shifts in the range [-1, 1] are currently supported') - def _serialization_payload(self, *, kind: str) -> dict[str, Any]: - payload: dict[str, Any] = { - 'kind': numpy.array(kind), - 'format_version': numpy.array(_FORMAT_VERSION, dtype=int), - 'shifts': self.shifts, - 'periodic': numpy.array(self.periodic, dtype=bool), - } - for axis, exyz in enumerate(self.exyz): - payload[f'exyz_{axis}'] = exyz - return payload + if (self.shifts < 0).any(): + # TODO: Test negative shifts + warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2) @staticmethod def load(filename: str) -> 'Grid': @@ -212,13 +307,14 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Args: filename: Filename to load from. """ - payload = _load_payload(filename) - kind = _payload_scalar_str(payload, 'kind') - if kind not in ('grid', 'grid_data'): - raise GridError(f'Unsupported serialized kind: {kind}') - return _grid_from_payload(payload) + with open(filename, 'rb') as f: + tmp_dict = pickle.load(f) - def save(self, filename: str) -> Self: + g = Grid([[-1, 1]] * 3) + g.__dict__.update(tmp_dict) + return g + + def save(self: T, filename: str) -> T: """ Save to file. @@ -228,19 +324,11 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Returns: self """ - _save_npz_payload(filename, self._serialization_payload(kind='grid')) + with open(filename, 'wb') as f: + pickle.dump(self.__dict__, f, protocol=2) return self - def with_data( - self, - fill_value: float | None = 1.0, - dtype: type[numpy.number] = numpy.float32, - ) -> 'GridData': - from .data import GridData - - return GridData(self.copy(), self.allocate(fill_value=fill_value, dtype=dtype)) - - def copy(self) -> Self: + def copy(self: T) -> T: """ Returns: Deep copy of the grid. diff --git a/gridlock/position.py b/gridlock/position.py index 6344ea4..1224a12 100644 --- a/gridlock/position.py +++ b/gridlock/position.py @@ -1,118 +1,115 @@ """ Position-related methods for Grid class """ -import numpy -from numpy.typing import NDArray, ArrayLike +from typing import List, Optional + +import numpy # type: ignore from . import GridError -from .base import GridBase -class GridPosMixin(GridBase): - def ind2pos( - self, - ind: NDArray, - which_shifts: int | None = None, +def ind2pos(self, + ind: numpy.ndarray, + which_shifts: Optional[int] = None, round_ind: bool = True, check_bounds: bool = True - ) -> NDArray[numpy.float64]: - """ - Returns the natural position corresponding to the specified cell center indices. - The resulting position is clipped to the bounds of the grid - (to cell centers if `round_ind=True`, or cell outer edges if `round_ind=False`) + ) -> numpy.ndarray: + """ + Returns the natural position corresponding to the specified cell center indices. + The resulting position is clipped to the bounds of the grid + (to cell centers if `round_ind=True`, or cell outer edges if `round_ind=False`) - Args: - ind: Indices of the position. Can be fractional. (3-element ndarray or list) - which_shifts: which grid number (`shifts`) to use - round_ind: Whether to round ind to the nearest integer position before indexing - (default `True`) - check_bounds: Whether to raise an `GridError` if the provided ind is outside of - the grid, as defined above (centers if `round_ind`, else edges) (default `True`) + Args: + ind: Indices of the position. Can be fractional. (3-element ndarray or list) + which_shifts: which grid number (`shifts`) to use + round_ind: Whether to round ind to the nearest integer position before indexing + (default `True`) + check_bounds: Whether to raise an `GridError` if the provided ind is outside of + the grid, as defined above (centers if `round_ind`, else edges) (default `True`) - Returns: - 3-element ndarray specifying the natural position + Returns: + 3-element ndarray specifying the natural position - Raises: - `GridError` if invalid `which_shifts` - `GridError` if `check_bounds` and out of bounds - """ - if which_shifts is not None and which_shifts >= self.shifts.shape[0]: - raise GridError('Invalid shifts') - ind = numpy.array(ind, dtype=float) - - if check_bounds: - if round_ind: - low_bound = 0.0 - high_bound = -1.0 - else: - low_bound = -0.5 - high_bound = -0.5 - if (ind < low_bound).any() or (ind > self.shape + high_bound).any(): - raise GridError(f'Position outside of grid: {ind}') + Raises: + `GridError` if invalid `which_shifts` + `GridError` if `check_bounds` and out of bounds + """ + if which_shifts is not None and which_shifts >= self.shifts.shape[0]: + raise GridError('Invalid shifts') + ind = numpy.array(ind, dtype=float) + if check_bounds: if round_ind: - rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) - sxyz = self.shifted_xyz(which_shifts) - position = [sxyz[a][rind[a]] for a in range(3)] + low_bound = 0.0 + high_bound = -1.0 else: - sexyz = self.shifted_exyz(which_shifts) - position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a]) - for a in range(3)] - return numpy.array(position, dtype=float) + low_bound = -0.5 + high_bound = -0.5 + if (ind < low_bound).any() or (ind > self.shape - high_bound).any(): + raise GridError(f'Position outside of grid: {ind}') + + if round_ind: + rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) + sxyz = self.shifted_xyz(which_shifts) + position = [sxyz[a][rind[a]].astype(int) for a in range(3)] + else: + sexyz = self.shifted_exyz(which_shifts) + position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a]) + for a in range(3)] + return numpy.array(position, dtype=float) - def pos2ind( - self, - r: ArrayLike, - which_shifts: int | None, +def pos2ind(self, + r: numpy.ndarray, + which_shifts: Optional[int], round_ind: bool = True, check_bounds: bool = True - ) -> NDArray[numpy.float64]: - """ - Returns the cell-center indices corresponding to the specified natural position. - The resulting position is clipped to within the outer centers of the grid. + ) -> numpy.ndarray: + """ + Returns the cell-center indices corresponding to the specified natural position. + The resulting position is clipped to within the outer centers of the grid. - Args: - r: Natural position that we will convert into indices (3-element ndarray or list) - which_shifts: which grid number (`shifts`) to use - round_ind: Whether to round the returned indices to the nearest integers. - check_bounds: Whether to throw an `GridError` if `r` is outside the grid edges + Args: + r: Natural position that we will convert into indices (3-element ndarray or list) + which_shifts: which grid number (`shifts`) to use + round_ind: Whether to round the returned indices to the nearest integers. + check_bounds: Whether to throw an `GridError` if `r` is outside the grid edges - Returns: - 3-element ndarray specifying the indices + Returns: + 3-element ndarray specifying the indices - Raises: - `GridError` if invalid `which_shifts` - `GridError` if `check_bounds` and out of bounds - """ - r = numpy.squeeze(r) - if r.size != 3: - raise GridError(f'r must be 3-element vector: {r}') + Raises: + `GridError` if invalid `which_shifts` + `GridError` if `check_bounds` and out of bounds + """ + r = numpy.squeeze(r) + if r.size != 3: + raise GridError(f'r must be 3-element vector: {r}') - if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]): - raise GridError(f'Invalid which_shifts: {which_shifts}') + if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]): + raise GridError(f'Invalid which_shifts: {which_shifts}') - sexyz = self.shifted_exyz(which_shifts) + sexyz = self.shifted_exyz(which_shifts) - if check_bounds: - for a in range(3): - if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]): - raise GridError(f'Position[{a}] outside of grid!') - - grid_pos = numpy.zeros((3,)) + if check_bounds: for a in range(3): - xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in - xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds + if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]): + raise GridError(f'Position[{a}] outside of grid!') - # No need to interpolate if round_ind is true or we were outside the grid - if round_ind or xi != xi_clipped: - grid_pos[a] = xi_clipped - else: - # Interpolate - x = self.shifted_xyz(which_shifts)[a][xi] - dx = self.shifted_dxyz(which_shifts)[a][xi] - f = (r[a] - x) / dx + grid_pos = numpy.zeros((3,)) + for a in range(3): + xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in + xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds - # Clip to centers - grid_pos[a] = numpy.clip(xi + f, 0, self.shape[a] - 1) - return grid_pos + # No need to interpolate if round_ind is true or we were outside the grid + if round_ind or xi != xi_clipped: + grid_pos[a] = xi_clipped + else: + # Interpolate + x = self.shifted_xyz(which_shifts)[a][xi] + dx = self.shifted_dxyz(which_shifts)[a][xi] + f = (r[a] - x) / dx + + # Clip to centers + grid_pos[a] = numpy.clip(xi + f, 0, self.shape[a] - 1) + return grid_pos diff --git a/gridlock/read.py b/gridlock/read.py index f8a40a1..aa059d5 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -1,318 +1,183 @@ """ Readback and visualization methods for Grid class """ -from typing import Any, TYPE_CHECKING +from typing import Dict, Optional, Union, Any -import numpy -from numpy.typing import NDArray - -from .utils import GridError, Plane, PlaneDict, PlaneProtocol -from .position import GridPosMixin - -if TYPE_CHECKING: - import matplotlib.axes - import matplotlib.figure +import numpy # type: ignore +from . import GridError # .visualize_* uses matplotlib # .visualize_isosurface uses skimage # .visualize_isosurface uses mpl_toolkits.mplot3d -class GridReadMixin(GridPosMixin): - @staticmethod - def _preview_exyz_from_centers(centers: NDArray, fallback_edges: NDArray) -> NDArray[numpy.float64]: - if centers.size > 1: - midpoints = 0.5 * (centers[:-1] + centers[1:]) - first = centers[0] - 0.5 * (centers[1] - centers[0]) - last = centers[-1] + 0.5 * (centers[-1] - centers[-2]) - return numpy.hstack(([first], midpoints, [last])) - return numpy.array([fallback_edges[0], fallback_edges[-1]], dtype=float) +def get_slice(self, + cell_data: numpy.ndarray, + surface_normal: int, + center: float, + which_shifts: int = 0, + sample_period: int = 1 + ) -> numpy.ndarray: + """ + Retrieve a slice of a grid. + Interpolates if given a position between two planes. - def _sampled_exyz(self, which_shifts: int, sample_period: int) -> list[NDArray[numpy.float64]]: - if sample_period <= 1: - return self.shifted_exyz(which_shifts) + Args: + cell_data: Cell data to slice + surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. + center: Scalar specifying position along surface_normal axis. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) - shifted_xyz = self.shifted_xyz(which_shifts) - shifted_exyz = self.shifted_exyz(which_shifts) - return [ - self._preview_exyz_from_centers(shifted_xyz[a][::sample_period], shifted_exyz[a]) - for a in range(3) - ] + Returns: + Array containing the portion of the grid. + """ + if numpy.size(center) != 1 or not numpy.isreal(center): + raise GridError('center must be a real scalar') - def get_slice( - self, - cell_data: NDArray, - plane: PlaneProtocol | PlaneDict, - which_shifts: int = 0, - sample_period: int = 1 - ) -> NDArray: - """ - Retrieve a slice of a grid. - Interpolates if given a position between two grid planes. + sp = round(sample_period) + if sp <= 0: + raise GridError('sample_period must be positive') - Args: - cell_data: Cell data to slice - plane: Axis and position (`Plane`) of the plane to read. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) + if numpy.size(which_shifts) != 1 or which_shifts < 0: + raise GridError('Invalid which_shifts') - Returns: - Array containing the portion of the grid. - """ - if isinstance(plane, dict): - plane = Plane(**plane) + if surface_normal not in range(3): + raise GridError('Invalid surface_normal direction') - sp = round(sample_period) - if sp <= 0: - raise GridError('sample_period must be positive') + surface = numpy.delete(range(3), surface_normal) - if numpy.size(which_shifts) != 1 or which_shifts < 0: - raise GridError('Invalid which_shifts') + # Extract indices and weights of planes + center3 = numpy.insert([0, 0], surface_normal, (center,)) + center_index = self.pos2ind(center3, which_shifts, + round_ind=False, check_bounds=False)[surface_normal] + centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) + if len(centers) == 2: + fpart = center_index - numpy.floor(center_index) + w = [1 - fpart, fpart] # longer distance -> less weight + else: + w = [1] - surface = numpy.delete(range(3), plane.axis) + c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1]) + if center < c_min or center > c_max: + raise GridError('Coordinate of selected plane must be within simulation domain') - # Extract indices and weights of planes - center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,)) - center_index = self.pos2ind(center3, which_shifts, - round_ind=False, check_bounds=False)[plane.axis] - centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) - if len(centers) == 2: - fpart = center_index - numpy.floor(center_index) - w = [1 - fpart, fpart] # longer distance -> less weight - else: - w = [1] + # Extract grid values from planes above and below visualized slice + sliced_grid = numpy.zeros(self.shape[surface]) + for ci, weight in zip(centers, w): + s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3)) + sliced_grid += weight * cell_data[which_shifts][tuple(s)] - c_min, c_max = (self.shifted_xyz(which_shifts)[plane.axis][i] for i in [0, -1]) - if plane.pos < c_min or plane.pos > c_max: - raise GridError('Coordinate of selected plane must be within simulation domain') + # Remove extra dimensions + sliced_grid = numpy.squeeze(sliced_grid) - # Extract grid values from planes above and below visualized slice - sample_shape = tuple(self.shifted_xyz(which_shifts)[a][::sp].size for a in surface) - sliced_grid = numpy.zeros(sample_shape, dtype=numpy.result_type(cell_data.dtype, float)) - for ci, weight in zip(centers, w, strict=True): - s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3)) - sliced_grid += weight * cell_data[which_shifts][tuple(s)] - - # Remove extra dimensions - sliced_grid = numpy.squeeze(sliced_grid) - - return sliced_grid + return sliced_grid - def visualize_slice( - self, - cell_data: NDArray, - plane: PlaneProtocol | PlaneDict, - which_shifts: int = 0, - sample_period: int = 1, - finalize: bool = True, - pcolormesh_args: dict[str, Any] | None = None, - ax: 'matplotlib.axes.Axes | None' = None, - ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: - """ - Visualize a slice of a grid. - Interpolates if given a position between two grid planes. +def visualize_slice(self, + cell_data: numpy.ndarray, + surface_normal: int, + center: float, + which_shifts: int = 0, + sample_period: int = 1, + finalize: bool = True, + pcolormesh_args: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Visualize a slice of a grid. + Interpolates if given a position between two planes. - Args: - cell_data: Cell data to visualize - plane: Axis and position (`Plane`) of the plane to read. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) - finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - pcolormesh_args: Args passed through to matplotlib `pcolormesh()` - ax: If provided, plot to these axes (instead of creating a new figure & axes) + Args: + surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. + center: Scalar specifying position along surface_normal axis. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) + finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + """ + from matplotlib import pyplot - Returns: - (Figure, Axes) - """ - from matplotlib import pyplot + if pcolormesh_args is None: + pcolormesh_args = {} - if isinstance(plane, dict): - plane = Plane(**plane) + grid_slice = self.get_slice(cell_data=cell_data, + surface_normal=surface_normal, + center=center, + which_shifts=which_shifts, + sample_period=sample_period) - if pcolormesh_args is None: - pcolormesh_args = {} + surface = numpy.delete(range(3), surface_normal) - grid_slice = self.get_slice( - cell_data = cell_data, - plane = plane, - which_shifts = which_shifts, - sample_period = sample_period, - ) + x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) + xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') + x_label, y_label = ('xyz'[a] for a in surface) - surface = numpy.delete(range(3), plane.axis) - - if sample_period == 1: - x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) - else: - x, y = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface) - pcolormesh_args.setdefault('shading', 'nearest') - xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') - x_label, y_label = ('xyz'[a] for a in surface) - - if ax is None: - fig, ax = pyplot.subplots() - else: - fig = ax.figure - mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) - fig.colorbar(mappable) - ax.set_aspect('equal', adjustable='box') - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - - if finalize: - pyplot.show() - - return fig, ax + pyplot.figure() + pyplot.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) + pyplot.colorbar() + pyplot.gca().set_aspect('equal', adjustable='box') + pyplot.xlabel(x_label) + pyplot.ylabel(y_label) + if finalize: + pyplot.show() - def visualize_edges( - self, - cell_data: NDArray, - plane: PlaneProtocol | PlaneDict, - which_shifts: int = 0, - sample_period: int = 1, - finalize: bool = True, - contour_args: dict[str, Any] | None = None, - ax: 'matplotlib.axes.Axes | None' = None, - level_fraction: float = 0.7, - ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: - """ - Visualize the edges of a grid slice. - This is intended as an overlay on top of visualize_slice (e.g. showing epsilon boundaries - on an E-field plot). +def visualize_isosurface(self, + cell_data: numpy.ndarray, + level: Optional[float] = None, + which_shifts: int = 0, + sample_period: int = 1, + show_edges: bool = True, + finalize: bool = True, + ) -> None: + """ + Draw an isosurface plot of the device. - Interpolates if given a position between two grid planes. + Args: + cell_data: Cell data to visualize + level: Value at which to find isosurface. Default (None) uses mean value in grid. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) + show_edges: Whether to draw triangle edges. Default `True` + finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + """ + from matplotlib import pyplot + import skimage.measure + # Claims to be unused, but needed for subplot(projection='3d') + from mpl_toolkits.mplot3d import Axes3D - Args: - cell_data: Cell data to visualize - plane: Axis and position (`Plane`) of the plane to read. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) - finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - contour_args: Args passed through to matplotlib `pcolormesh()` - ax: If provided, plot to these axes (instead of creating a new figure & axes) - level_fraction: Value between 0 and 1 which tunes how many contours are generated. - 1 indicates that every possible step should have its own contour. + # Get data from cell_data + grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period] + if level is None: + level = grid.mean() - Returns: - (Figure, Axes) - """ - from matplotlib import pyplot + # Find isosurface with marching cubes + verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) - if level_fraction > 1: - raise GridError(f'{level_fraction=} must be between 0 and 1') + # Convert vertices from index to position + pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False) + for i in range(verts.shape[0])], dtype=float) + xs, ys, zs = (pos_verts[:, a] for a in range(3)) - if isinstance(plane, dict): - plane = Plane(**plane) + # Draw the plot + fig = pyplot.figure() + ax = fig.add_subplot(111, projection='3d') + if show_edges: + ax.plot_trisurf(xs, ys, faces, zs) + else: + ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') - if contour_args is None: - contour_args = dict(alpha=0.8, colors='gray') + # Add a fake plot of a cube to force the axes to be equal lengths + max_range = numpy.array([xs.max() - xs.min(), + ys.max() - ys.min(), + zs.max() - zs.min()], dtype=float).max() + mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2] + xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min()) + ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min()) + zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min()) + # Comment or uncomment following both lines to test the fake bounding box: + for xb, yb, zb in zip(xbs, ybs, zbs): + ax.plot([xb], [yb], [zb], 'w') - grid_slice = self.get_slice( - cell_data = cell_data, - plane = plane, - which_shifts = which_shifts, - sample_period = sample_period, - ) - cvals, cval_counts = numpy.unique(grid_slice, return_counts=True) - if cvals.size == 1: - levels = [cvals[0] + 1] - else: - cval_order = numpy.argsort(cval_counts)[::-1] - level_count = 2 - while cval_counts[cval_order[:level_count]].sum() < level_fraction: - level_count += 1 - ctr_levels = cvals[cval_order[:level_count]] - levels = numpy.diff(ctr_levels[::-1]) + ctr_levels[:0:-1] - - surface = numpy.delete(range(3), plane.axis) - - if ax is None: - fig, ax = pyplot.subplots() - else: - fig = ax.figure - xc, yc = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface) - xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij') - - ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) - - if finalize: - pyplot.show() - - return fig, ax - - - def visualize_isosurface( - self, - cell_data: NDArray, - level: float | None = None, - which_shifts: int = 0, - sample_period: int = 1, - show_edges: bool = True, - finalize: bool = True, - ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: - """ - Draw an isosurface plot of the device. - - Args: - cell_data: Cell data to visualize - level: Value at which to find isosurface. Default (None) uses mean value in grid. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) - show_edges: Whether to draw triangle edges. Default `True` - finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - - Returns: - (Figure, Axes) - """ - from matplotlib import pyplot - import skimage.measure - # Claims to be unused, but needed for subplot(projection='3d') - from mpl_toolkits.mplot3d import Axes3D - del Axes3D # imported for side effects only - - # Get data from cell_data - grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period] - if level is None: - level = grid.mean() - - # Find isosurface with marching cubes - verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) - - # Convert vertices from index to position - preview_exyz = self._sampled_exyz(which_shifts, sample_period) - pos_verts = numpy.array([ - [ - numpy.interp(verts[i, a], numpy.arange(preview_exyz[a].size) - 0.5, preview_exyz[a]) - for a in range(3) - ] - for i in range(verts.shape[0]) - ], dtype=float) - xs, ys, zs = (pos_verts[:, a] for a in range(3)) - - # Draw the plot - fig = pyplot.figure() - ax = fig.add_subplot(111, projection='3d') - if show_edges: - ax.plot_trisurf(xs, ys, faces, zs) # type: ignore - else: - ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') # type: ignore - - # Add a fake plot of a cube to force the axes to be equal lengths - max_range = numpy.array([xs.max() - xs.min(), - ys.max() - ys.min(), - zs.max() - zs.min()], dtype=float).max() - mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2] - xbs = 0.5 * max_range * mg[0].ravel() + 0.5 * (xs.max() + xs.min()) - ybs = 0.5 * max_range * mg[1].ravel() + 0.5 * (ys.max() + ys.min()) - zbs = 0.5 * max_range * mg[2].ravel() + 0.5 * (zs.max() + zs.min()) - # Comment or uncomment following both lines to test the fake bounding box: - for xb, yb, zb in zip(xbs, ybs, zbs, strict=True): - ax.plot([xb], [yb], [zb], 'w') - - if finalize: - pyplot.show() - - return fig, ax + if finalize: + pyplot.show() diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index ae0a73a..fc54030 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,9 +1,8 @@ -import pytest -import numpy -from numpy.testing import assert_allclose #, assert_array_equal -import pickle +import pytest # type: ignore +import numpy # type: ignore +from numpy.testing import assert_allclose, assert_array_equal # type: ignore -from .. import Grid, GridData, Extent, GridError, Plane, Slab +from .. import Grid def test_draw_oncenter_2x2() -> None: @@ -13,13 +12,7 @@ def test_draw_oncenter_2x2() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid( - arr, - x=dict(center=0, span=1), - y=Extent(center=0, span=1), - z=dict(center=0, span=10), - foreground=1, - ) + grid.draw_cuboid(arr, center=[0, 0, 0], dimensions=[1, 1, 10], foreground=1) correct = numpy.array([[0.25, 0.25], [0.25, 0.25]])[None, :, :, None] @@ -34,13 +27,7 @@ def test_draw_ongrid_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid( - arr, - x=dict(center=0, span=2), - y=dict(min=-1, max=1), - z=dict(center=0, min=-5), - foreground=1, - ) + grid.draw_cuboid(arr, center=[0, 0, 0], dimensions=[2, 2, 10], foreground=1) correct = numpy.array([[0, 0, 0, 0], [0, 1, 1, 0], @@ -57,13 +44,7 @@ def test_draw_xshift_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid( - arr, - x=dict(center=0.5, span=1.5), - y=dict(min=-1, max=1), - z=dict(center=0, span=10), - foreground=1, - ) + grid.draw_cuboid(arr, center=[0.5, 0, 0], dimensions=[1.5, 2, 10], foreground=1) correct = numpy.array([[0, 0, 0, 0], [0, 0.25, 0.25, 0], @@ -80,13 +61,7 @@ def test_draw_yshift_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid( - arr, - x=dict(min=-1, max=1), - y=dict(center=0.5, span=1.5), - z=dict(center=0, span=10), - foreground=1, - ) + grid.draw_cuboid(arr, center=[0, 0.5, 0], dimensions=[2, 1.5, 10], foreground=1) correct = numpy.array([[0, 0, 0, 0], [0, 0.25, 1, 0.25], @@ -103,13 +78,7 @@ def test_draw_2shift_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid( - arr, - x=dict(center=0.5, span=1.5), - y=dict(min=-0.5, max=0.5), - z=dict(center=0, span=10), - foreground=1, - ) + grid.draw_cuboid(arr, center=[0.5, 0, 0], dimensions=[1.5, 1, 10], foreground=1) correct = numpy.array([[0, 0, 0, 0], [0, 0.125, 0.125, 0], @@ -117,350 +86,3 @@ def test_draw_2shift_4x4() -> None: [0, 0.125, 0.125, 0]])[None, :, :, None] assert_allclose(arr, correct) - - -def test_ind2pos_round_preserves_float_centers() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]]) - - pos = grid.ind2pos(numpy.array([1, 0, 0]), which_shifts=0) - - assert_allclose(pos, [2.0, 1.0, 0.5]) - - -def test_ind2pos_enforces_bounds_for_rounded_and_fractional_indices() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]]) - - with pytest.raises(GridError): - grid.ind2pos(numpy.array([2, 0, 0]), which_shifts=0, check_bounds=True) - - edge_pos = grid.ind2pos(numpy.array([1.5, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True) - assert_allclose(edge_pos, [3.0, 2.0, 1.0]) - - with pytest.raises(GridError): - grid.ind2pos(numpy.array([1.6, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True) - - -def test_draw_polygon_accepts_coplanar_nx3_vertices() -> None: - grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) - arr_2d = grid.allocate(0) - arr_3d = grid.allocate(0) - slab = dict(axis='z', center=0.5, span=1.0) - - polygon_2d = numpy.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=float) - polygon_3d = numpy.array([[0, 0, 0.5], - [1, 0, 0.5], - [1, 1, 0.5], - [0, 1, 0.5]], dtype=float) - - grid.draw_polygon(arr_2d, slab=slab, polygon=polygon_2d, foreground=1) - grid.draw_polygon(arr_3d, slab=slab, polygon=polygon_3d, foreground=1) - - assert_allclose(arr_3d, arr_2d) - - -def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None: - grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) - arr = grid.allocate(0) - polygon = numpy.array([[0, 0, 0.5], - [1, 0, 0.5], - [1, 1, 0.75], - [0, 1, 0.5]], dtype=float) - - with pytest.raises(GridError): - grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1) - - -def test_get_slice_supports_sampling() -> None: - grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]]) - cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) - - grid_slice = grid.get_slice(cell_data, Plane(z=0.5), sample_period=2) - - assert_allclose(grid_slice, cell_data[0, ::2, ::2, 0]) - - -def test_sampled_visualization_helpers_do_not_error() -> None: - matplotlib = pytest.importorskip('matplotlib') - matplotlib.use('Agg') - from matplotlib import pyplot - - grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]]) - cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) - - fig_slice, ax_slice = grid.visualize_slice(cell_data, Plane(z=0.5), sample_period=2, finalize=False) - fig_edges, ax_edges = grid.visualize_edges(cell_data, Plane(z=0.5), sample_period=2, finalize=False) - - assert fig_slice is ax_slice.figure - assert fig_edges is ax_edges.figure - - pyplot.close(fig_slice) - pyplot.close(fig_edges) - - -def test_grid_constructor_rejects_invalid_coordinate_count() -> None: - with pytest.raises(GridError): - Grid([[0, 1], [0, 1]], shifts=[[0, 0, 0]]) - - with pytest.raises(GridError): - Grid([[0, 1], [0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) - - -def test_grid_constructor_rejects_invalid_periodic_length() -> None: - with pytest.raises(GridError): - Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]], periodic=[True, False]) - - -def test_extent_and_slab_reject_inverted_geometry() -> None: - with pytest.raises(GridError): - Extent(center=0, min=1) - - with pytest.raises(GridError): - Extent(min=2, max=1) - - with pytest.raises(GridError): - Slab(axis='z', center=1, max=0) - - -def test_extent_accepts_scalar_like_inputs() -> None: - extent = Extent(min=numpy.array([1.0]), span=numpy.array([4.0])) - - assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0]) - - -def test_get_slice_uses_shifted_grid_bounds() -> None: - grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0.5, 0, 0]]) - cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) - - grid_slice = grid.get_slice(cell_data, Plane(x=2.0), which_shifts=0) - - assert_allclose(grid_slice, cell_data[0, 1, :, :]) - - with pytest.raises(GridError): - grid.get_slice(cell_data, Plane(x=2.1), which_shifts=0) - - -def test_draw_extrude_rectangle_uses_boundary_slice() -> None: - grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]]) - cell_data = grid.allocate(0) - source = numpy.array([[1, 2], - [3, 4]], dtype=float) - cell_data[0, :, :, 1] = source - - grid.draw_extrude_rectangle( - cell_data, - rectangle=[[0, 0, 2], [2, 2, 2]], - direction=2, - polarity=-1, - distance=2, - ) - - assert_allclose(cell_data[0, :, :, 0], source) - assert_allclose(cell_data[0, :, :, 1], source) - - -def test_sampled_preview_exyz_tracks_nonuniform_centers() -> None: - grid = Grid([[0, 1, 3, 6, 10], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]]) - - sampled_exyz = grid._sampled_exyz(0, 2) - - assert_allclose(sampled_exyz[0], [-1.5, 2.5, 6.5]) - - -def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None: - matplotlib = pytest.importorskip('matplotlib') - matplotlib.use('Agg') - skimage_measure = pytest.importorskip('skimage.measure') - from matplotlib import pyplot - from mpl_toolkits.mplot3d.axes3d import Axes3D - - captured: dict[str, numpy.ndarray] = {} - - def fake_marching_cubes(_grid: numpy.ndarray, _level: float) -> tuple[numpy.ndarray, numpy.ndarray, None, None]: - verts = numpy.array([[0.5, 0.5, 0.5], - [0.5, 1.5, 0.5], - [1.5, 0.5, 0.5]], dtype=float) - faces = numpy.array([[0, 1, 2]], dtype=int) - return verts, faces, None, None - - def fake_plot_trisurf( # noqa: ANN202 - _self: object, - xs: numpy.ndarray, - ys: numpy.ndarray, - faces: numpy.ndarray, - zs: numpy.ndarray, - *_args: object, - **_kwargs: object, - ) -> object: - captured['xs'] = numpy.asarray(xs) - captured['ys'] = numpy.asarray(ys) - captured['faces'] = numpy.asarray(faces) - captured['zs'] = numpy.asarray(zs) - return object() - - monkeypatch.setattr(skimage_measure, 'marching_cubes', fake_marching_cubes) - monkeypatch.setattr(Axes3D, 'plot_trisurf', fake_plot_trisurf) - - grid = Grid([numpy.arange(7, dtype=float), numpy.arange(7, dtype=float), numpy.arange(7, dtype=float)], shifts=[[0, 0, 0]]) - cell_data = numpy.zeros(grid.cell_data_shape) - - fig, _ax = grid.visualize_isosurface(cell_data, level=0.5, sample_period=2, finalize=False) - - assert_allclose(captured['xs'], [1.5, 1.5, 3.5]) - assert_allclose(captured['ys'], [1.5, 3.5, 1.5]) - assert_allclose(captured['zs'], [1.5, 1.5, 1.5]) - - pyplot.close(fig) - - -def test_grid_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False]) - path = tmp_path / 'grid.state' - - grid.save(str(path)) - loaded = Grid.load(str(path)) - - assert path.exists() - for original, restored in zip(grid.exyz, loaded.exyz, strict=True): - assert_allclose(restored, original) - assert_allclose(loaded.shifts, grid.shifts) - assert loaded.periodic == grid.periodic - - -def test_grid_load_supports_legacy_pickle(tmp_path: pytest.TempPathFactory) -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False]) - path = tmp_path / 'grid.pickle' - with open(path, 'wb') as f: - pickle.dump(grid.__dict__, f, protocol=2) - - loaded = Grid.load(str(path)) - - for original, restored in zip(grid.exyz, loaded.exyz, strict=True): - assert_allclose(restored, original) - assert_allclose(loaded.shifts, grid.shifts) - assert loaded.periodic == grid.periodic - - -def test_griddata_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None: - data = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]]).with_data(fill_value=2.0) - data.cell_data[0, 1, 0, 0] = 5.0 - path = tmp_path / 'griddata.state' - - data.save(str(path)) - loaded = GridData.load(str(path)) - - assert path.exists() - assert_allclose(loaded.cell_data, data.cell_data) - assert_allclose(loaded.grid.shifts, data.grid.shifts) - assert loaded.grid.periodic == data.grid.periodic - - -def test_griddata_rejects_invalid_payload_kind(tmp_path: pytest.TempPathFactory) -> None: - path = tmp_path / 'grid.state' - Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).save(str(path)) - - with pytest.raises(GridError): - GridData.load(str(path)) - - -def test_negative_shift_nonperiodic_edges_and_widths() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) - - assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0]) - assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5]) - assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25]) - - -def test_negative_shift_periodic_edges_and_widths() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False]) - - assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0]) - assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5]) - - -def test_negative_shift_coordinate_round_trip() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) - - ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False) - pos = grid.ind2pos(ind, 0, round_ind=False) - - assert_allclose(ind, [1.0, 0.0, 0.0]) - assert_allclose(pos, [1.25, 1.0, 0.5]) - - -def test_negative_shift_draw_cuboid_fractional_fill() -> None: - grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) - arr = grid.allocate(0) - - grid.draw_cuboid( - arr, - x=dict(min=0, max=1), - y=dict(min=0, max=1), - z=dict(min=0, max=1), - foreground=1, - ) - - assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3]) - - -def test_negative_shift_get_slice_uses_shifted_centers() -> None: - grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) - cell_data = numpy.zeros(grid.cell_data_shape) - cell_data[0, 1, :, 0] = [7, 9] - x_center = float(grid.shifted_xyz(0)[0][1]) - - grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0) - - assert_allclose(grid_slice, [7, 9]) - - -def test_grid_with_data_returns_griddata() -> None: - grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) - data = grid.with_data(fill_value=2.0) - - assert isinstance(data, GridData) - assert_allclose(data.cell_data, numpy.full(grid.cell_data_shape, 2.0, dtype=numpy.float32)) - - -def test_griddata_constructor_validates_shape() -> None: - grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) - - with pytest.raises(GridError): - GridData(grid, numpy.zeros((1, 1, 1))) - - -def test_griddata_draw_methods_are_chainable() -> None: - data = Grid([[0, 1, 2], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0) - - chained = data.draw_cuboid( - foreground=1, - x=dict(min=0, max=1), - y=dict(min=0, max=1), - z=dict(min=0, max=1), - ).draw_polygon( - foreground=0.5, - slab=dict(axis='z', center=0.5, span=1.0), - polygon=numpy.array([[0, 0], [2, 0], [2, 1], [0, 1]], dtype=float), - ) - - assert chained is data - assert data.cell_data.sum() > 0 - - -def test_griddata_read_methods_delegate() -> None: - data = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0) - data.cell_data[0, :, :, 0] = numpy.array([[1, 2], [3, 4]], dtype=float) - - assert_allclose( - data.get_slice(Plane(z=0.5)), - data.grid.get_slice(data.cell_data, Plane(z=0.5)), - ) - - -def test_griddata_copy_is_independent() -> None: - data = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=1.0) - cloned = data.copy() - cloned.cell_data[0, 0, 0, 0] = 5.0 - - assert data is not cloned - assert data.grid is not cloned.grid - assert data.cell_data[0, 0, 0, 0] == 1.0 diff --git a/gridlock/utils.py b/gridlock/utils.py deleted file mode 100644 index 585b999..0000000 --- a/gridlock/utils.py +++ /dev/null @@ -1,248 +0,0 @@ -from typing import Protocol, TypedDict, runtime_checkable, cast -from dataclasses import dataclass - -import numpy - - -class GridError(Exception): - """ Base error type for `gridlock` """ - pass - - -def _coerce_scalar(name: str, value: object) -> float: - arr = numpy.asarray(value) - if arr.size != 1: - raise GridError(f'{name} must be a scalar value') - - try: - return float(arr.reshape(())) - except (TypeError, ValueError) as exc: - raise GridError(f'{name} must be a real scalar value') from exc - - -class ExtentDict(TypedDict, total=False): - """ - Geometrical definition of an extent (1D bounded region) - Must contain exactly two of `min`, `max`, `center`, or `span`. - """ - min: float - center: float - max: float - span: float - - -@runtime_checkable -class ExtentProtocol(Protocol): - """ - Anything that looks like an `Extent` - """ - center: float - span: float - - @property - def max(self) -> float: ... - - @property - def min(self) -> float: ... - - -@dataclass(init=False, slots=True) -class Extent(ExtentProtocol): - """ - Geometrical definition of an extent (1D bounded region) - May be constructed with any two of `min`, `max`, `center`, or `span`. - """ - center: float - span: float - - @property - def max(self) -> float: - return self.center + self.span / 2 - - @property - def min(self) -> float: - return self.center - self.span / 2 - - def __init__( - self, - *, - min: float | None = None, - center: float | None = None, - max: float | None = None, - span: float | None = None, - ) -> None: - values = { - 'min': None if min is None else _coerce_scalar('min', min), - 'center': None if center is None else _coerce_scalar('center', center), - 'max': None if max is None else _coerce_scalar('max', max), - 'span': None if span is None else _coerce_scalar('span', span), - } - if sum(value is not None for value in values.values()) != 2: - raise GridError('Exactly two of min, center, max, span must be provided') - - min_v = values['min'] - center_v = values['center'] - max_v = values['max'] - span_v = values['span'] - - if span_v is not None and span_v < 0: - raise GridError('span must be non-negative') - - if min_v is not None and max_v is not None: - if max_v < min_v: - raise GridError('max must be greater than or equal to min') - center_v = 0.5 * (max_v + min_v) - span_v = max_v - min_v - elif center_v is not None and min_v is not None: - span_v = 2 * (center_v - min_v) - if span_v < 0: - raise GridError('min must be less than or equal to center') - elif center_v is not None and max_v is not None: - span_v = 2 * (max_v - center_v) - if span_v < 0: - raise GridError('center must be less than or equal to max') - elif min_v is not None and span_v is not None: - center_v = min_v + 0.5 * span_v - elif max_v is not None and span_v is not None: - center_v = max_v - 0.5 * span_v - - if center_v is None or span_v is None: - raise GridError('Unable to construct extent from the provided values') - - self.center = center_v - self.span = span_v - - -class SlabDict(TypedDict, total=False): - """ - Geometrical definition of a slab (3D region bounded on one axis only) - Must contain `axis` plus any two of `min`, `max`, `center`, or `span`. - """ - min: float - center: float - max: float - span: float - axis: int | str - - -@runtime_checkable -class SlabProtocol(ExtentProtocol, Protocol): - """ - Anything that looks like a `Slab` - """ - axis: int - center: float - span: float - - @property - def max(self) -> float: ... - - @property - def min(self) -> float: ... - - -@dataclass(init=False, slots=True) -class Slab(Extent, SlabProtocol): - """ - Geometrical definition of a slab (3D region bounded on one axis only) - May be constructed with `axis` (bounded axis) plus any two of `min`, `max`, `center`, or `span`. - """ - axis: int - - def __init__( - self, - axis: int | str, - *, - min: float | None = None, - center: float | None = None, - max: float | None = None, - span: float | None = None, - ) -> None: - Extent.__init__(self, min=min, center=center, max=max, span=span) - - if isinstance(axis, str): - axis_int = 'xyz'.find(axis.lower()) - else: - axis_int = axis - if axis_int not in range(3): - raise GridError(f'Invalid axis (slab normal direction): {axis}') - self.axis = axis_int - - def as_plane(self, where: str) -> 'Plane': - if where == 'center': - return Plane(axis=self.axis, pos=self.center) - if where == 'min': - return Plane(axis=self.axis, pos=self.min) - if where == 'max': - return Plane(axis=self.axis, pos=self.max) - raise GridError(f'Invalid {where=}') - - -class PlaneDict(TypedDict, total=False): - """ - Geometrical definition of a plane (2D unbounded region in 3D space) - Must contain exactly one of `x`, `y`, `z`, or both `axis` and `pos` - """ - x: float - y: float - z: float - axis: int - pos: float - - -@runtime_checkable -class PlaneProtocol(Protocol): - """ - Anything that looks like a `Plane` - """ - axis: int - pos: float - - -@dataclass(init=False, slots=True) -class Plane(PlaneProtocol): - """ - Geometrical definition of a plane (2D unbounded region in 3D space) - May be constructed with any of `x=4`, `y=5`, `z=-5`, or `axis=2, pos=-5`. - """ - axis: int - pos: float - - def __init__( - self, - *, - axis: int | str | None = None, - pos: float | None = None, - x: float | None = None, - y: float | None = None, - z: float | None = None, - ) -> None: - xx = x - yy = y - zz = z - - if sum(aa is not None for aa in (pos, xx, yy, zz)) != 1: - raise GridError('Exactly one of pos, x, y, z must be non-None!') - if (axis is None) != (pos is None): - raise GridError('Either both or neither of `axis` and `pos` must be defined.') - - if isinstance(axis, str): - axis_int = 'xyz'.find(axis.lower()) - elif axis is None: - axis_int = (xx is None, yy is None, zz is None).index(False) - else: - axis_int = axis - - if axis_int not in range(3): - raise GridError(f'Invalid axis (slab normal direction): {axis=} {x=} {y=} {z=}') - self.axis = axis_int - - if pos is not None: - cpos = pos - else: - cpos = cast('float', (xx, yy, zz)[axis_int]) - assert cpos is not None - - if hasattr(cpos, '__len__'): - assert len(cpos) == 1 - self.pos = cpos diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 03d0d19..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,98 +0,0 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[project] -name = "gridlock" -description = "Coupled gridding library" -readme = "README.md" -license = { file = "LICENSE.md" } -authors = [ - { name="Jan Petykiewicz", email="jan@mpxd.net" }, - ] -homepage = "https://mpxd.net/code/jan/gridlock" -repository = "https://mpxd.net/code/jan/gridlock" -keywords = [ - "FDTD", - "gridding", - "simulation", - "nonuniform", - "FDFD", - "finite", - "difference", - ] -classifiers = [ - "Programming Language :: Python :: 3", - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: GNU Affero General Public License v3", - "Topic :: Multimedia :: Graphics :: 3D Rendering", - "Topic :: Scientific/Engineering :: Electronic Design Automation (EDA)", - "Topic :: Scientific/Engineering :: Physics", - "Topic :: Scientific/Engineering :: Visualization", - ] -requires-python = ">=3.11" -include = [ - "LICENSE.md" - ] -dynamic = ["version"] -dependencies = [ - "numpy>=1.26", - "float_raster>=0.8", - ] - - -[tool.hatch.version] -path = "gridlock/__init__.py" - -[project.optional-dependencies] -visualization = ["matplotlib"] -visualization-isosurface = [ - "matplotlib", - "skimage>=0.13", - "mpl_toolkits", - ] - - -[tool.ruff] -exclude = [ - ".git", - "dist", - ] -line-length = 145 -indent-width = 4 -lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" -lint.select = [ - "NPY", "E", "F", "W", "B", "ANN", "UP", "SLOT", "SIM", "LOG", - "C4", "ISC", "PIE", "PT", "RET", "TCH", "PTH", "INT", - "ARG", "PL", "R", "TRY", - "G010", "G101", "G201", "G202", - "Q002", "Q003", "Q004", - ] -lint.ignore = [ - #"ANN001", # No annotation - "ANN002", # *args - "ANN003", # **kwargs - "ANN401", # Any - "SIM108", # single-line if / else assignment - "RET504", # x=y+z; return x - "PIE790", # unnecessary pass - "ISC003", # non-implicit string concatenation - "C408", # dict(x=y) instead of {'x': y} - "PLR09", # Too many xxx - "PLR2004", # magic number - "PLC0414", # import x as x - "TRY003", # Long exception message - "PTH123", # open() - ] - - -[[tool.mypy.overrides]] -module = [ - "matplotlib", - "matplotlib.axes", - "matplotlib.figure", - "mpl_toolkits.mplot3d", - ] -ignore_missing_imports = true diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..a424b97 --- /dev/null +++ b/setup.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 + +from setuptools import setup, find_packages + +with open('README.md', 'r') as f: + long_description = f.read() + +with open('gridlock/VERSION.py', 'rt') as f: + version = f.readlines()[2].strip() + +setup(name='gridlock', + version=version, + description='Coupled gridding library', + long_description=long_description, + long_description_content_type='text/markdown', + author='Jan Petykiewicz', + author_email='jan@mpxd.net', + url='https://mpxd.net/code/jan/gridlock', + packages=find_packages(), + package_data={ + 'gridlock': ['py.typed'], + }, + install_requires=[ + 'numpy', + 'float_raster', + ], + extras_require={ + 'visualization': ['matplotlib'], + 'visualization-isosurface': [ + 'matplotlib', + 'skimage>=0.13', + 'mpl_toolkits', + ], + }, + classifiers=[ + 'Programming Language :: Python :: 3', + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: GNU Affero General Public License v3', + 'Topic :: Multimedia :: Graphics :: 3D Rendering', + 'Topic :: Scientific/Engineering :: Electronic Design Automation (EDA)', + 'Topic :: Scientific/Engineering :: Physics', + 'Topic :: Scientific/Engineering :: Visualization', + ], + ) +