Compare commits
No commits in common. "master" and "vis_edges" have entirely different histories.
9 changed files with 102 additions and 760 deletions
|
|
@ -31,9 +31,8 @@ from .utils import (
|
|||
PlaneDict as PlaneDict,
|
||||
)
|
||||
from .grid import Grid as Grid
|
||||
from .data import GridData as GridData
|
||||
|
||||
|
||||
__author__ = 'Jan Petykiewicz'
|
||||
__version__ = '2.2'
|
||||
__version__ = '2.0'
|
||||
version = __version__
|
||||
|
|
|
|||
|
|
@ -76,21 +76,6 @@ class GridBase(Protocol):
|
|||
el = [0 if p else -1 for p in self.periodic]
|
||||
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
|
||||
|
||||
def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
||||
if which_shifts is None:
|
||||
return self.dxyz_with_ghost
|
||||
|
||||
shifts = self.shifts[which_shifts, :]
|
||||
edge_dxyz = []
|
||||
for a in range(3):
|
||||
if shifts[a] < 0:
|
||||
ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0]
|
||||
edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a])))
|
||||
else:
|
||||
ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1]
|
||||
edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost)))
|
||||
return edge_dxyz
|
||||
|
||||
@property
|
||||
def center(self) -> NDArray[numpy.float64]:
|
||||
"""
|
||||
|
|
@ -130,9 +115,15 @@ class GridBase(Protocol):
|
|||
"""
|
||||
if which_shifts is None:
|
||||
return self.exyz
|
||||
edge_dxyz = self._shifted_edge_dxyz(which_shifts)
|
||||
dxyz = self.dxyz_with_ghost
|
||||
shifts = self.shifts[which_shifts, :]
|
||||
return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)]
|
||||
|
||||
# If shift is negative, use left cell's dx to determine shift
|
||||
for a in range(3):
|
||||
if shifts[a] < 0:
|
||||
dxyz[a] = numpy.roll(dxyz[a], 1)
|
||||
|
||||
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
|
||||
|
||||
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
|
||||
"""
|
||||
|
|
@ -146,7 +137,20 @@ class GridBase(Protocol):
|
|||
"""
|
||||
if which_shifts is None:
|
||||
return self.dxyz
|
||||
return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)]
|
||||
shifts = self.shifts[which_shifts, :]
|
||||
dxyz = self.dxyz_with_ghost
|
||||
|
||||
# If shift is negative, use left cell's dx to determine size
|
||||
sdxyz = []
|
||||
for a in range(3):
|
||||
if shifts[a] < 0:
|
||||
roll_dxyz = numpy.roll(dxyz[a], 1)
|
||||
abs_shift = numpy.abs(shifts[a])
|
||||
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
|
||||
else:
|
||||
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
|
||||
|
||||
return sdxyz
|
||||
|
||||
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
||||
"""
|
||||
|
|
|
|||
176
gridlock/data.py
176
gridlock/data.py
|
|
@ -1,176 +0,0 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
from collections.abc import Sequence
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray, ArrayLike
|
||||
|
||||
from .draw import foreground_t
|
||||
from .grid import Grid, _grid_from_payload, _load_payload, _payload_scalar_str, _save_npz_payload
|
||||
from .utils import (
|
||||
ExtentDict,
|
||||
ExtentProtocol,
|
||||
GridError,
|
||||
PlaneDict,
|
||||
PlaneProtocol,
|
||||
SlabDict,
|
||||
SlabProtocol,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GridData:
|
||||
grid: Grid
|
||||
cell_data: NDArray
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if tuple(self.cell_data.shape) != tuple(self.grid.cell_data_shape):
|
||||
raise GridError(
|
||||
f'cell_data has shape {self.cell_data.shape}, expected {tuple(self.grid.cell_data_shape)}'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load(filename: str) -> 'GridData':
|
||||
payload = _load_payload(filename)
|
||||
if _payload_scalar_str(payload, 'kind') != 'grid_data':
|
||||
raise GridError('Serialized payload does not contain GridData')
|
||||
if 'cell_data' not in payload:
|
||||
raise GridError('Serialized GridData payload is missing cell_data')
|
||||
|
||||
return GridData(_grid_from_payload(payload), numpy.array(payload['cell_data']))
|
||||
|
||||
def save(self, filename: str) -> Self:
|
||||
payload = self.grid._serialization_payload(kind='grid_data')
|
||||
payload['cell_data'] = self.cell_data
|
||||
_save_npz_payload(filename, payload)
|
||||
return self
|
||||
|
||||
def copy(self) -> Self:
|
||||
return GridData(self.grid.copy(), self.cell_data.copy())
|
||||
|
||||
def draw_polygons(
|
||||
self,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
slab: SlabProtocol | SlabDict,
|
||||
polygons: Sequence[ArrayLike],
|
||||
*,
|
||||
offset2d: ArrayLike = (0, 0),
|
||||
) -> Self:
|
||||
self.grid.draw_polygons(self.cell_data, foreground, slab, polygons, offset2d=offset2d)
|
||||
return self
|
||||
|
||||
def draw_polygon(
|
||||
self,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
slab: SlabProtocol | SlabDict,
|
||||
polygon: ArrayLike,
|
||||
*,
|
||||
offset2d: ArrayLike = (0, 0),
|
||||
) -> Self:
|
||||
self.grid.draw_polygon(self.cell_data, foreground, slab, polygon, offset2d=offset2d)
|
||||
return self
|
||||
|
||||
def draw_slab(
|
||||
self,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
slab: SlabProtocol | SlabDict,
|
||||
) -> Self:
|
||||
self.grid.draw_slab(self.cell_data, foreground, slab)
|
||||
return self
|
||||
|
||||
def draw_cuboid(
|
||||
self,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
*,
|
||||
x: ExtentProtocol | ExtentDict,
|
||||
y: ExtentProtocol | ExtentDict,
|
||||
z: ExtentProtocol | ExtentDict,
|
||||
) -> Self:
|
||||
self.grid.draw_cuboid(self.cell_data, foreground, x=x, y=y, z=z)
|
||||
return self
|
||||
|
||||
def draw_cylinder(
|
||||
self,
|
||||
h: SlabProtocol | SlabDict,
|
||||
radius: float,
|
||||
num_points: int,
|
||||
center2d: ArrayLike,
|
||||
foreground: Sequence[foreground_t] | foreground_t,
|
||||
) -> Self:
|
||||
self.grid.draw_cylinder(self.cell_data, h, radius, num_points, center2d, foreground)
|
||||
return self
|
||||
|
||||
def draw_extrude_rectangle(
|
||||
self,
|
||||
rectangle: ArrayLike,
|
||||
direction: int,
|
||||
polarity: int,
|
||||
distance: float,
|
||||
) -> Self:
|
||||
self.grid.draw_extrude_rectangle(self.cell_data, rectangle, direction, polarity, distance)
|
||||
return self
|
||||
|
||||
def get_slice(
|
||||
self,
|
||||
plane: PlaneProtocol | PlaneDict,
|
||||
which_shifts: int = 0,
|
||||
sample_period: int = 1,
|
||||
) -> NDArray:
|
||||
return self.grid.get_slice(self.cell_data, plane, which_shifts=which_shifts, sample_period=sample_period)
|
||||
|
||||
def visualize_slice(
|
||||
self,
|
||||
plane: PlaneProtocol | PlaneDict,
|
||||
which_shifts: int = 0,
|
||||
sample_period: int = 1,
|
||||
finalize: bool = True,
|
||||
pcolormesh_args: dict[str, object] | None = None,
|
||||
ax: object | None = None,
|
||||
) -> tuple[object, object]:
|
||||
return self.grid.visualize_slice(
|
||||
self.cell_data,
|
||||
plane,
|
||||
which_shifts=which_shifts,
|
||||
sample_period=sample_period,
|
||||
finalize=finalize,
|
||||
pcolormesh_args=pcolormesh_args,
|
||||
ax=ax,
|
||||
)
|
||||
|
||||
def visualize_edges(
|
||||
self,
|
||||
plane: PlaneProtocol | PlaneDict,
|
||||
which_shifts: int = 0,
|
||||
sample_period: int = 1,
|
||||
finalize: bool = True,
|
||||
contour_args: dict[str, object] | None = None,
|
||||
ax: object | None = None,
|
||||
level_fraction: float = 0.7,
|
||||
) -> tuple[object, object]:
|
||||
return self.grid.visualize_edges(
|
||||
self.cell_data,
|
||||
plane,
|
||||
which_shifts=which_shifts,
|
||||
sample_period=sample_period,
|
||||
finalize=finalize,
|
||||
contour_args=contour_args,
|
||||
ax=ax,
|
||||
level_fraction=level_fraction,
|
||||
)
|
||||
|
||||
def visualize_isosurface(
|
||||
self,
|
||||
level: float | None = None,
|
||||
which_shifts: int = 0,
|
||||
sample_period: int = 1,
|
||||
show_edges: bool = True,
|
||||
finalize: bool = True,
|
||||
) -> tuple[object, object]:
|
||||
return self.grid.visualize_isosurface(
|
||||
self.cell_data,
|
||||
level=level,
|
||||
which_shifts=which_shifts,
|
||||
sample_period=sample_period,
|
||||
show_edges=show_edges,
|
||||
finalize=finalize,
|
||||
)
|
||||
|
|
@ -61,25 +61,23 @@ class GridDrawMixin(GridPosMixin):
|
|||
for ii in range(len(poly_list)):
|
||||
polygon = poly_list[ii]
|
||||
malformed = f'Malformed polygon: ({ii})'
|
||||
if polygon.ndim != 2:
|
||||
raise GridError(malformed + 'must be a 2-dimensional ndarray')
|
||||
if polygon.shape[1] not in (2, 3):
|
||||
raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray')
|
||||
if polygon.shape[1] == 3:
|
||||
if numpy.unique(polygon[:, slab.axis]).size != 1:
|
||||
raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis])
|
||||
polygon = polygon[:, surface]
|
||||
polygon = polygon[surface, :]
|
||||
poly_list[ii] = polygon
|
||||
|
||||
if not polygon.shape[0] > 2:
|
||||
raise GridError(malformed + 'must consist of more than 2 points')
|
||||
if polygon.ndim > 2 and not numpy.unique(polygon[:, slab.axis]).size == 1:
|
||||
raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis])
|
||||
|
||||
# Broadcast foreground where necessary
|
||||
foregrounds: Sequence[foreground_callable_t] | Sequence[float]
|
||||
if isinstance(foreground, numpy.ndarray):
|
||||
if numpy.size(foreground) == 1: # type: ignore
|
||||
foregrounds = [foreground] * len(cell_data) # type: ignore
|
||||
elif isinstance(foreground, numpy.ndarray):
|
||||
raise GridError('ndarray not supported for foreground')
|
||||
if callable(foreground) or numpy.isscalar(foreground):
|
||||
foregrounds = [foreground] * len(cell_data) # type: ignore[list-item]
|
||||
else:
|
||||
foregrounds = foreground # type: ignore
|
||||
|
||||
|
|
@ -298,6 +296,8 @@ class GridDrawMixin(GridPosMixin):
|
|||
if isinstance(z, dict):
|
||||
z = Extent(**z)
|
||||
|
||||
center = numpy.asarray([x.center, y.center, z.center])
|
||||
|
||||
p = numpy.array([[x.min, y.max],
|
||||
[x.max, y.max],
|
||||
[x.max, y.min],
|
||||
|
|
@ -376,18 +376,15 @@ class GridDrawMixin(GridPosMixin):
|
|||
foreground_func = []
|
||||
for ii, grid in enumerate(cell_data):
|
||||
zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction]
|
||||
|
||||
ind = [int(numpy.floor(zz)) if dd == direction else slice(None) for dd in range(3)]
|
||||
|
||||
fpart = zz - numpy.floor(zz)
|
||||
low = int(numpy.clip(numpy.floor(zz), 0, grid.shape[direction] - 1))
|
||||
high = int(numpy.clip(numpy.floor(zz) + 1, 0, grid.shape[direction] - 1))
|
||||
mult = [1 - fpart, fpart][::sgn] # reverses if s negative
|
||||
|
||||
low_ind = [low if dd == direction else slice(None) for dd in range(3)]
|
||||
high_ind = [high if dd == direction else slice(None) for dd in range(3)]
|
||||
|
||||
if low == high:
|
||||
foreground = grid[tuple(low_ind)]
|
||||
else:
|
||||
mult = [1 - fpart, fpart][::sgn] # reverses if s negative
|
||||
foreground = mult[0] * grid[tuple(low_ind)] + mult[1] * grid[tuple(high_ind)]
|
||||
foreground = mult[0] * grid[tuple(ind)]
|
||||
ind[direction] += 1 # type: ignore #(known safe)
|
||||
foreground += mult[1] * grid[tuple(ind)]
|
||||
|
||||
def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001
|
||||
# transform from natural position to index
|
||||
|
|
@ -401,3 +398,4 @@ class GridDrawMixin(GridPosMixin):
|
|||
|
||||
slab = Slab(axis=direction, center=center[direction], span=thickness)
|
||||
self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface])
|
||||
|
||||
|
|
|
|||
114
gridlock/grid.py
114
gridlock/grid.py
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TYPE_CHECKING, Any, ClassVar, Self
|
||||
from typing import ClassVar, Self
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import numpy
|
||||
|
|
@ -13,78 +13,8 @@ from .draw import GridDrawMixin
|
|||
from .read import GridReadMixin
|
||||
from .position import GridPosMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .data import GridData
|
||||
|
||||
|
||||
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
|
||||
_FORMAT_VERSION = 1
|
||||
|
||||
|
||||
def _is_npz_file(filename: str) -> bool:
|
||||
with open(filename, 'rb') as f:
|
||||
return f.read(2) == b'PK'
|
||||
|
||||
|
||||
def _save_npz_payload(filename: str, payload: dict[str, Any]) -> None:
|
||||
with open(filename, 'wb') as f:
|
||||
numpy.savez_compressed(f, **payload)
|
||||
|
||||
|
||||
def _load_payload(filename: str) -> dict[str, Any]:
|
||||
if _is_npz_file(filename):
|
||||
with numpy.load(filename, allow_pickle=False) as payload:
|
||||
return {key: payload[key] for key in payload.files}
|
||||
|
||||
with open(filename, 'rb') as f:
|
||||
legacy = pickle.load(f)
|
||||
|
||||
if isinstance(legacy, Grid):
|
||||
return legacy._serialization_payload(kind='grid')
|
||||
if isinstance(legacy, dict):
|
||||
grid = Grid([[-1, 1]] * 3)
|
||||
grid.__dict__.update(legacy)
|
||||
return grid._serialization_payload(kind='grid')
|
||||
raise GridError('Unsupported serialized Grid payload')
|
||||
|
||||
|
||||
def _payload_scalar_str(payload: dict[str, Any], key: str) -> str:
|
||||
if key not in payload:
|
||||
raise GridError(f'Missing serialized key: {key}')
|
||||
|
||||
value = numpy.asarray(payload[key])
|
||||
if value.size != 1:
|
||||
raise GridError(f'Serialized key {key} must be scalar')
|
||||
return str(value.reshape(()))
|
||||
|
||||
|
||||
def _payload_scalar_int(payload: dict[str, Any], key: str) -> int:
|
||||
if key not in payload:
|
||||
raise GridError(f'Missing serialized key: {key}')
|
||||
|
||||
value = numpy.asarray(payload[key])
|
||||
if value.size != 1:
|
||||
raise GridError(f'Serialized key {key} must be scalar')
|
||||
return int(value.reshape(()))
|
||||
|
||||
|
||||
def _grid_from_payload(payload: dict[str, Any]) -> 'Grid':
|
||||
if _payload_scalar_int(payload, 'format_version') != _FORMAT_VERSION:
|
||||
raise GridError('Unsupported serialized Grid format version')
|
||||
|
||||
exyz = []
|
||||
for axis in range(3):
|
||||
key = f'exyz_{axis}'
|
||||
if key not in payload:
|
||||
raise GridError(f'Missing serialized key: {key}')
|
||||
exyz.append(numpy.array(payload[key], dtype=float))
|
||||
|
||||
if 'shifts' not in payload or 'periodic' not in payload:
|
||||
raise GridError('Serialized Grid payload is missing shifts or periodic data')
|
||||
|
||||
shifts = numpy.array(payload['shifts'], dtype=float)
|
||||
periodic = numpy.array(payload['periodic'], dtype=bool).tolist()
|
||||
return Grid(exyz, shifts=shifts, periodic=periodic)
|
||||
|
||||
|
||||
class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
||||
|
|
@ -165,8 +95,6 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
`GridError` on invalid input
|
||||
"""
|
||||
edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates]
|
||||
if len(edge_arrs) != 3:
|
||||
raise GridError('pixel_edge_coordinates must contain exactly 3 coordinate arrays')
|
||||
self.exyz = [numpy.unique(edges) for edges in edge_arrs]
|
||||
self.shifts = numpy.array(shifts, dtype=float)
|
||||
|
||||
|
|
@ -178,10 +106,6 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
self.periodic = [periodic] * 3
|
||||
else:
|
||||
self.periodic = list(periodic)
|
||||
if len(self.periodic) != 3:
|
||||
raise GridError('periodic must be a bool or a sequence of length 3')
|
||||
if not all(isinstance(pp, bool | numpy.bool_) for pp in self.periodic):
|
||||
raise GridError('periodic sequence entries must be bool values')
|
||||
|
||||
if len(self.shifts.shape) != 2:
|
||||
raise GridError('Misshapen shifts: shifts must have two axes! '
|
||||
|
|
@ -193,16 +117,9 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
if (numpy.abs(self.shifts) > 1).any():
|
||||
raise GridError('Only shifts in the range [-1, 1] are currently supported')
|
||||
|
||||
def _serialization_payload(self, *, kind: str) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
'kind': numpy.array(kind),
|
||||
'format_version': numpy.array(_FORMAT_VERSION, dtype=int),
|
||||
'shifts': self.shifts,
|
||||
'periodic': numpy.array(self.periodic, dtype=bool),
|
||||
}
|
||||
for axis, exyz in enumerate(self.exyz):
|
||||
payload[f'exyz_{axis}'] = exyz
|
||||
return payload
|
||||
if (self.shifts < 0).any():
|
||||
# TODO: Test negative shifts
|
||||
warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2)
|
||||
|
||||
@staticmethod
|
||||
def load(filename: str) -> 'Grid':
|
||||
|
|
@ -212,11 +129,12 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
Args:
|
||||
filename: Filename to load from.
|
||||
"""
|
||||
payload = _load_payload(filename)
|
||||
kind = _payload_scalar_str(payload, 'kind')
|
||||
if kind not in ('grid', 'grid_data'):
|
||||
raise GridError(f'Unsupported serialized kind: {kind}')
|
||||
return _grid_from_payload(payload)
|
||||
with open(filename, 'rb') as f:
|
||||
tmp_dict = pickle.load(f)
|
||||
|
||||
g = Grid([[-1, 1]] * 3)
|
||||
g.__dict__.update(tmp_dict)
|
||||
return g
|
||||
|
||||
def save(self, filename: str) -> Self:
|
||||
"""
|
||||
|
|
@ -228,18 +146,10 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
|||
Returns:
|
||||
self
|
||||
"""
|
||||
_save_npz_payload(filename, self._serialization_payload(kind='grid'))
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(self.__dict__, f, protocol=2)
|
||||
return self
|
||||
|
||||
def with_data(
|
||||
self,
|
||||
fill_value: float | None = 1.0,
|
||||
dtype: type[numpy.number] = numpy.float32,
|
||||
) -> 'GridData':
|
||||
from .data import GridData
|
||||
|
||||
return GridData(self.copy(), self.allocate(fill_value=fill_value, dtype=dtype))
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""
|
||||
Returns:
|
||||
|
|
|
|||
|
|
@ -47,13 +47,13 @@ class GridPosMixin(GridBase):
|
|||
else:
|
||||
low_bound = -0.5
|
||||
high_bound = -0.5
|
||||
if (ind < low_bound).any() or (ind > self.shape + high_bound).any():
|
||||
if (ind < low_bound).any() or (ind > self.shape - high_bound).any():
|
||||
raise GridError(f'Position outside of grid: {ind}')
|
||||
|
||||
if round_ind:
|
||||
rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1)
|
||||
sxyz = self.shifted_xyz(which_shifts)
|
||||
position = [sxyz[a][rind[a]] for a in range(3)]
|
||||
position = [sxyz[a][rind[a]].astype(int) for a in range(3)]
|
||||
else:
|
||||
sexyz = self.shifted_exyz(which_shifts)
|
||||
position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a])
|
||||
|
|
|
|||
|
|
@ -20,26 +20,6 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class GridReadMixin(GridPosMixin):
|
||||
@staticmethod
|
||||
def _preview_exyz_from_centers(centers: NDArray, fallback_edges: NDArray) -> NDArray[numpy.float64]:
|
||||
if centers.size > 1:
|
||||
midpoints = 0.5 * (centers[:-1] + centers[1:])
|
||||
first = centers[0] - 0.5 * (centers[1] - centers[0])
|
||||
last = centers[-1] + 0.5 * (centers[-1] - centers[-2])
|
||||
return numpy.hstack(([first], midpoints, [last]))
|
||||
return numpy.array([fallback_edges[0], fallback_edges[-1]], dtype=float)
|
||||
|
||||
def _sampled_exyz(self, which_shifts: int, sample_period: int) -> list[NDArray[numpy.float64]]:
|
||||
if sample_period <= 1:
|
||||
return self.shifted_exyz(which_shifts)
|
||||
|
||||
shifted_xyz = self.shifted_xyz(which_shifts)
|
||||
shifted_exyz = self.shifted_exyz(which_shifts)
|
||||
return [
|
||||
self._preview_exyz_from_centers(shifted_xyz[a][::sample_period], shifted_exyz[a])
|
||||
for a in range(3)
|
||||
]
|
||||
|
||||
def get_slice(
|
||||
self,
|
||||
cell_data: NDArray,
|
||||
|
|
@ -73,7 +53,7 @@ class GridReadMixin(GridPosMixin):
|
|||
surface = numpy.delete(range(3), plane.axis)
|
||||
|
||||
# Extract indices and weights of planes
|
||||
center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,))
|
||||
center3 = numpy.insert([0, 0], plane.axis, (plane.pos,))
|
||||
center_index = self.pos2ind(center3, which_shifts,
|
||||
round_ind=False, check_bounds=False)[plane.axis]
|
||||
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
|
||||
|
|
@ -83,13 +63,12 @@ class GridReadMixin(GridPosMixin):
|
|||
else:
|
||||
w = [1]
|
||||
|
||||
c_min, c_max = (self.shifted_xyz(which_shifts)[plane.axis][i] for i in [0, -1])
|
||||
c_min, c_max = (self.xyz[plane.axis][i] for i in [0, -1])
|
||||
if plane.pos < c_min or plane.pos > c_max:
|
||||
raise GridError('Coordinate of selected plane must be within simulation domain')
|
||||
|
||||
# Extract grid values from planes above and below visualized slice
|
||||
sample_shape = tuple(self.shifted_xyz(which_shifts)[a][::sp].size for a in surface)
|
||||
sliced_grid = numpy.zeros(sample_shape, dtype=numpy.result_type(cell_data.dtype, float))
|
||||
sliced_grid = numpy.zeros(self.shape[surface])
|
||||
for ci, weight in zip(centers, w, strict=True):
|
||||
s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3))
|
||||
sliced_grid += weight * cell_data[which_shifts][tuple(s)]
|
||||
|
|
@ -143,11 +122,7 @@ class GridReadMixin(GridPosMixin):
|
|||
|
||||
surface = numpy.delete(range(3), plane.axis)
|
||||
|
||||
if sample_period == 1:
|
||||
x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
|
||||
else:
|
||||
x, y = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface)
|
||||
pcolormesh_args.setdefault('shading', 'nearest')
|
||||
x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
|
||||
xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
|
||||
x_label, y_label = ('xyz'[a] for a in surface)
|
||||
|
||||
|
|
@ -233,10 +208,10 @@ class GridReadMixin(GridPosMixin):
|
|||
fig, ax = pyplot.subplots()
|
||||
else:
|
||||
fig = ax.figure
|
||||
xc, yc = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface)
|
||||
xc, yc = (self.shifted_xyz(which_shifts)[a] for a in surface)
|
||||
xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij')
|
||||
|
||||
ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args)
|
||||
mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args)
|
||||
|
||||
if finalize:
|
||||
pyplot.show()
|
||||
|
|
@ -282,14 +257,8 @@ class GridReadMixin(GridPosMixin):
|
|||
verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level)
|
||||
|
||||
# Convert vertices from index to position
|
||||
preview_exyz = self._sampled_exyz(which_shifts, sample_period)
|
||||
pos_verts = numpy.array([
|
||||
[
|
||||
numpy.interp(verts[i, a], numpy.arange(preview_exyz[a].size) - 0.5, preview_exyz[a])
|
||||
for a in range(3)
|
||||
]
|
||||
for i in range(verts.shape[0])
|
||||
], dtype=float)
|
||||
pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False)
|
||||
for i in range(verts.shape[0])], dtype=float)
|
||||
xs, ys, zs = (pos_verts[:, a] for a in range(3))
|
||||
|
||||
# Draw the plot
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
import pytest
|
||||
# import pytest
|
||||
import numpy
|
||||
from numpy.testing import assert_allclose #, assert_array_equal
|
||||
import pickle
|
||||
|
||||
from .. import Grid, GridData, Extent, GridError, Plane, Slab
|
||||
from .. import Grid, Extent #, Slab, Plane
|
||||
|
||||
|
||||
def test_draw_oncenter_2x2() -> None:
|
||||
|
|
@ -117,350 +116,3 @@ def test_draw_2shift_4x4() -> None:
|
|||
[0, 0.125, 0.125, 0]])[None, :, :, None]
|
||||
|
||||
assert_allclose(arr, correct)
|
||||
|
||||
|
||||
def test_ind2pos_round_preserves_float_centers() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]])
|
||||
|
||||
pos = grid.ind2pos(numpy.array([1, 0, 0]), which_shifts=0)
|
||||
|
||||
assert_allclose(pos, [2.0, 1.0, 0.5])
|
||||
|
||||
|
||||
def test_ind2pos_enforces_bounds_for_rounded_and_fractional_indices() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]])
|
||||
|
||||
with pytest.raises(GridError):
|
||||
grid.ind2pos(numpy.array([2, 0, 0]), which_shifts=0, check_bounds=True)
|
||||
|
||||
edge_pos = grid.ind2pos(numpy.array([1.5, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True)
|
||||
assert_allclose(edge_pos, [3.0, 2.0, 1.0])
|
||||
|
||||
with pytest.raises(GridError):
|
||||
grid.ind2pos(numpy.array([1.6, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True)
|
||||
|
||||
|
||||
def test_draw_polygon_accepts_coplanar_nx3_vertices() -> None:
|
||||
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]])
|
||||
arr_2d = grid.allocate(0)
|
||||
arr_3d = grid.allocate(0)
|
||||
slab = dict(axis='z', center=0.5, span=1.0)
|
||||
|
||||
polygon_2d = numpy.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=float)
|
||||
polygon_3d = numpy.array([[0, 0, 0.5],
|
||||
[1, 0, 0.5],
|
||||
[1, 1, 0.5],
|
||||
[0, 1, 0.5]], dtype=float)
|
||||
|
||||
grid.draw_polygon(arr_2d, slab=slab, polygon=polygon_2d, foreground=1)
|
||||
grid.draw_polygon(arr_3d, slab=slab, polygon=polygon_3d, foreground=1)
|
||||
|
||||
assert_allclose(arr_3d, arr_2d)
|
||||
|
||||
|
||||
def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None:
|
||||
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]])
|
||||
arr = grid.allocate(0)
|
||||
polygon = numpy.array([[0, 0, 0.5],
|
||||
[1, 0, 0.5],
|
||||
[1, 1, 0.75],
|
||||
[0, 1, 0.5]], dtype=float)
|
||||
|
||||
with pytest.raises(GridError):
|
||||
grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1)
|
||||
|
||||
|
||||
def test_get_slice_supports_sampling() -> None:
|
||||
grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]])
|
||||
cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape)
|
||||
|
||||
grid_slice = grid.get_slice(cell_data, Plane(z=0.5), sample_period=2)
|
||||
|
||||
assert_allclose(grid_slice, cell_data[0, ::2, ::2, 0])
|
||||
|
||||
|
||||
def test_sampled_visualization_helpers_do_not_error() -> None:
|
||||
matplotlib = pytest.importorskip('matplotlib')
|
||||
matplotlib.use('Agg')
|
||||
from matplotlib import pyplot
|
||||
|
||||
grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]])
|
||||
cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape)
|
||||
|
||||
fig_slice, ax_slice = grid.visualize_slice(cell_data, Plane(z=0.5), sample_period=2, finalize=False)
|
||||
fig_edges, ax_edges = grid.visualize_edges(cell_data, Plane(z=0.5), sample_period=2, finalize=False)
|
||||
|
||||
assert fig_slice is ax_slice.figure
|
||||
assert fig_edges is ax_edges.figure
|
||||
|
||||
pyplot.close(fig_slice)
|
||||
pyplot.close(fig_edges)
|
||||
|
||||
|
||||
def test_grid_constructor_rejects_invalid_coordinate_count() -> None:
|
||||
with pytest.raises(GridError):
|
||||
Grid([[0, 1], [0, 1]], shifts=[[0, 0, 0]])
|
||||
|
||||
with pytest.raises(GridError):
|
||||
Grid([[0, 1], [0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
|
||||
|
||||
|
||||
def test_grid_constructor_rejects_invalid_periodic_length() -> None:
|
||||
with pytest.raises(GridError):
|
||||
Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]], periodic=[True, False])
|
||||
|
||||
|
||||
def test_extent_and_slab_reject_inverted_geometry() -> None:
|
||||
with pytest.raises(GridError):
|
||||
Extent(center=0, min=1)
|
||||
|
||||
with pytest.raises(GridError):
|
||||
Extent(min=2, max=1)
|
||||
|
||||
with pytest.raises(GridError):
|
||||
Slab(axis='z', center=1, max=0)
|
||||
|
||||
|
||||
def test_extent_accepts_scalar_like_inputs() -> None:
|
||||
extent = Extent(min=numpy.array([1.0]), span=numpy.array([4.0]))
|
||||
|
||||
assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0])
|
||||
|
||||
|
||||
def test_get_slice_uses_shifted_grid_bounds() -> None:
|
||||
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0.5, 0, 0]])
|
||||
cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape)
|
||||
|
||||
grid_slice = grid.get_slice(cell_data, Plane(x=2.0), which_shifts=0)
|
||||
|
||||
assert_allclose(grid_slice, cell_data[0, 1, :, :])
|
||||
|
||||
with pytest.raises(GridError):
|
||||
grid.get_slice(cell_data, Plane(x=2.1), which_shifts=0)
|
||||
|
||||
|
||||
def test_draw_extrude_rectangle_uses_boundary_slice() -> None:
|
||||
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]])
|
||||
cell_data = grid.allocate(0)
|
||||
source = numpy.array([[1, 2],
|
||||
[3, 4]], dtype=float)
|
||||
cell_data[0, :, :, 1] = source
|
||||
|
||||
grid.draw_extrude_rectangle(
|
||||
cell_data,
|
||||
rectangle=[[0, 0, 2], [2, 2, 2]],
|
||||
direction=2,
|
||||
polarity=-1,
|
||||
distance=2,
|
||||
)
|
||||
|
||||
assert_allclose(cell_data[0, :, :, 0], source)
|
||||
assert_allclose(cell_data[0, :, :, 1], source)
|
||||
|
||||
|
||||
def test_sampled_preview_exyz_tracks_nonuniform_centers() -> None:
|
||||
grid = Grid([[0, 1, 3, 6, 10], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]])
|
||||
|
||||
sampled_exyz = grid._sampled_exyz(0, 2)
|
||||
|
||||
assert_allclose(sampled_exyz[0], [-1.5, 2.5, 6.5])
|
||||
|
||||
|
||||
def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
matplotlib = pytest.importorskip('matplotlib')
|
||||
matplotlib.use('Agg')
|
||||
skimage_measure = pytest.importorskip('skimage.measure')
|
||||
from matplotlib import pyplot
|
||||
from mpl_toolkits.mplot3d.axes3d import Axes3D
|
||||
|
||||
captured: dict[str, numpy.ndarray] = {}
|
||||
|
||||
def fake_marching_cubes(_grid: numpy.ndarray, _level: float) -> tuple[numpy.ndarray, numpy.ndarray, None, None]:
|
||||
verts = numpy.array([[0.5, 0.5, 0.5],
|
||||
[0.5, 1.5, 0.5],
|
||||
[1.5, 0.5, 0.5]], dtype=float)
|
||||
faces = numpy.array([[0, 1, 2]], dtype=int)
|
||||
return verts, faces, None, None
|
||||
|
||||
def fake_plot_trisurf( # noqa: ANN202
|
||||
_self: object,
|
||||
xs: numpy.ndarray,
|
||||
ys: numpy.ndarray,
|
||||
faces: numpy.ndarray,
|
||||
zs: numpy.ndarray,
|
||||
*_args: object,
|
||||
**_kwargs: object,
|
||||
) -> object:
|
||||
captured['xs'] = numpy.asarray(xs)
|
||||
captured['ys'] = numpy.asarray(ys)
|
||||
captured['faces'] = numpy.asarray(faces)
|
||||
captured['zs'] = numpy.asarray(zs)
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(skimage_measure, 'marching_cubes', fake_marching_cubes)
|
||||
monkeypatch.setattr(Axes3D, 'plot_trisurf', fake_plot_trisurf)
|
||||
|
||||
grid = Grid([numpy.arange(7, dtype=float), numpy.arange(7, dtype=float), numpy.arange(7, dtype=float)], shifts=[[0, 0, 0]])
|
||||
cell_data = numpy.zeros(grid.cell_data_shape)
|
||||
|
||||
fig, _ax = grid.visualize_isosurface(cell_data, level=0.5, sample_period=2, finalize=False)
|
||||
|
||||
assert_allclose(captured['xs'], [1.5, 1.5, 3.5])
|
||||
assert_allclose(captured['ys'], [1.5, 3.5, 1.5])
|
||||
assert_allclose(captured['zs'], [1.5, 1.5, 1.5])
|
||||
|
||||
pyplot.close(fig)
|
||||
|
||||
|
||||
def test_grid_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False])
|
||||
path = tmp_path / 'grid.state'
|
||||
|
||||
grid.save(str(path))
|
||||
loaded = Grid.load(str(path))
|
||||
|
||||
assert path.exists()
|
||||
for original, restored in zip(grid.exyz, loaded.exyz, strict=True):
|
||||
assert_allclose(restored, original)
|
||||
assert_allclose(loaded.shifts, grid.shifts)
|
||||
assert loaded.periodic == grid.periodic
|
||||
|
||||
|
||||
def test_grid_load_supports_legacy_pickle(tmp_path: pytest.TempPathFactory) -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False])
|
||||
path = tmp_path / 'grid.pickle'
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(grid.__dict__, f, protocol=2)
|
||||
|
||||
loaded = Grid.load(str(path))
|
||||
|
||||
for original, restored in zip(grid.exyz, loaded.exyz, strict=True):
|
||||
assert_allclose(restored, original)
|
||||
assert_allclose(loaded.shifts, grid.shifts)
|
||||
assert loaded.periodic == grid.periodic
|
||||
|
||||
|
||||
def test_griddata_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None:
|
||||
data = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]]).with_data(fill_value=2.0)
|
||||
data.cell_data[0, 1, 0, 0] = 5.0
|
||||
path = tmp_path / 'griddata.state'
|
||||
|
||||
data.save(str(path))
|
||||
loaded = GridData.load(str(path))
|
||||
|
||||
assert path.exists()
|
||||
assert_allclose(loaded.cell_data, data.cell_data)
|
||||
assert_allclose(loaded.grid.shifts, data.grid.shifts)
|
||||
assert loaded.grid.periodic == data.grid.periodic
|
||||
|
||||
|
||||
def test_griddata_rejects_invalid_payload_kind(tmp_path: pytest.TempPathFactory) -> None:
|
||||
path = tmp_path / 'grid.state'
|
||||
Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).save(str(path))
|
||||
|
||||
with pytest.raises(GridError):
|
||||
GridData.load(str(path))
|
||||
|
||||
|
||||
def test_negative_shift_nonperiodic_edges_and_widths() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||
|
||||
assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0])
|
||||
assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5])
|
||||
assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25])
|
||||
|
||||
|
||||
def test_negative_shift_periodic_edges_and_widths() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False])
|
||||
|
||||
assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0])
|
||||
assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5])
|
||||
|
||||
|
||||
def test_negative_shift_coordinate_round_trip() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||
|
||||
ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False)
|
||||
pos = grid.ind2pos(ind, 0, round_ind=False)
|
||||
|
||||
assert_allclose(ind, [1.0, 0.0, 0.0])
|
||||
assert_allclose(pos, [1.25, 1.0, 0.5])
|
||||
|
||||
|
||||
def test_negative_shift_draw_cuboid_fractional_fill() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||
arr = grid.allocate(0)
|
||||
|
||||
grid.draw_cuboid(
|
||||
arr,
|
||||
x=dict(min=0, max=1),
|
||||
y=dict(min=0, max=1),
|
||||
z=dict(min=0, max=1),
|
||||
foreground=1,
|
||||
)
|
||||
|
||||
assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3])
|
||||
|
||||
|
||||
def test_negative_shift_get_slice_uses_shifted_centers() -> None:
|
||||
grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
|
||||
cell_data = numpy.zeros(grid.cell_data_shape)
|
||||
cell_data[0, 1, :, 0] = [7, 9]
|
||||
x_center = float(grid.shifted_xyz(0)[0][1])
|
||||
|
||||
grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0)
|
||||
|
||||
assert_allclose(grid_slice, [7, 9])
|
||||
|
||||
|
||||
def test_grid_with_data_returns_griddata() -> None:
|
||||
grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
|
||||
data = grid.with_data(fill_value=2.0)
|
||||
|
||||
assert isinstance(data, GridData)
|
||||
assert_allclose(data.cell_data, numpy.full(grid.cell_data_shape, 2.0, dtype=numpy.float32))
|
||||
|
||||
|
||||
def test_griddata_constructor_validates_shape() -> None:
|
||||
grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
|
||||
|
||||
with pytest.raises(GridError):
|
||||
GridData(grid, numpy.zeros((1, 1, 1)))
|
||||
|
||||
|
||||
def test_griddata_draw_methods_are_chainable() -> None:
|
||||
data = Grid([[0, 1, 2], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0)
|
||||
|
||||
chained = data.draw_cuboid(
|
||||
foreground=1,
|
||||
x=dict(min=0, max=1),
|
||||
y=dict(min=0, max=1),
|
||||
z=dict(min=0, max=1),
|
||||
).draw_polygon(
|
||||
foreground=0.5,
|
||||
slab=dict(axis='z', center=0.5, span=1.0),
|
||||
polygon=numpy.array([[0, 0], [2, 0], [2, 1], [0, 1]], dtype=float),
|
||||
)
|
||||
|
||||
assert chained is data
|
||||
assert data.cell_data.sum() > 0
|
||||
|
||||
|
||||
def test_griddata_read_methods_delegate() -> None:
|
||||
data = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0)
|
||||
data.cell_data[0, :, :, 0] = numpy.array([[1, 2], [3, 4]], dtype=float)
|
||||
|
||||
assert_allclose(
|
||||
data.get_slice(Plane(z=0.5)),
|
||||
data.grid.get_slice(data.cell_data, Plane(z=0.5)),
|
||||
)
|
||||
|
||||
|
||||
def test_griddata_copy_is_independent() -> None:
|
||||
data = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=1.0)
|
||||
cloned = data.copy()
|
||||
cloned.cell_data[0, 0, 0, 0] = 5.0
|
||||
|
||||
assert data is not cloned
|
||||
assert data.grid is not cloned.grid
|
||||
assert data.cell_data[0, 0, 0, 0] == 1.0
|
||||
|
|
|
|||
|
|
@ -1,25 +1,12 @@
|
|||
from typing import Protocol, TypedDict, runtime_checkable, cast
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy
|
||||
|
||||
|
||||
class GridError(Exception):
|
||||
""" Base error type for `gridlock` """
|
||||
pass
|
||||
|
||||
|
||||
def _coerce_scalar(name: str, value: object) -> float:
|
||||
arr = numpy.asarray(value)
|
||||
if arr.size != 1:
|
||||
raise GridError(f'{name} must be a scalar value')
|
||||
|
||||
try:
|
||||
return float(arr.reshape(()))
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise GridError(f'{name} must be a real scalar value') from exc
|
||||
|
||||
|
||||
class ExtentDict(TypedDict, total=False):
|
||||
"""
|
||||
Geometrical definition of an extent (1D bounded region)
|
||||
|
|
@ -71,46 +58,44 @@ class Extent(ExtentProtocol):
|
|||
max: float | None = None,
|
||||
span: float | None = None,
|
||||
) -> None:
|
||||
values = {
|
||||
'min': None if min is None else _coerce_scalar('min', min),
|
||||
'center': None if center is None else _coerce_scalar('center', center),
|
||||
'max': None if max is None else _coerce_scalar('max', max),
|
||||
'span': None if span is None else _coerce_scalar('span', span),
|
||||
}
|
||||
if sum(value is not None for value in values.values()) != 2:
|
||||
raise GridError('Exactly two of min, center, max, span must be provided')
|
||||
if sum(cc is None for cc in (min, center, max, span)) != 2:
|
||||
raise GridError('Exactly two of min, center, max, span must be None!')
|
||||
|
||||
min_v = values['min']
|
||||
center_v = values['center']
|
||||
max_v = values['max']
|
||||
span_v = values['span']
|
||||
if span is None:
|
||||
if center is None:
|
||||
assert min is not None
|
||||
assert max is not None
|
||||
assert max >= min
|
||||
center = 0.5 * (max + min)
|
||||
span = max - min
|
||||
elif max is None:
|
||||
assert min is not None
|
||||
assert center is not None
|
||||
span = 2 * (center - min)
|
||||
elif min is None:
|
||||
assert center is not None
|
||||
assert max is not None
|
||||
span = 2 * (max - center)
|
||||
else: # noqa: PLR5501
|
||||
if center is not None:
|
||||
pass
|
||||
elif max is None:
|
||||
assert min is not None
|
||||
assert span is not None
|
||||
center = min + 0.5 * span
|
||||
elif min is None:
|
||||
assert max is not None
|
||||
assert span is not None
|
||||
center = max - 0.5 * span
|
||||
|
||||
if span_v is not None and span_v < 0:
|
||||
raise GridError('span must be non-negative')
|
||||
|
||||
if min_v is not None and max_v is not None:
|
||||
if max_v < min_v:
|
||||
raise GridError('max must be greater than or equal to min')
|
||||
center_v = 0.5 * (max_v + min_v)
|
||||
span_v = max_v - min_v
|
||||
elif center_v is not None and min_v is not None:
|
||||
span_v = 2 * (center_v - min_v)
|
||||
if span_v < 0:
|
||||
raise GridError('min must be less than or equal to center')
|
||||
elif center_v is not None and max_v is not None:
|
||||
span_v = 2 * (max_v - center_v)
|
||||
if span_v < 0:
|
||||
raise GridError('center must be less than or equal to max')
|
||||
elif min_v is not None and span_v is not None:
|
||||
center_v = min_v + 0.5 * span_v
|
||||
elif max_v is not None and span_v is not None:
|
||||
center_v = max_v - 0.5 * span_v
|
||||
|
||||
if center_v is None or span_v is None:
|
||||
raise GridError('Unable to construct extent from the provided values')
|
||||
|
||||
self.center = center_v
|
||||
self.span = span_v
|
||||
assert center is not None
|
||||
assert span is not None
|
||||
if hasattr(center, '__len__'):
|
||||
assert len(center) == 1
|
||||
if hasattr(span, '__len__'):
|
||||
assert len(span) == 1
|
||||
self.center = center
|
||||
self.span = span
|
||||
|
||||
|
||||
class SlabDict(TypedDict, total=False):
|
||||
|
|
@ -246,3 +231,4 @@ class Plane(PlaneProtocol):
|
|||
if hasattr(cpos, '__len__'):
|
||||
assert len(cpos) == 1
|
||||
self.pos = cpos
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue