Compare commits

..

No commits in common. "master" and "arg_rework" have entirely different histories.

10 changed files with 124 additions and 900 deletions

View file

@ -31,9 +31,8 @@ from .utils import (
PlaneDict as PlaneDict,
)
from .grid import Grid as Grid
from .data import GridData as GridData
__author__ = 'Jan Petykiewicz'
__version__ = '2.2'
__version__ = '1.2'
version = __version__

View file

@ -76,21 +76,6 @@ class GridBase(Protocol):
el = [0 if p else -1 for p in self.periodic]
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
if which_shifts is None:
return self.dxyz_with_ghost
shifts = self.shifts[which_shifts, :]
edge_dxyz = []
for a in range(3):
if shifts[a] < 0:
ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0]
edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a])))
else:
ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1]
edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost)))
return edge_dxyz
@property
def center(self) -> NDArray[numpy.float64]:
"""
@ -130,9 +115,15 @@ class GridBase(Protocol):
"""
if which_shifts is None:
return self.exyz
edge_dxyz = self._shifted_edge_dxyz(which_shifts)
dxyz = self.dxyz_with_ghost
shifts = self.shifts[which_shifts, :]
return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)]
# If shift is negative, use left cell's dx to determine shift
for a in range(3):
if shifts[a] < 0:
dxyz[a] = numpy.roll(dxyz[a], 1)
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
"""
@ -146,7 +137,20 @@ class GridBase(Protocol):
"""
if which_shifts is None:
return self.dxyz
return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)]
shifts = self.shifts[which_shifts, :]
dxyz = self.dxyz_with_ghost
# If shift is negative, use left cell's dx to determine size
sdxyz = []
for a in range(3):
if shifts[a] < 0:
roll_dxyz = numpy.roll(dxyz[a], 1)
abs_shift = numpy.abs(shifts[a])
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
else:
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
return sdxyz
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
"""

View file

@ -1,176 +0,0 @@
from dataclasses import dataclass
from typing import Self
from collections.abc import Sequence
import numpy
from numpy.typing import NDArray, ArrayLike
from .draw import foreground_t
from .grid import Grid, _grid_from_payload, _load_payload, _payload_scalar_str, _save_npz_payload
from .utils import (
ExtentDict,
ExtentProtocol,
GridError,
PlaneDict,
PlaneProtocol,
SlabDict,
SlabProtocol,
)
@dataclass(slots=True)
class GridData:
grid: Grid
cell_data: NDArray
def __post_init__(self) -> None:
if tuple(self.cell_data.shape) != tuple(self.grid.cell_data_shape):
raise GridError(
f'cell_data has shape {self.cell_data.shape}, expected {tuple(self.grid.cell_data_shape)}'
)
@staticmethod
def load(filename: str) -> 'GridData':
payload = _load_payload(filename)
if _payload_scalar_str(payload, 'kind') != 'grid_data':
raise GridError('Serialized payload does not contain GridData')
if 'cell_data' not in payload:
raise GridError('Serialized GridData payload is missing cell_data')
return GridData(_grid_from_payload(payload), numpy.array(payload['cell_data']))
def save(self, filename: str) -> Self:
payload = self.grid._serialization_payload(kind='grid_data')
payload['cell_data'] = self.cell_data
_save_npz_payload(filename, payload)
return self
def copy(self) -> Self:
return GridData(self.grid.copy(), self.cell_data.copy())
def draw_polygons(
self,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
polygons: Sequence[ArrayLike],
*,
offset2d: ArrayLike = (0, 0),
) -> Self:
self.grid.draw_polygons(self.cell_data, foreground, slab, polygons, offset2d=offset2d)
return self
def draw_polygon(
self,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
polygon: ArrayLike,
*,
offset2d: ArrayLike = (0, 0),
) -> Self:
self.grid.draw_polygon(self.cell_data, foreground, slab, polygon, offset2d=offset2d)
return self
def draw_slab(
self,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
) -> Self:
self.grid.draw_slab(self.cell_data, foreground, slab)
return self
def draw_cuboid(
self,
foreground: Sequence[foreground_t] | foreground_t,
*,
x: ExtentProtocol | ExtentDict,
y: ExtentProtocol | ExtentDict,
z: ExtentProtocol | ExtentDict,
) -> Self:
self.grid.draw_cuboid(self.cell_data, foreground, x=x, y=y, z=z)
return self
def draw_cylinder(
self,
h: SlabProtocol | SlabDict,
radius: float,
num_points: int,
center2d: ArrayLike,
foreground: Sequence[foreground_t] | foreground_t,
) -> Self:
self.grid.draw_cylinder(self.cell_data, h, radius, num_points, center2d, foreground)
return self
def draw_extrude_rectangle(
self,
rectangle: ArrayLike,
direction: int,
polarity: int,
distance: float,
) -> Self:
self.grid.draw_extrude_rectangle(self.cell_data, rectangle, direction, polarity, distance)
return self
def get_slice(
self,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1,
) -> NDArray:
return self.grid.get_slice(self.cell_data, plane, which_shifts=which_shifts, sample_period=sample_period)
def visualize_slice(
self,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1,
finalize: bool = True,
pcolormesh_args: dict[str, object] | None = None,
ax: object | None = None,
) -> tuple[object, object]:
return self.grid.visualize_slice(
self.cell_data,
plane,
which_shifts=which_shifts,
sample_period=sample_period,
finalize=finalize,
pcolormesh_args=pcolormesh_args,
ax=ax,
)
def visualize_edges(
self,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1,
finalize: bool = True,
contour_args: dict[str, object] | None = None,
ax: object | None = None,
level_fraction: float = 0.7,
) -> tuple[object, object]:
return self.grid.visualize_edges(
self.cell_data,
plane,
which_shifts=which_shifts,
sample_period=sample_period,
finalize=finalize,
contour_args=contour_args,
ax=ax,
level_fraction=level_fraction,
)
def visualize_isosurface(
self,
level: float | None = None,
which_shifts: int = 0,
sample_period: int = 1,
show_edges: bool = True,
finalize: bool = True,
) -> tuple[object, object]:
return self.grid.visualize_isosurface(
self.cell_data,
level=level,
which_shifts=which_shifts,
sample_period=sample_period,
show_edges=show_edges,
finalize=finalize,
)

View file

@ -21,31 +21,30 @@ foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray]
foreground_t = float | foreground_callable_t
class GridDrawMixin(GridPosMixin):
def draw_polygons(
self,
cell_data: NDArray,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
polygons: Sequence[ArrayLike],
foreground: Sequence[foreground_t] | foreground_t,
*,
offset2d: ArrayLike = (0, 0),
) -> None:
"""
Draw polygons on an axis-aligned slab.
Draw polygons on an axis-aligned plane.
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
center: 3-element ndarray or list specifying an offset applied to all the polygons
polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon
(non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each
polygon must have at least 3 vertices.
foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list
of any of these (1 per grid). Callable values should take an ndarray the shape of the
grid and return an ndarray of equal shape containing the foreground value at the given x, y,
and z (natural, not grid coordinates).
slab: `Slab` or slab-like dict specifying the slab in which the polygons will be drawn.
polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon
(non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each
polygon must have at least 3 vertices.
offset2d: 2D offset to apply to polygon coordinates -- this offset is added directly
to the given polygon vertex coordinates. Default (0, 0).
Raises:
GridError
@ -61,25 +60,23 @@ class GridDrawMixin(GridPosMixin):
for ii in range(len(poly_list)):
polygon = poly_list[ii]
malformed = f'Malformed polygon: ({ii})'
if polygon.ndim != 2:
raise GridError(malformed + 'must be a 2-dimensional ndarray')
if polygon.shape[1] not in (2, 3):
raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray')
if polygon.shape[1] == 3:
if numpy.unique(polygon[:, slab.axis]).size != 1:
raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis])
polygon = polygon[:, surface]
polygon = polygon[surface, :]
poly_list[ii] = polygon
if not polygon.shape[0] > 2:
raise GridError(malformed + 'must consist of more than 2 points')
if polygon.ndim > 2 and not numpy.unique(polygon[:, slab.axis]).size == 1:
raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis])
# Broadcast foreground where necessary
foregrounds: Sequence[foreground_callable_t] | Sequence[float]
if isinstance(foreground, numpy.ndarray):
if numpy.size(foreground) == 1: # type: ignore
foregrounds = [foreground] * len(cell_data) # type: ignore
elif isinstance(foreground, numpy.ndarray):
raise GridError('ndarray not supported for foreground')
if callable(foreground) or numpy.isscalar(foreground):
foregrounds = [foreground] * len(cell_data) # type: ignore[list-item]
else:
foregrounds = foreground # type: ignore
@ -203,9 +200,9 @@ class GridDrawMixin(GridPosMixin):
def draw_polygon(
self,
cell_data: NDArray,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
polygon: ArrayLike,
foreground: Sequence[foreground_t] | foreground_t,
*,
offset2d: ArrayLike = (0, 0),
) -> None:
@ -214,13 +211,11 @@ class GridDrawMixin(GridPosMixin):
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
slab: `Slab` or slab-like dict specifying the slab in which the polygon will be drawn.
slab: `Slab` in which to draw polygons.
polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed,
clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Must have at
least 3 vertices.
offset2d: 2D offset to apply to polygon coordinates -- this offset is added directly
to the given polygon vertex coordinates. Default (0, 0).
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
"""
self.draw_polygons(
cell_data = cell_data,
@ -234,16 +229,17 @@ class GridDrawMixin(GridPosMixin):
def draw_slab(
self,
cell_data: NDArray,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
foreground: Sequence[foreground_t] | foreground_t,
) -> None:
"""
Draw an axis-aligned infinite slab.
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
slab:
thickness: Thickness of the layer to draw
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
slab: `Slab` or slab-like dict (geometrical slab specification)
"""
if isinstance(slab, dict):
slab = Slab(**slab)
@ -286,10 +282,10 @@ class GridDrawMixin(GridPosMixin):
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
center: 3-element ndarray or list specifying the cuboid's center
dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge
sizes of the cuboid
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
x: `Extent` or extent-like dict specifying the x-extent of the cuboid.
y: `Extent` or extent-like dict specifying the y-extent of the cuboid.
z: `Extent` or extent-like dict specifying the z-extent of the cuboid.
"""
if isinstance(x, dict):
x = Extent(**x)
@ -298,6 +294,8 @@ class GridDrawMixin(GridPosMixin):
if isinstance(z, dict):
z = Extent(**z)
center = numpy.asarray([x.center, y.center, z.center])
p = numpy.array([[x.min, y.max],
[x.max, y.max],
[x.max, y.min],
@ -376,18 +374,15 @@ class GridDrawMixin(GridPosMixin):
foreground_func = []
for ii, grid in enumerate(cell_data):
zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction]
ind = [int(numpy.floor(zz)) if dd == direction else slice(None) for dd in range(3)]
fpart = zz - numpy.floor(zz)
low = int(numpy.clip(numpy.floor(zz), 0, grid.shape[direction] - 1))
high = int(numpy.clip(numpy.floor(zz) + 1, 0, grid.shape[direction] - 1))
mult = [1 - fpart, fpart][::sgn] # reverses if s negative
low_ind = [low if dd == direction else slice(None) for dd in range(3)]
high_ind = [high if dd == direction else slice(None) for dd in range(3)]
if low == high:
foreground = grid[tuple(low_ind)]
else:
mult = [1 - fpart, fpart][::sgn] # reverses if s negative
foreground = mult[0] * grid[tuple(low_ind)] + mult[1] * grid[tuple(high_ind)]
foreground = mult[0] * grid[tuple(ind)]
ind[direction] += 1 # type: ignore #(known safe)
foreground += mult[1] * grid[tuple(ind)]
def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001
# transform from natural position to index
@ -401,3 +396,4 @@ class GridDrawMixin(GridPosMixin):
slab = Slab(axis=direction, center=center[direction], span=thickness)
self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface])

View file

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, ClassVar, Self
from typing import ClassVar, Self
from collections.abc import Callable, Sequence
import numpy
@ -13,78 +13,8 @@ from .draw import GridDrawMixin
from .read import GridReadMixin
from .position import GridPosMixin
if TYPE_CHECKING:
from .data import GridData
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
_FORMAT_VERSION = 1
def _is_npz_file(filename: str) -> bool:
with open(filename, 'rb') as f:
return f.read(2) == b'PK'
def _save_npz_payload(filename: str, payload: dict[str, Any]) -> None:
with open(filename, 'wb') as f:
numpy.savez_compressed(f, **payload)
def _load_payload(filename: str) -> dict[str, Any]:
if _is_npz_file(filename):
with numpy.load(filename, allow_pickle=False) as payload:
return {key: payload[key] for key in payload.files}
with open(filename, 'rb') as f:
legacy = pickle.load(f)
if isinstance(legacy, Grid):
return legacy._serialization_payload(kind='grid')
if isinstance(legacy, dict):
grid = Grid([[-1, 1]] * 3)
grid.__dict__.update(legacy)
return grid._serialization_payload(kind='grid')
raise GridError('Unsupported serialized Grid payload')
def _payload_scalar_str(payload: dict[str, Any], key: str) -> str:
if key not in payload:
raise GridError(f'Missing serialized key: {key}')
value = numpy.asarray(payload[key])
if value.size != 1:
raise GridError(f'Serialized key {key} must be scalar')
return str(value.reshape(()))
def _payload_scalar_int(payload: dict[str, Any], key: str) -> int:
if key not in payload:
raise GridError(f'Missing serialized key: {key}')
value = numpy.asarray(payload[key])
if value.size != 1:
raise GridError(f'Serialized key {key} must be scalar')
return int(value.reshape(()))
def _grid_from_payload(payload: dict[str, Any]) -> 'Grid':
if _payload_scalar_int(payload, 'format_version') != _FORMAT_VERSION:
raise GridError('Unsupported serialized Grid format version')
exyz = []
for axis in range(3):
key = f'exyz_{axis}'
if key not in payload:
raise GridError(f'Missing serialized key: {key}')
exyz.append(numpy.array(payload[key], dtype=float))
if 'shifts' not in payload or 'periodic' not in payload:
raise GridError('Serialized Grid payload is missing shifts or periodic data')
shifts = numpy.array(payload['shifts'], dtype=float)
periodic = numpy.array(payload['periodic'], dtype=bool).tolist()
return Grid(exyz, shifts=shifts, periodic=periodic)
class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
@ -165,8 +95,6 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
`GridError` on invalid input
"""
edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates]
if len(edge_arrs) != 3:
raise GridError('pixel_edge_coordinates must contain exactly 3 coordinate arrays')
self.exyz = [numpy.unique(edges) for edges in edge_arrs]
self.shifts = numpy.array(shifts, dtype=float)
@ -178,10 +106,6 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
self.periodic = [periodic] * 3
else:
self.periodic = list(periodic)
if len(self.periodic) != 3:
raise GridError('periodic must be a bool or a sequence of length 3')
if not all(isinstance(pp, bool | numpy.bool_) for pp in self.periodic):
raise GridError('periodic sequence entries must be bool values')
if len(self.shifts.shape) != 2:
raise GridError('Misshapen shifts: shifts must have two axes! '
@ -193,16 +117,9 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
if (numpy.abs(self.shifts) > 1).any():
raise GridError('Only shifts in the range [-1, 1] are currently supported')
def _serialization_payload(self, *, kind: str) -> dict[str, Any]:
payload: dict[str, Any] = {
'kind': numpy.array(kind),
'format_version': numpy.array(_FORMAT_VERSION, dtype=int),
'shifts': self.shifts,
'periodic': numpy.array(self.periodic, dtype=bool),
}
for axis, exyz in enumerate(self.exyz):
payload[f'exyz_{axis}'] = exyz
return payload
if (self.shifts < 0).any():
# TODO: Test negative shifts
warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2)
@staticmethod
def load(filename: str) -> 'Grid':
@ -212,11 +129,12 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
Args:
filename: Filename to load from.
"""
payload = _load_payload(filename)
kind = _payload_scalar_str(payload, 'kind')
if kind not in ('grid', 'grid_data'):
raise GridError(f'Unsupported serialized kind: {kind}')
return _grid_from_payload(payload)
with open(filename, 'rb') as f:
tmp_dict = pickle.load(f)
g = Grid([[-1, 1]] * 3)
g.__dict__.update(tmp_dict)
return g
def save(self, filename: str) -> Self:
"""
@ -228,18 +146,10 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
Returns:
self
"""
_save_npz_payload(filename, self._serialization_payload(kind='grid'))
with open(filename, 'wb') as f:
pickle.dump(self.__dict__, f, protocol=2)
return self
def with_data(
self,
fill_value: float | None = 1.0,
dtype: type[numpy.number] = numpy.float32,
) -> 'GridData':
from .data import GridData
return GridData(self.copy(), self.allocate(fill_value=fill_value, dtype=dtype))
def copy(self) -> Self:
"""
Returns:

View file

@ -47,13 +47,13 @@ class GridPosMixin(GridBase):
else:
low_bound = -0.5
high_bound = -0.5
if (ind < low_bound).any() or (ind > self.shape + high_bound).any():
if (ind < low_bound).any() or (ind > self.shape - high_bound).any():
raise GridError(f'Position outside of grid: {ind}')
if round_ind:
rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1)
sxyz = self.shifted_xyz(which_shifts)
position = [sxyz[a][rind[a]] for a in range(3)]
position = [sxyz[a][rind[a]].astype(int) for a in range(3)]
else:
sexyz = self.shifted_exyz(which_shifts)
position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a])

View file

@ -20,26 +20,6 @@ if TYPE_CHECKING:
class GridReadMixin(GridPosMixin):
@staticmethod
def _preview_exyz_from_centers(centers: NDArray, fallback_edges: NDArray) -> NDArray[numpy.float64]:
if centers.size > 1:
midpoints = 0.5 * (centers[:-1] + centers[1:])
first = centers[0] - 0.5 * (centers[1] - centers[0])
last = centers[-1] + 0.5 * (centers[-1] - centers[-2])
return numpy.hstack(([first], midpoints, [last]))
return numpy.array([fallback_edges[0], fallback_edges[-1]], dtype=float)
def _sampled_exyz(self, which_shifts: int, sample_period: int) -> list[NDArray[numpy.float64]]:
if sample_period <= 1:
return self.shifted_exyz(which_shifts)
shifted_xyz = self.shifted_xyz(which_shifts)
shifted_exyz = self.shifted_exyz(which_shifts)
return [
self._preview_exyz_from_centers(shifted_xyz[a][::sample_period], shifted_exyz[a])
for a in range(3)
]
def get_slice(
self,
cell_data: NDArray,
@ -73,7 +53,7 @@ class GridReadMixin(GridPosMixin):
surface = numpy.delete(range(3), plane.axis)
# Extract indices and weights of planes
center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,))
center3 = numpy.insert([0, 0], plane.axis, (plane.pos,))
center_index = self.pos2ind(center3, which_shifts,
round_ind=False, check_bounds=False)[plane.axis]
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
@ -83,13 +63,12 @@ class GridReadMixin(GridPosMixin):
else:
w = [1]
c_min, c_max = (self.shifted_xyz(which_shifts)[plane.axis][i] for i in [0, -1])
c_min, c_max = (self.xyz[plane.axis][i] for i in [0, -1])
if plane.pos < c_min or plane.pos > c_max:
raise GridError('Coordinate of selected plane must be within simulation domain')
# Extract grid values from planes above and below visualized slice
sample_shape = tuple(self.shifted_xyz(which_shifts)[a][::sp].size for a in surface)
sliced_grid = numpy.zeros(sample_shape, dtype=numpy.result_type(cell_data.dtype, float))
sliced_grid = numpy.zeros(self.shape[surface])
for ci, weight in zip(centers, w, strict=True):
s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3))
sliced_grid += weight * cell_data[which_shifts][tuple(s)]
@ -108,7 +87,6 @@ class GridReadMixin(GridPosMixin):
sample_period: int = 1,
finalize: bool = True,
pcolormesh_args: dict[str, Any] | None = None,
ax: 'matplotlib.axes.Axes | None' = None,
) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']:
"""
Visualize a slice of a grid.
@ -120,8 +98,6 @@ class GridReadMixin(GridPosMixin):
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
pcolormesh_args: Args passed through to matplotlib `pcolormesh()`
ax: If provided, plot to these axes (instead of creating a new figure & axes)
Returns:
(Figure, Axes)
@ -135,109 +111,24 @@ class GridReadMixin(GridPosMixin):
pcolormesh_args = {}
grid_slice = self.get_slice(
cell_data = cell_data,
plane = plane,
which_shifts = which_shifts,
sample_period = sample_period,
cell_data=cell_data,
plane=plane,
which_shifts=which_shifts,
sample_period=sample_period,
)
surface = numpy.delete(range(3), plane.axis)
if sample_period == 1:
x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
else:
x, y = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface)
pcolormesh_args.setdefault('shading', 'nearest')
x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
x_label, y_label = ('xyz'[a] for a in surface)
if ax is None:
fig, ax = pyplot.subplots()
else:
fig = ax.figure
fig, ax = pyplot.subplots()
mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args)
fig.colorbar(mappable)
ax.set_aspect('equal', adjustable='box')
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
if finalize:
pyplot.show()
return fig, ax
def visualize_edges(
self,
cell_data: NDArray,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1,
finalize: bool = True,
contour_args: dict[str, Any] | None = None,
ax: 'matplotlib.axes.Axes | None' = None,
level_fraction: float = 0.7,
) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']:
"""
Visualize the edges of a grid slice.
This is intended as an overlay on top of visualize_slice (e.g. showing epsilon boundaries
on an E-field plot).
Interpolates if given a position between two grid planes.
Args:
cell_data: Cell data to visualize
plane: Axis and position (`Plane`) of the plane to read.
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
contour_args: Args passed through to matplotlib `pcolormesh()`
ax: If provided, plot to these axes (instead of creating a new figure & axes)
level_fraction: Value between 0 and 1 which tunes how many contours are generated.
1 indicates that every possible step should have its own contour.
Returns:
(Figure, Axes)
"""
from matplotlib import pyplot
if level_fraction > 1:
raise GridError(f'{level_fraction=} must be between 0 and 1')
if isinstance(plane, dict):
plane = Plane(**plane)
if contour_args is None:
contour_args = dict(alpha=0.8, colors='gray')
grid_slice = self.get_slice(
cell_data = cell_data,
plane = plane,
which_shifts = which_shifts,
sample_period = sample_period,
)
cvals, cval_counts = numpy.unique(grid_slice, return_counts=True)
if cvals.size == 1:
levels = [cvals[0] + 1]
else:
cval_order = numpy.argsort(cval_counts)[::-1]
level_count = 2
while cval_counts[cval_order[:level_count]].sum() < level_fraction:
level_count += 1
ctr_levels = cvals[cval_order[:level_count]]
levels = numpy.diff(ctr_levels[::-1]) + ctr_levels[:0:-1]
surface = numpy.delete(range(3), plane.axis)
if ax is None:
fig, ax = pyplot.subplots()
else:
fig = ax.figure
xc, yc = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface)
xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij')
ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args)
if finalize:
pyplot.show()
@ -282,14 +173,8 @@ class GridReadMixin(GridPosMixin):
verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level)
# Convert vertices from index to position
preview_exyz = self._sampled_exyz(which_shifts, sample_period)
pos_verts = numpy.array([
[
numpy.interp(verts[i, a], numpy.arange(preview_exyz[a].size) - 0.5, preview_exyz[a])
for a in range(3)
]
for i in range(verts.shape[0])
], dtype=float)
pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False)
for i in range(verts.shape[0])], dtype=float)
xs, ys, zs = (pos_verts[:, a] for a in range(3))
# Draw the plot

