[GridData / save / load] Add GridData and update save format

This commit is contained in:
Jan Petykiewicz 2026-04-21 20:00:52 -07:00
commit 22cb410d84
4 changed files with 376 additions and 13 deletions

View file

@ -31,6 +31,7 @@ from .utils import (
PlaneDict as PlaneDict,
)
from .grid import Grid as Grid
from .data import GridData as GridData
__author__ = 'Jan Petykiewicz'

176
gridlock/data.py Normal file
View file

@ -0,0 +1,176 @@
from dataclasses import dataclass
from typing import Self
from collections.abc import Sequence
import numpy
from numpy.typing import NDArray, ArrayLike
from .draw import foreground_t
from .grid import Grid, _grid_from_payload, _load_payload, _payload_scalar_str, _save_npz_payload
from .utils import (
ExtentDict,
ExtentProtocol,
GridError,
PlaneDict,
PlaneProtocol,
SlabDict,
SlabProtocol,
)
@dataclass(slots=True)
class GridData:
grid: Grid
cell_data: NDArray
def __post_init__(self) -> None:
if tuple(self.cell_data.shape) != tuple(self.grid.cell_data_shape):
raise GridError(
f'cell_data has shape {self.cell_data.shape}, expected {tuple(self.grid.cell_data_shape)}'
)
@staticmethod
def load(filename: str) -> 'GridData':
payload = _load_payload(filename)
if _payload_scalar_str(payload, 'kind') != 'grid_data':
raise GridError('Serialized payload does not contain GridData')
if 'cell_data' not in payload:
raise GridError('Serialized GridData payload is missing cell_data')
return GridData(_grid_from_payload(payload), numpy.array(payload['cell_data']))
def save(self, filename: str) -> Self:
payload = self.grid._serialization_payload(kind='grid_data')
payload['cell_data'] = self.cell_data
_save_npz_payload(filename, payload)
return self
def copy(self) -> Self:
return GridData(self.grid.copy(), self.cell_data.copy())
def draw_polygons(
self,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
polygons: Sequence[ArrayLike],
*,
offset2d: ArrayLike = (0, 0),
) -> Self:
self.grid.draw_polygons(self.cell_data, foreground, slab, polygons, offset2d=offset2d)
return self
def draw_polygon(
self,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
polygon: ArrayLike,
*,
offset2d: ArrayLike = (0, 0),
) -> Self:
self.grid.draw_polygon(self.cell_data, foreground, slab, polygon, offset2d=offset2d)
return self
def draw_slab(
self,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
) -> Self:
self.grid.draw_slab(self.cell_data, foreground, slab)
return self
def draw_cuboid(
self,
foreground: Sequence[foreground_t] | foreground_t,
*,
x: ExtentProtocol | ExtentDict,
y: ExtentProtocol | ExtentDict,
z: ExtentProtocol | ExtentDict,
) -> Self:
self.grid.draw_cuboid(self.cell_data, foreground, x=x, y=y, z=z)
return self
def draw_cylinder(
self,
h: SlabProtocol | SlabDict,
radius: float,
num_points: int,
center2d: ArrayLike,
foreground: Sequence[foreground_t] | foreground_t,
) -> Self:
self.grid.draw_cylinder(self.cell_data, h, radius, num_points, center2d, foreground)
return self
def draw_extrude_rectangle(
self,
rectangle: ArrayLike,
direction: int,
polarity: int,
distance: float,
) -> Self:
self.grid.draw_extrude_rectangle(self.cell_data, rectangle, direction, polarity, distance)
return self
def get_slice(
self,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1,
) -> NDArray:
return self.grid.get_slice(self.cell_data, plane, which_shifts=which_shifts, sample_period=sample_period)
def visualize_slice(
self,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1,
finalize: bool = True,
pcolormesh_args: dict[str, object] | None = None,
ax: object | None = None,
) -> tuple[object, object]:
return self.grid.visualize_slice(
self.cell_data,
plane,
which_shifts=which_shifts,
sample_period=sample_period,
finalize=finalize,
pcolormesh_args=pcolormesh_args,
ax=ax,
)
def visualize_edges(
self,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1,
finalize: bool = True,
contour_args: dict[str, object] | None = None,
ax: object | None = None,
level_fraction: float = 0.7,
) -> tuple[object, object]:
return self.grid.visualize_edges(
self.cell_data,
plane,
which_shifts=which_shifts,
sample_period=sample_period,
finalize=finalize,
contour_args=contour_args,
ax=ax,
level_fraction=level_fraction,
)
def visualize_isosurface(
self,
level: float | None = None,
which_shifts: int = 0,
sample_period: int = 1,
show_edges: bool = True,
finalize: bool = True,
) -> tuple[object, object]:
return self.grid.visualize_isosurface(
self.cell_data,
level=level,
which_shifts=which_shifts,
sample_period=sample_period,
show_edges=show_edges,
finalize=finalize,
)

View file

@ -1,4 +1,4 @@
from typing import ClassVar, Self
from typing import TYPE_CHECKING, Any, ClassVar, Self
from collections.abc import Callable, Sequence
import numpy
@ -13,8 +13,78 @@ from .draw import GridDrawMixin
from .read import GridReadMixin
from .position import GridPosMixin
if TYPE_CHECKING:
from .data import GridData
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
_FORMAT_VERSION = 1
def _is_npz_file(filename: str) -> bool:
with open(filename, 'rb') as f:
return f.read(2) == b'PK'
def _save_npz_payload(filename: str, payload: dict[str, Any]) -> None:
with open(filename, 'wb') as f:
numpy.savez_compressed(f, **payload)
def _load_payload(filename: str) -> dict[str, Any]:
if _is_npz_file(filename):
with numpy.load(filename, allow_pickle=False) as payload:
return {key: payload[key] for key in payload.files}
with open(filename, 'rb') as f:
legacy = pickle.load(f)
if isinstance(legacy, Grid):
return legacy._serialization_payload(kind='grid')
if isinstance(legacy, dict):
grid = Grid([[-1, 1]] * 3)
grid.__dict__.update(legacy)
return grid._serialization_payload(kind='grid')
raise GridError('Unsupported serialized Grid payload')
def _payload_scalar_str(payload: dict[str, Any], key: str) -> str:
if key not in payload:
raise GridError(f'Missing serialized key: {key}')
value = numpy.asarray(payload[key])
if value.size != 1:
raise GridError(f'Serialized key {key} must be scalar')
return str(value.reshape(()))
def _payload_scalar_int(payload: dict[str, Any], key: str) -> int:
if key not in payload:
raise GridError(f'Missing serialized key: {key}')
value = numpy.asarray(payload[key])
if value.size != 1:
raise GridError(f'Serialized key {key} must be scalar')
return int(value.reshape(()))
def _grid_from_payload(payload: dict[str, Any]) -> 'Grid':
if _payload_scalar_int(payload, 'format_version') != _FORMAT_VERSION:
raise GridError('Unsupported serialized Grid format version')
exyz = []
for axis in range(3):
key = f'exyz_{axis}'
if key not in payload:
raise GridError(f'Missing serialized key: {key}')
exyz.append(numpy.array(payload[key], dtype=float))
if 'shifts' not in payload or 'periodic' not in payload:
raise GridError('Serialized Grid payload is missing shifts or periodic data')
shifts = numpy.array(payload['shifts'], dtype=float)
periodic = numpy.array(payload['periodic'], dtype=bool).tolist()
return Grid(exyz, shifts=shifts, periodic=periodic)
class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
@ -110,6 +180,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
self.periodic = list(periodic)
if len(self.periodic) != 3:
raise GridError('periodic must be a bool or a sequence of length 3')
if not all(isinstance(pp, bool | numpy.bool_) for pp in self.periodic):
raise GridError('periodic sequence entries must be bool values')
if len(self.shifts.shape) != 2:
raise GridError('Misshapen shifts: shifts must have two axes! '
@ -121,9 +193,16 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
if (numpy.abs(self.shifts) > 1).any():
raise GridError('Only shifts in the range [-1, 1] are currently supported')
if (self.shifts < 0).any():
# TODO: Test negative shifts
warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2)
def _serialization_payload(self, *, kind: str) -> dict[str, Any]:
payload: dict[str, Any] = {
'kind': numpy.array(kind),
'format_version': numpy.array(_FORMAT_VERSION, dtype=int),
'shifts': self.shifts,
'periodic': numpy.array(self.periodic, dtype=bool),
}
for axis, exyz in enumerate(self.exyz):
payload[f'exyz_{axis}'] = exyz
return payload
@staticmethod
def load(filename: str) -> 'Grid':
@ -133,12 +212,11 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
Args:
filename: Filename to load from.
"""
with open(filename, 'rb') as f:
tmp_dict = pickle.load(f)
g = Grid([[-1, 1]] * 3)
g.__dict__.update(tmp_dict)
return g
payload = _load_payload(filename)
kind = _payload_scalar_str(payload, 'kind')
if kind not in ('grid', 'grid_data'):
raise GridError(f'Unsupported serialized kind: {kind}')
return _grid_from_payload(payload)
def save(self, filename: str) -> Self:
"""
@ -150,10 +228,18 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
Returns:
self
"""
with open(filename, 'wb') as f:
pickle.dump(self.__dict__, f, protocol=2)
_save_npz_payload(filename, self._serialization_payload(kind='grid'))
return self
def with_data(
self,
fill_value: float | None = 1.0,
dtype: type[numpy.number] = numpy.float32,
) -> 'GridData':
from .data import GridData
return GridData(self.copy(), self.allocate(fill_value=fill_value, dtype=dtype))
def copy(self) -> Self:
"""
Returns:

View file

@ -1,8 +1,9 @@
import pytest
import numpy
from numpy.testing import assert_allclose #, assert_array_equal
import pickle
from .. import Grid, Extent, GridError, Plane, Slab
from .. import Grid, GridData, Extent, GridError, Plane, Slab
def test_draw_oncenter_2x2() -> None:
@ -311,6 +312,54 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.
pyplot.close(fig)
def test_grid_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False])
path = tmp_path / 'grid.state'
grid.save(str(path))
loaded = Grid.load(str(path))
assert path.exists()
for original, restored in zip(grid.exyz, loaded.exyz, strict=True):
assert_allclose(restored, original)
assert_allclose(loaded.shifts, grid.shifts)
assert loaded.periodic == grid.periodic
def test_grid_load_supports_legacy_pickle(tmp_path: pytest.TempPathFactory) -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False])
path = tmp_path / 'grid.pickle'
with open(path, 'wb') as f:
pickle.dump(grid.__dict__, f, protocol=2)
loaded = Grid.load(str(path))
for original, restored in zip(grid.exyz, loaded.exyz, strict=True):
assert_allclose(restored, original)
assert_allclose(loaded.shifts, grid.shifts)
assert loaded.periodic == grid.periodic
def test_griddata_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None:
data = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]]).with_data(fill_value=2.0)
data.cell_data[0, 1, 0, 0] = 5.0
path = tmp_path / 'griddata.state'
data.save(str(path))
loaded = GridData.load(str(path))
assert path.exists()
assert_allclose(loaded.cell_data, data.cell_data)
assert_allclose(loaded.grid.shifts, data.grid.shifts)
assert loaded.grid.periodic == data.grid.periodic
def test_griddata_rejects_invalid_payload_kind(tmp_path: pytest.TempPathFactory) -> None:
path = tmp_path / 'grid.state'
Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).save(str(path))
with pytest.raises(GridError):
GridData.load(str(path))
def test_negative_shift_nonperiodic_edges_and_widths() -> None:
@ -364,3 +413,54 @@ def test_negative_shift_get_slice_uses_shifted_centers() -> None:
assert_allclose(grid_slice, [7, 9])
def test_grid_with_data_returns_griddata() -> None:
grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
data = grid.with_data(fill_value=2.0)
assert isinstance(data, GridData)
assert_allclose(data.cell_data, numpy.full(grid.cell_data_shape, 2.0, dtype=numpy.float32))
def test_griddata_constructor_validates_shape() -> None:
grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
with pytest.raises(GridError):
GridData(grid, numpy.zeros((1, 1, 1)))
def test_griddata_draw_methods_are_chainable() -> None:
data = Grid([[0, 1, 2], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0)
chained = data.draw_cuboid(
foreground=1,
x=dict(min=0, max=1),
y=dict(min=0, max=1),
z=dict(min=0, max=1),
).draw_polygon(
foreground=0.5,
slab=dict(axis='z', center=0.5, span=1.0),
polygon=numpy.array([[0, 0], [2, 0], [2, 1], [0, 1]], dtype=float),
)
assert chained is data
assert data.cell_data.sum() > 0
def test_griddata_read_methods_delegate() -> None:
data = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0)
data.cell_data[0, :, :, 0] = numpy.array([[1, 2], [3, 4]], dtype=float)
assert_allclose(
data.get_slice(Plane(z=0.5)),
data.grid.get_slice(data.cell_data, Plane(z=0.5)),
)
def test_griddata_copy_is_independent() -> None:
data = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=1.0)
cloned = data.copy()
cloned.cell_data[0, 0, 0, 0] = 5.0
assert data is not cloned
assert data.grid is not cloned.grid
assert data.cell_data[0, 0, 0, 0] == 1.0