From 526b9e1666b55c59cbf2fa684e9f20dc500b7ac8 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:25:41 -0700 Subject: [PATCH] [read] fix sampling --- gridlock/read.py | 13 +++++++++---- gridlock/test/test_grid.py | 26 ++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/gridlock/read.py b/gridlock/read.py index 503e996..998e79d 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -68,7 +68,8 @@ class GridReadMixin(GridPosMixin): raise GridError('Coordinate of selected plane must be within simulation domain') # Extract grid values from planes above and below visualized slice - sliced_grid = numpy.zeros(self.shape[surface]) + sample_shape = tuple(self.shifted_xyz(which_shifts)[a][::sp].size for a in surface) + sliced_grid = numpy.zeros(sample_shape, dtype=numpy.result_type(cell_data.dtype, float)) for ci, weight in zip(centers, w, strict=True): s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3)) sliced_grid += weight * cell_data[which_shifts][tuple(s)] @@ -122,7 +123,11 @@ class GridReadMixin(GridPosMixin): surface = numpy.delete(range(3), plane.axis) - x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) + if sample_period == 1: + x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) + else: + x, y = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface) + pcolormesh_args.setdefault('shading', 'nearest') xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') x_label, y_label = ('xyz'[a] for a in surface) @@ -208,10 +213,10 @@ class GridReadMixin(GridPosMixin): fig, ax = pyplot.subplots() else: fig = ax.figure - xc, yc = (self.shifted_xyz(which_shifts)[a] for a in surface) + xc, yc = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface) xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij') - mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) + ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) if finalize: pyplot.show() diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index a9e3d9e..84b0f7b 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -168,3 +168,29 @@ def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None: with pytest.raises(GridError): grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1) + +def test_get_slice_supports_sampling() -> None: + grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]]) + cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) + + grid_slice = grid.get_slice(cell_data, Plane(z=0.5), sample_period=2) + + assert_allclose(grid_slice, cell_data[0, ::2, ::2, 0]) + + +def test_sampled_visualization_helpers_do_not_error() -> None: + matplotlib = pytest.importorskip('matplotlib') + matplotlib.use('Agg') + from matplotlib import pyplot + + grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]]) + cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) + + fig_slice, ax_slice = grid.visualize_slice(cell_data, Plane(z=0.5), sample_period=2, finalize=False) + fig_edges, ax_edges = grid.visualize_edges(cell_data, Plane(z=0.5), sample_period=2, finalize=False) + + assert fig_slice is ax_slice.figure + assert fig_edges is ax_edges.figure + + pyplot.close(fig_slice) + pyplot.close(fig_edges)