Compare commits

..

No commits in common. "96aad5a3a10ab779bbdf00da081cdbf85861096d" and "32b6c207dcf703f6c695dc9fae3f7615bd5e8f15" have entirely different histories.

7 changed files with 67 additions and 311 deletions

View file

@ -34,5 +34,5 @@ from .grid import Grid as Grid
__author__ = 'Jan Petykiewicz'
__version__ = '2.1'
__version__ = '2.0'
version = __version__

View file

@ -61,25 +61,23 @@ class GridDrawMixin(GridPosMixin):
for ii in range(len(poly_list)):
polygon = poly_list[ii]
malformed = f'Malformed polygon: ({ii})'
if polygon.ndim != 2:
raise GridError(malformed + 'must be a 2-dimensional ndarray')
if polygon.shape[1] not in (2, 3):
raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray')
if polygon.shape[1] == 3:
if numpy.unique(polygon[:, slab.axis]).size != 1:
raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis])
polygon = polygon[:, surface]
polygon = polygon[surface, :]
poly_list[ii] = polygon
if not polygon.shape[0] > 2:
raise GridError(malformed + 'must consist of more than 2 points')
if polygon.ndim > 2 and not numpy.unique(polygon[:, slab.axis]).size == 1:
raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis])
# Broadcast foreground where necessary
foregrounds: Sequence[foreground_callable_t] | Sequence[float]
if isinstance(foreground, numpy.ndarray):
if numpy.size(foreground) == 1: # type: ignore
foregrounds = [foreground] * len(cell_data) # type: ignore
elif isinstance(foreground, numpy.ndarray):
raise GridError('ndarray not supported for foreground')
if callable(foreground) or numpy.isscalar(foreground):
foregrounds = [foreground] * len(cell_data) # type: ignore[list-item]
else:
foregrounds = foreground # type: ignore
@ -298,6 +296,8 @@ class GridDrawMixin(GridPosMixin):
if isinstance(z, dict):
z = Extent(**z)
center = numpy.asarray([x.center, y.center, z.center])
p = numpy.array([[x.min, y.max],
[x.max, y.max],
[x.max, y.min],
@ -376,18 +376,15 @@ class GridDrawMixin(GridPosMixin):
foreground_func = []
for ii, grid in enumerate(cell_data):
zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction]
ind = [int(numpy.floor(zz)) if dd == direction else slice(None) for dd in range(3)]
fpart = zz - numpy.floor(zz)
low = int(numpy.clip(numpy.floor(zz), 0, grid.shape[direction] - 1))
high = int(numpy.clip(numpy.floor(zz) + 1, 0, grid.shape[direction] - 1))
low_ind = [low if dd == direction else slice(None) for dd in range(3)]
high_ind = [high if dd == direction else slice(None) for dd in range(3)]
if low == high:
foreground = grid[tuple(low_ind)]
else:
mult = [1 - fpart, fpart][::sgn] # reverses if s negative
foreground = mult[0] * grid[tuple(low_ind)] + mult[1] * grid[tuple(high_ind)]
foreground = mult[0] * grid[tuple(ind)]
ind[direction] += 1 # type: ignore #(known safe)
foreground += mult[1] * grid[tuple(ind)]
def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001
# transform from natural position to index
@ -401,3 +398,4 @@ class GridDrawMixin(GridPosMixin):
slab = Slab(axis=direction, center=center[direction], span=thickness)
self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface])

View file

@ -95,8 +95,6 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
`GridError` on invalid input
"""
edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates]
if len(edge_arrs) != 3:
raise GridError('pixel_edge_coordinates must contain exactly 3 coordinate arrays')
self.exyz = [numpy.unique(edges) for edges in edge_arrs]
self.shifts = numpy.array(shifts, dtype=float)
@ -108,8 +106,6 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
self.periodic = [periodic] * 3
else:
self.periodic = list(periodic)
if len(self.periodic) != 3:
raise GridError('periodic must be a bool or a sequence of length 3')
if len(self.shifts.shape) != 2:
raise GridError('Misshapen shifts: shifts must have two axes! '

View file

@ -47,13 +47,13 @@ class GridPosMixin(GridBase):
else:
low_bound = -0.5
high_bound = -0.5
if (ind < low_bound).any() or (ind > self.shape + high_bound).any():
if (ind < low_bound).any() or (ind > self.shape - high_bound).any():
raise GridError(f'Position outside of grid: {ind}')
if round_ind:
rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1)
sxyz = self.shifted_xyz(which_shifts)
position = [sxyz[a][rind[a]] for a in range(3)]
position = [sxyz[a][rind[a]].astype(int) for a in range(3)]
else:
sexyz = self.shifted_exyz(which_shifts)
position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a])

View file

@ -20,26 +20,6 @@ if TYPE_CHECKING:
class GridReadMixin(GridPosMixin):
@staticmethod
def _preview_exyz_from_centers(centers: NDArray, fallback_edges: NDArray) -> NDArray[numpy.float64]:
if centers.size > 1:
midpoints = 0.5 * (centers[:-1] + centers[1:])
first = centers[0] - 0.5 * (centers[1] - centers[0])
last = centers[-1] + 0.5 * (centers[-1] - centers[-2])
return numpy.hstack(([first], midpoints, [last]))
return numpy.array([fallback_edges[0], fallback_edges[-1]], dtype=float)
def _sampled_exyz(self, which_shifts: int, sample_period: int) -> list[NDArray[numpy.float64]]:
if sample_period <= 1:
return self.shifted_exyz(which_shifts)
shifted_xyz = self.shifted_xyz(which_shifts)
shifted_exyz = self.shifted_exyz(which_shifts)
return [
self._preview_exyz_from_centers(shifted_xyz[a][::sample_period], shifted_exyz[a])
for a in range(3)
]
def get_slice(
self,
cell_data: NDArray,
@ -83,13 +63,12 @@ class GridReadMixin(GridPosMixin):
else:
w = [1]
c_min, c_max = (self.shifted_xyz(which_shifts)[plane.axis][i] for i in [0, -1])
c_min, c_max = (self.xyz[plane.axis][i] for i in [0, -1])
if plane.pos < c_min or plane.pos > c_max:
raise GridError('Coordinate of selected plane must be within simulation domain')
# Extract grid values from planes above and below visualized slice
sample_shape = tuple(self.shifted_xyz(which_shifts)[a][::sp].size for a in surface)
sliced_grid = numpy.zeros(sample_shape, dtype=numpy.result_type(cell_data.dtype, float))
sliced_grid = numpy.zeros(self.shape[surface])
for ci, weight in zip(centers, w, strict=True):
s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3))
sliced_grid += weight * cell_data[which_shifts][tuple(s)]
@ -143,11 +122,7 @@ class GridReadMixin(GridPosMixin):
surface = numpy.delete(range(3), plane.axis)
if sample_period == 1:
x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
else:
x, y = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface)
pcolormesh_args.setdefault('shading', 'nearest')
xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
x_label, y_label = ('xyz'[a] for a in surface)
@ -233,10 +208,10 @@ class GridReadMixin(GridPosMixin):
fig, ax = pyplot.subplots()
else:
fig = ax.figure
xc, yc = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface)
xc, yc = (self.shifted_xyz(which_shifts)[a] for a in surface)
xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij')
ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args)
mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args)
if finalize:
pyplot.show()
@ -282,14 +257,8 @@ class GridReadMixin(GridPosMixin):
verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level)
# Convert vertices from index to position
preview_exyz = self._sampled_exyz(which_shifts, sample_period)
pos_verts = numpy.array([
[
numpy.interp(verts[i, a], numpy.arange(preview_exyz[a].size) - 0.5, preview_exyz[a])
for a in range(3)
]
for i in range(verts.shape[0])
], dtype=float)
pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False)
for i in range(verts.shape[0])], dtype=float)
xs, ys, zs = (pos_verts[:, a] for a in range(3))
# Draw the plot

View file

@ -1,8 +1,8 @@
import pytest
# import pytest
import numpy
from numpy.testing import assert_allclose #, assert_array_equal
from .. import Grid, Extent, GridError, Plane, Slab
from .. import Grid, Extent #, Slab, Plane
def test_draw_oncenter_2x2() -> None:
@ -116,196 +116,3 @@ def test_draw_2shift_4x4() -> None:
[0, 0.125, 0.125, 0]])[None, :, :, None]
assert_allclose(arr, correct)
def test_ind2pos_round_preserves_float_centers() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]])
pos = grid.ind2pos(numpy.array([1, 0, 0]), which_shifts=0)
assert_allclose(pos, [2.0, 1.0, 0.5])
def test_ind2pos_enforces_bounds_for_rounded_and_fractional_indices() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]])
with pytest.raises(GridError):
grid.ind2pos(numpy.array([2, 0, 0]), which_shifts=0, check_bounds=True)
edge_pos = grid.ind2pos(numpy.array([1.5, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True)
assert_allclose(edge_pos, [3.0, 2.0, 1.0])
with pytest.raises(GridError):
grid.ind2pos(numpy.array([1.6, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True)
def test_draw_polygon_accepts_coplanar_nx3_vertices() -> None:
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]])
arr_2d = grid.allocate(0)
arr_3d = grid.allocate(0)
slab = dict(axis='z', center=0.5, span=1.0)
polygon_2d = numpy.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=float)
polygon_3d = numpy.array([[0, 0, 0.5],
[1, 0, 0.5],
[1, 1, 0.5],
[0, 1, 0.5]], dtype=float)
grid.draw_polygon(arr_2d, slab=slab, polygon=polygon_2d, foreground=1)
grid.draw_polygon(arr_3d, slab=slab, polygon=polygon_3d, foreground=1)
assert_allclose(arr_3d, arr_2d)
def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None:
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]])
arr = grid.allocate(0)
polygon = numpy.array([[0, 0, 0.5],
[1, 0, 0.5],
[1, 1, 0.75],
[0, 1, 0.5]], dtype=float)
with pytest.raises(GridError):
grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1)
def test_get_slice_supports_sampling() -> None:
grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]])
cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape)
grid_slice = grid.get_slice(cell_data, Plane(z=0.5), sample_period=2)
assert_allclose(grid_slice, cell_data[0, ::2, ::2, 0])
def test_sampled_visualization_helpers_do_not_error() -> None:
matplotlib = pytest.importorskip('matplotlib')
matplotlib.use('Agg')
from matplotlib import pyplot
grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]])
cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape)
fig_slice, ax_slice = grid.visualize_slice(cell_data, Plane(z=0.5), sample_period=2, finalize=False)
fig_edges, ax_edges = grid.visualize_edges(cell_data, Plane(z=0.5), sample_period=2, finalize=False)
assert fig_slice is ax_slice.figure
assert fig_edges is ax_edges.figure
pyplot.close(fig_slice)
pyplot.close(fig_edges)
def test_grid_constructor_rejects_invalid_coordinate_count() -> None:
with pytest.raises(GridError):
Grid([[0, 1], [0, 1]], shifts=[[0, 0, 0]])
with pytest.raises(GridError):
Grid([[0, 1], [0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
def test_grid_constructor_rejects_invalid_periodic_length() -> None:
with pytest.raises(GridError):
Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]], periodic=[True, False])
def test_extent_and_slab_reject_inverted_geometry() -> None:
with pytest.raises(GridError):
Extent(center=0, min=1)
with pytest.raises(GridError):
Extent(min=2, max=1)
with pytest.raises(GridError):
Slab(axis='z', center=1, max=0)
def test_extent_accepts_scalar_like_inputs() -> None:
extent = Extent(min=numpy.array([1.0]), span=numpy.array([4.0]))
assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0])
def test_get_slice_uses_shifted_grid_bounds() -> None:
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0.5, 0, 0]])
cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape)
grid_slice = grid.get_slice(cell_data, Plane(x=2.0), which_shifts=0)
assert_allclose(grid_slice, cell_data[0, 1, :, :])
with pytest.raises(GridError):
grid.get_slice(cell_data, Plane(x=2.1), which_shifts=0)
def test_draw_extrude_rectangle_uses_boundary_slice() -> None:
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]])
cell_data = grid.allocate(0)
source = numpy.array([[1, 2],
[3, 4]], dtype=float)
cell_data[0, :, :, 1] = source
grid.draw_extrude_rectangle(
cell_data,
rectangle=[[0, 0, 2], [2, 2, 2]],
direction=2,
polarity=-1,
distance=2,
)
assert_allclose(cell_data[0, :, :, 0], source)
assert_allclose(cell_data[0, :, :, 1], source)
def test_sampled_preview_exyz_tracks_nonuniform_centers() -> None:
grid = Grid([[0, 1, 3, 6, 10], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]])
sampled_exyz = grid._sampled_exyz(0, 2)
assert_allclose(sampled_exyz[0], [-1.5, 2.5, 6.5])
def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None:
matplotlib = pytest.importorskip('matplotlib')
matplotlib.use('Agg')
skimage_measure = pytest.importorskip('skimage.measure')
from matplotlib import pyplot
from mpl_toolkits.mplot3d.axes3d import Axes3D
captured: dict[str, numpy.ndarray] = {}
def fake_marching_cubes(_grid: numpy.ndarray, _level: float) -> tuple[numpy.ndarray, numpy.ndarray, None, None]:
verts = numpy.array([[0.5, 0.5, 0.5],
[0.5, 1.5, 0.5],
[1.5, 0.5, 0.5]], dtype=float)
faces = numpy.array([[0, 1, 2]], dtype=int)
return verts, faces, None, None
def fake_plot_trisurf( # noqa: ANN202
_self: object,
xs: numpy.ndarray,
ys: numpy.ndarray,
faces: numpy.ndarray,
zs: numpy.ndarray,
*_args: object,
**_kwargs: object,
) -> object:
captured['xs'] = numpy.asarray(xs)
captured['ys'] = numpy.asarray(ys)
captured['faces'] = numpy.asarray(faces)
captured['zs'] = numpy.asarray(zs)
return object()
monkeypatch.setattr(skimage_measure, 'marching_cubes', fake_marching_cubes)
monkeypatch.setattr(Axes3D, 'plot_trisurf', fake_plot_trisurf)
grid = Grid([numpy.arange(7, dtype=float), numpy.arange(7, dtype=float), numpy.arange(7, dtype=float)], shifts=[[0, 0, 0]])
cell_data = numpy.zeros(grid.cell_data_shape)
fig, _ax = grid.visualize_isosurface(cell_data, level=0.5, sample_period=2, finalize=False)
assert_allclose(captured['xs'], [1.5, 1.5, 3.5])
assert_allclose(captured['ys'], [1.5, 3.5, 1.5])
assert_allclose(captured['zs'], [1.5, 1.5, 1.5])
pyplot.close(fig)

View file

@ -1,25 +1,12 @@
from typing import Protocol, TypedDict, runtime_checkable, cast
from dataclasses import dataclass
import numpy
class GridError(Exception):
""" Base error type for `gridlock` """
pass
def _coerce_scalar(name: str, value: object) -> float:
arr = numpy.asarray(value)
if arr.size != 1:
raise GridError(f'{name} must be a scalar value')
try:
return float(arr.reshape(()))
except (TypeError, ValueError) as exc:
raise GridError(f'{name} must be a real scalar value') from exc
class ExtentDict(TypedDict, total=False):
"""
Geometrical definition of an extent (1D bounded region)
@ -71,46 +58,44 @@ class Extent(ExtentProtocol):
max: float | None = None,
span: float | None = None,
) -> None:
values = {
'min': None if min is None else _coerce_scalar('min', min),
'center': None if center is None else _coerce_scalar('center', center),
'max': None if max is None else _coerce_scalar('max', max),
'span': None if span is None else _coerce_scalar('span', span),
}
if sum(value is not None for value in values.values()) != 2:
raise GridError('Exactly two of min, center, max, span must be provided')
if sum(cc is None for cc in (min, center, max, span)) != 2:
raise GridError('Exactly two of min, center, max, span must be None!')
min_v = values['min']
center_v = values['center']
max_v = values['max']
span_v = values['span']
if span is None:
if center is None:
assert min is not None
assert max is not None
assert max >= min
center = 0.5 * (max + min)
span = max - min
elif max is None:
assert min is not None
assert center is not None
span = 2 * (center - min)
elif min is None:
assert center is not None
assert max is not None
span = 2 * (max - center)
else: # noqa: PLR5501
if center is not None:
pass
elif max is None:
assert min is not None
assert span is not None
center = min + 0.5 * span
elif min is None:
assert max is not None
assert span is not None
center = max - 0.5 * span
if span_v is not None and span_v < 0:
raise GridError('span must be non-negative')
if min_v is not None and max_v is not None:
if max_v < min_v:
raise GridError('max must be greater than or equal to min')
center_v = 0.5 * (max_v + min_v)
span_v = max_v - min_v
elif center_v is not None and min_v is not None:
span_v = 2 * (center_v - min_v)
if span_v < 0:
raise GridError('min must be less than or equal to center')
elif center_v is not None and max_v is not None:
span_v = 2 * (max_v - center_v)
if span_v < 0:
raise GridError('center must be less than or equal to max')
elif min_v is not None and span_v is not None:
center_v = min_v + 0.5 * span_v
elif max_v is not None and span_v is not None:
center_v = max_v - 0.5 * span_v
if center_v is None or span_v is None:
raise GridError('Unable to construct extent from the provided values')
self.center = center_v
self.span = span_v
assert center is not None
assert span is not None
if hasattr(center, '__len__'):
assert len(center) == 1
if hasattr(span, '__len__'):
assert len(span) == 1
self.center = center
self.span = span
class SlabDict(TypedDict, total=False):
@ -246,3 +231,4 @@ class Plane(PlaneProtocol):
if hasattr(cpos, '__len__'):
assert len(cpos) == 1
self.pos = cpos