From e1303b8a5c5a52cb026fe4ddb59b5f014b73c83a Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Aug 2022 13:07:46 -0700 Subject: [PATCH 01/48] move to hatch-based builds --- MANIFEST.in | 2 -- gridlock/LICENSE.md | 1 + gridlock/README.md | 1 + gridlock/VERSION.py | 4 ---- gridlock/__init__.py | 3 +-- pyproject.toml | 55 ++++++++++++++++++++++++++++++++++++++++++++ setup.py | 47 ------------------------------------- 7 files changed, 58 insertions(+), 55 deletions(-) delete mode 100644 MANIFEST.in create mode 120000 gridlock/LICENSE.md create mode 120000 gridlock/README.md delete mode 100644 gridlock/VERSION.py create mode 100644 pyproject.toml delete mode 100644 setup.py diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index c28ab72..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,2 +0,0 @@ -include README.md -include LICENSE.md diff --git a/gridlock/LICENSE.md b/gridlock/LICENSE.md new file mode 120000 index 0000000..7eabdb1 --- /dev/null +++ b/gridlock/LICENSE.md @@ -0,0 +1 @@ +../LICENSE.md \ No newline at end of file diff --git a/gridlock/README.md b/gridlock/README.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/gridlock/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/gridlock/VERSION.py b/gridlock/VERSION.py deleted file mode 100644 index 8e6abf6..0000000 --- a/gridlock/VERSION.py +++ /dev/null @@ -1,4 +0,0 @@ -""" VERSION defintion. THIS FILE IS MANUALLY PARSED BY setup.py and REQUIRES A SPECIFIC FORMAT """ -__version__ = ''' -1.0 -'''.strip() diff --git a/gridlock/__init__.py b/gridlock/__init__.py index d547794..591ad12 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -19,6 +19,5 @@ from .error import GridError from .grid import Grid __author__ = 'Jan Petykiewicz' - -from .VERSION import __version__ +__version__ = '1.0' version = __version__ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e9ac6be --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,55 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "gridlock" +description = "Coupled gridding library" +readme = "README.md" +license = { file = "LICENSE.md" } +authors = [ + { name="Jan Petykiewicz", email="jan@mpxd.net" }, + ] +homepage = "https://mpxd.net/code/jan/gridlock" +repository = "https://mpxd.net/code/jan/gridlock" +keywords = [ + "FDTD", + "gridding", + "simulation", + "nonuniform", + "FDFD", + "finite", + "difference", + ] +classifiers = [ + "Programming Language :: Python :: 3", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: GNU Affero General Public License v3", + "Topic :: Multimedia :: Graphics :: 3D Rendering", + "Topic :: Scientific/Engineering :: Electronic Design Automation (EDA)", + "Topic :: Scientific/Engineering :: Physics", + "Topic :: Scientific/Engineering :: Visualization", + ] +requires-python = ">=3.8" +include = [ + "LICENSE.md" + ] +dynamic = ["version"] +dependencies = [ + "numpy~=1.21", + "float_raster", + ] + + +[tool.hatch.version] +path = "gridlock/__init__.py" + +[project.optional-dependencies] +visualization = ["matplotlib"] +visualization-isosurface = [ + "matplotlib", + "skimage>=0.13", + "mpl_toolkits", + ] diff --git a/setup.py b/setup.py deleted file mode 100644 index a424b97..0000000 --- a/setup.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env python3 - -from setuptools import setup, find_packages - -with open('README.md', 'r') as f: - long_description = f.read() - -with open('gridlock/VERSION.py', 'rt') as f: - version = f.readlines()[2].strip() - -setup(name='gridlock', - version=version, - description='Coupled gridding library', - long_description=long_description, - long_description_content_type='text/markdown', - author='Jan Petykiewicz', - author_email='jan@mpxd.net', - url='https://mpxd.net/code/jan/gridlock', - packages=find_packages(), - package_data={ - 'gridlock': ['py.typed'], - }, - install_requires=[ - 'numpy', - 'float_raster', - ], - extras_require={ - 'visualization': ['matplotlib'], - 'visualization-isosurface': [ - 'matplotlib', - 'skimage>=0.13', - 'mpl_toolkits', - ], - }, - classifiers=[ - 'Programming Language :: Python :: 3', - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: GNU Affero General Public License v3', - 'Topic :: Multimedia :: Graphics :: 3D Rendering', - 'Topic :: Scientific/Engineering :: Electronic Design Automation (EDA)', - 'Topic :: Scientific/Engineering :: Physics', - 'Topic :: Scientific/Engineering :: Visualization', - ], - ) - From 7d3b2272bcd6c2bfd4422dac324d09ba9b30237f Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Aug 2022 13:11:12 -0700 Subject: [PATCH 02/48] bump version to v1.1 --- gridlock/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 591ad12..abcadc2 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -19,5 +19,5 @@ from .error import GridError from .grid import Grid __author__ = 'Jan Petykiewicz' -__version__ = '1.0' +__version__ = '1.1' version = __version__ From ec5c77e018ff13b052bbbea851c92d95b527ed21 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Tue, 18 Oct 2022 19:44:30 -0700 Subject: [PATCH 03/48] typing and formatting updates --- gridlock/draw.py | 113 ++++++++++++++++++++----------------- gridlock/examples/ex0.py | 2 +- gridlock/grid.py | 58 ++++++++++--------- gridlock/position.py | 31 +++++----- gridlock/read.py | 54 ++++++++++-------- gridlock/test/test_grid.py | 4 +- 6 files changed, 142 insertions(+), 120 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 6385213..5390871 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -3,7 +3,8 @@ Drawing-related methods for Grid class """ from typing import List, Optional, Union, Sequence, Callable -import numpy # type: ignore +import numpy +from numpy.typing import NDArray, ArrayLike from float_raster import raster from . import GridError @@ -15,17 +16,19 @@ from . import GridError # without having to pass `cell_data` again each time? -foreground_callable_t = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray] +foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray] +foreground_t = Union[float, foreground_callable_t] -def draw_polygons(self, - cell_data: numpy.ndarray, - surface_normal: int, - center: numpy.ndarray, - polygons: Sequence[numpy.ndarray], - thickness: float, - foreground: Union[Sequence[Union[float, foreground_callable_t]], float, foreground_callable_t], - ) -> None: +def draw_polygons( + self, + cell_data: NDArray, + surface_normal: int, + center: ArrayLike, + polygons: Sequence[NDArray], + thickness: float, + foreground: Union[Sequence[foreground_t], foreground_t], + ) -> None: """ Draw polygons on an axis-aligned plane. @@ -74,8 +77,8 @@ def draw_polygons(self, # ## Compute sub-domain of the grid occupied by polygons # 1) Compute outer bounds (bd) of polygons - bd_2d_min = [0, 0] - bd_2d_max = [0, 0] + bd_2d_min = numpy.array([0, 0]) + bd_2d_max = numpy.array([0, 0]) for polygon in polygons: bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) @@ -97,7 +100,7 @@ def draw_polygons(self, polygons = [poly + center[surface] for poly in polygons] # ## Generate weighing function - def to_3d(vector: numpy.ndarray, val: float = 0.0) -> numpy.ndarray: + def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]: v_2d = numpy.array(vector, dtype=float) return numpy.insert(v_2d, surface_normal, (val,)) @@ -188,14 +191,15 @@ def draw_polygons(self, cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_i -def draw_polygon(self, - cell_data: numpy.ndarray, - surface_normal: int, - center: numpy.ndarray, - polygon: numpy.ndarray, - thickness: float, - foreground: Union[Sequence[Union[float, foreground_callable_t]], float, foreground_callable_t], - ) -> None: +def draw_polygon( + self, + cell_data: NDArray, + surface_normal: int, + center: ArrayLike, + polygon: ArrayLike, + thickness: float, + foreground: Union[Sequence[foreground_t], foreground_t], + ) -> None: """ Draw a polygon on an axis-aligned plane. @@ -212,13 +216,14 @@ def draw_polygon(self, self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground) -def draw_slab(self, - cell_data: numpy.ndarray, - surface_normal: int, - center: numpy.ndarray, - thickness: float, - foreground: Union[List[Union[float, foreground_callable_t]], float, foreground_callable_t], - ) -> None: +def draw_slab( + self, + cell_data: NDArray, + surface_normal: int, + center: ArrayLike, + thickness: float, + foreground: Union[Sequence[foreground_t], foreground_t], + ) -> None: """ Draw an axis-aligned infinite slab. @@ -262,12 +267,13 @@ def draw_slab(self, self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground) -def draw_cuboid(self, - cell_data: numpy.ndarray, - center: numpy.ndarray, - dimensions: numpy.ndarray, - foreground: Union[List[Union[float, foreground_callable_t]], float, foreground_callable_t], - ) -> None: +def draw_cuboid( + self, + cell_data: NDArray, + center: ArrayLike, + dimensions: ArrayLike, + foreground: Union[Sequence[foreground_t], foreground_t], + ) -> None: """ Draw an axis-aligned cuboid @@ -278,6 +284,7 @@ def draw_cuboid(self, sizes of the cuboid foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. """ + dimensions = numpy.array(dimensions, copy=False) p = numpy.array([[-dimensions[0], +dimensions[1]], [+dimensions[0], +dimensions[1]], [+dimensions[0], -dimensions[1]], @@ -286,15 +293,16 @@ def draw_cuboid(self, self.draw_polygon(cell_data, 2, center, p, thickness, foreground) -def draw_cylinder(self, - cell_data: numpy.ndarray, - surface_normal: int, - center: numpy.ndarray, - radius: float, - thickness: float, - num_points: int, - foreground: Union[List[Union[float, foreground_callable_t]], float, foreground_callable_t], - ) -> None: +def draw_cylinder( + self, + cell_data: NDArray, + surface_normal: int, + center: ArrayLike, + radius: float, + thickness: float, + num_points: int, + foreground: Union[Sequence[foreground_t], foreground_t], + ) -> None: """ Draw an axis-aligned cylinder. Approximated by a num_points-gon @@ -314,13 +322,14 @@ def draw_cylinder(self, self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground) -def draw_extrude_rectangle(self, - cell_data: numpy.ndarray, - rectangle: numpy.ndarray, - direction: int, - polarity: int, - distance: float, - ) -> None: +def draw_extrude_rectangle( + self, + cell_data: NDArray, + rectangle: ArrayLike, + direction: int, + polarity: int, + distance: float, + ) -> None: """ Extrude a rectangle of a previously-drawn structure along an axis. @@ -361,10 +370,10 @@ def draw_extrude_rectangle(self, mult = [1-fpart, fpart][::s] # reverses if s negative foreground = mult[0] * grid[tuple(ind)] - ind[direction] += 1 + ind[direction] += 1 # type: ignore #(known safe) foreground += mult[1] * grid[tuple(ind)] - def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> numpy.ndarray: + def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int_]: # transform from natural position to index xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i) for qrs in zip(xs.flat, ys.flat, zs.flat)], dtype=int) diff --git a/gridlock/examples/ex0.py b/gridlock/examples/ex0.py index ca2ef55..59756bc 100644 --- a/gridlock/examples/ex0.py +++ b/gridlock/examples/ex0.py @@ -1,4 +1,4 @@ -import numpy # type: ignore +import numpy from gridlock import Grid diff --git a/gridlock/grid.py b/gridlock/grid.py index e320854..5f132a0 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -1,6 +1,7 @@ from typing import List, Tuple, Callable, Dict, Optional, Union, Sequence, ClassVar, TypeVar -import numpy # type: ignore +import numpy +from numpy.typing import NDArray, ArrayLike from numpy import diff, floor, ceil, zeros, hstack, newaxis import pickle @@ -10,7 +11,7 @@ import copy from . import GridError -foreground_callable_type = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray] +foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray] T = TypeVar('T', bound='Grid') @@ -48,23 +49,27 @@ class Grid: Because of this, we either assume this 'ghost' cell is the same size as the last real cell, or, if `self.periodic[a]` is set to `True`, the same size as the first cell. """ - exyz: List[numpy.ndarray] + exyz: List[NDArray] """Cell edges. Monotonically increasing without duplicates.""" periodic: List[bool] """For each axis, determines how far the rightmost boundary gets shifted. """ - shifts: numpy.ndarray + shifts: NDArray """Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`""" - Yee_Shifts_E: ClassVar[numpy.ndarray] = 0.5 * numpy.array([[1, 0, 0], - [0, 1, 0], - [0, 0, 1]], dtype=float) + Yee_Shifts_E: ClassVar[NDArray] = 0.5 * numpy.array([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], dtype=float) """Default shifts for Yee grid E-field""" - Yee_Shifts_H: ClassVar[numpy.ndarray] = 0.5 * numpy.array([[0, 1, 1], - [1, 0, 1], - [1, 1, 0]], dtype=float) + Yee_Shifts_H: ClassVar[NDArray] = 0.5 * numpy.array([ + [0, 1, 1], + [1, 0, 1], + [1, 1, 0], + ], dtype=float) """Default shifts for Yee grid H-field""" from .draw import ( @@ -75,7 +80,7 @@ class Grid: from .position import ind2pos, pos2ind @property - def dxyz(self) -> List[numpy.ndarray]: + def dxyz(self) -> List[NDArray]: """ Cell sizes for each axis, no shifts applied @@ -85,7 +90,7 @@ class Grid: return [numpy.diff(ee) for ee in self.exyz] @property - def xyz(self) -> List[numpy.ndarray]: + def xyz(self) -> List[NDArray]: """ Cell centers for each axis, no shifts applied @@ -95,7 +100,7 @@ class Grid: return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)] @property - def shape(self) -> numpy.ndarray: + def shape(self) -> NDArray[numpy.int_]: """ The number of cells in x, y, and z @@ -119,7 +124,7 @@ class Grid: return numpy.hstack((self.num_grids, self.shape)) @property - def dxyz_with_ghost(self) -> List[numpy.ndarray]: + def dxyz_with_ghost(self) -> List[NDArray]: """ Gives dxyz with an additional 'ghost' cell at the end, whose value depends on whether or not the axis has periodic boundary conditions. See main description @@ -135,7 +140,7 @@ class Grid: return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el)] @property - def center(self) -> numpy.ndarray: + def center(self) -> NDArray[numpy.float64]: """ Center position of the entire grid, no shifts applied @@ -148,7 +153,7 @@ class Grid: return numpy.array(centers, dtype=float) @property - def dxyz_limits(self) -> Tuple[numpy.ndarray, numpy.ndarray]: + def dxyz_limits(self) -> Tuple[NDArray, NDArray]: """ Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element ndarrays. No shifts are applied, so these are extreme bounds on these values (as a @@ -161,7 +166,7 @@ class Grid: d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float) return d_min, d_max - def shifted_exyz(self, which_shifts: Optional[int]) -> List[numpy.ndarray]: + def shifted_exyz(self, which_shifts: Optional[int]) -> List[NDArray]: """ Returns edges for which_shifts. @@ -183,7 +188,7 @@ class Grid: return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] - def shifted_dxyz(self, which_shifts: Optional[int]) -> List[numpy.ndarray]: + def shifted_dxyz(self, which_shifts: Optional[int]) -> List[NDArray]: """ Returns cell sizes for `which_shifts`. @@ -210,7 +215,7 @@ class Grid: return sdxyz - def shifted_xyz(self, which_shifts: Optional[int]) -> List[numpy.ndarray]: + def shifted_xyz(self, which_shifts: Optional[int]) -> List[NDArray[numpy.float64]]: """ Returns cell centers for `which_shifts`. @@ -226,7 +231,7 @@ class Grid: dxyz = self.shifted_dxyz(which_shifts) return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)] - def autoshifted_dxyz(self) -> List[numpy.ndarray]: + def autoshifted_dxyz(self) -> List[NDArray[numpy.float64]]: """ Return cell widths, with each dimension shifted by the corresponding shifts. @@ -237,7 +242,7 @@ class Grid: raise GridError('Autoshifting requires exactly 3 grids') return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] - def allocate(self, fill_value: Optional[float] = 1.0, dtype=numpy.float32) -> numpy.ndarray: + def allocate(self, fill_value: Optional[float] = 1.0, dtype=numpy.float32) -> NDArray: """ Allocate an ndarray for storing grid data. @@ -254,11 +259,12 @@ class Grid: else: return numpy.full(self.cell_data_shape, fill_value, dtype=dtype) - def __init__(self, - pixel_edge_coordinates: Sequence[numpy.ndarray], - shifts: numpy.ndarray = Yee_Shifts_E, - periodic: Union[bool, Sequence[bool]] = False, - ) -> None: + def __init__( + self, + pixel_edge_coordinates: Sequence[ArrayLike], + shifts: ArrayLike = Yee_Shifts_E, + periodic: Union[bool, Sequence[bool]] = False, + ) -> None: """ Args: pixel_edge_coordinates: 3-element list of (ndarrays or lists) specifying the diff --git a/gridlock/position.py b/gridlock/position.py index 1224a12..35ec1df 100644 --- a/gridlock/position.py +++ b/gridlock/position.py @@ -1,19 +1,21 @@ """ Position-related methods for Grid class """ -from typing import List, Optional +from typing import List, Optional, Sequence -import numpy # type: ignore +import numpy +from numpy.typing import NDArray, ArrayLike from . import GridError -def ind2pos(self, - ind: numpy.ndarray, - which_shifts: Optional[int] = None, - round_ind: bool = True, - check_bounds: bool = True - ) -> numpy.ndarray: +def ind2pos( + self, + ind: NDArray, + which_shifts: Optional[int] = None, + round_ind: bool = True, + check_bounds: bool = True + ) -> NDArray[numpy.float64]: """ Returns the natural position corresponding to the specified cell center indices. The resulting position is clipped to the bounds of the grid @@ -59,12 +61,13 @@ def ind2pos(self, return numpy.array(position, dtype=float) -def pos2ind(self, - r: numpy.ndarray, - which_shifts: Optional[int], - round_ind: bool = True, - check_bounds: bool = True - ) -> numpy.ndarray: +def pos2ind( + self, + r: ArrayLike, + which_shifts: Optional[int], + round_ind: bool = True, + check_bounds: bool = True + ) -> NDArray[numpy.float64]: """ Returns the cell-center indices corresponding to the specified natural position. The resulting position is clipped to within the outer centers of the grid. diff --git a/gridlock/read.py b/gridlock/read.py index aa059d5..055670c 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -3,7 +3,8 @@ Readback and visualization methods for Grid class """ from typing import Dict, Optional, Union, Any -import numpy # type: ignore +import numpy +from numpy.typing import NDArray, ArrayLike from . import GridError @@ -12,13 +13,14 @@ from . import GridError # .visualize_isosurface uses mpl_toolkits.mplot3d -def get_slice(self, - cell_data: numpy.ndarray, - surface_normal: int, - center: float, - which_shifts: int = 0, - sample_period: int = 1 - ) -> numpy.ndarray: +def get_slice( + self, + cell_data: NDArray, + surface_normal: int, + center: float, + which_shifts: int = 0, + sample_period: int = 1 + ) -> NDArray: """ Retrieve a slice of a grid. Interpolates if given a position between two planes. @@ -75,15 +77,16 @@ def get_slice(self, return sliced_grid -def visualize_slice(self, - cell_data: numpy.ndarray, - surface_normal: int, - center: float, - which_shifts: int = 0, - sample_period: int = 1, - finalize: bool = True, - pcolormesh_args: Optional[Dict[str, Any]] = None, - ) -> None: +def visualize_slice( + self, + cell_data: NDArray, + surface_normal: int, + center: float, + which_shifts: int = 0, + sample_period: int = 1, + finalize: bool = True, + pcolormesh_args: Optional[Dict[str, Any]] = None, + ) -> None: """ Visualize a slice of a grid. Interpolates if given a position between two planes. @@ -122,14 +125,15 @@ def visualize_slice(self, pyplot.show() -def visualize_isosurface(self, - cell_data: numpy.ndarray, - level: Optional[float] = None, - which_shifts: int = 0, - sample_period: int = 1, - show_edges: bool = True, - finalize: bool = True, - ) -> None: +def visualize_isosurface( + self, + cell_data: NDArray, + level: Optional[float] = None, + which_shifts: int = 0, + sample_period: int = 1, + show_edges: bool = True, + finalize: bool = True, + ) -> None: """ Draw an isosurface plot of the device. diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index fc54030..1e30bf3 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,6 +1,6 @@ import pytest # type: ignore -import numpy # type: ignore -from numpy.testing import assert_allclose, assert_array_equal # type: ignore +import numpy +from numpy.testing import assert_allclose, assert_array_equal from .. import Grid From 73d07bbfe0845ed72d76a2235c9df99e9e7605ea Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Tue, 18 Oct 2022 19:44:47 -0700 Subject: [PATCH 04/48] disaambiguate some variables for typing purposes --- gridlock/draw.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 5390871..b2a3245 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -70,10 +70,13 @@ def draw_polygons( + 'xyz'[surface_normal]) # Broadcast foreground where necessary - if numpy.size(foreground) == 1: - foreground = [foreground] * len(cell_data) + foregrounds: Union[Sequence[foreground_callable_t], Sequence[float]] + if numpy.size(foreground) == 1: # type: ignore + foregrounds = [foreground] * len(cell_data) # type: ignore elif isinstance(foreground, numpy.ndarray): raise GridError('ndarray not supported for foreground') + else: + foregrounds = foreground # type: ignore # ## Compute sub-domain of the grid occupied by polygons # 1) Compute outer bounds (bd) of polygons @@ -105,22 +108,23 @@ def draw_polygons( return numpy.insert(v_2d, surface_normal, (val,)) # iterate over grids - for i, grid in enumerate(cell_data): - # ## Evaluate or expand foreground[i] - if callable(foreground[i]): + for i, _ in enumerate(cell_data): + # ## Evaluate or expand foregrounds[i] + foregrounds_i = foregrounds[i] + if callable(foregrounds_i): # meshgrid over the (shifted) domain domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k]+1] for k in range(3)] (x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij') # evaluate on the meshgrid - foreground_i = foreground[i](x0, y0, z0) - if not numpy.isfinite(foreground_i).all(): + foreground_val = foregrounds_i(x0, y0, z0) + if not numpy.isfinite(foreground_val).all(): raise GridError(f'Non-finite values in foreground[{i}]') - elif numpy.size(foreground[i]) != 1: - raise GridError(f'Unsupported foreground[{i}]: {type(foreground[i])}') + elif numpy.size(foregrounds_i) != 1: + raise GridError(f'Unsupported foreground[{i}]: {type(foregrounds_i)}') else: # foreground[i] is scalar non-callable - foreground_i = foreground[i] + foreground_val = foregrounds_i w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int)) @@ -188,7 +192,7 @@ def draw_polygons( # ## Modify the grid g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3)) - cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_i + cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val def draw_polygon( From a94c2cae67d7b98da5c5c78f564800f275e3f6cf Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Thu, 18 Jul 2024 00:17:20 -0700 Subject: [PATCH 05/48] type hint modernization --- gridlock/draw.py | 14 +++++++------- gridlock/grid.py | 31 +++++++++++++++---------------- gridlock/position.py | 6 ++---- gridlock/read.py | 9 +++++++-- 4 files changed, 31 insertions(+), 29 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index b2a3245..2322a66 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -1,7 +1,7 @@ """ Drawing-related methods for Grid class """ -from typing import List, Optional, Union, Sequence, Callable +from typing import Union, Sequence, Callable import numpy from numpy.typing import NDArray, ArrayLike @@ -27,7 +27,7 @@ def draw_polygons( center: ArrayLike, polygons: Sequence[NDArray], thickness: float, - foreground: Union[Sequence[foreground_t], foreground_t], + foreground: Sequence[foreground_t] | foreground_t, ) -> None: """ Draw polygons on an axis-aligned plane. @@ -70,7 +70,7 @@ def draw_polygons( + 'xyz'[surface_normal]) # Broadcast foreground where necessary - foregrounds: Union[Sequence[foreground_callable_t], Sequence[float]] + foregrounds: Sequence[foreground_callable_t] | Sequence[float] if numpy.size(foreground) == 1: # type: ignore foregrounds = [foreground] * len(cell_data) # type: ignore elif isinstance(foreground, numpy.ndarray): @@ -202,7 +202,7 @@ def draw_polygon( center: ArrayLike, polygon: ArrayLike, thickness: float, - foreground: Union[Sequence[foreground_t], foreground_t], + foreground: Sequence[foreground_t] | foreground_t, ) -> None: """ Draw a polygon on an axis-aligned plane. @@ -226,7 +226,7 @@ def draw_slab( surface_normal: int, center: ArrayLike, thickness: float, - foreground: Union[Sequence[foreground_t], foreground_t], + foreground: Sequence[foreground_t] | foreground_t, ) -> None: """ Draw an axis-aligned infinite slab. @@ -276,7 +276,7 @@ def draw_cuboid( cell_data: NDArray, center: ArrayLike, dimensions: ArrayLike, - foreground: Union[Sequence[foreground_t], foreground_t], + foreground: Sequence[foreground_t] | foreground_t, ) -> None: """ Draw an axis-aligned cuboid @@ -305,7 +305,7 @@ def draw_cylinder( radius: float, thickness: float, num_points: int, - foreground: Union[Sequence[foreground_t], foreground_t], + foreground: Sequence[foreground_t] | foreground_t, ) -> None: """ Draw an axis-aligned cylinder. Approximated by a num_points-gon diff --git a/gridlock/grid.py b/gridlock/grid.py index 5f132a0..36e0608 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Callable, Dict, Optional, Union, Sequence, ClassVar, TypeVar +from typing import Callable, Sequence, ClassVar, Self import numpy from numpy.typing import NDArray, ArrayLike @@ -12,7 +12,6 @@ from . import GridError foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray] -T = TypeVar('T', bound='Grid') class Grid: @@ -49,10 +48,10 @@ class Grid: Because of this, we either assume this 'ghost' cell is the same size as the last real cell, or, if `self.periodic[a]` is set to `True`, the same size as the first cell. """ - exyz: List[NDArray] + exyz: list[NDArray] """Cell edges. Monotonically increasing without duplicates.""" - periodic: List[bool] + periodic: list[bool] """For each axis, determines how far the rightmost boundary gets shifted. """ shifts: NDArray @@ -80,7 +79,7 @@ class Grid: from .position import ind2pos, pos2ind @property - def dxyz(self) -> List[NDArray]: + def dxyz(self) -> list[NDArray]: """ Cell sizes for each axis, no shifts applied @@ -90,7 +89,7 @@ class Grid: return [numpy.diff(ee) for ee in self.exyz] @property - def xyz(self) -> List[NDArray]: + def xyz(self) -> list[NDArray]: """ Cell centers for each axis, no shifts applied @@ -124,7 +123,7 @@ class Grid: return numpy.hstack((self.num_grids, self.shape)) @property - def dxyz_with_ghost(self) -> List[NDArray]: + def dxyz_with_ghost(self) -> list[NDArray]: """ Gives dxyz with an additional 'ghost' cell at the end, whose value depends on whether or not the axis has periodic boundary conditions. See main description @@ -153,7 +152,7 @@ class Grid: return numpy.array(centers, dtype=float) @property - def dxyz_limits(self) -> Tuple[NDArray, NDArray]: + def dxyz_limits(self) -> tuple[NDArray, NDArray]: """ Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element ndarrays. No shifts are applied, so these are extreme bounds on these values (as a @@ -166,7 +165,7 @@ class Grid: d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float) return d_min, d_max - def shifted_exyz(self, which_shifts: Optional[int]) -> List[NDArray]: + def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]: """ Returns edges for which_shifts. @@ -188,7 +187,7 @@ class Grid: return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] - def shifted_dxyz(self, which_shifts: Optional[int]) -> List[NDArray]: + def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]: """ Returns cell sizes for `which_shifts`. @@ -215,7 +214,7 @@ class Grid: return sdxyz - def shifted_xyz(self, which_shifts: Optional[int]) -> List[NDArray[numpy.float64]]: + def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: """ Returns cell centers for `which_shifts`. @@ -231,7 +230,7 @@ class Grid: dxyz = self.shifted_dxyz(which_shifts) return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)] - def autoshifted_dxyz(self) -> List[NDArray[numpy.float64]]: + def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]: """ Return cell widths, with each dimension shifted by the corresponding shifts. @@ -242,7 +241,7 @@ class Grid: raise GridError('Autoshifting requires exactly 3 grids') return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] - def allocate(self, fill_value: Optional[float] = 1.0, dtype=numpy.float32) -> NDArray: + def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray: """ Allocate an ndarray for storing grid data. @@ -263,7 +262,7 @@ class Grid: self, pixel_edge_coordinates: Sequence[ArrayLike], shifts: ArrayLike = Yee_Shifts_E, - periodic: Union[bool, Sequence[bool]] = False, + periodic: bool | Sequence[bool] = False, ) -> None: """ Args: @@ -320,7 +319,7 @@ class Grid: g.__dict__.update(tmp_dict) return g - def save(self: T, filename: str) -> T: + def save(self, filename: str) -> Self: """ Save to file. @@ -334,7 +333,7 @@ class Grid: pickle.dump(self.__dict__, f, protocol=2) return self - def copy(self: T) -> T: + def copy(self) -> Self: """ Returns: Deep copy of the grid. diff --git a/gridlock/position.py b/gridlock/position.py index 35ec1df..b674716 100644 --- a/gridlock/position.py +++ b/gridlock/position.py @@ -1,8 +1,6 @@ """ Position-related methods for Grid class """ -from typing import List, Optional, Sequence - import numpy from numpy.typing import NDArray, ArrayLike @@ -12,7 +10,7 @@ from . import GridError def ind2pos( self, ind: NDArray, - which_shifts: Optional[int] = None, + which_shifts: int | None = None, round_ind: bool = True, check_bounds: bool = True ) -> NDArray[numpy.float64]: @@ -64,7 +62,7 @@ def ind2pos( def pos2ind( self, r: ArrayLike, - which_shifts: Optional[int], + which_shifts: int | None, round_ind: bool = True, check_bounds: bool = True ) -> NDArray[numpy.float64]: diff --git a/gridlock/read.py b/gridlock/read.py index 055670c..8ea4b50 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -1,13 +1,18 @@ """ Readback and visualization methods for Grid class """ -from typing import Dict, Optional, Union, Any +from typing import Any, TYPE_CHECKING import numpy from numpy.typing import NDArray, ArrayLike from . import GridError +if TYPE_CHECKING: + import matplotlib.axes + import matplotlib.figure + + # .visualize_* uses matplotlib # .visualize_isosurface uses skimage # .visualize_isosurface uses mpl_toolkits.mplot3d @@ -128,7 +133,7 @@ def visualize_slice( def visualize_isosurface( self, cell_data: NDArray, - level: Optional[float] = None, + level: float | None = None, which_shifts: int = 0, sample_period: int = 1, show_edges: bool = True, From 3e4e6eead3bfe5af7a2cbe3ec892a315c6cc4d8e Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Thu, 18 Jul 2024 00:17:45 -0700 Subject: [PATCH 06/48] flake8 fixup --- .flake8 | 29 +++++++++++++++++++++++++++++ gridlock/draw.py | 14 +++++++------- gridlock/grid.py | 1 - gridlock/position.py | 2 +- gridlock/read.py | 2 +- 5 files changed, 38 insertions(+), 10 deletions(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..0042015 --- /dev/null +++ b/.flake8 @@ -0,0 +1,29 @@ +[flake8] +ignore = + # E501 line too long + E501, + # W391 newlines at EOF + W391, + # E241 multiple spaces after comma + E241, + # E302 expected 2 newlines + E302, + # W503 line break before binary operator (to be deprecated) + W503, + # E265 block comment should start with '# ' + E265, + # E123 closing bracket does not match indentation of opening bracket's line + E123, + # E124 closing bracket does not match visual indentation + E124, + # E221 multiple spaces before operator + E221, + # E201 whitespace after '[' + E201, + # E741 ambiguous variable name 'I' + E741, + + +per-file-ignores = + # F401 import without use + */__init__.py: F401, diff --git a/gridlock/draw.py b/gridlock/draw.py index 2322a66..75eccc6 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -59,7 +59,7 @@ def draw_polygons( for i, polygon in enumerate(polygons): malformed = f'Malformed polygon: ({i})' if polygon.shape[1] not in (2, 3): - raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') + raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') if polygon.shape[1] == 3: polygon = polygon[surface, :] @@ -72,7 +72,7 @@ def draw_polygons( # Broadcast foreground where necessary foregrounds: Sequence[foreground_callable_t] | Sequence[float] if numpy.size(foreground) == 1: # type: ignore - foregrounds = [foreground] * len(cell_data) # type: ignore + foregrounds = [foreground] * len(cell_data) # type: ignore elif isinstance(foreground, numpy.ndarray): raise GridError('ndarray not supported for foreground') else: @@ -113,7 +113,7 @@ def draw_polygons( foregrounds_i = foregrounds[i] if callable(foregrounds_i): # meshgrid over the (shifted) domain - domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k]+1] for k in range(3)] + domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k] + 1] for k in range(3)] (x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij') # evaluate on the meshgrid @@ -319,7 +319,7 @@ def draw_cylinder( num_points: The circle is approximated by a polygon with `num_points` vertices foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. """ - theta = numpy.linspace(0, 2*numpy.pi, num_points, endpoint=False) + theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False) x = radius * numpy.sin(theta) y = radius * numpy.cos(theta) polygon = numpy.hstack((x[:, None], y[:, None])) @@ -360,8 +360,8 @@ def draw_extrude_rectangle( surface = numpy.delete(range(3), direction) dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface] - p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0]/2.0, - numpy.array([-1, 1, 1, -1], dtype=float) * dim[1]/2.0)).T + p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5, + numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T thickness = distance foreground_func = [] @@ -371,7 +371,7 @@ def draw_extrude_rectangle( ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)] fpart = z - numpy.floor(z) - mult = [1-fpart, fpart][::s] # reverses if s negative + mult = [1 - fpart, fpart][::s] # reverses if s negative foreground = mult[0] * grid[tuple(ind)] ind[direction] += 1 # type: ignore #(known safe) diff --git a/gridlock/grid.py b/gridlock/grid.py index 36e0608..d8e9653 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -2,7 +2,6 @@ from typing import Callable, Sequence, ClassVar, Self import numpy from numpy.typing import NDArray, ArrayLike -from numpy import diff, floor, ceil, zeros, hstack, newaxis import pickle import warnings diff --git a/gridlock/position.py b/gridlock/position.py index b674716..5928174 100644 --- a/gridlock/position.py +++ b/gridlock/position.py @@ -99,7 +99,7 @@ def pos2ind( grid_pos = numpy.zeros((3,)) for a in range(3): - xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in + xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds # No need to interpolate if round_ind is true or we were outside the grid diff --git a/gridlock/read.py b/gridlock/read.py index 8ea4b50..bbd4f39 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -4,7 +4,7 @@ Readback and visualization methods for Grid class from typing import Any, TYPE_CHECKING import numpy -from numpy.typing import NDArray, ArrayLike +from numpy.typing import NDArray from . import GridError From d44e02e2f74c032963ca58b5e5bc27d11b0b2ebf Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Thu, 18 Jul 2024 00:17:58 -0700 Subject: [PATCH 07/48] return figure and axes after plotting --- gridlock/read.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/gridlock/read.py b/gridlock/read.py index bbd4f39..2fe35d5 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -90,8 +90,8 @@ def visualize_slice( which_shifts: int = 0, sample_period: int = 1, finalize: bool = True, - pcolormesh_args: Optional[Dict[str, Any]] = None, - ) -> None: + pcolormesh_args: dict[str, Any] | None = None, + ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: """ Visualize a slice of a grid. Interpolates if given a position between two planes. @@ -102,6 +102,9 @@ def visualize_slice( which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + + Returns: + (Figure, Axes) """ from matplotlib import pyplot @@ -120,15 +123,17 @@ def visualize_slice( xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') x_label, y_label = ('xyz'[a] for a in surface) - pyplot.figure() - pyplot.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) - pyplot.colorbar() - pyplot.gca().set_aspect('equal', adjustable='box') - pyplot.xlabel(x_label) - pyplot.ylabel(y_label) + fig, ax = pyplot.subplots() + mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) + fig.colorbar(mappable) + ax.set_aspect('equal', adjustable='box') + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) if finalize: pyplot.show() + return fig, ax + def visualize_isosurface( self, @@ -138,7 +143,7 @@ def visualize_isosurface( sample_period: int = 1, show_edges: bool = True, finalize: bool = True, - ) -> None: + ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: """ Draw an isosurface plot of the device. @@ -149,6 +154,9 @@ def visualize_isosurface( sample_period: Period for down-sampling the image. Default 1 (disabled) show_edges: Whether to draw triangle edges. Default `True` finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + + Returns: + (Figure, Axes) """ from matplotlib import pyplot import skimage.measure @@ -190,3 +198,5 @@ def visualize_isosurface( if finalize: pyplot.show() + + return fig, ax From 9ab97e763cc5662df4297c51bc2331652698c345 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Thu, 18 Jul 2024 00:20:18 -0700 Subject: [PATCH 08/48] bump min python version to 3.11 due to Self type --- README.md | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index e882b2e..e4f43d5 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ the coordinates of the boundary points along each axis). ## Installation Requirements: -* python 3 (written and tested with 3.9) +* python >3.11 (written and tested with 3.12) * numpy * [float_raster](https://mpxd.net/code/jan/float_raster) * matplotlib (optional, used for visualization functions) diff --git a/pyproject.toml b/pyproject.toml index e9ac6be..3d24be1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,13 +32,13 @@ classifiers = [ "Topic :: Scientific/Engineering :: Physics", "Topic :: Scientific/Engineering :: Visualization", ] -requires-python = ">=3.8" +requires-python = ">=3.11" include = [ "LICENSE.md" ] dynamic = ["version"] dependencies = [ - "numpy~=1.21", + "numpy~=1.26", "float_raster", ] From a15e4bc05e43262f948177d3b1a7e43ae1172346 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 00:44:26 -0700 Subject: [PATCH 09/48] repeat re-exported names --- gridlock/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gridlock/__init__.py b/gridlock/__init__.py index abcadc2..45d1aa8 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -15,8 +15,8 @@ Dependencies: - mpl_toolkits.mplot3d [Grid.visualize_isosurface()] - skimage [Grid.visualize_isosurface()] """ -from .error import GridError -from .grid import Grid +from .error import GridError as GridError +from .grid import Grid as Grid __author__ = 'Jan Petykiewicz' __version__ = '1.1' From e29c0901bdc6c9096b35bf5f88d839d2c613e124 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 00:46:21 -0700 Subject: [PATCH 10/48] use strict zip --- gridlock/draw.py | 8 ++++---- gridlock/grid.py | 2 +- gridlock/read.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 75eccc6..b4b2176 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -144,13 +144,13 @@ def draw_polygons( # Find indices in w_xy which are modified by polygon # First for the edge coordinates (+1 since we're indexing edges) - edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max)] + edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max, strict=True)] # Then for the pixel centers (-bdi_min since we're # calculating weights within a subspace) centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface], - corner_max - bdi_min[surface])) + corner_max - bdi_min[surface], strict=True)) - aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices)) + aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices, strict=True)) w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y) # Clamp overlapping polygons to 1 @@ -380,7 +380,7 @@ def draw_extrude_rectangle( def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int_]: # transform from natural position to index xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i) - for qrs in zip(xs.flat, ys.flat, zs.flat)], dtype=int) + for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=int) # reshape to original shape and keep only in-plane components qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface) return foreground[qi, ri] diff --git a/gridlock/grid.py b/gridlock/grid.py index d8e9653..2fb721b 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -135,7 +135,7 @@ class Grid: list of [dxs, dys, dzs] with each element same length as elements of `self.xyz` """ el = [0 if p else -1 for p in self.periodic] - return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el)] + return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)] @property def center(self) -> NDArray[numpy.float64]: diff --git a/gridlock/read.py b/gridlock/read.py index 2fe35d5..54afa36 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -72,7 +72,7 @@ def get_slice( # Extract grid values from planes above and below visualized slice sliced_grid = numpy.zeros(self.shape[surface]) - for ci, weight in zip(centers, w): + for ci, weight in zip(centers, w, strict=True): s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3)) sliced_grid += weight * cell_data[which_shifts][tuple(s)] @@ -193,7 +193,7 @@ def visualize_isosurface( ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min()) zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min()) # Comment or uncomment following both lines to test the fake bounding box: - for xb, yb, zb in zip(xbs, ybs, zbs): + for xb, yb, zb in zip(xbs, ybs, zbs, strict=True): ax.plot([xb], [yb], [zb], 'w') if finalize: From 5a20339eab6bb0096ed9a2090b55f44ee9aa8fa1 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 00:49:59 -0700 Subject: [PATCH 11/48] del axes3d to clarify it's unused on purpose --- gridlock/read.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gridlock/read.py b/gridlock/read.py index 54afa36..31b583d 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -162,6 +162,7 @@ def visualize_isosurface( import skimage.measure # Claims to be unused, but needed for subplot(projection='3d') from mpl_toolkits.mplot3d import Axes3D + del Axes3D # imported for side effects only # Get data from cell_data grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period] From f84a75f35afd1f9790dff7e46b3c899c38ca39c1 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 00:50:08 -0700 Subject: [PATCH 12/48] comment unused import --- gridlock/test/test_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 1e30bf3..183c9cf 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,6 +1,6 @@ import pytest # type: ignore import numpy -from numpy.testing import assert_allclose, assert_array_equal +from numpy.testing import assert_allclose #, assert_array_equal from .. import Grid From 8c33a39c02b5451f6ff631a75ff42a86a335e583 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 01:37:58 -0700 Subject: [PATCH 13/48] refactor to avoid class-scoped imports --- gridlock/base.py | 198 +++++++++++++ gridlock/draw.py | 681 ++++++++++++++++++++++--------------------- gridlock/grid.py | 201 +------------ gridlock/position.py | 190 ++++++------ gridlock/read.py | 305 +++++++++---------- 5 files changed, 801 insertions(+), 774 deletions(-) create mode 100644 gridlock/base.py diff --git a/gridlock/base.py b/gridlock/base.py new file mode 100644 index 0000000..6bd5fb8 --- /dev/null +++ b/gridlock/base.py @@ -0,0 +1,198 @@ +from typing import ClassVar, Self, Protocol +from collections.abc import Callable, Sequence + +import numpy +from numpy.typing import NDArray, ArrayLike + +from . import GridError + + +class GridBase(Protocol): + exyz: list[NDArray] + """Cell edges. Monotonically increasing without duplicates.""" + + periodic: list[bool] + """For each axis, determines how far the rightmost boundary gets shifted. """ + + shifts: NDArray + """Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`""" + + @property + def dxyz(self) -> list[NDArray]: + """ + Cell sizes for each axis, no shifts applied + + Returns: + List of 3 ndarrays of cell sizes + """ + return [numpy.diff(ee) for ee in self.exyz] + + @property + def xyz(self) -> list[NDArray]: + """ + Cell centers for each axis, no shifts applied + + Returns: + List of 3 ndarrays of cell edges + """ + return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)] + + @property + def shape(self) -> NDArray[numpy.int_]: + """ + The number of cells in x, y, and z + + Returns: + ndarray of [x_centers.size, y_centers.size, z_centers.size] + """ + return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int) + + @property + def num_grids(self) -> int: + """ + The number of grids (number of shifts) + """ + return self.shifts.shape[0] + + @property + def cell_data_shape(self): + """ + The shape of the cell_data ndarray (num_grids, *self.shape). + """ + return numpy.hstack((self.num_grids, self.shape)) + + @property + def dxyz_with_ghost(self) -> list[NDArray]: + """ + Gives dxyz with an additional 'ghost' cell at the end, whose value depends + on whether or not the axis has periodic boundary conditions. See main description + above to learn why this is necessary. + + If periodic, final edge shifts same amount as first + Otherwise, final edge shifts same amount as second-to-last + + Returns: + list of [dxs, dys, dzs] with each element same length as elements of `self.xyz` + """ + el = [0 if p else -1 for p in self.periodic] + return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)] + + @property + def center(self) -> NDArray[numpy.float64]: + """ + Center position of the entire grid, no shifts applied + + Returns: + ndarray of [x_center, y_center, z_center] + """ + # center is just average of first and last xyz, which is just the average of the + # first two and last two exyz + centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)] + return numpy.array(centers, dtype=float) + + @property + def dxyz_limits(self) -> tuple[NDArray, NDArray]: + """ + Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element + ndarrays. No shifts are applied, so these are extreme bounds on these values (as a + weighted average is performed when shifting). + + Returns: + Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]` + """ + d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float) + d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float) + return d_min, d_max + + def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]: + """ + Returns edges for which_shifts. + + Args: + which_shifts: Which grid (which shifts) to use, or `None` for unshifted + + Returns: + List of 3 ndarrays of cell edges + """ + if which_shifts is None: + return self.exyz + dxyz = self.dxyz_with_ghost + shifts = self.shifts[which_shifts, :] + + # If shift is negative, use left cell's dx to determine shift + for a in range(3): + if shifts[a] < 0: + dxyz[a] = numpy.roll(dxyz[a], 1) + + return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] + + def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]: + """ + Returns cell sizes for `which_shifts`. + + Args: + which_shifts: Which grid (which shifts) to use, or `None` for unshifted + + Returns: + List of 3 ndarrays of cell sizes + """ + if which_shifts is None: + return self.dxyz + shifts = self.shifts[which_shifts, :] + dxyz = self.dxyz_with_ghost + + # If shift is negative, use left cell's dx to determine size + sdxyz = [] + for a in range(3): + if shifts[a] < 0: + roll_dxyz = numpy.roll(dxyz[a], 1) + abs_shift = numpy.abs(shifts[a]) + sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift)) + else: + sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a]) + + return sdxyz + + def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: + """ + Returns cell centers for `which_shifts`. + + Args: + which_shifts: Which grid (which shifts) to use, or `None` for unshifted + + Returns: + List of 3 ndarrays of cell centers + """ + if which_shifts is None: + return self.xyz + exyz = self.shifted_exyz(which_shifts) + dxyz = self.shifted_dxyz(which_shifts) + return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)] + + def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]: + """ + Return cell widths, with each dimension shifted by the corresponding shifts. + + Returns: + `[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]` + """ + if self.num_grids != 3: + raise GridError('Autoshifting requires exactly 3 grids') + return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] + + def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray: + """ + Allocate an ndarray for storing grid data. + + Args: + fill_value: Value to initialize the grid to. If None, an + uninitialized array is returned. + dtype: Numpy dtype for the array. Default is `numpy.float32`. + + Returns: + The allocated array + """ + if fill_value is None: + return numpy.empty(self.cell_data_shape, dtype=dtype) + else: + return numpy.full(self.cell_data_shape, fill_value, dtype=dtype) diff --git a/gridlock/draw.py b/gridlock/draw.py index b4b2176..7146e15 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -1,13 +1,16 @@ """ Drawing-related methods for Grid class """ -from typing import Union, Sequence, Callable +from typing import Union +from collections.abc import Sequence, Callable import numpy from numpy.typing import NDArray, ArrayLike from float_raster import raster from . import GridError +from .base import GridBase +from .position import GridPosMixin # NOTE: Maybe it would make sense to create a GridDrawer class @@ -20,372 +23,374 @@ foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray] foreground_t = Union[float, foreground_callable_t] -def draw_polygons( - self, - cell_data: NDArray, - surface_normal: int, - center: ArrayLike, - polygons: Sequence[NDArray], - thickness: float, - foreground: Sequence[foreground_t] | foreground_t, - ) -> None: - """ - Draw polygons on an axis-aligned plane. +class GridDrawMixin(GridPosMixin): + def draw_polygons( + self, + cell_data: NDArray, + surface_normal: int, + center: ArrayLike, + polygons: Sequence[ArrayLike], + thickness: float, + foreground: Sequence[foreground_t] | foreground_t, + ) -> None: + """ + Draw polygons on an axis-aligned plane. - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: 3-element ndarray or list specifying an offset applied to all the polygons - polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon - (non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each - polygon must have at least 3 vertices. - thickness: Thickness of the layer to draw - foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list - of any of these (1 per grid). Callable values should take an ndarray the shape of the - grid and return an ndarray of equal shape containing the foreground value at the given x, y, - and z (natural, not grid coordinates). + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. + center: 3-element ndarray or list specifying an offset applied to all the polygons + polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon + (non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each + polygon must have at least 3 vertices. + thickness: Thickness of the layer to draw + foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list + of any of these (1 per grid). Callable values should take an ndarray the shape of the + grid and return an ndarray of equal shape containing the foreground value at the given x, y, + and z (natural, not grid coordinates). - Raises: - GridError - """ - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') - - center = numpy.squeeze(center) - - # Check polygons, and remove redundant coordinates - surface = numpy.delete(range(3), surface_normal) - - for i, polygon in enumerate(polygons): - malformed = f'Malformed polygon: ({i})' - if polygon.shape[1] not in (2, 3): - raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') - if polygon.shape[1] == 3: - polygon = polygon[surface, :] - - if not polygon.shape[0] > 2: - raise GridError(malformed + 'must consist of more than 2 points') - if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1: - raise GridError(malformed + 'must be in plane with surface normal ' - + 'xyz'[surface_normal]) - - # Broadcast foreground where necessary - foregrounds: Sequence[foreground_callable_t] | Sequence[float] - if numpy.size(foreground) == 1: # type: ignore - foregrounds = [foreground] * len(cell_data) # type: ignore - elif isinstance(foreground, numpy.ndarray): - raise GridError('ndarray not supported for foreground') - else: - foregrounds = foreground # type: ignore - - # ## Compute sub-domain of the grid occupied by polygons - # 1) Compute outer bounds (bd) of polygons - bd_2d_min = numpy.array([0, 0]) - bd_2d_max = numpy.array([0, 0]) - for polygon in polygons: - bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) - bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) - bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center - bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center - - # 2) Find indices (bdi) just outside bd elements - buf = 2 # size of safety buffer - # Use s_min and s_max with unshifted pos2ind to get absolute limits on - # the indices the polygons might affect - s_min = self.shifts.min(axis=0) - s_max = self.shifts.max(axis=0) - bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf - bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf - bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int) - bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int) - - # 3) Adjust polygons for center - polygons = [poly + center[surface] for poly in polygons] - - # ## Generate weighing function - def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]: - v_2d = numpy.array(vector, dtype=float) - return numpy.insert(v_2d, surface_normal, (val,)) - - # iterate over grids - for i, _ in enumerate(cell_data): - # ## Evaluate or expand foregrounds[i] - foregrounds_i = foregrounds[i] - if callable(foregrounds_i): - # meshgrid over the (shifted) domain - domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k] + 1] for k in range(3)] - (x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij') - - # evaluate on the meshgrid - foreground_val = foregrounds_i(x0, y0, z0) - if not numpy.isfinite(foreground_val).all(): - raise GridError(f'Non-finite values in foreground[{i}]') - elif numpy.size(foregrounds_i) != 1: - raise GridError(f'Unsupported foreground[{i}]: {type(foregrounds_i)}') - else: - # foreground[i] is scalar non-callable - foreground_val = foregrounds_i - - w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int)) - - # Draw each polygon separately - for polygon in polygons: - - # Get the boundaries of the polygon - pbd_min = polygon.min(axis=0) - pbd_max = polygon.max(axis=0) - - # Find indices in w_xy just outside polygon - # using per-grid xy-weights (self.shifted_xyz()) - corner_min = self.pos2ind(to_3d(pbd_min), i, - check_bounds=False)[surface].astype(int) - corner_max = self.pos2ind(to_3d(pbd_max), i, - check_bounds=False)[surface].astype(int) - - # Find indices in w_xy which are modified by polygon - # First for the edge coordinates (+1 since we're indexing edges) - edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max, strict=True)] - # Then for the pixel centers (-bdi_min since we're - # calculating weights within a subspace) - centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface], - corner_max - bdi_min[surface], strict=True)) - - aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices, strict=True)) - w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y) - - # Clamp overlapping polygons to 1 - w_xy = numpy.minimum(w_xy, 1.0) - - # 2) Generate weights in z-direction - w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], )) - - def get_zi(offset, i=i, w_z=w_z): - edges = self.shifted_exyz(i)[surface_normal] - point = center[surface_normal] + offset - grid_coord = numpy.digitize(point, edges) - 1 - w_coord = grid_coord - bdi_min[surface_normal] - - if w_coord < 0: - w_coord = 0 - f = 0 - elif w_coord >= w_z.size: - w_coord = w_z.size - 1 - f = 1 - else: - dz = self.shifted_dxyz(i)[surface_normal][grid_coord] - f = (point - edges[grid_coord]) / dz - return f, w_coord - - zi_top_f, zi_top = get_zi(+thickness / 2.0) - zi_bot_f, zi_bot = get_zi(-thickness / 2.0) - - w_z[zi_bot + 1:zi_top] = 1 - - if zi_bot < zi_top: - w_z[zi_top] = zi_top_f - w_z[zi_bot] = 1 - zi_bot_f - else: - w_z[zi_bot] = zi_top_f - zi_bot_f - - # 3) Generate total weight function - w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,))) - - # ## Modify the grid - g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3)) - cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val - - -def draw_polygon( - self, - cell_data: NDArray, - surface_normal: int, - center: ArrayLike, - polygon: ArrayLike, - thickness: float, - foreground: Sequence[foreground_t] | foreground_t, - ) -> None: - """ - Draw a polygon on an axis-aligned plane. - - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: 3-element ndarray or list specifying an offset applied to the polygon - polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, - clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at - least 3 vertices. - thickness: Thickness of the layer to draw - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - """ - self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground) - - -def draw_slab( - self, - cell_data: NDArray, - surface_normal: int, - center: ArrayLike, - thickness: float, - foreground: Sequence[foreground_t] | foreground_t, - ) -> None: - """ - Draw an axis-aligned infinite slab. - - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: `surface_normal` coordinate value at the center of the slab - thickness: Thickness of the layer to draw - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - """ - # Turn surface_normal into its integer representation - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') - - if numpy.size(center) != 1: + Raises: + GridError + """ + if surface_normal not in range(3): + raise GridError('Invalid surface_normal direction') center = numpy.squeeze(center) - if len(center) == 3: - center = center[surface_normal] + poly_list = [numpy.array(poly, copy=False) for poly in polygons] + + # Check polygons, and remove redundant coordinates + surface = numpy.delete(range(3), surface_normal) + + for i, polygon in enumerate(poly_list): + malformed = f'Malformed polygon: ({i})' + if polygon.shape[1] not in (2, 3): + raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') + if polygon.shape[1] == 3: + polygon = polygon[surface, :] + + if not polygon.shape[0] > 2: + raise GridError(malformed + 'must consist of more than 2 points') + if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1: + raise GridError(malformed + 'must be in plane with surface normal ' + + 'xyz'[surface_normal]) + + # Broadcast foreground where necessary + foregrounds: Sequence[foreground_callable_t] | Sequence[float] + if numpy.size(foreground) == 1: # type: ignore + foregrounds = [foreground] * len(cell_data) # type: ignore + elif isinstance(foreground, numpy.ndarray): + raise GridError('ndarray not supported for foreground') else: - raise GridError(f'Bad center: {center}') + foregrounds = foreground # type: ignore - # Find center of slab - center_shift = self.center - center_shift[surface_normal] = center + # ## Compute sub-domain of the grid occupied by polygons + # 1) Compute outer bounds (bd) of polygons + bd_2d_min = numpy.array([0, 0]) + bd_2d_max = numpy.array([0, 0]) + for polygon in poly_list: + bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) + bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) + bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center + bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center - surface = numpy.delete(range(3), surface_normal) + # 2) Find indices (bdi) just outside bd elements + buf = 2 # size of safety buffer + # Use s_min and s_max with unshifted pos2ind to get absolute limits on + # the indices the polygons might affect + s_min = self.shifts.min(axis=0) + s_max = self.shifts.max(axis=0) + bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf + bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf + bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int) + bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int) - xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface] - xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface] + # 3) Adjust polygons for center + poly_list = [poly + center[surface] for poly in poly_list] - dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float) + # ## Generate weighing function + def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]: + v_2d = numpy.array(vector, dtype=float) + return numpy.insert(v_2d, surface_normal, (val,)) - xyz_min -= 4 * dxyz - xyz_max += 4 * dxyz + # iterate over grids + foreground_val: NDArray | float + for i, _ in enumerate(cell_data): + # ## Evaluate or expand foregrounds[i] + foregrounds_i = foregrounds[i] + if callable(foregrounds_i): + # meshgrid over the (shifted) domain + domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k] + 1] for k in range(3)] + (x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij') - p = numpy.array([[xyz_min[0], xyz_max[1]], - [xyz_max[0], xyz_max[1]], - [xyz_max[0], xyz_min[1]], - [xyz_min[0], xyz_min[1]]], dtype=float) + # evaluate on the meshgrid + foreground_val = foregrounds_i(x0, y0, z0) + if not numpy.isfinite(foreground_val).all(): + raise GridError(f'Non-finite values in foreground[{i}]') + elif numpy.size(foregrounds_i) != 1: + raise GridError(f'Unsupported foreground[{i}]: {type(foregrounds_i)}') + else: + # foreground[i] is scalar non-callable + foreground_val = foregrounds_i - self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground) + w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int)) + + # Draw each polygon separately + for polygon in poly_list: + + # Get the boundaries of the polygon + pbd_min = polygon.min(axis=0) + pbd_max = polygon.max(axis=0) + + # Find indices in w_xy just outside polygon + # using per-grid xy-weights (self.shifted_xyz()) + corner_min = self.pos2ind(to_3d(pbd_min), i, + check_bounds=False)[surface].astype(int) + corner_max = self.pos2ind(to_3d(pbd_max), i, + check_bounds=False)[surface].astype(int) + + # Find indices in w_xy which are modified by polygon + # First for the edge coordinates (+1 since we're indexing edges) + edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max, strict=True)] + # Then for the pixel centers (-bdi_min since we're + # calculating weights within a subspace) + centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface], + corner_max - bdi_min[surface], strict=True)) + + aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices, strict=True)) + w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y) + + # Clamp overlapping polygons to 1 + w_xy = numpy.minimum(w_xy, 1.0) + + # 2) Generate weights in z-direction + w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], )) + + def get_zi(offset, i=i, w_z=w_z): + edges = self.shifted_exyz(i)[surface_normal] + point = center[surface_normal] + offset + grid_coord = numpy.digitize(point, edges) - 1 + w_coord = grid_coord - bdi_min[surface_normal] + + if w_coord < 0: + w_coord = 0 + f = 0 + elif w_coord >= w_z.size: + w_coord = w_z.size - 1 + f = 1 + else: + dz = self.shifted_dxyz(i)[surface_normal][grid_coord] + f = (point - edges[grid_coord]) / dz + return f, w_coord + + zi_top_f, zi_top = get_zi(+thickness / 2.0) + zi_bot_f, zi_bot = get_zi(-thickness / 2.0) + + w_z[zi_bot + 1:zi_top] = 1 + + if zi_bot < zi_top: + w_z[zi_top] = zi_top_f + w_z[zi_bot] = 1 - zi_bot_f + else: + w_z[zi_bot] = zi_top_f - zi_bot_f + + # 3) Generate total weight function + w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,))) + + # ## Modify the grid + g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3)) + cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val -def draw_cuboid( - self, - cell_data: NDArray, - center: ArrayLike, - dimensions: ArrayLike, - foreground: Sequence[foreground_t] | foreground_t, - ) -> None: - """ - Draw an axis-aligned cuboid + def draw_polygon( + self, + cell_data: NDArray, + surface_normal: int, + center: ArrayLike, + polygon: ArrayLike, + thickness: float, + foreground: Sequence[foreground_t] | foreground_t, + ) -> None: + """ + Draw a polygon on an axis-aligned plane. - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - center: 3-element ndarray or list specifying the cuboid's center - dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge - sizes of the cuboid - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - """ - dimensions = numpy.array(dimensions, copy=False) - p = numpy.array([[-dimensions[0], +dimensions[1]], - [+dimensions[0], +dimensions[1]], - [+dimensions[0], -dimensions[1]], - [-dimensions[0], -dimensions[1]]], dtype=float) / 2.0 - thickness = dimensions[2] - self.draw_polygon(cell_data, 2, center, p, thickness, foreground) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. + center: 3-element ndarray or list specifying an offset applied to the polygon + polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, + clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at + least 3 vertices. + thickness: Thickness of the layer to draw + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + """ + self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground) -def draw_cylinder( - self, - cell_data: NDArray, - surface_normal: int, - center: ArrayLike, - radius: float, - thickness: float, - num_points: int, - foreground: Sequence[foreground_t] | foreground_t, - ) -> None: - """ - Draw an axis-aligned cylinder. Approximated by a num_points-gon + def draw_slab( + self, + cell_data: NDArray, + surface_normal: int, + center: ArrayLike, + thickness: float, + foreground: Sequence[foreground_t] | foreground_t, + ) -> None: + """ + Draw an axis-aligned infinite slab. - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: 3-element ndarray or list specifying the cylinder's center - radius: cylinder radius - thickness: Thickness of the layer to draw - num_points: The circle is approximated by a polygon with `num_points` vertices - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - """ - theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False) - x = radius * numpy.sin(theta) - y = radius * numpy.cos(theta) - polygon = numpy.hstack((x[:, None], y[:, None])) - self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. + center: `surface_normal` coordinate value at the center of the slab + thickness: Thickness of the layer to draw + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + """ + # Turn surface_normal into its integer representation + if surface_normal not in range(3): + raise GridError('Invalid surface_normal direction') + + if numpy.size(center) != 1: + center = numpy.squeeze(center) + if len(center) == 3: + center = center[surface_normal] + else: + raise GridError(f'Bad center: {center}') + + # Find center of slab + center_shift = self.center + center_shift[surface_normal] = center + + surface = numpy.delete(range(3), surface_normal) + + xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface] + xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface] + + dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float) + + xyz_min -= 4 * dxyz + xyz_max += 4 * dxyz + + p = numpy.array([[xyz_min[0], xyz_max[1]], + [xyz_max[0], xyz_max[1]], + [xyz_max[0], xyz_min[1]], + [xyz_min[0], xyz_min[1]]], dtype=float) + + self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground) -def draw_extrude_rectangle( - self, - cell_data: NDArray, - rectangle: ArrayLike, - direction: int, - polarity: int, - distance: float, - ) -> None: - """ - Extrude a rectangle of a previously-drawn structure along an axis. + def draw_cuboid( + self, + cell_data: NDArray, + center: ArrayLike, + dimensions: ArrayLike, + foreground: Sequence[foreground_t] | foreground_t, + ) -> None: + """ + Draw an axis-aligned cuboid - Args: - cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - rectangle: 2x3 ndarray or list specifying the rectangle's corners - direction: Direction to extrude in. Integer in `range(3)`. - polarity: +1 or -1, direction along axis to extrude in - distance: How far to extrude - """ - s = numpy.sign(polarity) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + center: 3-element ndarray or list specifying the cuboid's center + dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge + sizes of the cuboid + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + """ + dimensions = numpy.array(dimensions, copy=False) + p = numpy.array([[-dimensions[0], +dimensions[1]], + [+dimensions[0], +dimensions[1]], + [+dimensions[0], -dimensions[1]], + [-dimensions[0], -dimensions[1]]], dtype=float) / 2.0 + thickness = dimensions[2] + self.draw_polygon(cell_data, 2, center, p, thickness, foreground) - rectangle = numpy.array(rectangle, dtype=float) - if s == 0: - raise GridError('0 is not a valid polarity') - if direction not in range(3): - raise GridError(f'Invalid direction: {direction}') - if rectangle[0, direction] != rectangle[1, direction]: - raise GridError('Rectangle entries along extrusion direction do not match.') - center = rectangle.sum(axis=0) / 2.0 - center[direction] += s * distance / 2.0 + def draw_cylinder( + self, + cell_data: NDArray, + surface_normal: int, + center: ArrayLike, + radius: float, + thickness: float, + num_points: int, + foreground: Sequence[foreground_t] | foreground_t, + ) -> None: + """ + Draw an axis-aligned cylinder. Approximated by a num_points-gon - surface = numpy.delete(range(3), direction) + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. + center: 3-element ndarray or list specifying the cylinder's center + radius: cylinder radius + thickness: Thickness of the layer to draw + num_points: The circle is approximated by a polygon with `num_points` vertices + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + """ + theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False) + x = radius * numpy.sin(theta) + y = radius * numpy.cos(theta) + polygon = numpy.hstack((x[:, None], y[:, None])) + self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground) - dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface] - p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5, - numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T - thickness = distance - foreground_func = [] - for i, grid in enumerate(cell_data): - z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction] + def draw_extrude_rectangle( + self, + cell_data: NDArray, + rectangle: ArrayLike, + direction: int, + polarity: int, + distance: float, + ) -> None: + """ + Extrude a rectangle of a previously-drawn structure along an axis. - ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)] + Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + rectangle: 2x3 ndarray or list specifying the rectangle's corners + direction: Direction to extrude in. Integer in `range(3)`. + polarity: +1 or -1, direction along axis to extrude in + distance: How far to extrude + """ + s = numpy.sign(polarity) - fpart = z - numpy.floor(z) - mult = [1 - fpart, fpart][::s] # reverses if s negative + rectangle = numpy.array(rectangle, dtype=float) + if s == 0: + raise GridError('0 is not a valid polarity') + if direction not in range(3): + raise GridError(f'Invalid direction: {direction}') + if rectangle[0, direction] != rectangle[1, direction]: + raise GridError('Rectangle entries along extrusion direction do not match.') - foreground = mult[0] * grid[tuple(ind)] - ind[direction] += 1 # type: ignore #(known safe) - foreground += mult[1] * grid[tuple(ind)] + center = rectangle.sum(axis=0) / 2.0 + center[direction] += s * distance / 2.0 - def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int_]: - # transform from natural position to index - xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i) - for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=int) - # reshape to original shape and keep only in-plane components - qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface) - return foreground[qi, ri] + surface = numpy.delete(range(3), direction) - foreground_func.append(f_foreground) + dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface] + p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5, + numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T + thickness = distance - self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func) + foreground_func = [] + for i, grid in enumerate(cell_data): + z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction] + + ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)] + + fpart = z - numpy.floor(z) + mult = [1 - fpart, fpart][::s] # reverses if s negative + + foreground = mult[0] * grid[tuple(ind)] + ind[direction] += 1 # type: ignore #(known safe) + foreground += mult[1] * grid[tuple(ind)] + + def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int_]: + # transform from natural position to index + xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i) + for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=int) + # reshape to original shape and keep only in-plane components + qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface) + return foreground[qi, ri] + + foreground_func.append(f_foreground) + + self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func) diff --git a/gridlock/grid.py b/gridlock/grid.py index 2fb721b..55b1abb 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -1,4 +1,5 @@ -from typing import Callable, Sequence, ClassVar, Self +from typing import ClassVar, Self +from collections.abc import Callable, Sequence import numpy from numpy.typing import NDArray, ArrayLike @@ -8,12 +9,16 @@ import warnings import copy from . import GridError +from .base import GridBase +from .draw import GridDrawMixin +from .read import GridReadMixin +from .position import GridPosMixin foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray] -class Grid: +class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): """ Simulation grid metadata for finite-difference simulations. @@ -70,193 +75,6 @@ class Grid: ], dtype=float) """Default shifts for Yee grid H-field""" - from .draw import ( - draw_polygons, draw_polygon, draw_slab, draw_cuboid, - draw_cylinder, draw_extrude_rectangle, - ) - from .read import get_slice, visualize_slice, visualize_isosurface - from .position import ind2pos, pos2ind - - @property - def dxyz(self) -> list[NDArray]: - """ - Cell sizes for each axis, no shifts applied - - Returns: - List of 3 ndarrays of cell sizes - """ - return [numpy.diff(ee) for ee in self.exyz] - - @property - def xyz(self) -> list[NDArray]: - """ - Cell centers for each axis, no shifts applied - - Returns: - List of 3 ndarrays of cell edges - """ - return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)] - - @property - def shape(self) -> NDArray[numpy.int_]: - """ - The number of cells in x, y, and z - - Returns: - ndarray of [x_centers.size, y_centers.size, z_centers.size] - """ - return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int) - - @property - def num_grids(self) -> int: - """ - The number of grids (number of shifts) - """ - return self.shifts.shape[0] - - @property - def cell_data_shape(self): - """ - The shape of the cell_data ndarray (num_grids, *self.shape). - """ - return numpy.hstack((self.num_grids, self.shape)) - - @property - def dxyz_with_ghost(self) -> list[NDArray]: - """ - Gives dxyz with an additional 'ghost' cell at the end, whose value depends - on whether or not the axis has periodic boundary conditions. See main description - above to learn why this is necessary. - - If periodic, final edge shifts same amount as first - Otherwise, final edge shifts same amount as second-to-last - - Returns: - list of [dxs, dys, dzs] with each element same length as elements of `self.xyz` - """ - el = [0 if p else -1 for p in self.periodic] - return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)] - - @property - def center(self) -> NDArray[numpy.float64]: - """ - Center position of the entire grid, no shifts applied - - Returns: - ndarray of [x_center, y_center, z_center] - """ - # center is just average of first and last xyz, which is just the average of the - # first two and last two exyz - centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)] - return numpy.array(centers, dtype=float) - - @property - def dxyz_limits(self) -> tuple[NDArray, NDArray]: - """ - Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element - ndarrays. No shifts are applied, so these are extreme bounds on these values (as a - weighted average is performed when shifting). - - Returns: - Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]` - """ - d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float) - d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float) - return d_min, d_max - - def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]: - """ - Returns edges for which_shifts. - - Args: - which_shifts: Which grid (which shifts) to use, or `None` for unshifted - - Returns: - List of 3 ndarrays of cell edges - """ - if which_shifts is None: - return self.exyz - dxyz = self.dxyz_with_ghost - shifts = self.shifts[which_shifts, :] - - # If shift is negative, use left cell's dx to determine shift - for a in range(3): - if shifts[a] < 0: - dxyz[a] = numpy.roll(dxyz[a], 1) - - return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] - - def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]: - """ - Returns cell sizes for `which_shifts`. - - Args: - which_shifts: Which grid (which shifts) to use, or `None` for unshifted - - Returns: - List of 3 ndarrays of cell sizes - """ - if which_shifts is None: - return self.dxyz - shifts = self.shifts[which_shifts, :] - dxyz = self.dxyz_with_ghost - - # If shift is negative, use left cell's dx to determine size - sdxyz = [] - for a in range(3): - if shifts[a] < 0: - roll_dxyz = numpy.roll(dxyz[a], 1) - abs_shift = numpy.abs(shifts[a]) - sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift)) - else: - sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a]) - - return sdxyz - - def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: - """ - Returns cell centers for `which_shifts`. - - Args: - which_shifts: Which grid (which shifts) to use, or `None` for unshifted - - Returns: - List of 3 ndarrays of cell centers - """ - if which_shifts is None: - return self.xyz - exyz = self.shifted_exyz(which_shifts) - dxyz = self.shifted_dxyz(which_shifts) - return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)] - - def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]: - """ - Return cell widths, with each dimension shifted by the corresponding shifts. - - Returns: - `[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]` - """ - if self.num_grids != 3: - raise GridError('Autoshifting requires exactly 3 grids') - return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] - - def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray: - """ - Allocate an ndarray for storing grid data. - - Args: - fill_value: Value to initialize the grid to. If None, an - uninitialized array is returned. - dtype: Numpy dtype for the array. Default is `numpy.float32`. - - Returns: - The allocated array - """ - if fill_value is None: - return numpy.empty(self.cell_data_shape, dtype=dtype) - else: - return numpy.full(self.cell_data_shape, fill_value, dtype=dtype) - def __init__( self, pixel_edge_coordinates: Sequence[ArrayLike], @@ -277,11 +95,12 @@ class Grid: Raises: `GridError` on invalid input """ - self.exyz = [numpy.unique(pixel_edge_coordinates[i]) for i in range(3)] + edge_arrs = [numpy.array(cc, copy=False) for cc in pixel_edge_coordinates] + self.exyz = [numpy.unique(edges) for edges in edge_arrs] self.shifts = numpy.array(shifts, dtype=float) for i in range(3): - if len(self.exyz[i]) != len(pixel_edge_coordinates[i]): + if self.exyz[i].size != edge_arrs[i].size: warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2) if isinstance(periodic, bool): diff --git a/gridlock/position.py b/gridlock/position.py index 5928174..b705b99 100644 --- a/gridlock/position.py +++ b/gridlock/position.py @@ -5,112 +5,114 @@ import numpy from numpy.typing import NDArray, ArrayLike from . import GridError +from .base import GridBase -def ind2pos( - self, - ind: NDArray, - which_shifts: int | None = None, - round_ind: bool = True, - check_bounds: bool = True - ) -> NDArray[numpy.float64]: - """ - Returns the natural position corresponding to the specified cell center indices. - The resulting position is clipped to the bounds of the grid - (to cell centers if `round_ind=True`, or cell outer edges if `round_ind=False`) +class GridPosMixin(GridBase): + def ind2pos( + self, + ind: NDArray, + which_shifts: int | None = None, + round_ind: bool = True, + check_bounds: bool = True + ) -> NDArray[numpy.float64]: + """ + Returns the natural position corresponding to the specified cell center indices. + The resulting position is clipped to the bounds of the grid + (to cell centers if `round_ind=True`, or cell outer edges if `round_ind=False`) - Args: - ind: Indices of the position. Can be fractional. (3-element ndarray or list) - which_shifts: which grid number (`shifts`) to use - round_ind: Whether to round ind to the nearest integer position before indexing - (default `True`) - check_bounds: Whether to raise an `GridError` if the provided ind is outside of - the grid, as defined above (centers if `round_ind`, else edges) (default `True`) + Args: + ind: Indices of the position. Can be fractional. (3-element ndarray or list) + which_shifts: which grid number (`shifts`) to use + round_ind: Whether to round ind to the nearest integer position before indexing + (default `True`) + check_bounds: Whether to raise an `GridError` if the provided ind is outside of + the grid, as defined above (centers if `round_ind`, else edges) (default `True`) - Returns: - 3-element ndarray specifying the natural position + Returns: + 3-element ndarray specifying the natural position - Raises: - `GridError` if invalid `which_shifts` - `GridError` if `check_bounds` and out of bounds - """ - if which_shifts is not None and which_shifts >= self.shifts.shape[0]: - raise GridError('Invalid shifts') - ind = numpy.array(ind, dtype=float) + Raises: + `GridError` if invalid `which_shifts` + `GridError` if `check_bounds` and out of bounds + """ + if which_shifts is not None and which_shifts >= self.shifts.shape[0]: + raise GridError('Invalid shifts') + ind = numpy.array(ind, dtype=float) + + if check_bounds: + if round_ind: + low_bound = 0.0 + high_bound = -1.0 + else: + low_bound = -0.5 + high_bound = -0.5 + if (ind < low_bound).any() or (ind > self.shape - high_bound).any(): + raise GridError(f'Position outside of grid: {ind}') - if check_bounds: if round_ind: - low_bound = 0.0 - high_bound = -1.0 + rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) + sxyz = self.shifted_xyz(which_shifts) + position = [sxyz[a][rind[a]].astype(int) for a in range(3)] else: - low_bound = -0.5 - high_bound = -0.5 - if (ind < low_bound).any() or (ind > self.shape - high_bound).any(): - raise GridError(f'Position outside of grid: {ind}') + sexyz = self.shifted_exyz(which_shifts) + position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a]) + for a in range(3)] + return numpy.array(position, dtype=float) + + + def pos2ind( + self, + r: ArrayLike, + which_shifts: int | None, + round_ind: bool = True, + check_bounds: bool = True + ) -> NDArray[numpy.float64]: + """ + Returns the cell-center indices corresponding to the specified natural position. + The resulting position is clipped to within the outer centers of the grid. + + Args: + r: Natural position that we will convert into indices (3-element ndarray or list) + which_shifts: which grid number (`shifts`) to use + round_ind: Whether to round the returned indices to the nearest integers. + check_bounds: Whether to throw an `GridError` if `r` is outside the grid edges + + Returns: + 3-element ndarray specifying the indices + + Raises: + `GridError` if invalid `which_shifts` + `GridError` if `check_bounds` and out of bounds + """ + r = numpy.squeeze(r) + if r.size != 3: + raise GridError(f'r must be 3-element vector: {r}') + + if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]): + raise GridError(f'Invalid which_shifts: {which_shifts}') - if round_ind: - rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) - sxyz = self.shifted_xyz(which_shifts) - position = [sxyz[a][rind[a]].astype(int) for a in range(3)] - else: sexyz = self.shifted_exyz(which_shifts) - position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a]) - for a in range(3)] - return numpy.array(position, dtype=float) + if check_bounds: + for a in range(3): + if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]): + raise GridError(f'Position[{a}] outside of grid!') -def pos2ind( - self, - r: ArrayLike, - which_shifts: int | None, - round_ind: bool = True, - check_bounds: bool = True - ) -> NDArray[numpy.float64]: - """ - Returns the cell-center indices corresponding to the specified natural position. - The resulting position is clipped to within the outer centers of the grid. - - Args: - r: Natural position that we will convert into indices (3-element ndarray or list) - which_shifts: which grid number (`shifts`) to use - round_ind: Whether to round the returned indices to the nearest integers. - check_bounds: Whether to throw an `GridError` if `r` is outside the grid edges - - Returns: - 3-element ndarray specifying the indices - - Raises: - `GridError` if invalid `which_shifts` - `GridError` if `check_bounds` and out of bounds - """ - r = numpy.squeeze(r) - if r.size != 3: - raise GridError(f'r must be 3-element vector: {r}') - - if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]): - raise GridError(f'Invalid which_shifts: {which_shifts}') - - sexyz = self.shifted_exyz(which_shifts) - - if check_bounds: + grid_pos = numpy.zeros((3,)) for a in range(3): - if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]): - raise GridError(f'Position[{a}] outside of grid!') + xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in + xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds - grid_pos = numpy.zeros((3,)) - for a in range(3): - xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in - xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds + # No need to interpolate if round_ind is true or we were outside the grid + if round_ind or xi != xi_clipped: + grid_pos[a] = xi_clipped + else: + # Interpolate + x = self.shifted_xyz(which_shifts)[a][xi] + dx = self.shifted_dxyz(which_shifts)[a][xi] + f = (r[a] - x) / dx - # No need to interpolate if round_ind is true or we were outside the grid - if round_ind or xi != xi_clipped: - grid_pos[a] = xi_clipped - else: - # Interpolate - x = self.shifted_xyz(which_shifts)[a][xi] - dx = self.shifted_dxyz(which_shifts)[a][xi] - f = (r[a] - x) / dx - - # Clip to centers - grid_pos[a] = numpy.clip(xi + f, 0, self.shape[a] - 1) - return grid_pos + # Clip to centers + grid_pos[a] = numpy.clip(xi + f, 0, self.shape[a] - 1) + return grid_pos diff --git a/gridlock/read.py b/gridlock/read.py index 31b583d..e82ffcc 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -7,6 +7,8 @@ import numpy from numpy.typing import NDArray from . import GridError +from .base import GridBase +from .position import GridPosMixin if TYPE_CHECKING: import matplotlib.axes @@ -18,186 +20,187 @@ if TYPE_CHECKING: # .visualize_isosurface uses mpl_toolkits.mplot3d -def get_slice( - self, - cell_data: NDArray, - surface_normal: int, - center: float, - which_shifts: int = 0, - sample_period: int = 1 - ) -> NDArray: - """ - Retrieve a slice of a grid. - Interpolates if given a position between two planes. +class GridReadMixin(GridPosMixin): + def get_slice( + self, + cell_data: NDArray, + surface_normal: int, + center: float, + which_shifts: int = 0, + sample_period: int = 1 + ) -> NDArray: + """ + Retrieve a slice of a grid. + Interpolates if given a position between two planes. - Args: - cell_data: Cell data to slice - surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. - center: Scalar specifying position along surface_normal axis. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) + Args: + cell_data: Cell data to slice + surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. + center: Scalar specifying position along surface_normal axis. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) - Returns: - Array containing the portion of the grid. - """ - if numpy.size(center) != 1 or not numpy.isreal(center): - raise GridError('center must be a real scalar') + Returns: + Array containing the portion of the grid. + """ + if numpy.size(center) != 1 or not numpy.isreal(center): + raise GridError('center must be a real scalar') - sp = round(sample_period) - if sp <= 0: - raise GridError('sample_period must be positive') + sp = round(sample_period) + if sp <= 0: + raise GridError('sample_period must be positive') - if numpy.size(which_shifts) != 1 or which_shifts < 0: - raise GridError('Invalid which_shifts') + if numpy.size(which_shifts) != 1 or which_shifts < 0: + raise GridError('Invalid which_shifts') - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') + if surface_normal not in range(3): + raise GridError('Invalid surface_normal direction') - surface = numpy.delete(range(3), surface_normal) + surface = numpy.delete(range(3), surface_normal) - # Extract indices and weights of planes - center3 = numpy.insert([0, 0], surface_normal, (center,)) - center_index = self.pos2ind(center3, which_shifts, - round_ind=False, check_bounds=False)[surface_normal] - centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) - if len(centers) == 2: - fpart = center_index - numpy.floor(center_index) - w = [1 - fpart, fpart] # longer distance -> less weight - else: - w = [1] + # Extract indices and weights of planes + center3 = numpy.insert([0, 0], surface_normal, (center,)) + center_index = self.pos2ind(center3, which_shifts, + round_ind=False, check_bounds=False)[surface_normal] + centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) + if len(centers) == 2: + fpart = center_index - numpy.floor(center_index) + w = [1 - fpart, fpart] # longer distance -> less weight + else: + w = [1] - c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1]) - if center < c_min or center > c_max: - raise GridError('Coordinate of selected plane must be within simulation domain') + c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1]) + if center < c_min or center > c_max: + raise GridError('Coordinate of selected plane must be within simulation domain') - # Extract grid values from planes above and below visualized slice - sliced_grid = numpy.zeros(self.shape[surface]) - for ci, weight in zip(centers, w, strict=True): - s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3)) - sliced_grid += weight * cell_data[which_shifts][tuple(s)] + # Extract grid values from planes above and below visualized slice + sliced_grid = numpy.zeros(self.shape[surface]) + for ci, weight in zip(centers, w, strict=True): + s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3)) + sliced_grid += weight * cell_data[which_shifts][tuple(s)] - # Remove extra dimensions - sliced_grid = numpy.squeeze(sliced_grid) + # Remove extra dimensions + sliced_grid = numpy.squeeze(sliced_grid) - return sliced_grid + return sliced_grid -def visualize_slice( - self, - cell_data: NDArray, - surface_normal: int, - center: float, - which_shifts: int = 0, - sample_period: int = 1, - finalize: bool = True, - pcolormesh_args: dict[str, Any] | None = None, - ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: - """ - Visualize a slice of a grid. - Interpolates if given a position between two planes. + def visualize_slice( + self, + cell_data: NDArray, + surface_normal: int, + center: float, + which_shifts: int = 0, + sample_period: int = 1, + finalize: bool = True, + pcolormesh_args: dict[str, Any] | None = None, + ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: + """ + Visualize a slice of a grid. + Interpolates if given a position between two planes. - Args: - surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. - center: Scalar specifying position along surface_normal axis. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) - finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + Args: + surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. + center: Scalar specifying position along surface_normal axis. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) + finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - Returns: - (Figure, Axes) - """ - from matplotlib import pyplot + Returns: + (Figure, Axes) + """ + from matplotlib import pyplot - if pcolormesh_args is None: - pcolormesh_args = {} + if pcolormesh_args is None: + pcolormesh_args = {} - grid_slice = self.get_slice(cell_data=cell_data, - surface_normal=surface_normal, - center=center, - which_shifts=which_shifts, - sample_period=sample_period) + grid_slice = self.get_slice(cell_data=cell_data, + surface_normal=surface_normal, + center=center, + which_shifts=which_shifts, + sample_period=sample_period) - surface = numpy.delete(range(3), surface_normal) + surface = numpy.delete(range(3), surface_normal) - x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) - xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') - x_label, y_label = ('xyz'[a] for a in surface) + x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) + xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') + x_label, y_label = ('xyz'[a] for a in surface) - fig, ax = pyplot.subplots() - mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) - fig.colorbar(mappable) - ax.set_aspect('equal', adjustable='box') - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - if finalize: - pyplot.show() + fig, ax = pyplot.subplots() + mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) + fig.colorbar(mappable) + ax.set_aspect('equal', adjustable='box') + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + if finalize: + pyplot.show() - return fig, ax + return fig, ax -def visualize_isosurface( - self, - cell_data: NDArray, - level: float | None = None, - which_shifts: int = 0, - sample_period: int = 1, - show_edges: bool = True, - finalize: bool = True, - ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: - """ - Draw an isosurface plot of the device. + def visualize_isosurface( + self, + cell_data: NDArray, + level: float | None = None, + which_shifts: int = 0, + sample_period: int = 1, + show_edges: bool = True, + finalize: bool = True, + ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: + """ + Draw an isosurface plot of the device. - Args: - cell_data: Cell data to visualize - level: Value at which to find isosurface. Default (None) uses mean value in grid. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) - show_edges: Whether to draw triangle edges. Default `True` - finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + Args: + cell_data: Cell data to visualize + level: Value at which to find isosurface. Default (None) uses mean value in grid. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) + show_edges: Whether to draw triangle edges. Default `True` + finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - Returns: - (Figure, Axes) - """ - from matplotlib import pyplot - import skimage.measure - # Claims to be unused, but needed for subplot(projection='3d') - from mpl_toolkits.mplot3d import Axes3D - del Axes3D # imported for side effects only + Returns: + (Figure, Axes) + """ + from matplotlib import pyplot + import skimage.measure + # Claims to be unused, but needed for subplot(projection='3d') + from mpl_toolkits.mplot3d import Axes3D + del Axes3D # imported for side effects only - # Get data from cell_data - grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period] - if level is None: - level = grid.mean() + # Get data from cell_data + grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period] + if level is None: + level = grid.mean() - # Find isosurface with marching cubes - verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) + # Find isosurface with marching cubes + verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) - # Convert vertices from index to position - pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False) - for i in range(verts.shape[0])], dtype=float) - xs, ys, zs = (pos_verts[:, a] for a in range(3)) + # Convert vertices from index to position + pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False) + for i in range(verts.shape[0])], dtype=float) + xs, ys, zs = (pos_verts[:, a] for a in range(3)) - # Draw the plot - fig = pyplot.figure() - ax = fig.add_subplot(111, projection='3d') - if show_edges: - ax.plot_trisurf(xs, ys, faces, zs) - else: - ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') + # Draw the plot + fig = pyplot.figure() + ax = fig.add_subplot(111, projection='3d') + if show_edges: + ax.plot_trisurf(xs, ys, faces, zs) + else: + ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') - # Add a fake plot of a cube to force the axes to be equal lengths - max_range = numpy.array([xs.max() - xs.min(), - ys.max() - ys.min(), - zs.max() - zs.min()], dtype=float).max() - mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2] - xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min()) - ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min()) - zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min()) - # Comment or uncomment following both lines to test the fake bounding box: - for xb, yb, zb in zip(xbs, ybs, zbs, strict=True): - ax.plot([xb], [yb], [zb], 'w') + # Add a fake plot of a cube to force the axes to be equal lengths + max_range = numpy.array([xs.max() - xs.min(), + ys.max() - ys.min(), + zs.max() - zs.min()], dtype=float).max() + mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2] + xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min()) + ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min()) + zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min()) + # Comment or uncomment following both lines to test the fake bounding box: + for xb, yb, zb in zip(xbs, ybs, zbs, strict=True): + ax.plot([xb], [yb], [zb], 'w') - if finalize: - pyplot.show() + if finalize: + pyplot.show() - return fig, ax + return fig, ax From c32d94ed856811930621821cc38efcb63860491f Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 01:38:42 -0700 Subject: [PATCH 14/48] fix typos in arg names in example --- gridlock/examples/ex0.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gridlock/examples/ex0.py b/gridlock/examples/ex0.py index 59756bc..7dd4355 100644 --- a/gridlock/examples/ex0.py +++ b/gridlock/examples/ex0.py @@ -37,8 +37,8 @@ if __name__ == '__main__': # eg.draw_slab(Direction.z, 0, 10, 2) eg.save('/home/jan/Desktop/test.pickle') eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0, - thickness=10, num_poitns=1000, foreground=1) + thickness=10, num_points=1000, foreground=1) eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]], - direction=1, poalarity=+1, distance=5) + direction=1, polarity=+1, distance=5) eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2) eg.visualize_isosurface(egc, which_shifts=2) From e256f56f2b85640ac5368377fe21bd683a969da9 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 01:47:12 -0700 Subject: [PATCH 15/48] fix handling of 3d polys --- gridlock/draw.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 7146e15..e68fdc6 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -60,12 +60,14 @@ class GridDrawMixin(GridPosMixin): # Check polygons, and remove redundant coordinates surface = numpy.delete(range(3), surface_normal) - for i, polygon in enumerate(poly_list): - malformed = f'Malformed polygon: ({i})' + for ii in range(len(poly_list)): + polygon = poly_list[ii] + malformed = f'Malformed polygon: ({ii})' if polygon.shape[1] not in (2, 3): raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') if polygon.shape[1] == 3: polygon = polygon[surface, :] + poly_list[ii] = polygon if not polygon.shape[0] > 2: raise GridError(malformed + 'must consist of more than 2 points') From 646911c4b5f707fc35174a1addcada4c5e4caadf Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 01:57:39 -0700 Subject: [PATCH 16/48] type annotation improvements --- gridlock/base.py | 14 ++++++-------- gridlock/draw.py | 9 ++++----- gridlock/examples/ex0.py | 2 +- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/gridlock/base.py b/gridlock/base.py index 6bd5fb8..aca9c69 100644 --- a/gridlock/base.py +++ b/gridlock/base.py @@ -1,8 +1,7 @@ -from typing import ClassVar, Self, Protocol -from collections.abc import Callable, Sequence +from typing import Protocol import numpy -from numpy.typing import NDArray, ArrayLike +from numpy.typing import NDArray from . import GridError @@ -38,7 +37,7 @@ class GridBase(Protocol): return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)] @property - def shape(self) -> NDArray[numpy.int_]: + def shape(self) -> NDArray[numpy.intp]: """ The number of cells in x, y, and z @@ -55,7 +54,7 @@ class GridBase(Protocol): return self.shifts.shape[0] @property - def cell_data_shape(self): + def cell_data_shape(self) -> NDArray[numpy.intp]: """ The shape of the cell_data ndarray (num_grids, *self.shape). """ @@ -180,7 +179,7 @@ class GridBase(Protocol): raise GridError('Autoshifting requires exactly 3 grids') return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] - def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray: + def allocate(self, fill_value: float | None = 1.0, dtype: type[numpy.number] = numpy.float32) -> NDArray: """ Allocate an ndarray for storing grid data. @@ -194,5 +193,4 @@ class GridBase(Protocol): """ if fill_value is None: return numpy.empty(self.cell_data_shape, dtype=dtype) - else: - return numpy.full(self.cell_data_shape, fill_value, dtype=dtype) + return numpy.full(self.cell_data_shape, fill_value, dtype=dtype) diff --git a/gridlock/draw.py b/gridlock/draw.py index e68fdc6..aa45200 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -1,7 +1,6 @@ """ Drawing-related methods for Grid class """ -from typing import Union from collections.abc import Sequence, Callable import numpy @@ -20,7 +19,7 @@ from .position import GridPosMixin foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray] -foreground_t = Union[float, foreground_callable_t] +foreground_t = float | foreground_callable_t class GridDrawMixin(GridPosMixin): @@ -166,7 +165,7 @@ class GridDrawMixin(GridPosMixin): # 2) Generate weights in z-direction w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], )) - def get_zi(offset, i=i, w_z=w_z): + def get_zi(offset: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001 edges = self.shifted_exyz(i)[surface_normal] point = center[surface_normal] + offset grid_coord = numpy.digitize(point, edges) - 1 @@ -384,10 +383,10 @@ class GridDrawMixin(GridPosMixin): ind[direction] += 1 # type: ignore #(known safe) foreground += mult[1] * grid[tuple(ind)] - def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int_]: + def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 # transform from natural position to index xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i) - for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=int) + for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64) # reshape to original shape and keep only in-plane components qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface) return foreground[qi, ri] diff --git a/gridlock/examples/ex0.py b/gridlock/examples/ex0.py index 7dd4355..b96cdca 100644 --- a/gridlock/examples/ex0.py +++ b/gridlock/examples/ex0.py @@ -29,7 +29,7 @@ if __name__ == '__main__': # numpy.linspace(-5.5, 5.5, 10)] half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5] - xyz3 = [[-x for x in half_x[::-1]] + [0] + half_x, + xyz3 = [numpy.array([-x for x in half_x[::-1]] + [0] + half_x), numpy.linspace(-5.5, 5.5, 10), numpy.linspace(-5.5, 5.5, 10)] eg = Grid(xyz3) From e5fdc3ce23557fa6bf56c2e24fe4e2e0927776e0 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 01:57:48 -0700 Subject: [PATCH 17/48] drop unused imports --- gridlock/draw.py | 1 - gridlock/grid.py | 1 - gridlock/read.py | 1 - 3 files changed, 3 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index aa45200..2fb0b47 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -8,7 +8,6 @@ from numpy.typing import NDArray, ArrayLike from float_raster import raster from . import GridError -from .base import GridBase from .position import GridPosMixin diff --git a/gridlock/grid.py b/gridlock/grid.py index 55b1abb..91f3c6b 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -9,7 +9,6 @@ import warnings import copy from . import GridError -from .base import GridBase from .draw import GridDrawMixin from .read import GridReadMixin from .position import GridPosMixin diff --git a/gridlock/read.py b/gridlock/read.py index e82ffcc..0f46836 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -7,7 +7,6 @@ import numpy from numpy.typing import NDArray from . import GridError -from .base import GridBase from .position import GridPosMixin if TYPE_CHECKING: From 8e7e0edb1f8c070249e301427f0ad36ff3699ec5 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 01:57:57 -0700 Subject: [PATCH 18/48] add ruff and mypy configs --- pyproject.toml | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 3d24be1..672910c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,3 +53,47 @@ visualization-isosurface = [ "skimage>=0.13", "mpl_toolkits", ] + + +[tool.ruff] +exclude = [ + ".git", + "dist", + ] +line-length = 145 +indent-width = 4 +lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +lint.select = [ + "NPY", "E", "F", "W", "B", "ANN", "UP", "SLOT", "SIM", "LOG", + "C4", "ISC", "PIE", "PT", "RET", "TCH", "PTH", "INT", + "ARG", "PL", "R", "TRY", + "G010", "G101", "G201", "G202", + "Q002", "Q003", "Q004", + ] +lint.ignore = [ + #"ANN001", # No annotation + "ANN002", # *args + "ANN003", # **kwargs + "ANN401", # Any + "ANN101", # self: Self + "SIM108", # single-line if / else assignment + "RET504", # x=y+z; return x + "PIE790", # unnecessary pass + "ISC003", # non-implicit string concatenation + "C408", # dict(x=y) instead of {'x': y} + "PLR09", # Too many xxx + "PLR2004", # magic number + "PLC0414", # import x as x + "TRY003", # Long exception message + "PTH123", # open() + ] + + +[[tool.mypy.overrides]] +module = [ + "matplotlib", + "matplotlib.axes", + "matplotlib.figure", + "mpl_toolkits.mplot3d", + ] +ignore_missing_imports = true From e1e6134ec0a2eebc0bab323ab8e7e3be76e499e9 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 02:11:04 -0700 Subject: [PATCH 19/48] use asarray (since copy=False meaning changes in numpy 2.0) --- gridlock/draw.py | 4 ++-- gridlock/grid.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 2fb0b47..e10d4ff 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -53,7 +53,7 @@ class GridDrawMixin(GridPosMixin): if surface_normal not in range(3): raise GridError('Invalid surface_normal direction') center = numpy.squeeze(center) - poly_list = [numpy.array(poly, copy=False) for poly in polygons] + poly_list = [numpy.asarray(poly) for poly in polygons] # Check polygons, and remove redundant coordinates surface = numpy.delete(range(3), surface_normal) @@ -293,7 +293,7 @@ class GridDrawMixin(GridPosMixin): sizes of the cuboid foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. """ - dimensions = numpy.array(dimensions, copy=False) + dimensions = numpy.asarray(dimensions) p = numpy.array([[-dimensions[0], +dimensions[1]], [+dimensions[0], +dimensions[1]], [+dimensions[0], -dimensions[1]], diff --git a/gridlock/grid.py b/gridlock/grid.py index 91f3c6b..c8fe998 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -94,7 +94,7 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Raises: `GridError` on invalid input """ - edge_arrs = [numpy.array(cc, copy=False) for cc in pixel_edge_coordinates] + edge_arrs = [numpy.asarray(cc) for cc in pixel_edge_coordinates] self.exyz = [numpy.unique(edges) for edges in edge_arrs] self.shifts = numpy.array(shifts, dtype=float) From 045b0c0228860f7983d3f867a47cb357080e40e6 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 02:12:37 -0700 Subject: [PATCH 20/48] enable numpy 2.0 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 672910c..1df2e8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,8 +38,8 @@ include = [ ] dynamic = ["version"] dependencies = [ - "numpy~=1.26", - "float_raster", + "numpy>=1.26", + "float_raster>=0.8", ] From c95341c9b9945a96fb03d6a237ea0cb1e631cff3 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Wed, 31 Jul 2024 22:50:34 -0700 Subject: [PATCH 21/48] be clearer about floats --- gridlock/draw.py | 2 +- gridlock/examples/ex0.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index e10d4ff..52e02f5 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -297,7 +297,7 @@ class GridDrawMixin(GridPosMixin): p = numpy.array([[-dimensions[0], +dimensions[1]], [+dimensions[0], +dimensions[1]], [+dimensions[0], -dimensions[1]], - [-dimensions[0], -dimensions[1]]], dtype=float) / 2.0 + [-dimensions[0], -dimensions[1]]], dtype=float) * 0.5 thickness = dimensions[2] self.draw_polygon(cell_data, 2, center, p, thickness, foreground) diff --git a/gridlock/examples/ex0.py b/gridlock/examples/ex0.py index b96cdca..d13a8cf 100644 --- a/gridlock/examples/ex0.py +++ b/gridlock/examples/ex0.py @@ -29,9 +29,9 @@ if __name__ == '__main__': # numpy.linspace(-5.5, 5.5, 10)] half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5] - xyz3 = [numpy.array([-x for x in half_x[::-1]] + [0] + half_x), - numpy.linspace(-5.5, 5.5, 10), - numpy.linspace(-5.5, 5.5, 10)] + xyz3 = [numpy.array([-x for x in half_x[::-1]] + [0] + half_x, dtype=float), + numpy.linspace(-5.5, 5.5, 10, dtype=float), + numpy.linspace(-5.5, 5.5, 10, dtype=float)] eg = Grid(xyz3) egc = eg.allocate(0) # eg.draw_slab(Direction.z, 0, 10, 2) From 4218f529ea4c90980323d69d9a8d73859df086ac Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Wed, 31 Jul 2024 22:51:07 -0700 Subject: [PATCH 22/48] Copy pixel edge coordinates --- gridlock/grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/grid.py b/gridlock/grid.py index c8fe998..5790dbd 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -94,7 +94,7 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Raises: `GridError` on invalid input """ - edge_arrs = [numpy.asarray(cc) for cc in pixel_edge_coordinates] + edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates] self.exyz = [numpy.unique(edges) for edges in edge_arrs] self.shifts = numpy.array(shifts, dtype=float) From c7ad0f0e37cfcc0b1377637e8a9eb3845c2fe6bc Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Wed, 31 Jul 2024 22:51:19 -0700 Subject: [PATCH 23/48] comment out unused pytest import --- gridlock/test/test_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 183c9cf..7a70211 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,4 +1,4 @@ -import pytest # type: ignore +# import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal From c5989785430ff79942da5abd6f328936cf307d55 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Wed, 31 Jul 2024 22:51:34 -0700 Subject: [PATCH 24/48] bump version to v1.2 --- gridlock/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 45d1aa8..281c2b1 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -19,5 +19,5 @@ from .error import GridError as GridError from .grid import Grid as Grid __author__ = 'Jan Petykiewicz' -__version__ = '1.1' +__version__ = '1.2' version = __version__ From 34f80202ba6b189a5886d1d75dc373339165a895 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Tue, 28 Jan 2025 19:36:59 -0800 Subject: [PATCH 25/48] Major rework of arguments using Extent/Slab/Plane --- gridlock/__init__.py | 17 ++- gridlock/direction.py | 10 -- gridlock/draw.py | 209 +++++++++++++++++++------------------ gridlock/error.py | 2 - gridlock/examples/ex0.py | 37 ++++--- gridlock/read.py | 68 ++++++------ gridlock/test/test_grid.py | 42 ++++++-- gridlock/utils.py | 201 +++++++++++++++++++++++++++++++++++ 8 files changed, 416 insertions(+), 170 deletions(-) delete mode 100644 gridlock/direction.py delete mode 100644 gridlock/error.py create mode 100644 gridlock/utils.py diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 281c2b1..120291f 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -15,9 +15,24 @@ Dependencies: - mpl_toolkits.mplot3d [Grid.visualize_isosurface()] - skimage [Grid.visualize_isosurface()] """ -from .error import GridError as GridError +from .utils import ( + GridError as GridError, + + Extent as Extent, + ExtentProtocol as ExtentProtocol, + ExtentDict as ExtentDict, + + Slab as Slab, + SlabProtocol as SlabProtocol, + SlabDict as SlabDict, + + Plane as Plane, + PlaneProtocol as PlaneProtocol, + PlaneDict as PlaneDict, + ) from .grid import Grid as Grid + __author__ = 'Jan Petykiewicz' __version__ = '1.2' version = __version__ diff --git a/gridlock/direction.py b/gridlock/direction.py deleted file mode 100644 index b93b122..0000000 --- a/gridlock/direction.py +++ /dev/null @@ -1,10 +0,0 @@ -from enum import Enum - - -class Direction(Enum): - """ - Enum for axis->integer mapping - """ - x = 0 - y = 1 - z = 2 diff --git a/gridlock/draw.py b/gridlock/draw.py index 52e02f5..4930231 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -7,7 +7,7 @@ import numpy from numpy.typing import NDArray, ArrayLike from float_raster import raster -from . import GridError +from .utils import GridError, Slab, SlabDict, SlabProtocol, Extent, ExtentDict, ExtentProtocol from .position import GridPosMixin @@ -21,27 +21,26 @@ foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray] foreground_t = float | foreground_callable_t + class GridDrawMixin(GridPosMixin): def draw_polygons( self, cell_data: NDArray, - surface_normal: int, - center: ArrayLike, + slab: SlabProtocol | SlabDict, polygons: Sequence[ArrayLike], - thickness: float, foreground: Sequence[foreground_t] | foreground_t, + *, + offset2d: ArrayLike = (0, 0), ) -> None: """ Draw polygons on an axis-aligned plane. Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. center: 3-element ndarray or list specifying an offset applied to all the polygons polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon - (non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each + (non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each polygon must have at least 3 vertices. - thickness: Thickness of the layer to draw foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list of any of these (1 per grid). Callable values should take an ndarray the shape of the grid and return an ndarray of equal shape containing the foreground value at the given x, y, @@ -50,13 +49,13 @@ class GridDrawMixin(GridPosMixin): Raises: GridError """ - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') - center = numpy.squeeze(center) + if isinstance(slab, dict): + slab = Slab(**slab) + poly_list = [numpy.asarray(poly) for poly in polygons] # Check polygons, and remove redundant coordinates - surface = numpy.delete(range(3), surface_normal) + surface = numpy.delete(range(3), slab.axis) for ii in range(len(poly_list)): polygon = poly_list[ii] @@ -69,9 +68,8 @@ class GridDrawMixin(GridPosMixin): if not polygon.shape[0] > 2: raise GridError(malformed + 'must consist of more than 2 points') - if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1: - raise GridError(malformed + 'must be in plane with surface normal ' - + 'xyz'[surface_normal]) + if polygon.ndim > 2 and not numpy.unique(polygon[:, slab.axis]).size == 1: + raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis]) # Broadcast foreground where necessary foregrounds: Sequence[foreground_callable_t] | Sequence[float] @@ -87,10 +85,10 @@ class GridDrawMixin(GridPosMixin): bd_2d_min = numpy.array([0, 0]) bd_2d_max = numpy.array([0, 0]) for polygon in poly_list: - bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) - bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) - bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center - bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center + bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) + offset2d + bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) + offset2d + bd_min = numpy.insert(bd_2d_min, slab.axis, slab.min) + bd_max = numpy.insert(bd_2d_max, slab.axis, slab.max) # 2) Find indices (bdi) just outside bd elements buf = 2 # size of safety buffer @@ -103,13 +101,13 @@ class GridDrawMixin(GridPosMixin): bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int) bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int) - # 3) Adjust polygons for center - poly_list = [poly + center[surface] for poly in poly_list] + # 3) Adjust polygons for offset2d + poly_list = [poly + offset2d for poly in poly_list] # ## Generate weighing function def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]: v_2d = numpy.array(vector, dtype=float) - return numpy.insert(v_2d, surface_normal, (val,)) + return numpy.insert(v_2d, slab.axis, (val,)) # iterate over grids foreground_val: NDArray | float @@ -162,13 +160,12 @@ class GridDrawMixin(GridPosMixin): w_xy = numpy.minimum(w_xy, 1.0) # 2) Generate weights in z-direction - w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], )) + w_z = numpy.zeros(((bdi_max - bdi_min + 1)[slab.axis], )) - def get_zi(offset: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001 - edges = self.shifted_exyz(i)[surface_normal] - point = center[surface_normal] + offset + def get_zi(point: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001 + edges = self.shifted_exyz(i)[slab.axis] grid_coord = numpy.digitize(point, edges) - 1 - w_coord = grid_coord - bdi_min[surface_normal] + w_coord = grid_coord - bdi_min[slab.axis] if w_coord < 0: w_coord = 0 @@ -177,12 +174,12 @@ class GridDrawMixin(GridPosMixin): w_coord = w_z.size - 1 f = 1 else: - dz = self.shifted_dxyz(i)[surface_normal][grid_coord] + dz = self.shifted_dxyz(i)[slab.axis][grid_coord] f = (point - edges[grid_coord]) / dz return f, w_coord - zi_top_f, zi_top = get_zi(+thickness / 2.0) - zi_bot_f, zi_bot = get_zi(-thickness / 2.0) + zi_top_f, zi_top = get_zi(slab.max) + zi_bot_f, zi_bot = get_zi(slab.min) w_z[zi_bot + 1:zi_top] = 1 @@ -193,7 +190,7 @@ class GridDrawMixin(GridPosMixin): w_z[zi_bot] = zi_top_f - zi_bot_f # 3) Generate total weight function - w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,))) + w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], slab.axis, (2,))) # ## Modify the grid g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3)) @@ -203,34 +200,36 @@ class GridDrawMixin(GridPosMixin): def draw_polygon( self, cell_data: NDArray, - surface_normal: int, - center: ArrayLike, + slab: SlabProtocol | SlabDict, polygon: ArrayLike, - thickness: float, foreground: Sequence[foreground_t] | foreground_t, + *, + offset2d: ArrayLike = (0, 0), ) -> None: """ Draw a polygon on an axis-aligned plane. Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: 3-element ndarray or list specifying an offset applied to the polygon + slab: `Slab` in which to draw polygons. polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, - clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at + clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Must have at least 3 vertices. - thickness: Thickness of the layer to draw foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. """ - self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground) + self.draw_polygons( + cell_data = cell_data, + slab = slab, + polygons = [polygon], + foreground = foreground, + offset2d = offset2d, + ) def draw_slab( self, cell_data: NDArray, - surface_normal: int, - center: ArrayLike, - thickness: float, + slab: SlabProtocol | SlabDict, foreground: Sequence[foreground_t] | foreground_t, ) -> None: """ @@ -238,50 +237,45 @@ class GridDrawMixin(GridPosMixin): Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: `surface_normal` coordinate value at the center of the slab + slab: thickness: Thickness of the layer to draw foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. """ - # Turn surface_normal into its integer representation - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') - - if numpy.size(center) != 1: - center = numpy.squeeze(center) - if len(center) == 3: - center = center[surface_normal] - else: - raise GridError(f'Bad center: {center}') + if isinstance(slab, dict): + slab = Slab(**slab) # Find center of slab center_shift = self.center - center_shift[surface_normal] = center + center_shift[slab.axis] = slab.center - surface = numpy.delete(range(3), surface_normal) + surface = numpy.delete(range(3), slab.axis) + u_min, u_max = self.exyz[surface[0]][[0, -1]] + v_min, v_max = self.exyz[surface[1]][[0, -1]] - xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface] - xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface] + margin = 4 * numpy.max(self.dxyz[surface[0]].max(), + self.dxyz[surface[1]].max()) - dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float) + p = numpy.array([[u_min - margin, v_max + margin], + [u_max + margin, v_max + margin], + [u_max + margin, v_min - margin], + [u_min - margin, v_min - margin]], dtype=float) - xyz_min -= 4 * dxyz - xyz_max += 4 * dxyz - - p = numpy.array([[xyz_min[0], xyz_max[1]], - [xyz_max[0], xyz_max[1]], - [xyz_max[0], xyz_min[1]], - [xyz_min[0], xyz_min[1]]], dtype=float) - - self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground) + self.draw_polygon( + cell_data = cell_data, + slab = slab, + polygon = p, + foreground = foreground, + ) def draw_cuboid( self, cell_data: NDArray, - center: ArrayLike, - dimensions: ArrayLike, foreground: Sequence[foreground_t] | foreground_t, + *, + x: ExtentProtocol | ExtentDict, + y: ExtentProtocol | ExtentDict, + z: ExtentProtocol | ExtentDict, ) -> None: """ Draw an axis-aligned cuboid @@ -293,23 +287,30 @@ class GridDrawMixin(GridPosMixin): sizes of the cuboid foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. """ - dimensions = numpy.asarray(dimensions) - p = numpy.array([[-dimensions[0], +dimensions[1]], - [+dimensions[0], +dimensions[1]], - [+dimensions[0], -dimensions[1]], - [-dimensions[0], -dimensions[1]]], dtype=float) * 0.5 - thickness = dimensions[2] - self.draw_polygon(cell_data, 2, center, p, thickness, foreground) + if isinstance(x, dict): + x = Extent(**x) + if isinstance(y, dict): + y = Extent(**y) + if isinstance(z, dict): + z = Extent(**z) + + center = numpy.asarray([x.center, y.center, z.center]) + + p = numpy.array([[x.min, y.max], + [x.max, y.max], + [x.max, y.min], + [x.min, y.min]], dtype=float) + slab = Slab(axis=2, center=z.center, span=z.span) + self.draw_polygon(cell_data=cell_data, slab=slab, polygon=p, foreground=foreground) def draw_cylinder( self, cell_data: NDArray, - surface_normal: int, - center: ArrayLike, + h: SlabProtocol | SlabDict, radius: float, - thickness: float, num_points: int, + center2d: ArrayLike, foreground: Sequence[foreground_t] | foreground_t, ) -> None: """ @@ -317,18 +318,19 @@ class GridDrawMixin(GridPosMixin): Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: 3-element ndarray or list specifying the cylinder's center - radius: cylinder radius - thickness: Thickness of the layer to draw + h: + radius: num_points: The circle is approximated by a polygon with `num_points` vertices + center2d: foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. """ - theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False) - x = radius * numpy.sin(theta) - y = radius * numpy.cos(theta) - polygon = numpy.hstack((x[:, None], y[:, None])) - self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground) + if isinstance(h, dict): + h = Slab(**h) + + theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False)[:, None] + xy0 = numpy.hstack((numpy.sin(theta), numpy.cos(theta))) + polygon = radius * xy0 + self.draw_polygon(cell_data=cell_data, slab=h, polygon=polygon, foreground=foreground, offset2d=center2d) def draw_extrude_rectangle( @@ -349,10 +351,10 @@ class GridDrawMixin(GridPosMixin): polarity: +1 or -1, direction along axis to extrude in distance: How far to extrude """ - s = numpy.sign(polarity) + sgn = numpy.sign(polarity) - rectangle = numpy.array(rectangle, dtype=float) - if s == 0: + rectangle = numpy.asarray(rectangle, dtype=float) + if sgn == 0: raise GridError('0 is not a valid polarity') if direction not in range(3): raise GridError(f'Invalid direction: {direction}') @@ -360,37 +362,38 @@ class GridDrawMixin(GridPosMixin): raise GridError('Rectangle entries along extrusion direction do not match.') center = rectangle.sum(axis=0) / 2.0 - center[direction] += s * distance / 2.0 + center[direction] += sgn * distance / 2.0 surface = numpy.delete(range(3), direction) dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface] - p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5, - numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T + poly = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5, + numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T thickness = distance foreground_func = [] - for i, grid in enumerate(cell_data): - z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction] + for ii, grid in enumerate(cell_data): + zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction] - ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)] + ind = [int(numpy.floor(zz)) if dd == direction else slice(None) for dd in range(3)] - fpart = z - numpy.floor(z) - mult = [1 - fpart, fpart][::s] # reverses if s negative + fpart = zz - numpy.floor(zz) + mult = [1 - fpart, fpart][::sgn] # reverses if s negative foreground = mult[0] * grid[tuple(ind)] ind[direction] += 1 # type: ignore #(known safe) foreground += mult[1] * grid[tuple(ind)] - def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 + def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 # transform from natural position to index - xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i) + xyzi = numpy.array([self.pos2ind(qrs, which_shifts=ii) for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64) # reshape to original shape and keep only in-plane components - qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface) + qi, ri = (numpy.reshape(xyzi[:, kk], xs.shape) for kk in surface) return foreground[qi, ri] foreground_func.append(f_foreground) - self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func) + slab = Slab(axis=direction, center=center[direction], span=thickness) + self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface]) diff --git a/gridlock/error.py b/gridlock/error.py deleted file mode 100644 index 3974e9c..0000000 --- a/gridlock/error.py +++ /dev/null @@ -1,2 +0,0 @@ -class GridError(Exception): - pass diff --git a/gridlock/examples/ex0.py b/gridlock/examples/ex0.py index d13a8cf..4ff2fb9 100644 --- a/gridlock/examples/ex0.py +++ b/gridlock/examples/ex0.py @@ -6,18 +6,18 @@ if __name__ == '__main__': # xyz = [numpy.arange(-5.0, 6.0), numpy.arange(-4.0, 5.0), [-1.0, 1.0]] # eg = Grid(xyz) # egc = Grid.allocate(0.0) - # # eg.draw_slab(egc, surface_normal=2, center=0, thickness=10, foreground=2) - # eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=4, - # thickness=10, num_points=1000, foreground=1) - # eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2) + # # eg.draw_slab(egc, slab=dict(axis=2, center=0, span=10), foreground=2) + # eg.draw_cylinder(egc, h=slab(axis=2, center=0, span=10), + # center2d=[0, 0], radius=4, thickness=10, num_points=1000, foreground=1) + # eg.visualize_slice(egc, plane=dict(z=0), which_shifts=2) # xyz2 = [numpy.arange(-5.0, 6.0), [-1.0, 1.0], numpy.arange(-4.0, 5.0)] # eg2 = Grid(xyz2) # eg2c = Grid.allocate(0.0) - # # eg2.draw_slab(eg2c, surface_normal=2, center=0, thickness=10, foreground=2) - # eg2.draw_cylinder(eg2c, surface_normal=1, center=[0, 0, 0], - # radius=4, thickness=10, num_points=1000, foreground=1.0) - # eg2.visualize_slice(eg2c, surface_normal=1, center=0, which_shifts=1) + # # eg2.draw_slab(eg2c, slab=dict(axis=2, center=0, span=10), foreground=2) + # eg2.draw_cylinder(eg2c, h=slab(axis=1, center=0, span=10), center2d=[0, 0], + # radius=4, num_points=1000, foreground=1.0) + # eg2.visualize_slice(eg2c, plane=dict(y=0), which_shifts=1) # n = 20 # m = 3 @@ -36,9 +36,20 @@ if __name__ == '__main__': egc = eg.allocate(0) # eg.draw_slab(Direction.z, 0, 10, 2) eg.save('/home/jan/Desktop/test.pickle') - eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0, - thickness=10, num_points=1000, foreground=1) - eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]], - direction=1, polarity=+1, distance=5) - eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2) + eg.draw_cylinder( + egc, + h=dict(axis='z', center=0, span=10), + center2d=[0, 0], + radius=2.0, + num_points=1000, + foreground=1, + ) + eg.draw_extrude_rectangle( + egc, + rectangle=[[-2, 1, -1], [0, 1, 1]], + direction=1, + polarity=+1, + distance=5, + ) + eg.visualize_slice(egc, plane=dict(z=0), which_shifts=2) eg.visualize_isosurface(egc, which_shifts=2) diff --git a/gridlock/read.py b/gridlock/read.py index 0f46836..707251a 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -6,7 +6,7 @@ from typing import Any, TYPE_CHECKING import numpy from numpy.typing import NDArray -from . import GridError +from .utils import GridError, Plane, PlaneDict, PlaneProtocol from .position import GridPosMixin if TYPE_CHECKING: @@ -23,27 +23,25 @@ class GridReadMixin(GridPosMixin): def get_slice( self, cell_data: NDArray, - surface_normal: int, - center: float, + plane: PlaneProtocol | PlaneDict, which_shifts: int = 0, sample_period: int = 1 ) -> NDArray: """ Retrieve a slice of a grid. - Interpolates if given a position between two planes. + Interpolates if given a position between two grid planes. Args: cell_data: Cell data to slice - surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. - center: Scalar specifying position along surface_normal axis. + plane: Axis and position (`Plane`) of the plane to read. which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) Returns: Array containing the portion of the grid. """ - if numpy.size(center) != 1 or not numpy.isreal(center): - raise GridError('center must be a real scalar') + if isinstance(plane, dict): + plane = Plane(**plane) sp = round(sample_period) if sp <= 0: @@ -52,15 +50,12 @@ class GridReadMixin(GridPosMixin): if numpy.size(which_shifts) != 1 or which_shifts < 0: raise GridError('Invalid which_shifts') - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') - - surface = numpy.delete(range(3), surface_normal) + surface = numpy.delete(range(3), plane.axis) # Extract indices and weights of planes - center3 = numpy.insert([0, 0], surface_normal, (center,)) + center3 = numpy.insert([0, 0], plane.axis, (plane.pos,)) center_index = self.pos2ind(center3, which_shifts, - round_ind=False, check_bounds=False)[surface_normal] + round_ind=False, check_bounds=False)[plane.axis] centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) if len(centers) == 2: fpart = center_index - numpy.floor(center_index) @@ -68,14 +63,14 @@ class GridReadMixin(GridPosMixin): else: w = [1] - c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1]) - if center < c_min or center > c_max: + c_min, c_max = (self.xyz[plane.axis][i] for i in [0, -1]) + if plane.pos < c_min or plane.pos > c_max: raise GridError('Coordinate of selected plane must be within simulation domain') # Extract grid values from planes above and below visualized slice sliced_grid = numpy.zeros(self.shape[surface]) for ci, weight in zip(centers, w, strict=True): - s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3)) + s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3)) sliced_grid += weight * cell_data[which_shifts][tuple(s)] # Remove extra dimensions @@ -87,20 +82,19 @@ class GridReadMixin(GridPosMixin): def visualize_slice( self, cell_data: NDArray, - surface_normal: int, - center: float, + plane: PlaneProtocol | PlaneDict, which_shifts: int = 0, sample_period: int = 1, finalize: bool = True, pcolormesh_args: dict[str, Any] | None = None, - ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: + ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: """ Visualize a slice of a grid. - Interpolates if given a position between two planes. + Interpolates if given a position between two grid planes. Args: - surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. - center: Scalar specifying position along surface_normal axis. + cell_data: Cell data to visualize + plane: Axis and position (`Plane`) of the plane to read. which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` @@ -110,16 +104,20 @@ class GridReadMixin(GridPosMixin): """ from matplotlib import pyplot + if isinstance(plane, dict): + plane = Plane(**plane) + if pcolormesh_args is None: pcolormesh_args = {} - grid_slice = self.get_slice(cell_data=cell_data, - surface_normal=surface_normal, - center=center, - which_shifts=which_shifts, - sample_period=sample_period) + grid_slice = self.get_slice( + cell_data=cell_data, + plane=plane, + which_shifts=which_shifts, + sample_period=sample_period, + ) - surface = numpy.delete(range(3), surface_normal) + surface = numpy.delete(range(3), plane.axis) x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') @@ -145,7 +143,7 @@ class GridReadMixin(GridPosMixin): sample_period: int = 1, show_edges: bool = True, finalize: bool = True, - ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: + ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: """ Draw an isosurface plot of the device. @@ -183,18 +181,18 @@ class GridReadMixin(GridPosMixin): fig = pyplot.figure() ax = fig.add_subplot(111, projection='3d') if show_edges: - ax.plot_trisurf(xs, ys, faces, zs) + ax.plot_trisurf(xs, ys, faces, zs) # type: ignore else: - ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') + ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') # type: ignore # Add a fake plot of a cube to force the axes to be equal lengths max_range = numpy.array([xs.max() - xs.min(), ys.max() - ys.min(), zs.max() - zs.min()], dtype=float).max() mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2] - xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min()) - ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min()) - zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min()) + xbs = 0.5 * max_range * mg[0].ravel() + 0.5 * (xs.max() + xs.min()) + ybs = 0.5 * max_range * mg[1].ravel() + 0.5 * (ys.max() + ys.min()) + zbs = 0.5 * max_range * mg[2].ravel() + 0.5 * (zs.max() + zs.min()) # Comment or uncomment following both lines to test the fake bounding box: for xb, yb, zb in zip(xbs, ybs, zbs, strict=True): ax.plot([xb], [yb], [zb], 'w') diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 7a70211..8d9ca92 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -2,7 +2,7 @@ import numpy from numpy.testing import assert_allclose #, assert_array_equal -from .. import Grid +from .. import Grid, Extent #, Slab, Plane def test_draw_oncenter_2x2() -> None: @@ -12,7 +12,13 @@ def test_draw_oncenter_2x2() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0, 0, 0], dimensions=[1, 1, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(center=0, span=1), + y=Extent(center=0, span=1), + z=dict(center=0, span=10), + foreground=1, + ) correct = numpy.array([[0.25, 0.25], [0.25, 0.25]])[None, :, :, None] @@ -27,7 +33,13 @@ def test_draw_ongrid_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0, 0, 0], dimensions=[2, 2, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(center=0, span=2), + y=dict(min=-1, max=1), + z=dict(center=0, min=-5), + foreground=1, + ) correct = numpy.array([[0, 0, 0, 0], [0, 1, 1, 0], @@ -44,7 +56,13 @@ def test_draw_xshift_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0.5, 0, 0], dimensions=[1.5, 2, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(center=0.5, span=1.5), + y=dict(min=-1, max=1), + z=dict(center=0, span=10), + foreground=1, + ) correct = numpy.array([[0, 0, 0, 0], [0, 0.25, 0.25, 0], @@ -61,7 +79,13 @@ def test_draw_yshift_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0, 0.5, 0], dimensions=[2, 1.5, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(min=-1, max=1), + y=dict(center=0.5, span=1.5), + z=dict(center=0, span=10), + foreground=1, + ) correct = numpy.array([[0, 0, 0, 0], [0, 0.25, 1, 0.25], @@ -78,7 +102,13 @@ def test_draw_2shift_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0.5, 0, 0], dimensions=[1.5, 1, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(center=0.5, span=1.5), + y=dict(min=-0.5, max=0.5), + z=dict(center=0, span=10), + foreground=1, + ) correct = numpy.array([[0, 0, 0, 0], [0, 0.125, 0.125, 0], diff --git a/gridlock/utils.py b/gridlock/utils.py new file mode 100644 index 0000000..7a12035 --- /dev/null +++ b/gridlock/utils.py @@ -0,0 +1,201 @@ +from typing import Protocol, TypedDict, runtime_checkable +from dataclasses import dataclass + + +class GridError(Exception): + """ Base error type for `gridlock` """ + pass + + +class ExtentDict(TypedDict, total=False): + min: float + center: float + max: float + span: float + + +@runtime_checkable +class ExtentProtocol(Protocol): + center: float + span: float + + @property + def max(self) -> float: ... + + @property + def min(self) -> float: ... + + +@dataclass(init=False, slots=True) +class Extent(ExtentProtocol): + center: float + span: float + + @property + def max(self) -> float: + return self.center + self.span / 2 + + @property + def min(self) -> float: + return self.center - self.span / 2 + + def __init__( + self, + *, + min: float | None = None, + center: float | None = None, + max: float | None = None, + span: float | None = None, + ) -> None: + if sum(cc is None for cc in (min, center, max, span)) != 2: + raise GridError('Exactly two of min, center, max, span must be None!') + + if span is None: + if center is None: + assert min is not None + assert max is not None + assert max >= min + center = 0.5 * (max + min) + span = max - min + elif max is None: + assert min is not None + assert center is not None + span = 2 * (center - min) + elif min is None: + assert center is not None + assert max is not None + span = 2 * (max - center) + else: # noqa: PLR5501 + if center is not None: + pass + elif max is None: + assert min is not None + assert span is not None + center = min + 0.5 * span + elif min is None: + assert max is not None + assert span is not None + center = max - 0.5 * span + + assert center is not None + assert span is not None + if hasattr(center, '__len__'): + assert len(center) == 1 + if hasattr(span, '__len__'): + assert len(span) == 1 + self.center = center + self.span = span + + +class SlabDict(TypedDict, total=False): + min: float + center: float + max: float + span: float + axis: int | str + + +@runtime_checkable +class SlabProtocol(ExtentProtocol, Protocol): + axis: int + center: float + span: float + + @property + def max(self) -> float: ... + + @property + def min(self) -> float: ... + + +@dataclass(init=False, slots=True) +class Slab(Extent, SlabProtocol): + axis: int + + def __init__( + self, + axis: int | str, + *, + min: float | None = None, + center: float | None = None, + max: float | None = None, + span: float | None = None, + ) -> None: + Extent.__init__(self, min=min, center=center, max=max, span=span) + + if isinstance(axis, str): + axis_int = 'xyz'.find(axis.lower()) + else: + axis_int = axis + if axis_int not in range(3): + raise GridError(f'Invalid axis (slab normal direction): {axis}') + self.axis = axis_int + + def as_plane(self, where: str) -> 'Plane': + if where == 'center': + return Plane(axis=self.axis, pos=self.center) + if where == 'min': + return Plane(axis=self.axis, pos=self.min) + if where == 'max': + return Plane(axis=self.axis, pos=self.max) + raise GridError(f'Invalid {where=}') + + +class PlaneDict(TypedDict, total=False): + x: float + y: float + z: float + axis: int + pos: float + + +@runtime_checkable +class PlaneProtocol(Protocol): + axis: int + pos: float + + +@dataclass(init=False, slots=True) +class Plane(PlaneProtocol): + axis: int + pos: float + + def __init__( + self, + *, + axis: int | str | None = None, + pos: float | None = None, + x: float | None = None, + y: float | None = None, + z: float | None = None, + ) -> None: + xx = x + yy = y + zz = z + + if sum(aa is not None for aa in (pos, xx, yy, zz)) != 1: + raise GridError('Exactly one of pos, x, y, z must be non-None!') + if (axis is None) != (pos is None): + raise GridError('Either both or neither of `axis` and `pos` must be defined.') + + if isinstance(axis, str): + axis_int = 'xyz'.find(axis.lower()) + elif axis is None: + axis_int = (xx is None, yy is None, zz is None).index(False) + else: + axis_int = axis + + if axis_int not in range(3): + raise GridError(f'Invalid axis (slab normal direction): {axis=} {x=} {y=} {z=}') + self.axis = axis_int + + if pos is not None: + cpos = pos + else: + cpos = (xx, yy, zz)[axis_int] + assert cpos is not None + + if hasattr(cpos, '__len__'): + assert len(cpos) == 1 + self.pos = cpos + From b739534cfe35ed384116c8843935ad8fec0056da Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Wed, 12 Mar 2025 23:17:29 -0700 Subject: [PATCH 26/48] [draw] fix missing brackets --- gridlock/draw.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 4930231..0b93d20 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -252,8 +252,8 @@ class GridDrawMixin(GridPosMixin): u_min, u_max = self.exyz[surface[0]][[0, -1]] v_min, v_max = self.exyz[surface[1]][[0, -1]] - margin = 4 * numpy.max(self.dxyz[surface[0]].max(), - self.dxyz[surface[1]].max()) + margin = 4 * numpy.max([self.dxyz[surface[0]].max(), + self.dxyz[surface[1]].max()]) p = numpy.array([[u_min - margin, v_max + margin], [u_max + margin, v_max + margin], From 13c12b0a6adb2555c9fff55cb512af6701ec54de Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Wed, 16 Apr 2025 21:25:43 -0700 Subject: [PATCH 27/48] update docs to reflect new args --- gridlock/draw.py | 34 ++++++++++++++++++---------------- gridlock/utils.py | 37 +++++++++++++++++++++++++++++++++++-- pyproject.toml | 1 - 3 files changed, 53 insertions(+), 19 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 0b93d20..9ba4623 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -21,30 +21,31 @@ foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray] foreground_t = float | foreground_callable_t - class GridDrawMixin(GridPosMixin): def draw_polygons( self, cell_data: NDArray, + foreground: Sequence[foreground_t] | foreground_t, slab: SlabProtocol | SlabDict, polygons: Sequence[ArrayLike], - foreground: Sequence[foreground_t] | foreground_t, *, offset2d: ArrayLike = (0, 0), ) -> None: """ - Draw polygons on an axis-aligned plane. + Draw polygons on an axis-aligned slab. Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - center: 3-element ndarray or list specifying an offset applied to all the polygons - polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon - (non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each - polygon must have at least 3 vertices. foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list of any of these (1 per grid). Callable values should take an ndarray the shape of the grid and return an ndarray of equal shape containing the foreground value at the given x, y, and z (natural, not grid coordinates). + slab: `Slab` or slab-like dict specifying the slab in which the polygons will be drawn. + polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon + (non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each + polygon must have at least 3 vertices. + offset2d: 2D offset to apply to polygon coordinates -- this offset is added directly + to the given polygon vertex coordinates. Default (0, 0). Raises: GridError @@ -200,9 +201,9 @@ class GridDrawMixin(GridPosMixin): def draw_polygon( self, cell_data: NDArray, + foreground: Sequence[foreground_t] | foreground_t, slab: SlabProtocol | SlabDict, polygon: ArrayLike, - foreground: Sequence[foreground_t] | foreground_t, *, offset2d: ArrayLike = (0, 0), ) -> None: @@ -211,11 +212,13 @@ class GridDrawMixin(GridPosMixin): Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - slab: `Slab` in which to draw polygons. + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + slab: `Slab` or slab-like dict specifying the slab in which the polygon will be drawn. polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Must have at least 3 vertices. - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + offset2d: 2D offset to apply to polygon coordinates -- this offset is added directly + to the given polygon vertex coordinates. Default (0, 0). """ self.draw_polygons( cell_data = cell_data, @@ -229,17 +232,16 @@ class GridDrawMixin(GridPosMixin): def draw_slab( self, cell_data: NDArray, - slab: SlabProtocol | SlabDict, foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, ) -> None: """ Draw an axis-aligned infinite slab. Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - slab: - thickness: Thickness of the layer to draw foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + slab: `Slab` or slab-like dict (geometrical slab specification) """ if isinstance(slab, dict): slab = Slab(**slab) @@ -282,10 +284,10 @@ class GridDrawMixin(GridPosMixin): Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - center: 3-element ndarray or list specifying the cuboid's center - dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge - sizes of the cuboid foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + x: `Extent` or extent-like dict specifying the x-extent of the cuboid. + y: `Extent` or extent-like dict specifying the y-extent of the cuboid. + z: `Extent` or extent-like dict specifying the z-extent of the cuboid. """ if isinstance(x, dict): x = Extent(**x) diff --git a/gridlock/utils.py b/gridlock/utils.py index 7a12035..8a8f11d 100644 --- a/gridlock/utils.py +++ b/gridlock/utils.py @@ -1,4 +1,4 @@ -from typing import Protocol, TypedDict, runtime_checkable +from typing import Protocol, TypedDict, runtime_checkable, cast from dataclasses import dataclass @@ -8,6 +8,10 @@ class GridError(Exception): class ExtentDict(TypedDict, total=False): + """ + Geometrical definition of an extent (1D bounded region) + Must contain exactly two of `min`, `max`, `center`, or `span`. + """ min: float center: float max: float @@ -16,6 +20,9 @@ class ExtentDict(TypedDict, total=False): @runtime_checkable class ExtentProtocol(Protocol): + """ + Anything that looks like an `Extent` + """ center: float span: float @@ -28,6 +35,10 @@ class ExtentProtocol(Protocol): @dataclass(init=False, slots=True) class Extent(ExtentProtocol): + """ + Geometrical definition of an extent (1D bounded region) + May be constructed with any two of `min`, `max`, `center`, or `span`. + """ center: float span: float @@ -88,6 +99,10 @@ class Extent(ExtentProtocol): class SlabDict(TypedDict, total=False): + """ + Geometrical definition of a slab (3D region bounded on one axis only) + Must contain `axis` plus any two of `min`, `max`, `center`, or `span`. + """ min: float center: float max: float @@ -97,6 +112,9 @@ class SlabDict(TypedDict, total=False): @runtime_checkable class SlabProtocol(ExtentProtocol, Protocol): + """ + Anything that looks like a `Slab` + """ axis: int center: float span: float @@ -110,6 +128,10 @@ class SlabProtocol(ExtentProtocol, Protocol): @dataclass(init=False, slots=True) class Slab(Extent, SlabProtocol): + """ + Geometrical definition of a slab (3D region bounded on one axis only) + May be constructed with `axis` (bounded axis) plus any two of `min`, `max`, `center`, or `span`. + """ axis: int def __init__( @@ -142,6 +164,10 @@ class Slab(Extent, SlabProtocol): class PlaneDict(TypedDict, total=False): + """ + Geometrical definition of a plane (2D unbounded region in 3D space) + Must contain exactly one of `x`, `y`, `z`, or both `axis` and `pos` + """ x: float y: float z: float @@ -151,12 +177,19 @@ class PlaneDict(TypedDict, total=False): @runtime_checkable class PlaneProtocol(Protocol): + """ + Anything that looks like a `Plane` + """ axis: int pos: float @dataclass(init=False, slots=True) class Plane(PlaneProtocol): + """ + Geometrical definition of a plane (2D unbounded region in 3D space) + May be constructed with any of `x=4`, `y=5`, `z=-5`, or `axis=2, pos=-5`. + """ axis: int pos: float @@ -192,7 +225,7 @@ class Plane(PlaneProtocol): if pos is not None: cpos = pos else: - cpos = (xx, yy, zz)[axis_int] + cpos = cast('float', (xx, yy, zz)[axis_int]) assert cpos is not None if hasattr(cpos, '__len__'): diff --git a/pyproject.toml b/pyproject.toml index 1df2e8b..03d0d19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,6 @@ lint.ignore = [ "ANN002", # *args "ANN003", # **kwargs "ANN401", # Any - "ANN101", # self: Self "SIM108", # single-line if / else assignment "RET504", # x=y+z; return x "PIE790", # unnecessary pass From be7c26c1d1668564f5d2018d02554616bbbebec4 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Wed, 16 Apr 2025 21:28:43 -0700 Subject: [PATCH 28/48] bump version to v2.0 -- major arg rework for drawing/reading --- gridlock/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 120291f..2f39696 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -34,5 +34,5 @@ from .grid import Grid as Grid __author__ = 'Jan Petykiewicz' -__version__ = '1.2' +__version__ = '2.0' version = __version__ From 6802e57fa9fb26d5017b55b9bd46725db38ae04b Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 19:18:08 -0700 Subject: [PATCH 29/48] [read] add missing arg to docstring --- gridlock/read.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gridlock/read.py b/gridlock/read.py index 707251a..cfc8f3d 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -98,6 +98,7 @@ class GridReadMixin(GridPosMixin): which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + pcolormesh_args: Args passed through to matplotlib `pcolormesh()` Returns: (Figure, Axes) From 21304f0dbfce8b089a44a7dfb13c072b57839bfa Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 19:18:49 -0700 Subject: [PATCH 30/48] [read] add option to visualize on preexisting axes --- gridlock/read.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/gridlock/read.py b/gridlock/read.py index cfc8f3d..b5840e7 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -87,6 +87,7 @@ class GridReadMixin(GridPosMixin): sample_period: int = 1, finalize: bool = True, pcolormesh_args: dict[str, Any] | None = None, + ax: 'matplotlib.axes.Axes' | None = None, ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: """ Visualize a slice of a grid. @@ -99,6 +100,7 @@ class GridReadMixin(GridPosMixin): sample_period: Period for down-sampling the image. Default 1 (disabled) finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` pcolormesh_args: Args passed through to matplotlib `pcolormesh()` + ax: If provided, plot to these axes (instead of creating a new figure & axes) Returns: (Figure, Axes) @@ -112,10 +114,10 @@ class GridReadMixin(GridPosMixin): pcolormesh_args = {} grid_slice = self.get_slice( - cell_data=cell_data, - plane=plane, - which_shifts=which_shifts, - sample_period=sample_period, + cell_data = cell_data, + plane = plane, + which_shifts = which_shifts, + sample_period = sample_period, ) surface = numpy.delete(range(3), plane.axis) @@ -124,7 +126,10 @@ class GridReadMixin(GridPosMixin): xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') x_label, y_label = ('xyz'[a] for a in surface) - fig, ax = pyplot.subplots() + if ax is None: + fig, ax = pyplot.subplots() + else: + fig = ax.figure mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) fig.colorbar(mappable) ax.set_aspect('equal', adjustable='box') From 68520b871018c3d86254fa1fa87faf4838351d11 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 19:19:28 -0700 Subject: [PATCH 31/48] [read] add visualize_edges() --- gridlock/read.py | 74 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/gridlock/read.py b/gridlock/read.py index b5840e7..600227d 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -135,12 +135,86 @@ class GridReadMixin(GridPosMixin): ax.set_aspect('equal', adjustable='box') ax.set_xlabel(x_label) ax.set_ylabel(y_label) + if finalize: pyplot.show() return fig, ax + def visualize_edges( + self, + cell_data: NDArray, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + finalize: bool = True, + contour_args: dict[str, Any] | None = None, + ax: 'matplotlib.axes.Axes' | None = None, + level_fraction: float = 0.7, + ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: + """ + Visualize the edges of a grid slice. + This is intended as an overlay on top of visualize_slice (e.g. showing epsilon boundaries + on an E-field plot). + + Interpolates if given a position between two grid planes. + + Args: + cell_data: Cell data to visualize + plane: Axis and position (`Plane`) of the plane to read. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) + finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + pcolormesh_args: Args passed through to matplotlib `pcolormesh()` + ax: If provided, plot to these axes (instead of creating a new figure & axes) + level_fraction: Value between 0 and 1 which tunes how many contours are generated. + 1 indicates that every possible step should have its own contour. + + Returns: + (Figure, Axes) + """ + from matplotlib import pyplot + + if level_fraction > 1: + raise GridError(f'{level_fraction=} must be between 0 and 1') + + if isinstance(plane, dict): + plane = Plane(**plane) + + if pcolormesh_args is None: + pcolormesh_args = {} + + grid_slice = self.get_slice( + cell_data = cell_data, + plane = plane, + which_shifts = which_shifts, + sample_period = sample_period, + ) + cvals, cval_counts = numpy.unique(grid_slice, return_counts=True) + if cvals.size == 1: + levels = [cvals[0] + 1] + else: + cval_order = numpy.argsort(cval_counts)[::-1] + level_count = 2 + while cval_counts[cval_order[:level_count]].sum() < level_fraction: + level_count += 1 + ctr_levels = cvals[cval_order[:level_count]] + levels = numpy.diff(ctr_levels[::-1]) + ctr_levels[:0:-1] + + surface = numpy.delete(range(3), plane.axis) + + if ax is None: + fig, ax = pyplot.subplots() + else: + fig = ax.figure + xc, yc = (self.shifted_exyz(which_shifts)[a] for a in surface) + xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij') + + mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) + + return fig, ax + + def visualize_isosurface( self, cell_data: NDArray, From 7cac73bcb400021289350a52215083e88c337cd7 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 19:25:39 -0700 Subject: [PATCH 32/48] [draw] add missing code for finalize --- gridlock/read.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gridlock/read.py b/gridlock/read.py index 600227d..4f39432 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -212,6 +212,9 @@ class GridReadMixin(GridPosMixin): mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) + if finalize: + pyplot.show() + return fig, ax From 16a76e0122845d029c7a4a3c2896ab0606553576 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 19:26:59 -0700 Subject: [PATCH 33/48] [read] make visualize_edges more friendly for overlay by default --- gridlock/read.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/read.py b/gridlock/read.py index 4f39432..28afdd2 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -182,7 +182,7 @@ class GridReadMixin(GridPosMixin): plane = Plane(**plane) if pcolormesh_args is None: - pcolormesh_args = {} + pcolormesh_args = dict(alpha=0.8, colors='gray') grid_slice = self.get_slice( cell_data = cell_data, From 64752873fbde2392b91150f8b26740dfdc633e66 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 19:28:40 -0700 Subject: [PATCH 34/48] [read] fix type spec --- gridlock/read.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gridlock/read.py b/gridlock/read.py index 28afdd2..44f1c8a 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -87,7 +87,7 @@ class GridReadMixin(GridPosMixin): sample_period: int = 1, finalize: bool = True, pcolormesh_args: dict[str, Any] | None = None, - ax: 'matplotlib.axes.Axes' | None = None, + ax: 'matplotlib.axes.Axes | None' = None, ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: """ Visualize a slice of a grid. @@ -149,7 +149,7 @@ class GridReadMixin(GridPosMixin): which_shifts: int = 0, finalize: bool = True, contour_args: dict[str, Any] | None = None, - ax: 'matplotlib.axes.Axes' | None = None, + ax: 'matplotlib.axes.Axes | None' = None, level_fraction: float = 0.7, ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: """ From f4818fd55450463a63cfa8bcc26ebbcecfce7a88 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 20:32:01 -0700 Subject: [PATCH 35/48] [draw] fix arg naming --- gridlock/read.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gridlock/read.py b/gridlock/read.py index 44f1c8a..7b4de1e 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -165,7 +165,7 @@ class GridReadMixin(GridPosMixin): which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - pcolormesh_args: Args passed through to matplotlib `pcolormesh()` + contour_args: Args passed through to matplotlib `pcolormesh()` ax: If provided, plot to these axes (instead of creating a new figure & axes) level_fraction: Value between 0 and 1 which tunes how many contours are generated. 1 indicates that every possible step should have its own contour. @@ -181,8 +181,8 @@ class GridReadMixin(GridPosMixin): if isinstance(plane, dict): plane = Plane(**plane) - if pcolormesh_args is None: - pcolormesh_args = dict(alpha=0.8, colors='gray') + if contour_args is None: + contour_args = dict(alpha=0.8, colors='gray') grid_slice = self.get_slice( cell_data = cell_data, From 32b6c207dcf703f6c695dc9fae3f7615bd5e8f15 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 22:24:15 -0700 Subject: [PATCH 36/48] [read] more fixup for visualize_edges --- gridlock/read.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gridlock/read.py b/gridlock/read.py index 7b4de1e..503e996 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -147,6 +147,7 @@ class GridReadMixin(GridPosMixin): cell_data: NDArray, plane: PlaneProtocol | PlaneDict, which_shifts: int = 0, + sample_period: int = 1, finalize: bool = True, contour_args: dict[str, Any] | None = None, ax: 'matplotlib.axes.Axes | None' = None, @@ -207,7 +208,7 @@ class GridReadMixin(GridPosMixin): fig, ax = pyplot.subplots() else: fig = ax.figure - xc, yc = (self.shifted_exyz(which_shifts)[a] for a in surface) + xc, yc = (self.shifted_xyz(which_shifts)[a] for a in surface) xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij') mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) From 43d5fa8b4f2e1d48a2d35de02ebc2a7adf591168 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:21:22 -0700 Subject: [PATCH 37/48] [draw] fix handling of Nx3 vertex arrays --- gridlock/draw.py | 11 +++++------ gridlock/test/test_grid.py | 35 +++++++++++++++++++++++++++++++++-- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 9ba4623..864468f 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -61,16 +61,18 @@ class GridDrawMixin(GridPosMixin): for ii in range(len(poly_list)): polygon = poly_list[ii] malformed = f'Malformed polygon: ({ii})' + if polygon.ndim != 2: + raise GridError(malformed + 'must be a 2-dimensional ndarray') if polygon.shape[1] not in (2, 3): raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') if polygon.shape[1] == 3: - polygon = polygon[surface, :] + if numpy.unique(polygon[:, slab.axis]).size != 1: + raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis]) + polygon = polygon[:, surface] poly_list[ii] = polygon if not polygon.shape[0] > 2: raise GridError(malformed + 'must consist of more than 2 points') - if polygon.ndim > 2 and not numpy.unique(polygon[:, slab.axis]).size == 1: - raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis]) # Broadcast foreground where necessary foregrounds: Sequence[foreground_callable_t] | Sequence[float] @@ -296,8 +298,6 @@ class GridDrawMixin(GridPosMixin): if isinstance(z, dict): z = Extent(**z) - center = numpy.asarray([x.center, y.center, z.center]) - p = numpy.array([[x.min, y.max], [x.max, y.max], [x.max, y.min], @@ -398,4 +398,3 @@ class GridDrawMixin(GridPosMixin): slab = Slab(axis=direction, center=center[direction], span=thickness) self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface]) - diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 8d9ca92..6cb9edc 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,8 +1,8 @@ -# import pytest +import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal -from .. import Grid, Extent #, Slab, Plane +from .. import Grid, Extent, GridError, Plane def test_draw_oncenter_2x2() -> None: @@ -116,3 +116,34 @@ def test_draw_2shift_4x4() -> None: [0, 0.125, 0.125, 0]])[None, :, :, None] assert_allclose(arr, correct) + + +def test_draw_polygon_accepts_coplanar_nx3_vertices() -> None: + grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) + arr_2d = grid.allocate(0) + arr_3d = grid.allocate(0) + slab = dict(axis='z', center=0.5, span=1.0) + + polygon_2d = numpy.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=float) + polygon_3d = numpy.array([[0, 0, 0.5], + [1, 0, 0.5], + [1, 1, 0.5], + [0, 1, 0.5]], dtype=float) + + grid.draw_polygon(arr_2d, slab=slab, polygon=polygon_2d, foreground=1) + grid.draw_polygon(arr_3d, slab=slab, polygon=polygon_3d, foreground=1) + + assert_allclose(arr_3d, arr_2d) + + +def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None: + grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) + arr = grid.allocate(0) + polygon = numpy.array([[0, 0, 0.5], + [1, 0, 0.5], + [1, 1, 0.75], + [0, 1, 0.5]], dtype=float) + + with pytest.raises(GridError): + grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1) + From 1cc47da386d69a56938c4d62629f74afd2d20966 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:25:13 -0700 Subject: [PATCH 38/48] [ind2pos] fix rounding and bounds --- gridlock/position.py | 4 ++-- gridlock/test/test_grid.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/gridlock/position.py b/gridlock/position.py index b705b99..6344ea4 100644 --- a/gridlock/position.py +++ b/gridlock/position.py @@ -47,13 +47,13 @@ class GridPosMixin(GridBase): else: low_bound = -0.5 high_bound = -0.5 - if (ind < low_bound).any() or (ind > self.shape - high_bound).any(): + if (ind < low_bound).any() or (ind > self.shape + high_bound).any(): raise GridError(f'Position outside of grid: {ind}') if round_ind: rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) sxyz = self.shifted_xyz(which_shifts) - position = [sxyz[a][rind[a]].astype(int) for a in range(3)] + position = [sxyz[a][rind[a]] for a in range(3)] else: sexyz = self.shifted_exyz(which_shifts) position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a]) diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 6cb9edc..a9e3d9e 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -118,6 +118,27 @@ def test_draw_2shift_4x4() -> None: assert_allclose(arr, correct) +def test_ind2pos_round_preserves_float_centers() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]]) + + pos = grid.ind2pos(numpy.array([1, 0, 0]), which_shifts=0) + + assert_allclose(pos, [2.0, 1.0, 0.5]) + + +def test_ind2pos_enforces_bounds_for_rounded_and_fractional_indices() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]]) + + with pytest.raises(GridError): + grid.ind2pos(numpy.array([2, 0, 0]), which_shifts=0, check_bounds=True) + + edge_pos = grid.ind2pos(numpy.array([1.5, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True) + assert_allclose(edge_pos, [3.0, 2.0, 1.0]) + + with pytest.raises(GridError): + grid.ind2pos(numpy.array([1.6, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True) + + def test_draw_polygon_accepts_coplanar_nx3_vertices() -> None: grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) arr_2d = grid.allocate(0) From 526b9e1666b55c59cbf2fa684e9f20dc500b7ac8 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:25:41 -0700 Subject: [PATCH 39/48] [read] fix sampling --- gridlock/read.py | 13 +++++++++---- gridlock/test/test_grid.py | 26 ++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/gridlock/read.py b/gridlock/read.py index 503e996..998e79d 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -68,7 +68,8 @@ class GridReadMixin(GridPosMixin): raise GridError('Coordinate of selected plane must be within simulation domain') # Extract grid values from planes above and below visualized slice - sliced_grid = numpy.zeros(self.shape[surface]) + sample_shape = tuple(self.shifted_xyz(which_shifts)[a][::sp].size for a in surface) + sliced_grid = numpy.zeros(sample_shape, dtype=numpy.result_type(cell_data.dtype, float)) for ci, weight in zip(centers, w, strict=True): s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3)) sliced_grid += weight * cell_data[which_shifts][tuple(s)] @@ -122,7 +123,11 @@ class GridReadMixin(GridPosMixin): surface = numpy.delete(range(3), plane.axis) - x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) + if sample_period == 1: + x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) + else: + x, y = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface) + pcolormesh_args.setdefault('shading', 'nearest') xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') x_label, y_label = ('xyz'[a] for a in surface) @@ -208,10 +213,10 @@ class GridReadMixin(GridPosMixin): fig, ax = pyplot.subplots() else: fig = ax.figure - xc, yc = (self.shifted_xyz(which_shifts)[a] for a in surface) + xc, yc = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface) xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij') - mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) + ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) if finalize: pyplot.show() diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index a9e3d9e..84b0f7b 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -168,3 +168,29 @@ def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None: with pytest.raises(GridError): grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1) + +def test_get_slice_supports_sampling() -> None: + grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]]) + cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) + + grid_slice = grid.get_slice(cell_data, Plane(z=0.5), sample_period=2) + + assert_allclose(grid_slice, cell_data[0, ::2, ::2, 0]) + + +def test_sampled_visualization_helpers_do_not_error() -> None: + matplotlib = pytest.importorskip('matplotlib') + matplotlib.use('Agg') + from matplotlib import pyplot + + grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]]) + cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) + + fig_slice, ax_slice = grid.visualize_slice(cell_data, Plane(z=0.5), sample_period=2, finalize=False) + fig_edges, ax_edges = grid.visualize_edges(cell_data, Plane(z=0.5), sample_period=2, finalize=False) + + assert fig_slice is ax_slice.figure + assert fig_edges is ax_edges.figure + + pyplot.close(fig_slice) + pyplot.close(fig_edges) From 15c2cf83516a8fe9bce4a2a0603398f2bded0dcc Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:47:35 -0700 Subject: [PATCH 40/48] improve arg checking --- gridlock/grid.py | 4 ++ gridlock/test/test_grid.py | 34 ++++++++++++++- gridlock/utils.py | 88 ++++++++++++++++++++++---------------- 3 files changed, 88 insertions(+), 38 deletions(-) diff --git a/gridlock/grid.py b/gridlock/grid.py index 5790dbd..5bed422 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -95,6 +95,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): `GridError` on invalid input """ edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates] + if len(edge_arrs) != 3: + raise GridError('pixel_edge_coordinates must contain exactly 3 coordinate arrays') self.exyz = [numpy.unique(edges) for edges in edge_arrs] self.shifts = numpy.array(shifts, dtype=float) @@ -106,6 +108,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): self.periodic = [periodic] * 3 else: self.periodic = list(periodic) + if len(self.periodic) != 3: + raise GridError('periodic must be a bool or a sequence of length 3') if len(self.shifts.shape) != 2: raise GridError('Misshapen shifts: shifts must have two axes! ' diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 84b0f7b..60929e8 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -2,7 +2,7 @@ import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal -from .. import Grid, Extent, GridError, Plane +from .. import Grid, Extent, GridError, Plane, Slab def test_draw_oncenter_2x2() -> None: @@ -194,3 +194,35 @@ def test_sampled_visualization_helpers_do_not_error() -> None: pyplot.close(fig_slice) pyplot.close(fig_edges) + + +def test_grid_constructor_rejects_invalid_coordinate_count() -> None: + with pytest.raises(GridError): + Grid([[0, 1], [0, 1]], shifts=[[0, 0, 0]]) + + with pytest.raises(GridError): + Grid([[0, 1], [0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) + + +def test_grid_constructor_rejects_invalid_periodic_length() -> None: + with pytest.raises(GridError): + Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]], periodic=[True, False]) + + +def test_extent_and_slab_reject_inverted_geometry() -> None: + with pytest.raises(GridError): + Extent(center=0, min=1) + + with pytest.raises(GridError): + Extent(min=2, max=1) + + with pytest.raises(GridError): + Slab(axis='z', center=1, max=0) + + +def test_extent_accepts_scalar_like_inputs() -> None: + extent = Extent(min=numpy.array([1.0]), span=numpy.array([4.0])) + + assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0]) + + diff --git a/gridlock/utils.py b/gridlock/utils.py index 8a8f11d..585b999 100644 --- a/gridlock/utils.py +++ b/gridlock/utils.py @@ -1,12 +1,25 @@ from typing import Protocol, TypedDict, runtime_checkable, cast from dataclasses import dataclass +import numpy + class GridError(Exception): """ Base error type for `gridlock` """ pass +def _coerce_scalar(name: str, value: object) -> float: + arr = numpy.asarray(value) + if arr.size != 1: + raise GridError(f'{name} must be a scalar value') + + try: + return float(arr.reshape(())) + except (TypeError, ValueError) as exc: + raise GridError(f'{name} must be a real scalar value') from exc + + class ExtentDict(TypedDict, total=False): """ Geometrical definition of an extent (1D bounded region) @@ -58,44 +71,46 @@ class Extent(ExtentProtocol): max: float | None = None, span: float | None = None, ) -> None: - if sum(cc is None for cc in (min, center, max, span)) != 2: - raise GridError('Exactly two of min, center, max, span must be None!') + values = { + 'min': None if min is None else _coerce_scalar('min', min), + 'center': None if center is None else _coerce_scalar('center', center), + 'max': None if max is None else _coerce_scalar('max', max), + 'span': None if span is None else _coerce_scalar('span', span), + } + if sum(value is not None for value in values.values()) != 2: + raise GridError('Exactly two of min, center, max, span must be provided') - if span is None: - if center is None: - assert min is not None - assert max is not None - assert max >= min - center = 0.5 * (max + min) - span = max - min - elif max is None: - assert min is not None - assert center is not None - span = 2 * (center - min) - elif min is None: - assert center is not None - assert max is not None - span = 2 * (max - center) - else: # noqa: PLR5501 - if center is not None: - pass - elif max is None: - assert min is not None - assert span is not None - center = min + 0.5 * span - elif min is None: - assert max is not None - assert span is not None - center = max - 0.5 * span + min_v = values['min'] + center_v = values['center'] + max_v = values['max'] + span_v = values['span'] - assert center is not None - assert span is not None - if hasattr(center, '__len__'): - assert len(center) == 1 - if hasattr(span, '__len__'): - assert len(span) == 1 - self.center = center - self.span = span + if span_v is not None and span_v < 0: + raise GridError('span must be non-negative') + + if min_v is not None and max_v is not None: + if max_v < min_v: + raise GridError('max must be greater than or equal to min') + center_v = 0.5 * (max_v + min_v) + span_v = max_v - min_v + elif center_v is not None and min_v is not None: + span_v = 2 * (center_v - min_v) + if span_v < 0: + raise GridError('min must be less than or equal to center') + elif center_v is not None and max_v is not None: + span_v = 2 * (max_v - center_v) + if span_v < 0: + raise GridError('center must be less than or equal to max') + elif min_v is not None and span_v is not None: + center_v = min_v + 0.5 * span_v + elif max_v is not None and span_v is not None: + center_v = max_v - 0.5 * span_v + + if center_v is None or span_v is None: + raise GridError('Unable to construct extent from the provided values') + + self.center = center_v + self.span = span_v class SlabDict(TypedDict, total=False): @@ -231,4 +246,3 @@ class Plane(PlaneProtocol): if hasattr(cpos, '__len__'): assert len(cpos) == 1 self.pos = cpos - From ddce4fa491081bee41e4eba699e8ff1bf5669141 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:50:48 -0700 Subject: [PATCH 41/48] [isosurface] fix sampling --- gridlock/read.py | 30 +++++++++++++++++++++++-- gridlock/test/test_grid.py | 46 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/gridlock/read.py b/gridlock/read.py index 998e79d..9df3e08 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -20,6 +20,26 @@ if TYPE_CHECKING: class GridReadMixin(GridPosMixin): + @staticmethod + def _preview_exyz_from_centers(centers: NDArray, fallback_edges: NDArray) -> NDArray[numpy.float64]: + if centers.size > 1: + midpoints = 0.5 * (centers[:-1] + centers[1:]) + first = centers[0] - 0.5 * (centers[1] - centers[0]) + last = centers[-1] + 0.5 * (centers[-1] - centers[-2]) + return numpy.hstack(([first], midpoints, [last])) + return numpy.array([fallback_edges[0], fallback_edges[-1]], dtype=float) + + def _sampled_exyz(self, which_shifts: int, sample_period: int) -> list[NDArray[numpy.float64]]: + if sample_period <= 1: + return self.shifted_exyz(which_shifts) + + shifted_xyz = self.shifted_xyz(which_shifts) + shifted_exyz = self.shifted_exyz(which_shifts) + return [ + self._preview_exyz_from_centers(shifted_xyz[a][::sample_period], shifted_exyz[a]) + for a in range(3) + ] + def get_slice( self, cell_data: NDArray, @@ -262,8 +282,14 @@ class GridReadMixin(GridPosMixin): verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) # Convert vertices from index to position - pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False) - for i in range(verts.shape[0])], dtype=float) + preview_exyz = self._sampled_exyz(which_shifts, sample_period) + pos_verts = numpy.array([ + [ + numpy.interp(verts[i, a], numpy.arange(preview_exyz[a].size) - 0.5, preview_exyz[a]) + for a in range(3) + ] + for i in range(verts.shape[0]) + ], dtype=float) xs, ys, zs = (pos_verts[:, a] for a in range(3)) # Draw the plot diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 60929e8..9f2e4f3 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -226,3 +226,49 @@ def test_extent_accepts_scalar_like_inputs() -> None: assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0]) + + +def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None: + matplotlib = pytest.importorskip('matplotlib') + matplotlib.use('Agg') + skimage_measure = pytest.importorskip('skimage.measure') + from matplotlib import pyplot + from mpl_toolkits.mplot3d.axes3d import Axes3D + + captured: dict[str, numpy.ndarray] = {} + + def fake_marching_cubes(_grid: numpy.ndarray, _level: float) -> tuple[numpy.ndarray, numpy.ndarray, None, None]: + verts = numpy.array([[0.5, 0.5, 0.5], + [0.5, 1.5, 0.5], + [1.5, 0.5, 0.5]], dtype=float) + faces = numpy.array([[0, 1, 2]], dtype=int) + return verts, faces, None, None + + def fake_plot_trisurf( # noqa: ANN202 + _self: object, + xs: numpy.ndarray, + ys: numpy.ndarray, + faces: numpy.ndarray, + zs: numpy.ndarray, + *_args: object, + **_kwargs: object, + ) -> object: + captured['xs'] = numpy.asarray(xs) + captured['ys'] = numpy.asarray(ys) + captured['faces'] = numpy.asarray(faces) + captured['zs'] = numpy.asarray(zs) + return object() + + monkeypatch.setattr(skimage_measure, 'marching_cubes', fake_marching_cubes) + monkeypatch.setattr(Axes3D, 'plot_trisurf', fake_plot_trisurf) + + grid = Grid([numpy.arange(7, dtype=float), numpy.arange(7, dtype=float), numpy.arange(7, dtype=float)], shifts=[[0, 0, 0]]) + cell_data = numpy.zeros(grid.cell_data_shape) + + fig, _ax = grid.visualize_isosurface(cell_data, level=0.5, sample_period=2, finalize=False) + + assert_allclose(captured['xs'], [1.5, 1.5, 3.5]) + assert_allclose(captured['ys'], [1.5, 3.5, 1.5]) + assert_allclose(captured['zs'], [1.5, 1.5, 1.5]) + + pyplot.close(fig) From e345d1dcf8f9b52af7cd83844efe66082f0b0379 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:51:34 -0700 Subject: [PATCH 42/48] [get_slice] use shifted bounds --- gridlock/read.py | 2 +- gridlock/test/test_grid.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/gridlock/read.py b/gridlock/read.py index 9df3e08..9be52b1 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -83,7 +83,7 @@ class GridReadMixin(GridPosMixin): else: w = [1] - c_min, c_max = (self.xyz[plane.axis][i] for i in [0, -1]) + c_min, c_max = (self.shifted_xyz(which_shifts)[plane.axis][i] for i in [0, -1]) if plane.pos < c_min or plane.pos > c_max: raise GridError('Coordinate of selected plane must be within simulation domain') diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 9f2e4f3..c6c8ae7 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -226,6 +226,18 @@ def test_extent_accepts_scalar_like_inputs() -> None: assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0]) +def test_get_slice_uses_shifted_grid_bounds() -> None: + grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0.5, 0, 0]]) + cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) + + grid_slice = grid.get_slice(cell_data, Plane(x=2.0), which_shifts=0) + + assert_allclose(grid_slice, cell_data[0, 1, :, :]) + + with pytest.raises(GridError): + grid.get_slice(cell_data, Plane(x=2.1), which_shifts=0) + + def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None: From 8895b06f08df4fb43f5910cd29468f32db8866ff Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:51:59 -0700 Subject: [PATCH 43/48] fixup! [isosurface] fix sampling --- gridlock/test/test_grid.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index c6c8ae7..2cb60c5 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -240,6 +240,14 @@ def test_get_slice_uses_shifted_grid_bounds() -> None: +def test_sampled_preview_exyz_tracks_nonuniform_centers() -> None: + grid = Grid([[0, 1, 3, 6, 10], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]]) + + sampled_exyz = grid._sampled_exyz(0, 2) + + assert_allclose(sampled_exyz[0], [-1.5, 2.5, 6.5]) + + def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None: matplotlib = pytest.importorskip('matplotlib') matplotlib.use('Agg') From 481b56874ee9c42f8534a378fa85e89c1e523d93 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:52:45 -0700 Subject: [PATCH 44/48] [draw] fix extrude without out-of-bounds slice --- gridlock/draw.py | 23 +++++++++++++---------- gridlock/test/test_grid.py | 17 +++++++++++++++++ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 864468f..321ec15 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -76,10 +76,10 @@ class GridDrawMixin(GridPosMixin): # Broadcast foreground where necessary foregrounds: Sequence[foreground_callable_t] | Sequence[float] - if numpy.size(foreground) == 1: # type: ignore - foregrounds = [foreground] * len(cell_data) # type: ignore - elif isinstance(foreground, numpy.ndarray): + if isinstance(foreground, numpy.ndarray): raise GridError('ndarray not supported for foreground') + if callable(foreground) or numpy.isscalar(foreground): + foregrounds = [foreground] * len(cell_data) # type: ignore[list-item] else: foregrounds = foreground # type: ignore @@ -376,15 +376,18 @@ class GridDrawMixin(GridPosMixin): foreground_func = [] for ii, grid in enumerate(cell_data): zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction] - - ind = [int(numpy.floor(zz)) if dd == direction else slice(None) for dd in range(3)] - fpart = zz - numpy.floor(zz) - mult = [1 - fpart, fpart][::sgn] # reverses if s negative + low = int(numpy.clip(numpy.floor(zz), 0, grid.shape[direction] - 1)) + high = int(numpy.clip(numpy.floor(zz) + 1, 0, grid.shape[direction] - 1)) - foreground = mult[0] * grid[tuple(ind)] - ind[direction] += 1 # type: ignore #(known safe) - foreground += mult[1] * grid[tuple(ind)] + low_ind = [low if dd == direction else slice(None) for dd in range(3)] + high_ind = [high if dd == direction else slice(None) for dd in range(3)] + + if low == high: + foreground = grid[tuple(low_ind)] + else: + mult = [1 - fpart, fpart][::sgn] # reverses if s negative + foreground = mult[0] * grid[tuple(low_ind)] + mult[1] * grid[tuple(high_ind)] def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 # transform from natural position to index diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 2cb60c5..e7b3b28 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -238,6 +238,23 @@ def test_get_slice_uses_shifted_grid_bounds() -> None: grid.get_slice(cell_data, Plane(x=2.1), which_shifts=0) +def test_draw_extrude_rectangle_uses_boundary_slice() -> None: + grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]]) + cell_data = grid.allocate(0) + source = numpy.array([[1, 2], + [3, 4]], dtype=float) + cell_data[0, :, :, 1] = source + + grid.draw_extrude_rectangle( + cell_data, + rectangle=[[0, 0, 2], [2, 2, 2]], + direction=2, + polarity=-1, + distance=2, + ) + + assert_allclose(cell_data[0, :, :, 0], source) + assert_allclose(cell_data[0, :, :, 1], source) def test_sampled_preview_exyz_tracks_nonuniform_centers() -> None: From 96aad5a3a10ab779bbdf00da081cdbf85861096d Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 11:00:08 -0700 Subject: [PATCH 45/48] bump version to v2.1 --- gridlock/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 2f39696..e7be065 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -34,5 +34,5 @@ from .grid import Grid as Grid __author__ = 'Jan Petykiewicz' -__version__ = '2.0' +__version__ = '2.1' version = __version__ From 066ca8f3b88cc03a30da43125358895ee0337e84 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 11:00:49 -0700 Subject: [PATCH 46/48] bump version to v2.2 2.1 had an existing tag --- gridlock/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/__init__.py b/gridlock/__init__.py index e7be065..3f965fd 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -34,5 +34,5 @@ from .grid import Grid as Grid __author__ = 'Jan Petykiewicz' -__version__ = '2.1' +__version__ = '2.2' version = __version__ From 85ae6e66cd4ee97192d6bb33249b5dc69e3d5668 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Tue, 21 Apr 2026 19:58:57 -0700 Subject: [PATCH 47/48] [Grid] enable negative shifts --- gridlock/base.py | 40 +++++++++++++-------------- gridlock/read.py | 2 +- gridlock/test/test_grid.py | 55 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 23 deletions(-) diff --git a/gridlock/base.py b/gridlock/base.py index aca9c69..e68d955 100644 --- a/gridlock/base.py +++ b/gridlock/base.py @@ -76,6 +76,21 @@ class GridBase(Protocol): el = [0 if p else -1 for p in self.periodic] return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)] + def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: + if which_shifts is None: + return self.dxyz_with_ghost + + shifts = self.shifts[which_shifts, :] + edge_dxyz = [] + for a in range(3): + if shifts[a] < 0: + ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0] + edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a]))) + else: + ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1] + edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost))) + return edge_dxyz + @property def center(self) -> NDArray[numpy.float64]: """ @@ -115,15 +130,9 @@ class GridBase(Protocol): """ if which_shifts is None: return self.exyz - dxyz = self.dxyz_with_ghost + edge_dxyz = self._shifted_edge_dxyz(which_shifts) shifts = self.shifts[which_shifts, :] - - # If shift is negative, use left cell's dx to determine shift - for a in range(3): - if shifts[a] < 0: - dxyz[a] = numpy.roll(dxyz[a], 1) - - return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] + return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)] def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]: """ @@ -137,20 +146,7 @@ class GridBase(Protocol): """ if which_shifts is None: return self.dxyz - shifts = self.shifts[which_shifts, :] - dxyz = self.dxyz_with_ghost - - # If shift is negative, use left cell's dx to determine size - sdxyz = [] - for a in range(3): - if shifts[a] < 0: - roll_dxyz = numpy.roll(dxyz[a], 1) - abs_shift = numpy.abs(shifts[a]) - sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift)) - else: - sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a]) - - return sdxyz + return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)] def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: """ diff --git a/gridlock/read.py b/gridlock/read.py index 9be52b1..f8a40a1 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -73,7 +73,7 @@ class GridReadMixin(GridPosMixin): surface = numpy.delete(range(3), plane.axis) # Extract indices and weights of planes - center3 = numpy.insert([0, 0], plane.axis, (plane.pos,)) + center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,)) center_index = self.pos2ind(center3, which_shifts, round_ind=False, check_bounds=False)[plane.axis] centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index e7b3b28..b4929a4 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -309,3 +309,58 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest. assert_allclose(captured['zs'], [1.5, 1.5, 1.5]) pyplot.close(fig) + + + + +def test_negative_shift_nonperiodic_edges_and_widths() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + + assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0]) + assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5]) + assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25]) + + +def test_negative_shift_periodic_edges_and_widths() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False]) + + assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0]) + assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5]) + + +def test_negative_shift_coordinate_round_trip() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + + ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False) + pos = grid.ind2pos(ind, 0, round_ind=False) + + assert_allclose(ind, [1.0, 0.0, 0.0]) + assert_allclose(pos, [1.25, 1.0, 0.5]) + + +def test_negative_shift_draw_cuboid_fractional_fill() -> None: + grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + arr = grid.allocate(0) + + grid.draw_cuboid( + arr, + x=dict(min=0, max=1), + y=dict(min=0, max=1), + z=dict(min=0, max=1), + foreground=1, + ) + + assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3]) + + +def test_negative_shift_get_slice_uses_shifted_centers() -> None: + grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + cell_data = numpy.zeros(grid.cell_data_shape) + cell_data[0, 1, :, 0] = [7, 9] + x_center = float(grid.shifted_xyz(0)[0][1]) + + grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0) + + assert_allclose(grid_slice, [7, 9]) + + From 22cb410d84ff4f33b727376761ebc489b59c382e Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Tue, 21 Apr 2026 20:00:52 -0700 Subject: [PATCH 48/48] [GridData / save / load] Add GridData and update save format --- gridlock/__init__.py | 1 + gridlock/data.py | 176 +++++++++++++++++++++++++++++++++++++ gridlock/grid.py | 110 ++++++++++++++++++++--- gridlock/test/test_grid.py | 102 ++++++++++++++++++++- 4 files changed, 376 insertions(+), 13 deletions(-) create mode 100644 gridlock/data.py diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 3f965fd..759d1c1 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -31,6 +31,7 @@ from .utils import ( PlaneDict as PlaneDict, ) from .grid import Grid as Grid +from .data import GridData as GridData __author__ = 'Jan Petykiewicz' diff --git a/gridlock/data.py b/gridlock/data.py new file mode 100644 index 0000000..5e6faa5 --- /dev/null +++ b/gridlock/data.py @@ -0,0 +1,176 @@ +from dataclasses import dataclass +from typing import Self +from collections.abc import Sequence + +import numpy +from numpy.typing import NDArray, ArrayLike + +from .draw import foreground_t +from .grid import Grid, _grid_from_payload, _load_payload, _payload_scalar_str, _save_npz_payload +from .utils import ( + ExtentDict, + ExtentProtocol, + GridError, + PlaneDict, + PlaneProtocol, + SlabDict, + SlabProtocol, +) + + +@dataclass(slots=True) +class GridData: + grid: Grid + cell_data: NDArray + + def __post_init__(self) -> None: + if tuple(self.cell_data.shape) != tuple(self.grid.cell_data_shape): + raise GridError( + f'cell_data has shape {self.cell_data.shape}, expected {tuple(self.grid.cell_data_shape)}' + ) + + @staticmethod + def load(filename: str) -> 'GridData': + payload = _load_payload(filename) + if _payload_scalar_str(payload, 'kind') != 'grid_data': + raise GridError('Serialized payload does not contain GridData') + if 'cell_data' not in payload: + raise GridError('Serialized GridData payload is missing cell_data') + + return GridData(_grid_from_payload(payload), numpy.array(payload['cell_data'])) + + def save(self, filename: str) -> Self: + payload = self.grid._serialization_payload(kind='grid_data') + payload['cell_data'] = self.cell_data + _save_npz_payload(filename, payload) + return self + + def copy(self) -> Self: + return GridData(self.grid.copy(), self.cell_data.copy()) + + def draw_polygons( + self, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + polygons: Sequence[ArrayLike], + *, + offset2d: ArrayLike = (0, 0), + ) -> Self: + self.grid.draw_polygons(self.cell_data, foreground, slab, polygons, offset2d=offset2d) + return self + + def draw_polygon( + self, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + polygon: ArrayLike, + *, + offset2d: ArrayLike = (0, 0), + ) -> Self: + self.grid.draw_polygon(self.cell_data, foreground, slab, polygon, offset2d=offset2d) + return self + + def draw_slab( + self, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + ) -> Self: + self.grid.draw_slab(self.cell_data, foreground, slab) + return self + + def draw_cuboid( + self, + foreground: Sequence[foreground_t] | foreground_t, + *, + x: ExtentProtocol | ExtentDict, + y: ExtentProtocol | ExtentDict, + z: ExtentProtocol | ExtentDict, + ) -> Self: + self.grid.draw_cuboid(self.cell_data, foreground, x=x, y=y, z=z) + return self + + def draw_cylinder( + self, + h: SlabProtocol | SlabDict, + radius: float, + num_points: int, + center2d: ArrayLike, + foreground: Sequence[foreground_t] | foreground_t, + ) -> Self: + self.grid.draw_cylinder(self.cell_data, h, radius, num_points, center2d, foreground) + return self + + def draw_extrude_rectangle( + self, + rectangle: ArrayLike, + direction: int, + polarity: int, + distance: float, + ) -> Self: + self.grid.draw_extrude_rectangle(self.cell_data, rectangle, direction, polarity, distance) + return self + + def get_slice( + self, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + sample_period: int = 1, + ) -> NDArray: + return self.grid.get_slice(self.cell_data, plane, which_shifts=which_shifts, sample_period=sample_period) + + def visualize_slice( + self, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + sample_period: int = 1, + finalize: bool = True, + pcolormesh_args: dict[str, object] | None = None, + ax: object | None = None, + ) -> tuple[object, object]: + return self.grid.visualize_slice( + self.cell_data, + plane, + which_shifts=which_shifts, + sample_period=sample_period, + finalize=finalize, + pcolormesh_args=pcolormesh_args, + ax=ax, + ) + + def visualize_edges( + self, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + sample_period: int = 1, + finalize: bool = True, + contour_args: dict[str, object] | None = None, + ax: object | None = None, + level_fraction: float = 0.7, + ) -> tuple[object, object]: + return self.grid.visualize_edges( + self.cell_data, + plane, + which_shifts=which_shifts, + sample_period=sample_period, + finalize=finalize, + contour_args=contour_args, + ax=ax, + level_fraction=level_fraction, + ) + + def visualize_isosurface( + self, + level: float | None = None, + which_shifts: int = 0, + sample_period: int = 1, + show_edges: bool = True, + finalize: bool = True, + ) -> tuple[object, object]: + return self.grid.visualize_isosurface( + self.cell_data, + level=level, + which_shifts=which_shifts, + sample_period=sample_period, + show_edges=show_edges, + finalize=finalize, + ) diff --git a/gridlock/grid.py b/gridlock/grid.py index 5bed422..eeb9708 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Self +from typing import TYPE_CHECKING, Any, ClassVar, Self from collections.abc import Callable, Sequence import numpy @@ -13,8 +13,78 @@ from .draw import GridDrawMixin from .read import GridReadMixin from .position import GridPosMixin +if TYPE_CHECKING: + from .data import GridData + foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray] +_FORMAT_VERSION = 1 + + +def _is_npz_file(filename: str) -> bool: + with open(filename, 'rb') as f: + return f.read(2) == b'PK' + + +def _save_npz_payload(filename: str, payload: dict[str, Any]) -> None: + with open(filename, 'wb') as f: + numpy.savez_compressed(f, **payload) + + +def _load_payload(filename: str) -> dict[str, Any]: + if _is_npz_file(filename): + with numpy.load(filename, allow_pickle=False) as payload: + return {key: payload[key] for key in payload.files} + + with open(filename, 'rb') as f: + legacy = pickle.load(f) + + if isinstance(legacy, Grid): + return legacy._serialization_payload(kind='grid') + if isinstance(legacy, dict): + grid = Grid([[-1, 1]] * 3) + grid.__dict__.update(legacy) + return grid._serialization_payload(kind='grid') + raise GridError('Unsupported serialized Grid payload') + + +def _payload_scalar_str(payload: dict[str, Any], key: str) -> str: + if key not in payload: + raise GridError(f'Missing serialized key: {key}') + + value = numpy.asarray(payload[key]) + if value.size != 1: + raise GridError(f'Serialized key {key} must be scalar') + return str(value.reshape(())) + + +def _payload_scalar_int(payload: dict[str, Any], key: str) -> int: + if key not in payload: + raise GridError(f'Missing serialized key: {key}') + + value = numpy.asarray(payload[key]) + if value.size != 1: + raise GridError(f'Serialized key {key} must be scalar') + return int(value.reshape(())) + + +def _grid_from_payload(payload: dict[str, Any]) -> 'Grid': + if _payload_scalar_int(payload, 'format_version') != _FORMAT_VERSION: + raise GridError('Unsupported serialized Grid format version') + + exyz = [] + for axis in range(3): + key = f'exyz_{axis}' + if key not in payload: + raise GridError(f'Missing serialized key: {key}') + exyz.append(numpy.array(payload[key], dtype=float)) + + if 'shifts' not in payload or 'periodic' not in payload: + raise GridError('Serialized Grid payload is missing shifts or periodic data') + + shifts = numpy.array(payload['shifts'], dtype=float) + periodic = numpy.array(payload['periodic'], dtype=bool).tolist() + return Grid(exyz, shifts=shifts, periodic=periodic) class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): @@ -110,6 +180,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): self.periodic = list(periodic) if len(self.periodic) != 3: raise GridError('periodic must be a bool or a sequence of length 3') + if not all(isinstance(pp, bool | numpy.bool_) for pp in self.periodic): + raise GridError('periodic sequence entries must be bool values') if len(self.shifts.shape) != 2: raise GridError('Misshapen shifts: shifts must have two axes! ' @@ -121,9 +193,16 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): if (numpy.abs(self.shifts) > 1).any(): raise GridError('Only shifts in the range [-1, 1] are currently supported') - if (self.shifts < 0).any(): - # TODO: Test negative shifts - warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2) + def _serialization_payload(self, *, kind: str) -> dict[str, Any]: + payload: dict[str, Any] = { + 'kind': numpy.array(kind), + 'format_version': numpy.array(_FORMAT_VERSION, dtype=int), + 'shifts': self.shifts, + 'periodic': numpy.array(self.periodic, dtype=bool), + } + for axis, exyz in enumerate(self.exyz): + payload[f'exyz_{axis}'] = exyz + return payload @staticmethod def load(filename: str) -> 'Grid': @@ -133,12 +212,11 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Args: filename: Filename to load from. """ - with open(filename, 'rb') as f: - tmp_dict = pickle.load(f) - - g = Grid([[-1, 1]] * 3) - g.__dict__.update(tmp_dict) - return g + payload = _load_payload(filename) + kind = _payload_scalar_str(payload, 'kind') + if kind not in ('grid', 'grid_data'): + raise GridError(f'Unsupported serialized kind: {kind}') + return _grid_from_payload(payload) def save(self, filename: str) -> Self: """ @@ -150,10 +228,18 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Returns: self """ - with open(filename, 'wb') as f: - pickle.dump(self.__dict__, f, protocol=2) + _save_npz_payload(filename, self._serialization_payload(kind='grid')) return self + def with_data( + self, + fill_value: float | None = 1.0, + dtype: type[numpy.number] = numpy.float32, + ) -> 'GridData': + from .data import GridData + + return GridData(self.copy(), self.allocate(fill_value=fill_value, dtype=dtype)) + def copy(self) -> Self: """ Returns: diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index b4929a4..ae0a73a 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,8 +1,9 @@ import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal +import pickle -from .. import Grid, Extent, GridError, Plane, Slab +from .. import Grid, GridData, Extent, GridError, Plane, Slab def test_draw_oncenter_2x2() -> None: @@ -311,6 +312,54 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest. pyplot.close(fig) +def test_grid_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False]) + path = tmp_path / 'grid.state' + + grid.save(str(path)) + loaded = Grid.load(str(path)) + + assert path.exists() + for original, restored in zip(grid.exyz, loaded.exyz, strict=True): + assert_allclose(restored, original) + assert_allclose(loaded.shifts, grid.shifts) + assert loaded.periodic == grid.periodic + + +def test_grid_load_supports_legacy_pickle(tmp_path: pytest.TempPathFactory) -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False]) + path = tmp_path / 'grid.pickle' + with open(path, 'wb') as f: + pickle.dump(grid.__dict__, f, protocol=2) + + loaded = Grid.load(str(path)) + + for original, restored in zip(grid.exyz, loaded.exyz, strict=True): + assert_allclose(restored, original) + assert_allclose(loaded.shifts, grid.shifts) + assert loaded.periodic == grid.periodic + + +def test_griddata_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None: + data = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]]).with_data(fill_value=2.0) + data.cell_data[0, 1, 0, 0] = 5.0 + path = tmp_path / 'griddata.state' + + data.save(str(path)) + loaded = GridData.load(str(path)) + + assert path.exists() + assert_allclose(loaded.cell_data, data.cell_data) + assert_allclose(loaded.grid.shifts, data.grid.shifts) + assert loaded.grid.periodic == data.grid.periodic + + +def test_griddata_rejects_invalid_payload_kind(tmp_path: pytest.TempPathFactory) -> None: + path = tmp_path / 'grid.state' + Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).save(str(path)) + + with pytest.raises(GridError): + GridData.load(str(path)) def test_negative_shift_nonperiodic_edges_and_widths() -> None: @@ -364,3 +413,54 @@ def test_negative_shift_get_slice_uses_shifted_centers() -> None: assert_allclose(grid_slice, [7, 9]) +def test_grid_with_data_returns_griddata() -> None: + grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) + data = grid.with_data(fill_value=2.0) + + assert isinstance(data, GridData) + assert_allclose(data.cell_data, numpy.full(grid.cell_data_shape, 2.0, dtype=numpy.float32)) + + +def test_griddata_constructor_validates_shape() -> None: + grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) + + with pytest.raises(GridError): + GridData(grid, numpy.zeros((1, 1, 1))) + + +def test_griddata_draw_methods_are_chainable() -> None: + data = Grid([[0, 1, 2], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0) + + chained = data.draw_cuboid( + foreground=1, + x=dict(min=0, max=1), + y=dict(min=0, max=1), + z=dict(min=0, max=1), + ).draw_polygon( + foreground=0.5, + slab=dict(axis='z', center=0.5, span=1.0), + polygon=numpy.array([[0, 0], [2, 0], [2, 1], [0, 1]], dtype=float), + ) + + assert chained is data + assert data.cell_data.sum() > 0 + + +def test_griddata_read_methods_delegate() -> None: + data = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0) + data.cell_data[0, :, :, 0] = numpy.array([[1, 2], [3, 4]], dtype=float) + + assert_allclose( + data.get_slice(Plane(z=0.5)), + data.grid.get_slice(data.cell_data, Plane(z=0.5)), + ) + + +def test_griddata_copy_is_independent() -> None: + data = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=1.0) + cloned = data.copy() + cloned.cell_data[0, 0, 0, 0] = 5.0 + + assert data is not cloned + assert data.grid is not cloned.grid + assert data.cell_data[0, 0, 0, 0] == 1.0