diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 759d1c1..3f965fd 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -31,7 +31,6 @@ from .utils import ( PlaneDict as PlaneDict, ) from .grid import Grid as Grid -from .data import GridData as GridData __author__ = 'Jan Petykiewicz' diff --git a/gridlock/base.py b/gridlock/base.py index e68d955..aca9c69 100644 --- a/gridlock/base.py +++ b/gridlock/base.py @@ -76,21 +76,6 @@ class GridBase(Protocol): el = [0 if p else -1 for p in self.periodic] return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)] - def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: - if which_shifts is None: - return self.dxyz_with_ghost - - shifts = self.shifts[which_shifts, :] - edge_dxyz = [] - for a in range(3): - if shifts[a] < 0: - ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0] - edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a]))) - else: - ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1] - edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost))) - return edge_dxyz - @property def center(self) -> NDArray[numpy.float64]: """ @@ -130,9 +115,15 @@ class GridBase(Protocol): """ if which_shifts is None: return self.exyz - edge_dxyz = self._shifted_edge_dxyz(which_shifts) + dxyz = self.dxyz_with_ghost shifts = self.shifts[which_shifts, :] - return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)] + + # If shift is negative, use left cell's dx to determine shift + for a in range(3): + if shifts[a] < 0: + dxyz[a] = numpy.roll(dxyz[a], 1) + + return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]: """ @@ -146,7 +137,20 @@ class GridBase(Protocol): """ if which_shifts is None: return self.dxyz - return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)] + shifts = self.shifts[which_shifts, :] + dxyz = self.dxyz_with_ghost + + # If shift is negative, use left cell's dx to determine size + sdxyz = [] + for a in range(3): + if shifts[a] < 0: + roll_dxyz = numpy.roll(dxyz[a], 1) + abs_shift = numpy.abs(shifts[a]) + sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift)) + else: + sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a]) + + return sdxyz def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: """ diff --git a/gridlock/data.py b/gridlock/data.py deleted file mode 100644 index 5e6faa5..0000000 --- a/gridlock/data.py +++ /dev/null @@ -1,176 +0,0 @@ -from dataclasses import dataclass -from typing import Self -from collections.abc import Sequence - -import numpy -from numpy.typing import NDArray, ArrayLike - -from .draw import foreground_t -from .grid import Grid, _grid_from_payload, _load_payload, _payload_scalar_str, _save_npz_payload -from .utils import ( - ExtentDict, - ExtentProtocol, - GridError, - PlaneDict, - PlaneProtocol, - SlabDict, - SlabProtocol, -) - - -@dataclass(slots=True) -class GridData: - grid: Grid - cell_data: NDArray - - def __post_init__(self) -> None: - if tuple(self.cell_data.shape) != tuple(self.grid.cell_data_shape): - raise GridError( - f'cell_data has shape {self.cell_data.shape}, expected {tuple(self.grid.cell_data_shape)}' - ) - - @staticmethod - def load(filename: str) -> 'GridData': - payload = _load_payload(filename) - if _payload_scalar_str(payload, 'kind') != 'grid_data': - raise GridError('Serialized payload does not contain GridData') - if 'cell_data' not in payload: - raise GridError('Serialized GridData payload is missing cell_data') - - return GridData(_grid_from_payload(payload), numpy.array(payload['cell_data'])) - - def save(self, filename: str) -> Self: - payload = self.grid._serialization_payload(kind='grid_data') - payload['cell_data'] = self.cell_data - _save_npz_payload(filename, payload) - return self - - def copy(self) -> Self: - return GridData(self.grid.copy(), self.cell_data.copy()) - - def draw_polygons( - self, - foreground: Sequence[foreground_t] | foreground_t, - slab: SlabProtocol | SlabDict, - polygons: Sequence[ArrayLike], - *, - offset2d: ArrayLike = (0, 0), - ) -> Self: - self.grid.draw_polygons(self.cell_data, foreground, slab, polygons, offset2d=offset2d) - return self - - def draw_polygon( - self, - foreground: Sequence[foreground_t] | foreground_t, - slab: SlabProtocol | SlabDict, - polygon: ArrayLike, - *, - offset2d: ArrayLike = (0, 0), - ) -> Self: - self.grid.draw_polygon(self.cell_data, foreground, slab, polygon, offset2d=offset2d) - return self - - def draw_slab( - self, - foreground: Sequence[foreground_t] | foreground_t, - slab: SlabProtocol | SlabDict, - ) -> Self: - self.grid.draw_slab(self.cell_data, foreground, slab) - return self - - def draw_cuboid( - self, - foreground: Sequence[foreground_t] | foreground_t, - *, - x: ExtentProtocol | ExtentDict, - y: ExtentProtocol | ExtentDict, - z: ExtentProtocol | ExtentDict, - ) -> Self: - self.grid.draw_cuboid(self.cell_data, foreground, x=x, y=y, z=z) - return self - - def draw_cylinder( - self, - h: SlabProtocol | SlabDict, - radius: float, - num_points: int, - center2d: ArrayLike, - foreground: Sequence[foreground_t] | foreground_t, - ) -> Self: - self.grid.draw_cylinder(self.cell_data, h, radius, num_points, center2d, foreground) - return self - - def draw_extrude_rectangle( - self, - rectangle: ArrayLike, - direction: int, - polarity: int, - distance: float, - ) -> Self: - self.grid.draw_extrude_rectangle(self.cell_data, rectangle, direction, polarity, distance) - return self - - def get_slice( - self, - plane: PlaneProtocol | PlaneDict, - which_shifts: int = 0, - sample_period: int = 1, - ) -> NDArray: - return self.grid.get_slice(self.cell_data, plane, which_shifts=which_shifts, sample_period=sample_period) - - def visualize_slice( - self, - plane: PlaneProtocol | PlaneDict, - which_shifts: int = 0, - sample_period: int = 1, - finalize: bool = True, - pcolormesh_args: dict[str, object] | None = None, - ax: object | None = None, - ) -> tuple[object, object]: - return self.grid.visualize_slice( - self.cell_data, - plane, - which_shifts=which_shifts, - sample_period=sample_period, - finalize=finalize, - pcolormesh_args=pcolormesh_args, - ax=ax, - ) - - def visualize_edges( - self, - plane: PlaneProtocol | PlaneDict, - which_shifts: int = 0, - sample_period: int = 1, - finalize: bool = True, - contour_args: dict[str, object] | None = None, - ax: object | None = None, - level_fraction: float = 0.7, - ) -> tuple[object, object]: - return self.grid.visualize_edges( - self.cell_data, - plane, - which_shifts=which_shifts, - sample_period=sample_period, - finalize=finalize, - contour_args=contour_args, - ax=ax, - level_fraction=level_fraction, - ) - - def visualize_isosurface( - self, - level: float | None = None, - which_shifts: int = 0, - sample_period: int = 1, - show_edges: bool = True, - finalize: bool = True, - ) -> tuple[object, object]: - return self.grid.visualize_isosurface( - self.cell_data, - level=level, - which_shifts=which_shifts, - sample_period=sample_period, - show_edges=show_edges, - finalize=finalize, - ) diff --git a/gridlock/grid.py b/gridlock/grid.py index eeb9708..5bed422 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, ClassVar, Self +from typing import ClassVar, Self from collections.abc import Callable, Sequence import numpy @@ -13,78 +13,8 @@ from .draw import GridDrawMixin from .read import GridReadMixin from .position import GridPosMixin -if TYPE_CHECKING: - from .data import GridData - foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray] -_FORMAT_VERSION = 1 - - -def _is_npz_file(filename: str) -> bool: - with open(filename, 'rb') as f: - return f.read(2) == b'PK' - - -def _save_npz_payload(filename: str, payload: dict[str, Any]) -> None: - with open(filename, 'wb') as f: - numpy.savez_compressed(f, **payload) - - -def _load_payload(filename: str) -> dict[str, Any]: - if _is_npz_file(filename): - with numpy.load(filename, allow_pickle=False) as payload: - return {key: payload[key] for key in payload.files} - - with open(filename, 'rb') as f: - legacy = pickle.load(f) - - if isinstance(legacy, Grid): - return legacy._serialization_payload(kind='grid') - if isinstance(legacy, dict): - grid = Grid([[-1, 1]] * 3) - grid.__dict__.update(legacy) - return grid._serialization_payload(kind='grid') - raise GridError('Unsupported serialized Grid payload') - - -def _payload_scalar_str(payload: dict[str, Any], key: str) -> str: - if key not in payload: - raise GridError(f'Missing serialized key: {key}') - - value = numpy.asarray(payload[key]) - if value.size != 1: - raise GridError(f'Serialized key {key} must be scalar') - return str(value.reshape(())) - - -def _payload_scalar_int(payload: dict[str, Any], key: str) -> int: - if key not in payload: - raise GridError(f'Missing serialized key: {key}') - - value = numpy.asarray(payload[key]) - if value.size != 1: - raise GridError(f'Serialized key {key} must be scalar') - return int(value.reshape(())) - - -def _grid_from_payload(payload: dict[str, Any]) -> 'Grid': - if _payload_scalar_int(payload, 'format_version') != _FORMAT_VERSION: - raise GridError('Unsupported serialized Grid format version') - - exyz = [] - for axis in range(3): - key = f'exyz_{axis}' - if key not in payload: - raise GridError(f'Missing serialized key: {key}') - exyz.append(numpy.array(payload[key], dtype=float)) - - if 'shifts' not in payload or 'periodic' not in payload: - raise GridError('Serialized Grid payload is missing shifts or periodic data') - - shifts = numpy.array(payload['shifts'], dtype=float) - periodic = numpy.array(payload['periodic'], dtype=bool).tolist() - return Grid(exyz, shifts=shifts, periodic=periodic) class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): @@ -180,8 +110,6 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): self.periodic = list(periodic) if len(self.periodic) != 3: raise GridError('periodic must be a bool or a sequence of length 3') - if not all(isinstance(pp, bool | numpy.bool_) for pp in self.periodic): - raise GridError('periodic sequence entries must be bool values') if len(self.shifts.shape) != 2: raise GridError('Misshapen shifts: shifts must have two axes! ' @@ -193,16 +121,9 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): if (numpy.abs(self.shifts) > 1).any(): raise GridError('Only shifts in the range [-1, 1] are currently supported') - def _serialization_payload(self, *, kind: str) -> dict[str, Any]: - payload: dict[str, Any] = { - 'kind': numpy.array(kind), - 'format_version': numpy.array(_FORMAT_VERSION, dtype=int), - 'shifts': self.shifts, - 'periodic': numpy.array(self.periodic, dtype=bool), - } - for axis, exyz in enumerate(self.exyz): - payload[f'exyz_{axis}'] = exyz - return payload + if (self.shifts < 0).any(): + # TODO: Test negative shifts + warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2) @staticmethod def load(filename: str) -> 'Grid': @@ -212,11 +133,12 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Args: filename: Filename to load from. """ - payload = _load_payload(filename) - kind = _payload_scalar_str(payload, 'kind') - if kind not in ('grid', 'grid_data'): - raise GridError(f'Unsupported serialized kind: {kind}') - return _grid_from_payload(payload) + with open(filename, 'rb') as f: + tmp_dict = pickle.load(f) + + g = Grid([[-1, 1]] * 3) + g.__dict__.update(tmp_dict) + return g def save(self, filename: str) -> Self: """ @@ -228,18 +150,10 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Returns: self """ - _save_npz_payload(filename, self._serialization_payload(kind='grid')) + with open(filename, 'wb') as f: + pickle.dump(self.__dict__, f, protocol=2) return self - def with_data( - self, - fill_value: float | None = 1.0, - dtype: type[numpy.number] = numpy.float32, - ) -> 'GridData': - from .data import GridData - - return GridData(self.copy(), self.allocate(fill_value=fill_value, dtype=dtype)) - def copy(self) -> Self: """ Returns: diff --git a/gridlock/read.py b/gridlock/read.py index f8a40a1..9be52b1 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -73,7 +73,7 @@ class GridReadMixin(GridPosMixin): surface = numpy.delete(range(3), plane.axis) # Extract indices and weights of planes - center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,)) + center3 = numpy.insert([0, 0], plane.axis, (plane.pos,)) center_index = self.pos2ind(center3, which_shifts, round_ind=False, check_bounds=False)[plane.axis] centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index ae0a73a..e7b3b28 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,9 +1,8 @@ import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal -import pickle -from .. import Grid, GridData, Extent, GridError, Plane, Slab +from .. import Grid, Extent, GridError, Plane, Slab def test_draw_oncenter_2x2() -> None: @@ -310,157 +309,3 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest. assert_allclose(captured['zs'], [1.5, 1.5, 1.5]) pyplot.close(fig) - - -def test_grid_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False]) - path = tmp_path / 'grid.state' - - grid.save(str(path)) - loaded = Grid.load(str(path)) - - assert path.exists() - for original, restored in zip(grid.exyz, loaded.exyz, strict=True): - assert_allclose(restored, original) - assert_allclose(loaded.shifts, grid.shifts) - assert loaded.periodic == grid.periodic - - -def test_grid_load_supports_legacy_pickle(tmp_path: pytest.TempPathFactory) -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False]) - path = tmp_path / 'grid.pickle' - with open(path, 'wb') as f: - pickle.dump(grid.__dict__, f, protocol=2) - - loaded = Grid.load(str(path)) - - for original, restored in zip(grid.exyz, loaded.exyz, strict=True): - assert_allclose(restored, original) - assert_allclose(loaded.shifts, grid.shifts) - assert loaded.periodic == grid.periodic - - -def test_griddata_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None: - data = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]]).with_data(fill_value=2.0) - data.cell_data[0, 1, 0, 0] = 5.0 - path = tmp_path / 'griddata.state' - - data.save(str(path)) - loaded = GridData.load(str(path)) - - assert path.exists() - assert_allclose(loaded.cell_data, data.cell_data) - assert_allclose(loaded.grid.shifts, data.grid.shifts) - assert loaded.grid.periodic == data.grid.periodic - - -def test_griddata_rejects_invalid_payload_kind(tmp_path: pytest.TempPathFactory) -> None: - path = tmp_path / 'grid.state' - Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).save(str(path)) - - with pytest.raises(GridError): - GridData.load(str(path)) - - -def test_negative_shift_nonperiodic_edges_and_widths() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) - - assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0]) - assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5]) - assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25]) - - -def test_negative_shift_periodic_edges_and_widths() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False]) - - assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0]) - assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5]) - - -def test_negative_shift_coordinate_round_trip() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) - - ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False) - pos = grid.ind2pos(ind, 0, round_ind=False) - - assert_allclose(ind, [1.0, 0.0, 0.0]) - assert_allclose(pos, [1.25, 1.0, 0.5]) - - -def test_negative_shift_draw_cuboid_fractional_fill() -> None: - grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) - arr = grid.allocate(0) - - grid.draw_cuboid( - arr, - x=dict(min=0, max=1), - y=dict(min=0, max=1), - z=dict(min=0, max=1), - foreground=1, - ) - - assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3]) - - -def test_negative_shift_get_slice_uses_shifted_centers() -> None: - grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) - cell_data = numpy.zeros(grid.cell_data_shape) - cell_data[0, 1, :, 0] = [7, 9] - x_center = float(grid.shifted_xyz(0)[0][1]) - - grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0) - - assert_allclose(grid_slice, [7, 9]) - - -def test_grid_with_data_returns_griddata() -> None: - grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) - data = grid.with_data(fill_value=2.0) - - assert isinstance(data, GridData) - assert_allclose(data.cell_data, numpy.full(grid.cell_data_shape, 2.0, dtype=numpy.float32)) - - -def test_griddata_constructor_validates_shape() -> None: - grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) - - with pytest.raises(GridError): - GridData(grid, numpy.zeros((1, 1, 1))) - - -def test_griddata_draw_methods_are_chainable() -> None: - data = Grid([[0, 1, 2], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0) - - chained = data.draw_cuboid( - foreground=1, - x=dict(min=0, max=1), - y=dict(min=0, max=1), - z=dict(min=0, max=1), - ).draw_polygon( - foreground=0.5, - slab=dict(axis='z', center=0.5, span=1.0), - polygon=numpy.array([[0, 0], [2, 0], [2, 1], [0, 1]], dtype=float), - ) - - assert chained is data - assert data.cell_data.sum() > 0 - - -def test_griddata_read_methods_delegate() -> None: - data = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0) - data.cell_data[0, :, :, 0] = numpy.array([[1, 2], [3, 4]], dtype=float) - - assert_allclose( - data.get_slice(Plane(z=0.5)), - data.grid.get_slice(data.cell_data, Plane(z=0.5)), - ) - - -def test_griddata_copy_is_independent() -> None: - data = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=1.0) - cloned = data.copy() - cloned.cell_data[0, 0, 0, 0] = 5.0 - - assert data is not cloned - assert data.grid is not cloned.grid - assert data.cell_data[0, 0, 0, 0] == 1.0