diff --git a/gridlock/__init__.py b/gridlock/__init__.py index e7be065..2f39696 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -34,5 +34,5 @@ from .grid import Grid as Grid __author__ = 'Jan Petykiewicz' -__version__ = '2.1' +__version__ = '2.0' version = __version__ diff --git a/gridlock/draw.py b/gridlock/draw.py index 321ec15..9ba4623 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -61,25 +61,23 @@ class GridDrawMixin(GridPosMixin): for ii in range(len(poly_list)): polygon = poly_list[ii] malformed = f'Malformed polygon: ({ii})' - if polygon.ndim != 2: - raise GridError(malformed + 'must be a 2-dimensional ndarray') if polygon.shape[1] not in (2, 3): raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') if polygon.shape[1] == 3: - if numpy.unique(polygon[:, slab.axis]).size != 1: - raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis]) - polygon = polygon[:, surface] + polygon = polygon[surface, :] poly_list[ii] = polygon if not polygon.shape[0] > 2: raise GridError(malformed + 'must consist of more than 2 points') + if polygon.ndim > 2 and not numpy.unique(polygon[:, slab.axis]).size == 1: + raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis]) # Broadcast foreground where necessary foregrounds: Sequence[foreground_callable_t] | Sequence[float] - if isinstance(foreground, numpy.ndarray): + if numpy.size(foreground) == 1: # type: ignore + foregrounds = [foreground] * len(cell_data) # type: ignore + elif isinstance(foreground, numpy.ndarray): raise GridError('ndarray not supported for foreground') - if callable(foreground) or numpy.isscalar(foreground): - foregrounds = [foreground] * len(cell_data) # type: ignore[list-item] else: foregrounds = foreground # type: ignore @@ -298,6 +296,8 @@ class GridDrawMixin(GridPosMixin): if isinstance(z, dict): z = Extent(**z) + center = numpy.asarray([x.center, y.center, z.center]) + p = numpy.array([[x.min, y.max], [x.max, y.max], [x.max, y.min], @@ -376,18 +376,15 @@ class GridDrawMixin(GridPosMixin): foreground_func = [] for ii, grid in enumerate(cell_data): zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction] + + ind = [int(numpy.floor(zz)) if dd == direction else slice(None) for dd in range(3)] + fpart = zz - numpy.floor(zz) - low = int(numpy.clip(numpy.floor(zz), 0, grid.shape[direction] - 1)) - high = int(numpy.clip(numpy.floor(zz) + 1, 0, grid.shape[direction] - 1)) + mult = [1 - fpart, fpart][::sgn] # reverses if s negative - low_ind = [low if dd == direction else slice(None) for dd in range(3)] - high_ind = [high if dd == direction else slice(None) for dd in range(3)] - - if low == high: - foreground = grid[tuple(low_ind)] - else: - mult = [1 - fpart, fpart][::sgn] # reverses if s negative - foreground = mult[0] * grid[tuple(low_ind)] + mult[1] * grid[tuple(high_ind)] + foreground = mult[0] * grid[tuple(ind)] + ind[direction] += 1 # type: ignore #(known safe) + foreground += mult[1] * grid[tuple(ind)] def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 # transform from natural position to index @@ -401,3 +398,4 @@ class GridDrawMixin(GridPosMixin): slab = Slab(axis=direction, center=center[direction], span=thickness) self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface]) + diff --git a/gridlock/grid.py b/gridlock/grid.py index 5bed422..5790dbd 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -95,8 +95,6 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): `GridError` on invalid input """ edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates] - if len(edge_arrs) != 3: - raise GridError('pixel_edge_coordinates must contain exactly 3 coordinate arrays') self.exyz = [numpy.unique(edges) for edges in edge_arrs] self.shifts = numpy.array(shifts, dtype=float) @@ -108,8 +106,6 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): self.periodic = [periodic] * 3 else: self.periodic = list(periodic) - if len(self.periodic) != 3: - raise GridError('periodic must be a bool or a sequence of length 3') if len(self.shifts.shape) != 2: raise GridError('Misshapen shifts: shifts must have two axes! ' diff --git a/gridlock/position.py b/gridlock/position.py index 6344ea4..b705b99 100644 --- a/gridlock/position.py +++ b/gridlock/position.py @@ -47,13 +47,13 @@ class GridPosMixin(GridBase): else: low_bound = -0.5 high_bound = -0.5 - if (ind < low_bound).any() or (ind > self.shape + high_bound).any(): + if (ind < low_bound).any() or (ind > self.shape - high_bound).any(): raise GridError(f'Position outside of grid: {ind}') if round_ind: rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) sxyz = self.shifted_xyz(which_shifts) - position = [sxyz[a][rind[a]] for a in range(3)] + position = [sxyz[a][rind[a]].astype(int) for a in range(3)] else: sexyz = self.shifted_exyz(which_shifts) position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a]) diff --git a/gridlock/read.py b/gridlock/read.py index 9be52b1..503e996 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -20,26 +20,6 @@ if TYPE_CHECKING: class GridReadMixin(GridPosMixin): - @staticmethod - def _preview_exyz_from_centers(centers: NDArray, fallback_edges: NDArray) -> NDArray[numpy.float64]: - if centers.size > 1: - midpoints = 0.5 * (centers[:-1] + centers[1:]) - first = centers[0] - 0.5 * (centers[1] - centers[0]) - last = centers[-1] + 0.5 * (centers[-1] - centers[-2]) - return numpy.hstack(([first], midpoints, [last])) - return numpy.array([fallback_edges[0], fallback_edges[-1]], dtype=float) - - def _sampled_exyz(self, which_shifts: int, sample_period: int) -> list[NDArray[numpy.float64]]: - if sample_period <= 1: - return self.shifted_exyz(which_shifts) - - shifted_xyz = self.shifted_xyz(which_shifts) - shifted_exyz = self.shifted_exyz(which_shifts) - return [ - self._preview_exyz_from_centers(shifted_xyz[a][::sample_period], shifted_exyz[a]) - for a in range(3) - ] - def get_slice( self, cell_data: NDArray, @@ -83,13 +63,12 @@ class GridReadMixin(GridPosMixin): else: w = [1] - c_min, c_max = (self.shifted_xyz(which_shifts)[plane.axis][i] for i in [0, -1]) + c_min, c_max = (self.xyz[plane.axis][i] for i in [0, -1]) if plane.pos < c_min or plane.pos > c_max: raise GridError('Coordinate of selected plane must be within simulation domain') # Extract grid values from planes above and below visualized slice - sample_shape = tuple(self.shifted_xyz(which_shifts)[a][::sp].size for a in surface) - sliced_grid = numpy.zeros(sample_shape, dtype=numpy.result_type(cell_data.dtype, float)) + sliced_grid = numpy.zeros(self.shape[surface]) for ci, weight in zip(centers, w, strict=True): s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3)) sliced_grid += weight * cell_data[which_shifts][tuple(s)] @@ -143,11 +122,7 @@ class GridReadMixin(GridPosMixin): surface = numpy.delete(range(3), plane.axis) - if sample_period == 1: - x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) - else: - x, y = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface) - pcolormesh_args.setdefault('shading', 'nearest') + x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') x_label, y_label = ('xyz'[a] for a in surface) @@ -233,10 +208,10 @@ class GridReadMixin(GridPosMixin): fig, ax = pyplot.subplots() else: fig = ax.figure - xc, yc = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface) + xc, yc = (self.shifted_xyz(which_shifts)[a] for a in surface) xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij') - ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) + mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) if finalize: pyplot.show() @@ -282,14 +257,8 @@ class GridReadMixin(GridPosMixin): verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) # Convert vertices from index to position - preview_exyz = self._sampled_exyz(which_shifts, sample_period) - pos_verts = numpy.array([ - [ - numpy.interp(verts[i, a], numpy.arange(preview_exyz[a].size) - 0.5, preview_exyz[a]) - for a in range(3) - ] - for i in range(verts.shape[0]) - ], dtype=float) + pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False) + for i in range(verts.shape[0])], dtype=float) xs, ys, zs = (pos_verts[:, a] for a in range(3)) # Draw the plot diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index e7b3b28..8d9ca92 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,8 +1,8 @@ -import pytest +# import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal -from .. import Grid, Extent, GridError, Plane, Slab +from .. import Grid, Extent #, Slab, Plane def test_draw_oncenter_2x2() -> None: @@ -116,196 +116,3 @@ def test_draw_2shift_4x4() -> None: [0, 0.125, 0.125, 0]])[None, :, :, None] assert_allclose(arr, correct) - - -def test_ind2pos_round_preserves_float_centers() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]]) - - pos = grid.ind2pos(numpy.array([1, 0, 0]), which_shifts=0) - - assert_allclose(pos, [2.0, 1.0, 0.5]) - - -def test_ind2pos_enforces_bounds_for_rounded_and_fractional_indices() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]]) - - with pytest.raises(GridError): - grid.ind2pos(numpy.array([2, 0, 0]), which_shifts=0, check_bounds=True) - - edge_pos = grid.ind2pos(numpy.array([1.5, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True) - assert_allclose(edge_pos, [3.0, 2.0, 1.0]) - - with pytest.raises(GridError): - grid.ind2pos(numpy.array([1.6, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True) - - -def test_draw_polygon_accepts_coplanar_nx3_vertices() -> None: - grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) - arr_2d = grid.allocate(0) - arr_3d = grid.allocate(0) - slab = dict(axis='z', center=0.5, span=1.0) - - polygon_2d = numpy.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=float) - polygon_3d = numpy.array([[0, 0, 0.5], - [1, 0, 0.5], - [1, 1, 0.5], - [0, 1, 0.5]], dtype=float) - - grid.draw_polygon(arr_2d, slab=slab, polygon=polygon_2d, foreground=1) - grid.draw_polygon(arr_3d, slab=slab, polygon=polygon_3d, foreground=1) - - assert_allclose(arr_3d, arr_2d) - - -def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None: - grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) - arr = grid.allocate(0) - polygon = numpy.array([[0, 0, 0.5], - [1, 0, 0.5], - [1, 1, 0.75], - [0, 1, 0.5]], dtype=float) - - with pytest.raises(GridError): - grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1) - - -def test_get_slice_supports_sampling() -> None: - grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]]) - cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) - - grid_slice = grid.get_slice(cell_data, Plane(z=0.5), sample_period=2) - - assert_allclose(grid_slice, cell_data[0, ::2, ::2, 0]) - - -def test_sampled_visualization_helpers_do_not_error() -> None: - matplotlib = pytest.importorskip('matplotlib') - matplotlib.use('Agg') - from matplotlib import pyplot - - grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]]) - cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) - - fig_slice, ax_slice = grid.visualize_slice(cell_data, Plane(z=0.5), sample_period=2, finalize=False) - fig_edges, ax_edges = grid.visualize_edges(cell_data, Plane(z=0.5), sample_period=2, finalize=False) - - assert fig_slice is ax_slice.figure - assert fig_edges is ax_edges.figure - - pyplot.close(fig_slice) - pyplot.close(fig_edges) - - -def test_grid_constructor_rejects_invalid_coordinate_count() -> None: - with pytest.raises(GridError): - Grid([[0, 1], [0, 1]], shifts=[[0, 0, 0]]) - - with pytest.raises(GridError): - Grid([[0, 1], [0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) - - -def test_grid_constructor_rejects_invalid_periodic_length() -> None: - with pytest.raises(GridError): - Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]], periodic=[True, False]) - - -def test_extent_and_slab_reject_inverted_geometry() -> None: - with pytest.raises(GridError): - Extent(center=0, min=1) - - with pytest.raises(GridError): - Extent(min=2, max=1) - - with pytest.raises(GridError): - Slab(axis='z', center=1, max=0) - - -def test_extent_accepts_scalar_like_inputs() -> None: - extent = Extent(min=numpy.array([1.0]), span=numpy.array([4.0])) - - assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0]) - - -def test_get_slice_uses_shifted_grid_bounds() -> None: - grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0.5, 0, 0]]) - cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) - - grid_slice = grid.get_slice(cell_data, Plane(x=2.0), which_shifts=0) - - assert_allclose(grid_slice, cell_data[0, 1, :, :]) - - with pytest.raises(GridError): - grid.get_slice(cell_data, Plane(x=2.1), which_shifts=0) - - -def test_draw_extrude_rectangle_uses_boundary_slice() -> None: - grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]]) - cell_data = grid.allocate(0) - source = numpy.array([[1, 2], - [3, 4]], dtype=float) - cell_data[0, :, :, 1] = source - - grid.draw_extrude_rectangle( - cell_data, - rectangle=[[0, 0, 2], [2, 2, 2]], - direction=2, - polarity=-1, - distance=2, - ) - - assert_allclose(cell_data[0, :, :, 0], source) - assert_allclose(cell_data[0, :, :, 1], source) - - -def test_sampled_preview_exyz_tracks_nonuniform_centers() -> None: - grid = Grid([[0, 1, 3, 6, 10], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]]) - - sampled_exyz = grid._sampled_exyz(0, 2) - - assert_allclose(sampled_exyz[0], [-1.5, 2.5, 6.5]) - - -def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None: - matplotlib = pytest.importorskip('matplotlib') - matplotlib.use('Agg') - skimage_measure = pytest.importorskip('skimage.measure') - from matplotlib import pyplot - from mpl_toolkits.mplot3d.axes3d import Axes3D - - captured: dict[str, numpy.ndarray] = {} - - def fake_marching_cubes(_grid: numpy.ndarray, _level: float) -> tuple[numpy.ndarray, numpy.ndarray, None, None]: - verts = numpy.array([[0.5, 0.5, 0.5], - [0.5, 1.5, 0.5], - [1.5, 0.5, 0.5]], dtype=float) - faces = numpy.array([[0, 1, 2]], dtype=int) - return verts, faces, None, None - - def fake_plot_trisurf( # noqa: ANN202 - _self: object, - xs: numpy.ndarray, - ys: numpy.ndarray, - faces: numpy.ndarray, - zs: numpy.ndarray, - *_args: object, - **_kwargs: object, - ) -> object: - captured['xs'] = numpy.asarray(xs) - captured['ys'] = numpy.asarray(ys) - captured['faces'] = numpy.asarray(faces) - captured['zs'] = numpy.asarray(zs) - return object() - - monkeypatch.setattr(skimage_measure, 'marching_cubes', fake_marching_cubes) - monkeypatch.setattr(Axes3D, 'plot_trisurf', fake_plot_trisurf) - - grid = Grid([numpy.arange(7, dtype=float), numpy.arange(7, dtype=float), numpy.arange(7, dtype=float)], shifts=[[0, 0, 0]]) - cell_data = numpy.zeros(grid.cell_data_shape) - - fig, _ax = grid.visualize_isosurface(cell_data, level=0.5, sample_period=2, finalize=False) - - assert_allclose(captured['xs'], [1.5, 1.5, 3.5]) - assert_allclose(captured['ys'], [1.5, 3.5, 1.5]) - assert_allclose(captured['zs'], [1.5, 1.5, 1.5]) - - pyplot.close(fig) diff --git a/gridlock/utils.py b/gridlock/utils.py index 585b999..8a8f11d 100644 --- a/gridlock/utils.py +++ b/gridlock/utils.py @@ -1,25 +1,12 @@ from typing import Protocol, TypedDict, runtime_checkable, cast from dataclasses import dataclass -import numpy - class GridError(Exception): """ Base error type for `gridlock` """ pass -def _coerce_scalar(name: str, value: object) -> float: - arr = numpy.asarray(value) - if arr.size != 1: - raise GridError(f'{name} must be a scalar value') - - try: - return float(arr.reshape(())) - except (TypeError, ValueError) as exc: - raise GridError(f'{name} must be a real scalar value') from exc - - class ExtentDict(TypedDict, total=False): """ Geometrical definition of an extent (1D bounded region) @@ -71,46 +58,44 @@ class Extent(ExtentProtocol): max: float | None = None, span: float | None = None, ) -> None: - values = { - 'min': None if min is None else _coerce_scalar('min', min), - 'center': None if center is None else _coerce_scalar('center', center), - 'max': None if max is None else _coerce_scalar('max', max), - 'span': None if span is None else _coerce_scalar('span', span), - } - if sum(value is not None for value in values.values()) != 2: - raise GridError('Exactly two of min, center, max, span must be provided') + if sum(cc is None for cc in (min, center, max, span)) != 2: + raise GridError('Exactly two of min, center, max, span must be None!') - min_v = values['min'] - center_v = values['center'] - max_v = values['max'] - span_v = values['span'] + if span is None: + if center is None: + assert min is not None + assert max is not None + assert max >= min + center = 0.5 * (max + min) + span = max - min + elif max is None: + assert min is not None + assert center is not None + span = 2 * (center - min) + elif min is None: + assert center is not None + assert max is not None + span = 2 * (max - center) + else: # noqa: PLR5501 + if center is not None: + pass + elif max is None: + assert min is not None + assert span is not None + center = min + 0.5 * span + elif min is None: + assert max is not None + assert span is not None + center = max - 0.5 * span - if span_v is not None and span_v < 0: - raise GridError('span must be non-negative') - - if min_v is not None and max_v is not None: - if max_v < min_v: - raise GridError('max must be greater than or equal to min') - center_v = 0.5 * (max_v + min_v) - span_v = max_v - min_v - elif center_v is not None and min_v is not None: - span_v = 2 * (center_v - min_v) - if span_v < 0: - raise GridError('min must be less than or equal to center') - elif center_v is not None and max_v is not None: - span_v = 2 * (max_v - center_v) - if span_v < 0: - raise GridError('center must be less than or equal to max') - elif min_v is not None and span_v is not None: - center_v = min_v + 0.5 * span_v - elif max_v is not None and span_v is not None: - center_v = max_v - 0.5 * span_v - - if center_v is None or span_v is None: - raise GridError('Unable to construct extent from the provided values') - - self.center = center_v - self.span = span_v + assert center is not None + assert span is not None + if hasattr(center, '__len__'): + assert len(center) == 1 + if hasattr(span, '__len__'): + assert len(span) == 1 + self.center = center + self.span = span class SlabDict(TypedDict, total=False): @@ -246,3 +231,4 @@ class Plane(PlaneProtocol): if hasattr(cpos, '__len__'): assert len(cpos) == 1 self.pos = cpos +