Compare commits
2 commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 22cb410d84 | |||
| 85ae6e66cd |
6 changed files with 450 additions and 36 deletions
|
|
@ -31,6 +31,7 @@ from .utils import (
|
|||
PlaneDict as PlaneDict,
|
||||
)
|
||||
from .grid import Grid as Grid
|
||||
from .data import GridData as GridData
|
||||
|
||||
|
||||
__author__ = 'Jan Petykiewicz'
|
||||
|
|
|
|||
|
|
@ -76,6 +76,21 @@ class GridBase(Protocol):
|
|||
el = [0 if p else -1 for p in self.periodic]
|
||||
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
|
||||
|
||||
def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
||||
if which_shifts is None:
|
||||
return self.dxyz_with_ghost
|
||||
|
||||
shifts = self.shifts[which_shifts, :]
|
||||
edge_dxyz = []
|
||||
for a in range(3):
|
||||
if shifts[a] < 0:
|
||||
ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0]
|
||||
edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a])))
|
||||
else:
|
||||
ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1]
|
||||
edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost)))
|
||||
return edge_dxyz
|
||||
|
||||
@property
|
||||
def center(self) -> NDArray[numpy.float64]:
|
||||
"""
|
||||
|
|
@ -115,15 +130,9 @@ class GridBase(Protocol):
|
|||
"""
|
||||
if which_shifts is None:
|
||||
return self.exyz
|
||||
dxyz = self.dxyz_with_ghost
|
||||
edge_dxyz = self._shifted_edge_dxyz(which_shifts)
|
||||
shifts = self.shifts[which_shifts, :]
|
||||
|
||||
# If shift is negative, use left cell's dx to determine shift
|
||||
for a in range(3):
|
||||
if shifts[a] < 0:
|
||||
dxyz[a] = numpy.roll(dxyz[a], 1)
|
||||
|
||||
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
|
||||
return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)]
|
||||
|
||||
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
|
||||
"""
|
||||
|
|
@ -137,20 +146,7 @@ class GridBase(Protocol):
|
|||
"""
|
||||
if which_shifts is None:
|
||||
return self.dxyz
|
||||
shifts = self.shifts[which_shifts, :]
|
||||
dxyz = self.dxyz_with_ghost
|
||||
|
||||
# If shift is negative, use left cell's dx to determine size
|
||||
sdxyz = []
|
||||
for a in range(3):
|
||||
if shifts[a] < 0:
|
||||
roll_dxyz = numpy.roll(dxyz[a], 1)
|
||||
abs_shift = numpy.abs(shifts[a])
|
||||
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
|
||||
else:
|
||||
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
|
||||
|
||||
return sdxyz
|
||||
return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)]
|
||||
|
||||
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
||||
"""
|
||||
|
|
|
|||
176
gridlock/data.py
Normal file
176
gridlock/data.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
from collections.abc import Sequence
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray, ArrayLike
|
||||
|
||||
from .draw import foreground_t
|
||||
from .grid import Grid, _grid_from_payload, _load_payload, _payload_scalar_str, _save_npz_payload
|
||||
from .utils import (
|
||||
ExtentDict,
|
||||
ExtentProtocol,
|
||||
GridError,
|
||||
PlaneDict,
|
||||
PlaneProtocol,
|
||||
SlabDict,
|
||||
SlabProtocol,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GridData:
|
||||
grid: Grid
|
||||
cell_data: NDArray
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if tuple(self.cell_data.shape) != tuple(self.grid.cell_data_shape):
|
||||
raise GridError(
|
||||
f'cell_data has shape {self.cell_data.shape}, expected {tuple(self.grid.cell_data_shape)}'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load(filename: str) -> 'GridData':
|
||||
payload = _load_payload(filename)
|
||||
if _payload_scalar_str(payload, 'kind') != 'grid_data':
|
||||
raise GridError('Serialized payload does not contain GridData')
|
||||
if 'cell_data' not in payload:
|
||||
raise GridError('Serialized GridData payload is missing cell_data')
|
||||
|
||||
return GridData(_grid_from_payload(payload), numpy.array(payload['cell_data']))
|
||||
|
||||
def save(self, filename: str) -> Self:
|
||||
payload = self.grid._serialization_payload(kind='grid_data')
|
||||
payload['cell_data'] = self.cell_data
|
||||
_save_npz_payload(filename, payload)
|
||||
return self
|
||||
|
||||
def copy(self) -> Self:
|
||||
return GridData(self.grid.copy(), self.cell_data.copy())
|
||||
|
||||
def draw_polygons(
|
||||
self,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
slab: SlabProtocol | SlabDict,
|
||||
polygons: Sequence[ArrayLike],
|
||||
*,
|
||||
offset2d: ArrayLike = (0, 0),
|
||||
) -> Self:
|
||||
self.grid.draw_polygons(self.cell_data, foreground, slab, polygons, offset2d=offset2d)
|
||||
return self
|
||||
|
||||
def draw_polygon(
|
||||
self,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
slab: SlabProtocol | SlabDict,
|
||||
polygon: ArrayLike,
|
||||
*,
|
||||
offset2d: ArrayLike = (0, 0),
|
||||
) -> Self:
|
||||
self.grid.draw_polygon(self.cell_data, foreground, slab, polygon, offset2d=offset2d)
|
||||
return self
|
||||
|
||||
def draw_slab(
|
||||
self,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
slab: SlabProtocol | SlabDict,
|
||||
) -> Self:
|
||||
self.grid.draw_slab(self.cell_data, foreground, slab)
|
||||
return self
|
||||
|
||||
def draw_cuboid(
|
||||
self,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
*,
|
||||
x: ExtentProtocol | ExtentDict,
|
||||
y: ExtentProtocol | ExtentDict,
|
||||
z: ExtentProtocol | ExtentDict,
|
||||
) -> Self:
|
||||
self.grid.draw_cuboid(self.cell_data, foreground, x=x, y=y, z=z)
|
||||
return self
|
||||
|
||||
def draw_cylinder(
|
||||
self,
|
||||
h: SlabProtocol | SlabDict,
|
||||
radius: float,
|
||||
num_points: int,
|
||||
center2d: ArrayLike,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
) -> Self:
|
||||
self.grid.draw_cylinder(self.cell_data, h, radius, num_points, center2d, foreground)
|
||||
return self
|
||||
|
||||
def draw_extrude_rectangle(
|
||||
self,
|
||||
rectangle: ArrayLike,
|
||||
direction: int,
|
||||
polarity: int,
|
||||
distance: float,
|
||||
) -> Self:
|
||||
self.grid.draw_extrude_rectangle(self.cell_data, rectangle, direction, polarity, distance)
|
||||
return self
|
||||
|
||||
def get_slice(
|
||||
self,
|
||||
plane: PlaneProtocol | PlaneDict,
|
||||
which_shifts: int = 0,
|
||||
sample_period: int = 1,
|
||||
) -> NDArray:
|
||||
return self.grid.get_slice(self.cell_data, plane, which_shifts=which_shifts, sample_period=sample_period)
|
||||
|
||||
def visualize_slice(
|
||||
self,
|
||||
plane: PlaneProtocol | PlaneDict,
|
||||
which_shifts: int = 0,
|
||||
sample_period: int = 1,
|
||||
finalize: bool = True,
|
||||
pcolormesh_args: dict[str, object] | None = None,
|
||||
ax: object | None = None,
|
||||
) -> tuple[object, object]:
|
||||
return self.grid.visualize_slice(
|
||||
self.cell_data,
|
||||
plane,
|
||||
which_shifts=which_shifts,
|
||||
sample_period=sample_period,
|
||||
finalize=finalize,
|
||||
pcolormesh_args=pcolormesh_args,
|
||||
ax=ax,
|
||||
)
|
||||
|
||||
def visualize_edges(
|
||||
self,
|
||||
plane: PlaneProtocol | PlaneDict,
|
||||
which_shifts: int = 0,
|
||||
sample_period: int = 1,
|
||||
finalize: bool = True,
|
||||
contour_args: dict[str, object] | None = None,
|
||||
ax: object | None = None,
|
||||
level_fraction: float = 0.7,
|
||||
) -> tuple[object, object]:
|
||||
return self.grid.visualize_edges(
|
||||
self.cell_data,
|
||||
plane,
|
||||
which_shifts=which_shifts,
|
||||
sample_period=sample_period,
|
||||
finalize=finalize,
|
||||
contour_args=contour_args,
|
||||
ax=ax,
|
||||
level_fraction=level_fraction,
|
||||
)
|
||||
|
||||
def visualize_isosurface(
|
||||
self,
|
||||
level: float | None = None,
|
||||
which_shifts: int = 0,
|
||||
sample_period: int = 1,
|
||||
show_edges: bool = True,
|
||||
finalize: bool = True,
|
||||
) -> tuple[object, object]:
|
||||
return self.grid.visualize_isosurface(
|
||||
self.cell_data,
|
||||
level=level,
|
||||
which_shifts=which_shifts,
|
||||
sample_period=sample_period,
|
||||
show_edges=show_edges,
|
||||
finalize=finalize,
|
||||
)
|
||||
110
gridlock/grid.py
110
gridlock/grid.py
|
|
@ -1,4 +1,4 @@
|
|||
from typing import ClassVar, Self
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Self
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import numpy
|
||||
|
|
@ -13,8 +13,78 @@ from .draw import GridDrawMixin
|
|||
from .read import GridReadMixin
|
||||
from .position import GridPosMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .data import GridData
|
||||
|
||||
|
||||
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
|
||||
_FORMAT_VERSION = 1
|
||||
|
||||
|
||||
def _is_npz_file(filename: str) -> bool:
|
||||
with open(filename, 'rb') as f:
|
||||
return f.read(2) == b'PK'
|
||||
|
||||
|
||||
def _save_npz_payload(filename: str, payload: dict[str, Any]) -> None:
|
||||
with open(filename, 'wb') as f:
|
||||
numpy.savez_compressed(f, **payload)
|
||||
|
||||
|
||||
def _load_payload(filename: str) -> dict[str, Any]:
|
||||
if _is_npz_file(filename):
|
||||
with numpy.load(filename, allow_pickle=False) as payload:
|
||||
return {key: payload[key] for key in payload.files}
|
||||
|
||||
with open(filename, 'rb') as f:
|
||||
legacy = pickle.load(f)
|
||||
|
||||
if isinstance(legacy, Grid):
|
||||
return legacy._serialization_payload(kind='grid')
|
||||
if isinstance(legacy, dict):
|
||||
grid = Grid([[-1, 1]] * 3)
|
||||
grid.__dict__.update(legacy)
|
||||
return grid._serialization_payload(kind='grid')
|
||||
raise GridError('Unsupported serialized Grid payload')
|
||||
|
||||
|
||||
def _payload_scalar_str(payload: dict[str, Any], key: str) -> str:
|
||||
if key not in payload:
|
||||
raise GridError(f'Missing serialized key: {key}')
|
||||
|
||||
value = numpy.asarray(payload[key])
|
||||
if value.size != 1:
|
||||
raise GridError(f'Serialized key {key} must be scalar')
|
||||
return str(value.reshape(()))
|
||||
|
||||
|
||||
def _payload_scalar_int(payload: dict[str, Any], key: str) -> int:
|
||||
if key not in payload:
|
||||
raise GridError(f'Missing serialized key: {key}')
|
||||
|
||||
value = numpy.asarray(payload[key])
|
||||
if value.size != 1:
|
||||
raise GridError(f'Serialized key {key} must be scalar')
|
||||
return int(value.reshape(()))
|
||||
|
||||
|
||||
def _grid_from_payload(payload: dict[str, Any]) -> 'Grid':
|
||||
if _payload_scalar_int(payload, 'format_version') != _FORMAT_VERSION:
|
||||
raise GridError('Unsupported serialized Grid format version')
|
||||
|
||||
exyz = []
|
||||
for axis in range(3):
|
||||
key = f'exyz_{axis}'
|
||||
if key not in payload:
|
||||
raise GridError(f'Missing serialized key: {key}')
|
||||
exyz.append(numpy.array(payload[key], dtype=float))
|
||||
|
||||
if 'shifts' not in payload or 'periodic' not in payload:
|
||||
raise GridError('Serialized Grid payload is missing shifts or periodic data')
|
||||
|
||||
shifts = numpy.array(payload['shifts'], dtype=float)
|
||||
periodic = numpy.array(payload['periodic'], dtype=bool).tolist()
|
||||
return Grid(exyz, shifts=shifts, periodic=periodic)
|
||||
|
||||
|
||||
class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
||||
|
|
@ -110,6 +180,8 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
self.periodic = list(periodic)
|
||||
if len(self.periodic) != 3:
|
||||
raise GridError('periodic must be a bool or a sequence of length 3')
|
||||
if not all(isinstance(pp, bool | numpy.bool_) for pp in self.periodic):
|
||||
raise GridError('periodic sequence entries must be bool values')
|
||||
|
||||
if len(self.shifts.shape) != 2:
|
||||
raise GridError('Misshapen shifts: shifts must have two axes! '
|
||||
|
|
@ -121,9 +193,16 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
if (numpy.abs(self.shifts) > 1).any():
|
||||
raise GridError('Only shifts in the range [-1, 1] are currently supported')
|
||||
|
||||
if (self.shifts < 0).any():
|
||||
# TODO: Test negative shifts
|
||||
warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2)
|
||||
def _serialization_payload(self, *, kind: str) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
'kind': numpy.array(kind),
|
||||
'format_version': numpy.array(_FORMAT_VERSION, dtype=int),
|
||||
'shifts': self.shifts,
|
||||
'periodic': numpy.array(self.periodic, dtype=bool),
|
||||
}
|
||||
for axis, exyz in enumerate(self.exyz):
|
||||
payload[f'exyz_{axis}'] = exyz
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def load(filename: str) -> 'Grid':
|
||||
|
|
@ -133,12 +212,11 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
Args:
|
||||
filename: Filename to load from.
|
||||
"""
|
||||
with open(filename, 'rb') as f:
|
||||
tmp_dict = pickle.load(f)
|
||||
|
||||
g = Grid([[-1, 1]] * 3)
|
||||
g.__dict__.update(tmp_dict)
|
||||
return g
|
||||
payload = _load_payload(filename)
|
||||
kind = _payload_scalar_str(payload, 'kind')
|
||||
if kind not in ('grid', 'grid_data'):
|
||||
raise GridError(f'Unsupported serialized kind: {kind}')
|
||||
return _grid_from_payload(payload)
|
||||
|
||||
def save(self, filename: str) -> Self:
|
||||
"""
|
||||
|
|
@ -150,10 +228,18 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
Returns:
|
||||
self
|
||||
"""
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(self.__dict__, f, protocol=2)
|
||||
_save_npz_payload(filename, self._serialization_payload(kind='grid'))
|
||||
return self
|
||||
|
||||
def with_data(
|
||||
self,
|
||||
fill_value: float | None = 1.0,
|
||||
dtype: type[numpy.number] = numpy.float32,
|
||||
) -> 'GridData':
|
||||
from .data import GridData
|
||||
|
||||
return GridData(self.copy(), self.allocate(fill_value=fill_value, dtype=dtype))
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""
|
||||
Returns:
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class GridReadMixin(GridPosMixin):
|
|||
surface = numpy.delete(range(3), plane.axis)
|
||||
|
||||
# Extract indices and weights of planes
|
||||
center3 = numpy.insert([0, 0], plane.axis, (plane.pos,))
|
||||
center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,))
|
||||
center_index = self.pos2ind(center3, which_shifts,
|
||||
round_ind=False, check_bounds=False)[plane.axis]
|
||||
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import pytest
|
||||
import numpy
|
||||
from numpy.testing import assert_allclose #, assert_array_equal
|
||||
import pickle
|
||||
|
||||
from .. import Grid, Extent, GridError, Plane, Slab
|
||||
from .. import Grid, GridData, Extent, GridError, Plane, Slab
|
||||
|
||||
|
||||
def test_draw_oncenter_2x2() -> None:
|
||||
|
|
@ -309,3 +310,157 @@ def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.
|
|||
assert_allclose(captured['zs'], [1.5, 1.5, 1.5])
|
||||
|
||||
pyplot.close(fig)
|
||||
|
||||
|
||||
def test_grid_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False])
|
||||
path = tmp_path / 'grid.state'
|
||||
|
||||
grid.save(str(path))
|
||||
loaded = Grid.load(str(path))
|
||||
|
||||
assert path.exists()
|
||||
for original, restored in zip(grid.exyz, loaded.exyz, strict=True):
|
||||
assert_allclose(restored, original)
|
||||
assert_allclose(loaded.shifts, grid.shifts)
|
||||
assert loaded.periodic == grid.periodic
|
||||
|
||||
|
||||
def test_grid_load_supports_legacy_pickle(tmp_path: pytest.TempPathFactory) -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False])
|
||||
path = tmp_path / 'grid.pickle'
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(grid.__dict__, f, protocol=2)
|
||||
|
||||
loaded = Grid.load(str(path))
|
||||
|
||||
for original, restored in zip(grid.exyz, loaded.exyz, strict=True):
|
||||
assert_allclose(restored, original)
|
||||
assert_allclose(loaded.shifts, grid.shifts)
|
||||
assert loaded.periodic == grid.periodic
|
||||
|
||||
|
||||
def test_griddata_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None:
|
||||
data = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]]).with_data(fill_value=2.0)
|
||||
data.cell_data[0, 1, 0, 0] = 5.0
|
||||
path = tmp_path / 'griddata.state'
|
||||
|
||||
data.save(str(path))
|
||||
loaded = GridData.load(str(path))
|
||||
|
||||
assert path.exists()
|
||||
assert_allclose(loaded.cell_data, data.cell_data)
|
||||
assert_allclose(loaded.grid.shifts, data.grid.shifts)
|
||||
assert loaded.grid.periodic == data.grid.periodic
|
||||
|
||||
|
||||
def test_griddata_rejects_invalid_payload_kind(tmp_path: pytest.TempPathFactory) -> None:
|
||||
path = tmp_path / 'grid.state'
|
||||
Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).save(str(path))
|
||||
|
||||
with pytest.raises(GridError):
|
||||
GridData.load(str(path))
|
||||
|
||||
|
||||
def test_negative_shift_nonperiodic_edges_and_widths() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||
|
||||
assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0])
|
||||
assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5])
|
||||
assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25])
|
||||
|
||||
|
||||
def test_negative_shift_periodic_edges_and_widths() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False])
|
||||
|
||||
assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0])
|
||||
assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5])
|
||||
|
||||
|
||||
def test_negative_shift_coordinate_round_trip() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||
|
||||
ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False)
|
||||
pos = grid.ind2pos(ind, 0, round_ind=False)
|
||||
|
||||
assert_allclose(ind, [1.0, 0.0, 0.0])
|
||||
assert_allclose(pos, [1.25, 1.0, 0.5])
|
||||
|
||||
|
||||
def test_negative_shift_draw_cuboid_fractional_fill() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||
arr = grid.allocate(0)
|
||||
|
||||
grid.draw_cuboid(
|
||||
arr,
|
||||
x=dict(min=0, max=1),
|
||||
y=dict(min=0, max=1),
|
||||
z=dict(min=0, max=1),
|
||||
foreground=1,
|
||||
)
|
||||
|
||||
assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3])
|
||||
|
||||
|
||||
def test_negative_shift_get_slice_uses_shifted_centers() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||
cell_data = numpy.zeros(grid.cell_data_shape)
|
||||
cell_data[0, 1, :, 0] = [7, 9]
|
||||
x_center = float(grid.shifted_xyz(0)[0][1])
|
||||
|
||||
grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0)
|
||||
|
||||
assert_allclose(grid_slice, [7, 9])
|
||||
|
||||
|
||||
def test_grid_with_data_returns_griddata() -> None:
|
||||
grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
|
||||
data = grid.with_data(fill_value=2.0)
|
||||
|
||||
assert isinstance(data, GridData)
|
||||
assert_allclose(data.cell_data, numpy.full(grid.cell_data_shape, 2.0, dtype=numpy.float32))
|
||||
|
||||
|
||||
def test_griddata_constructor_validates_shape() -> None:
|
||||
grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
|
||||
|
||||
with pytest.raises(GridError):
|
||||
GridData(grid, numpy.zeros((1, 1, 1)))
|
||||
|
||||
|
||||
def test_griddata_draw_methods_are_chainable() -> None:
|
||||
data = Grid([[0, 1, 2], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0)
|
||||
|
||||
chained = data.draw_cuboid(
|
||||
foreground=1,
|
||||
x=dict(min=0, max=1),
|
||||
y=dict(min=0, max=1),
|
||||
z=dict(min=0, max=1),
|
||||
).draw_polygon(
|
||||
foreground=0.5,
|
||||
slab=dict(axis='z', center=0.5, span=1.0),
|
||||
polygon=numpy.array([[0, 0], [2, 0], [2, 1], [0, 1]], dtype=float),
|
||||
)
|
||||
|
||||
assert chained is data
|
||||
assert data.cell_data.sum() > 0
|
||||
|
||||
|
||||
def test_griddata_read_methods_delegate() -> None:
|
||||
data = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0)
|
||||
data.cell_data[0, :, :, 0] = numpy.array([[1, 2], [3, 4]], dtype=float)
|
||||
|
||||
assert_allclose(
|
||||
data.get_slice(Plane(z=0.5)),
|
||||
data.grid.get_slice(data.cell_data, Plane(z=0.5)),
|
||||
)
|
||||
|
||||
|
||||
def test_griddata_copy_is_independent() -> None:
|
||||
data = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=1.0)
|
||||
cloned = data.copy()
|
||||
cloned.cell_data[0, 0, 0, 0] = 5.0
|
||||
|
||||
assert data is not cloned
|
||||
assert data.grid is not cloned.grid
|
||||
assert data.cell_data[0, 0, 0, 0] == 1.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue