[GridData / save / load] Add GridData and update save format
This commit is contained in:
parent
85ae6e66cd
commit
22cb410d84
4 changed files with 376 additions and 13 deletions
|
|
@ -31,6 +31,7 @@ from .utils import (
|
|||
PlaneDict as PlaneDict,
|
||||
)
|
||||
from .grid import Grid as Grid
|
||||
from .data import GridData as GridData
|
||||
|
||||
|
||||
__author__ = 'Jan Petykiewicz'
|
||||
|
|
|
|||
176
gridlock/data.py
Normal file
176
gridlock/data.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
from collections.abc import Sequence
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray, ArrayLike
|
||||
|
||||
from .draw import foreground_t
|
||||
from .grid import Grid, _grid_from_payload, _load_payload, _payload_scalar_str, _save_npz_payload
|
||||
from .utils import (
|
||||
ExtentDict,
|
||||
ExtentProtocol,
|
||||
GridError,
|
||||
PlaneDict,
|
||||
PlaneProtocol,
|
||||
SlabDict,
|
||||
SlabProtocol,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GridData:
|
||||
grid: Grid
|
||||
cell_data: NDArray
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if tuple(self.cell_data.shape) != tuple(self.grid.cell_data_shape):
|
||||
raise GridError(
|
||||
f'cell_data has shape {self.cell_data.shape}, expected {tuple(self.grid.cell_data_shape)}'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load(filename: str) -> 'GridData':
|
||||
payload = _load_payload(filename)
|
||||
if _payload_scalar_str(payload, 'kind') != 'grid_data':
|
||||
raise GridError('Serialized payload does not contain GridData')
|
||||
if 'cell_data' not in payload:
|
||||
raise GridError('Serialized GridData payload is missing cell_data')
|
||||
|
||||
return GridData(_grid_from_payload(payload), numpy.array(payload['cell_data']))
|
||||
|
||||
def save(self, filename: str) -> Self:
|
||||
payload = self.grid._serialization_payload(kind='grid_data')
|
||||
payload['cell_data'] = self.cell_data
|
||||
_save_npz_payload(filename, payload)
|
||||
return self
|
||||
|
||||
def copy(self) -> Self:
|
||||
return GridData(self.grid.copy(), self.cell_data.copy())
|
||||
|
||||
def draw_polygons(
|
||||
self,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
slab: SlabProtocol | SlabDict,
|
||||
polygons: Sequence[ArrayLike],
|
||||
*,
|
||||
offset2d: ArrayLike = (0, 0),
|
||||
) -> Self:
|
||||
self.grid.draw_polygons(self.cell_data, foreground, slab, polygons, offset2d=offset2d)
|
||||
return self
|
||||
|
||||
def draw_polygon(
|
||||
self,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
slab: SlabProtocol | SlabDict,
|
||||
polygon: ArrayLike,
|
||||
*,
|
||||
offset2d: ArrayLike = (0, 0),
|
||||
) -> Self:
|
||||
self.grid.draw_polygon(self.cell_data, foreground, slab, polygon, offset2d=offset2d)
|
||||
return self
|
||||
|
||||
def draw_slab(
|
||||
self,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
slab: SlabProtocol | SlabDict,
|
||||
) -> Self:
|
||||
self.grid.draw_slab(self.cell_data, foreground, slab)
|
||||
return self
|
||||
|
||||
def draw_cuboid(
|
||||
self,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
*,
|
||||
x: ExtentProtocol | ExtentDict,
|
||||
y: ExtentProtocol | ExtentDict,
|
||||
z: ExtentProtocol | ExtentDict,
|
||||
) -> Self:
|
||||
self.grid.draw_cuboid(self.cell_data, foreground, x=x, y=y, z=z)
|
||||
return self
|
||||
|
||||
def draw_cylinder(
|
||||
self,
|
||||
h: SlabProtocol | SlabDict,
|
||||
radius: float,
|
||||
num_points: int,
|
||||
center2d: ArrayLike,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
) -> Self:
|
||||
self.grid.draw_cylinder(self.cell_data, h, radius, num_points, center2d, foreground)
|
||||
return self
|
||||
|
||||
def draw_extrude_rectangle(
|
||||
self,
|
||||
rectangle: ArrayLike,
|
||||
direction: int,
|
||||
polarity: int,
|
||||
distance: float,
|
||||
) -> Self:
|
||||
self.grid.draw_extrude_rectangle(self.cell_data, rectangle, direction, polarity, distance)
|
||||
return self
|
||||
|
||||
def get_slice(
|
||||
self,
|
||||
plane: PlaneProtocol | PlaneDict,
|
||||
which_shifts: int = 0,
|
||||
sample_period: int = 1,
|
||||
) -> NDArray:
|
||||
return self.grid.get_slice(self.cell_data, plane, which_shifts=which_shifts, sample_period=sample_period)
|
||||
|
||||
def visualize_slice(
|
||||
self,
|
||||
plane: PlaneProtocol | PlaneDict,
|
||||
which_shifts: int = 0,
|
||||
sample_period: int = 1,
|
||||
finalize: bool = True,
|
||||
pcolormesh_args: dict[str, object] | None = None,
|
||||
ax: object | None = None,
|
||||
) -> tuple[object, object]:
|
||||
return self.grid.visualize_slice(
|
||||
self.cell_data,
|
||||
plane,
|
||||
which_shifts=which_shifts,
|
||||
sample_period=sample_period,
|
||||
finalize=finalize,
|
||||
pcolormesh_args=pcolormesh_args,
|
||||
ax=ax,
|
||||
)
|
||||
|
||||
def visualize_edges(
|
||||
self,
|
||||
plane: PlaneProtocol | PlaneDict,
|
||||
which_shifts: int = 0,
|
||||
sample_period: int = 1,
|
||||
finalize: bool = True,
|
||||
contour_args: dict[str, object] | None = None,
|
||||
ax: object | None = None,
|
||||
level_fraction: float = 0.7,
|
||||
) -> tuple[object, object]:
|
||||
return self.grid.visualize_edges(
|
||||
self.cell_data,
|
||||
plane,
|
||||
which_shifts=which_shifts,
|
||||
sample_period=sample_period,
|
||||
finalize=finalize,
|
||||
contour_args=contour_args,
|
||||
ax=ax,
|
||||
level_fraction=level_fraction,
|
||||
)
|
||||
|
||||
def visualize_isosurface(
|
||||
self,
|
||||
level: float | None = None,
|
||||
which_shifts: int = 0,
|
||||
sample_period: int = 1,
|
||||
show_edges: bool = True,
|
||||
finalize: bool = True,
|
||||
) -> tuple[object, object]:
|
||||
return self.grid.visualize_isosurface(
|
||||
self.cell_data,
|
||||
level=level,
|
||||
which_shifts=which_shifts,
|
||||
sample_period=sample_period,
|
||||
show_edges=show_edges,
|
||||
finalize=finalize,
|
||||
)
|
||||
110
gridlock/grid.py
110
gridlock/grid.py
|
|
@ -1,4 +1,4 @@
|
|||
from typing import ClassVar, Self
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Self
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import numpy
|
||||
|
|
@ -13,8 +13,78 @@ from .draw import GridDrawMixin
|
|||
from .read import GridReadMixin
|
||||
from .position import GridPosMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .data import GridData
|
||||
|
||||
|
||||
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
|
||||
_FORMAT_VERSION = 1
|
||||
|
||||
|
||||
def _is_npz_file(filename: str) -> bool:
|
||||
with open(filename, 'rb') as f:
|
||||
return f.read(2) == b'PK'
|
||||
|
||||
|
||||
def _save_npz_payload(filename: str, payload: dict[str, Any]) -> None:
|
||||
with open(filename, 'wb') as f:
|
||||
numpy.savez_compressed(f, **payload)
|
||||
|
||||
|
||||
def _load_payload(filename: str) -> dict[str, Any]:
|
||||
if _is_npz_file(filename):
|
||||
with numpy.load(filename, allow_pickle=False) as payload:
|
||||
return {key: payload[key] for key in payload.files}
|
||||
|
||||
with open(filename, 'rb') as f:
|
||||
legacy = pickle.load(f)
|
||||
|
||||
if isinstance(legacy, Grid):
|
||||
return legacy._serialization_payload(kind='grid')
|
||||
if isinstance(legacy, dict):
|
||||
grid = Grid([[-1, 1]] * 3)
|
||||
grid.__dict__.update(legacy)
|
||||
return grid._serialization_payload(kind='grid')
|
||||
raise GridError('Unsupported serialized Grid payload')
|
||||
|
||||
|
||||
def _payload_scalar_str(payload: dict[str, Any], key: str) -> str:
|
||||
if key not in payload:
|
||||
raise GridError(f'Missing serialized key: {key}')
|
||||
|
||||
value = numpy.asarray(payload[key])
|
||||
if value.size != 1:
|
||||
raise GridError(f'Serialized key {key} must be scalar')
|
||||
return str(value.reshape(()))
|
||||
|
||||
|
||||
def _payload_scalar_int(payload: dict[str, Any], key: str) -> int:
|
||||
if key not in payload:
|
||||
raise GridError(f'Missing serialized key: {key}')
|
||||
|
||||
value = numpy.asarray(payload[key])
|
||||
if value.size != 1:
|
||||
raise GridError(f'Serialized key {key} must be scalar')
|
||||
return int(value.reshape(()))
|
||||
|
||||
|
||||
def _grid_from_payload(payload: dict[str, Any]) -> 'Grid':
|
||||
if _payload_scalar_int(payload, 'format_version') != _FORMAT_VERSION:
|
||||
raise GridError('Unsupported serialized Grid format version')
|
||||
|
||||
exyz = []
|
||||
for axis in range(3):
|
||||
key = f'exyz_{axis}'
|
||||
if key not in payload:
|
||||
raise GridError(f'Missing serialized key: {key}')
|
||||
exyz.append(numpy.array(payload[key], dtype=float))
|
||||
|
||||
if 'shifts' not in payload or 'periodic' not in payload:
|
||||
raise GridError('Serialized Grid payload is missing shifts or periodic data')
|
||||
|
||||
shifts = numpy.array(payload['shifts'], dtype=float)
|
||||
periodic = numpy.array(payload['periodic'], dtype=bool).tolist()
|
||||
return Grid(exyz, shifts=shifts, periodic=periodic)
|
||||
|
||||
|
||||
class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
||||
|
|
@ -110,6 +180,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
self.periodic = list(periodic)
|
||||
if len(self.periodic) != 3:
|
||||
raise GridError('periodic must be a bool or a sequence of length 3')
|
||||
if not all(isinstance(pp, bool | numpy.bool_) for pp in self.periodic):
|
||||
raise GridError('periodic sequence entries must be bool values')
|
||||
|
||||
if len(self.shifts.shape) != 2:
|
||||
raise GridError('Misshapen shifts: shifts must have two axes! '
|
||||
|
|
@ -121,9 +193,16 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
if (numpy.abs(self.shifts) > 1).any():
|
||||
raise GridError('Only shifts in the range [-1, 1] are currently supported')
|
||||
|
||||
if (self.shifts < 0).any():
|
||||
# TODO: Test negative shifts
|
||||
warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2)
|
||||
def _serialization_payload(self, *, kind: str) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
'kind': numpy.array(kind),
|
||||
'format_version': numpy.array(_FORMAT_VERSION, dtype=int),
|
||||
'shifts': self.shifts,
|
||||
'periodic': numpy.array(self.periodic, dtype=bool),
|
||||
}
|
||||
for axis, exyz in enumerate(self.exyz):
|
||||
payload[f'exyz_{axis}'] = exyz
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def load(filename: str) -> 'Grid':
|
||||
|
|
@ -133,12 +212,11 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
Args:
|
||||
filename: Filename to load from.
|
||||
"""
|
||||
with open(filename, 'rb') as f:
|
||||
tmp_dict = pickle.load(f)
|
||||
|
||||
g = Grid([[-1, 1]] * 3)
|
||||
g.__dict__.update(tmp_dict)
|
||||
return g
|
||||
payload = _load_payload(filename)
|
||||
kind = _payload_scalar_str(payload, 'kind')
|
||||
if kind not in ('grid', 'grid_data'):
|
||||
raise GridError(f'Unsupported serialized kind: {kind}')
|
||||
return _grid_from_payload(payload)
|
||||
|
||||
def save(self, filename: str) -> Self:
|
||||
"""
|
||||
|
|
@ -150,10 +228,18 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
Returns:
|
||||
self
|
||||
"""
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(self.__dict__, f, protocol=2)
|
||||
_save_npz_payload(filename, self._serialization_payload(kind='grid'))
|
||||
return self
|
||||
|
||||
def with_data(
|
||||
self,
|
||||
fill_value: float | None = 1.0,
|
||||
dtype: type[numpy.number] = numpy.float32,
|
||||
) -> 'GridData':
|
||||
from .data import GridData
|
||||
|
||||
return GridData(self.copy(), self.allocate(fill_value=fill_value, dtype=dtype))
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""
|
||||
Returns:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import pytest
|
||||
import numpy
|
||||
from numpy.testing import assert_allclose #, assert_array_equal
|
||||
import pickle
|
||||
|
||||
from .. import Grid, Extent, GridError, Plane, Slab
|
||||
from .. import Grid, GridData, Extent, GridError, Plane, Slab
|
||||
|
||||
|
||||
def test_draw_oncenter_2x2() -> None:
|
||||
|
|
@ -311,6 +312,54 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.
|
|||
pyplot.close(fig)
|
||||
|
||||
|
||||
def test_grid_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False])
|
||||
path = tmp_path / 'grid.state'
|
||||
|
||||
grid.save(str(path))
|
||||
loaded = Grid.load(str(path))
|
||||
|
||||
assert path.exists()
|
||||
for original, restored in zip(grid.exyz, loaded.exyz, strict=True):
|
||||
assert_allclose(restored, original)
|
||||
assert_allclose(loaded.shifts, grid.shifts)
|
||||
assert loaded.periodic == grid.periodic
|
||||
|
||||
|
||||
def test_grid_load_supports_legacy_pickle(tmp_path: pytest.TempPathFactory) -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False])
|
||||
path = tmp_path / 'grid.pickle'
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(grid.__dict__, f, protocol=2)
|
||||
|
||||
loaded = Grid.load(str(path))
|
||||
|
||||
for original, restored in zip(grid.exyz, loaded.exyz, strict=True):
|
||||
assert_allclose(restored, original)
|
||||
assert_allclose(loaded.shifts, grid.shifts)
|
||||
assert loaded.periodic == grid.periodic
|
||||
|
||||
|
||||
def test_griddata_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None:
|
||||
data = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]]).with_data(fill_value=2.0)
|
||||
data.cell_data[0, 1, 0, 0] = 5.0
|
||||
path = tmp_path / 'griddata.state'
|
||||
|
||||
data.save(str(path))
|
||||
loaded = GridData.load(str(path))
|
||||
|
||||
assert path.exists()
|
||||
assert_allclose(loaded.cell_data, data.cell_data)
|
||||
assert_allclose(loaded.grid.shifts, data.grid.shifts)
|
||||
assert loaded.grid.periodic == data.grid.periodic
|
||||
|
||||
|
||||
def test_griddata_rejects_invalid_payload_kind(tmp_path: pytest.TempPathFactory) -> None:
|
||||
path = tmp_path / 'grid.state'
|
||||
Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).save(str(path))
|
||||
|
||||
with pytest.raises(GridError):
|
||||
GridData.load(str(path))
|
||||
|
||||
|
||||
def test_negative_shift_nonperiodic_edges_and_widths() -> None:
|
||||
|
|
@ -364,3 +413,54 @@ def test_negative_shift_get_slice_uses_shifted_centers() -> None:
|
|||
assert_allclose(grid_slice, [7, 9])
|
||||
|
||||
|
||||
def test_grid_with_data_returns_griddata() -> None:
|
||||
grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
|
||||
data = grid.with_data(fill_value=2.0)
|
||||
|
||||
assert isinstance(data, GridData)
|
||||
assert_allclose(data.cell_data, numpy.full(grid.cell_data_shape, 2.0, dtype=numpy.float32))
|
||||
|
||||
|
||||
def test_griddata_constructor_validates_shape() -> None:
|
||||
grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
|
||||
|
||||
with pytest.raises(GridError):
|
||||
GridData(grid, numpy.zeros((1, 1, 1)))
|
||||
|
||||
|
||||
def test_griddata_draw_methods_are_chainable() -> None:
|
||||
data = Grid([[0, 1, 2], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0)
|
||||
|
||||
chained = data.draw_cuboid(
|
||||
foreground=1,
|
||||
x=dict(min=0, max=1),
|
||||
y=dict(min=0, max=1),
|
||||
z=dict(min=0, max=1),
|
||||
).draw_polygon(
|
||||
foreground=0.5,
|
||||
slab=dict(axis='z', center=0.5, span=1.0),
|
||||
polygon=numpy.array([[0, 0], [2, 0], [2, 1], [0, 1]], dtype=float),
|
||||
)
|
||||
|
||||
assert chained is data
|
||||
assert data.cell_data.sum() > 0
|
||||
|
||||
|
||||
def test_griddata_read_methods_delegate() -> None:
|
||||
data = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0)
|
||||
data.cell_data[0, :, :, 0] = numpy.array([[1, 2], [3, 4]], dtype=float)
|
||||
|
||||
assert_allclose(
|
||||
data.get_slice(Plane(z=0.5)),
|
||||
data.grid.get_slice(data.cell_data, Plane(z=0.5)),
|
||||
)
|
||||
|
||||
|
||||
def test_griddata_copy_is_independent() -> None:
|
||||
data = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=1.0)
|
||||
cloned = data.copy()
|
||||
cloned.cell_data[0, 0, 0, 0] = 5.0
|
||||
|
||||
assert data is not cloned
|
||||
assert data.grid is not cloned.grid
|
||||
assert data.cell_data[0, 0, 0, 0] == 1.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue