From 22cb410d84ff4f33b727376761ebc489b59c382e Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Tue, 21 Apr 2026 20:00:52 -0700 Subject: [PATCH] [GridData / save / load] Add GridData and update save format --- gridlock/__init__.py | 1 + gridlock/data.py | 176 +++++++++++++++++++++++++++++++++++++ gridlock/grid.py | 110 ++++++++++++++++++++--- gridlock/test/test_grid.py | 102 ++++++++++++++++++++- 4 files changed, 376 insertions(+), 13 deletions(-) create mode 100644 gridlock/data.py diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 3f965fd..759d1c1 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -31,6 +31,7 @@ from .utils import ( PlaneDict as PlaneDict, ) from .grid import Grid as Grid +from .data import GridData as GridData __author__ = 'Jan Petykiewicz' diff --git a/gridlock/data.py b/gridlock/data.py new file mode 100644 index 0000000..5e6faa5 --- /dev/null +++ b/gridlock/data.py @@ -0,0 +1,176 @@ +from dataclasses import dataclass +from typing import Self +from collections.abc import Sequence + +import numpy +from numpy.typing import NDArray, ArrayLike + +from .draw import foreground_t +from .grid import Grid, _grid_from_payload, _load_payload, _payload_scalar_str, _save_npz_payload +from .utils import ( + ExtentDict, + ExtentProtocol, + GridError, + PlaneDict, + PlaneProtocol, + SlabDict, + SlabProtocol, +) + + +@dataclass(slots=True) +class GridData: + grid: Grid + cell_data: NDArray + + def __post_init__(self) -> None: + if tuple(self.cell_data.shape) != tuple(self.grid.cell_data_shape): + raise GridError( + f'cell_data has shape {self.cell_data.shape}, expected {tuple(self.grid.cell_data_shape)}' + ) + + @staticmethod + def load(filename: str) -> 'GridData': + payload = _load_payload(filename) + if _payload_scalar_str(payload, 'kind') != 'grid_data': + raise GridError('Serialized payload does not contain GridData') + if 'cell_data' not in payload: + raise GridError('Serialized GridData payload is missing cell_data') + + return GridData(_grid_from_payload(payload), numpy.array(payload['cell_data'])) + + def save(self, filename: str) -> Self: + payload = self.grid._serialization_payload(kind='grid_data') + payload['cell_data'] = self.cell_data + _save_npz_payload(filename, payload) + return self + + def copy(self) -> Self: + return GridData(self.grid.copy(), self.cell_data.copy()) + + def draw_polygons( + self, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + polygons: Sequence[ArrayLike], + *, + offset2d: ArrayLike = (0, 0), + ) -> Self: + self.grid.draw_polygons(self.cell_data, foreground, slab, polygons, offset2d=offset2d) + return self + + def draw_polygon( + self, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + polygon: ArrayLike, + *, + offset2d: ArrayLike = (0, 0), + ) -> Self: + self.grid.draw_polygon(self.cell_data, foreground, slab, polygon, offset2d=offset2d) + return self + + def draw_slab( + self, + foreground: Sequence[foreground_t] | foreground_t, + slab: SlabProtocol | SlabDict, + ) -> Self: + self.grid.draw_slab(self.cell_data, foreground, slab) + return self + + def draw_cuboid( + self, + foreground: Sequence[foreground_t] | foreground_t, + *, + x: ExtentProtocol | ExtentDict, + y: ExtentProtocol | ExtentDict, + z: ExtentProtocol | ExtentDict, + ) -> Self: + self.grid.draw_cuboid(self.cell_data, foreground, x=x, y=y, z=z) + return self + + def draw_cylinder( + self, + h: SlabProtocol | SlabDict, + radius: float, + num_points: int, + center2d: ArrayLike, + foreground: Sequence[foreground_t] | foreground_t, + ) -> Self: + self.grid.draw_cylinder(self.cell_data, h, radius, num_points, center2d, foreground) + return self + + def draw_extrude_rectangle( + self, + rectangle: ArrayLike, + direction: int, + polarity: int, + distance: float, + ) -> Self: + self.grid.draw_extrude_rectangle(self.cell_data, rectangle, direction, polarity, distance) + return self + + def get_slice( + self, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + sample_period: int = 1, + ) -> NDArray: + return self.grid.get_slice(self.cell_data, plane, which_shifts=which_shifts, sample_period=sample_period) + + def visualize_slice( + self, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + sample_period: int = 1, + finalize: bool = True, + pcolormesh_args: dict[str, object] | None = None, + ax: object | None = None, + ) -> tuple[object, object]: + return self.grid.visualize_slice( + self.cell_data, + plane, + which_shifts=which_shifts, + sample_period=sample_period, + finalize=finalize, + pcolormesh_args=pcolormesh_args, + ax=ax, + ) + + def visualize_edges( + self, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + sample_period: int = 1, + finalize: bool = True, + contour_args: dict[str, object] | None = None, + ax: object | None = None, + level_fraction: float = 0.7, + ) -> tuple[object, object]: + return self.grid.visualize_edges( + self.cell_data, + plane, + which_shifts=which_shifts, + sample_period=sample_period, + finalize=finalize, + contour_args=contour_args, + ax=ax, + level_fraction=level_fraction, + ) + + def visualize_isosurface( + self, + level: float | None = None, + which_shifts: int = 0, + sample_period: int = 1, + show_edges: bool = True, + finalize: bool = True, + ) -> tuple[object, object]: + return self.grid.visualize_isosurface( + self.cell_data, + level=level, + which_shifts=which_shifts, + sample_period=sample_period, + show_edges=show_edges, + finalize=finalize, + ) diff --git a/gridlock/grid.py b/gridlock/grid.py index 5bed422..eeb9708 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Self +from typing import TYPE_CHECKING, Any, ClassVar, Self from collections.abc import Callable, Sequence import numpy @@ -13,8 +13,78 @@ from .draw import GridDrawMixin from .read import GridReadMixin from .position import GridPosMixin +if TYPE_CHECKING: + from .data import GridData + foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray] +_FORMAT_VERSION = 1 + + +def _is_npz_file(filename: str) -> bool: + with open(filename, 'rb') as f: + return f.read(2) == b'PK' + + +def _save_npz_payload(filename: str, payload: dict[str, Any]) -> None: + with open(filename, 'wb') as f: + numpy.savez_compressed(f, **payload) + + +def _load_payload(filename: str) -> dict[str, Any]: + if _is_npz_file(filename): + with numpy.load(filename, allow_pickle=False) as payload: + return {key: payload[key] for key in payload.files} + + with open(filename, 'rb') as f: + legacy = pickle.load(f) + + if isinstance(legacy, Grid): + return legacy._serialization_payload(kind='grid') + if isinstance(legacy, dict): + grid = Grid([[-1, 1]] * 3) + grid.__dict__.update(legacy) + return grid._serialization_payload(kind='grid') + raise GridError('Unsupported serialized Grid payload') + + +def _payload_scalar_str(payload: dict[str, Any], key: str) -> str: + if key not in payload: + raise GridError(f'Missing serialized key: {key}') + + value = numpy.asarray(payload[key]) + if value.size != 1: + raise GridError(f'Serialized key {key} must be scalar') + return str(value.reshape(())) + + +def _payload_scalar_int(payload: dict[str, Any], key: str) -> int: + if key not in payload: + raise GridError(f'Missing serialized key: {key}') + + value = numpy.asarray(payload[key]) + if value.size != 1: + raise GridError(f'Serialized key {key} must be scalar') + return int(value.reshape(())) + + +def _grid_from_payload(payload: dict[str, Any]) -> 'Grid': + if _payload_scalar_int(payload, 'format_version') != _FORMAT_VERSION: + raise GridError('Unsupported serialized Grid format version') + + exyz = [] + for axis in range(3): + key = f'exyz_{axis}' + if key not in payload: + raise GridError(f'Missing serialized key: {key}') + exyz.append(numpy.array(payload[key], dtype=float)) + + if 'shifts' not in payload or 'periodic' not in payload: + raise GridError('Serialized Grid payload is missing shifts or periodic data') + + shifts = numpy.array(payload['shifts'], dtype=float) + periodic = numpy.array(payload['periodic'], dtype=bool).tolist() + return Grid(exyz, shifts=shifts, periodic=periodic) class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): @@ -110,6 +180,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): self.periodic = list(periodic) if len(self.periodic) != 3: raise GridError('periodic must be a bool or a sequence of length 3') + if not all(isinstance(pp, bool | numpy.bool_) for pp in self.periodic): + raise GridError('periodic sequence entries must be bool values') if len(self.shifts.shape) != 2: raise GridError('Misshapen shifts: shifts must have two axes! ' @@ -121,9 +193,16 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): if (numpy.abs(self.shifts) > 1).any(): raise GridError('Only shifts in the range [-1, 1] are currently supported') - if (self.shifts < 0).any(): - # TODO: Test negative shifts - warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2) + def _serialization_payload(self, *, kind: str) -> dict[str, Any]: + payload: dict[str, Any] = { + 'kind': numpy.array(kind), + 'format_version': numpy.array(_FORMAT_VERSION, dtype=int), + 'shifts': self.shifts, + 'periodic': numpy.array(self.periodic, dtype=bool), + } + for axis, exyz in enumerate(self.exyz): + payload[f'exyz_{axis}'] = exyz + return payload @staticmethod def load(filename: str) -> 'Grid': @@ -133,12 +212,11 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Args: filename: Filename to load from. """ - with open(filename, 'rb') as f: - tmp_dict = pickle.load(f) - - g = Grid([[-1, 1]] * 3) - g.__dict__.update(tmp_dict) - return g + payload = _load_payload(filename) + kind = _payload_scalar_str(payload, 'kind') + if kind not in ('grid', 'grid_data'): + raise GridError(f'Unsupported serialized kind: {kind}') + return _grid_from_payload(payload) def save(self, filename: str) -> Self: """ @@ -150,10 +228,18 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Returns: self """ - with open(filename, 'wb') as f: - pickle.dump(self.__dict__, f, protocol=2) + _save_npz_payload(filename, self._serialization_payload(kind='grid')) return self + def with_data( + self, + fill_value: float | None = 1.0, + dtype: type[numpy.number] = numpy.float32, + ) -> 'GridData': + from .data import GridData + + return GridData(self.copy(), self.allocate(fill_value=fill_value, dtype=dtype)) + def copy(self) -> Self: """ Returns: diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index b4929a4..ae0a73a 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,8 +1,9 @@ import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal +import pickle -from .. import Grid, Extent, GridError, Plane, Slab +from .. import Grid, GridData, Extent, GridError, Plane, Slab def test_draw_oncenter_2x2() -> None: @@ -311,6 +312,54 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest. pyplot.close(fig) +def test_grid_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False]) + path = tmp_path / 'grid.state' + + grid.save(str(path)) + loaded = Grid.load(str(path)) + + assert path.exists() + for original, restored in zip(grid.exyz, loaded.exyz, strict=True): + assert_allclose(restored, original) + assert_allclose(loaded.shifts, grid.shifts) + assert loaded.periodic == grid.periodic + + +def test_grid_load_supports_legacy_pickle(tmp_path: pytest.TempPathFactory) -> None: + grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False]) + path = tmp_path / 'grid.pickle' + with open(path, 'wb') as f: + pickle.dump(grid.__dict__, f, protocol=2) + + loaded = Grid.load(str(path)) + + for original, restored in zip(grid.exyz, loaded.exyz, strict=True): + assert_allclose(restored, original) + assert_allclose(loaded.shifts, grid.shifts) + assert loaded.periodic == grid.periodic + + +def test_griddata_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None: + data = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]]).with_data(fill_value=2.0) + data.cell_data[0, 1, 0, 0] = 5.0 + path = tmp_path / 'griddata.state' + + data.save(str(path)) + loaded = GridData.load(str(path)) + + assert path.exists() + assert_allclose(loaded.cell_data, data.cell_data) + assert_allclose(loaded.grid.shifts, data.grid.shifts) + assert loaded.grid.periodic == data.grid.periodic + + +def test_griddata_rejects_invalid_payload_kind(tmp_path: pytest.TempPathFactory) -> None: + path = tmp_path / 'grid.state' + Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).save(str(path)) + + with pytest.raises(GridError): + GridData.load(str(path)) def test_negative_shift_nonperiodic_edges_and_widths() -> None: @@ -364,3 +413,54 @@ def test_negative_shift_get_slice_uses_shifted_centers() -> None: assert_allclose(grid_slice, [7, 9]) +def test_grid_with_data_returns_griddata() -> None: + grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) + data = grid.with_data(fill_value=2.0) + + assert isinstance(data, GridData) + assert_allclose(data.cell_data, numpy.full(grid.cell_data_shape, 2.0, dtype=numpy.float32)) + + +def test_griddata_constructor_validates_shape() -> None: + grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) + + with pytest.raises(GridError): + GridData(grid, numpy.zeros((1, 1, 1))) + + +def test_griddata_draw_methods_are_chainable() -> None: + data = Grid([[0, 1, 2], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0) + + chained = data.draw_cuboid( + foreground=1, + x=dict(min=0, max=1), + y=dict(min=0, max=1), + z=dict(min=0, max=1), + ).draw_polygon( + foreground=0.5, + slab=dict(axis='z', center=0.5, span=1.0), + polygon=numpy.array([[0, 0], [2, 0], [2, 1], [0, 1]], dtype=float), + ) + + assert chained is data + assert data.cell_data.sum() > 0 + + +def test_griddata_read_methods_delegate() -> None: + data = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0) + data.cell_data[0, :, :, 0] = numpy.array([[1, 2], [3, 4]], dtype=float) + + assert_allclose( + data.get_slice(Plane(z=0.5)), + data.grid.get_slice(data.cell_data, Plane(z=0.5)), + ) + + +def test_griddata_copy_is_independent() -> None: + data = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=1.0) + cloned = data.copy() + cloned.cell_data[0, 0, 0, 0] = 5.0 + + assert data is not cloned + assert data.grid is not cloned.grid + assert data.cell_data[0, 0, 0, 0] == 1.0