[isosurface] fix sampling
This commit is contained in:
parent
15c2cf8351
commit
ddce4fa491
2 changed files with 74 additions and 2 deletions
|
|
@ -20,6 +20,26 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
class GridReadMixin(GridPosMixin):
|
class GridReadMixin(GridPosMixin):
|
||||||
|
@staticmethod
|
||||||
|
def _preview_exyz_from_centers(centers: NDArray, fallback_edges: NDArray) -> NDArray[numpy.float64]:
|
||||||
|
if centers.size > 1:
|
||||||
|
midpoints = 0.5 * (centers[:-1] + centers[1:])
|
||||||
|
first = centers[0] - 0.5 * (centers[1] - centers[0])
|
||||||
|
last = centers[-1] + 0.5 * (centers[-1] - centers[-2])
|
||||||
|
return numpy.hstack(([first], midpoints, [last]))
|
||||||
|
return numpy.array([fallback_edges[0], fallback_edges[-1]], dtype=float)
|
||||||
|
|
||||||
|
def _sampled_exyz(self, which_shifts: int, sample_period: int) -> list[NDArray[numpy.float64]]:
|
||||||
|
if sample_period <= 1:
|
||||||
|
return self.shifted_exyz(which_shifts)
|
||||||
|
|
||||||
|
shifted_xyz = self.shifted_xyz(which_shifts)
|
||||||
|
shifted_exyz = self.shifted_exyz(which_shifts)
|
||||||
|
return [
|
||||||
|
self._preview_exyz_from_centers(shifted_xyz[a][::sample_period], shifted_exyz[a])
|
||||||
|
for a in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
def get_slice(
|
def get_slice(
|
||||||
self,
|
self,
|
||||||
cell_data: NDArray,
|
cell_data: NDArray,
|
||||||
|
|
@ -262,8 +282,14 @@ class GridReadMixin(GridPosMixin):
|
||||||
verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level)
|
verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level)
|
||||||
|
|
||||||
# Convert vertices from index to position
|
# Convert vertices from index to position
|
||||||
pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False)
|
preview_exyz = self._sampled_exyz(which_shifts, sample_period)
|
||||||
for i in range(verts.shape[0])], dtype=float)
|
pos_verts = numpy.array([
|
||||||
|
[
|
||||||
|
numpy.interp(verts[i, a], numpy.arange(preview_exyz[a].size) - 0.5, preview_exyz[a])
|
||||||
|
for a in range(3)
|
||||||
|
]
|
||||||
|
for i in range(verts.shape[0])
|
||||||
|
], dtype=float)
|
||||||
xs, ys, zs = (pos_verts[:, a] for a in range(3))
|
xs, ys, zs = (pos_verts[:, a] for a in range(3))
|
||||||
|
|
||||||
# Draw the plot
|
# Draw the plot
|
||||||
|
|
|
||||||
|
|
@ -226,3 +226,49 @@ def test_extent_accepts_scalar_like_inputs() -> None:
|
||||||
assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0])
|
assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
matplotlib = pytest.importorskip('matplotlib')
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
skimage_measure = pytest.importorskip('skimage.measure')
|
||||||
|
from matplotlib import pyplot
|
||||||
|
from mpl_toolkits.mplot3d.axes3d import Axes3D
|
||||||
|
|
||||||
|
captured: dict[str, numpy.ndarray] = {}
|
||||||
|
|
||||||
|
def fake_marching_cubes(_grid: numpy.ndarray, _level: float) -> tuple[numpy.ndarray, numpy.ndarray, None, None]:
|
||||||
|
verts = numpy.array([[0.5, 0.5, 0.5],
|
||||||
|
[0.5, 1.5, 0.5],
|
||||||
|
[1.5, 0.5, 0.5]], dtype=float)
|
||||||
|
faces = numpy.array([[0, 1, 2]], dtype=int)
|
||||||
|
return verts, faces, None, None
|
||||||
|
|
||||||
|
def fake_plot_trisurf( # noqa: ANN202
|
||||||
|
_self: object,
|
||||||
|
xs: numpy.ndarray,
|
||||||
|
ys: numpy.ndarray,
|
||||||
|
faces: numpy.ndarray,
|
||||||
|
zs: numpy.ndarray,
|
||||||
|
*_args: object,
|
||||||
|
**_kwargs: object,
|
||||||
|
) -> object:
|
||||||
|
captured['xs'] = numpy.asarray(xs)
|
||||||
|
captured['ys'] = numpy.asarray(ys)
|
||||||
|
captured['faces'] = numpy.asarray(faces)
|
||||||
|
captured['zs'] = numpy.asarray(zs)
|
||||||
|
return object()
|
||||||
|
|
||||||
|
monkeypatch.setattr(skimage_measure, 'marching_cubes', fake_marching_cubes)
|
||||||
|
monkeypatch.setattr(Axes3D, 'plot_trisurf', fake_plot_trisurf)
|
||||||
|
|
||||||
|
grid = Grid([numpy.arange(7, dtype=float), numpy.arange(7, dtype=float), numpy.arange(7, dtype=float)], shifts=[[0, 0, 0]])
|
||||||
|
cell_data = numpy.zeros(grid.cell_data_shape)
|
||||||
|
|
||||||
|
fig, _ax = grid.visualize_isosurface(cell_data, level=0.5, sample_period=2, finalize=False)
|
||||||
|
|
||||||
|
assert_allclose(captured['xs'], [1.5, 1.5, 3.5])
|
||||||
|
assert_allclose(captured['ys'], [1.5, 3.5, 1.5])
|
||||||
|
assert_allclose(captured['zs'], [1.5, 1.5, 1.5])
|
||||||
|
|
||||||
|
pyplot.close(fig)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue