Compare commits
2 commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 22cb410d84 | |||
| 85ae6e66cd |
6 changed files with 450 additions and 36 deletions
|
|
@ -31,6 +31,7 @@ from .utils import (
|
||||||
PlaneDict as PlaneDict,
|
PlaneDict as PlaneDict,
|
||||||
)
|
)
|
||||||
from .grid import Grid as Grid
|
from .grid import Grid as Grid
|
||||||
|
from .data import GridData as GridData
|
||||||
|
|
||||||
|
|
||||||
__author__ = 'Jan Petykiewicz'
|
__author__ = 'Jan Petykiewicz'
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,21 @@ class GridBase(Protocol):
|
||||||
el = [0 if p else -1 for p in self.periodic]
|
el = [0 if p else -1 for p in self.periodic]
|
||||||
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
|
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
|
||||||
|
|
||||||
|
def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
||||||
|
if which_shifts is None:
|
||||||
|
return self.dxyz_with_ghost
|
||||||
|
|
||||||
|
shifts = self.shifts[which_shifts, :]
|
||||||
|
edge_dxyz = []
|
||||||
|
for a in range(3):
|
||||||
|
if shifts[a] < 0:
|
||||||
|
ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0]
|
||||||
|
edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a])))
|
||||||
|
else:
|
||||||
|
ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1]
|
||||||
|
edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost)))
|
||||||
|
return edge_dxyz
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def center(self) -> NDArray[numpy.float64]:
|
def center(self) -> NDArray[numpy.float64]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -115,15 +130,9 @@ class GridBase(Protocol):
|
||||||
"""
|
"""
|
||||||
if which_shifts is None:
|
if which_shifts is None:
|
||||||
return self.exyz
|
return self.exyz
|
||||||
dxyz = self.dxyz_with_ghost
|
edge_dxyz = self._shifted_edge_dxyz(which_shifts)
|
||||||
shifts = self.shifts[which_shifts, :]
|
shifts = self.shifts[which_shifts, :]
|
||||||
|
return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)]
|
||||||
# If shift is negative, use left cell's dx to determine shift
|
|
||||||
for a in range(3):
|
|
||||||
if shifts[a] < 0:
|
|
||||||
dxyz[a] = numpy.roll(dxyz[a], 1)
|
|
||||||
|
|
||||||
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
|
|
||||||
|
|
||||||
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
|
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -137,20 +146,7 @@ class GridBase(Protocol):
|
||||||
"""
|
"""
|
||||||
if which_shifts is None:
|
if which_shifts is None:
|
||||||
return self.dxyz
|
return self.dxyz
|
||||||
shifts = self.shifts[which_shifts, :]
|
return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)]
|
||||||
dxyz = self.dxyz_with_ghost
|
|
||||||
|
|
||||||
# If shift is negative, use left cell's dx to determine size
|
|
||||||
sdxyz = []
|
|
||||||
for a in range(3):
|
|
||||||
if shifts[a] < 0:
|
|
||||||
roll_dxyz = numpy.roll(dxyz[a], 1)
|
|
||||||
abs_shift = numpy.abs(shifts[a])
|
|
||||||
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
|
|
||||||
else:
|
|
||||||
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
|
|
||||||
|
|
||||||
return sdxyz
|
|
||||||
|
|
||||||
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
176
gridlock/data.py
Normal file
176
gridlock/data.py
Normal file
|
|
@ -0,0 +1,176 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Self
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
from numpy.typing import NDArray, ArrayLike
|
||||||
|
|
||||||
|
from .draw import foreground_t
|
||||||
|
from .grid import Grid, _grid_from_payload, _load_payload, _payload_scalar_str, _save_npz_payload
|
||||||
|
from .utils import (
|
||||||
|
ExtentDict,
|
||||||
|
ExtentProtocol,
|
||||||
|
GridError,
|
||||||
|
PlaneDict,
|
||||||
|
PlaneProtocol,
|
||||||
|
SlabDict,
|
||||||
|
SlabProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class GridData:
|
||||||
|
grid: Grid
|
||||||
|
cell_data: NDArray
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if tuple(self.cell_data.shape) != tuple(self.grid.cell_data_shape):
|
||||||
|
raise GridError(
|
||||||
|
f'cell_data has shape {self.cell_data.shape}, expected {tuple(self.grid.cell_data_shape)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(filename: str) -> 'GridData':
|
||||||
|
payload = _load_payload(filename)
|
||||||
|
if _payload_scalar_str(payload, 'kind') != 'grid_data':
|
||||||
|
raise GridError('Serialized payload does not contain GridData')
|
||||||
|
if 'cell_data' not in payload:
|
||||||
|
raise GridError('Serialized GridData payload is missing cell_data')
|
||||||
|
|
||||||
|
return GridData(_grid_from_payload(payload), numpy.array(payload['cell_data']))
|
||||||
|
|
||||||
|
def save(self, filename: str) -> Self:
|
||||||
|
payload = self.grid._serialization_payload(kind='grid_data')
|
||||||
|
payload['cell_data'] = self.cell_data
|
||||||
|
_save_npz_payload(filename, payload)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def copy(self) -> Self:
|
||||||
|
return GridData(self.grid.copy(), self.cell_data.copy())
|
||||||
|
|
||||||
|
def draw_polygons(
|
||||||
|
self,
|
||||||
|
foreground: Sequence[foreground_t] | foreground_t,
|
||||||
|
slab: SlabProtocol | SlabDict,
|
||||||
|
polygons: Sequence[ArrayLike],
|
||||||
|
*,
|
||||||
|
offset2d: ArrayLike = (0, 0),
|
||||||
|
) -> Self:
|
||||||
|
self.grid.draw_polygons(self.cell_data, foreground, slab, polygons, offset2d=offset2d)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def draw_polygon(
|
||||||
|
self,
|
||||||
|
foreground: Sequence[foreground_t] | foreground_t,
|
||||||
|
slab: SlabProtocol | SlabDict,
|
||||||
|
polygon: ArrayLike,
|
||||||
|
*,
|
||||||
|
offset2d: ArrayLike = (0, 0),
|
||||||
|
) -> Self:
|
||||||
|
self.grid.draw_polygon(self.cell_data, foreground, slab, polygon, offset2d=offset2d)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def draw_slab(
|
||||||
|
self,
|
||||||
|
foreground: Sequence[foreground_t] | foreground_t,
|
||||||
|
slab: SlabProtocol | SlabDict,
|
||||||
|
) -> Self:
|
||||||
|
self.grid.draw_slab(self.cell_data, foreground, slab)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def draw_cuboid(
|
||||||
|
self,
|
||||||
|
foreground: Sequence[foreground_t] | foreground_t,
|
||||||
|
*,
|
||||||
|
x: ExtentProtocol | ExtentDict,
|
||||||
|
y: ExtentProtocol | ExtentDict,
|
||||||
|
z: ExtentProtocol | ExtentDict,
|
||||||
|
) -> Self:
|
||||||
|
self.grid.draw_cuboid(self.cell_data, foreground, x=x, y=y, z=z)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def draw_cylinder(
|
||||||
|
self,
|
||||||
|
h: SlabProtocol | SlabDict,
|
||||||
|
radius: float,
|
||||||
|
num_points: int,
|
||||||
|
center2d: ArrayLike,
|
||||||
|
foreground: Sequence[foreground_t] | foreground_t,
|
||||||
|
) -> Self:
|
||||||
|
self.grid.draw_cylinder(self.cell_data, h, radius, num_points, center2d, foreground)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def draw_extrude_rectangle(
|
||||||
|
self,
|
||||||
|
rectangle: ArrayLike,
|
||||||
|
direction: int,
|
||||||
|
polarity: int,
|
||||||
|
distance: float,
|
||||||
|
) -> Self:
|
||||||
|
self.grid.draw_extrude_rectangle(self.cell_data, rectangle, direction, polarity, distance)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_slice(
|
||||||
|
self,
|
||||||
|
plane: PlaneProtocol | PlaneDict,
|
||||||
|
which_shifts: int = 0,
|
||||||
|
sample_period: int = 1,
|
||||||
|
) -> NDArray:
|
||||||
|
return self.grid.get_slice(self.cell_data, plane, which_shifts=which_shifts, sample_period=sample_period)
|
||||||
|
|
||||||
|
def visualize_slice(
|
||||||
|
self,
|
||||||
|
plane: PlaneProtocol | PlaneDict,
|
||||||
|
which_shifts: int = 0,
|
||||||
|
sample_period: int = 1,
|
||||||
|
finalize: bool = True,
|
||||||
|
pcolormesh_args: dict[str, object] | None = None,
|
||||||
|
ax: object | None = None,
|
||||||
|
) -> tuple[object, object]:
|
||||||
|
return self.grid.visualize_slice(
|
||||||
|
self.cell_data,
|
||||||
|
plane,
|
||||||
|
which_shifts=which_shifts,
|
||||||
|
sample_period=sample_period,
|
||||||
|
finalize=finalize,
|
||||||
|
pcolormesh_args=pcolormesh_args,
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
|
||||||
|
def visualize_edges(
|
||||||
|
self,
|
||||||
|
plane: PlaneProtocol | PlaneDict,
|
||||||
|
which_shifts: int = 0,
|
||||||
|
sample_period: int = 1,
|
||||||
|
finalize: bool = True,
|
||||||
|
contour_args: dict[str, object] | None = None,
|
||||||
|
ax: object | None = None,
|
||||||
|
level_fraction: float = 0.7,
|
||||||
|
) -> tuple[object, object]:
|
||||||
|
return self.grid.visualize_edges(
|
||||||
|
self.cell_data,
|
||||||
|
plane,
|
||||||
|
which_shifts=which_shifts,
|
||||||
|
sample_period=sample_period,
|
||||||
|
finalize=finalize,
|
||||||
|
contour_args=contour_args,
|
||||||
|
ax=ax,
|
||||||
|
level_fraction=level_fraction,
|
||||||
|
)
|
||||||
|
|
||||||
|
def visualize_isosurface(
|
||||||
|
self,
|
||||||
|
level: float | None = None,
|
||||||
|
which_shifts: int = 0,
|
||||||
|
sample_period: int = 1,
|
||||||
|
show_edges: bool = True,
|
||||||
|
finalize: bool = True,
|
||||||
|
) -> tuple[object, object]:
|
||||||
|
return self.grid.visualize_isosurface(
|
||||||
|
self.cell_data,
|
||||||
|
level=level,
|
||||||
|
which_shifts=which_shifts,
|
||||||
|
sample_period=sample_period,
|
||||||
|
show_edges=show_edges,
|
||||||
|
finalize=finalize,
|
||||||
|
)
|
||||||
110
gridlock/grid.py
110
gridlock/grid.py
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import ClassVar, Self
|
from typing import TYPE_CHECKING, Any, ClassVar, Self
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
|
@ -13,8 +13,78 @@ from .draw import GridDrawMixin
|
||||||
from .read import GridReadMixin
|
from .read import GridReadMixin
|
||||||
from .position import GridPosMixin
|
from .position import GridPosMixin
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .data import GridData
|
||||||
|
|
||||||
|
|
||||||
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
|
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
|
||||||
|
_FORMAT_VERSION = 1
|
||||||
|
|
||||||
|
|
||||||
|
def _is_npz_file(filename: str) -> bool:
|
||||||
|
with open(filename, 'rb') as f:
|
||||||
|
return f.read(2) == b'PK'
|
||||||
|
|
||||||
|
|
||||||
|
def _save_npz_payload(filename: str, payload: dict[str, Any]) -> None:
|
||||||
|
with open(filename, 'wb') as f:
|
||||||
|
numpy.savez_compressed(f, **payload)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_payload(filename: str) -> dict[str, Any]:
|
||||||
|
if _is_npz_file(filename):
|
||||||
|
with numpy.load(filename, allow_pickle=False) as payload:
|
||||||
|
return {key: payload[key] for key in payload.files}
|
||||||
|
|
||||||
|
with open(filename, 'rb') as f:
|
||||||
|
legacy = pickle.load(f)
|
||||||
|
|
||||||
|
if isinstance(legacy, Grid):
|
||||||
|
return legacy._serialization_payload(kind='grid')
|
||||||
|
if isinstance(legacy, dict):
|
||||||
|
grid = Grid([[-1, 1]] * 3)
|
||||||
|
grid.__dict__.update(legacy)
|
||||||
|
return grid._serialization_payload(kind='grid')
|
||||||
|
raise GridError('Unsupported serialized Grid payload')
|
||||||
|
|
||||||
|
|
||||||
|
def _payload_scalar_str(payload: dict[str, Any], key: str) -> str:
|
||||||
|
if key not in payload:
|
||||||
|
raise GridError(f'Missing serialized key: {key}')
|
||||||
|
|
||||||
|
value = numpy.asarray(payload[key])
|
||||||
|
if value.size != 1:
|
||||||
|
raise GridError(f'Serialized key {key} must be scalar')
|
||||||
|
return str(value.reshape(()))
|
||||||
|
|
||||||
|
|
||||||
|
def _payload_scalar_int(payload: dict[str, Any], key: str) -> int:
|
||||||
|
if key not in payload:
|
||||||
|
raise GridError(f'Missing serialized key: {key}')
|
||||||
|
|
||||||
|
value = numpy.asarray(payload[key])
|
||||||
|
if value.size != 1:
|
||||||
|
raise GridError(f'Serialized key {key} must be scalar')
|
||||||
|
return int(value.reshape(()))
|
||||||
|
|
||||||
|
|
||||||
|
def _grid_from_payload(payload: dict[str, Any]) -> 'Grid':
|
||||||
|
if _payload_scalar_int(payload, 'format_version') != _FORMAT_VERSION:
|
||||||
|
raise GridError('Unsupported serialized Grid format version')
|
||||||
|
|
||||||
|
exyz = []
|
||||||
|
for axis in range(3):
|
||||||
|
key = f'exyz_{axis}'
|
||||||
|
if key not in payload:
|
||||||
|
raise GridError(f'Missing serialized key: {key}')
|
||||||
|
exyz.append(numpy.array(payload[key], dtype=float))
|
||||||
|
|
||||||
|
if 'shifts' not in payload or 'periodic' not in payload:
|
||||||
|
raise GridError('Serialized Grid payload is missing shifts or periodic data')
|
||||||
|
|
||||||
|
shifts = numpy.array(payload['shifts'], dtype=float)
|
||||||
|
periodic = numpy.array(payload['periodic'], dtype=bool).tolist()
|
||||||
|
return Grid(exyz, shifts=shifts, periodic=periodic)
|
||||||
|
|
||||||
|
|
||||||
class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
||||||
|
|
@ -110,6 +180,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
||||||
self.periodic = list(periodic)
|
self.periodic = list(periodic)
|
||||||
if len(self.periodic) != 3:
|
if len(self.periodic) != 3:
|
||||||
raise GridError('periodic must be a bool or a sequence of length 3')
|
raise GridError('periodic must be a bool or a sequence of length 3')
|
||||||
|
if not all(isinstance(pp, bool | numpy.bool_) for pp in self.periodic):
|
||||||
|
raise GridError('periodic sequence entries must be bool values')
|
||||||
|
|
||||||
if len(self.shifts.shape) != 2:
|
if len(self.shifts.shape) != 2:
|
||||||
raise GridError('Misshapen shifts: shifts must have two axes! '
|
raise GridError('Misshapen shifts: shifts must have two axes! '
|
||||||
|
|
@ -121,9 +193,16 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
||||||
if (numpy.abs(self.shifts) > 1).any():
|
if (numpy.abs(self.shifts) > 1).any():
|
||||||
raise GridError('Only shifts in the range [-1, 1] are currently supported')
|
raise GridError('Only shifts in the range [-1, 1] are currently supported')
|
||||||
|
|
||||||
if (self.shifts < 0).any():
|
def _serialization_payload(self, *, kind: str) -> dict[str, Any]:
|
||||||
# TODO: Test negative shifts
|
payload: dict[str, Any] = {
|
||||||
warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2)
|
'kind': numpy.array(kind),
|
||||||
|
'format_version': numpy.array(_FORMAT_VERSION, dtype=int),
|
||||||
|
'shifts': self.shifts,
|
||||||
|
'periodic': numpy.array(self.periodic, dtype=bool),
|
||||||
|
}
|
||||||
|
for axis, exyz in enumerate(self.exyz):
|
||||||
|
payload[f'exyz_{axis}'] = exyz
|
||||||
|
return payload
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(filename: str) -> 'Grid':
|
def load(filename: str) -> 'Grid':
|
||||||
|
|
@ -133,12 +212,11 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
||||||
Args:
|
Args:
|
||||||
filename: Filename to load from.
|
filename: Filename to load from.
|
||||||
"""
|
"""
|
||||||
with open(filename, 'rb') as f:
|
payload = _load_payload(filename)
|
||||||
tmp_dict = pickle.load(f)
|
kind = _payload_scalar_str(payload, 'kind')
|
||||||
|
if kind not in ('grid', 'grid_data'):
|
||||||
g = Grid([[-1, 1]] * 3)
|
raise GridError(f'Unsupported serialized kind: {kind}')
|
||||||
g.__dict__.update(tmp_dict)
|
return _grid_from_payload(payload)
|
||||||
return g
|
|
||||||
|
|
||||||
def save(self, filename: str) -> Self:
|
def save(self, filename: str) -> Self:
|
||||||
"""
|
"""
|
||||||
|
|
@ -150,10 +228,18 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
||||||
Returns:
|
Returns:
|
||||||
self
|
self
|
||||||
"""
|
"""
|
||||||
with open(filename, 'wb') as f:
|
_save_npz_payload(filename, self._serialization_payload(kind='grid'))
|
||||||
pickle.dump(self.__dict__, f, protocol=2)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_data(
|
||||||
|
self,
|
||||||
|
fill_value: float | None = 1.0,
|
||||||
|
dtype: type[numpy.number] = numpy.float32,
|
||||||
|
) -> 'GridData':
|
||||||
|
from .data import GridData
|
||||||
|
|
||||||
|
return GridData(self.copy(), self.allocate(fill_value=fill_value, dtype=dtype))
|
||||||
|
|
||||||
def copy(self) -> Self:
|
def copy(self) -> Self:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ class GridReadMixin(GridPosMixin):
|
||||||
surface = numpy.delete(range(3), plane.axis)
|
surface = numpy.delete(range(3), plane.axis)
|
||||||
|
|
||||||
# Extract indices and weights of planes
|
# Extract indices and weights of planes
|
||||||
center3 = numpy.insert([0, 0], plane.axis, (plane.pos,))
|
center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,))
|
||||||
center_index = self.pos2ind(center3, which_shifts,
|
center_index = self.pos2ind(center3, which_shifts,
|
||||||
round_ind=False, check_bounds=False)[plane.axis]
|
round_ind=False, check_bounds=False)[plane.axis]
|
||||||
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
|
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
import pytest
|
import pytest
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.testing import assert_allclose #, assert_array_equal
|
from numpy.testing import assert_allclose #, assert_array_equal
|
||||||
|
import pickle
|
||||||
|
|
||||||
from .. import Grid, Extent, GridError, Plane, Slab
|
from .. import Grid, GridData, Extent, GridError, Plane, Slab
|
||||||
|
|
||||||
|
|
||||||
def test_draw_oncenter_2x2() -> None:
|
def test_draw_oncenter_2x2() -> None:
|
||||||
|
|
@ -309,3 +310,157 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.
|
||||||
assert_allclose(captured['zs'], [1.5, 1.5, 1.5])
|
assert_allclose(captured['zs'], [1.5, 1.5, 1.5])
|
||||||
|
|
||||||
pyplot.close(fig)
|
pyplot.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def test_grid_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None:
|
||||||
|
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False])
|
||||||
|
path = tmp_path / 'grid.state'
|
||||||
|
|
||||||
|
grid.save(str(path))
|
||||||
|
loaded = Grid.load(str(path))
|
||||||
|
|
||||||
|
assert path.exists()
|
||||||
|
for original, restored in zip(grid.exyz, loaded.exyz, strict=True):
|
||||||
|
assert_allclose(restored, original)
|
||||||
|
assert_allclose(loaded.shifts, grid.shifts)
|
||||||
|
assert loaded.periodic == grid.periodic
|
||||||
|
|
||||||
|
|
||||||
|
def test_grid_load_supports_legacy_pickle(tmp_path: pytest.TempPathFactory) -> None:
|
||||||
|
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False])
|
||||||
|
path = tmp_path / 'grid.pickle'
|
||||||
|
with open(path, 'wb') as f:
|
||||||
|
pickle.dump(grid.__dict__, f, protocol=2)
|
||||||
|
|
||||||
|
loaded = Grid.load(str(path))
|
||||||
|
|
||||||
|
for original, restored in zip(grid.exyz, loaded.exyz, strict=True):
|
||||||
|
assert_allclose(restored, original)
|
||||||
|
assert_allclose(loaded.shifts, grid.shifts)
|
||||||
|
assert loaded.periodic == grid.periodic
|
||||||
|
|
||||||
|
|
||||||
|
def test_griddata_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None:
|
||||||
|
data = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]]).with_data(fill_value=2.0)
|
||||||
|
data.cell_data[0, 1, 0, 0] = 5.0
|
||||||
|
path = tmp_path / 'griddata.state'
|
||||||
|
|
||||||
|
data.save(str(path))
|
||||||
|
loaded = GridData.load(str(path))
|
||||||
|
|
||||||
|
assert path.exists()
|
||||||
|
assert_allclose(loaded.cell_data, data.cell_data)
|
||||||
|
assert_allclose(loaded.grid.shifts, data.grid.shifts)
|
||||||
|
assert loaded.grid.periodic == data.grid.periodic
|
||||||
|
|
||||||
|
|
||||||
|
def test_griddata_rejects_invalid_payload_kind(tmp_path: pytest.TempPathFactory) -> None:
|
||||||
|
path = tmp_path / 'grid.state'
|
||||||
|
Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).save(str(path))
|
||||||
|
|
||||||
|
with pytest.raises(GridError):
|
||||||
|
GridData.load(str(path))
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative_shift_nonperiodic_edges_and_widths() -> None:
|
||||||
|
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||||
|
|
||||||
|
assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0])
|
||||||
|
assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5])
|
||||||
|
assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25])
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative_shift_periodic_edges_and_widths() -> None:
|
||||||
|
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False])
|
||||||
|
|
||||||
|
assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0])
|
||||||
|
assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5])
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative_shift_coordinate_round_trip() -> None:
|
||||||
|
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||||
|
|
||||||
|
ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False)
|
||||||
|
pos = grid.ind2pos(ind, 0, round_ind=False)
|
||||||
|
|
||||||
|
assert_allclose(ind, [1.0, 0.0, 0.0])
|
||||||
|
assert_allclose(pos, [1.25, 1.0, 0.5])
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative_shift_draw_cuboid_fractional_fill() -> None:
|
||||||
|
grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||||
|
arr = grid.allocate(0)
|
||||||
|
|
||||||
|
grid.draw_cuboid(
|
||||||
|
arr,
|
||||||
|
x=dict(min=0, max=1),
|
||||||
|
y=dict(min=0, max=1),
|
||||||
|
z=dict(min=0, max=1),
|
||||||
|
foreground=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3])
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative_shift_get_slice_uses_shifted_centers() -> None:
|
||||||
|
grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||||
|
cell_data = numpy.zeros(grid.cell_data_shape)
|
||||||
|
cell_data[0, 1, :, 0] = [7, 9]
|
||||||
|
x_center = float(grid.shifted_xyz(0)[0][1])
|
||||||
|
|
||||||
|
grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0)
|
||||||
|
|
||||||
|
assert_allclose(grid_slice, [7, 9])
|
||||||
|
|
||||||
|
|
||||||
|
def test_grid_with_data_returns_griddata() -> None:
|
||||||
|
grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
|
||||||
|
data = grid.with_data(fill_value=2.0)
|
||||||
|
|
||||||
|
assert isinstance(data, GridData)
|
||||||
|
assert_allclose(data.cell_data, numpy.full(grid.cell_data_shape, 2.0, dtype=numpy.float32))
|
||||||
|
|
||||||
|
|
||||||
|
def test_griddata_constructor_validates_shape() -> None:
|
||||||
|
grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
|
||||||
|
|
||||||
|
with pytest.raises(GridError):
|
||||||
|
GridData(grid, numpy.zeros((1, 1, 1)))
|
||||||
|
|
||||||
|
|
||||||
|
def test_griddata_draw_methods_are_chainable() -> None:
|
||||||
|
data = Grid([[0, 1, 2], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0)
|
||||||
|
|
||||||
|
chained = data.draw_cuboid(
|
||||||
|
foreground=1,
|
||||||
|
x=dict(min=0, max=1),
|
||||||
|
y=dict(min=0, max=1),
|
||||||
|
z=dict(min=0, max=1),
|
||||||
|
).draw_polygon(
|
||||||
|
foreground=0.5,
|
||||||
|
slab=dict(axis='z', center=0.5, span=1.0),
|
||||||
|
polygon=numpy.array([[0, 0], [2, 0], [2, 1], [0, 1]], dtype=float),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chained is data
|
||||||
|
assert data.cell_data.sum() > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_griddata_read_methods_delegate() -> None:
|
||||||
|
data = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0)
|
||||||
|
data.cell_data[0, :, :, 0] = numpy.array([[1, 2], [3, 4]], dtype=float)
|
||||||
|
|
||||||
|
assert_allclose(
|
||||||
|
data.get_slice(Plane(z=0.5)),
|
||||||
|
data.grid.get_slice(data.cell_data, Plane(z=0.5)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_griddata_copy_is_independent() -> None:
|
||||||
|
data = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=1.0)
|
||||||
|
cloned = data.copy()
|
||||||
|
cloned.cell_data[0, 0, 0, 0] = 5.0
|
||||||
|
|
||||||
|
assert data is not cloned
|
||||||
|
assert data.grid is not cloned.grid
|
||||||
|
assert data.cell_data[0, 0, 0, 0] == 1.0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue