[read] fix sampling

This commit is contained in:
Jan Petykiewicz 2026-04-20 10:25:41 -07:00
commit 526b9e1666
2 changed files with 35 additions and 4 deletions

View file

@ -68,7 +68,8 @@ class GridReadMixin(GridPosMixin):
raise GridError('Coordinate of selected plane must be within simulation domain')
# Extract grid values from planes above and below visualized slice
sliced_grid = numpy.zeros(self.shape[surface])
sample_shape = tuple(self.shifted_xyz(which_shifts)[a][::sp].size for a in surface)
sliced_grid = numpy.zeros(sample_shape, dtype=numpy.result_type(cell_data.dtype, float))
for ci, weight in zip(centers, w, strict=True):
s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3))
sliced_grid += weight * cell_data[which_shifts][tuple(s)]
@ -122,7 +123,11 @@ class GridReadMixin(GridPosMixin):
surface = numpy.delete(range(3), plane.axis)
x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
if sample_period == 1:
x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
else:
x, y = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface)
pcolormesh_args.setdefault('shading', 'nearest')
xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
x_label, y_label = ('xyz'[a] for a in surface)
@ -208,10 +213,10 @@ class GridReadMixin(GridPosMixin):
fig, ax = pyplot.subplots()
else:
fig = ax.figure
xc, yc = (self.shifted_xyz(which_shifts)[a] for a in surface)
xc, yc = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface)
xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij')
mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args)
ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args)
if finalize:
pyplot.show()

View file

@ -168,3 +168,29 @@ def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None:
with pytest.raises(GridError):
grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1)
def test_get_slice_supports_sampling() -> None:
grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]])
cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape)
grid_slice = grid.get_slice(cell_data, Plane(z=0.5), sample_period=2)
assert_allclose(grid_slice, cell_data[0, ::2, ::2, 0])
def test_sampled_visualization_helpers_do_not_error() -> None:
matplotlib = pytest.importorskip('matplotlib')
matplotlib.use('Agg')
from matplotlib import pyplot
grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]])
cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape)
fig_slice, ax_slice = grid.visualize_slice(cell_data, Plane(z=0.5), sample_period=2, finalize=False)
fig_edges, ax_edges = grid.visualize_edges(cell_data, Plane(z=0.5), sample_period=2, finalize=False)
assert fig_slice is ax_slice.figure
assert fig_edges is ax_edges.figure
pyplot.close(fig_slice)
pyplot.close(fig_edges)