diff --git a/gridlock/read.py b/gridlock/read.py index 998e79d..9df3e08 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -20,6 +20,26 @@ if TYPE_CHECKING: class GridReadMixin(GridPosMixin): + @staticmethod + def _preview_exyz_from_centers(centers: NDArray, fallback_edges: NDArray) -> NDArray[numpy.float64]: + if centers.size > 1: + midpoints = 0.5 * (centers[:-1] + centers[1:]) + first = centers[0] - 0.5 * (centers[1] - centers[0]) + last = centers[-1] + 0.5 * (centers[-1] - centers[-2]) + return numpy.hstack(([first], midpoints, [last])) + return numpy.array([fallback_edges[0], fallback_edges[-1]], dtype=float) + + def _sampled_exyz(self, which_shifts: int, sample_period: int) -> list[NDArray[numpy.float64]]: + if sample_period <= 1: + return self.shifted_exyz(which_shifts) + + shifted_xyz = self.shifted_xyz(which_shifts) + shifted_exyz = self.shifted_exyz(which_shifts) + return [ + self._preview_exyz_from_centers(shifted_xyz[a][::sample_period], shifted_exyz[a]) + for a in range(3) + ] + def get_slice( self, cell_data: NDArray, @@ -262,8 +282,14 @@ class GridReadMixin(GridPosMixin): verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) # Convert vertices from index to position - pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False) - for i in range(verts.shape[0])], dtype=float) + preview_exyz = self._sampled_exyz(which_shifts, sample_period) + pos_verts = numpy.array([ + [ + numpy.interp(verts[i, a], numpy.arange(preview_exyz[a].size) - 0.5, preview_exyz[a]) + for a in range(3) + ] + for i in range(verts.shape[0]) + ], dtype=float) xs, ys, zs = (pos_verts[:, a] for a in range(3)) # Draw the plot diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 60929e8..9f2e4f3 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -226,3 +226,49 @@ def test_extent_accepts_scalar_like_inputs() -> None: assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0]) + + +def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None: + matplotlib = pytest.importorskip('matplotlib') + matplotlib.use('Agg') + skimage_measure = pytest.importorskip('skimage.measure') + from matplotlib import pyplot + from mpl_toolkits.mplot3d.axes3d import Axes3D + + captured: dict[str, numpy.ndarray] = {} + + def fake_marching_cubes(_grid: numpy.ndarray, _level: float) -> tuple[numpy.ndarray, numpy.ndarray, None, None]: + verts = numpy.array([[0.5, 0.5, 0.5], + [0.5, 1.5, 0.5], + [1.5, 0.5, 0.5]], dtype=float) + faces = numpy.array([[0, 1, 2]], dtype=int) + return verts, faces, None, None + + def fake_plot_trisurf( # noqa: ANN202 + _self: object, + xs: numpy.ndarray, + ys: numpy.ndarray, + faces: numpy.ndarray, + zs: numpy.ndarray, + *_args: object, + **_kwargs: object, + ) -> object: + captured['xs'] = numpy.asarray(xs) + captured['ys'] = numpy.asarray(ys) + captured['faces'] = numpy.asarray(faces) + captured['zs'] = numpy.asarray(zs) + return object() + + monkeypatch.setattr(skimage_measure, 'marching_cubes', fake_marching_cubes) + monkeypatch.setattr(Axes3D, 'plot_trisurf', fake_plot_trisurf) + + grid = Grid([numpy.arange(7, dtype=float), numpy.arange(7, dtype=float), numpy.arange(7, dtype=float)], shifts=[[0, 0, 0]]) + cell_data = numpy.zeros(grid.cell_data_shape) + + fig, _ax = grid.visualize_isosurface(cell_data, level=0.5, sample_period=2, finalize=False) + + assert_allclose(captured['xs'], [1.5, 1.5, 3.5]) + assert_allclose(captured['ys'], [1.5, 3.5, 1.5]) + assert_allclose(captured['zs'], [1.5, 1.5, 1.5]) + + pyplot.close(fig)