From 13c12b0a6adb2555c9fff55cb512af6701ec54de Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Wed, 16 Apr 2025 21:25:43 -0700 Subject: [PATCH 01/22] update docs to reflect new args --- gridlock/draw.py | 34 ++++++++++++++++++---------------- gridlock/utils.py | 37 +++++++++++++++++++++++++++++++++++-- pyproject.toml | 1 - 3 files changed, 53 insertions(+), 19 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 0b93d20..9ba4623 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -21,30 +21,31 @@ foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray] foreground_t = float | foreground_callable_t - class GridDrawMixin(GridPosMixin): def draw_polygons( self, cell_data: NDArray, + foreground: Sequence[foreground_t] | foreground_t, slab: SlabProtocol | SlabDict, polygons: Sequence[ArrayLike], - foreground: Sequence[foreground_t] | foreground_t, *, offset2d: ArrayLike = (0, 0), ) -> None: """ - Draw polygons on an axis-aligned plane. + Draw polygons on an axis-aligned slab. Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - center: 3-element ndarray or list specifying an offset applied to all the polygons - polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon - (non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each - polygon must have at least 3 vertices. foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list of any of these (1 per grid). Callable values should take an ndarray the shape of the grid and return an ndarray of equal shape containing the foreground value at the given x, y, and z (natural, not grid coordinates). + slab: `Slab` or slab-like dict specifying the slab in which the polygons will be drawn. + polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon + (non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each + polygon must have at least 3 vertices. + offset2d: 2D offset to apply to polygon coordinates -- this offset is added directly + to the given polygon vertex coordinates. Default (0, 0). Raises: GridError @@ -200,9 +201,9 @@ class GridDrawMixin(GridPosMixin): def draw_polygon( self, cell_data: NDArray, + foreground: Sequence[foreground_t] | foreground_t, slab: SlabProtocol | SlabDict, polygon: ArrayLike, - foreground: Sequence[foreground_t] | foreground_t, *, offset2d: ArrayLike = (0, 0), ) -> None: @@ -211,11 +212,13 @@ class GridDrawMixin(GridPosMixin): Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - slab: `Slab` in which to draw polygons. + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + slab: `Slab` or slab-like dict specifying the slab in which the polygon will be drawn. polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Must have at least 3 vertices. - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + offset2d: 2D offset to apply to polygon coordinates -- this offset is added directly + to the given polygon vertex coordinates. Default (0, 0). """ self.draw_polygons( cell_data = cell_data, @@ -229,17 +232,16 @@ class GridDrawMixin(GridPosMixin): def draw_slab( self, cell_data: NDArray, - slab: SlabProtocol | SlabDict, foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, ) -> None: """ Draw an axis-aligned infinite slab. Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - slab: - thickness: Thickness of the layer to draw foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + slab: `Slab` or slab-like dict (geometrical slab specification) """ if isinstance(slab, dict): slab = Slab(**slab) @@ -282,10 +284,10 @@ class GridDrawMixin(GridPosMixin): Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - center: 3-element ndarray or list specifying the cuboid's center - dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge - sizes of the cuboid foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. + x: `Extent` or extent-like dict specifying the x-extent of the cuboid. + y: `Extent` or extent-like dict specifying the y-extent of the cuboid. + z: `Extent` or extent-like dict specifying the z-extent of the cuboid. """ if isinstance(x, dict): x = Extent(**x) diff --git a/gridlock/utils.py b/gridlock/utils.py index 7a12035..8a8f11d 100644 --- a/gridlock/utils.py +++ b/gridlock/utils.py @@ -1,4 +1,4 @@ -from typing import Protocol, TypedDict, runtime_checkable +from typing import Protocol, TypedDict, runtime_checkable, cast from dataclasses import dataclass @@ -8,6 +8,10 @@ class GridError(Exception): class ExtentDict(TypedDict, total=False): + """ + Geometrical definition of an extent (1D bounded region) + Must contain exactly two of `min`, `max`, `center`, or `span`. + """ min: float center: float max: float @@ -16,6 +20,9 @@ class ExtentDict(TypedDict, total=False): @runtime_checkable class ExtentProtocol(Protocol): + """ + Anything that looks like an `Extent` + """ center: float span: float @@ -28,6 +35,10 @@ class ExtentProtocol(Protocol): @dataclass(init=False, slots=True) class Extent(ExtentProtocol): + """ + Geometrical definition of an extent (1D bounded region) + May be constructed with any two of `min`, `max`, `center`, or `span`. + """ center: float span: float @@ -88,6 +99,10 @@ class Extent(ExtentProtocol): class SlabDict(TypedDict, total=False): + """ + Geometrical definition of a slab (3D region bounded on one axis only) + Must contain `axis` plus any two of `min`, `max`, `center`, or `span`. + """ min: float center: float max: float @@ -97,6 +112,9 @@ class SlabDict(TypedDict, total=False): @runtime_checkable class SlabProtocol(ExtentProtocol, Protocol): + """ + Anything that looks like a `Slab` + """ axis: int center: float span: float @@ -110,6 +128,10 @@ class SlabProtocol(ExtentProtocol, Protocol): @dataclass(init=False, slots=True) class Slab(Extent, SlabProtocol): + """ + Geometrical definition of a slab (3D region bounded on one axis only) + May be constructed with `axis` (bounded axis) plus any two of `min`, `max`, `center`, or `span`. + """ axis: int def __init__( @@ -142,6 +164,10 @@ class Slab(Extent, SlabProtocol): class PlaneDict(TypedDict, total=False): + """ + Geometrical definition of a plane (2D unbounded region in 3D space) + Must contain exactly one of `x`, `y`, `z`, or both `axis` and `pos` + """ x: float y: float z: float @@ -151,12 +177,19 @@ class PlaneDict(TypedDict, total=False): @runtime_checkable class PlaneProtocol(Protocol): + """ + Anything that looks like a `Plane` + """ axis: int pos: float @dataclass(init=False, slots=True) class Plane(PlaneProtocol): + """ + Geometrical definition of a plane (2D unbounded region in 3D space) + May be constructed with any of `x=4`, `y=5`, `z=-5`, or `axis=2, pos=-5`. + """ axis: int pos: float @@ -192,7 +225,7 @@ class Plane(PlaneProtocol): if pos is not None: cpos = pos else: - cpos = (xx, yy, zz)[axis_int] + cpos = cast('float', (xx, yy, zz)[axis_int]) assert cpos is not None if hasattr(cpos, '__len__'): diff --git a/pyproject.toml b/pyproject.toml index 1df2e8b..03d0d19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,6 @@ lint.ignore = [ "ANN002", # *args "ANN003", # **kwargs "ANN401", # Any - "ANN101", # self: Self "SIM108", # single-line if / else assignment "RET504", # x=y+z; return x "PIE790", # unnecessary pass From be7c26c1d1668564f5d2018d02554616bbbebec4 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Wed, 16 Apr 2025 21:28:43 -0700 Subject: [PATCH 02/22] bump version to v2.0 -- major arg rework for drawing/reading --- gridlock/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 120291f..2f39696 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -34,5 +34,5 @@ from .grid import Grid as Grid __author__ = 'Jan Petykiewicz' -__version__ = '1.2' +__version__ = '2.0' version = __version__ From 6802e57fa9fb26d5017b55b9bd46725db38ae04b Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 19:18:08 -0700 Subject: [PATCH 03/22] [read] add missing arg to docstring --- gridlock/read.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gridlock/read.py b/gridlock/read.py index 707251a..cfc8f3d 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -98,6 +98,7 @@ class GridReadMixin(GridPosMixin): which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + pcolormesh_args: Args passed through to matplotlib `pcolormesh()` Returns: (Figure, Axes) From 21304f0dbfce8b089a44a7dfb13c072b57839bfa Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 19:18:49 -0700 Subject: [PATCH 04/22] [read] add option to visualize on preexisting axes --- gridlock/read.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/gridlock/read.py b/gridlock/read.py index cfc8f3d..b5840e7 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -87,6 +87,7 @@ class GridReadMixin(GridPosMixin): sample_period: int = 1, finalize: bool = True, pcolormesh_args: dict[str, Any] | None = None, + ax: 'matplotlib.axes.Axes' | None = None, ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: """ Visualize a slice of a grid. @@ -99,6 +100,7 @@ class GridReadMixin(GridPosMixin): sample_period: Period for down-sampling the image. Default 1 (disabled) finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` pcolormesh_args: Args passed through to matplotlib `pcolormesh()` + ax: If provided, plot to these axes (instead of creating a new figure & axes) Returns: (Figure, Axes) @@ -112,10 +114,10 @@ class GridReadMixin(GridPosMixin): pcolormesh_args = {} grid_slice = self.get_slice( - cell_data=cell_data, - plane=plane, - which_shifts=which_shifts, - sample_period=sample_period, + cell_data = cell_data, + plane = plane, + which_shifts = which_shifts, + sample_period = sample_period, ) surface = numpy.delete(range(3), plane.axis) @@ -124,7 +126,10 @@ class GridReadMixin(GridPosMixin): xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') x_label, y_label = ('xyz'[a] for a in surface) - fig, ax = pyplot.subplots() + if ax is None: + fig, ax = pyplot.subplots() + else: + fig = ax.figure mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) fig.colorbar(mappable) ax.set_aspect('equal', adjustable='box') From 68520b871018c3d86254fa1fa87faf4838351d11 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 19:19:28 -0700 Subject: [PATCH 05/22] [read] add visualize_edges() --- gridlock/read.py | 74 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/gridlock/read.py b/gridlock/read.py index b5840e7..600227d 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -135,12 +135,86 @@ class GridReadMixin(GridPosMixin): ax.set_aspect('equal', adjustable='box') ax.set_xlabel(x_label) ax.set_ylabel(y_label) + if finalize: pyplot.show() return fig, ax + def visualize_edges( + self, + cell_data: NDArray, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + finalize: bool = True, + contour_args: dict[str, Any] | None = None, + ax: 'matplotlib.axes.Axes' | None = None, + level_fraction: float = 0.7, + ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: + """ + Visualize the edges of a grid slice. + This is intended as an overlay on top of visualize_slice (e.g. showing epsilon boundaries + on an E-field plot). + + Interpolates if given a position between two grid planes. + + Args: + cell_data: Cell data to visualize + plane: Axis and position (`Plane`) of the plane to read. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) + finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + pcolormesh_args: Args passed through to matplotlib `pcolormesh()` + ax: If provided, plot to these axes (instead of creating a new figure & axes) + level_fraction: Value between 0 and 1 which tunes how many contours are generated. + 1 indicates that every possible step should have its own contour. + + Returns: + (Figure, Axes) + """ + from matplotlib import pyplot + + if level_fraction > 1: + raise GridError(f'{level_fraction=} must be between 0 and 1') + + if isinstance(plane, dict): + plane = Plane(**plane) + + if pcolormesh_args is None: + pcolormesh_args = {} + + grid_slice = self.get_slice( + cell_data = cell_data, + plane = plane, + which_shifts = which_shifts, + sample_period = sample_period, + ) + cvals, cval_counts = numpy.unique(grid_slice, return_counts=True) + if cvals.size == 1: + levels = [cvals[0] + 1] + else: + cval_order = numpy.argsort(cval_counts)[::-1] + level_count = 2 + while cval_counts[cval_order[:level_count]].sum() < level_fraction: + level_count += 1 + ctr_levels = cvals[cval_order[:level_count]] + levels = numpy.diff(ctr_levels[::-1]) + ctr_levels[:0:-1] + + surface = numpy.delete(range(3), plane.axis) + + if ax is None: + fig, ax = pyplot.subplots() + else: + fig = ax.figure + xc, yc = (self.shifted_exyz(which_shifts)[a] for a in surface) + xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij') + + mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) + + return fig, ax + + def visualize_isosurface( self, cell_data: NDArray, From 7cac73bcb400021289350a52215083e88c337cd7 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 19:25:39 -0700 Subject: [PATCH 06/22] [draw] add missing code for finalize --- gridlock/read.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gridlock/read.py b/gridlock/read.py index 600227d..4f39432 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -212,6 +212,9 @@ class GridReadMixin(GridPosMixin): mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) + if finalize: + pyplot.show() + return fig, ax From 16a76e0122845d029c7a4a3c2896ab0606553576 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 19:26:59 -0700 Subject: [PATCH 07/22] [read] make visualize_edges more friendly for overlay by default --- gridlock/read.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/read.py b/gridlock/read.py index 4f39432..28afdd2 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -182,7 +182,7 @@ class GridReadMixin(GridPosMixin): plane = Plane(**plane) if pcolormesh_args is None: - pcolormesh_args = {} + pcolormesh_args = dict(alpha=0.8, colors='gray') grid_slice = self.get_slice( cell_data = cell_data, From 64752873fbde2392b91150f8b26740dfdc633e66 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 19:28:40 -0700 Subject: [PATCH 08/22] [read] fix type spec --- gridlock/read.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gridlock/read.py b/gridlock/read.py index 28afdd2..44f1c8a 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -87,7 +87,7 @@ class GridReadMixin(GridPosMixin): sample_period: int = 1, finalize: bool = True, pcolormesh_args: dict[str, Any] | None = None, - ax: 'matplotlib.axes.Axes' | None = None, + ax: 'matplotlib.axes.Axes | None' = None, ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: """ Visualize a slice of a grid. @@ -149,7 +149,7 @@ class GridReadMixin(GridPosMixin): which_shifts: int = 0, finalize: bool = True, contour_args: dict[str, Any] | None = None, - ax: 'matplotlib.axes.Axes' | None = None, + ax: 'matplotlib.axes.Axes | None' = None, level_fraction: float = 0.7, ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: """ From f4818fd55450463a63cfa8bcc26ebbcecfce7a88 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 20:32:01 -0700 Subject: [PATCH 09/22] [draw] fix arg naming --- gridlock/read.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gridlock/read.py b/gridlock/read.py index 44f1c8a..7b4de1e 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -165,7 +165,7 @@ class GridReadMixin(GridPosMixin): which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - pcolormesh_args: Args passed through to matplotlib `pcolormesh()` + contour_args: Args passed through to matplotlib `pcolormesh()` ax: If provided, plot to these axes (instead of creating a new figure & axes) level_fraction: Value between 0 and 1 which tunes how many contours are generated. 1 indicates that every possible step should have its own contour. @@ -181,8 +181,8 @@ class GridReadMixin(GridPosMixin): if isinstance(plane, dict): plane = Plane(**plane) - if pcolormesh_args is None: - pcolormesh_args = dict(alpha=0.8, colors='gray') + if contour_args is None: + contour_args = dict(alpha=0.8, colors='gray') grid_slice = self.get_slice( cell_data = cell_data, From 32b6c207dcf703f6c695dc9fae3f7615bd5e8f15 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 22 Sep 2025 22:24:15 -0700 Subject: [PATCH 10/22] [read] more fixup for visualize_edges --- gridlock/read.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gridlock/read.py b/gridlock/read.py index 7b4de1e..503e996 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -147,6 +147,7 @@ class GridReadMixin(GridPosMixin): cell_data: NDArray, plane: PlaneProtocol | PlaneDict, which_shifts: int = 0, + sample_period: int = 1, finalize: bool = True, contour_args: dict[str, Any] | None = None, ax: 'matplotlib.axes.Axes | None' = None, @@ -207,7 +208,7 @@ class GridReadMixin(GridPosMixin): fig, ax = pyplot.subplots() else: fig = ax.figure - xc, yc = (self.shifted_exyz(which_shifts)[a] for a in surface) + xc, yc = (self.shifted_xyz(which_shifts)[a] for a in surface) xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij') mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) From 43d5fa8b4f2e1d48a2d35de02ebc2a7adf591168 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:21:22 -0700 Subject: [PATCH 11/22] [draw] fix handling of Nx3 vertex arrays --- gridlock/draw.py | 11 +++++------ gridlock/test/test_grid.py | 35 +++++++++++++++++++++++++++++++++-- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 9ba4623..864468f 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -61,16 +61,18 @@ class GridDrawMixin(GridPosMixin): for ii in range(len(poly_list)): polygon = poly_list[ii] malformed = f'Malformed polygon: ({ii})' + if polygon.ndim != 2: + raise GridError(malformed + 'must be a 2-dimensional ndarray') if polygon.shape[1] not in (2, 3): raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') if polygon.shape[1] == 3: - polygon = polygon[surface, :] + if numpy.unique(polygon[:, slab.axis]).size != 1: + raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis]) + polygon = polygon[:, surface] poly_list[ii] = polygon if not polygon.shape[0] > 2: raise GridError(malformed + 'must consist of more than 2 points') - if polygon.ndim > 2 and not numpy.unique(polygon[:, slab.axis]).size == 1: - raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis]) # Broadcast foreground where necessary foregrounds: Sequence[foreground_callable_t] | Sequence[float] @@ -296,8 +298,6 @@ class GridDrawMixin(GridPosMixin): if isinstance(z, dict): z = Extent(**z) - center = numpy.asarray([x.center, y.center, z.center]) - p = numpy.array([[x.min, y.max], [x.max, y.max], [x.max, y.min], @@ -398,4 +398,3 @@ class GridDrawMixin(GridPosMixin): slab = Slab(axis=direction, center=center[direction], span=thickness) self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface]) - diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 8d9ca92..6cb9edc 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,8 +1,8 @@ -# import pytest +import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal -from .. import Grid, Extent #, Slab, Plane +from .. import Grid, Extent, GridError, Plane def test_draw_oncenter_2x2() -> None: @@ -116,3 +116,34 @@ def test_draw_2shift_4x4() -> None: [0, 0.125, 0.125, 0]])[None, :, :, None] assert_allclose(arr, correct) + + +def test_draw_polygon_accepts_coplanar_nx3_vertices() -> None: + grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) + arr_2d = grid.allocate(0) + arr_3d = grid.allocate(0) + slab = dict(axis='z', center=0.5, span=1.0) + + polygon_2d = numpy.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=float) + polygon_3d = numpy.array([[0, 0, 0.5], + [1, 0, 0.5], + [1, 1, 0.5], + [0, 1, 0.5]], dtype=float) + + grid.draw_polygon(arr_2d, slab=slab, polygon=polygon_2d, foreground=1) + grid.draw_polygon(arr_3d, slab=slab, polygon=polygon_3d, foreground=1) + + assert_allclose(arr_3d, arr_2d) + + +def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None: + grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) + arr = grid.allocate(0) + polygon = numpy.array([[0, 0, 0.5], + [1, 0, 0.5], + [1, 1, 0.75], + [0, 1, 0.5]], dtype=float) + + with pytest.raises(GridError): + grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1) + From 1cc47da386d69a56938c4d62629f74afd2d20966 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:25:13 -0700 Subject: [PATCH 12/22] [ind2pos] fix rounding and bounds --- gridlock/position.py | 4 ++-- gridlock/test/test_grid.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/gridlock/position.py b/gridlock/position.py index b705b99..6344ea4 100644 --- a/gridlock/position.py +++ b/gridlock/position.py @@ -47,13 +47,13 @@ class GridPosMixin(GridBase): else: low_bound = -0.5 high_bound = -0.5 - if (ind < low_bound).any() or (ind > self.shape - high_bound).any(): + if (ind < low_bound).any() or (ind > self.shape + high_bound).any(): raise GridError(f'Position outside of grid: {ind}') if round_ind: rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) sxyz = self.shifted_xyz(which_shifts) - position = [sxyz[a][rind[a]].astype(int) for a in range(3)] + position = [sxyz[a][rind[a]] for a in range(3)] else: sexyz = self.shifted_exyz(which_shifts) position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a]) diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 6cb9edc..a9e3d9e 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -118,6 +118,27 @@ def test_draw_2shift_4x4() -> None: assert_allclose(arr, correct) +def test_ind2pos_round_preserves_float_centers() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]]) + + pos = grid.ind2pos(numpy.array([1, 0, 0]), which_shifts=0) + + assert_allclose(pos, [2.0, 1.0, 0.5]) + + +def test_ind2pos_enforces_bounds_for_rounded_and_fractional_indices() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]]) + + with pytest.raises(GridError): + grid.ind2pos(numpy.array([2, 0, 0]), which_shifts=0, check_bounds=True) + + edge_pos = grid.ind2pos(numpy.array([1.5, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True) + assert_allclose(edge_pos, [3.0, 2.0, 1.0]) + + with pytest.raises(GridError): + grid.ind2pos(numpy.array([1.6, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True) + + def test_draw_polygon_accepts_coplanar_nx3_vertices() -> None: grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) arr_2d = grid.allocate(0) From 526b9e1666b55c59cbf2fa684e9f20dc500b7ac8 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:25:41 -0700 Subject: [PATCH 13/22] [read] fix sampling --- gridlock/read.py | 13 +++++++++---- gridlock/test/test_grid.py | 26 ++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/gridlock/read.py b/gridlock/read.py index 503e996..998e79d 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -68,7 +68,8 @@ class GridReadMixin(GridPosMixin): raise GridError('Coordinate of selected plane must be within simulation domain') # Extract grid values from planes above and below visualized slice - sliced_grid = numpy.zeros(self.shape[surface]) + sample_shape = tuple(self.shifted_xyz(which_shifts)[a][::sp].size for a in surface) + sliced_grid = numpy.zeros(sample_shape, dtype=numpy.result_type(cell_data.dtype, float)) for ci, weight in zip(centers, w, strict=True): s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3)) sliced_grid += weight * cell_data[which_shifts][tuple(s)] @@ -122,7 +123,11 @@ class GridReadMixin(GridPosMixin): surface = numpy.delete(range(3), plane.axis) - x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) + if sample_period == 1: + x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) + else: + x, y = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface) + pcolormesh_args.setdefault('shading', 'nearest') xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') x_label, y_label = ('xyz'[a] for a in surface) @@ -208,10 +213,10 @@ class GridReadMixin(GridPosMixin): fig, ax = pyplot.subplots() else: fig = ax.figure - xc, yc = (self.shifted_xyz(which_shifts)[a] for a in surface) + xc, yc = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface) xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij') - mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) + ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) if finalize: pyplot.show() diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index a9e3d9e..84b0f7b 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -168,3 +168,29 @@ def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None: with pytest.raises(GridError): grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1) + +def test_get_slice_supports_sampling() -> None: + grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]]) + cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) + + grid_slice = grid.get_slice(cell_data, Plane(z=0.5), sample_period=2) + + assert_allclose(grid_slice, cell_data[0, ::2, ::2, 0]) + + +def test_sampled_visualization_helpers_do_not_error() -> None: + matplotlib = pytest.importorskip('matplotlib') + matplotlib.use('Agg') + from matplotlib import pyplot + + grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]]) + cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) + + fig_slice, ax_slice = grid.visualize_slice(cell_data, Plane(z=0.5), sample_period=2, finalize=False) + fig_edges, ax_edges = grid.visualize_edges(cell_data, Plane(z=0.5), sample_period=2, finalize=False) + + assert fig_slice is ax_slice.figure + assert fig_edges is ax_edges.figure + + pyplot.close(fig_slice) + pyplot.close(fig_edges) From 15c2cf83516a8fe9bce4a2a0603398f2bded0dcc Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:47:35 -0700 Subject: [PATCH 14/22] improve arg checking --- gridlock/grid.py | 4 ++ gridlock/test/test_grid.py | 34 ++++++++++++++- gridlock/utils.py | 88 ++++++++++++++++++++++---------------- 3 files changed, 88 insertions(+), 38 deletions(-) diff --git a/gridlock/grid.py b/gridlock/grid.py index 5790dbd..5bed422 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -95,6 +95,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): `GridError` on invalid input """ edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates] + if len(edge_arrs) != 3: + raise GridError('pixel_edge_coordinates must contain exactly 3 coordinate arrays') self.exyz = [numpy.unique(edges) for edges in edge_arrs] self.shifts = numpy.array(shifts, dtype=float) @@ -106,6 +108,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): self.periodic = [periodic] * 3 else: self.periodic = list(periodic) + if len(self.periodic) != 3: + raise GridError('periodic must be a bool or a sequence of length 3') if len(self.shifts.shape) != 2: raise GridError('Misshapen shifts: shifts must have two axes! ' diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 84b0f7b..60929e8 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -2,7 +2,7 @@ import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal -from .. import Grid, Extent, GridError, Plane +from .. import Grid, Extent, GridError, Plane, Slab def test_draw_oncenter_2x2() -> None: @@ -194,3 +194,35 @@ def test_sampled_visualization_helpers_do_not_error() -> None: pyplot.close(fig_slice) pyplot.close(fig_edges) + + +def test_grid_constructor_rejects_invalid_coordinate_count() -> None: + with pytest.raises(GridError): + Grid([[0, 1], [0, 1]], shifts=[[0, 0, 0]]) + + with pytest.raises(GridError): + Grid([[0, 1], [0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) + + +def test_grid_constructor_rejects_invalid_periodic_length() -> None: + with pytest.raises(GridError): + Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]], periodic=[True, False]) + + +def test_extent_and_slab_reject_inverted_geometry() -> None: + with pytest.raises(GridError): + Extent(center=0, min=1) + + with pytest.raises(GridError): + Extent(min=2, max=1) + + with pytest.raises(GridError): + Slab(axis='z', center=1, max=0) + + +def test_extent_accepts_scalar_like_inputs() -> None: + extent = Extent(min=numpy.array([1.0]), span=numpy.array([4.0])) + + assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0]) + + diff --git a/gridlock/utils.py b/gridlock/utils.py index 8a8f11d..585b999 100644 --- a/gridlock/utils.py +++ b/gridlock/utils.py @@ -1,12 +1,25 @@ from typing import Protocol, TypedDict, runtime_checkable, cast from dataclasses import dataclass +import numpy + class GridError(Exception): """ Base error type for `gridlock` """ pass +def _coerce_scalar(name: str, value: object) -> float: + arr = numpy.asarray(value) + if arr.size != 1: + raise GridError(f'{name} must be a scalar value') + + try: + return float(arr.reshape(())) + except (TypeError, ValueError) as exc: + raise GridError(f'{name} must be a real scalar value') from exc + + class ExtentDict(TypedDict, total=False): """ Geometrical definition of an extent (1D bounded region) @@ -58,44 +71,46 @@ class Extent(ExtentProtocol): max: float | None = None, span: float | None = None, ) -> None: - if sum(cc is None for cc in (min, center, max, span)) != 2: - raise GridError('Exactly two of min, center, max, span must be None!') + values = { + 'min': None if min is None else _coerce_scalar('min', min), + 'center': None if center is None else _coerce_scalar('center', center), + 'max': None if max is None else _coerce_scalar('max', max), + 'span': None if span is None else _coerce_scalar('span', span), + } + if sum(value is not None for value in values.values()) != 2: + raise GridError('Exactly two of min, center, max, span must be provided') - if span is None: - if center is None: - assert min is not None - assert max is not None - assert max >= min - center = 0.5 * (max + min) - span = max - min - elif max is None: - assert min is not None - assert center is not None - span = 2 * (center - min) - elif min is None: - assert center is not None - assert max is not None - span = 2 * (max - center) - else: # noqa: PLR5501 - if center is not None: - pass - elif max is None: - assert min is not None - assert span is not None - center = min + 0.5 * span - elif min is None: - assert max is not None - assert span is not None - center = max - 0.5 * span + min_v = values['min'] + center_v = values['center'] + max_v = values['max'] + span_v = values['span'] - assert center is not None - assert span is not None - if hasattr(center, '__len__'): - assert len(center) == 1 - if hasattr(span, '__len__'): - assert len(span) == 1 - self.center = center - self.span = span + if span_v is not None and span_v < 0: + raise GridError('span must be non-negative') + + if min_v is not None and max_v is not None: + if max_v < min_v: + raise GridError('max must be greater than or equal to min') + center_v = 0.5 * (max_v + min_v) + span_v = max_v - min_v + elif center_v is not None and min_v is not None: + span_v = 2 * (center_v - min_v) + if span_v < 0: + raise GridError('min must be less than or equal to center') + elif center_v is not None and max_v is not None: + span_v = 2 * (max_v - center_v) + if span_v < 0: + raise GridError('center must be less than or equal to max') + elif min_v is not None and span_v is not None: + center_v = min_v + 0.5 * span_v + elif max_v is not None and span_v is not None: + center_v = max_v - 0.5 * span_v + + if center_v is None or span_v is None: + raise GridError('Unable to construct extent from the provided values') + + self.center = center_v + self.span = span_v class SlabDict(TypedDict, total=False): @@ -231,4 +246,3 @@ class Plane(PlaneProtocol): if hasattr(cpos, '__len__'): assert len(cpos) == 1 self.pos = cpos - From ddce4fa491081bee41e4eba699e8ff1bf5669141 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:50:48 -0700 Subject: [PATCH 15/22] [isosurface] fix sampling --- gridlock/read.py | 30 +++++++++++++++++++++++-- gridlock/test/test_grid.py | 46 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/gridlock/read.py b/gridlock/read.py index 998e79d..9df3e08 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -20,6 +20,26 @@ if TYPE_CHECKING: class GridReadMixin(GridPosMixin): + @staticmethod + def _preview_exyz_from_centers(centers: NDArray, fallback_edges: NDArray) -> NDArray[numpy.float64]: + if centers.size > 1: + midpoints = 0.5 * (centers[:-1] + centers[1:]) + first = centers[0] - 0.5 * (centers[1] - centers[0]) + last = centers[-1] + 0.5 * (centers[-1] - centers[-2]) + return numpy.hstack(([first], midpoints, [last])) + return numpy.array([fallback_edges[0], fallback_edges[-1]], dtype=float) + + def _sampled_exyz(self, which_shifts: int, sample_period: int) -> list[NDArray[numpy.float64]]: + if sample_period <= 1: + return self.shifted_exyz(which_shifts) + + shifted_xyz = self.shifted_xyz(which_shifts) + shifted_exyz = self.shifted_exyz(which_shifts) + return [ + self._preview_exyz_from_centers(shifted_xyz[a][::sample_period], shifted_exyz[a]) + for a in range(3) + ] + def get_slice( self, cell_data: NDArray, @@ -262,8 +282,14 @@ class GridReadMixin(GridPosMixin): verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) # Convert vertices from index to position - pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False) - for i in range(verts.shape[0])], dtype=float) + preview_exyz = self._sampled_exyz(which_shifts, sample_period) + pos_verts = numpy.array([ + [ + numpy.interp(verts[i, a], numpy.arange(preview_exyz[a].size) - 0.5, preview_exyz[a]) + for a in range(3) + ] + for i in range(verts.shape[0]) + ], dtype=float) xs, ys, zs = (pos_verts[:, a] for a in range(3)) # Draw the plot diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 60929e8..9f2e4f3 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -226,3 +226,49 @@ def test_extent_accepts_scalar_like_inputs() -> None: assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0]) + + +def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None: + matplotlib = pytest.importorskip('matplotlib') + matplotlib.use('Agg') + skimage_measure = pytest.importorskip('skimage.measure') + from matplotlib import pyplot + from mpl_toolkits.mplot3d.axes3d import Axes3D + + captured: dict[str, numpy.ndarray] = {} + + def fake_marching_cubes(_grid: numpy.ndarray, _level: float) -> tuple[numpy.ndarray, numpy.ndarray, None, None]: + verts = numpy.array([[0.5, 0.5, 0.5], + [0.5, 1.5, 0.5], + [1.5, 0.5, 0.5]], dtype=float) + faces = numpy.array([[0, 1, 2]], dtype=int) + return verts, faces, None, None + + def fake_plot_trisurf( # noqa: ANN202 + _self: object, + xs: numpy.ndarray, + ys: numpy.ndarray, + faces: numpy.ndarray, + zs: numpy.ndarray, + *_args: object, + **_kwargs: object, + ) -> object: + captured['xs'] = numpy.asarray(xs) + captured['ys'] = numpy.asarray(ys) + captured['faces'] = numpy.asarray(faces) + captured['zs'] = numpy.asarray(zs) + return object() + + monkeypatch.setattr(skimage_measure, 'marching_cubes', fake_marching_cubes) + monkeypatch.setattr(Axes3D, 'plot_trisurf', fake_plot_trisurf) + + grid = Grid([numpy.arange(7, dtype=float), numpy.arange(7, dtype=float), numpy.arange(7, dtype=float)], shifts=[[0, 0, 0]]) + cell_data = numpy.zeros(grid.cell_data_shape) + + fig, _ax = grid.visualize_isosurface(cell_data, level=0.5, sample_period=2, finalize=False) + + assert_allclose(captured['xs'], [1.5, 1.5, 3.5]) + assert_allclose(captured['ys'], [1.5, 3.5, 1.5]) + assert_allclose(captured['zs'], [1.5, 1.5, 1.5]) + + pyplot.close(fig) From e345d1dcf8f9b52af7cd83844efe66082f0b0379 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:51:34 -0700 Subject: [PATCH 16/22] [get_slice] use shifted bounds --- gridlock/read.py | 2 +- gridlock/test/test_grid.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/gridlock/read.py b/gridlock/read.py index 9df3e08..9be52b1 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -83,7 +83,7 @@ class GridReadMixin(GridPosMixin): else: w = [1] - c_min, c_max = (self.xyz[plane.axis][i] for i in [0, -1]) + c_min, c_max = (self.shifted_xyz(which_shifts)[plane.axis][i] for i in [0, -1]) if plane.pos < c_min or plane.pos > c_max: raise GridError('Coordinate of selected plane must be within simulation domain') diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 9f2e4f3..c6c8ae7 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -226,6 +226,18 @@ def test_extent_accepts_scalar_like_inputs() -> None: assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0]) +def test_get_slice_uses_shifted_grid_bounds() -> None: + grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0.5, 0, 0]]) + cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) + + grid_slice = grid.get_slice(cell_data, Plane(x=2.0), which_shifts=0) + + assert_allclose(grid_slice, cell_data[0, 1, :, :]) + + with pytest.raises(GridError): + grid.get_slice(cell_data, Plane(x=2.1), which_shifts=0) + + def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None: From 8895b06f08df4fb43f5910cd29468f32db8866ff Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:51:59 -0700 Subject: [PATCH 17/22] fixup! [isosurface] fix sampling --- gridlock/test/test_grid.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index c6c8ae7..2cb60c5 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -240,6 +240,14 @@ def test_get_slice_uses_shifted_grid_bounds() -> None: +def test_sampled_preview_exyz_tracks_nonuniform_centers() -> None: + grid = Grid([[0, 1, 3, 6, 10], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]]) + + sampled_exyz = grid._sampled_exyz(0, 2) + + assert_allclose(sampled_exyz[0], [-1.5, 2.5, 6.5]) + + def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None: matplotlib = pytest.importorskip('matplotlib') matplotlib.use('Agg') From 481b56874ee9c42f8534a378fa85e89c1e523d93 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:52:45 -0700 Subject: [PATCH 18/22] [draw] fix extrude without out-of-bounds slice --- gridlock/draw.py | 23 +++++++++++++---------- gridlock/test/test_grid.py | 17 +++++++++++++++++ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 864468f..321ec15 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -76,10 +76,10 @@ class GridDrawMixin(GridPosMixin): # Broadcast foreground where necessary foregrounds: Sequence[foreground_callable_t] | Sequence[float] - if numpy.size(foreground) == 1: # type: ignore - foregrounds = [foreground] * len(cell_data) # type: ignore - elif isinstance(foreground, numpy.ndarray): + if isinstance(foreground, numpy.ndarray): raise GridError('ndarray not supported for foreground') + if callable(foreground) or numpy.isscalar(foreground): + foregrounds = [foreground] * len(cell_data) # type: ignore[list-item] else: foregrounds = foreground # type: ignore @@ -376,15 +376,18 @@ class GridDrawMixin(GridPosMixin): foreground_func = [] for ii, grid in enumerate(cell_data): zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction] - - ind = [int(numpy.floor(zz)) if dd == direction else slice(None) for dd in range(3)] - fpart = zz - numpy.floor(zz) - mult = [1 - fpart, fpart][::sgn] # reverses if s negative + low = int(numpy.clip(numpy.floor(zz), 0, grid.shape[direction] - 1)) + high = int(numpy.clip(numpy.floor(zz) + 1, 0, grid.shape[direction] - 1)) - foreground = mult[0] * grid[tuple(ind)] - ind[direction] += 1 # type: ignore #(known safe) - foreground += mult[1] * grid[tuple(ind)] + low_ind = [low if dd == direction else slice(None) for dd in range(3)] + high_ind = [high if dd == direction else slice(None) for dd in range(3)] + + if low == high: + foreground = grid[tuple(low_ind)] + else: + mult = [1 - fpart, fpart][::sgn] # reverses if s negative + foreground = mult[0] * grid[tuple(low_ind)] + mult[1] * grid[tuple(high_ind)] def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 # transform from natural position to index diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 2cb60c5..e7b3b28 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -238,6 +238,23 @@ def test_get_slice_uses_shifted_grid_bounds() -> None: grid.get_slice(cell_data, Plane(x=2.1), which_shifts=0) +def test_draw_extrude_rectangle_uses_boundary_slice() -> None: + grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]]) + cell_data = grid.allocate(0) + source = numpy.array([[1, 2], + [3, 4]], dtype=float) + cell_data[0, :, :, 1] = source + + grid.draw_extrude_rectangle( + cell_data, + rectangle=[[0, 0, 2], [2, 2, 2]], + direction=2, + polarity=-1, + distance=2, + ) + + assert_allclose(cell_data[0, :, :, 0], source) + assert_allclose(cell_data[0, :, :, 1], source) def test_sampled_preview_exyz_tracks_nonuniform_centers() -> None: From 96aad5a3a10ab779bbdf00da081cdbf85861096d Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 11:00:08 -0700 Subject: [PATCH 19/22] bump version to v2.1 --- gridlock/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 2f39696..e7be065 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -34,5 +34,5 @@ from .grid import Grid as Grid __author__ = 'Jan Petykiewicz' -__version__ = '2.0' +__version__ = '2.1' version = __version__ From 066ca8f3b88cc03a30da43125358895ee0337e84 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 11:00:49 -0700 Subject: [PATCH 20/22] bump version to v2.2 2.1 had an existing tag --- gridlock/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/__init__.py b/gridlock/__init__.py index e7be065..3f965fd 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -34,5 +34,5 @@ from .grid import Grid as Grid __author__ = 'Jan Petykiewicz' -__version__ = '2.1' +__version__ = '2.2' version = __version__ From 85ae6e66cd4ee97192d6bb33249b5dc69e3d5668 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Tue, 21 Apr 2026 19:58:57 -0700 Subject: [PATCH 21/22] [Grid] enable negative shifts --- gridlock/base.py | 40 +++++++++++++-------------- gridlock/read.py | 2 +- gridlock/test/test_grid.py | 55 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 23 deletions(-) diff --git a/gridlock/base.py b/gridlock/base.py index aca9c69..e68d955 100644 --- a/gridlock/base.py +++ b/gridlock/base.py @@ -76,6 +76,21 @@ class GridBase(Protocol): el = [0 if p else -1 for p in self.periodic] return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)] + def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: + if which_shifts is None: + return self.dxyz_with_ghost + + shifts = self.shifts[which_shifts, :] + edge_dxyz = [] + for a in range(3): + if shifts[a] < 0: + ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0] + edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a]))) + else: + ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1] + edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost))) + return edge_dxyz + @property def center(self) -> NDArray[numpy.float64]: """ @@ -115,15 +130,9 @@ class GridBase(Protocol): """ if which_shifts is None: return self.exyz - dxyz = self.dxyz_with_ghost + edge_dxyz = self._shifted_edge_dxyz(which_shifts) shifts = self.shifts[which_shifts, :] - - # If shift is negative, use left cell's dx to determine shift - for a in range(3): - if shifts[a] < 0: - dxyz[a] = numpy.roll(dxyz[a], 1) - - return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] + return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)] def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]: """ @@ -137,20 +146,7 @@ class GridBase(Protocol): """ if which_shifts is None: return self.dxyz - shifts = self.shifts[which_shifts, :] - dxyz = self.dxyz_with_ghost - - # If shift is negative, use left cell's dx to determine size - sdxyz = [] - for a in range(3): - if shifts[a] < 0: - roll_dxyz = numpy.roll(dxyz[a], 1) - abs_shift = numpy.abs(shifts[a]) - sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift)) - else: - sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a]) - - return sdxyz + return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)] def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: """ diff --git a/gridlock/read.py b/gridlock/read.py index 9be52b1..f8a40a1 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -73,7 +73,7 @@ class GridReadMixin(GridPosMixin): surface = numpy.delete(range(3), plane.axis) # Extract indices and weights of planes - center3 = numpy.insert([0, 0], plane.axis, (plane.pos,)) + center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,)) center_index = self.pos2ind(center3, which_shifts, round_ind=False, check_bounds=False)[plane.axis] centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index e7b3b28..b4929a4 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -309,3 +309,58 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest. assert_allclose(captured['zs'], [1.5, 1.5, 1.5]) pyplot.close(fig) + + + + +def test_negative_shift_nonperiodic_edges_and_widths() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + + assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0]) + assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5]) + assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25]) + + +def test_negative_shift_periodic_edges_and_widths() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False]) + + assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0]) + assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5]) + + +def test_negative_shift_coordinate_round_trip() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + + ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False) + pos = grid.ind2pos(ind, 0, round_ind=False) + + assert_allclose(ind, [1.0, 0.0, 0.0]) + assert_allclose(pos, [1.25, 1.0, 0.5]) + + +def test_negative_shift_draw_cuboid_fractional_fill() -> None: + grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + arr = grid.allocate(0) + + grid.draw_cuboid( + arr, + x=dict(min=0, max=1), + y=dict(min=0, max=1), + z=dict(min=0, max=1), + foreground=1, + ) + + assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3]) + + +def test_negative_shift_get_slice_uses_shifted_centers() -> None: + grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + cell_data = numpy.zeros(grid.cell_data_shape) + cell_data[0, 1, :, 0] = [7, 9] + x_center = float(grid.shifted_xyz(0)[0][1]) + + grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0) + + assert_allclose(grid_slice, [7, 9]) + + From 22cb410d84ff4f33b727376761ebc489b59c382e Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Tue, 21 Apr 2026 20:00:52 -0700 Subject: [PATCH 22/22] [GridData / save / load] Add GridData and update save format --- gridlock/__init__.py | 1 + gridlock/data.py | 176 +++++++++++++++++++++++++++++++++++++ gridlock/grid.py | 110 ++++++++++++++++++++--- gridlock/test/test_grid.py | 102 ++++++++++++++++++++- 4 files changed, 376 insertions(+), 13 deletions(-) create mode 100644 gridlock/data.py diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 3f965fd..759d1c1 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -31,6 +31,7 @@ from .utils import ( PlaneDict as PlaneDict, ) from .grid import Grid as Grid +from .data import GridData as GridData __author__ = 'Jan Petykiewicz' diff --git a/gridlock/data.py b/gridlock/data.py new file mode 100644 index 0000000..5e6faa5 --- /dev/null +++ b/gridlock/data.py @@ -0,0 +1,176 @@ +from dataclasses import dataclass +from typing import Self +from collections.abc import Sequence + +import numpy +from numpy.typing import NDArray, ArrayLike + +from .draw import foreground_t +from .grid import Grid, _grid_from_payload, _load_payload, _payload_scalar_str, _save_npz_payload +from .utils import ( + ExtentDict, + ExtentProtocol, + GridError, + PlaneDict, + PlaneProtocol, + SlabDict, + SlabProtocol, +) + + +@dataclass(slots=True) +class GridData: + grid: Grid + cell_data: NDArray + + def __post_init__(self) -> None: + if tuple(self.cell_data.shape) != tuple(self.grid.cell_data_shape): + raise GridError( + f'cell_data has shape {self.cell_data.shape}, expected {tuple(self.grid.cell_data_shape)}' + ) + + @staticmethod + def load(filename: str) -> 'GridData': + payload = _load_payload(filename) + if _payload_scalar_str(payload, 'kind') != 'grid_data': + raise GridError('Serialized payload does not contain GridData') + if 'cell_data' not in payload: + raise GridError('Serialized GridData payload is missing cell_data') + + return GridData(_grid_from_payload(payload), numpy.array(payload['cell_data'])) + + def save(self, filename: str) -> Self: + payload = self.grid._serialization_payload(kind='grid_data') + payload['cell_data'] = self.cell_data + _save_npz_payload(filename, payload) + return self + + def copy(self) -> Self: + return GridData(self.grid.copy(), self.cell_data.copy()) + + def draw_polygons( + self, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + polygons: Sequence[ArrayLike], + *, + offset2d: ArrayLike = (0, 0), + ) -> Self: + self.grid.draw_polygons(self.cell_data, foreground, slab, polygons, offset2d=offset2d) + return self + + def draw_polygon( + self, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + polygon: ArrayLike, + *, + offset2d: ArrayLike = (0, 0), + ) -> Self: + self.grid.draw_polygon(self.cell_data, foreground, slab, polygon, offset2d=offset2d) + return self + + def draw_slab( + self, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + ) -> Self: + self.grid.draw_slab(self.cell_data, foreground, slab) + return self + + def draw_cuboid( + self, + foreground: Sequence[foreground_t] | foreground_t, + *, + x: ExtentProtocol | ExtentDict, + y: ExtentProtocol | ExtentDict, + z: ExtentProtocol | ExtentDict, + ) -> Self: + self.grid.draw_cuboid(self.cell_data, foreground, x=x, y=y, z=z) + return self + + def draw_cylinder( + self, + h: SlabProtocol | SlabDict, + radius: float, + num_points: int, + center2d: ArrayLike, + foreground: Sequence[foreground_t] | foreground_t, + ) -> Self: + self.grid.draw_cylinder(self.cell_data, h, radius, num_points, center2d, foreground) + return self + + def draw_extrude_rectangle( + self, + rectangle: ArrayLike, + direction: int, + polarity: int, + distance: float, + ) -> Self: + self.grid.draw_extrude_rectangle(self.cell_data, rectangle, direction, polarity, distance) + return self + + def get_slice( + self, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + sample_period: int = 1, + ) -> NDArray: + return self.grid.get_slice(self.cell_data, plane, which_shifts=which_shifts, sample_period=sample_period) + + def visualize_slice( + self, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + sample_period: int = 1, + finalize: bool = True, + pcolormesh_args: dict[str, object] | None = None, + ax: object | None = None, + ) -> tuple[object, object]: + return self.grid.visualize_slice( + self.cell_data, + plane, + which_shifts=which_shifts, + sample_period=sample_period, + finalize=finalize, + pcolormesh_args=pcolormesh_args, + ax=ax, + ) + + def visualize_edges( + self, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + sample_period: int = 1, + finalize: bool = True, + contour_args: dict[str, object] | None = None, + ax: object | None = None, + level_fraction: float = 0.7, + ) -> tuple[object, object]: + return self.grid.visualize_edges( + self.cell_data, + plane, + which_shifts=which_shifts, + sample_period=sample_period, + finalize=finalize, + contour_args=contour_args, + ax=ax, + level_fraction=level_fraction, + ) + + def visualize_isosurface( + self, + level: float | None = None, + which_shifts: int = 0, + sample_period: int = 1, + show_edges: bool = True, + finalize: bool = True, + ) -> tuple[object, object]: + return self.grid.visualize_isosurface( + self.cell_data, + level=level, + which_shifts=which_shifts, + sample_period=sample_period, + show_edges=show_edges, + finalize=finalize, + ) diff --git a/gridlock/grid.py b/gridlock/grid.py index 5bed422..eeb9708 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Self +from typing import TYPE_CHECKING, Any, ClassVar, Self from collections.abc import Callable, Sequence import numpy @@ -13,8 +13,78 @@ from .draw import GridDrawMixin from .read import GridReadMixin from .position import GridPosMixin +if TYPE_CHECKING: + from .data import GridData + foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray] +_FORMAT_VERSION = 1 + + +def _is_npz_file(filename: str) -> bool: + with open(filename, 'rb') as f: + return f.read(2) == b'PK' + + +def _save_npz_payload(filename: str, payload: dict[str, Any]) -> None: + with open(filename, 'wb') as f: + numpy.savez_compressed(f, **payload) + + +def _load_payload(filename: str) -> dict[str, Any]: + if _is_npz_file(filename): + with numpy.load(filename, allow_pickle=False) as payload: + return {key: payload[key] for key in payload.files} + + with open(filename, 'rb') as f: + legacy = pickle.load(f) + + if isinstance(legacy, Grid): + return legacy._serialization_payload(kind='grid') + if isinstance(legacy, dict): + grid = Grid([[-1, 1]] * 3) + grid.__dict__.update(legacy) + return grid._serialization_payload(kind='grid') + raise GridError('Unsupported serialized Grid payload') + + +def _payload_scalar_str(payload: dict[str, Any], key: str) -> str: + if key not in payload: + raise GridError(f'Missing serialized key: {key}') + + value = numpy.asarray(payload[key]) + if value.size != 1: + raise GridError(f'Serialized key {key} must be scalar') + return str(value.reshape(())) + + +def _payload_scalar_int(payload: dict[str, Any], key: str) -> int: + if key not in payload: + raise GridError(f'Missing serialized key: {key}') + + value = numpy.asarray(payload[key]) + if value.size != 1: + raise GridError(f'Serialized key {key} must be scalar') + return int(value.reshape(())) + + +def _grid_from_payload(payload: dict[str, Any]) -> 'Grid': + if _payload_scalar_int(payload, 'format_version') != _FORMAT_VERSION: + raise GridError('Unsupported serialized Grid format version') + + exyz = [] + for axis in range(3): + key = f'exyz_{axis}' + if key not in payload: + raise GridError(f'Missing serialized key: {key}') + exyz.append(numpy.array(payload[key], dtype=float)) + + if 'shifts' not in payload or 'periodic' not in payload: + raise GridError('Serialized Grid payload is missing shifts or periodic data') + + shifts = numpy.array(payload['shifts'], dtype=float) + periodic = numpy.array(payload['periodic'], dtype=bool).tolist() + return Grid(exyz, shifts=shifts, periodic=periodic) class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): @@ -110,6 +180,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): self.periodic = list(periodic) if len(self.periodic) != 3: raise GridError('periodic must be a bool or a sequence of length 3') + if not all(isinstance(pp, bool | numpy.bool_) for pp in self.periodic): + raise GridError('periodic sequence entries must be bool values') if len(self.shifts.shape) != 2: raise GridError('Misshapen shifts: shifts must have two axes! ' @@ -121,9 +193,16 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): if (numpy.abs(self.shifts) > 1).any(): raise GridError('Only shifts in the range [-1, 1] are currently supported') - if (self.shifts < 0).any(): - # TODO: Test negative shifts - warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2) + def _serialization_payload(self, *, kind: str) -> dict[str, Any]: + payload: dict[str, Any] = { + 'kind': numpy.array(kind), + 'format_version': numpy.array(_FORMAT_VERSION, dtype=int), + 'shifts': self.shifts, + 'periodic': numpy.array(self.periodic, dtype=bool), + } + for axis, exyz in enumerate(self.exyz): + payload[f'exyz_{axis}'] = exyz + return payload @staticmethod def load(filename: str) -> 'Grid': @@ -133,12 +212,11 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Args: filename: Filename to load from. """ - with open(filename, 'rb') as f: - tmp_dict = pickle.load(f) - - g = Grid([[-1, 1]] * 3) - g.__dict__.update(tmp_dict) - return g + payload = _load_payload(filename) + kind = _payload_scalar_str(payload, 'kind') + if kind not in ('grid', 'grid_data'): + raise GridError(f'Unsupported serialized kind: {kind}') + return _grid_from_payload(payload) def save(self, filename: str) -> Self: """ @@ -150,10 +228,18 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Returns: self """ - with open(filename, 'wb') as f: - pickle.dump(self.__dict__, f, protocol=2) + _save_npz_payload(filename, self._serialization_payload(kind='grid')) return self + def with_data( + self, + fill_value: float | None = 1.0, + dtype: type[numpy.number] = numpy.float32, + ) -> 'GridData': + from .data import GridData + + return GridData(self.copy(), self.allocate(fill_value=fill_value, dtype=dtype)) + def copy(self) -> Self: """ Returns: diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index b4929a4..ae0a73a 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,8 +1,9 @@ import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal +import pickle -from .. import Grid, Extent, GridError, Plane, Slab +from .. import Grid, GridData, Extent, GridError, Plane, Slab def test_draw_oncenter_2x2() -> None: @@ -311,6 +312,54 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest. pyplot.close(fig) +def test_grid_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False]) + path = tmp_path / 'grid.state' + + grid.save(str(path)) + loaded = Grid.load(str(path)) + + assert path.exists() + for original, restored in zip(grid.exyz, loaded.exyz, strict=True): + assert_allclose(restored, original) + assert_allclose(loaded.shifts, grid.shifts) + assert loaded.periodic == grid.periodic + + +def test_grid_load_supports_legacy_pickle(tmp_path: pytest.TempPathFactory) -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False]) + path = tmp_path / 'grid.pickle' + with open(path, 'wb') as f: + pickle.dump(grid.__dict__, f, protocol=2) + + loaded = Grid.load(str(path)) + + for original, restored in zip(grid.exyz, loaded.exyz, strict=True): + assert_allclose(restored, original) + assert_allclose(loaded.shifts, grid.shifts) + assert loaded.periodic == grid.periodic + + +def test_griddata_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None: + data = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]]).with_data(fill_value=2.0) + data.cell_data[0, 1, 0, 0] = 5.0 + path = tmp_path / 'griddata.state' + + data.save(str(path)) + loaded = GridData.load(str(path)) + + assert path.exists() + assert_allclose(loaded.cell_data, data.cell_data) + assert_allclose(loaded.grid.shifts, data.grid.shifts) + assert loaded.grid.periodic == data.grid.periodic + + +def test_griddata_rejects_invalid_payload_kind(tmp_path: pytest.TempPathFactory) -> None: + path = tmp_path / 'grid.state' + Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).save(str(path)) + + with pytest.raises(GridError): + GridData.load(str(path)) def test_negative_shift_nonperiodic_edges_and_widths() -> None: @@ -364,3 +413,54 @@ def test_negative_shift_get_slice_uses_shifted_centers() -> None: assert_allclose(grid_slice, [7, 9]) +def test_grid_with_data_returns_griddata() -> None: + grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) + data = grid.with_data(fill_value=2.0) + + assert isinstance(data, GridData) + assert_allclose(data.cell_data, numpy.full(grid.cell_data_shape, 2.0, dtype=numpy.float32)) + + +def test_griddata_constructor_validates_shape() -> None: + grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) + + with pytest.raises(GridError): + GridData(grid, numpy.zeros((1, 1, 1))) + + +def test_griddata_draw_methods_are_chainable() -> None: + data = Grid([[0, 1, 2], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0) + + chained = data.draw_cuboid( + foreground=1, + x=dict(min=0, max=1), + y=dict(min=0, max=1), + z=dict(min=0, max=1), + ).draw_polygon( + foreground=0.5, + slab=dict(axis='z', center=0.5, span=1.0), + polygon=numpy.array([[0, 0], [2, 0], [2, 1], [0, 1]], dtype=float), + ) + + assert chained is data + assert data.cell_data.sum() > 0 + + +def test_griddata_read_methods_delegate() -> None: + data = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0) + data.cell_data[0, :, :, 0] = numpy.array([[1, 2], [3, 4]], dtype=float) + + assert_allclose( + data.get_slice(Plane(z=0.5)), + data.grid.get_slice(data.cell_data, Plane(z=0.5)), + ) + + +def test_griddata_copy_is_independent() -> None: + data = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=1.0) + cloned = data.copy() + cloned.cell_data[0, 0, 0, 0] = 5.0 + + assert data is not cloned + assert data.grid is not cloned.grid + assert data.cell_data[0, 0, 0, 0] == 1.0