[isosurface] fix sampling

This commit is contained in:
Jan Petykiewicz 2026-04-20 10:50:48 -07:00
commit ddce4fa491
2 changed files with 74 additions and 2 deletions

View file

@ -20,6 +20,26 @@ if TYPE_CHECKING:
class GridReadMixin(GridPosMixin):
@staticmethod
def _preview_exyz_from_centers(centers: NDArray, fallback_edges: NDArray) -> NDArray[numpy.float64]:
if centers.size > 1:
midpoints = 0.5 * (centers[:-1] + centers[1:])
first = centers[0] - 0.5 * (centers[1] - centers[0])
last = centers[-1] + 0.5 * (centers[-1] - centers[-2])
return numpy.hstack(([first], midpoints, [last]))
return numpy.array([fallback_edges[0], fallback_edges[-1]], dtype=float)
def _sampled_exyz(self, which_shifts: int, sample_period: int) -> list[NDArray[numpy.float64]]:
if sample_period <= 1:
return self.shifted_exyz(which_shifts)
shifted_xyz = self.shifted_xyz(which_shifts)
shifted_exyz = self.shifted_exyz(which_shifts)
return [
self._preview_exyz_from_centers(shifted_xyz[a][::sample_period], shifted_exyz[a])
for a in range(3)
]
def get_slice(
self,
cell_data: NDArray,
@ -262,8 +282,14 @@ class GridReadMixin(GridPosMixin):
verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level)
# Convert vertices from index to position
pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False)
for i in range(verts.shape[0])], dtype=float)
preview_exyz = self._sampled_exyz(which_shifts, sample_period)
pos_verts = numpy.array([
[
numpy.interp(verts[i, a], numpy.arange(preview_exyz[a].size) - 0.5, preview_exyz[a])
for a in range(3)
]
for i in range(verts.shape[0])
], dtype=float)
xs, ys, zs = (pos_verts[:, a] for a in range(3))
# Draw the plot

View file

@ -226,3 +226,49 @@ def test_extent_accepts_scalar_like_inputs() -> None:
assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0])
def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None:
matplotlib = pytest.importorskip('matplotlib')
matplotlib.use('Agg')
skimage_measure = pytest.importorskip('skimage.measure')
from matplotlib import pyplot
from mpl_toolkits.mplot3d.axes3d import Axes3D
captured: dict[str, numpy.ndarray] = {}
def fake_marching_cubes(_grid: numpy.ndarray, _level: float) -> tuple[numpy.ndarray, numpy.ndarray, None, None]:
verts = numpy.array([[0.5, 0.5, 0.5],
[0.5, 1.5, 0.5],
[1.5, 0.5, 0.5]], dtype=float)
faces = numpy.array([[0, 1, 2]], dtype=int)
return verts, faces, None, None
def fake_plot_trisurf( # noqa: ANN202
_self: object,
xs: numpy.ndarray,
ys: numpy.ndarray,
faces: numpy.ndarray,
zs: numpy.ndarray,
*_args: object,
**_kwargs: object,
) -> object:
captured['xs'] = numpy.asarray(xs)
captured['ys'] = numpy.asarray(ys)
captured['faces'] = numpy.asarray(faces)
captured['zs'] = numpy.asarray(zs)
return object()
monkeypatch.setattr(skimage_measure, 'marching_cubes', fake_marching_cubes)
monkeypatch.setattr(Axes3D, 'plot_trisurf', fake_plot_trisurf)
grid = Grid([numpy.arange(7, dtype=float), numpy.arange(7, dtype=float), numpy.arange(7, dtype=float)], shifts=[[0, 0, 0]])
cell_data = numpy.zeros(grid.cell_data_shape)
fig, _ax = grid.visualize_isosurface(cell_data, level=0.5, sample_period=2, finalize=False)
assert_allclose(captured['xs'], [1.5, 1.5, 3.5])
assert_allclose(captured['ys'], [1.5, 3.5, 1.5])
assert_allclose(captured['zs'], [1.5, 1.5, 1.5])
pyplot.close(fig)