diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..0042015 --- /dev/null +++ b/.flake8 @@ -0,0 +1,29 @@ +[flake8] +ignore = + # E501 line too long + E501, + # W391 newlines at EOF + W391, + # E241 multiple spaces after comma + E241, + # E302 expected 2 newlines + E302, + # W503 line break before binary operator (to be deprecated) + W503, + # E265 block comment should start with '# ' + E265, + # E123 closing bracket does not match indentation of opening bracket's line + E123, + # E124 closing bracket does not match visual indentation + E124, + # E221 multiple spaces before operator + E221, + # E201 whitespace after '[' + E201, + # E741 ambiguous variable name 'I' + E741, + + +per-file-ignores = + # F401 import without use + */__init__.py: F401, diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index c28ab72..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,2 +0,0 @@ -include README.md -include LICENSE.md diff --git a/README.md b/README.md index e882b2e..e4f43d5 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ the coordinates of the boundary points along each axis). ## Installation Requirements: -* python 3 (written and tested with 3.9) +* python >3.11 (written and tested with 3.12) * numpy * [float_raster](https://mpxd.net/code/jan/float_raster) * matplotlib (optional, used for visualization functions) diff --git a/gridlock/LICENSE.md b/gridlock/LICENSE.md new file mode 120000 index 0000000..7eabdb1 --- /dev/null +++ b/gridlock/LICENSE.md @@ -0,0 +1 @@ +../LICENSE.md \ No newline at end of file diff --git a/gridlock/README.md b/gridlock/README.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/gridlock/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/gridlock/VERSION.py b/gridlock/VERSION.py deleted file mode 100644 index 8e6abf6..0000000 --- a/gridlock/VERSION.py +++ /dev/null @@ -1,4 +0,0 @@ -""" VERSION defintion. THIS FILE IS MANUALLY PARSED BY setup.py and REQUIRES A SPECIFIC FORMAT """ -__version__ = ''' -1.0 -'''.strip() diff --git a/gridlock/__init__.py b/gridlock/__init__.py index d547794..2f39696 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -15,10 +15,24 @@ Dependencies: - mpl_toolkits.mplot3d [Grid.visualize_isosurface()] - skimage [Grid.visualize_isosurface()] """ -from .error import GridError -from .grid import Grid +from .utils import ( + GridError as GridError, + + Extent as Extent, + ExtentProtocol as ExtentProtocol, + ExtentDict as ExtentDict, + + Slab as Slab, + SlabProtocol as SlabProtocol, + SlabDict as SlabDict, + + Plane as Plane, + PlaneProtocol as PlaneProtocol, + PlaneDict as PlaneDict, + ) +from .grid import Grid as Grid + __author__ = 'Jan Petykiewicz' - -from .VERSION import __version__ +__version__ = '2.0' version = __version__ diff --git a/gridlock/base.py b/gridlock/base.py new file mode 100644 index 0000000..aca9c69 --- /dev/null +++ b/gridlock/base.py @@ -0,0 +1,196 @@ +from typing import Protocol + +import numpy +from numpy.typing import NDArray + +from . import GridError + + +class GridBase(Protocol): + exyz: list[NDArray] + """Cell edges. Monotonically increasing without duplicates.""" + + periodic: list[bool] + """For each axis, determines how far the rightmost boundary gets shifted. """ + + shifts: NDArray + """Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`""" + + @property + def dxyz(self) -> list[NDArray]: + """ + Cell sizes for each axis, no shifts applied + + Returns: + List of 3 ndarrays of cell sizes + """ + return [numpy.diff(ee) for ee in self.exyz] + + @property + def xyz(self) -> list[NDArray]: + """ + Cell centers for each axis, no shifts applied + + Returns: + List of 3 ndarrays of cell edges + """ + return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)] + + @property + def shape(self) -> NDArray[numpy.intp]: + """ + The number of cells in x, y, and z + + Returns: + ndarray of [x_centers.size, y_centers.size, z_centers.size] + """ + return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int) + + @property + def num_grids(self) -> int: + """ + The number of grids (number of shifts) + """ + return self.shifts.shape[0] + + @property + def cell_data_shape(self) -> NDArray[numpy.intp]: + """ + The shape of the cell_data ndarray (num_grids, *self.shape). + """ + return numpy.hstack((self.num_grids, self.shape)) + + @property + def dxyz_with_ghost(self) -> list[NDArray]: + """ + Gives dxyz with an additional 'ghost' cell at the end, whose value depends + on whether or not the axis has periodic boundary conditions. See main description + above to learn why this is necessary. + + If periodic, final edge shifts same amount as first + Otherwise, final edge shifts same amount as second-to-last + + Returns: + list of [dxs, dys, dzs] with each element same length as elements of `self.xyz` + """ + el = [0 if p else -1 for p in self.periodic] + return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)] + + @property + def center(self) -> NDArray[numpy.float64]: + """ + Center position of the entire grid, no shifts applied + + Returns: + ndarray of [x_center, y_center, z_center] + """ + # center is just average of first and last xyz, which is just the average of the + # first two and last two exyz + centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)] + return numpy.array(centers, dtype=float) + + @property + def dxyz_limits(self) -> tuple[NDArray, NDArray]: + """ + Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element + ndarrays. No shifts are applied, so these are extreme bounds on these values (as a + weighted average is performed when shifting). + + Returns: + Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]` + """ + d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float) + d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float) + return d_min, d_max + + def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]: + """ + Returns edges for which_shifts. + + Args: + which_shifts: Which grid (which shifts) to use, or `None` for unshifted + + Returns: + List of 3 ndarrays of cell edges + """ + if which_shifts is None: + return self.exyz + dxyz = self.dxyz_with_ghost + shifts = self.shifts[which_shifts, :] + + # If shift is negative, use left cell's dx to determine shift + for a in range(3): + if shifts[a] < 0: + dxyz[a] = numpy.roll(dxyz[a], 1) + + return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] + + def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]: + """ + Returns cell sizes for `which_shifts`. + + Args: + which_shifts: Which grid (which shifts) to use, or `None` for unshifted + + Returns: + List of 3 ndarrays of cell sizes + """ + if which_shifts is None: + return self.dxyz + shifts = self.shifts[which_shifts, :] + dxyz = self.dxyz_with_ghost + + # If shift is negative, use left cell's dx to determine size + sdxyz = [] + for a in range(3): + if shifts[a] < 0: + roll_dxyz = numpy.roll(dxyz[a], 1) + abs_shift = numpy.abs(shifts[a]) + sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift)) + else: + sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a]) + + return sdxyz + + def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: + """ + Returns cell centers for `which_shifts`. + + Args: + which_shifts: Which grid (which shifts) to use, or `None` for unshifted + + Returns: + List of 3 ndarrays of cell centers + """ + if which_shifts is None: + return self.xyz + exyz = self.shifted_exyz(which_shifts) + dxyz = self.shifted_dxyz(which_shifts) + return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)] + + def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]: + """ + Return cell widths, with each dimension shifted by the corresponding shifts. + + Returns: + `[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]` + """ + if self.num_grids != 3: + raise GridError('Autoshifting requires exactly 3 grids') + return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] + + def allocate(self, fill_value: float | None = 1.0, dtype: type[numpy.number] = numpy.float32) -> NDArray: + """ + Allocate an ndarray for storing grid data. + + Args: + fill_value: Value to initialize the grid to. If None, an + uninitialized array is returned. + dtype: Numpy dtype for the array. Default is `numpy.float32`. + + Returns: + The allocated array + """ + if fill_value is None: + return numpy.empty(self.cell_data_shape, dtype=dtype) + return numpy.full(self.cell_data_shape, fill_value, dtype=dtype) diff --git a/gridlock/direction.py b/gridlock/direction.py deleted file mode 100644 index b93b122..0000000 --- a/gridlock/direction.py +++ /dev/null @@ -1,10 +0,0 @@ -from enum import Enum - - -class Direction(Enum): - """ - Enum for axis->integer mapping - """ - x = 0 - y = 1 - z = 2 diff --git a/gridlock/draw.py b/gridlock/draw.py index 6385213..9ba4623 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -1,12 +1,14 @@ """ Drawing-related methods for Grid class """ -from typing import List, Optional, Union, Sequence, Callable +from collections.abc import Sequence, Callable -import numpy # type: ignore +import numpy +from numpy.typing import NDArray, ArrayLike from float_raster import raster -from . import GridError +from .utils import GridError, Slab, SlabDict, SlabProtocol, Extent, ExtentDict, ExtentProtocol +from .position import GridPosMixin # NOTE: Maybe it would make sense to create a GridDrawer class @@ -15,364 +17,385 @@ from . import GridError # without having to pass `cell_data` again each time? -foreground_callable_t = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray] +foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray] +foreground_t = float | foreground_callable_t -def draw_polygons(self, - cell_data: numpy.ndarray, - surface_normal: int, - center: numpy.ndarray, - polygons: Sequence[numpy.ndarray], - thickness: float, - foreground: Union[Sequence[Union[float, foreground_callable_t]], float, foreground_callable_t], - ) -> None: - """ - Draw polygons on an axis-aligned plane. +class GridDrawMixin(GridPosMixin): + def draw_polygons( + self, + cell_data: NDArray, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + polygons: Sequence[ArrayLike], + *, + offset2d: ArrayLike = (0, 0), + ) -> None: + """ + Draw polygons on an axis-aligned slab. - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: 3-element ndarray or list specifying an offset applied to all the polygons - polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon - (non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each - polygon must have at least 3 vertices. - thickness: Thickness of the layer to draw - foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list - of any of these (1 per grid). Callable values should take an ndarray the shape of the - grid and return an ndarray of equal shape containing the foreground value at the given x, y, - and z (natural, not grid coordinates). + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list + of any of these (1 per grid). Callable values should take an ndarray the shape of the + grid and return an ndarray of equal shape containing the foreground value at the given x, y, + and z (natural, not grid coordinates). + slab: `Slab` or slab-like dict specifying the slab in which the polygons will be drawn. + polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon + (non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each + polygon must have at least 3 vertices. + offset2d: 2D offset to apply to polygon coordinates -- this offset is added directly + to the given polygon vertex coordinates. Default (0, 0). - Raises: - GridError - """ - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') + Raises: + GridError + """ + if isinstance(slab, dict): + slab = Slab(**slab) - center = numpy.squeeze(center) + poly_list = [numpy.asarray(poly) for poly in polygons] - # Check polygons, and remove redundant coordinates - surface = numpy.delete(range(3), surface_normal) + # Check polygons, and remove redundant coordinates + surface = numpy.delete(range(3), slab.axis) - for i, polygon in enumerate(polygons): - malformed = f'Malformed polygon: ({i})' - if polygon.shape[1] not in (2, 3): + for ii in range(len(poly_list)): + polygon = poly_list[ii] + malformed = f'Malformed polygon: ({ii})' + if polygon.shape[1] not in (2, 3): raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') - if polygon.shape[1] == 3: - polygon = polygon[surface, :] + if polygon.shape[1] == 3: + polygon = polygon[surface, :] + poly_list[ii] = polygon - if not polygon.shape[0] > 2: - raise GridError(malformed + 'must consist of more than 2 points') - if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1: - raise GridError(malformed + 'must be in plane with surface normal ' - + 'xyz'[surface_normal]) + if not polygon.shape[0] > 2: + raise GridError(malformed + 'must consist of more than 2 points') + if polygon.ndim > 2 and not numpy.unique(polygon[:, slab.axis]).size == 1: + raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis]) - # Broadcast foreground where necessary - if numpy.size(foreground) == 1: - foreground = [foreground] * len(cell_data) - elif isinstance(foreground, numpy.ndarray): - raise GridError('ndarray not supported for foreground') - - # ## Compute sub-domain of the grid occupied by polygons - # 1) Compute outer bounds (bd) of polygons - bd_2d_min = [0, 0] - bd_2d_max = [0, 0] - for polygon in polygons: - bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) - bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) - bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center - bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center - - # 2) Find indices (bdi) just outside bd elements - buf = 2 # size of safety buffer - # Use s_min and s_max with unshifted pos2ind to get absolute limits on - # the indices the polygons might affect - s_min = self.shifts.min(axis=0) - s_max = self.shifts.max(axis=0) - bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf - bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf - bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int) - bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int) - - # 3) Adjust polygons for center - polygons = [poly + center[surface] for poly in polygons] - - # ## Generate weighing function - def to_3d(vector: numpy.ndarray, val: float = 0.0) -> numpy.ndarray: - v_2d = numpy.array(vector, dtype=float) - return numpy.insert(v_2d, surface_normal, (val,)) - - # iterate over grids - for i, grid in enumerate(cell_data): - # ## Evaluate or expand foreground[i] - if callable(foreground[i]): - # meshgrid over the (shifted) domain - domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k]+1] for k in range(3)] - (x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij') - - # evaluate on the meshgrid - foreground_i = foreground[i](x0, y0, z0) - if not numpy.isfinite(foreground_i).all(): - raise GridError(f'Non-finite values in foreground[{i}]') - elif numpy.size(foreground[i]) != 1: - raise GridError(f'Unsupported foreground[{i}]: {type(foreground[i])}') + # Broadcast foreground where necessary + foregrounds: Sequence[foreground_callable_t] | Sequence[float] + if numpy.size(foreground) == 1: # type: ignore + foregrounds = [foreground] * len(cell_data) # type: ignore + elif isinstance(foreground, numpy.ndarray): + raise GridError('ndarray not supported for foreground') else: - # foreground[i] is scalar non-callable - foreground_i = foreground[i] + foregrounds = foreground # type: ignore - w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int)) + # ## Compute sub-domain of the grid occupied by polygons + # 1) Compute outer bounds (bd) of polygons + bd_2d_min = numpy.array([0, 0]) + bd_2d_max = numpy.array([0, 0]) + for polygon in poly_list: + bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) + offset2d + bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) + offset2d + bd_min = numpy.insert(bd_2d_min, slab.axis, slab.min) + bd_max = numpy.insert(bd_2d_max, slab.axis, slab.max) - # Draw each polygon separately - for polygon in polygons: + # 2) Find indices (bdi) just outside bd elements + buf = 2 # size of safety buffer + # Use s_min and s_max with unshifted pos2ind to get absolute limits on + # the indices the polygons might affect + s_min = self.shifts.min(axis=0) + s_max = self.shifts.max(axis=0) + bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf + bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf + bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int) + bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int) - # Get the boundaries of the polygon - pbd_min = polygon.min(axis=0) - pbd_max = polygon.max(axis=0) + # 3) Adjust polygons for offset2d + poly_list = [poly + offset2d for poly in poly_list] - # Find indices in w_xy just outside polygon - # using per-grid xy-weights (self.shifted_xyz()) - corner_min = self.pos2ind(to_3d(pbd_min), i, - check_bounds=False)[surface].astype(int) - corner_max = self.pos2ind(to_3d(pbd_max), i, - check_bounds=False)[surface].astype(int) + # ## Generate weighing function + def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]: + v_2d = numpy.array(vector, dtype=float) + return numpy.insert(v_2d, slab.axis, (val,)) - # Find indices in w_xy which are modified by polygon - # First for the edge coordinates (+1 since we're indexing edges) - edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max)] - # Then for the pixel centers (-bdi_min since we're - # calculating weights within a subspace) - centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface], - corner_max - bdi_min[surface])) + # iterate over grids + foreground_val: NDArray | float + for i, _ in enumerate(cell_data): + # ## Evaluate or expand foregrounds[i] + foregrounds_i = foregrounds[i] + if callable(foregrounds_i): + # meshgrid over the (shifted) domain + domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k] + 1] for k in range(3)] + (x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij') - aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices)) - w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y) - - # Clamp overlapping polygons to 1 - w_xy = numpy.minimum(w_xy, 1.0) - - # 2) Generate weights in z-direction - w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], )) - - def get_zi(offset, i=i, w_z=w_z): - edges = self.shifted_exyz(i)[surface_normal] - point = center[surface_normal] + offset - grid_coord = numpy.digitize(point, edges) - 1 - w_coord = grid_coord - bdi_min[surface_normal] - - if w_coord < 0: - w_coord = 0 - f = 0 - elif w_coord >= w_z.size: - w_coord = w_z.size - 1 - f = 1 + # evaluate on the meshgrid + foreground_val = foregrounds_i(x0, y0, z0) + if not numpy.isfinite(foreground_val).all(): + raise GridError(f'Non-finite values in foreground[{i}]') + elif numpy.size(foregrounds_i) != 1: + raise GridError(f'Unsupported foreground[{i}]: {type(foregrounds_i)}') else: - dz = self.shifted_dxyz(i)[surface_normal][grid_coord] - f = (point - edges[grid_coord]) / dz - return f, w_coord + # foreground[i] is scalar non-callable + foreground_val = foregrounds_i - zi_top_f, zi_top = get_zi(+thickness / 2.0) - zi_bot_f, zi_bot = get_zi(-thickness / 2.0) + w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int)) - w_z[zi_bot + 1:zi_top] = 1 + # Draw each polygon separately + for polygon in poly_list: - if zi_bot < zi_top: - w_z[zi_top] = zi_top_f - w_z[zi_bot] = 1 - zi_bot_f - else: - w_z[zi_bot] = zi_top_f - zi_bot_f + # Get the boundaries of the polygon + pbd_min = polygon.min(axis=0) + pbd_max = polygon.max(axis=0) - # 3) Generate total weight function - w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,))) + # Find indices in w_xy just outside polygon + # using per-grid xy-weights (self.shifted_xyz()) + corner_min = self.pos2ind(to_3d(pbd_min), i, + check_bounds=False)[surface].astype(int) + corner_max = self.pos2ind(to_3d(pbd_max), i, + check_bounds=False)[surface].astype(int) - # ## Modify the grid - g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3)) - cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_i + # Find indices in w_xy which are modified by polygon + # First for the edge coordinates (+1 since we're indexing edges) + edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max, strict=True)] + # Then for the pixel centers (-bdi_min since we're + # calculating weights within a subspace) + centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface], + corner_max - bdi_min[surface], strict=True)) + + aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices, strict=True)) + w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y) + + # Clamp overlapping polygons to 1 + w_xy = numpy.minimum(w_xy, 1.0) + + # 2) Generate weights in z-direction + w_z = numpy.zeros(((bdi_max - bdi_min + 1)[slab.axis], )) + + def get_zi(point: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001 + edges = self.shifted_exyz(i)[slab.axis] + grid_coord = numpy.digitize(point, edges) - 1 + w_coord = grid_coord - bdi_min[slab.axis] + + if w_coord < 0: + w_coord = 0 + f = 0 + elif w_coord >= w_z.size: + w_coord = w_z.size - 1 + f = 1 + else: + dz = self.shifted_dxyz(i)[slab.axis][grid_coord] + f = (point - edges[grid_coord]) / dz + return f, w_coord + + zi_top_f, zi_top = get_zi(slab.max) + zi_bot_f, zi_bot = get_zi(slab.min) + + w_z[zi_bot + 1:zi_top] = 1 + + if zi_bot < zi_top: + w_z[zi_top] = zi_top_f + w_z[zi_bot] = 1 - zi_bot_f + else: + w_z[zi_bot] = zi_top_f - zi_bot_f + + # 3) Generate total weight function + w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], slab.axis, (2,))) + + # ## Modify the grid + g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3)) + cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val -def draw_polygon(self, - cell_data: numpy.ndarray, - surface_normal: int, - center: numpy.ndarray, - polygon: numpy.ndarray, - thickness: float, - foreground: Union[Sequence[Union[float, foreground_callable_t]], float, foreground_callable_t], - ) -> None: - """ - Draw a polygon on an axis-aligned plane. + def draw_polygon( + self, + cell_data: NDArray, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + polygon: ArrayLike, + *, + offset2d: ArrayLike = (0, 0), + ) -> None: + """ + Draw a polygon on an axis-aligned plane. - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: 3-element ndarray or list specifying an offset applied to the polygon - polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, - clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at - least 3 vertices. - thickness: Thickness of the layer to draw - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - """ - self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + slab: `Slab` or slab-like dict specifying the slab in which the polygon will be drawn. + polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, + clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Must have at + least 3 vertices. + offset2d: 2D offset to apply to polygon coordinates -- this offset is added directly + to the given polygon vertex coordinates. Default (0, 0). + """ + self.draw_polygons( + cell_data = cell_data, + slab = slab, + polygons = [polygon], + foreground = foreground, + offset2d = offset2d, + ) -def draw_slab(self, - cell_data: numpy.ndarray, - surface_normal: int, - center: numpy.ndarray, - thickness: float, - foreground: Union[List[Union[float, foreground_callable_t]], float, foreground_callable_t], - ) -> None: - """ - Draw an axis-aligned infinite slab. + def draw_slab( + self, + cell_data: NDArray, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + ) -> None: + """ + Draw an axis-aligned infinite slab. - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: `surface_normal` coordinate value at the center of the slab - thickness: Thickness of the layer to draw - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - """ - # Turn surface_normal into its integer representation - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + slab: `Slab` or slab-like dict (geometrical slab specification) + """ + if isinstance(slab, dict): + slab = Slab(**slab) - if numpy.size(center) != 1: - center = numpy.squeeze(center) - if len(center) == 3: - center = center[surface_normal] - else: - raise GridError(f'Bad center: {center}') + # Find center of slab + center_shift = self.center + center_shift[slab.axis] = slab.center - # Find center of slab - center_shift = self.center - center_shift[surface_normal] = center + surface = numpy.delete(range(3), slab.axis) + u_min, u_max = self.exyz[surface[0]][[0, -1]] + v_min, v_max = self.exyz[surface[1]][[0, -1]] - surface = numpy.delete(range(3), surface_normal) + margin = 4 * numpy.max([self.dxyz[surface[0]].max(), + self.dxyz[surface[1]].max()]) - xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface] - xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface] + p = numpy.array([[u_min - margin, v_max + margin], + [u_max + margin, v_max + margin], + [u_max + margin, v_min - margin], + [u_min - margin, v_min - margin]], dtype=float) - dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float) - - xyz_min -= 4 * dxyz - xyz_max += 4 * dxyz - - p = numpy.array([[xyz_min[0], xyz_max[1]], - [xyz_max[0], xyz_max[1]], - [xyz_max[0], xyz_min[1]], - [xyz_min[0], xyz_min[1]]], dtype=float) - - self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground) + self.draw_polygon( + cell_data = cell_data, + slab = slab, + polygon = p, + foreground = foreground, + ) -def draw_cuboid(self, - cell_data: numpy.ndarray, - center: numpy.ndarray, - dimensions: numpy.ndarray, - foreground: Union[List[Union[float, foreground_callable_t]], float, foreground_callable_t], - ) -> None: - """ - Draw an axis-aligned cuboid + def draw_cuboid( + self, + cell_data: NDArray, + foreground: Sequence[foreground_t] | foreground_t, + *, + x: ExtentProtocol | ExtentDict, + y: ExtentProtocol | ExtentDict, + z: ExtentProtocol | ExtentDict, + ) -> None: + """ + Draw an axis-aligned cuboid - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - center: 3-element ndarray or list specifying the cuboid's center - dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge - sizes of the cuboid - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - """ - p = numpy.array([[-dimensions[0], +dimensions[1]], - [+dimensions[0], +dimensions[1]], - [+dimensions[0], -dimensions[1]], - [-dimensions[0], -dimensions[1]]], dtype=float) / 2.0 - thickness = dimensions[2] - self.draw_polygon(cell_data, 2, center, p, thickness, foreground) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + x: `Extent` or extent-like dict specifying the x-extent of the cuboid. + y: `Extent` or extent-like dict specifying the y-extent of the cuboid. + z: `Extent` or extent-like dict specifying the z-extent of the cuboid. + """ + if isinstance(x, dict): + x = Extent(**x) + if isinstance(y, dict): + y = Extent(**y) + if isinstance(z, dict): + z = Extent(**z) + + center = numpy.asarray([x.center, y.center, z.center]) + + p = numpy.array([[x.min, y.max], + [x.max, y.max], + [x.max, y.min], + [x.min, y.min]], dtype=float) + slab = Slab(axis=2, center=z.center, span=z.span) + self.draw_polygon(cell_data=cell_data, slab=slab, polygon=p, foreground=foreground) -def draw_cylinder(self, - cell_data: numpy.ndarray, - surface_normal: int, - center: numpy.ndarray, - radius: float, - thickness: float, - num_points: int, - foreground: Union[List[Union[float, foreground_callable_t]], float, foreground_callable_t], - ) -> None: - """ - Draw an axis-aligned cylinder. Approximated by a num_points-gon + def draw_cylinder( + self, + cell_data: NDArray, + h: SlabProtocol | SlabDict, + radius: float, + num_points: int, + center2d: ArrayLike, + foreground: Sequence[foreground_t] | foreground_t, + ) -> None: + """ + Draw an axis-aligned cylinder. Approximated by a num_points-gon - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: 3-element ndarray or list specifying the cylinder's center - radius: cylinder radius - thickness: Thickness of the layer to draw - num_points: The circle is approximated by a polygon with `num_points` vertices - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - """ - theta = numpy.linspace(0, 2*numpy.pi, num_points, endpoint=False) - x = radius * numpy.sin(theta) - y = radius * numpy.cos(theta) - polygon = numpy.hstack((x[:, None], y[:, None])) - self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + h: + radius: + num_points: The circle is approximated by a polygon with `num_points` vertices + center2d: + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + """ + if isinstance(h, dict): + h = Slab(**h) + + theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False)[:, None] + xy0 = numpy.hstack((numpy.sin(theta), numpy.cos(theta))) + polygon = radius * xy0 + self.draw_polygon(cell_data=cell_data, slab=h, polygon=polygon, foreground=foreground, offset2d=center2d) -def draw_extrude_rectangle(self, - cell_data: numpy.ndarray, - rectangle: numpy.ndarray, - direction: int, - polarity: int, - distance: float, - ) -> None: - """ - Extrude a rectangle of a previously-drawn structure along an axis. + def draw_extrude_rectangle( + self, + cell_data: NDArray, + rectangle: ArrayLike, + direction: int, + polarity: int, + distance: float, + ) -> None: + """ + Extrude a rectangle of a previously-drawn structure along an axis. - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - rectangle: 2x3 ndarray or list specifying the rectangle's corners - direction: Direction to extrude in. Integer in `range(3)`. - polarity: +1 or -1, direction along axis to extrude in - distance: How far to extrude - """ - s = numpy.sign(polarity) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + rectangle: 2x3 ndarray or list specifying the rectangle's corners + direction: Direction to extrude in. Integer in `range(3)`. + polarity: +1 or -1, direction along axis to extrude in + distance: How far to extrude + """ + sgn = numpy.sign(polarity) - rectangle = numpy.array(rectangle, dtype=float) - if s == 0: - raise GridError('0 is not a valid polarity') - if direction not in range(3): - raise GridError(f'Invalid direction: {direction}') - if rectangle[0, direction] != rectangle[1, direction]: - raise GridError('Rectangle entries along extrusion direction do not match.') + rectangle = numpy.asarray(rectangle, dtype=float) + if sgn == 0: + raise GridError('0 is not a valid polarity') + if direction not in range(3): + raise GridError(f'Invalid direction: {direction}') + if rectangle[0, direction] != rectangle[1, direction]: + raise GridError('Rectangle entries along extrusion direction do not match.') - center = rectangle.sum(axis=0) / 2.0 - center[direction] += s * distance / 2.0 + center = rectangle.sum(axis=0) / 2.0 + center[direction] += sgn * distance / 2.0 - surface = numpy.delete(range(3), direction) + surface = numpy.delete(range(3), direction) - dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface] - p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0]/2.0, - numpy.array([-1, 1, 1, -1], dtype=float) * dim[1]/2.0)).T - thickness = distance + dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface] + poly = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5, + numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T + thickness = distance - foreground_func = [] - for i, grid in enumerate(cell_data): - z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction] + foreground_func = [] + for ii, grid in enumerate(cell_data): + zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction] - ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)] + ind = [int(numpy.floor(zz)) if dd == direction else slice(None) for dd in range(3)] - fpart = z - numpy.floor(z) - mult = [1-fpart, fpart][::s] # reverses if s negative + fpart = zz - numpy.floor(zz) + mult = [1 - fpart, fpart][::sgn] # reverses if s negative - foreground = mult[0] * grid[tuple(ind)] - ind[direction] += 1 - foreground += mult[1] * grid[tuple(ind)] + foreground = mult[0] * grid[tuple(ind)] + ind[direction] += 1 # type: ignore #(known safe) + foreground += mult[1] * grid[tuple(ind)] - def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> numpy.ndarray: - # transform from natural position to index - xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i) - for qrs in zip(xs.flat, ys.flat, zs.flat)], dtype=int) - # reshape to original shape and keep only in-plane components - qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface) - return foreground[qi, ri] + def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 + # transform from natural position to index + xyzi = numpy.array([self.pos2ind(qrs, which_shifts=ii) + for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64) + # reshape to original shape and keep only in-plane components + qi, ri = (numpy.reshape(xyzi[:, kk], xs.shape) for kk in surface) + return foreground[qi, ri] - foreground_func.append(f_foreground) + foreground_func.append(f_foreground) - self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func) + slab = Slab(axis=direction, center=center[direction], span=thickness) + self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface]) diff --git a/gridlock/error.py b/gridlock/error.py deleted file mode 100644 index 3974e9c..0000000 --- a/gridlock/error.py +++ /dev/null @@ -1,2 +0,0 @@ -class GridError(Exception): - pass diff --git a/gridlock/examples/ex0.py b/gridlock/examples/ex0.py index ca2ef55..4ff2fb9 100644 --- a/gridlock/examples/ex0.py +++ b/gridlock/examples/ex0.py @@ -1,4 +1,4 @@ -import numpy # type: ignore +import numpy from gridlock import Grid @@ -6,18 +6,18 @@ if __name__ == '__main__': # xyz = [numpy.arange(-5.0, 6.0), numpy.arange(-4.0, 5.0), [-1.0, 1.0]] # eg = Grid(xyz) # egc = Grid.allocate(0.0) - # # eg.draw_slab(egc, surface_normal=2, center=0, thickness=10, foreground=2) - # eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=4, - # thickness=10, num_points=1000, foreground=1) - # eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2) + # # eg.draw_slab(egc, slab=dict(axis=2, center=0, span=10), foreground=2) + # eg.draw_cylinder(egc, h=slab(axis=2, center=0, span=10), + # center2d=[0, 0], radius=4, thickness=10, num_points=1000, foreground=1) + # eg.visualize_slice(egc, plane=dict(z=0), which_shifts=2) # xyz2 = [numpy.arange(-5.0, 6.0), [-1.0, 1.0], numpy.arange(-4.0, 5.0)] # eg2 = Grid(xyz2) # eg2c = Grid.allocate(0.0) - # # eg2.draw_slab(eg2c, surface_normal=2, center=0, thickness=10, foreground=2) - # eg2.draw_cylinder(eg2c, surface_normal=1, center=[0, 0, 0], - # radius=4, thickness=10, num_points=1000, foreground=1.0) - # eg2.visualize_slice(eg2c, surface_normal=1, center=0, which_shifts=1) + # # eg2.draw_slab(eg2c, slab=dict(axis=2, center=0, span=10), foreground=2) + # eg2.draw_cylinder(eg2c, h=slab(axis=1, center=0, span=10), center2d=[0, 0], + # radius=4, num_points=1000, foreground=1.0) + # eg2.visualize_slice(eg2c, plane=dict(y=0), which_shifts=1) # n = 20 # m = 3 @@ -29,16 +29,27 @@ if __name__ == '__main__': # numpy.linspace(-5.5, 5.5, 10)] half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5] - xyz3 = [[-x for x in half_x[::-1]] + [0] + half_x, - numpy.linspace(-5.5, 5.5, 10), - numpy.linspace(-5.5, 5.5, 10)] + xyz3 = [numpy.array([-x for x in half_x[::-1]] + [0] + half_x, dtype=float), + numpy.linspace(-5.5, 5.5, 10, dtype=float), + numpy.linspace(-5.5, 5.5, 10, dtype=float)] eg = Grid(xyz3) egc = eg.allocate(0) # eg.draw_slab(Direction.z, 0, 10, 2) eg.save('/home/jan/Desktop/test.pickle') - eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0, - thickness=10, num_poitns=1000, foreground=1) - eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]], - direction=1, poalarity=+1, distance=5) - eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2) + eg.draw_cylinder( + egc, + h=dict(axis='z', center=0, span=10), + center2d=[0, 0], + radius=2.0, + num_points=1000, + foreground=1, + ) + eg.draw_extrude_rectangle( + egc, + rectangle=[[-2, 1, -1], [0, 1, 1]], + direction=1, + polarity=+1, + distance=5, + ) + eg.visualize_slice(egc, plane=dict(z=0), which_shifts=2) eg.visualize_isosurface(egc, which_shifts=2) diff --git a/gridlock/grid.py b/gridlock/grid.py index e320854..5790dbd 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -1,20 +1,23 @@ -from typing import List, Tuple, Callable, Dict, Optional, Union, Sequence, ClassVar, TypeVar +from typing import ClassVar, Self +from collections.abc import Callable, Sequence -import numpy # type: ignore -from numpy import diff, floor, ceil, zeros, hstack, newaxis +import numpy +from numpy.typing import NDArray, ArrayLike import pickle import warnings import copy from . import GridError +from .draw import GridDrawMixin +from .read import GridReadMixin +from .position import GridPosMixin -foreground_callable_type = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray] -T = TypeVar('T', bound='Grid') +foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray] -class Grid: +class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): """ Simulation grid metadata for finite-difference simulations. @@ -48,217 +51,35 @@ class Grid: Because of this, we either assume this 'ghost' cell is the same size as the last real cell, or, if `self.periodic[a]` is set to `True`, the same size as the first cell. """ - exyz: List[numpy.ndarray] + exyz: list[NDArray] """Cell edges. Monotonically increasing without duplicates.""" - periodic: List[bool] + periodic: list[bool] """For each axis, determines how far the rightmost boundary gets shifted. """ - shifts: numpy.ndarray + shifts: NDArray """Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`""" - Yee_Shifts_E: ClassVar[numpy.ndarray] = 0.5 * numpy.array([[1, 0, 0], - [0, 1, 0], - [0, 0, 1]], dtype=float) + Yee_Shifts_E: ClassVar[NDArray] = 0.5 * numpy.array([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], dtype=float) """Default shifts for Yee grid E-field""" - Yee_Shifts_H: ClassVar[numpy.ndarray] = 0.5 * numpy.array([[0, 1, 1], - [1, 0, 1], - [1, 1, 0]], dtype=float) + Yee_Shifts_H: ClassVar[NDArray] = 0.5 * numpy.array([ + [0, 1, 1], + [1, 0, 1], + [1, 1, 0], + ], dtype=float) """Default shifts for Yee grid H-field""" - from .draw import ( - draw_polygons, draw_polygon, draw_slab, draw_cuboid, - draw_cylinder, draw_extrude_rectangle, - ) - from .read import get_slice, visualize_slice, visualize_isosurface - from .position import ind2pos, pos2ind - - @property - def dxyz(self) -> List[numpy.ndarray]: - """ - Cell sizes for each axis, no shifts applied - - Returns: - List of 3 ndarrays of cell sizes - """ - return [numpy.diff(ee) for ee in self.exyz] - - @property - def xyz(self) -> List[numpy.ndarray]: - """ - Cell centers for each axis, no shifts applied - - Returns: - List of 3 ndarrays of cell edges - """ - return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)] - - @property - def shape(self) -> numpy.ndarray: - """ - The number of cells in x, y, and z - - Returns: - ndarray of [x_centers.size, y_centers.size, z_centers.size] - """ - return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int) - - @property - def num_grids(self) -> int: - """ - The number of grids (number of shifts) - """ - return self.shifts.shape[0] - - @property - def cell_data_shape(self): - """ - The shape of the cell_data ndarray (num_grids, *self.shape). - """ - return numpy.hstack((self.num_grids, self.shape)) - - @property - def dxyz_with_ghost(self) -> List[numpy.ndarray]: - """ - Gives dxyz with an additional 'ghost' cell at the end, whose value depends - on whether or not the axis has periodic boundary conditions. See main description - above to learn why this is necessary. - - If periodic, final edge shifts same amount as first - Otherwise, final edge shifts same amount as second-to-last - - Returns: - list of [dxs, dys, dzs] with each element same length as elements of `self.xyz` - """ - el = [0 if p else -1 for p in self.periodic] - return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el)] - - @property - def center(self) -> numpy.ndarray: - """ - Center position of the entire grid, no shifts applied - - Returns: - ndarray of [x_center, y_center, z_center] - """ - # center is just average of first and last xyz, which is just the average of the - # first two and last two exyz - centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)] - return numpy.array(centers, dtype=float) - - @property - def dxyz_limits(self) -> Tuple[numpy.ndarray, numpy.ndarray]: - """ - Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element - ndarrays. No shifts are applied, so these are extreme bounds on these values (as a - weighted average is performed when shifting). - - Returns: - Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]` - """ - d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float) - d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float) - return d_min, d_max - - def shifted_exyz(self, which_shifts: Optional[int]) -> List[numpy.ndarray]: - """ - Returns edges for which_shifts. - - Args: - which_shifts: Which grid (which shifts) to use, or `None` for unshifted - - Returns: - List of 3 ndarrays of cell edges - """ - if which_shifts is None: - return self.exyz - dxyz = self.dxyz_with_ghost - shifts = self.shifts[which_shifts, :] - - # If shift is negative, use left cell's dx to determine shift - for a in range(3): - if shifts[a] < 0: - dxyz[a] = numpy.roll(dxyz[a], 1) - - return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] - - def shifted_dxyz(self, which_shifts: Optional[int]) -> List[numpy.ndarray]: - """ - Returns cell sizes for `which_shifts`. - - Args: - which_shifts: Which grid (which shifts) to use, or `None` for unshifted - - Returns: - List of 3 ndarrays of cell sizes - """ - if which_shifts is None: - return self.dxyz - shifts = self.shifts[which_shifts, :] - dxyz = self.dxyz_with_ghost - - # If shift is negative, use left cell's dx to determine size - sdxyz = [] - for a in range(3): - if shifts[a] < 0: - roll_dxyz = numpy.roll(dxyz[a], 1) - abs_shift = numpy.abs(shifts[a]) - sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift)) - else: - sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a]) - - return sdxyz - - def shifted_xyz(self, which_shifts: Optional[int]) -> List[numpy.ndarray]: - """ - Returns cell centers for `which_shifts`. - - Args: - which_shifts: Which grid (which shifts) to use, or `None` for unshifted - - Returns: - List of 3 ndarrays of cell centers - """ - if which_shifts is None: - return self.xyz - exyz = self.shifted_exyz(which_shifts) - dxyz = self.shifted_dxyz(which_shifts) - return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)] - - def autoshifted_dxyz(self) -> List[numpy.ndarray]: - """ - Return cell widths, with each dimension shifted by the corresponding shifts. - - Returns: - `[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]` - """ - if self.num_grids != 3: - raise GridError('Autoshifting requires exactly 3 grids') - return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] - - def allocate(self, fill_value: Optional[float] = 1.0, dtype=numpy.float32) -> numpy.ndarray: - """ - Allocate an ndarray for storing grid data. - - Args: - fill_value: Value to initialize the grid to. If None, an - uninitialized array is returned. - dtype: Numpy dtype for the array. Default is `numpy.float32`. - - Returns: - The allocated array - """ - if fill_value is None: - return numpy.empty(self.cell_data_shape, dtype=dtype) - else: - return numpy.full(self.cell_data_shape, fill_value, dtype=dtype) - - def __init__(self, - pixel_edge_coordinates: Sequence[numpy.ndarray], - shifts: numpy.ndarray = Yee_Shifts_E, - periodic: Union[bool, Sequence[bool]] = False, - ) -> None: + def __init__( + self, + pixel_edge_coordinates: Sequence[ArrayLike], + shifts: ArrayLike = Yee_Shifts_E, + periodic: bool | Sequence[bool] = False, + ) -> None: """ Args: pixel_edge_coordinates: 3-element list of (ndarrays or lists) specifying the @@ -273,11 +94,12 @@ class Grid: Raises: `GridError` on invalid input """ - self.exyz = [numpy.unique(pixel_edge_coordinates[i]) for i in range(3)] + edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates] + self.exyz = [numpy.unique(edges) for edges in edge_arrs] self.shifts = numpy.array(shifts, dtype=float) for i in range(3): - if len(self.exyz[i]) != len(pixel_edge_coordinates[i]): + if self.exyz[i].size != edge_arrs[i].size: warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2) if isinstance(periodic, bool): @@ -314,7 +136,7 @@ class Grid: g.__dict__.update(tmp_dict) return g - def save(self: T, filename: str) -> T: + def save(self, filename: str) -> Self: """ Save to file. @@ -328,7 +150,7 @@ class Grid: pickle.dump(self.__dict__, f, protocol=2) return self - def copy(self: T) -> T: + def copy(self) -> Self: """ Returns: Deep copy of the grid. diff --git a/gridlock/position.py b/gridlock/position.py index 1224a12..b705b99 100644 --- a/gridlock/position.py +++ b/gridlock/position.py @@ -1,115 +1,118 @@ """ Position-related methods for Grid class """ -from typing import List, Optional - -import numpy # type: ignore +import numpy +from numpy.typing import NDArray, ArrayLike from . import GridError +from .base import GridBase -def ind2pos(self, - ind: numpy.ndarray, - which_shifts: Optional[int] = None, +class GridPosMixin(GridBase): + def ind2pos( + self, + ind: NDArray, + which_shifts: int | None = None, round_ind: bool = True, check_bounds: bool = True - ) -> numpy.ndarray: - """ - Returns the natural position corresponding to the specified cell center indices. - The resulting position is clipped to the bounds of the grid - (to cell centers if `round_ind=True`, or cell outer edges if `round_ind=False`) + ) -> NDArray[numpy.float64]: + """ + Returns the natural position corresponding to the specified cell center indices. + The resulting position is clipped to the bounds of the grid + (to cell centers if `round_ind=True`, or cell outer edges if `round_ind=False`) - Args: - ind: Indices of the position. Can be fractional. (3-element ndarray or list) - which_shifts: which grid number (`shifts`) to use - round_ind: Whether to round ind to the nearest integer position before indexing - (default `True`) - check_bounds: Whether to raise an `GridError` if the provided ind is outside of - the grid, as defined above (centers if `round_ind`, else edges) (default `True`) + Args: + ind: Indices of the position. Can be fractional. (3-element ndarray or list) + which_shifts: which grid number (`shifts`) to use + round_ind: Whether to round ind to the nearest integer position before indexing + (default `True`) + check_bounds: Whether to raise an `GridError` if the provided ind is outside of + the grid, as defined above (centers if `round_ind`, else edges) (default `True`) - Returns: - 3-element ndarray specifying the natural position + Returns: + 3-element ndarray specifying the natural position - Raises: - `GridError` if invalid `which_shifts` - `GridError` if `check_bounds` and out of bounds - """ - if which_shifts is not None and which_shifts >= self.shifts.shape[0]: - raise GridError('Invalid shifts') - ind = numpy.array(ind, dtype=float) + Raises: + `GridError` if invalid `which_shifts` + `GridError` if `check_bounds` and out of bounds + """ + if which_shifts is not None and which_shifts >= self.shifts.shape[0]: + raise GridError('Invalid shifts') + ind = numpy.array(ind, dtype=float) + + if check_bounds: + if round_ind: + low_bound = 0.0 + high_bound = -1.0 + else: + low_bound = -0.5 + high_bound = -0.5 + if (ind < low_bound).any() or (ind > self.shape - high_bound).any(): + raise GridError(f'Position outside of grid: {ind}') - if check_bounds: if round_ind: - low_bound = 0.0 - high_bound = -1.0 + rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) + sxyz = self.shifted_xyz(which_shifts) + position = [sxyz[a][rind[a]].astype(int) for a in range(3)] else: - low_bound = -0.5 - high_bound = -0.5 - if (ind < low_bound).any() or (ind > self.shape - high_bound).any(): - raise GridError(f'Position outside of grid: {ind}') - - if round_ind: - rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) - sxyz = self.shifted_xyz(which_shifts) - position = [sxyz[a][rind[a]].astype(int) for a in range(3)] - else: - sexyz = self.shifted_exyz(which_shifts) - position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a]) - for a in range(3)] - return numpy.array(position, dtype=float) + sexyz = self.shifted_exyz(which_shifts) + position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a]) + for a in range(3)] + return numpy.array(position, dtype=float) -def pos2ind(self, - r: numpy.ndarray, - which_shifts: Optional[int], + def pos2ind( + self, + r: ArrayLike, + which_shifts: int | None, round_ind: bool = True, check_bounds: bool = True - ) -> numpy.ndarray: - """ - Returns the cell-center indices corresponding to the specified natural position. - The resulting position is clipped to within the outer centers of the grid. + ) -> NDArray[numpy.float64]: + """ + Returns the cell-center indices corresponding to the specified natural position. + The resulting position is clipped to within the outer centers of the grid. - Args: - r: Natural position that we will convert into indices (3-element ndarray or list) - which_shifts: which grid number (`shifts`) to use - round_ind: Whether to round the returned indices to the nearest integers. - check_bounds: Whether to throw an `GridError` if `r` is outside the grid edges + Args: + r: Natural position that we will convert into indices (3-element ndarray or list) + which_shifts: which grid number (`shifts`) to use + round_ind: Whether to round the returned indices to the nearest integers. + check_bounds: Whether to throw an `GridError` if `r` is outside the grid edges - Returns: - 3-element ndarray specifying the indices + Returns: + 3-element ndarray specifying the indices - Raises: - `GridError` if invalid `which_shifts` - `GridError` if `check_bounds` and out of bounds - """ - r = numpy.squeeze(r) - if r.size != 3: - raise GridError(f'r must be 3-element vector: {r}') + Raises: + `GridError` if invalid `which_shifts` + `GridError` if `check_bounds` and out of bounds + """ + r = numpy.squeeze(r) + if r.size != 3: + raise GridError(f'r must be 3-element vector: {r}') - if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]): - raise GridError(f'Invalid which_shifts: {which_shifts}') + if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]): + raise GridError(f'Invalid which_shifts: {which_shifts}') - sexyz = self.shifted_exyz(which_shifts) + sexyz = self.shifted_exyz(which_shifts) - if check_bounds: + if check_bounds: + for a in range(3): + if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]): + raise GridError(f'Position[{a}] outside of grid!') + + grid_pos = numpy.zeros((3,)) for a in range(3): - if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]): - raise GridError(f'Position[{a}] outside of grid!') + xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in + xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds - grid_pos = numpy.zeros((3,)) - for a in range(3): - xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in - xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds + # No need to interpolate if round_ind is true or we were outside the grid + if round_ind or xi != xi_clipped: + grid_pos[a] = xi_clipped + else: + # Interpolate + x = self.shifted_xyz(which_shifts)[a][xi] + dx = self.shifted_dxyz(which_shifts)[a][xi] + f = (r[a] - x) / dx - # No need to interpolate if round_ind is true or we were outside the grid - if round_ind or xi != xi_clipped: - grid_pos[a] = xi_clipped - else: - # Interpolate - x = self.shifted_xyz(which_shifts)[a][xi] - dx = self.shifted_dxyz(which_shifts)[a][xi] - f = (r[a] - x) / dx - - # Clip to centers - grid_pos[a] = numpy.clip(xi + f, 0, self.shape[a] - 1) - return grid_pos + # Clip to centers + grid_pos[a] = numpy.clip(xi + f, 0, self.shape[a] - 1) + return grid_pos diff --git a/gridlock/read.py b/gridlock/read.py index aa059d5..707251a 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -1,183 +1,203 @@ """ Readback and visualization methods for Grid class """ -from typing import Dict, Optional, Union, Any +from typing import Any, TYPE_CHECKING -import numpy # type: ignore +import numpy +from numpy.typing import NDArray + +from .utils import GridError, Plane, PlaneDict, PlaneProtocol +from .position import GridPosMixin + +if TYPE_CHECKING: + import matplotlib.axes + import matplotlib.figure -from . import GridError # .visualize_* uses matplotlib # .visualize_isosurface uses skimage # .visualize_isosurface uses mpl_toolkits.mplot3d -def get_slice(self, - cell_data: numpy.ndarray, - surface_normal: int, - center: float, - which_shifts: int = 0, - sample_period: int = 1 - ) -> numpy.ndarray: - """ - Retrieve a slice of a grid. - Interpolates if given a position between two planes. +class GridReadMixin(GridPosMixin): + def get_slice( + self, + cell_data: NDArray, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + sample_period: int = 1 + ) -> NDArray: + """ + Retrieve a slice of a grid. + Interpolates if given a position between two grid planes. - Args: - cell_data: Cell data to slice - surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. - center: Scalar specifying position along surface_normal axis. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) + Args: + cell_data: Cell data to slice + plane: Axis and position (`Plane`) of the plane to read. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) - Returns: - Array containing the portion of the grid. - """ - if numpy.size(center) != 1 or not numpy.isreal(center): - raise GridError('center must be a real scalar') + Returns: + Array containing the portion of the grid. + """ + if isinstance(plane, dict): + plane = Plane(**plane) - sp = round(sample_period) - if sp <= 0: - raise GridError('sample_period must be positive') + sp = round(sample_period) + if sp <= 0: + raise GridError('sample_period must be positive') - if numpy.size(which_shifts) != 1 or which_shifts < 0: - raise GridError('Invalid which_shifts') + if numpy.size(which_shifts) != 1 or which_shifts < 0: + raise GridError('Invalid which_shifts') - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') + surface = numpy.delete(range(3), plane.axis) - surface = numpy.delete(range(3), surface_normal) + # Extract indices and weights of planes + center3 = numpy.insert([0, 0], plane.axis, (plane.pos,)) + center_index = self.pos2ind(center3, which_shifts, + round_ind=False, check_bounds=False)[plane.axis] + centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) + if len(centers) == 2: + fpart = center_index - numpy.floor(center_index) + w = [1 - fpart, fpart] # longer distance -> less weight + else: + w = [1] - # Extract indices and weights of planes - center3 = numpy.insert([0, 0], surface_normal, (center,)) - center_index = self.pos2ind(center3, which_shifts, - round_ind=False, check_bounds=False)[surface_normal] - centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) - if len(centers) == 2: - fpart = center_index - numpy.floor(center_index) - w = [1 - fpart, fpart] # longer distance -> less weight - else: - w = [1] + c_min, c_max = (self.xyz[plane.axis][i] for i in [0, -1]) + if plane.pos < c_min or plane.pos > c_max: + raise GridError('Coordinate of selected plane must be within simulation domain') - c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1]) - if center < c_min or center > c_max: - raise GridError('Coordinate of selected plane must be within simulation domain') + # Extract grid values from planes above and below visualized slice + sliced_grid = numpy.zeros(self.shape[surface]) + for ci, weight in zip(centers, w, strict=True): + s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3)) + sliced_grid += weight * cell_data[which_shifts][tuple(s)] - # Extract grid values from planes above and below visualized slice - sliced_grid = numpy.zeros(self.shape[surface]) - for ci, weight in zip(centers, w): - s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3)) - sliced_grid += weight * cell_data[which_shifts][tuple(s)] + # Remove extra dimensions + sliced_grid = numpy.squeeze(sliced_grid) - # Remove extra dimensions - sliced_grid = numpy.squeeze(sliced_grid) - - return sliced_grid + return sliced_grid -def visualize_slice(self, - cell_data: numpy.ndarray, - surface_normal: int, - center: float, - which_shifts: int = 0, - sample_period: int = 1, - finalize: bool = True, - pcolormesh_args: Optional[Dict[str, Any]] = None, - ) -> None: - """ - Visualize a slice of a grid. - Interpolates if given a position between two planes. + def visualize_slice( + self, + cell_data: NDArray, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + sample_period: int = 1, + finalize: bool = True, + pcolormesh_args: dict[str, Any] | None = None, + ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: + """ + Visualize a slice of a grid. + Interpolates if given a position between two grid planes. - Args: - surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. - center: Scalar specifying position along surface_normal axis. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) - finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - """ - from matplotlib import pyplot + Args: + cell_data: Cell data to visualize + plane: Axis and position (`Plane`) of the plane to read. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) + finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - if pcolormesh_args is None: - pcolormesh_args = {} + Returns: + (Figure, Axes) + """ + from matplotlib import pyplot - grid_slice = self.get_slice(cell_data=cell_data, - surface_normal=surface_normal, - center=center, - which_shifts=which_shifts, - sample_period=sample_period) + if isinstance(plane, dict): + plane = Plane(**plane) - surface = numpy.delete(range(3), surface_normal) + if pcolormesh_args is None: + pcolormesh_args = {} - x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) - xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') - x_label, y_label = ('xyz'[a] for a in surface) + grid_slice = self.get_slice( + cell_data=cell_data, + plane=plane, + which_shifts=which_shifts, + sample_period=sample_period, + ) - pyplot.figure() - pyplot.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) - pyplot.colorbar() - pyplot.gca().set_aspect('equal', adjustable='box') - pyplot.xlabel(x_label) - pyplot.ylabel(y_label) - if finalize: - pyplot.show() + surface = numpy.delete(range(3), plane.axis) + + x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) + xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') + x_label, y_label = ('xyz'[a] for a in surface) + + fig, ax = pyplot.subplots() + mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) + fig.colorbar(mappable) + ax.set_aspect('equal', adjustable='box') + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + if finalize: + pyplot.show() + + return fig, ax -def visualize_isosurface(self, - cell_data: numpy.ndarray, - level: Optional[float] = None, - which_shifts: int = 0, - sample_period: int = 1, - show_edges: bool = True, - finalize: bool = True, - ) -> None: - """ - Draw an isosurface plot of the device. + def visualize_isosurface( + self, + cell_data: NDArray, + level: float | None = None, + which_shifts: int = 0, + sample_period: int = 1, + show_edges: bool = True, + finalize: bool = True, + ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: + """ + Draw an isosurface plot of the device. - Args: - cell_data: Cell data to visualize - level: Value at which to find isosurface. Default (None) uses mean value in grid. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) - show_edges: Whether to draw triangle edges. Default `True` - finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - """ - from matplotlib import pyplot - import skimage.measure - # Claims to be unused, but needed for subplot(projection='3d') - from mpl_toolkits.mplot3d import Axes3D + Args: + cell_data: Cell data to visualize + level: Value at which to find isosurface. Default (None) uses mean value in grid. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) + show_edges: Whether to draw triangle edges. Default `True` + finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - # Get data from cell_data - grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period] - if level is None: - level = grid.mean() + Returns: + (Figure, Axes) + """ + from matplotlib import pyplot + import skimage.measure + # Claims to be unused, but needed for subplot(projection='3d') + from mpl_toolkits.mplot3d import Axes3D + del Axes3D # imported for side effects only - # Find isosurface with marching cubes - verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) + # Get data from cell_data + grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period] + if level is None: + level = grid.mean() - # Convert vertices from index to position - pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False) - for i in range(verts.shape[0])], dtype=float) - xs, ys, zs = (pos_verts[:, a] for a in range(3)) + # Find isosurface with marching cubes + verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) - # Draw the plot - fig = pyplot.figure() - ax = fig.add_subplot(111, projection='3d') - if show_edges: - ax.plot_trisurf(xs, ys, faces, zs) - else: - ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') + # Convert vertices from index to position + pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False) + for i in range(verts.shape[0])], dtype=float) + xs, ys, zs = (pos_verts[:, a] for a in range(3)) - # Add a fake plot of a cube to force the axes to be equal lengths - max_range = numpy.array([xs.max() - xs.min(), - ys.max() - ys.min(), - zs.max() - zs.min()], dtype=float).max() - mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2] - xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min()) - ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min()) - zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min()) - # Comment or uncomment following both lines to test the fake bounding box: - for xb, yb, zb in zip(xbs, ybs, zbs): - ax.plot([xb], [yb], [zb], 'w') + # Draw the plot + fig = pyplot.figure() + ax = fig.add_subplot(111, projection='3d') + if show_edges: + ax.plot_trisurf(xs, ys, faces, zs) # type: ignore + else: + ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') # type: ignore - if finalize: - pyplot.show() + # Add a fake plot of a cube to force the axes to be equal lengths + max_range = numpy.array([xs.max() - xs.min(), + ys.max() - ys.min(), + zs.max() - zs.min()], dtype=float).max() + mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2] + xbs = 0.5 * max_range * mg[0].ravel() + 0.5 * (xs.max() + xs.min()) + ybs = 0.5 * max_range * mg[1].ravel() + 0.5 * (ys.max() + ys.min()) + zbs = 0.5 * max_range * mg[2].ravel() + 0.5 * (zs.max() + zs.min()) + # Comment or uncomment following both lines to test the fake bounding box: + for xb, yb, zb in zip(xbs, ybs, zbs, strict=True): + ax.plot([xb], [yb], [zb], 'w') + + if finalize: + pyplot.show() + + return fig, ax diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index fc54030..8d9ca92 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,8 +1,8 @@ -import pytest # type: ignore -import numpy # type: ignore -from numpy.testing import assert_allclose, assert_array_equal # type: ignore +# import pytest +import numpy +from numpy.testing import assert_allclose #, assert_array_equal -from .. import Grid +from .. import Grid, Extent #, Slab, Plane def test_draw_oncenter_2x2() -> None: @@ -12,7 +12,13 @@ def test_draw_oncenter_2x2() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0, 0, 0], dimensions=[1, 1, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(center=0, span=1), + y=Extent(center=0, span=1), + z=dict(center=0, span=10), + foreground=1, + ) correct = numpy.array([[0.25, 0.25], [0.25, 0.25]])[None, :, :, None] @@ -27,7 +33,13 @@ def test_draw_ongrid_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0, 0, 0], dimensions=[2, 2, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(center=0, span=2), + y=dict(min=-1, max=1), + z=dict(center=0, min=-5), + foreground=1, + ) correct = numpy.array([[0, 0, 0, 0], [0, 1, 1, 0], @@ -44,7 +56,13 @@ def test_draw_xshift_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0.5, 0, 0], dimensions=[1.5, 2, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(center=0.5, span=1.5), + y=dict(min=-1, max=1), + z=dict(center=0, span=10), + foreground=1, + ) correct = numpy.array([[0, 0, 0, 0], [0, 0.25, 0.25, 0], @@ -61,7 +79,13 @@ def test_draw_yshift_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0, 0.5, 0], dimensions=[2, 1.5, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(min=-1, max=1), + y=dict(center=0.5, span=1.5), + z=dict(center=0, span=10), + foreground=1, + ) correct = numpy.array([[0, 0, 0, 0], [0, 0.25, 1, 0.25], @@ -78,7 +102,13 @@ def test_draw_2shift_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0.5, 0, 0], dimensions=[1.5, 1, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(center=0.5, span=1.5), + y=dict(min=-0.5, max=0.5), + z=dict(center=0, span=10), + foreground=1, + ) correct = numpy.array([[0, 0, 0, 0], [0, 0.125, 0.125, 0], diff --git a/gridlock/utils.py b/gridlock/utils.py new file mode 100644 index 0000000..8a8f11d --- /dev/null +++ b/gridlock/utils.py @@ -0,0 +1,234 @@ +from typing import Protocol, TypedDict, runtime_checkable, cast +from dataclasses import dataclass + + +class GridError(Exception): + """ Base error type for `gridlock` """ + pass + + +class ExtentDict(TypedDict, total=False): + """ + Geometrical definition of an extent (1D bounded region) + Must contain exactly two of `min`, `max`, `center`, or `span`. + """ + min: float + center: float + max: float + span: float + + +@runtime_checkable +class ExtentProtocol(Protocol): + """ + Anything that looks like an `Extent` + """ + center: float + span: float + + @property + def max(self) -> float: ... + + @property + def min(self) -> float: ... + + +@dataclass(init=False, slots=True) +class Extent(ExtentProtocol): + """ + Geometrical definition of an extent (1D bounded region) + May be constructed with any two of `min`, `max`, `center`, or `span`. + """ + center: float + span: float + + @property + def max(self) -> float: + return self.center + self.span / 2 + + @property + def min(self) -> float: + return self.center - self.span / 2 + + def __init__( + self, + *, + min: float | None = None, + center: float | None = None, + max: float | None = None, + span: float | None = None, + ) -> None: + if sum(cc is None for cc in (min, center, max, span)) != 2: + raise GridError('Exactly two of min, center, max, span must be None!') + + if span is None: + if center is None: + assert min is not None + assert max is not None + assert max >= min + center = 0.5 * (max + min) + span = max - min + elif max is None: + assert min is not None + assert center is not None + span = 2 * (center - min) + elif min is None: + assert center is not None + assert max is not None + span = 2 * (max - center) + else: # noqa: PLR5501 + if center is not None: + pass + elif max is None: + assert min is not None + assert span is not None + center = min + 0.5 * span + elif min is None: + assert max is not None + assert span is not None + center = max - 0.5 * span + + assert center is not None + assert span is not None + if hasattr(center, '__len__'): + assert len(center) == 1 + if hasattr(span, '__len__'): + assert len(span) == 1 + self.center = center + self.span = span + + +class SlabDict(TypedDict, total=False): + """ + Geometrical definition of a slab (3D region bounded on one axis only) + Must contain `axis` plus any two of `min`, `max`, `center`, or `span`. + """ + min: float + center: float + max: float + span: float + axis: int | str + + +@runtime_checkable +class SlabProtocol(ExtentProtocol, Protocol): + """ + Anything that looks like a `Slab` + """ + axis: int + center: float + span: float + + @property + def max(self) -> float: ... + + @property + def min(self) -> float: ... + + +@dataclass(init=False, slots=True) +class Slab(Extent, SlabProtocol): + """ + Geometrical definition of a slab (3D region bounded on one axis only) + May be constructed with `axis` (bounded axis) plus any two of `min`, `max`, `center`, or `span`. + """ + axis: int + + def __init__( + self, + axis: int | str, + *, + min: float | None = None, + center: float | None = None, + max: float | None = None, + span: float | None = None, + ) -> None: + Extent.__init__(self, min=min, center=center, max=max, span=span) + + if isinstance(axis, str): + axis_int = 'xyz'.find(axis.lower()) + else: + axis_int = axis + if axis_int not in range(3): + raise GridError(f'Invalid axis (slab normal direction): {axis}') + self.axis = axis_int + + def as_plane(self, where: str) -> 'Plane': + if where == 'center': + return Plane(axis=self.axis, pos=self.center) + if where == 'min': + return Plane(axis=self.axis, pos=self.min) + if where == 'max': + return Plane(axis=self.axis, pos=self.max) + raise GridError(f'Invalid {where=}') + + +class PlaneDict(TypedDict, total=False): + """ + Geometrical definition of a plane (2D unbounded region in 3D space) + Must contain exactly one of `x`, `y`, `z`, or both `axis` and `pos` + """ + x: float + y: float + z: float + axis: int + pos: float + + +@runtime_checkable +class PlaneProtocol(Protocol): + """ + Anything that looks like a `Plane` + """ + axis: int + pos: float + + +@dataclass(init=False, slots=True) +class Plane(PlaneProtocol): + """ + Geometrical definition of a plane (2D unbounded region in 3D space) + May be constructed with any of `x=4`, `y=5`, `z=-5`, or `axis=2, pos=-5`. + """ + axis: int + pos: float + + def __init__( + self, + *, + axis: int | str | None = None, + pos: float | None = None, + x: float | None = None, + y: float | None = None, + z: float | None = None, + ) -> None: + xx = x + yy = y + zz = z + + if sum(aa is not None for aa in (pos, xx, yy, zz)) != 1: + raise GridError('Exactly one of pos, x, y, z must be non-None!') + if (axis is None) != (pos is None): + raise GridError('Either both or neither of `axis` and `pos` must be defined.') + + if isinstance(axis, str): + axis_int = 'xyz'.find(axis.lower()) + elif axis is None: + axis_int = (xx is None, yy is None, zz is None).index(False) + else: + axis_int = axis + + if axis_int not in range(3): + raise GridError(f'Invalid axis (slab normal direction): {axis=} {x=} {y=} {z=}') + self.axis = axis_int + + if pos is not None: + cpos = pos + else: + cpos = cast('float', (xx, yy, zz)[axis_int]) + assert cpos is not None + + if hasattr(cpos, '__len__'): + assert len(cpos) == 1 + self.pos = cpos + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..03d0d19 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,98 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "gridlock" +description = "Coupled gridding library" +readme = "README.md" +license = { file = "LICENSE.md" } +authors = [ + { name="Jan Petykiewicz", email="jan@mpxd.net" }, + ] +homepage = "https://mpxd.net/code/jan/gridlock" +repository = "https://mpxd.net/code/jan/gridlock" +keywords = [ + "FDTD", + "gridding", + "simulation", + "nonuniform", + "FDFD", + "finite", + "difference", + ] +classifiers = [ + "Programming Language :: Python :: 3", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: GNU Affero General Public License v3", + "Topic :: Multimedia :: Graphics :: 3D Rendering", + "Topic :: Scientific/Engineering :: Electronic Design Automation (EDA)", + "Topic :: Scientific/Engineering :: Physics", + "Topic :: Scientific/Engineering :: Visualization", + ] +requires-python = ">=3.11" +include = [ + "LICENSE.md" + ] +dynamic = ["version"] +dependencies = [ + "numpy>=1.26", + "float_raster>=0.8", + ] + + +[tool.hatch.version] +path = "gridlock/__init__.py" + +[project.optional-dependencies] +visualization = ["matplotlib"] +visualization-isosurface = [ + "matplotlib", + "skimage>=0.13", + "mpl_toolkits", + ] + + +[tool.ruff] +exclude = [ + ".git", + "dist", + ] +line-length = 145 +indent-width = 4 +lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +lint.select = [ + "NPY", "E", "F", "W", "B", "ANN", "UP", "SLOT", "SIM", "LOG", + "C4", "ISC", "PIE", "PT", "RET", "TCH", "PTH", "INT", + "ARG", "PL", "R", "TRY", + "G010", "G101", "G201", "G202", + "Q002", "Q003", "Q004", + ] +lint.ignore = [ + #"ANN001", # No annotation + "ANN002", # *args + "ANN003", # **kwargs + "ANN401", # Any + "SIM108", # single-line if / else assignment + "RET504", # x=y+z; return x + "PIE790", # unnecessary pass + "ISC003", # non-implicit string concatenation + "C408", # dict(x=y) instead of {'x': y} + "PLR09", # Too many xxx + "PLR2004", # magic number + "PLC0414", # import x as x + "TRY003", # Long exception message + "PTH123", # open() + ] + + +[[tool.mypy.overrides]] +module = [ + "matplotlib", + "matplotlib.axes", + "matplotlib.figure", + "mpl_toolkits.mplot3d", + ] +ignore_missing_imports = true diff --git a/setup.py b/setup.py deleted file mode 100644 index a424b97..0000000 --- a/setup.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env python3 - -from setuptools import setup, find_packages - -with open('README.md', 'r') as f: - long_description = f.read() - -with open('gridlock/VERSION.py', 'rt') as f: - version = f.readlines()[2].strip() - -setup(name='gridlock', - version=version, - description='Coupled gridding library', - long_description=long_description, - long_description_content_type='text/markdown', - author='Jan Petykiewicz', - author_email='jan@mpxd.net', - url='https://mpxd.net/code/jan/gridlock', - packages=find_packages(), - package_data={ - 'gridlock': ['py.typed'], - }, - install_requires=[ - 'numpy', - 'float_raster', - ], - extras_require={ - 'visualization': ['matplotlib'], - 'visualization-isosurface': [ - 'matplotlib', - 'skimage>=0.13', - 'mpl_toolkits', - ], - }, - classifiers=[ - 'Programming Language :: Python :: 3', - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: GNU Affero General Public License v3', - 'Topic :: Multimedia :: Graphics :: 3D Rendering', - 'Topic :: Scientific/Engineering :: Electronic Design Automation (EDA)', - 'Topic :: Scientific/Engineering :: Physics', - 'Topic :: Scientific/Engineering :: Visualization', - ], - ) -