View file

@ -1,9 +1,8 @@
import pytest
# import pytest
import numpy
from numpy.testing import assert_allclose #, assert_array_equal
import pickle
from .. import Grid, GridData, Extent, GridError, Plane, Slab
from .. import Grid, Extent #, Slab, Plane
def test_draw_oncenter_2x2() -> None:
@ -117,350 +116,3 @@ def test_draw_2shift_4x4() -> None:
[0, 0.125, 0.125, 0]])[None, :, :, None]
assert_allclose(arr, correct)
def test_ind2pos_round_preserves_float_centers() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]])
pos = grid.ind2pos(numpy.array([1, 0, 0]), which_shifts=0)
assert_allclose(pos, [2.0, 1.0, 0.5])
def test_ind2pos_enforces_bounds_for_rounded_and_fractional_indices() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]])
with pytest.raises(GridError):
grid.ind2pos(numpy.array([2, 0, 0]), which_shifts=0, check_bounds=True)
edge_pos = grid.ind2pos(numpy.array([1.5, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True)
assert_allclose(edge_pos, [3.0, 2.0, 1.0])
with pytest.raises(GridError):
grid.ind2pos(numpy.array([1.6, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True)
def test_draw_polygon_accepts_coplanar_nx3_vertices() -> None:
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]])
arr_2d = grid.allocate(0)
arr_3d = grid.allocate(0)
slab = dict(axis='z', center=0.5, span=1.0)
polygon_2d = numpy.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=float)
polygon_3d = numpy.array([[0, 0, 0.5],
[1, 0, 0.5],
[1, 1, 0.5],
[0, 1, 0.5]], dtype=float)
grid.draw_polygon(arr_2d, slab=slab, polygon=polygon_2d, foreground=1)
grid.draw_polygon(arr_3d, slab=slab, polygon=polygon_3d, foreground=1)
assert_allclose(arr_3d, arr_2d)
def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None:
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]])
arr = grid.allocate(0)
polygon = numpy.array([[0, 0, 0.5],
[1, 0, 0.5],
[1, 1, 0.75],
[0, 1, 0.5]], dtype=float)
with pytest.raises(GridError):
grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1)
def test_get_slice_supports_sampling() -> None:
grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]])
cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape)
grid_slice = grid.get_slice(cell_data, Plane(z=0.5), sample_period=2)
assert_allclose(grid_slice, cell_data[0, ::2, ::2, 0])
def test_sampled_visualization_helpers_do_not_error() -> None:
matplotlib = pytest.importorskip('matplotlib')
matplotlib.use('Agg')
from matplotlib import pyplot
grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]])
cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape)
fig_slice, ax_slice = grid.visualize_slice(cell_data, Plane(z=0.5), sample_period=2, finalize=False)
fig_edges, ax_edges = grid.visualize_edges(cell_data, Plane(z=0.5), sample_period=2, finalize=False)
assert fig_slice is ax_slice.figure
assert fig_edges is ax_edges.figure
pyplot.close(fig_slice)
pyplot.close(fig_edges)
def test_grid_constructor_rejects_invalid_coordinate_count() -> None:
with pytest.raises(GridError):
Grid([[0, 1], [0, 1]], shifts=[[0, 0, 0]])
with pytest.raises(GridError):
Grid([[0, 1], [0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
def test_grid_constructor_rejects_invalid_periodic_length() -> None:
with pytest.raises(GridError):
Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]], periodic=[True, False])
def test_extent_and_slab_reject_inverted_geometry() -> None:
with pytest.raises(GridError):
Extent(center=0, min=1)
with pytest.raises(GridError):
Extent(min=2, max=1)
with pytest.raises(GridError):
Slab(axis='z', center=1, max=0)
def test_extent_accepts_scalar_like_inputs() -> None:
extent = Extent(min=numpy.array([1.0]), span=numpy.array([4.0]))
assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0])
def test_get_slice_uses_shifted_grid_bounds() -> None:
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0.5, 0, 0]])
cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape)
grid_slice = grid.get_slice(cell_data, Plane(x=2.0), which_shifts=0)
assert_allclose(grid_slice, cell_data[0, 1, :, :])
with pytest.raises(GridError):
grid.get_slice(cell_data, Plane(x=2.1), which_shifts=0)
def test_draw_extrude_rectangle_uses_boundary_slice() -> None:
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]])
cell_data = grid.allocate(0)
source = numpy.array([[1, 2],
[3, 4]], dtype=float)
cell_data[0, :, :, 1] = source
grid.draw_extrude_rectangle(
cell_data,
rectangle=[[0, 0, 2], [2, 2, 2]],
direction=2,
polarity=-1,
distance=2,
)
assert_allclose(cell_data[0, :, :, 0], source)
assert_allclose(cell_data[0, :, :, 1], source)
def test_sampled_preview_exyz_tracks_nonuniform_centers() -> None:
grid = Grid([[0, 1, 3, 6, 10], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]])
sampled_exyz = grid._sampled_exyz(0, 2)
assert_allclose(sampled_exyz[0], [-1.5, 2.5, 6.5])
def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None:
matplotlib = pytest.importorskip('matplotlib')
matplotlib.use('Agg')
skimage_measure = pytest.importorskip('skimage.measure')
from matplotlib import pyplot
from mpl_toolkits.mplot3d.axes3d import Axes3D
captured: dict[str, numpy.ndarray] = {}
def fake_marching_cubes(_grid: numpy.ndarray, _level: float) -> tuple[numpy.ndarray, numpy.ndarray, None, None]:
verts = numpy.array([[0.5, 0.5, 0.5],
[0.5, 1.5, 0.5],
[1.5, 0.5, 0.5]], dtype=float)
faces = numpy.array([[0, 1, 2]], dtype=int)
return verts, faces, None, None
def fake_plot_trisurf( # noqa: ANN202
_self: object,
xs: numpy.ndarray,
ys: numpy.ndarray,
faces: numpy.ndarray,
zs: numpy.ndarray,
*_args: object,
**_kwargs: object,
) -> object:
captured['xs'] = numpy.asarray(xs)
captured['ys'] = numpy.asarray(ys)
captured['faces'] = numpy.asarray(faces)
captured['zs'] = numpy.asarray(zs)
return object()
monkeypatch.setattr(skimage_measure, 'marching_cubes', fake_marching_cubes)
monkeypatch.setattr(Axes3D, 'plot_trisurf', fake_plot_trisurf)
grid = Grid([numpy.arange(7, dtype=float), numpy.arange(7, dtype=float), numpy.arange(7, dtype=float)], shifts=[[0, 0, 0]])
cell_data = numpy.zeros(grid.cell_data_shape)
fig, _ax = grid.visualize_isosurface(cell_data, level=0.5, sample_period=2, finalize=False)
assert_allclose(captured['xs'], [1.5, 1.5, 3.5])
assert_allclose(captured['ys'], [1.5, 3.5, 1.5])
assert_allclose(captured['zs'], [1.5, 1.5, 1.5])
pyplot.close(fig)
def test_grid_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False])
path = tmp_path / 'grid.state'
grid.save(str(path))
loaded = Grid.load(str(path))
assert path.exists()
for original, restored in zip(grid.exyz, loaded.exyz, strict=True):
assert_allclose(restored, original)
assert_allclose(loaded.shifts, grid.shifts)
assert loaded.periodic == grid.periodic
def test_grid_load_supports_legacy_pickle(tmp_path: pytest.TempPathFactory) -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False])
path = tmp_path / 'grid.pickle'
with open(path, 'wb') as f:
pickle.dump(grid.__dict__, f, protocol=2)
loaded = Grid.load(str(path))
for original, restored in zip(grid.exyz, loaded.exyz, strict=True):
assert_allclose(restored, original)
assert_allclose(loaded.shifts, grid.shifts)
assert loaded.periodic == grid.periodic
def test_griddata_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None:
data = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]]).with_data(fill_value=2.0)
data.cell_data[0, 1, 0, 0] = 5.0
path = tmp_path / 'griddata.state'
data.save(str(path))
loaded = GridData.load(str(path))
assert path.exists()
assert_allclose(loaded.cell_data, data.cell_data)
assert_allclose(loaded.grid.shifts, data.grid.shifts)
assert loaded.grid.periodic == data.grid.periodic
def test_griddata_rejects_invalid_payload_kind(tmp_path: pytest.TempPathFactory) -> None:
path = tmp_path / 'grid.state'
Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).save(str(path))
with pytest.raises(GridError):
GridData.load(str(path))
def test_negative_shift_nonperiodic_edges_and_widths() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0])
assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5])
assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25])
def test_negative_shift_periodic_edges_and_widths() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False])
assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0])
assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5])
def test_negative_shift_coordinate_round_trip() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False)
pos = grid.ind2pos(ind, 0, round_ind=False)
assert_allclose(ind, [1.0, 0.0, 0.0])
assert_allclose(pos, [1.25, 1.0, 0.5])
def test_negative_shift_draw_cuboid_fractional_fill() -> None:
grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
arr = grid.allocate(0)
grid.draw_cuboid(
arr,
x=dict(min=0, max=1),
y=dict(min=0, max=1),
z=dict(min=0, max=1),
foreground=1,
)
assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3])
def test_negative_shift_get_slice_uses_shifted_centers() -> None:
grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
cell_data = numpy.zeros(grid.cell_data_shape)
cell_data[0, 1, :, 0] = [7, 9]
x_center = float(grid.shifted_xyz(0)[0][1])
grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0)
assert_allclose(grid_slice, [7, 9])
def test_grid_with_data_returns_griddata() -> None:
grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
data = grid.with_data(fill_value=2.0)
assert isinstance(data, GridData)
assert_allclose(data.cell_data, numpy.full(grid.cell_data_shape, 2.0, dtype=numpy.float32))
def test_griddata_constructor_validates_shape() -> None:
grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
with pytest.raises(GridError):
GridData(grid, numpy.zeros((1, 1, 1)))
def test_griddata_draw_methods_are_chainable() -> None:
data = Grid([[0, 1, 2], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0)
chained = data.draw_cuboid(
foreground=1,
x=dict(min=0, max=1),
y=dict(min=0, max=1),
z=dict(min=0, max=1),
).draw_polygon(
foreground=0.5,
slab=dict(axis='z', center=0.5, span=1.0),
polygon=numpy.array([[0, 0], [2, 0], [2, 1], [0, 1]], dtype=float),
)
assert chained is data
assert data.cell_data.sum() > 0
def test_griddata_read_methods_delegate() -> None:
data = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0)
data.cell_data[0, :, :, 0] = numpy.array([[1, 2], [3, 4]], dtype=float)
assert_allclose(
data.get_slice(Plane(z=0.5)),
data.grid.get_slice(data.cell_data, Plane(z=0.5)),
)
def test_griddata_copy_is_independent() -> None:
data = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=1.0)
cloned = data.copy()
cloned.cell_data[0, 0, 0, 0] = 5.0
assert data is not cloned
assert data.grid is not cloned.grid
assert data.cell_data[0, 0, 0, 0] == 1.0

View file

@ -1,30 +1,13 @@
from typing import Protocol, TypedDict, runtime_checkable, cast
from typing import Protocol, TypedDict, runtime_checkable
from dataclasses import dataclass
import numpy
class GridError(Exception):
""" Base error type for `gridlock` """
pass
def _coerce_scalar(name: str, value: object) -> float:
arr = numpy.asarray(value)
if arr.size != 1:
raise GridError(f'{name} must be a scalar value')
try:
return float(arr.reshape(()))
except (TypeError, ValueError) as exc:
raise GridError(f'{name} must be a real scalar value') from exc
class ExtentDict(TypedDict, total=False):
"""
Geometrical definition of an extent (1D bounded region)
Must contain exactly two of `min`, `max`, `center`, or `span`.
"""
min: float
center: float
max: float
@ -33,9 +16,6 @@ class ExtentDict(TypedDict, total=False):
@runtime_checkable
class ExtentProtocol(Protocol):
"""
Anything that looks like an `Extent`
"""
center: float
span: float
@ -48,10 +28,6 @@ class ExtentProtocol(Protocol):
@dataclass(init=False, slots=True)
class Extent(ExtentProtocol):
"""
Geometrical definition of an extent (1D bounded region)
May be constructed with any two of `min`, `max`, `center`, or `span`.
"""
center: float
span: float
@ -71,53 +47,47 @@ class Extent(ExtentProtocol):
max: float | None = None,
span: float | None = None,
) -> None:
values = {
'min': None if min is None else _coerce_scalar('min', min),
'center': None if center is None else _coerce_scalar('center', center),
'max': None if max is None else _coerce_scalar('max', max),
'span': None if span is None else _coerce_scalar('span', span),
}
if sum(value is not None for value in values.values()) != 2:
raise GridError('Exactly two of min, center, max, span must be provided')
if sum(cc is None for cc in (min, center, max, span)) != 2:
raise GridError('Exactly two of min, center, max, span must be None!')
min_v = values['min']
center_v = values['center']
max_v = values['max']
span_v = values['span']
if span is None:
if center is None:
assert min is not None
assert max is not None
assert max >= min
center = 0.5 * (max + min)
span = max - min
elif max is None:
assert min is not None
assert center is not None
span = 2 * (center - min)
elif min is None:
assert center is not None
assert max is not None
span = 2 * (max - center)
else: # noqa: PLR5501
if center is not None:
pass
elif max is None:
assert min is not None
assert span is not None
center = min + 0.5 * span
elif min is None:
assert max is not None
assert span is not None
center = max - 0.5 * span
if span_v is not None and span_v < 0:
raise GridError('span must be non-negative')
if min_v is not None and max_v is not None:
if max_v < min_v:
raise GridError('max must be greater than or equal to min')
center_v = 0.5 * (max_v + min_v)
span_v = max_v - min_v
elif center_v is not None and min_v is not None:
span_v = 2 * (center_v - min_v)
if span_v < 0:
raise GridError('min must be less than or equal to center')
elif center_v is not None and max_v is not None:
span_v = 2 * (max_v - center_v)
if span_v < 0:
raise GridError('center must be less than or equal to max')
elif min_v is not None and span_v is not None:
center_v = min_v + 0.5 * span_v
elif max_v is not None and span_v is not None:
center_v = max_v - 0.5 * span_v
if center_v is None or span_v is None:
raise GridError('Unable to construct extent from the provided values')
self.center = center_v
self.span = span_v
assert center is not None
assert span is not None
if hasattr(center, '__len__'):
assert len(center) == 1
if hasattr(span, '__len__'):
assert len(span) == 1
self.center = center
self.span = span
class SlabDict(TypedDict, total=False):
"""
Geometrical definition of a slab (3D region bounded on one axis only)
Must contain `axis` plus any two of `min`, `max`, `center`, or `span`.
"""
min: float
center: float
max: float
@ -127,9 +97,6 @@ class SlabDict(TypedDict, total=False):
@runtime_checkable
class SlabProtocol(ExtentProtocol, Protocol):
"""
Anything that looks like a `Slab`
"""
axis: int
center: float
span: float
@ -143,10 +110,6 @@ class SlabProtocol(ExtentProtocol, Protocol):
@dataclass(init=False, slots=True)
class Slab(Extent, SlabProtocol):
"""
Geometrical definition of a slab (3D region bounded on one axis only)
May be constructed with `axis` (bounded axis) plus any two of `min`, `max`, `center`, or `span`.
"""
axis: int
def __init__(
@ -179,10 +142,6 @@ class Slab(Extent, SlabProtocol):
class PlaneDict(TypedDict, total=False):
"""
Geometrical definition of a plane (2D unbounded region in 3D space)
Must contain exactly one of `x`, `y`, `z`, or both `axis` and `pos`
"""
x: float
y: float
z: float
@ -192,19 +151,12 @@ class PlaneDict(TypedDict, total=False):
@runtime_checkable
class PlaneProtocol(Protocol):
"""
Anything that looks like a `Plane`
"""
axis: int
pos: float
@dataclass(init=False, slots=True)
class Plane(PlaneProtocol):
"""
Geometrical definition of a plane (2D unbounded region in 3D space)
May be constructed with any of `x=4`, `y=5`, `z=-5`, or `axis=2, pos=-5`.
"""
axis: int
pos: float
@ -240,9 +192,10 @@ class Plane(PlaneProtocol):
if pos is not None:
cpos = pos
else:
cpos = cast('float', (xx, yy, zz)[axis_int])
cpos = (xx, yy, zz)[axis_int]
assert cpos is not None
if hasattr(cpos, '__len__'):
assert len(cpos) == 1
self.pos = cpos

View file

@ -75,6 +75,7 @@ lint.ignore = [
"ANN002", # *args
"ANN003", # **kwargs
"ANN401", # Any
"ANN101", # self: Self
"SIM108", # single-line if / else assignment
"RET504", # x=y+z; return x
"PIE790", # unnecessary pass