From 43d5fa8b4f2e1d48a2d35de02ebc2a7adf591168 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:21:22 -0700 Subject: [PATCH 01/12] [draw] fix handling of Nx3 vertex arrays --- gridlock/draw.py | 11 +++++------ gridlock/test/test_grid.py | 35 +++++++++++++++++++++++++++++++++-- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 9ba4623..864468f 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -61,16 +61,18 @@ class GridDrawMixin(GridPosMixin): for ii in range(len(poly_list)): polygon = poly_list[ii] malformed = f'Malformed polygon: ({ii})' + if polygon.ndim != 2: + raise GridError(malformed + 'must be a 2-dimensional ndarray') if polygon.shape[1] not in (2, 3): raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') if polygon.shape[1] == 3: - polygon = polygon[surface, :] + if numpy.unique(polygon[:, slab.axis]).size != 1: + raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis]) + polygon = polygon[:, surface] poly_list[ii] = polygon if not polygon.shape[0] > 2: raise GridError(malformed + 'must consist of more than 2 points') - if polygon.ndim > 2 and not numpy.unique(polygon[:, slab.axis]).size == 1: - raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis]) # Broadcast foreground where necessary foregrounds: Sequence[foreground_callable_t] | Sequence[float] @@ -296,8 +298,6 @@ class GridDrawMixin(GridPosMixin): if isinstance(z, dict): z = Extent(**z) - center = numpy.asarray([x.center, y.center, z.center]) - p = numpy.array([[x.min, y.max], [x.max, y.max], [x.max, y.min], @@ -398,4 +398,3 @@ class GridDrawMixin(GridPosMixin): slab = Slab(axis=direction, center=center[direction], span=thickness) self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface]) - diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 8d9ca92..6cb9edc 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,8 +1,8 @@ -# import pytest +import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal -from .. import Grid, Extent #, Slab, Plane +from .. import Grid, Extent, GridError, Plane def test_draw_oncenter_2x2() -> None: @@ -116,3 +116,34 @@ def test_draw_2shift_4x4() -> None: [0, 0.125, 0.125, 0]])[None, :, :, None] assert_allclose(arr, correct) + + +def test_draw_polygon_accepts_coplanar_nx3_vertices() -> None: + grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) + arr_2d = grid.allocate(0) + arr_3d = grid.allocate(0) + slab = dict(axis='z', center=0.5, span=1.0) + + polygon_2d = numpy.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=float) + polygon_3d = numpy.array([[0, 0, 0.5], + [1, 0, 0.5], + [1, 1, 0.5], + [0, 1, 0.5]], dtype=float) + + grid.draw_polygon(arr_2d, slab=slab, polygon=polygon_2d, foreground=1) + grid.draw_polygon(arr_3d, slab=slab, polygon=polygon_3d, foreground=1) + + assert_allclose(arr_3d, arr_2d) + + +def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None: + grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) + arr = grid.allocate(0) + polygon = numpy.array([[0, 0, 0.5], + [1, 0, 0.5], + [1, 1, 0.75], + [0, 1, 0.5]], dtype=float) + + with pytest.raises(GridError): + grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1) + From 1cc47da386d69a56938c4d62629f74afd2d20966 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:25:13 -0700 Subject: [PATCH 02/12] [ind2pos] fix rounding and bounds --- gridlock/position.py | 4 ++-- gridlock/test/test_grid.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/gridlock/position.py b/gridlock/position.py index b705b99..6344ea4 100644 --- a/gridlock/position.py +++ b/gridlock/position.py @@ -47,13 +47,13 @@ class GridPosMixin(GridBase): else: low_bound = -0.5 high_bound = -0.5 - if (ind < low_bound).any() or (ind > self.shape - high_bound).any(): + if (ind < low_bound).any() or (ind > self.shape + high_bound).any(): raise GridError(f'Position outside of grid: {ind}') if round_ind: rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) sxyz = self.shifted_xyz(which_shifts) - position = [sxyz[a][rind[a]].astype(int) for a in range(3)] + position = [sxyz[a][rind[a]] for a in range(3)] else: sexyz = self.shifted_exyz(which_shifts) position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a]) diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 6cb9edc..a9e3d9e 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -118,6 +118,27 @@ def test_draw_2shift_4x4() -> None: assert_allclose(arr, correct) +def test_ind2pos_round_preserves_float_centers() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]]) + + pos = grid.ind2pos(numpy.array([1, 0, 0]), which_shifts=0) + + assert_allclose(pos, [2.0, 1.0, 0.5]) + + +def test_ind2pos_enforces_bounds_for_rounded_and_fractional_indices() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]]) + + with pytest.raises(GridError): + grid.ind2pos(numpy.array([2, 0, 0]), which_shifts=0, check_bounds=True) + + edge_pos = grid.ind2pos(numpy.array([1.5, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True) + assert_allclose(edge_pos, [3.0, 2.0, 1.0]) + + with pytest.raises(GridError): + grid.ind2pos(numpy.array([1.6, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True) + + def test_draw_polygon_accepts_coplanar_nx3_vertices() -> None: grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) arr_2d = grid.allocate(0) From 526b9e1666b55c59cbf2fa684e9f20dc500b7ac8 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:25:41 -0700 Subject: [PATCH 03/12] [read] fix sampling --- gridlock/read.py | 13 +++++++++---- gridlock/test/test_grid.py | 26 ++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/gridlock/read.py b/gridlock/read.py index 503e996..998e79d 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -68,7 +68,8 @@ class GridReadMixin(GridPosMixin): raise GridError('Coordinate of selected plane must be within simulation domain') # Extract grid values from planes above and below visualized slice - sliced_grid = numpy.zeros(self.shape[surface]) + sample_shape = tuple(self.shifted_xyz(which_shifts)[a][::sp].size for a in surface) + sliced_grid = numpy.zeros(sample_shape, dtype=numpy.result_type(cell_data.dtype, float)) for ci, weight in zip(centers, w, strict=True): s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3)) sliced_grid += weight * cell_data[which_shifts][tuple(s)] @@ -122,7 +123,11 @@ class GridReadMixin(GridPosMixin): surface = numpy.delete(range(3), plane.axis) - x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) + if sample_period == 1: + x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) + else: + x, y = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface) + pcolormesh_args.setdefault('shading', 'nearest') xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') x_label, y_label = ('xyz'[a] for a in surface) @@ -208,10 +213,10 @@ class GridReadMixin(GridPosMixin): fig, ax = pyplot.subplots() else: fig = ax.figure - xc, yc = (self.shifted_xyz(which_shifts)[a] for a in surface) + xc, yc = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface) xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij') - mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) + ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) if finalize: pyplot.show() diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index a9e3d9e..84b0f7b 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -168,3 +168,29 @@ def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None: with pytest.raises(GridError): grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1) + +def test_get_slice_supports_sampling() -> None: + grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]]) + cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) + + grid_slice = grid.get_slice(cell_data, Plane(z=0.5), sample_period=2) + + assert_allclose(grid_slice, cell_data[0, ::2, ::2, 0]) + + +def test_sampled_visualization_helpers_do_not_error() -> None: + matplotlib = pytest.importorskip('matplotlib') + matplotlib.use('Agg') + from matplotlib import pyplot + + grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]]) + cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) + + fig_slice, ax_slice = grid.visualize_slice(cell_data, Plane(z=0.5), sample_period=2, finalize=False) + fig_edges, ax_edges = grid.visualize_edges(cell_data, Plane(z=0.5), sample_period=2, finalize=False) + + assert fig_slice is ax_slice.figure + assert fig_edges is ax_edges.figure + + pyplot.close(fig_slice) + pyplot.close(fig_edges) From 15c2cf83516a8fe9bce4a2a0603398f2bded0dcc Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:47:35 -0700 Subject: [PATCH 04/12] improve arg checking --- gridlock/grid.py | 4 ++ gridlock/test/test_grid.py | 34 ++++++++++++++- gridlock/utils.py | 88 ++++++++++++++++++++++---------------- 3 files changed, 88 insertions(+), 38 deletions(-) diff --git a/gridlock/grid.py b/gridlock/grid.py index 5790dbd..5bed422 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -95,6 +95,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): `GridError` on invalid input """ edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates] + if len(edge_arrs) != 3: + raise GridError('pixel_edge_coordinates must contain exactly 3 coordinate arrays') self.exyz = [numpy.unique(edges) for edges in edge_arrs] self.shifts = numpy.array(shifts, dtype=float) @@ -106,6 +108,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): self.periodic = [periodic] * 3 else: self.periodic = list(periodic) + if len(self.periodic) != 3: + raise GridError('periodic must be a bool or a sequence of length 3') if len(self.shifts.shape) != 2: raise GridError('Misshapen shifts: shifts must have two axes! ' diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 84b0f7b..60929e8 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -2,7 +2,7 @@ import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal -from .. import Grid, Extent, GridError, Plane +from .. import Grid, Extent, GridError, Plane, Slab def test_draw_oncenter_2x2() -> None: @@ -194,3 +194,35 @@ def test_sampled_visualization_helpers_do_not_error() -> None: pyplot.close(fig_slice) pyplot.close(fig_edges) + + +def test_grid_constructor_rejects_invalid_coordinate_count() -> None: + with pytest.raises(GridError): + Grid([[0, 1], [0, 1]], shifts=[[0, 0, 0]]) + + with pytest.raises(GridError): + Grid([[0, 1], [0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) + + +def test_grid_constructor_rejects_invalid_periodic_length() -> None: + with pytest.raises(GridError): + Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]], periodic=[True, False]) + + +def test_extent_and_slab_reject_inverted_geometry() -> None: + with pytest.raises(GridError): + Extent(center=0, min=1) + + with pytest.raises(GridError): + Extent(min=2, max=1) + + with pytest.raises(GridError): + Slab(axis='z', center=1, max=0) + + +def test_extent_accepts_scalar_like_inputs() -> None: + extent = Extent(min=numpy.array([1.0]), span=numpy.array([4.0])) + + assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0]) + + diff --git a/gridlock/utils.py b/gridlock/utils.py index 8a8f11d..585b999 100644 --- a/gridlock/utils.py +++ b/gridlock/utils.py @@ -1,12 +1,25 @@ from typing import Protocol, TypedDict, runtime_checkable, cast from dataclasses import dataclass +import numpy + class GridError(Exception): """ Base error type for `gridlock` """ pass +def _coerce_scalar(name: str, value: object) -> float: + arr = numpy.asarray(value) + if arr.size != 1: + raise GridError(f'{name} must be a scalar value') + + try: + return float(arr.reshape(())) + except (TypeError, ValueError) as exc: + raise GridError(f'{name} must be a real scalar value') from exc + + class ExtentDict(TypedDict, total=False): """ Geometrical definition of an extent (1D bounded region) @@ -58,44 +71,46 @@ class Extent(ExtentProtocol): max: float | None = None, span: float | None = None, ) -> None: - if sum(cc is None for cc in (min, center, max, span)) != 2: - raise GridError('Exactly two of min, center, max, span must be None!') + values = { + 'min': None if min is None else _coerce_scalar('min', min), + 'center': None if center is None else _coerce_scalar('center', center), + 'max': None if max is None else _coerce_scalar('max', max), + 'span': None if span is None else _coerce_scalar('span', span), + } + if sum(value is not None for value in values.values()) != 2: + raise GridError('Exactly two of min, center, max, span must be provided') - if span is None: - if center is None: - assert min is not None - assert max is not None - assert max >= min - center = 0.5 * (max + min) - span = max - min - elif max is None: - assert min is not None - assert center is not None - span = 2 * (center - min) - elif min is None: - assert center is not None - assert max is not None - span = 2 * (max - center) - else: # noqa: PLR5501 - if center is not None: - pass - elif max is None: - assert min is not None - assert span is not None - center = min + 0.5 * span - elif min is None: - assert max is not None - assert span is not None - center = max - 0.5 * span + min_v = values['min'] + center_v = values['center'] + max_v = values['max'] + span_v = values['span'] - assert center is not None - assert span is not None - if hasattr(center, '__len__'): - assert len(center) == 1 - if hasattr(span, '__len__'): - assert len(span) == 1 - self.center = center - self.span = span + if span_v is not None and span_v < 0: + raise GridError('span must be non-negative') + + if min_v is not None and max_v is not None: + if max_v < min_v: + raise GridError('max must be greater than or equal to min') + center_v = 0.5 * (max_v + min_v) + span_v = max_v - min_v + elif center_v is not None and min_v is not None: + span_v = 2 * (center_v - min_v) + if span_v < 0: + raise GridError('min must be less than or equal to center') + elif center_v is not None and max_v is not None: + span_v = 2 * (max_v - center_v) + if span_v < 0: + raise GridError('center must be less than or equal to max') + elif min_v is not None and span_v is not None: + center_v = min_v + 0.5 * span_v + elif max_v is not None and span_v is not None: + center_v = max_v - 0.5 * span_v + + if center_v is None or span_v is None: + raise GridError('Unable to construct extent from the provided values') + + self.center = center_v + self.span = span_v class SlabDict(TypedDict, total=False): @@ -231,4 +246,3 @@ class Plane(PlaneProtocol): if hasattr(cpos, '__len__'): assert len(cpos) == 1 self.pos = cpos - From ddce4fa491081bee41e4eba699e8ff1bf5669141 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:50:48 -0700 Subject: [PATCH 05/12] [isosurface] fix sampling --- gridlock/read.py | 30 +++++++++++++++++++++++-- gridlock/test/test_grid.py | 46 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/gridlock/read.py b/gridlock/read.py index 998e79d..9df3e08 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -20,6 +20,26 @@ if TYPE_CHECKING: class GridReadMixin(GridPosMixin): + @staticmethod + def _preview_exyz_from_centers(centers: NDArray, fallback_edges: NDArray) -> NDArray[numpy.float64]: + if centers.size > 1: + midpoints = 0.5 * (centers[:-1] + centers[1:]) + first = centers[0] - 0.5 * (centers[1] - centers[0]) + last = centers[-1] + 0.5 * (centers[-1] - centers[-2]) + return numpy.hstack(([first], midpoints, [last])) + return numpy.array([fallback_edges[0], fallback_edges[-1]], dtype=float) + + def _sampled_exyz(self, which_shifts: int, sample_period: int) -> list[NDArray[numpy.float64]]: + if sample_period <= 1: + return self.shifted_exyz(which_shifts) + + shifted_xyz = self.shifted_xyz(which_shifts) + shifted_exyz = self.shifted_exyz(which_shifts) + return [ + self._preview_exyz_from_centers(shifted_xyz[a][::sample_period], shifted_exyz[a]) + for a in range(3) + ] + def get_slice( self, cell_data: NDArray, @@ -262,8 +282,14 @@ class GridReadMixin(GridPosMixin): verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) # Convert vertices from index to position - pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False) - for i in range(verts.shape[0])], dtype=float) + preview_exyz = self._sampled_exyz(which_shifts, sample_period) + pos_verts = numpy.array([ + [ + numpy.interp(verts[i, a], numpy.arange(preview_exyz[a].size) - 0.5, preview_exyz[a]) + for a in range(3) + ] + for i in range(verts.shape[0]) + ], dtype=float) xs, ys, zs = (pos_verts[:, a] for a in range(3)) # Draw the plot diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 60929e8..9f2e4f3 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -226,3 +226,49 @@ def test_extent_accepts_scalar_like_inputs() -> None: assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0]) + + +def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None: + matplotlib = pytest.importorskip('matplotlib') + matplotlib.use('Agg') + skimage_measure = pytest.importorskip('skimage.measure') + from matplotlib import pyplot + from mpl_toolkits.mplot3d.axes3d import Axes3D + + captured: dict[str, numpy.ndarray] = {} + + def fake_marching_cubes(_grid: numpy.ndarray, _level: float) -> tuple[numpy.ndarray, numpy.ndarray, None, None]: + verts = numpy.array([[0.5, 0.5, 0.5], + [0.5, 1.5, 0.5], + [1.5, 0.5, 0.5]], dtype=float) + faces = numpy.array([[0, 1, 2]], dtype=int) + return verts, faces, None, None + + def fake_plot_trisurf( # noqa: ANN202 + _self: object, + xs: numpy.ndarray, + ys: numpy.ndarray, + faces: numpy.ndarray, + zs: numpy.ndarray, + *_args: object, + **_kwargs: object, + ) -> object: + captured['xs'] = numpy.asarray(xs) + captured['ys'] = numpy.asarray(ys) + captured['faces'] = numpy.asarray(faces) + captured['zs'] = numpy.asarray(zs) + return object() + + monkeypatch.setattr(skimage_measure, 'marching_cubes', fake_marching_cubes) + monkeypatch.setattr(Axes3D, 'plot_trisurf', fake_plot_trisurf) + + grid = Grid([numpy.arange(7, dtype=float), numpy.arange(7, dtype=float), numpy.arange(7, dtype=float)], shifts=[[0, 0, 0]]) + cell_data = numpy.zeros(grid.cell_data_shape) + + fig, _ax = grid.visualize_isosurface(cell_data, level=0.5, sample_period=2, finalize=False) + + assert_allclose(captured['xs'], [1.5, 1.5, 3.5]) + assert_allclose(captured['ys'], [1.5, 3.5, 1.5]) + assert_allclose(captured['zs'], [1.5, 1.5, 1.5]) + + pyplot.close(fig) From e345d1dcf8f9b52af7cd83844efe66082f0b0379 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:51:34 -0700 Subject: [PATCH 06/12] [get_slice] use shifted bounds --- gridlock/read.py | 2 +- gridlock/test/test_grid.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/gridlock/read.py b/gridlock/read.py index 9df3e08..9be52b1 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -83,7 +83,7 @@ class GridReadMixin(GridPosMixin): else: w = [1] - c_min, c_max = (self.xyz[plane.axis][i] for i in [0, -1]) + c_min, c_max = (self.shifted_xyz(which_shifts)[plane.axis][i] for i in [0, -1]) if plane.pos < c_min or plane.pos > c_max: raise GridError('Coordinate of selected plane must be within simulation domain') diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 9f2e4f3..c6c8ae7 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -226,6 +226,18 @@ def test_extent_accepts_scalar_like_inputs() -> None: assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0]) +def test_get_slice_uses_shifted_grid_bounds() -> None: + grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0.5, 0, 0]]) + cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) + + grid_slice = grid.get_slice(cell_data, Plane(x=2.0), which_shifts=0) + + assert_allclose(grid_slice, cell_data[0, 1, :, :]) + + with pytest.raises(GridError): + grid.get_slice(cell_data, Plane(x=2.1), which_shifts=0) + + def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None: From 8895b06f08df4fb43f5910cd29468f32db8866ff Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:51:59 -0700 Subject: [PATCH 07/12] fixup! [isosurface] fix sampling --- gridlock/test/test_grid.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index c6c8ae7..2cb60c5 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -240,6 +240,14 @@ def test_get_slice_uses_shifted_grid_bounds() -> None: +def test_sampled_preview_exyz_tracks_nonuniform_centers() -> None: + grid = Grid([[0, 1, 3, 6, 10], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]]) + + sampled_exyz = grid._sampled_exyz(0, 2) + + assert_allclose(sampled_exyz[0], [-1.5, 2.5, 6.5]) + + def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None: matplotlib = pytest.importorskip('matplotlib') matplotlib.use('Agg') From 481b56874ee9c42f8534a378fa85e89c1e523d93 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:52:45 -0700 Subject: [PATCH 08/12] [draw] fix extrude without out-of-bounds slice --- gridlock/draw.py | 23 +++++++++++++---------- gridlock/test/test_grid.py | 17 +++++++++++++++++ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 864468f..321ec15 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -76,10 +76,10 @@ class GridDrawMixin(GridPosMixin): # Broadcast foreground where necessary foregrounds: Sequence[foreground_callable_t] | Sequence[float] - if numpy.size(foreground) == 1: # type: ignore - foregrounds = [foreground] * len(cell_data) # type: ignore - elif isinstance(foreground, numpy.ndarray): + if isinstance(foreground, numpy.ndarray): raise GridError('ndarray not supported for foreground') + if callable(foreground) or numpy.isscalar(foreground): + foregrounds = [foreground] * len(cell_data) # type: ignore[list-item] else: foregrounds = foreground # type: ignore @@ -376,15 +376,18 @@ class GridDrawMixin(GridPosMixin): foreground_func = [] for ii, grid in enumerate(cell_data): zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction] - - ind = [int(numpy.floor(zz)) if dd == direction else slice(None) for dd in range(3)] - fpart = zz - numpy.floor(zz) - mult = [1 - fpart, fpart][::sgn] # reverses if s negative + low = int(numpy.clip(numpy.floor(zz), 0, grid.shape[direction] - 1)) + high = int(numpy.clip(numpy.floor(zz) + 1, 0, grid.shape[direction] - 1)) - foreground = mult[0] * grid[tuple(ind)] - ind[direction] += 1 # type: ignore #(known safe) - foreground += mult[1] * grid[tuple(ind)] + low_ind = [low if dd == direction else slice(None) for dd in range(3)] + high_ind = [high if dd == direction else slice(None) for dd in range(3)] + + if low == high: + foreground = grid[tuple(low_ind)] + else: + mult = [1 - fpart, fpart][::sgn] # reverses if s negative + foreground = mult[0] * grid[tuple(low_ind)] + mult[1] * grid[tuple(high_ind)] def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 # transform from natural position to index diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 2cb60c5..e7b3b28 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -238,6 +238,23 @@ def test_get_slice_uses_shifted_grid_bounds() -> None: grid.get_slice(cell_data, Plane(x=2.1), which_shifts=0) +def test_draw_extrude_rectangle_uses_boundary_slice() -> None: + grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]]) + cell_data = grid.allocate(0) + source = numpy.array([[1, 2], + [3, 4]], dtype=float) + cell_data[0, :, :, 1] = source + + grid.draw_extrude_rectangle( + cell_data, + rectangle=[[0, 0, 2], [2, 2, 2]], + direction=2, + polarity=-1, + distance=2, + ) + + assert_allclose(cell_data[0, :, :, 0], source) + assert_allclose(cell_data[0, :, :, 1], source) def test_sampled_preview_exyz_tracks_nonuniform_centers() -> None: From 96aad5a3a10ab779bbdf00da081cdbf85861096d Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 11:00:08 -0700 Subject: [PATCH 09/12] bump version to v2.1 --- gridlock/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 2f39696..e7be065 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -34,5 +34,5 @@ from .grid import Grid as Grid __author__ = 'Jan Petykiewicz' -__version__ = '2.0' +__version__ = '2.1' version = __version__ From 066ca8f3b88cc03a30da43125358895ee0337e84 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 11:00:49 -0700 Subject: [PATCH 10/12] bump version to v2.2 2.1 had an existing tag --- gridlock/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridlock/__init__.py b/gridlock/__init__.py index e7be065..3f965fd 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -34,5 +34,5 @@ from .grid import Grid as Grid __author__ = 'Jan Petykiewicz' -__version__ = '2.1' +__version__ = '2.2' version = __version__ From 85ae6e66cd4ee97192d6bb33249b5dc69e3d5668 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Tue, 21 Apr 2026 19:58:57 -0700 Subject: [PATCH 11/12] [Grid] enable negative shifts --- gridlock/base.py | 40 +++++++++++++-------------- gridlock/read.py | 2 +- gridlock/test/test_grid.py | 55 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 23 deletions(-) diff --git a/gridlock/base.py b/gridlock/base.py index aca9c69..e68d955 100644 --- a/gridlock/base.py +++ b/gridlock/base.py @@ -76,6 +76,21 @@ class GridBase(Protocol): el = [0 if p else -1 for p in self.periodic] return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)] + def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: + if which_shifts is None: + return self.dxyz_with_ghost + + shifts = self.shifts[which_shifts, :] + edge_dxyz = [] + for a in range(3): + if shifts[a] < 0: + ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0] + edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a]))) + else: + ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1] + edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost))) + return edge_dxyz + @property def center(self) -> NDArray[numpy.float64]: """ @@ -115,15 +130,9 @@ class GridBase(Protocol): """ if which_shifts is None: return self.exyz - dxyz = self.dxyz_with_ghost + edge_dxyz = self._shifted_edge_dxyz(which_shifts) shifts = self.shifts[which_shifts, :] - - # If shift is negative, use left cell's dx to determine shift - for a in range(3): - if shifts[a] < 0: - dxyz[a] = numpy.roll(dxyz[a], 1) - - return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] + return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)] def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]: """ @@ -137,20 +146,7 @@ class GridBase(Protocol): """ if which_shifts is None: return self.dxyz - shifts = self.shifts[which_shifts, :] - dxyz = self.dxyz_with_ghost - - # If shift is negative, use left cell's dx to determine size - sdxyz = [] - for a in range(3): - if shifts[a] < 0: - roll_dxyz = numpy.roll(dxyz[a], 1) - abs_shift = numpy.abs(shifts[a]) - sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift)) - else: - sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a]) - - return sdxyz + return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)] def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: """ diff --git a/gridlock/read.py b/gridlock/read.py index 9be52b1..f8a40a1 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -73,7 +73,7 @@ class GridReadMixin(GridPosMixin): surface = numpy.delete(range(3), plane.axis) # Extract indices and weights of planes - center3 = numpy.insert([0, 0], plane.axis, (plane.pos,)) + center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,)) center_index = self.pos2ind(center3, which_shifts, round_ind=False, check_bounds=False)[plane.axis] centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index e7b3b28..b4929a4 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -309,3 +309,58 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest. assert_allclose(captured['zs'], [1.5, 1.5, 1.5]) pyplot.close(fig) + + + + +def test_negative_shift_nonperiodic_edges_and_widths() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + + assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0]) + assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5]) + assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25]) + + +def test_negative_shift_periodic_edges_and_widths() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False]) + + assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0]) + assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5]) + + +def test_negative_shift_coordinate_round_trip() -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + + ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False) + pos = grid.ind2pos(ind, 0, round_ind=False) + + assert_allclose(ind, [1.0, 0.0, 0.0]) + assert_allclose(pos, [1.25, 1.0, 0.5]) + + +def test_negative_shift_draw_cuboid_fractional_fill() -> None: + grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + arr = grid.allocate(0) + + grid.draw_cuboid( + arr, + x=dict(min=0, max=1), + y=dict(min=0, max=1), + z=dict(min=0, max=1), + foreground=1, + ) + + assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3]) + + +def test_negative_shift_get_slice_uses_shifted_centers() -> None: + grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) + cell_data = numpy.zeros(grid.cell_data_shape) + cell_data[0, 1, :, 0] = [7, 9] + x_center = float(grid.shifted_xyz(0)[0][1]) + + grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0) + + assert_allclose(grid_slice, [7, 9]) + + From 22cb410d84ff4f33b727376761ebc489b59c382e Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Tue, 21 Apr 2026 20:00:52 -0700 Subject: [PATCH 12/12] [GridData / save / load] Add GridData and update save format --- gridlock/__init__.py | 1 + gridlock/data.py | 176 +++++++++++++++++++++++++++++++++++++ gridlock/grid.py | 110 ++++++++++++++++++++--- gridlock/test/test_grid.py | 102 ++++++++++++++++++++- 4 files changed, 376 insertions(+), 13 deletions(-) create mode 100644 gridlock/data.py diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 3f965fd..759d1c1 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -31,6 +31,7 @@ from .utils import ( PlaneDict as PlaneDict, ) from .grid import Grid as Grid +from .data import GridData as GridData __author__ = 'Jan Petykiewicz' diff --git a/gridlock/data.py b/gridlock/data.py new file mode 100644 index 0000000..5e6faa5 --- /dev/null +++ b/gridlock/data.py @@ -0,0 +1,176 @@ +from dataclasses import dataclass +from typing import Self +from collections.abc import Sequence + +import numpy +from numpy.typing import NDArray, ArrayLike + +from .draw import foreground_t +from .grid import Grid, _grid_from_payload, _load_payload, _payload_scalar_str, _save_npz_payload +from .utils import ( + ExtentDict, + ExtentProtocol, + GridError, + PlaneDict, + PlaneProtocol, + SlabDict, + SlabProtocol, +) + + +@dataclass(slots=True) +class GridData: + grid: Grid + cell_data: NDArray + + def __post_init__(self) -> None: + if tuple(self.cell_data.shape) != tuple(self.grid.cell_data_shape): + raise GridError( + f'cell_data has shape {self.cell_data.shape}, expected {tuple(self.grid.cell_data_shape)}' + ) + + @staticmethod + def load(filename: str) -> 'GridData': + payload = _load_payload(filename) + if _payload_scalar_str(payload, 'kind') != 'grid_data': + raise GridError('Serialized payload does not contain GridData') + if 'cell_data' not in payload: + raise GridError('Serialized GridData payload is missing cell_data') + + return GridData(_grid_from_payload(payload), numpy.array(payload['cell_data'])) + + def save(self, filename: str) -> Self: + payload = self.grid._serialization_payload(kind='grid_data') + payload['cell_data'] = self.cell_data + _save_npz_payload(filename, payload) + return self + + def copy(self) -> Self: + return GridData(self.grid.copy(), self.cell_data.copy()) + + def draw_polygons( + self, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + polygons: Sequence[ArrayLike], + *, + offset2d: ArrayLike = (0, 0), + ) -> Self: + self.grid.draw_polygons(self.cell_data, foreground, slab, polygons, offset2d=offset2d) + return self + + def draw_polygon( + self, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + polygon: ArrayLike, + *, + offset2d: ArrayLike = (0, 0), + ) -> Self: + self.grid.draw_polygon(self.cell_data, foreground, slab, polygon, offset2d=offset2d) + return self + + def draw_slab( + self, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + ) -> Self: + self.grid.draw_slab(self.cell_data, foreground, slab) + return self + + def draw_cuboid( + self, + foreground: Sequence[foreground_t] | foreground_t, + *, + x: ExtentProtocol | ExtentDict, + y: ExtentProtocol | ExtentDict, + z: ExtentProtocol | ExtentDict, + ) -> Self: + self.grid.draw_cuboid(self.cell_data, foreground, x=x, y=y, z=z) + return self + + def draw_cylinder( + self, + h: SlabProtocol | SlabDict, + radius: float, + num_points: int, + center2d: ArrayLike, + foreground: Sequence[foreground_t] | foreground_t, + ) -> Self: + self.grid.draw_cylinder(self.cell_data, h, radius, num_points, center2d, foreground) + return self + + def draw_extrude_rectangle( + self, + rectangle: ArrayLike, + direction: int, + polarity: int, + distance: float, + ) -> Self: + self.grid.draw_extrude_rectangle(self.cell_data, rectangle, direction, polarity, distance) + return self + + def get_slice( + self, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + sample_period: int = 1, + ) -> NDArray: + return self.grid.get_slice(self.cell_data, plane, which_shifts=which_shifts, sample_period=sample_period) + + def visualize_slice( + self, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + sample_period: int = 1, + finalize: bool = True, + pcolormesh_args: dict[str, object] | None = None, + ax: object | None = None, + ) -> tuple[object, object]: + return self.grid.visualize_slice( + self.cell_data, + plane, + which_shifts=which_shifts, + sample_period=sample_period, + finalize=finalize, + pcolormesh_args=pcolormesh_args, + ax=ax, + ) + + def visualize_edges( + self, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + sample_period: int = 1, + finalize: bool = True, + contour_args: dict[str, object] | None = None, + ax: object | None = None, + level_fraction: float = 0.7, + ) -> tuple[object, object]: + return self.grid.visualize_edges( + self.cell_data, + plane, + which_shifts=which_shifts, + sample_period=sample_period, + finalize=finalize, + contour_args=contour_args, + ax=ax, + level_fraction=level_fraction, + ) + + def visualize_isosurface( + self, + level: float | None = None, + which_shifts: int = 0, + sample_period: int = 1, + show_edges: bool = True, + finalize: bool = True, + ) -> tuple[object, object]: + return self.grid.visualize_isosurface( + self.cell_data, + level=level, + which_shifts=which_shifts, + sample_period=sample_period, + show_edges=show_edges, + finalize=finalize, + ) diff --git a/gridlock/grid.py b/gridlock/grid.py index 5bed422..eeb9708 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Self +from typing import TYPE_CHECKING, Any, ClassVar, Self from collections.abc import Callable, Sequence import numpy @@ -13,8 +13,78 @@ from .draw import GridDrawMixin from .read import GridReadMixin from .position import GridPosMixin +if TYPE_CHECKING: + from .data import GridData + foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray] +_FORMAT_VERSION = 1 + + +def _is_npz_file(filename: str) -> bool: + with open(filename, 'rb') as f: + return f.read(2) == b'PK' + + +def _save_npz_payload(filename: str, payload: dict[str, Any]) -> None: + with open(filename, 'wb') as f: + numpy.savez_compressed(f, **payload) + + +def _load_payload(filename: str) -> dict[str, Any]: + if _is_npz_file(filename): + with numpy.load(filename, allow_pickle=False) as payload: + return {key: payload[key] for key in payload.files} + + with open(filename, 'rb') as f: + legacy = pickle.load(f) + + if isinstance(legacy, Grid): + return legacy._serialization_payload(kind='grid') + if isinstance(legacy, dict): + grid = Grid([[-1, 1]] * 3) + grid.__dict__.update(legacy) + return grid._serialization_payload(kind='grid') + raise GridError('Unsupported serialized Grid payload') + + +def _payload_scalar_str(payload: dict[str, Any], key: str) -> str: + if key not in payload: + raise GridError(f'Missing serialized key: {key}') + + value = numpy.asarray(payload[key]) + if value.size != 1: + raise GridError(f'Serialized key {key} must be scalar') + return str(value.reshape(())) + + +def _payload_scalar_int(payload: dict[str, Any], key: str) -> int: + if key not in payload: + raise GridError(f'Missing serialized key: {key}') + + value = numpy.asarray(payload[key]) + if value.size != 1: + raise GridError(f'Serialized key {key} must be scalar') + return int(value.reshape(())) + + +def _grid_from_payload(payload: dict[str, Any]) -> 'Grid': + if _payload_scalar_int(payload, 'format_version') != _FORMAT_VERSION: + raise GridError('Unsupported serialized Grid format version') + + exyz = [] + for axis in range(3): + key = f'exyz_{axis}' + if key not in payload: + raise GridError(f'Missing serialized key: {key}') + exyz.append(numpy.array(payload[key], dtype=float)) + + if 'shifts' not in payload or 'periodic' not in payload: + raise GridError('Serialized Grid payload is missing shifts or periodic data') + + shifts = numpy.array(payload['shifts'], dtype=float) + periodic = numpy.array(payload['periodic'], dtype=bool).tolist() + return Grid(exyz, shifts=shifts, periodic=periodic) class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): @@ -110,6 +180,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): self.periodic = list(periodic) if len(self.periodic) != 3: raise GridError('periodic must be a bool or a sequence of length 3') + if not all(isinstance(pp, bool | numpy.bool_) for pp in self.periodic): + raise GridError('periodic sequence entries must be bool values') if len(self.shifts.shape) != 2: raise GridError('Misshapen shifts: shifts must have two axes! ' @@ -121,9 +193,16 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): if (numpy.abs(self.shifts) > 1).any(): raise GridError('Only shifts in the range [-1, 1] are currently supported') - if (self.shifts < 0).any(): - # TODO: Test negative shifts - warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2) + def _serialization_payload(self, *, kind: str) -> dict[str, Any]: + payload: dict[str, Any] = { + 'kind': numpy.array(kind), + 'format_version': numpy.array(_FORMAT_VERSION, dtype=int), + 'shifts': self.shifts, + 'periodic': numpy.array(self.periodic, dtype=bool), + } + for axis, exyz in enumerate(self.exyz): + payload[f'exyz_{axis}'] = exyz + return payload @staticmethod def load(filename: str) -> 'Grid': @@ -133,12 +212,11 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Args: filename: Filename to load from. """ - with open(filename, 'rb') as f: - tmp_dict = pickle.load(f) - - g = Grid([[-1, 1]] * 3) - g.__dict__.update(tmp_dict) - return g + payload = _load_payload(filename) + kind = _payload_scalar_str(payload, 'kind') + if kind not in ('grid', 'grid_data'): + raise GridError(f'Unsupported serialized kind: {kind}') + return _grid_from_payload(payload) def save(self, filename: str) -> Self: """ @@ -150,10 +228,18 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Returns: self """ - with open(filename, 'wb') as f: - pickle.dump(self.__dict__, f, protocol=2) + _save_npz_payload(filename, self._serialization_payload(kind='grid')) return self + def with_data( + self, + fill_value: float | None = 1.0, + dtype: type[numpy.number] = numpy.float32, + ) -> 'GridData': + from .data import GridData + + return GridData(self.copy(), self.allocate(fill_value=fill_value, dtype=dtype)) + def copy(self) -> Self: """ Returns: diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index b4929a4..ae0a73a 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,8 +1,9 @@ import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal +import pickle -from .. import Grid, Extent, GridError, Plane, Slab +from .. import Grid, GridData, Extent, GridError, Plane, Slab def test_draw_oncenter_2x2() -> None: @@ -311,6 +312,54 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest. pyplot.close(fig) +def test_grid_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False]) + path = tmp_path / 'grid.state' + + grid.save(str(path)) + loaded = Grid.load(str(path)) + + assert path.exists() + for original, restored in zip(grid.exyz, loaded.exyz, strict=True): + assert_allclose(restored, original) + assert_allclose(loaded.shifts, grid.shifts) + assert loaded.periodic == grid.periodic + + +def test_grid_load_supports_legacy_pickle(tmp_path: pytest.TempPathFactory) -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False]) + path = tmp_path / 'grid.pickle' + with open(path, 'wb') as f: + pickle.dump(grid.__dict__, f, protocol=2) + + loaded = Grid.load(str(path)) + + for original, restored in zip(grid.exyz, loaded.exyz, strict=True): + assert_allclose(restored, original) + assert_allclose(loaded.shifts, grid.shifts) + assert loaded.periodic == grid.periodic + + +def test_griddata_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None: + data = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]]).with_data(fill_value=2.0) + data.cell_data[0, 1, 0, 0] = 5.0 + path = tmp_path / 'griddata.state' + + data.save(str(path)) + loaded = GridData.load(str(path)) + + assert path.exists() + assert_allclose(loaded.cell_data, data.cell_data) + assert_allclose(loaded.grid.shifts, data.grid.shifts) + assert loaded.grid.periodic == data.grid.periodic + + +def test_griddata_rejects_invalid_payload_kind(tmp_path: pytest.TempPathFactory) -> None: + path = tmp_path / 'grid.state' + Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).save(str(path)) + + with pytest.raises(GridError): + GridData.load(str(path)) def test_negative_shift_nonperiodic_edges_and_widths() -> None: @@ -364,3 +413,54 @@ def test_negative_shift_get_slice_uses_shifted_centers() -> None: assert_allclose(grid_slice, [7, 9]) +def test_grid_with_data_returns_griddata() -> None: + grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) + data = grid.with_data(fill_value=2.0) + + assert isinstance(data, GridData) + assert_allclose(data.cell_data, numpy.full(grid.cell_data_shape, 2.0, dtype=numpy.float32)) + + +def test_griddata_constructor_validates_shape() -> None: + grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) + + with pytest.raises(GridError): + GridData(grid, numpy.zeros((1, 1, 1))) + + +def test_griddata_draw_methods_are_chainable() -> None: + data = Grid([[0, 1, 2], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0) + + chained = data.draw_cuboid( + foreground=1, + x=dict(min=0, max=1), + y=dict(min=0, max=1), + z=dict(min=0, max=1), + ).draw_polygon( + foreground=0.5, + slab=dict(axis='z', center=0.5, span=1.0), + polygon=numpy.array([[0, 0], [2, 0], [2, 1], [0, 1]], dtype=float), + ) + + assert chained is data + assert data.cell_data.sum() > 0 + + +def test_griddata_read_methods_delegate() -> None: + data = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0) + data.cell_data[0, :, :, 0] = numpy.array([[1, 2], [3, 4]], dtype=float) + + assert_allclose( + data.get_slice(Plane(z=0.5)), + data.grid.get_slice(data.cell_data, Plane(z=0.5)), + ) + + +def test_griddata_copy_is_independent() -> None: + data = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=1.0) + cloned = data.copy() + cloned.cell_data[0, 0, 0, 0] = 5.0 + + assert data is not cloned + assert data.grid is not cloned.grid + assert data.cell_data[0, 0, 0, 0] == 1.0