[isosurface] fix sampling
This commit is contained in:
parent
15c2cf8351
commit
ddce4fa491
2 changed files with 74 additions and 2 deletions
|
|
@ -20,6 +20,26 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class GridReadMixin(GridPosMixin):
|
||||
@staticmethod
|
||||
def _preview_exyz_from_centers(centers: NDArray, fallback_edges: NDArray) -> NDArray[numpy.float64]:
|
||||
if centers.size > 1:
|
||||
midpoints = 0.5 * (centers[:-1] + centers[1:])
|
||||
first = centers[0] - 0.5 * (centers[1] - centers[0])
|
||||
last = centers[-1] + 0.5 * (centers[-1] - centers[-2])
|
||||
return numpy.hstack(([first], midpoints, [last]))
|
||||
return numpy.array([fallback_edges[0], fallback_edges[-1]], dtype=float)
|
||||
|
||||
def _sampled_exyz(self, which_shifts: int, sample_period: int) -> list[NDArray[numpy.float64]]:
|
||||
if sample_period <= 1:
|
||||
return self.shifted_exyz(which_shifts)
|
||||
|
||||
shifted_xyz = self.shifted_xyz(which_shifts)
|
||||
shifted_exyz = self.shifted_exyz(which_shifts)
|
||||
return [
|
||||
self._preview_exyz_from_centers(shifted_xyz[a][::sample_period], shifted_exyz[a])
|
||||
for a in range(3)
|
||||
]
|
||||
|
||||
def get_slice(
|
||||
self,
|
||||
cell_data: NDArray,
|
||||
|
|
@ -262,8 +282,14 @@ class GridReadMixin(GridPosMixin):
|
|||
verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level)
|
||||
|
||||
# Convert vertices from index to position
|
||||
pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False)
|
||||
for i in range(verts.shape[0])], dtype=float)
|
||||
preview_exyz = self._sampled_exyz(which_shifts, sample_period)
|
||||
pos_verts = numpy.array([
|
||||
[
|
||||
numpy.interp(verts[i, a], numpy.arange(preview_exyz[a].size) - 0.5, preview_exyz[a])
|
||||
for a in range(3)
|
||||
]
|
||||
for i in range(verts.shape[0])
|
||||
], dtype=float)
|
||||
xs, ys, zs = (pos_verts[:, a] for a in range(3))
|
||||
|
||||
# Draw the plot
|
||||
|
|
|
|||
|
|
@ -226,3 +226,49 @@ def test_extent_accepts_scalar_like_inputs() -> None:
|
|||
assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0])
|
||||
|
||||
|
||||
|
||||
|
||||
def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
matplotlib = pytest.importorskip('matplotlib')
|
||||
matplotlib.use('Agg')
|
||||
skimage_measure = pytest.importorskip('skimage.measure')
|
||||
from matplotlib import pyplot
|
||||
from mpl_toolkits.mplot3d.axes3d import Axes3D
|
||||
|
||||
captured: dict[str, numpy.ndarray] = {}
|
||||
|
||||
def fake_marching_cubes(_grid: numpy.ndarray, _level: float) -> tuple[numpy.ndarray, numpy.ndarray, None, None]:
|
||||
verts = numpy.array([[0.5, 0.5, 0.5],
|
||||
[0.5, 1.5, 0.5],
|
||||
[1.5, 0.5, 0.5]], dtype=float)
|
||||
faces = numpy.array([[0, 1, 2]], dtype=int)
|
||||
return verts, faces, None, None
|
||||
|
||||
def fake_plot_trisurf( # noqa: ANN202
|
||||
_self: object,
|
||||
xs: numpy.ndarray,
|
||||
ys: numpy.ndarray,
|
||||
faces: numpy.ndarray,
|
||||
zs: numpy.ndarray,
|
||||
*_args: object,
|
||||
**_kwargs: object,
|
||||
) -> object:
|
||||
captured['xs'] = numpy.asarray(xs)
|
||||
captured['ys'] = numpy.asarray(ys)
|
||||
captured['faces'] = numpy.asarray(faces)
|
||||
captured['zs'] = numpy.asarray(zs)
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(skimage_measure, 'marching_cubes', fake_marching_cubes)
|
||||
monkeypatch.setattr(Axes3D, 'plot_trisurf', fake_plot_trisurf)
|
||||
|
||||
grid = Grid([numpy.arange(7, dtype=float), numpy.arange(7, dtype=float), numpy.arange(7, dtype=float)], shifts=[[0, 0, 0]])
|
||||
cell_data = numpy.zeros(grid.cell_data_shape)
|
||||
|
||||
fig, _ax = grid.visualize_isosurface(cell_data, level=0.5, sample_period=2, finalize=False)
|
||||
|
||||
assert_allclose(captured['xs'], [1.5, 1.5, 3.5])
|
||||
assert_allclose(captured['ys'], [1.5, 3.5, 1.5])
|
||||
assert_allclose(captured['zs'], [1.5, 1.5, 1.5])
|
||||
|
||||
pyplot.close(fig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue