Compare commits

..

No commits in common. "master" and "cell_data" have entirely different histories.

20 changed files with 946 additions and 2105 deletions

29
.flake8
View file

@ -1,29 +0,0 @@
[flake8]
ignore =
# E501 line too long
E501,
# W391 newlines at EOF
W391,
# E241 multiple spaces after comma
E241,
# E302 expected 2 newlines
E302,
# W503 line break before binary operator (to be deprecated)
W503,
# E265 block comment should start with '# '
E265,
# E123 closing bracket does not match indentation of opening bracket's line
E123,
# E124 closing bracket does not match visual indentation
E124,
# E221 multiple spaces before operator
E221,
# E201 whitespace after '['
E201,
# E741 ambiguous variable name 'I'
E741,
per-file-ignores =
# F401 import without use
*/__init__.py: F401,

2
MANIFEST.in Normal file
View file

@ -0,0 +1,2 @@
include README.md
include LICENSE.md

View file

@ -14,7 +14,7 @@ the coordinates of the boundary points along each axis).
## Installation
Requirements:
* python >3.11 (written and tested with 3.12)
* python 3 (written and tested with 3.9)
* numpy
* [float_raster](https://mpxd.net/code/jan/float_raster)
* matplotlib (optional, used for visualization functions)

View file

@ -1 +0,0 @@
../LICENSE.md

View file

@ -1 +0,0 @@
../README.md

4
gridlock/VERSION.py Normal file
View file

@ -0,0 +1,4 @@
""" VERSION defintion. THIS FILE IS MANUALLY PARSED BY setup.py and REQUIRES A SPECIFIC FORMAT """
__version__ = '''
1.0
'''.strip()

View file

@ -15,25 +15,10 @@ Dependencies:
- mpl_toolkits.mplot3d [Grid.visualize_isosurface()]
- skimage [Grid.visualize_isosurface()]
"""
from .utils import (
GridError as GridError,
Extent as Extent,
ExtentProtocol as ExtentProtocol,
ExtentDict as ExtentDict,
Slab as Slab,
SlabProtocol as SlabProtocol,
SlabDict as SlabDict,
Plane as Plane,
PlaneProtocol as PlaneProtocol,
PlaneDict as PlaneDict,
)
from .grid import Grid as Grid
from .data import GridData as GridData
from .error import GridError
from .grid import Grid
__author__ = 'Jan Petykiewicz'
__version__ = '2.2'
from .VERSION import __version__
version = __version__

View file

@ -1,192 +0,0 @@
from typing import Protocol
import numpy
from numpy.typing import NDArray
from . import GridError
class GridBase(Protocol):
exyz: list[NDArray]
"""Cell edges. Monotonically increasing without duplicates."""
periodic: list[bool]
"""For each axis, determines how far the rightmost boundary gets shifted. """
shifts: NDArray
"""Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`"""
@property
def dxyz(self) -> list[NDArray]:
"""
Cell sizes for each axis, no shifts applied
Returns:
List of 3 ndarrays of cell sizes
"""
return [numpy.diff(ee) for ee in self.exyz]
@property
def xyz(self) -> list[NDArray]:
"""
Cell centers for each axis, no shifts applied
Returns:
List of 3 ndarrays of cell edges
"""
return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
@property
def shape(self) -> NDArray[numpy.intp]:
"""
The number of cells in x, y, and z
Returns:
ndarray of [x_centers.size, y_centers.size, z_centers.size]
"""
return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int)
@property
def num_grids(self) -> int:
"""
The number of grids (number of shifts)
"""
return self.shifts.shape[0]
@property
def cell_data_shape(self) -> NDArray[numpy.intp]:
"""
The shape of the cell_data ndarray (num_grids, *self.shape).
"""
return numpy.hstack((self.num_grids, self.shape))
@property
def dxyz_with_ghost(self) -> list[NDArray]:
"""
Gives dxyz with an additional 'ghost' cell at the end, whose value depends
on whether or not the axis has periodic boundary conditions. See main description
above to learn why this is necessary.
If periodic, final edge shifts same amount as first
Otherwise, final edge shifts same amount as second-to-last
Returns:
list of [dxs, dys, dzs] with each element same length as elements of `self.xyz`
"""
el = [0 if p else -1 for p in self.periodic]
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
if which_shifts is None:
return self.dxyz_with_ghost
shifts = self.shifts[which_shifts, :]
edge_dxyz = []
for a in range(3):
if shifts[a] < 0:
ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0]
edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a])))
else:
ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1]
edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost)))
return edge_dxyz
@property
def center(self) -> NDArray[numpy.float64]:
"""
Center position of the entire grid, no shifts applied
Returns:
ndarray of [x_center, y_center, z_center]
"""
# center is just average of first and last xyz, which is just the average of the
# first two and last two exyz
centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)]
return numpy.array(centers, dtype=float)
@property
def dxyz_limits(self) -> tuple[NDArray, NDArray]:
"""
Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element
ndarrays. No shifts are applied, so these are extreme bounds on these values (as a
weighted average is performed when shifting).
Returns:
Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]`
"""
d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float)
d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float)
return d_min, d_max
def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]:
"""
Returns edges for which_shifts.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell edges
"""
if which_shifts is None:
return self.exyz
edge_dxyz = self._shifted_edge_dxyz(which_shifts)
shifts = self.shifts[which_shifts, :]
return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)]
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
"""
Returns cell sizes for `which_shifts`.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell sizes
"""
if which_shifts is None:
return self.dxyz
return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)]
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
"""
Returns cell centers for `which_shifts`.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell centers
"""
if which_shifts is None:
return self.xyz
exyz = self.shifted_exyz(which_shifts)
dxyz = self.shifted_dxyz(which_shifts)
return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)]
def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]:
"""
Return cell widths, with each dimension shifted by the corresponding shifts.
Returns:
`[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]`
"""
if self.num_grids != 3:
raise GridError('Autoshifting requires exactly 3 grids')
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
def allocate(self, fill_value: float | None = 1.0, dtype: type[numpy.number] = numpy.float32) -> NDArray:
"""
Allocate an ndarray for storing grid data.
Args:
fill_value: Value to initialize the grid to. If None, an
uninitialized array is returned.
dtype: Numpy dtype for the array. Default is `numpy.float32`.
Returns:
The allocated array
"""
if fill_value is None:
return numpy.empty(self.cell_data_shape, dtype=dtype)
return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)

View file

@ -1,176 +0,0 @@
from dataclasses import dataclass
from typing import Self
from collections.abc import Sequence
import numpy
from numpy.typing import NDArray, ArrayLike
from .draw import foreground_t
from .grid import Grid, _grid_from_payload, _load_payload, _payload_scalar_str, _save_npz_payload
from .utils import (
ExtentDict,
ExtentProtocol,
GridError,
PlaneDict,
PlaneProtocol,
SlabDict,
SlabProtocol,
)
@dataclass(slots=True)
class GridData:
grid: Grid
cell_data: NDArray
def __post_init__(self) -> None:
if tuple(self.cell_data.shape) != tuple(self.grid.cell_data_shape):
raise GridError(
f'cell_data has shape {self.cell_data.shape}, expected {tuple(self.grid.cell_data_shape)}'
)
@staticmethod
def load(filename: str) -> 'GridData':
payload = _load_payload(filename)
if _payload_scalar_str(payload, 'kind') != 'grid_data':
raise GridError('Serialized payload does not contain GridData')
if 'cell_data' not in payload:
raise GridError('Serialized GridData payload is missing cell_data')
return GridData(_grid_from_payload(payload), numpy.array(payload['cell_data']))
def save(self, filename: str) -> Self:
payload = self.grid._serialization_payload(kind='grid_data')
payload['cell_data'] = self.cell_data
_save_npz_payload(filename, payload)
return self
def copy(self) -> Self:
return GridData(self.grid.copy(), self.cell_data.copy())
def draw_polygons(
self,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
polygons: Sequence[ArrayLike],
*,
offset2d: ArrayLike = (0, 0),
) -> Self:
self.grid.draw_polygons(self.cell_data, foreground, slab, polygons, offset2d=offset2d)
return self
def draw_polygon(
self,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
polygon: ArrayLike,
*,
offset2d: ArrayLike = (0, 0),
) -> Self:
self.grid.draw_polygon(self.cell_data, foreground, slab, polygon, offset2d=offset2d)
return self
def draw_slab(
self,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
) -> Self:
self.grid.draw_slab(self.cell_data, foreground, slab)
return self
def draw_cuboid(
self,
foreground: Sequence[foreground_t] | foreground_t,
*,
x: ExtentProtocol | ExtentDict,
y: ExtentProtocol | ExtentDict,
z: ExtentProtocol | ExtentDict,
) -> Self:
self.grid.draw_cuboid(self.cell_data, foreground, x=x, y=y, z=z)
return self
def draw_cylinder(
self,
h: SlabProtocol | SlabDict,
radius: float,
num_points: int,
center2d: ArrayLike,
foreground: Sequence[foreground_t] | foreground_t,
) -> Self:
self.grid.draw_cylinder(self.cell_data, h, radius, num_points, center2d, foreground)
return self
def draw_extrude_rectangle(
self,
rectangle: ArrayLike,
direction: int,
polarity: int,
distance: float,
) -> Self:
self.grid.draw_extrude_rectangle(self.cell_data, rectangle, direction, polarity, distance)
return self
def get_slice(
self,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1,
) -> NDArray:
return self.grid.get_slice(self.cell_data, plane, which_shifts=which_shifts, sample_period=sample_period)
def visualize_slice(
self,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1,
finalize: bool = True,
pcolormesh_args: dict[str, object] | None = None,
ax: object | None = None,
) -> tuple[object, object]:
return self.grid.visualize_slice(
self.cell_data,
plane,
which_shifts=which_shifts,
sample_period=sample_period,
finalize=finalize,
pcolormesh_args=pcolormesh_args,
ax=ax,
)
def visualize_edges(
self,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1,
finalize: bool = True,
contour_args: dict[str, object] | None = None,
ax: object | None = None,
level_fraction: float = 0.7,
) -> tuple[object, object]:
return self.grid.visualize_edges(
self.cell_data,
plane,
which_shifts=which_shifts,
sample_period=sample_period,
finalize=finalize,
contour_args=contour_args,
ax=ax,
level_fraction=level_fraction,
)
def visualize_isosurface(
self,
level: float | None = None,
which_shifts: int = 0,
sample_period: int = 1,
show_edges: bool = True,
finalize: bool = True,
) -> tuple[object, object]:
return self.grid.visualize_isosurface(
self.cell_data,
level=level,
which_shifts=which_shifts,
sample_period=sample_period,
show_edges=show_edges,
finalize=finalize,
)

10
gridlock/direction.py Normal file
View file

@ -0,0 +1,10 @@
from enum import Enum
class Direction(Enum):
"""
Enum for axis->integer mapping
"""
x = 0
y = 1
z = 2

View file

@ -1,14 +1,12 @@
"""
Drawing-related methods for Grid class
"""
from collections.abc import Sequence, Callable
from typing import List, Optional, Union, Sequence, Callable
import numpy
from numpy.typing import NDArray, ArrayLike
import numpy # type: ignore
from float_raster import raster
from .utils import GridError, Slab, SlabDict, SlabProtocol, Extent, ExtentDict, ExtentProtocol
from .position import GridPosMixin
from . import GridError
# NOTE: Maybe it would make sense to create a GridDrawer class
@ -17,387 +15,364 @@ from .position import GridPosMixin
# without having to pass `cell_data` again each time?
foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray]
foreground_t = float | foreground_callable_t
foreground_callable_t = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray]
class GridDrawMixin(GridPosMixin):
def draw_polygons(
self,
cell_data: NDArray,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
polygons: Sequence[ArrayLike],
*,
offset2d: ArrayLike = (0, 0),
) -> None:
"""
Draw polygons on an axis-aligned slab.
def draw_polygons(self,
cell_data: numpy.ndarray,
surface_normal: int,
center: numpy.ndarray,
polygons: Sequence[numpy.ndarray],
thickness: float,
foreground: Union[Sequence[Union[float, foreground_callable_t]], float, foreground_callable_t],
) -> None:
"""
Draw polygons on an axis-aligned plane.
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list
of any of these (1 per grid). Callable values should take an ndarray the shape of the
grid and return an ndarray of equal shape containing the foreground value at the given x, y,
and z (natural, not grid coordinates).
slab: `Slab` or slab-like dict specifying the slab in which the polygons will be drawn.
polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon
(non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each
polygon must have at least 3 vertices.
offset2d: 2D offset to apply to polygon coordinates -- this offset is added directly
to the given polygon vertex coordinates. Default (0, 0).
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying an offset applied to all the polygons
polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon
(non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each
polygon must have at least 3 vertices.
thickness: Thickness of the layer to draw
foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list
of any of these (1 per grid). Callable values should take an ndarray the shape of the
grid and return an ndarray of equal shape containing the foreground value at the given x, y,
and z (natural, not grid coordinates).
Raises:
GridError
"""
if isinstance(slab, dict):
slab = Slab(**slab)
Raises:
GridError
"""
if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction')
poly_list = [numpy.asarray(poly) for poly in polygons]
center = numpy.squeeze(center)
# Check polygons, and remove redundant coordinates
surface = numpy.delete(range(3), slab.axis)
# Check polygons, and remove redundant coordinates
surface = numpy.delete(range(3), surface_normal)
for ii in range(len(poly_list)):
polygon = poly_list[ii]
malformed = f'Malformed polygon: ({ii})'
if polygon.ndim != 2:
raise GridError(malformed + 'must be a 2-dimensional ndarray')
if polygon.shape[1] not in (2, 3):
for i, polygon in enumerate(polygons):
malformed = f'Malformed polygon: ({i})'
if polygon.shape[1] not in (2, 3):
raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray')
if polygon.shape[1] == 3:
if numpy.unique(polygon[:, slab.axis]).size != 1:
raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis])
polygon = polygon[:, surface]
poly_list[ii] = polygon
if polygon.shape[1] == 3:
polygon = polygon[surface, :]
if not polygon.shape[0] > 2:
raise GridError(malformed + 'must consist of more than 2 points')
if not polygon.shape[0] > 2:
raise GridError(malformed + 'must consist of more than 2 points')
if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1:
raise GridError(malformed + 'must be in plane with surface normal '
+ 'xyz'[surface_normal])
# Broadcast foreground where necessary
foregrounds: Sequence[foreground_callable_t] | Sequence[float]
if isinstance(foreground, numpy.ndarray):
raise GridError('ndarray not supported for foreground')
if callable(foreground) or numpy.isscalar(foreground):
foregrounds = [foreground] * len(cell_data) # type: ignore[list-item]
# Broadcast foreground where necessary
if numpy.size(foreground) == 1:
foreground = [foreground] * len(cell_data)
elif isinstance(foreground, numpy.ndarray):
raise GridError('ndarray not supported for foreground')
# ## Compute sub-domain of the grid occupied by polygons
# 1) Compute outer bounds (bd) of polygons
bd_2d_min = [0, 0]
bd_2d_max = [0, 0]
for polygon in polygons:
bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0))
bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0))
bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center
bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center
# 2) Find indices (bdi) just outside bd elements
buf = 2 # size of safety buffer
# Use s_min and s_max with unshifted pos2ind to get absolute limits on
# the indices the polygons might affect
s_min = self.shifts.min(axis=0)
s_max = self.shifts.max(axis=0)
bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf
bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf
bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int)
bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int)
# 3) Adjust polygons for center
polygons = [poly + center[surface] for poly in polygons]
# ## Generate weighing function
def to_3d(vector: numpy.ndarray, val: float = 0.0) -> numpy.ndarray:
v_2d = numpy.array(vector, dtype=float)
return numpy.insert(v_2d, surface_normal, (val,))
# iterate over grids
for i, grid in enumerate(cell_data):
# ## Evaluate or expand foreground[i]
if callable(foreground[i]):
# meshgrid over the (shifted) domain
domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k]+1] for k in range(3)]
(x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij')
# evaluate on the meshgrid
foreground_i = foreground[i](x0, y0, z0)
if not numpy.isfinite(foreground_i).all():
raise GridError(f'Non-finite values in foreground[{i}]')
elif numpy.size(foreground[i]) != 1:
raise GridError(f'Unsupported foreground[{i}]: {type(foreground[i])}')
else:
foregrounds = foreground # type: ignore
# foreground[i] is scalar non-callable
foreground_i = foreground[i]
# ## Compute sub-domain of the grid occupied by polygons
# 1) Compute outer bounds (bd) of polygons
bd_2d_min = numpy.array([0, 0])
bd_2d_max = numpy.array([0, 0])
for polygon in poly_list:
bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) + offset2d
bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) + offset2d
bd_min = numpy.insert(bd_2d_min, slab.axis, slab.min)
bd_max = numpy.insert(bd_2d_max, slab.axis, slab.max)
w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int))
# 2) Find indices (bdi) just outside bd elements
buf = 2 # size of safety buffer
# Use s_min and s_max with unshifted pos2ind to get absolute limits on
# the indices the polygons might affect
s_min = self.shifts.min(axis=0)
s_max = self.shifts.max(axis=0)
bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf
bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf
bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int)
bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int)
# Draw each polygon separately
for polygon in polygons:
# 3) Adjust polygons for offset2d
poly_list = [poly + offset2d for poly in poly_list]
# Get the boundaries of the polygon
pbd_min = polygon.min(axis=0)
pbd_max = polygon.max(axis=0)
# ## Generate weighing function
def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]:
v_2d = numpy.array(vector, dtype=float)
return numpy.insert(v_2d, slab.axis, (val,))
# Find indices in w_xy just outside polygon
# using per-grid xy-weights (self.shifted_xyz())
corner_min = self.pos2ind(to_3d(pbd_min), i,
check_bounds=False)[surface].astype(int)
corner_max = self.pos2ind(to_3d(pbd_max), i,
check_bounds=False)[surface].astype(int)
# iterate over grids
foreground_val: NDArray | float
for i, _ in enumerate(cell_data):
# ## Evaluate or expand foregrounds[i]
foregrounds_i = foregrounds[i]
if callable(foregrounds_i):
# meshgrid over the (shifted) domain
domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k] + 1] for k in range(3)]
(x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij')
# Find indices in w_xy which are modified by polygon
# First for the edge coordinates (+1 since we're indexing edges)
edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max)]
# Then for the pixel centers (-bdi_min since we're
# calculating weights within a subspace)
centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface],
corner_max - bdi_min[surface]))
# evaluate on the meshgrid
foreground_val = foregrounds_i(x0, y0, z0)
if not numpy.isfinite(foreground_val).all():
raise GridError(f'Non-finite values in foreground[{i}]')
elif numpy.size(foregrounds_i) != 1:
raise GridError(f'Unsupported foreground[{i}]: {type(foregrounds_i)}')
aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices))
w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y)
# Clamp overlapping polygons to 1
w_xy = numpy.minimum(w_xy, 1.0)
# 2) Generate weights in z-direction
w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], ))
def get_zi(offset, i=i, w_z=w_z):
edges = self.shifted_exyz(i)[surface_normal]
point = center[surface_normal] + offset
grid_coord = numpy.digitize(point, edges) - 1
w_coord = grid_coord - bdi_min[surface_normal]
if w_coord < 0:
w_coord = 0
f = 0
elif w_coord >= w_z.size:
w_coord = w_z.size - 1
f = 1
else:
# foreground[i] is scalar non-callable
foreground_val = foregrounds_i
dz = self.shifted_dxyz(i)[surface_normal][grid_coord]
f = (point - edges[grid_coord]) / dz
return f, w_coord
w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int))
zi_top_f, zi_top = get_zi(+thickness / 2.0)
zi_bot_f, zi_bot = get_zi(-thickness / 2.0)
# Draw each polygon separately
for polygon in poly_list:
w_z[zi_bot + 1:zi_top] = 1
# Get the boundaries of the polygon
pbd_min = polygon.min(axis=0)
pbd_max = polygon.max(axis=0)
if zi_bot < zi_top:
w_z[zi_top] = zi_top_f
w_z[zi_bot] = 1 - zi_bot_f
else:
w_z[zi_bot] = zi_top_f - zi_bot_f
# Find indices in w_xy just outside polygon
# using per-grid xy-weights (self.shifted_xyz())
corner_min = self.pos2ind(to_3d(pbd_min), i,
check_bounds=False)[surface].astype(int)
corner_max = self.pos2ind(to_3d(pbd_max), i,
check_bounds=False)[surface].astype(int)
# 3) Generate total weight function
w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,)))
# Find indices in w_xy which are modified by polygon
# First for the edge coordinates (+1 since we're indexing edges)
edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max, strict=True)]
# Then for the pixel centers (-bdi_min since we're
# calculating weights within a subspace)
centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface],
corner_max - bdi_min[surface], strict=True))
aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices, strict=True))
w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y)
# Clamp overlapping polygons to 1
w_xy = numpy.minimum(w_xy, 1.0)
# 2) Generate weights in z-direction
w_z = numpy.zeros(((bdi_max - bdi_min + 1)[slab.axis], ))
def get_zi(point: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001
edges = self.shifted_exyz(i)[slab.axis]
grid_coord = numpy.digitize(point, edges) - 1
w_coord = grid_coord - bdi_min[slab.axis]
if w_coord < 0:
w_coord = 0
f = 0
elif w_coord >= w_z.size:
w_coord = w_z.size - 1
f = 1
else:
dz = self.shifted_dxyz(i)[slab.axis][grid_coord]
f = (point - edges[grid_coord]) / dz
return f, w_coord
zi_top_f, zi_top = get_zi(slab.max)
zi_bot_f, zi_bot = get_zi(slab.min)
w_z[zi_bot + 1:zi_top] = 1
if zi_bot < zi_top:
w_z[zi_top] = zi_top_f
w_z[zi_bot] = 1 - zi_bot_f
else:
w_z[zi_bot] = zi_top_f - zi_bot_f
# 3) Generate total weight function
w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], slab.axis, (2,)))
# ## Modify the grid
g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3))
cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val
# ## Modify the grid
g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3))
cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_i
def draw_polygon(
self,
cell_data: NDArray,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
polygon: ArrayLike,
*,
offset2d: ArrayLike = (0, 0),
) -> None:
"""
Draw a polygon on an axis-aligned plane.
def draw_polygon(self,
cell_data: numpy.ndarray,
surface_normal: int,
center: numpy.ndarray,
polygon: numpy.ndarray,
thickness: float,
foreground: Union[Sequence[Union[float, foreground_callable_t]], float, foreground_callable_t],
) -> None:
"""
Draw a polygon on an axis-aligned plane.
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
slab: `Slab` or slab-like dict specifying the slab in which the polygon will be drawn.
polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed,
clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Must have at
least 3 vertices.
offset2d: 2D offset to apply to polygon coordinates -- this offset is added directly
to the given polygon vertex coordinates. Default (0, 0).
"""
self.draw_polygons(
cell_data = cell_data,
slab = slab,
polygons = [polygon],
foreground = foreground,
offset2d = offset2d,
)
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying an offset applied to the polygon
polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed,
clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at
least 3 vertices.
thickness: Thickness of the layer to draw
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
"""
self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground)
def draw_slab(
self,
cell_data: NDArray,
foreground: Sequence[foreground_t] | foreground_t,
slab: SlabProtocol | SlabDict,
) -> None:
"""
Draw an axis-aligned infinite slab.
def draw_slab(self,
cell_data: numpy.ndarray,
surface_normal: int,
center: numpy.ndarray,
thickness: float,
foreground: Union[List[Union[float, foreground_callable_t]], float, foreground_callable_t],
) -> None:
"""
Draw an axis-aligned infinite slab.
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
slab: `Slab` or slab-like dict (geometrical slab specification)
"""
if isinstance(slab, dict):
slab = Slab(**slab)
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: `surface_normal` coordinate value at the center of the slab
thickness: Thickness of the layer to draw
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
"""
# Turn surface_normal into its integer representation
if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction')
# Find center of slab
center_shift = self.center
center_shift[slab.axis] = slab.center
if numpy.size(center) != 1:
center = numpy.squeeze(center)
if len(center) == 3:
center = center[surface_normal]
else:
raise GridError(f'Bad center: {center}')
surface = numpy.delete(range(3), slab.axis)
u_min, u_max = self.exyz[surface[0]][[0, -1]]
v_min, v_max = self.exyz[surface[1]][[0, -1]]
# Find center of slab
center_shift = self.center
center_shift[surface_normal] = center
margin = 4 * numpy.max([self.dxyz[surface[0]].max(),
self.dxyz[surface[1]].max()])
surface = numpy.delete(range(3), surface_normal)
p = numpy.array([[u_min - margin, v_max + margin],
[u_max + margin, v_max + margin],
[u_max + margin, v_min - margin],
[u_min - margin, v_min - margin]], dtype=float)
xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface]
xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface]
self.draw_polygon(
cell_data = cell_data,
slab = slab,
polygon = p,
foreground = foreground,
)
dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float)
xyz_min -= 4 * dxyz
xyz_max += 4 * dxyz
p = numpy.array([[xyz_min[0], xyz_max[1]],
[xyz_max[0], xyz_max[1]],
[xyz_max[0], xyz_min[1]],
[xyz_min[0], xyz_min[1]]], dtype=float)
self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground)
def draw_cuboid(
self,
cell_data: NDArray,
foreground: Sequence[foreground_t] | foreground_t,
*,
x: ExtentProtocol | ExtentDict,
y: ExtentProtocol | ExtentDict,
z: ExtentProtocol | ExtentDict,
) -> None:
"""
Draw an axis-aligned cuboid
def draw_cuboid(self,
cell_data: numpy.ndarray,
center: numpy.ndarray,
dimensions: numpy.ndarray,
foreground: Union[List[Union[float, foreground_callable_t]], float, foreground_callable_t],
) -> None:
"""
Draw an axis-aligned cuboid
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
x: `Extent` or extent-like dict specifying the x-extent of the cuboid.
y: `Extent` or extent-like dict specifying the y-extent of the cuboid.
z: `Extent` or extent-like dict specifying the z-extent of the cuboid.
"""
if isinstance(x, dict):
x = Extent(**x)
if isinstance(y, dict):
y = Extent(**y)
if isinstance(z, dict):
z = Extent(**z)
p = numpy.array([[x.min, y.max],
[x.max, y.max],
[x.max, y.min],
[x.min, y.min]], dtype=float)
slab = Slab(axis=2, center=z.center, span=z.span)
self.draw_polygon(cell_data=cell_data, slab=slab, polygon=p, foreground=foreground)
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
center: 3-element ndarray or list specifying the cuboid's center
dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge
sizes of the cuboid
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
"""
p = numpy.array([[-dimensions[0], +dimensions[1]],
[+dimensions[0], +dimensions[1]],
[+dimensions[0], -dimensions[1]],
[-dimensions[0], -dimensions[1]]], dtype=float) / 2.0
thickness = dimensions[2]
self.draw_polygon(cell_data, 2, center, p, thickness, foreground)
def draw_cylinder(
self,
cell_data: NDArray,
h: SlabProtocol | SlabDict,
radius: float,
num_points: int,
center2d: ArrayLike,
foreground: Sequence[foreground_t] | foreground_t,
) -> None:
"""
Draw an axis-aligned cylinder. Approximated by a num_points-gon
def draw_cylinder(self,
cell_data: numpy.ndarray,
surface_normal: int,
center: numpy.ndarray,
radius: float,
thickness: float,
num_points: int,
foreground: Union[List[Union[float, foreground_callable_t]], float, foreground_callable_t],
) -> None:
"""
Draw an axis-aligned cylinder. Approximated by a num_points-gon
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
h:
radius:
num_points: The circle is approximated by a polygon with `num_points` vertices
center2d:
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
"""
if isinstance(h, dict):
h = Slab(**h)
theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False)[:, None]
xy0 = numpy.hstack((numpy.sin(theta), numpy.cos(theta)))
polygon = radius * xy0
self.draw_polygon(cell_data=cell_data, slab=h, polygon=polygon, foreground=foreground, offset2d=center2d)
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying the cylinder's center
radius: cylinder radius
thickness: Thickness of the layer to draw
num_points: The circle is approximated by a polygon with `num_points` vertices
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
"""
theta = numpy.linspace(0, 2*numpy.pi, num_points, endpoint=False)
x = radius * numpy.sin(theta)
y = radius * numpy.cos(theta)
polygon = numpy.hstack((x[:, None], y[:, None]))
self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground)
def draw_extrude_rectangle(
self,
cell_data: NDArray,
rectangle: ArrayLike,
direction: int,
polarity: int,
distance: float,
) -> None:
"""
Extrude a rectangle of a previously-drawn structure along an axis.
def draw_extrude_rectangle(self,
cell_data: numpy.ndarray,
rectangle: numpy.ndarray,
direction: int,
polarity: int,
distance: float,
) -> None:
"""
Extrude a rectangle of a previously-drawn structure along an axis.
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
rectangle: 2x3 ndarray or list specifying the rectangle's corners
direction: Direction to extrude in. Integer in `range(3)`.
polarity: +1 or -1, direction along axis to extrude in
distance: How far to extrude
"""
sgn = numpy.sign(polarity)
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
rectangle: 2x3 ndarray or list specifying the rectangle's corners
direction: Direction to extrude in. Integer in `range(3)`.
polarity: +1 or -1, direction along axis to extrude in
distance: How far to extrude
"""
s = numpy.sign(polarity)
rectangle = numpy.asarray(rectangle, dtype=float)
if sgn == 0:
raise GridError('0 is not a valid polarity')
if direction not in range(3):
raise GridError(f'Invalid direction: {direction}')
if rectangle[0, direction] != rectangle[1, direction]:
raise GridError('Rectangle entries along extrusion direction do not match.')
rectangle = numpy.array(rectangle, dtype=float)
if s == 0:
raise GridError('0 is not a valid polarity')
if direction not in range(3):
raise GridError(f'Invalid direction: {direction}')
if rectangle[0, direction] != rectangle[1, direction]:
raise GridError('Rectangle entries along extrusion direction do not match.')
center = rectangle.sum(axis=0) / 2.0
center[direction] += sgn * distance / 2.0
center = rectangle.sum(axis=0) / 2.0
center[direction] += s * distance / 2.0
surface = numpy.delete(range(3), direction)
surface = numpy.delete(range(3), direction)
dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface]
poly = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5,
numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T
thickness = distance
dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface]
p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0]/2.0,
numpy.array([-1, 1, 1, -1], dtype=float) * dim[1]/2.0)).T
thickness = distance
foreground_func = []
for ii, grid in enumerate(cell_data):
zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction]
fpart = zz - numpy.floor(zz)
low = int(numpy.clip(numpy.floor(zz), 0, grid.shape[direction] - 1))
high = int(numpy.clip(numpy.floor(zz) + 1, 0, grid.shape[direction] - 1))
foreground_func = []
for i, grid in enumerate(cell_data):
z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction]
low_ind = [low if dd == direction else slice(None) for dd in range(3)]
high_ind = [high if dd == direction else slice(None) for dd in range(3)]
ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)]
if low == high:
foreground = grid[tuple(low_ind)]
else:
mult = [1 - fpart, fpart][::sgn] # reverses if s negative
foreground = mult[0] * grid[tuple(low_ind)] + mult[1] * grid[tuple(high_ind)]
fpart = z - numpy.floor(z)
mult = [1-fpart, fpart][::s] # reverses if s negative
def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001
# transform from natural position to index
xyzi = numpy.array([self.pos2ind(qrs, which_shifts=ii)
for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64)
# reshape to original shape and keep only in-plane components
qi, ri = (numpy.reshape(xyzi[:, kk], xs.shape) for kk in surface)
return foreground[qi, ri]
foreground = mult[0] * grid[tuple(ind)]
ind[direction] += 1
foreground += mult[1] * grid[tuple(ind)]
foreground_func.append(f_foreground)
def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> numpy.ndarray:
# transform from natural position to index
xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i)
for qrs in zip(xs.flat, ys.flat, zs.flat)], dtype=int)
# reshape to original shape and keep only in-plane components
qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface)
return foreground[qi, ri]
foreground_func.append(f_foreground)
self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func)
slab = Slab(axis=direction, center=center[direction], span=thickness)
self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface])

2
gridlock/error.py Normal file
View file

@ -0,0 +1,2 @@
class GridError(Exception):
pass

View file

@ -1,4 +1,4 @@
import numpy
import numpy # type: ignore
from gridlock import Grid
@ -6,18 +6,18 @@ if __name__ == '__main__':
# xyz = [numpy.arange(-5.0, 6.0), numpy.arange(-4.0, 5.0), [-1.0, 1.0]]
# eg = Grid(xyz)
# egc = Grid.allocate(0.0)
# # eg.draw_slab(egc, slab=dict(axis=2, center=0, span=10), foreground=2)
# eg.draw_cylinder(egc, h=slab(axis=2, center=0, span=10),
# center2d=[0, 0], radius=4, thickness=10, num_points=1000, foreground=1)
# eg.visualize_slice(egc, plane=dict(z=0), which_shifts=2)
# # eg.draw_slab(egc, surface_normal=2, center=0, thickness=10, foreground=2)
# eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=4,
# thickness=10, num_points=1000, foreground=1)
# eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2)
# xyz2 = [numpy.arange(-5.0, 6.0), [-1.0, 1.0], numpy.arange(-4.0, 5.0)]
# eg2 = Grid(xyz2)
# eg2c = Grid.allocate(0.0)
# # eg2.draw_slab(eg2c, slab=dict(axis=2, center=0, span=10), foreground=2)
# eg2.draw_cylinder(eg2c, h=slab(axis=1, center=0, span=10), center2d=[0, 0],
# radius=4, num_points=1000, foreground=1.0)
# eg2.visualize_slice(eg2c, plane=dict(y=0), which_shifts=1)
# # eg2.draw_slab(eg2c, surface_normal=2, center=0, thickness=10, foreground=2)
# eg2.draw_cylinder(eg2c, surface_normal=1, center=[0, 0, 0],
# radius=4, thickness=10, num_points=1000, foreground=1.0)
# eg2.visualize_slice(eg2c, surface_normal=1, center=0, which_shifts=1)
# n = 20
# m = 3
@ -29,27 +29,16 @@ if __name__ == '__main__':
# numpy.linspace(-5.5, 5.5, 10)]
half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5]
xyz3 = [numpy.array([-x for x in half_x[::-1]] + [0] + half_x, dtype=float),
numpy.linspace(-5.5, 5.5, 10, dtype=float),
numpy.linspace(-5.5, 5.5, 10, dtype=float)]
xyz3 = [[-x for x in half_x[::-1]] + [0] + half_x,
numpy.linspace(-5.5, 5.5, 10),
numpy.linspace(-5.5, 5.5, 10)]
eg = Grid(xyz3)
egc = eg.allocate(0)
# eg.draw_slab(Direction.z, 0, 10, 2)
eg.save('/home/jan/Desktop/test.pickle')
eg.draw_cylinder(
egc,
h=dict(axis='z', center=0, span=10),
center2d=[0, 0],
radius=2.0,
num_points=1000,
foreground=1,
)
eg.draw_extrude_rectangle(
egc,
rectangle=[[-2, 1, -1], [0, 1, 1]],
direction=1,
polarity=+1,
distance=5,
)
eg.visualize_slice(egc, plane=dict(z=0), which_shifts=2)
eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0,
thickness=10, num_poitns=1000, foreground=1)
eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]],
direction=1, poalarity=+1, distance=5)
eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2)
eg.visualize_isosurface(egc, which_shifts=2)

View file

@ -1,93 +1,20 @@
from typing import TYPE_CHECKING, Any, ClassVar, Self
from collections.abc import Callable, Sequence
from typing import List, Tuple, Callable, Dict, Optional, Union, Sequence, ClassVar, TypeVar
import numpy
from numpy.typing import NDArray, ArrayLike
import numpy # type: ignore
from numpy import diff, floor, ceil, zeros, hstack, newaxis
import pickle
import warnings
import copy
from . import GridError
from .draw import GridDrawMixin
from .read import GridReadMixin
from .position import GridPosMixin
if TYPE_CHECKING:
from .data import GridData
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
_FORMAT_VERSION = 1
foreground_callable_type = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray]
T = TypeVar('T', bound='Grid')
def _is_npz_file(filename: str) -> bool:
with open(filename, 'rb') as f:
return f.read(2) == b'PK'
def _save_npz_payload(filename: str, payload: dict[str, Any]) -> None:
with open(filename, 'wb') as f:
numpy.savez_compressed(f, **payload)
def _load_payload(filename: str) -> dict[str, Any]:
if _is_npz_file(filename):
with numpy.load(filename, allow_pickle=False) as payload:
return {key: payload[key] for key in payload.files}
with open(filename, 'rb') as f:
legacy = pickle.load(f)
if isinstance(legacy, Grid):
return legacy._serialization_payload(kind='grid')
if isinstance(legacy, dict):
grid = Grid([[-1, 1]] * 3)
grid.__dict__.update(legacy)
return grid._serialization_payload(kind='grid')
raise GridError('Unsupported serialized Grid payload')
def _payload_scalar_str(payload: dict[str, Any], key: str) -> str:
if key not in payload:
raise GridError(f'Missing serialized key: {key}')
value = numpy.asarray(payload[key])
if value.size != 1:
raise GridError(f'Serialized key {key} must be scalar')
return str(value.reshape(()))
def _payload_scalar_int(payload: dict[str, Any], key: str) -> int:
if key not in payload:
raise GridError(f'Missing serialized key: {key}')
value = numpy.asarray(payload[key])
if value.size != 1:
raise GridError(f'Serialized key {key} must be scalar')
return int(value.reshape(()))
def _grid_from_payload(payload: dict[str, Any]) -> 'Grid':
if _payload_scalar_int(payload, 'format_version') != _FORMAT_VERSION:
raise GridError('Unsupported serialized Grid format version')
exyz = []
for axis in range(3):
key = f'exyz_{axis}'
if key not in payload:
raise GridError(f'Missing serialized key: {key}')
exyz.append(numpy.array(payload[key], dtype=float))
if 'shifts' not in payload or 'periodic' not in payload:
raise GridError('Serialized Grid payload is missing shifts or periodic data')
shifts = numpy.array(payload['shifts'], dtype=float)
periodic = numpy.array(payload['periodic'], dtype=bool).tolist()
return Grid(exyz, shifts=shifts, periodic=periodic)
class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
class Grid:
"""
Simulation grid metadata for finite-difference simulations.
@ -121,35 +48,217 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
Because of this, we either assume this 'ghost' cell is the same size as the last
real cell, or, if `self.periodic[a]` is set to `True`, the same size as the first cell.
"""
exyz: list[NDArray]
exyz: List[numpy.ndarray]
"""Cell edges. Monotonically increasing without duplicates."""
periodic: list[bool]
periodic: List[bool]
"""For each axis, determines how far the rightmost boundary gets shifted. """
shifts: NDArray
shifts: numpy.ndarray
"""Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`"""
Yee_Shifts_E: ClassVar[NDArray] = 0.5 * numpy.array([
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
], dtype=float)
Yee_Shifts_E: ClassVar[numpy.ndarray] = 0.5 * numpy.array([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]], dtype=float)
"""Default shifts for Yee grid E-field"""
Yee_Shifts_H: ClassVar[NDArray] = 0.5 * numpy.array([
[0, 1, 1],
[1, 0, 1],
[1, 1, 0],
], dtype=float)
Yee_Shifts_H: ClassVar[numpy.ndarray] = 0.5 * numpy.array([[0, 1, 1],
[1, 0, 1],
[1, 1, 0]], dtype=float)
"""Default shifts for Yee grid H-field"""
def __init__(
self,
pixel_edge_coordinates: Sequence[ArrayLike],
shifts: ArrayLike = Yee_Shifts_E,
periodic: bool | Sequence[bool] = False,
) -> None:
from .draw import (
draw_polygons, draw_polygon, draw_slab, draw_cuboid,
draw_cylinder, draw_extrude_rectangle,
)
from .read import get_slice, visualize_slice, visualize_isosurface
from .position import ind2pos, pos2ind
@property
def dxyz(self) -> List[numpy.ndarray]:
"""
Cell sizes for each axis, no shifts applied
Returns:
List of 3 ndarrays of cell sizes
"""
return [numpy.diff(ee) for ee in self.exyz]
@property
def xyz(self) -> List[numpy.ndarray]:
"""
Cell centers for each axis, no shifts applied
Returns:
List of 3 ndarrays of cell edges
"""
return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
@property
def shape(self) -> numpy.ndarray:
"""
The number of cells in x, y, and z
Returns:
ndarray of [x_centers.size, y_centers.size, z_centers.size]
"""
return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int)
@property
def num_grids(self) -> int:
"""
The number of grids (number of shifts)
"""
return self.shifts.shape[0]
@property
def cell_data_shape(self):
"""
The shape of the cell_data ndarray (num_grids, *self.shape).
"""
return numpy.hstack((self.num_grids, self.shape))
@property
def dxyz_with_ghost(self) -> List[numpy.ndarray]:
"""
Gives dxyz with an additional 'ghost' cell at the end, whose value depends
on whether or not the axis has periodic boundary conditions. See main description
above to learn why this is necessary.
If periodic, final edge shifts same amount as first
Otherwise, final edge shifts same amount as second-to-last
Returns:
list of [dxs, dys, dzs] with each element same length as elements of `self.xyz`
"""
el = [0 if p else -1 for p in self.periodic]
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el)]
@property
def center(self) -> numpy.ndarray:
"""
Center position of the entire grid, no shifts applied
Returns:
ndarray of [x_center, y_center, z_center]
"""
# center is just average of first and last xyz, which is just the average of the
# first two and last two exyz
centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)]
return numpy.array(centers, dtype=float)
@property
def dxyz_limits(self) -> Tuple[numpy.ndarray, numpy.ndarray]:
"""
Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element
ndarrays. No shifts are applied, so these are extreme bounds on these values (as a
weighted average is performed when shifting).
Returns:
Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]`
"""
d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float)
d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float)
return d_min, d_max
def shifted_exyz(self, which_shifts: Optional[int]) -> List[numpy.ndarray]:
"""
Returns edges for which_shifts.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell edges
"""
if which_shifts is None:
return self.exyz
dxyz = self.dxyz_with_ghost
shifts = self.shifts[which_shifts, :]
# If shift is negative, use left cell's dx to determine shift
for a in range(3):
if shifts[a] < 0:
dxyz[a] = numpy.roll(dxyz[a], 1)
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
def shifted_dxyz(self, which_shifts: Optional[int]) -> List[numpy.ndarray]:
"""
Returns cell sizes for `which_shifts`.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell sizes
"""
if which_shifts is None:
return self.dxyz
shifts = self.shifts[which_shifts, :]
dxyz = self.dxyz_with_ghost
# If shift is negative, use left cell's dx to determine size
sdxyz = []
for a in range(3):
if shifts[a] < 0:
roll_dxyz = numpy.roll(dxyz[a], 1)
abs_shift = numpy.abs(shifts[a])
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
else:
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
return sdxyz
def shifted_xyz(self, which_shifts: Optional[int]) -> List[numpy.ndarray]:
"""
Returns cell centers for `which_shifts`.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell centers
"""
if which_shifts is None:
return self.xyz
exyz = self.shifted_exyz(which_shifts)
dxyz = self.shifted_dxyz(which_shifts)
return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)]
def autoshifted_dxyz(self) -> List[numpy.ndarray]:
"""
Return cell widths, with each dimension shifted by the corresponding shifts.
Returns:
`[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]`
"""
if self.num_grids != 3:
raise GridError('Autoshifting requires exactly 3 grids')
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
def allocate(self, fill_value: Optional[float] = 1.0, dtype=numpy.float32) -> numpy.ndarray:
"""
Allocate an ndarray for storing grid data.
Args:
fill_value: Value to initialize the grid to. If None, an
uninitialized array is returned.
dtype: Numpy dtype for the array. Default is `numpy.float32`.
Returns:
The allocated array
"""
if fill_value is None:
return numpy.empty(self.cell_data_shape, dtype=dtype)
else:
return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)
def __init__(self,
pixel_edge_coordinates: Sequence[numpy.ndarray],
shifts: numpy.ndarray = Yee_Shifts_E,
periodic: Union[bool, Sequence[bool]] = False,
) -> None:
"""
Args:
pixel_edge_coordinates: 3-element list of (ndarrays or lists) specifying the
@ -164,24 +273,17 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
Raises:
`GridError` on invalid input
"""
edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates]
if len(edge_arrs) != 3:
raise GridError('pixel_edge_coordinates must contain exactly 3 coordinate arrays')
self.exyz = [numpy.unique(edges) for edges in edge_arrs]
self.exyz = [numpy.unique(pixel_edge_coordinates[i]) for i in range(3)]
self.shifts = numpy.array(shifts, dtype=float)
for i in range(3):
if self.exyz[i].size != edge_arrs[i].size:
if len(self.exyz[i]) != len(pixel_edge_coordinates[i]):
warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2)
if isinstance(periodic, bool):
self.periodic = [periodic] * 3
else:
self.periodic = list(periodic)
if len(self.periodic) != 3:
raise GridError('periodic must be a bool or a sequence of length 3')
if not all(isinstance(pp, bool | numpy.bool_) for pp in self.periodic):
raise GridError('periodic sequence entries must be bool values')
if len(self.shifts.shape) != 2:
raise GridError('Misshapen shifts: shifts must have two axes! '
@ -193,16 +295,9 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
if (numpy.abs(self.shifts) > 1).any():
raise GridError('Only shifts in the range [-1, 1] are currently supported')
def _serialization_payload(self, *, kind: str) -> dict[str, Any]:
payload: dict[str, Any] = {
'kind': numpy.array(kind),
'format_version': numpy.array(_FORMAT_VERSION, dtype=int),
'shifts': self.shifts,
'periodic': numpy.array(self.periodic, dtype=bool),
}
for axis, exyz in enumerate(self.exyz):
payload[f'exyz_{axis}'] = exyz
return payload
if (self.shifts < 0).any():
# TODO: Test negative shifts
warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2)
@staticmethod
def load(filename: str) -> 'Grid':
@ -212,13 +307,14 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
Args:
filename: Filename to load from.
"""
payload = _load_payload(filename)
kind = _payload_scalar_str(payload, 'kind')
if kind not in ('grid', 'grid_data'):
raise GridError(f'Unsupported serialized kind: {kind}')
return _grid_from_payload(payload)
with open(filename, 'rb') as f:
tmp_dict = pickle.load(f)
def save(self, filename: str) -> Self:
g = Grid([[-1, 1]] * 3)
g.__dict__.update(tmp_dict)
return g
def save(self: T, filename: str) -> T:
"""
Save to file.
@ -228,19 +324,11 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
Returns:
self
"""
_save_npz_payload(filename, self._serialization_payload(kind='grid'))
with open(filename, 'wb') as f:
pickle.dump(self.__dict__, f, protocol=2)
return self
def with_data(
self,
fill_value: float | None = 1.0,
dtype: type[numpy.number] = numpy.float32,
) -> 'GridData':
from .data import GridData
return GridData(self.copy(), self.allocate(fill_value=fill_value, dtype=dtype))
def copy(self) -> Self:
def copy(self: T) -> T:
"""
Returns:
Deep copy of the grid.

View file

@ -1,118 +1,115 @@
"""
Position-related methods for Grid class
"""
import numpy
from numpy.typing import NDArray, ArrayLike
from typing import List, Optional
import numpy # type: ignore
from . import GridError
from .base import GridBase
class GridPosMixin(GridBase):
def ind2pos(
self,
ind: NDArray,
which_shifts: int | None = None,
def ind2pos(self,
ind: numpy.ndarray,
which_shifts: Optional[int] = None,
round_ind: bool = True,
check_bounds: bool = True
) -> NDArray[numpy.float64]:
"""
Returns the natural position corresponding to the specified cell center indices.
The resulting position is clipped to the bounds of the grid
(to cell centers if `round_ind=True`, or cell outer edges if `round_ind=False`)
) -> numpy.ndarray:
"""
Returns the natural position corresponding to the specified cell center indices.
The resulting position is clipped to the bounds of the grid
(to cell centers if `round_ind=True`, or cell outer edges if `round_ind=False`)
Args:
ind: Indices of the position. Can be fractional. (3-element ndarray or list)
which_shifts: which grid number (`shifts`) to use
round_ind: Whether to round ind to the nearest integer position before indexing
(default `True`)
check_bounds: Whether to raise an `GridError` if the provided ind is outside of
the grid, as defined above (centers if `round_ind`, else edges) (default `True`)
Args:
ind: Indices of the position. Can be fractional. (3-element ndarray or list)
which_shifts: which grid number (`shifts`) to use
round_ind: Whether to round ind to the nearest integer position before indexing
(default `True`)
check_bounds: Whether to raise an `GridError` if the provided ind is outside of
the grid, as defined above (centers if `round_ind`, else edges) (default `True`)
Returns:
3-element ndarray specifying the natural position
Returns:
3-element ndarray specifying the natural position
Raises:
`GridError` if invalid `which_shifts`
`GridError` if `check_bounds` and out of bounds
"""
if which_shifts is not None and which_shifts >= self.shifts.shape[0]:
raise GridError('Invalid shifts')
ind = numpy.array(ind, dtype=float)
if check_bounds:
if round_ind:
low_bound = 0.0
high_bound = -1.0
else:
low_bound = -0.5
high_bound = -0.5
if (ind < low_bound).any() or (ind > self.shape + high_bound).any():
raise GridError(f'Position outside of grid: {ind}')
Raises:
`GridError` if invalid `which_shifts`
`GridError` if `check_bounds` and out of bounds
"""
if which_shifts is not None and which_shifts >= self.shifts.shape[0]:
raise GridError('Invalid shifts')
ind = numpy.array(ind, dtype=float)
if check_bounds:
if round_ind:
rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1)
sxyz = self.shifted_xyz(which_shifts)
position = [sxyz[a][rind[a]] for a in range(3)]
low_bound = 0.0
high_bound = -1.0
else:
sexyz = self.shifted_exyz(which_shifts)
position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a])
for a in range(3)]
return numpy.array(position, dtype=float)
low_bound = -0.5
high_bound = -0.5
if (ind < low_bound).any() or (ind > self.shape - high_bound).any():
raise GridError(f'Position outside of grid: {ind}')
if round_ind:
rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1)
sxyz = self.shifted_xyz(which_shifts)
position = [sxyz[a][rind[a]].astype(int) for a in range(3)]
else:
sexyz = self.shifted_exyz(which_shifts)
position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a])
for a in range(3)]
return numpy.array(position, dtype=float)
def pos2ind(
self,
r: ArrayLike,
which_shifts: int | None,
def pos2ind(self,
r: numpy.ndarray,
which_shifts: Optional[int],
round_ind: bool = True,
check_bounds: bool = True
) -> NDArray[numpy.float64]:
"""
Returns the cell-center indices corresponding to the specified natural position.
The resulting position is clipped to within the outer centers of the grid.
) -> numpy.ndarray:
"""
Returns the cell-center indices corresponding to the specified natural position.
The resulting position is clipped to within the outer centers of the grid.
Args:
r: Natural position that we will convert into indices (3-element ndarray or list)
which_shifts: which grid number (`shifts`) to use
round_ind: Whether to round the returned indices to the nearest integers.
check_bounds: Whether to throw an `GridError` if `r` is outside the grid edges
Args:
r: Natural position that we will convert into indices (3-element ndarray or list)
which_shifts: which grid number (`shifts`) to use
round_ind: Whether to round the returned indices to the nearest integers.
check_bounds: Whether to throw an `GridError` if `r` is outside the grid edges
Returns:
3-element ndarray specifying the indices
Returns:
3-element ndarray specifying the indices
Raises:
`GridError` if invalid `which_shifts`
`GridError` if `check_bounds` and out of bounds
"""
r = numpy.squeeze(r)
if r.size != 3:
raise GridError(f'r must be 3-element vector: {r}')
Raises:
`GridError` if invalid `which_shifts`
`GridError` if `check_bounds` and out of bounds
"""
r = numpy.squeeze(r)
if r.size != 3:
raise GridError(f'r must be 3-element vector: {r}')
if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]):
raise GridError(f'Invalid which_shifts: {which_shifts}')
if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]):
raise GridError(f'Invalid which_shifts: {which_shifts}')
sexyz = self.shifted_exyz(which_shifts)
sexyz = self.shifted_exyz(which_shifts)
if check_bounds:
for a in range(3):
if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]):
raise GridError(f'Position[{a}] outside of grid!')
grid_pos = numpy.zeros((3,))
if check_bounds:
for a in range(3):
xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in
xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds
if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]):
raise GridError(f'Position[{a}] outside of grid!')
# No need to interpolate if round_ind is true or we were outside the grid
if round_ind or xi != xi_clipped:
grid_pos[a] = xi_clipped
else:
# Interpolate
x = self.shifted_xyz(which_shifts)[a][xi]
dx = self.shifted_dxyz(which_shifts)[a][xi]
f = (r[a] - x) / dx
grid_pos = numpy.zeros((3,))
for a in range(3):
xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in
xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds
# Clip to centers
grid_pos[a] = numpy.clip(xi + f, 0, self.shape[a] - 1)
return grid_pos
# No need to interpolate if round_ind is true or we were outside the grid
if round_ind or xi != xi_clipped:
grid_pos[a] = xi_clipped
else:
# Interpolate
x = self.shifted_xyz(which_shifts)[a][xi]
dx = self.shifted_dxyz(which_shifts)[a][xi]
f = (r[a] - x) / dx
# Clip to centers
grid_pos[a] = numpy.clip(xi + f, 0, self.shape[a] - 1)
return grid_pos

View file

@ -1,318 +1,183 @@
"""
Readback and visualization methods for Grid class
"""
from typing import Any, TYPE_CHECKING
from typing import Dict, Optional, Union, Any
import numpy
from numpy.typing import NDArray
from .utils import GridError, Plane, PlaneDict, PlaneProtocol
from .position import GridPosMixin
if TYPE_CHECKING:
import matplotlib.axes
import matplotlib.figure
import numpy # type: ignore
from . import GridError
# .visualize_* uses matplotlib
# .visualize_isosurface uses skimage
# .visualize_isosurface uses mpl_toolkits.mplot3d
class GridReadMixin(GridPosMixin):
@staticmethod
def _preview_exyz_from_centers(centers: NDArray, fallback_edges: NDArray) -> NDArray[numpy.float64]:
if centers.size > 1:
midpoints = 0.5 * (centers[:-1] + centers[1:])
first = centers[0] - 0.5 * (centers[1] - centers[0])
last = centers[-1] + 0.5 * (centers[-1] - centers[-2])
return numpy.hstack(([first], midpoints, [last]))
return numpy.array([fallback_edges[0], fallback_edges[-1]], dtype=float)
def get_slice(self,
cell_data: numpy.ndarray,
surface_normal: int,
center: float,
which_shifts: int = 0,
sample_period: int = 1
) -> numpy.ndarray:
"""
Retrieve a slice of a grid.
Interpolates if given a position between two planes.
def _sampled_exyz(self, which_shifts: int, sample_period: int) -> list[NDArray[numpy.float64]]:
if sample_period <= 1:
return self.shifted_exyz(which_shifts)
Args:
cell_data: Cell data to slice
surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
center: Scalar specifying position along surface_normal axis.
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
shifted_xyz = self.shifted_xyz(which_shifts)
shifted_exyz = self.shifted_exyz(which_shifts)
return [
self._preview_exyz_from_centers(shifted_xyz[a][::sample_period], shifted_exyz[a])
for a in range(3)
]
Returns:
Array containing the portion of the grid.
"""
if numpy.size(center) != 1 or not numpy.isreal(center):
raise GridError('center must be a real scalar')
def get_slice(
self,
cell_data: NDArray,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1
) -> NDArray:
"""
Retrieve a slice of a grid.
Interpolates if given a position between two grid planes.
sp = round(sample_period)
if sp <= 0:
raise GridError('sample_period must be positive')
Args:
cell_data: Cell data to slice
plane: Axis and position (`Plane`) of the plane to read.
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
if numpy.size(which_shifts) != 1 or which_shifts < 0:
raise GridError('Invalid which_shifts')
Returns:
Array containing the portion of the grid.
"""
if isinstance(plane, dict):
plane = Plane(**plane)
if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction')
sp = round(sample_period)
if sp <= 0:
raise GridError('sample_period must be positive')
surface = numpy.delete(range(3), surface_normal)
if numpy.size(which_shifts) != 1 or which_shifts < 0:
raise GridError('Invalid which_shifts')
# Extract indices and weights of planes
center3 = numpy.insert([0, 0], surface_normal, (center,))
center_index = self.pos2ind(center3, which_shifts,
round_ind=False, check_bounds=False)[surface_normal]
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
if len(centers) == 2:
fpart = center_index - numpy.floor(center_index)
w = [1 - fpart, fpart] # longer distance -> less weight
else:
w = [1]
surface = numpy.delete(range(3), plane.axis)
c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1])
if center < c_min or center > c_max:
raise GridError('Coordinate of selected plane must be within simulation domain')
# Extract indices and weights of planes
center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,))
center_index = self.pos2ind(center3, which_shifts,
round_ind=False, check_bounds=False)[plane.axis]
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
if len(centers) == 2:
fpart = center_index - numpy.floor(center_index)
w = [1 - fpart, fpart] # longer distance -> less weight
else:
w = [1]
# Extract grid values from planes above and below visualized slice
sliced_grid = numpy.zeros(self.shape[surface])
for ci, weight in zip(centers, w):
s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3))
sliced_grid += weight * cell_data[which_shifts][tuple(s)]
c_min, c_max = (self.shifted_xyz(which_shifts)[plane.axis][i] for i in [0, -1])
if plane.pos < c_min or plane.pos > c_max:
raise GridError('Coordinate of selected plane must be within simulation domain')
# Remove extra dimensions
sliced_grid = numpy.squeeze(sliced_grid)
# Extract grid values from planes above and below visualized slice
sample_shape = tuple(self.shifted_xyz(which_shifts)[a][::sp].size for a in surface)
sliced_grid = numpy.zeros(sample_shape, dtype=numpy.result_type(cell_data.dtype, float))
for ci, weight in zip(centers, w, strict=True):
s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3))
sliced_grid += weight * cell_data[which_shifts][tuple(s)]
# Remove extra dimensions
sliced_grid = numpy.squeeze(sliced_grid)
return sliced_grid
return sliced_grid
def visualize_slice(
self,
cell_data: NDArray,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1,
finalize: bool = True,
pcolormesh_args: dict[str, Any] | None = None,
ax: 'matplotlib.axes.Axes | None' = None,
) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']:
"""
Visualize a slice of a grid.
Interpolates if given a position between two grid planes.
def visualize_slice(self,
cell_data: numpy.ndarray,
surface_normal: int,
center: float,
which_shifts: int = 0,
sample_period: int = 1,
finalize: bool = True,
pcolormesh_args: Optional[Dict[str, Any]] = None,
) -> None:
"""
Visualize a slice of a grid.
Interpolates if given a position between two planes.
Args:
cell_data: Cell data to visualize
plane: Axis and position (`Plane`) of the plane to read.
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
pcolormesh_args: Args passed through to matplotlib `pcolormesh()`
ax: If provided, plot to these axes (instead of creating a new figure & axes)
Args:
surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
center: Scalar specifying position along surface_normal axis.
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
"""
from matplotlib import pyplot
Returns:
(Figure, Axes)
"""
from matplotlib import pyplot
if pcolormesh_args is None:
pcolormesh_args = {}
if isinstance(plane, dict):
plane = Plane(**plane)
grid_slice = self.get_slice(cell_data=cell_data,
surface_normal=surface_normal,
center=center,
which_shifts=which_shifts,
sample_period=sample_period)
if pcolormesh_args is None:
pcolormesh_args = {}
surface = numpy.delete(range(3), surface_normal)
grid_slice = self.get_slice(
cell_data = cell_data,
plane = plane,
which_shifts = which_shifts,
sample_period = sample_period,
)
x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
x_label, y_label = ('xyz'[a] for a in surface)
surface = numpy.delete(range(3), plane.axis)
if sample_period == 1:
x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
else:
x, y = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface)
pcolormesh_args.setdefault('shading', 'nearest')
xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
x_label, y_label = ('xyz'[a] for a in surface)
if ax is None:
fig, ax = pyplot.subplots()
else:
fig = ax.figure
mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args)
fig.colorbar(mappable)
ax.set_aspect('equal', adjustable='box')
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
if finalize:
pyplot.show()
return fig, ax
pyplot.figure()
pyplot.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args)
pyplot.colorbar()
pyplot.gca().set_aspect('equal', adjustable='box')
pyplot.xlabel(x_label)
pyplot.ylabel(y_label)
if finalize:
pyplot.show()
def visualize_edges(
self,
cell_data: NDArray,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1,
finalize: bool = True,
contour_args: dict[str, Any] | None = None,
ax: 'matplotlib.axes.Axes | None' = None,
level_fraction: float = 0.7,
) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']:
"""
Visualize the edges of a grid slice.
This is intended as an overlay on top of visualize_slice (e.g. showing epsilon boundaries
on an E-field plot).
def visualize_isosurface(self,
cell_data: numpy.ndarray,
level: Optional[float] = None,
which_shifts: int = 0,
sample_period: int = 1,
show_edges: bool = True,
finalize: bool = True,
) -> None:
"""
Draw an isosurface plot of the device.
Interpolates if given a position between two grid planes.
Args:
cell_data: Cell data to visualize
level: Value at which to find isosurface. Default (None) uses mean value in grid.
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
show_edges: Whether to draw triangle edges. Default `True`
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
"""
from matplotlib import pyplot
import skimage.measure
# Claims to be unused, but needed for subplot(projection='3d')
from mpl_toolkits.mplot3d import Axes3D
Args:
cell_data: Cell data to visualize
plane: Axis and position (`Plane`) of the plane to read.
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
contour_args: Args passed through to matplotlib `pcolormesh()`
ax: If provided, plot to these axes (instead of creating a new figure & axes)
level_fraction: Value between 0 and 1 which tunes how many contours are generated.
1 indicates that every possible step should have its own contour.
# Get data from cell_data
grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period]
if level is None:
level = grid.mean()
Returns:
(Figure, Axes)
"""
from matplotlib import pyplot
# Find isosurface with marching cubes
verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level)
if level_fraction > 1:
raise GridError(f'{level_fraction=} must be between 0 and 1')
# Convert vertices from index to position
pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False)
for i in range(verts.shape[0])], dtype=float)
xs, ys, zs = (pos_verts[:, a] for a in range(3))
if isinstance(plane, dict):
plane = Plane(**plane)
# Draw the plot
fig = pyplot.figure()
ax = fig.add_subplot(111, projection='3d')
if show_edges:
ax.plot_trisurf(xs, ys, faces, zs)
else:
ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none')
if contour_args is None:
contour_args = dict(alpha=0.8, colors='gray')
# Add a fake plot of a cube to force the axes to be equal lengths
max_range = numpy.array([xs.max() - xs.min(),
ys.max() - ys.min(),
zs.max() - zs.min()], dtype=float).max()
mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2]
xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min())
ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min())
zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min())
# Comment or uncomment following both lines to test the fake bounding box:
for xb, yb, zb in zip(xbs, ybs, zbs):
ax.plot([xb], [yb], [zb], 'w')
grid_slice = self.get_slice(
cell_data = cell_data,
plane = plane,
which_shifts = which_shifts,
sample_period = sample_period,
)
cvals, cval_counts = numpy.unique(grid_slice, return_counts=True)
if cvals.size == 1:
levels = [cvals[0] + 1]
else:
cval_order = numpy.argsort(cval_counts)[::-1]
level_count = 2
while cval_counts[cval_order[:level_count]].sum() < level_fraction:
level_count += 1
ctr_levels = cvals[cval_order[:level_count]]
levels = numpy.diff(ctr_levels[::-1]) + ctr_levels[:0:-1]
surface = numpy.delete(range(3), plane.axis)
if ax is None:
fig, ax = pyplot.subplots()
else:
fig = ax.figure
xc, yc = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface)
xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij')
ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args)
if finalize:
pyplot.show()
return fig, ax
def visualize_isosurface(
self,
cell_data: NDArray,
level: float | None = None,
which_shifts: int = 0,
sample_period: int = 1,
show_edges: bool = True,
finalize: bool = True,
) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']:
"""
Draw an isosurface plot of the device.
Args:
cell_data: Cell data to visualize
level: Value at which to find isosurface. Default (None) uses mean value in grid.
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
show_edges: Whether to draw triangle edges. Default `True`
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
Returns:
(Figure, Axes)
"""
from matplotlib import pyplot
import skimage.measure
# Claims to be unused, but needed for subplot(projection='3d')
from mpl_toolkits.mplot3d import Axes3D
del Axes3D # imported for side effects only
# Get data from cell_data
grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period]
if level is None:
level = grid.mean()
# Find isosurface with marching cubes
verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level)
# Convert vertices from index to position
preview_exyz = self._sampled_exyz(which_shifts, sample_period)
pos_verts = numpy.array([
[
numpy.interp(verts[i, a], numpy.arange(preview_exyz[a].size) - 0.5, preview_exyz[a])
for a in range(3)
]
for i in range(verts.shape[0])
], dtype=float)
xs, ys, zs = (pos_verts[:, a] for a in range(3))
# Draw the plot
fig = pyplot.figure()
ax = fig.add_subplot(111, projection='3d')
if show_edges:
ax.plot_trisurf(xs, ys, faces, zs) # type: ignore
else:
ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') # type: ignore
# Add a fake plot of a cube to force the axes to be equal lengths
max_range = numpy.array([xs.max() - xs.min(),
ys.max() - ys.min(),
zs.max() - zs.min()], dtype=float).max()
mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2]
xbs = 0.5 * max_range * mg[0].ravel() + 0.5 * (xs.max() + xs.min())
ybs = 0.5 * max_range * mg[1].ravel() + 0.5 * (ys.max() + ys.min())
zbs = 0.5 * max_range * mg[2].ravel() + 0.5 * (zs.max() + zs.min())
# Comment or uncomment following both lines to test the fake bounding box:
for xb, yb, zb in zip(xbs, ybs, zbs, strict=True):
ax.plot([xb], [yb], [zb], 'w')
if finalize:
pyplot.show()
return fig, ax
if finalize:
pyplot.show()

View file

@ -1,9 +1,8 @@
import pytest
import numpy
from numpy.testing import assert_allclose #, assert_array_equal
import pickle
import pytest # type: ignore
import numpy # type: ignore
from numpy.testing import assert_allclose, assert_array_equal # type: ignore
from .. import Grid, GridData, Extent, GridError, Plane, Slab
from .. import Grid
def test_draw_oncenter_2x2() -> None:
@ -13,13 +12,7 @@ def test_draw_oncenter_2x2() -> None:
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
arr = grid.allocate(0)
grid.draw_cuboid(
arr,
x=dict(center=0, span=1),
y=Extent(center=0, span=1),
z=dict(center=0, span=10),
foreground=1,
)
grid.draw_cuboid(arr, center=[0, 0, 0], dimensions=[1, 1, 10], foreground=1)
correct = numpy.array([[0.25, 0.25],
[0.25, 0.25]])[None, :, :, None]
@ -34,13 +27,7 @@ def test_draw_ongrid_4x4() -> None:
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
arr = grid.allocate(0)
grid.draw_cuboid(
arr,
x=dict(center=0, span=2),
y=dict(min=-1, max=1),
z=dict(center=0, min=-5),
foreground=1,
)
grid.draw_cuboid(arr, center=[0, 0, 0], dimensions=[2, 2, 10], foreground=1)
correct = numpy.array([[0, 0, 0, 0],
[0, 1, 1, 0],
@ -57,13 +44,7 @@ def test_draw_xshift_4x4() -> None:
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
arr = grid.allocate(0)
grid.draw_cuboid(
arr,
x=dict(center=0.5, span=1.5),
y=dict(min=-1, max=1),
z=dict(center=0, span=10),
foreground=1,
)
grid.draw_cuboid(arr, center=[0.5, 0, 0], dimensions=[1.5, 2, 10], foreground=1)
correct = numpy.array([[0, 0, 0, 0],
[0, 0.25, 0.25, 0],
@ -80,13 +61,7 @@ def test_draw_yshift_4x4() -> None:
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
arr = grid.allocate(0)
grid.draw_cuboid(
arr,
x=dict(min=-1, max=1),
y=dict(center=0.5, span=1.5),
z=dict(center=0, span=10),
foreground=1,
)
grid.draw_cuboid(arr, center=[0, 0.5, 0], dimensions=[2, 1.5, 10], foreground=1)
correct = numpy.array([[0, 0, 0, 0],
[0, 0.25, 1, 0.25],
@ -103,13 +78,7 @@ def test_draw_2shift_4x4() -> None:
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
arr = grid.allocate(0)
grid.draw_cuboid(
arr,
x=dict(center=0.5, span=1.5),
y=dict(min=-0.5, max=0.5),
z=dict(center=0, span=10),
foreground=1,
)
grid.draw_cuboid(arr, center=[0.5, 0, 0], dimensions=[1.5, 1, 10], foreground=1)
correct = numpy.array([[0, 0, 0, 0],
[0, 0.125, 0.125, 0],
@ -117,350 +86,3 @@ def test_draw_2shift_4x4() -> None:
[0, 0.125, 0.125, 0]])[None, :, :, None]
assert_allclose(arr, correct)
def test_ind2pos_round_preserves_float_centers() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]])
pos = grid.ind2pos(numpy.array([1, 0, 0]), which_shifts=0)
assert_allclose(pos, [2.0, 1.0, 0.5])
def test_ind2pos_enforces_bounds_for_rounded_and_fractional_indices() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]])
with pytest.raises(GridError):
grid.ind2pos(numpy.array([2, 0, 0]), which_shifts=0, check_bounds=True)
edge_pos = grid.ind2pos(numpy.array([1.5, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True)
assert_allclose(edge_pos, [3.0, 2.0, 1.0])
with pytest.raises(GridError):
grid.ind2pos(numpy.array([1.6, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True)
def test_draw_polygon_accepts_coplanar_nx3_vertices() -> None:
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]])
arr_2d = grid.allocate(0)
arr_3d = grid.allocate(0)
slab = dict(axis='z', center=0.5, span=1.0)
polygon_2d = numpy.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=float)
polygon_3d = numpy.array([[0, 0, 0.5],
[1, 0, 0.5],
[1, 1, 0.5],
[0, 1, 0.5]], dtype=float)
grid.draw_polygon(arr_2d, slab=slab, polygon=polygon_2d, foreground=1)
grid.draw_polygon(arr_3d, slab=slab, polygon=polygon_3d, foreground=1)
assert_allclose(arr_3d, arr_2d)
def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None:
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]])
arr = grid.allocate(0)
polygon = numpy.array([[0, 0, 0.5],
[1, 0, 0.5],
[1, 1, 0.75],
[0, 1, 0.5]], dtype=float)
with pytest.raises(GridError):
grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1)
def test_get_slice_supports_sampling() -> None:
grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]])
cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape)
grid_slice = grid.get_slice(cell_data, Plane(z=0.5), sample_period=2)
assert_allclose(grid_slice, cell_data[0, ::2, ::2, 0])
def test_sampled_visualization_helpers_do_not_error() -> None:
matplotlib = pytest.importorskip('matplotlib')
matplotlib.use('Agg')
from matplotlib import pyplot
grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]])
cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape)
fig_slice, ax_slice = grid.visualize_slice(cell_data, Plane(z=0.5), sample_period=2, finalize=False)
fig_edges, ax_edges = grid.visualize_edges(cell_data, Plane(z=0.5), sample_period=2, finalize=False)
assert fig_slice is ax_slice.figure
assert fig_edges is ax_edges.figure
pyplot.close(fig_slice)
pyplot.close(fig_edges)
def test_grid_constructor_rejects_invalid_coordinate_count() -> None:
with pytest.raises(GridError):
Grid([[0, 1], [0, 1]], shifts=[[0, 0, 0]])
with pytest.raises(GridError):
Grid([[0, 1], [0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
def test_grid_constructor_rejects_invalid_periodic_length() -> None:
with pytest.raises(GridError):
Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]], periodic=[True, False])
def test_extent_and_slab_reject_inverted_geometry() -> None:
with pytest.raises(GridError):
Extent(center=0, min=1)
with pytest.raises(GridError):
Extent(min=2, max=1)
with pytest.raises(GridError):
Slab(axis='z', center=1, max=0)
def test_extent_accepts_scalar_like_inputs() -> None:
extent = Extent(min=numpy.array([1.0]), span=numpy.array([4.0]))
assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0])
def test_get_slice_uses_shifted_grid_bounds() -> None:
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0.5, 0, 0]])
cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape)
grid_slice = grid.get_slice(cell_data, Plane(x=2.0), which_shifts=0)
assert_allclose(grid_slice, cell_data[0, 1, :, :])
with pytest.raises(GridError):
grid.get_slice(cell_data, Plane(x=2.1), which_shifts=0)
def test_draw_extrude_rectangle_uses_boundary_slice() -> None:
grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]])
cell_data = grid.allocate(0)
source = numpy.array([[1, 2],
[3, 4]], dtype=float)
cell_data[0, :, :, 1] = source
grid.draw_extrude_rectangle(
cell_data,
rectangle=[[0, 0, 2], [2, 2, 2]],
direction=2,
polarity=-1,
distance=2,
)
assert_allclose(cell_data[0, :, :, 0], source)
assert_allclose(cell_data[0, :, :, 1], source)
def test_sampled_preview_exyz_tracks_nonuniform_centers() -> None:
grid = Grid([[0, 1, 3, 6, 10], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]])
sampled_exyz = grid._sampled_exyz(0, 2)
assert_allclose(sampled_exyz[0], [-1.5, 2.5, 6.5])
def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None:
matplotlib = pytest.importorskip('matplotlib')
matplotlib.use('Agg')
skimage_measure = pytest.importorskip('skimage.measure')
from matplotlib import pyplot
from mpl_toolkits.mplot3d.axes3d import Axes3D
captured: dict[str, numpy.ndarray] = {}
def fake_marching_cubes(_grid: numpy.ndarray, _level: float) -> tuple[numpy.ndarray, numpy.ndarray, None, None]:
verts = numpy.array([[0.5, 0.5, 0.5],
[0.5, 1.5, 0.5],
[1.5, 0.5, 0.5]], dtype=float)
faces = numpy.array([[0, 1, 2]], dtype=int)
return verts, faces, None, None
def fake_plot_trisurf( # noqa: ANN202
_self: object,
xs: numpy.ndarray,
ys: numpy.ndarray,
faces: numpy.ndarray,
zs: numpy.ndarray,
*_args: object,
**_kwargs: object,
) -> object:
captured['xs'] = numpy.asarray(xs)
captured['ys'] = numpy.asarray(ys)
captured['faces'] = numpy.asarray(faces)
captured['zs'] = numpy.asarray(zs)
return object()
monkeypatch.setattr(skimage_measure, 'marching_cubes', fake_marching_cubes)
monkeypatch.setattr(Axes3D, 'plot_trisurf', fake_plot_trisurf)
grid = Grid([numpy.arange(7, dtype=float), numpy.arange(7, dtype=float), numpy.arange(7, dtype=float)], shifts=[[0, 0, 0]])
cell_data = numpy.zeros(grid.cell_data_shape)
fig, _ax = grid.visualize_isosurface(cell_data, level=0.5, sample_period=2, finalize=False)
assert_allclose(captured['xs'], [1.5, 1.5, 3.5])
assert_allclose(captured['ys'], [1.5, 3.5, 1.5])
assert_allclose(captured['zs'], [1.5, 1.5, 1.5])
pyplot.close(fig)
def test_grid_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False])
path = tmp_path / 'grid.state'
grid.save(str(path))
loaded = Grid.load(str(path))
assert path.exists()
for original, restored in zip(grid.exyz, loaded.exyz, strict=True):
assert_allclose(restored, original)
assert_allclose(loaded.shifts, grid.shifts)
assert loaded.periodic == grid.periodic
def test_grid_load_supports_legacy_pickle(tmp_path: pytest.TempPathFactory) -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False])
path = tmp_path / 'grid.pickle'
with open(path, 'wb') as f:
pickle.dump(grid.__dict__, f, protocol=2)
loaded = Grid.load(str(path))
for original, restored in zip(grid.exyz, loaded.exyz, strict=True):
assert_allclose(restored, original)
assert_allclose(loaded.shifts, grid.shifts)
assert loaded.periodic == grid.periodic
def test_griddata_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None:
data = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]]).with_data(fill_value=2.0)
data.cell_data[0, 1, 0, 0] = 5.0
path = tmp_path / 'griddata.state'
data.save(str(path))
loaded = GridData.load(str(path))
assert path.exists()
assert_allclose(loaded.cell_data, data.cell_data)
assert_allclose(loaded.grid.shifts, data.grid.shifts)
assert loaded.grid.periodic == data.grid.periodic
def test_griddata_rejects_invalid_payload_kind(tmp_path: pytest.TempPathFactory) -> None:
path = tmp_path / 'grid.state'
Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).save(str(path))
with pytest.raises(GridError):
GridData.load(str(path))
def test_negative_shift_nonperiodic_edges_and_widths() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0])
assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5])
assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25])
def test_negative_shift_periodic_edges_and_widths() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False])
assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0])
assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5])
def test_negative_shift_coordinate_round_trip() -> None:
grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False)
pos = grid.ind2pos(ind, 0, round_ind=False)
assert_allclose(ind, [1.0, 0.0, 0.0])
assert_allclose(pos, [1.25, 1.0, 0.5])
def test_negative_shift_draw_cuboid_fractional_fill() -> None:
grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
arr = grid.allocate(0)
grid.draw_cuboid(
arr,
x=dict(min=0, max=1),
y=dict(min=0, max=1),
z=dict(min=0, max=1),
foreground=1,
)
assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3])
def test_negative_shift_get_slice_uses_shifted_centers() -> None:
grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False])
cell_data = numpy.zeros(grid.cell_data_shape)
cell_data[0, 1, :, 0] = [7, 9]
x_center = float(grid.shifted_xyz(0)[0][1])
grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0)
assert_allclose(grid_slice, [7, 9])
def test_grid_with_data_returns_griddata() -> None:
grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
data = grid.with_data(fill_value=2.0)
assert isinstance(data, GridData)
assert_allclose(data.cell_data, numpy.full(grid.cell_data_shape, 2.0, dtype=numpy.float32))
def test_griddata_constructor_validates_shape() -> None:
grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]])
with pytest.raises(GridError):
GridData(grid, numpy.zeros((1, 1, 1)))
def test_griddata_draw_methods_are_chainable() -> None:
data = Grid([[0, 1, 2], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0)
chained = data.draw_cuboid(
foreground=1,
x=dict(min=0, max=1),
y=dict(min=0, max=1),
z=dict(min=0, max=1),
).draw_polygon(
foreground=0.5,
slab=dict(axis='z', center=0.5, span=1.0),
polygon=numpy.array([[0, 0], [2, 0], [2, 1], [0, 1]], dtype=float),
)
assert chained is data
assert data.cell_data.sum() > 0
def test_griddata_read_methods_delegate() -> None:
data = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0)
data.cell_data[0, :, :, 0] = numpy.array([[1, 2], [3, 4]], dtype=float)
assert_allclose(
data.get_slice(Plane(z=0.5)),
data.grid.get_slice(data.cell_data, Plane(z=0.5)),
)
def test_griddata_copy_is_independent() -> None:
data = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=1.0)
cloned = data.copy()
cloned.cell_data[0, 0, 0, 0] = 5.0
assert data is not cloned
assert data.grid is not cloned.grid
assert data.cell_data[0, 0, 0, 0] == 1.0

View file

@ -1,248 +0,0 @@
from typing import Protocol, TypedDict, runtime_checkable, cast
from dataclasses import dataclass
import numpy
class GridError(Exception):
""" Base error type for `gridlock` """
pass
def _coerce_scalar(name: str, value: object) -> float:
arr = numpy.asarray(value)
if arr.size != 1:
raise GridError(f'{name} must be a scalar value')
try:
return float(arr.reshape(()))
except (TypeError, ValueError) as exc:
raise GridError(f'{name} must be a real scalar value') from exc
class ExtentDict(TypedDict, total=False):
"""
Geometrical definition of an extent (1D bounded region)
Must contain exactly two of `min`, `max`, `center`, or `span`.
"""
min: float
center: float
max: float
span: float
@runtime_checkable
class ExtentProtocol(Protocol):
"""
Anything that looks like an `Extent`
"""
center: float
span: float
@property
def max(self) -> float: ...
@property
def min(self) -> float: ...
@dataclass(init=False, slots=True)
class Extent(ExtentProtocol):
"""
Geometrical definition of an extent (1D bounded region)
May be constructed with any two of `min`, `max`, `center`, or `span`.
"""
center: float
span: float
@property
def max(self) -> float:
return self.center + self.span / 2
@property
def min(self) -> float:
return self.center - self.span / 2
def __init__(
self,
*,
min: float | None = None,
center: float | None = None,
max: float | None = None,
span: float | None = None,
) -> None:
values = {
'min': None if min is None else _coerce_scalar('min', min),
'center': None if center is None else _coerce_scalar('center', center),
'max': None if max is None else _coerce_scalar('max', max),
'span': None if span is None else _coerce_scalar('span', span),
}
if sum(value is not None for value in values.values()) != 2:
raise GridError('Exactly two of min, center, max, span must be provided')
min_v = values['min']
center_v = values['center']
max_v = values['max']
span_v = values['span']
if span_v is not None and span_v < 0:
raise GridError('span must be non-negative')
if min_v is not None and max_v is not None:
if max_v < min_v:
raise GridError('max must be greater than or equal to min')
center_v = 0.5 * (max_v + min_v)
span_v = max_v - min_v
elif center_v is not None and min_v is not None:
span_v = 2 * (center_v - min_v)
if span_v < 0:
raise GridError('min must be less than or equal to center')
elif center_v is not None and max_v is not None:
span_v = 2 * (max_v - center_v)
if span_v < 0:
raise GridError('center must be less than or equal to max')
elif min_v is not None and span_v is not None:
center_v = min_v + 0.5 * span_v
elif max_v is not None and span_v is not None:
center_v = max_v - 0.5 * span_v
if center_v is None or span_v is None:
raise GridError('Unable to construct extent from the provided values')
self.center = center_v
self.span = span_v
class SlabDict(TypedDict, total=False):
"""
Geometrical definition of a slab (3D region bounded on one axis only)
Must contain `axis` plus any two of `min`, `max`, `center`, or `span`.
"""
min: float
center: float
max: float
span: float
axis: int | str
@runtime_checkable
class SlabProtocol(ExtentProtocol, Protocol):
"""
Anything that looks like a `Slab`
"""
axis: int
center: float
span: float
@property
def max(self) -> float: ...
@property
def min(self) -> float: ...
@dataclass(init=False, slots=True)
class Slab(Extent, SlabProtocol):
"""
Geometrical definition of a slab (3D region bounded on one axis only)
May be constructed with `axis` (bounded axis) plus any two of `min`, `max`, `center`, or `span`.
"""
axis: int
def __init__(
self,
axis: int | str,
*,
min: float | None = None,
center: float | None = None,
max: float | None = None,
span: float | None = None,
) -> None:
Extent.__init__(self, min=min, center=center, max=max, span=span)
if isinstance(axis, str):
axis_int = 'xyz'.find(axis.lower())
else:
axis_int = axis
if axis_int not in range(3):
raise GridError(f'Invalid axis (slab normal direction): {axis}')
self.axis = axis_int
def as_plane(self, where: str) -> 'Plane':
if where == 'center':
return Plane(axis=self.axis, pos=self.center)
if where == 'min':
return Plane(axis=self.axis, pos=self.min)
if where == 'max':
return Plane(axis=self.axis, pos=self.max)
raise GridError(f'Invalid {where=}')
class PlaneDict(TypedDict, total=False):
"""
Geometrical definition of a plane (2D unbounded region in 3D space)
Must contain exactly one of `x`, `y`, `z`, or both `axis` and `pos`
"""
x: float
y: float
z: float
axis: int
pos: float
@runtime_checkable
class PlaneProtocol(Protocol):
"""
Anything that looks like a `Plane`
"""
axis: int
pos: float
@dataclass(init=False, slots=True)
class Plane(PlaneProtocol):
"""
Geometrical definition of a plane (2D unbounded region in 3D space)
May be constructed with any of `x=4`, `y=5`, `z=-5`, or `axis=2, pos=-5`.
"""
axis: int
pos: float
def __init__(
self,
*,
axis: int | str | None = None,
pos: float | None = None,
x: float | None = None,
y: float | None = None,
z: float | None = None,
) -> None:
xx = x
yy = y
zz = z
if sum(aa is not None for aa in (pos, xx, yy, zz)) != 1:
raise GridError('Exactly one of pos, x, y, z must be non-None!')
if (axis is None) != (pos is None):
raise GridError('Either both or neither of `axis` and `pos` must be defined.')
if isinstance(axis, str):
axis_int = 'xyz'.find(axis.lower())
elif axis is None:
axis_int = (xx is None, yy is None, zz is None).index(False)
else:
axis_int = axis
if axis_int not in range(3):
raise GridError(f'Invalid axis (slab normal direction): {axis=} {x=} {y=} {z=}')
self.axis = axis_int
if pos is not None:
cpos = pos
else:
cpos = cast('float', (xx, yy, zz)[axis_int])
assert cpos is not None
if hasattr(cpos, '__len__'):
assert len(cpos) == 1
self.pos = cpos

View file

@ -1,98 +0,0 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "gridlock"
description = "Coupled gridding library"
readme = "README.md"
license = { file = "LICENSE.md" }
authors = [
{ name="Jan Petykiewicz", email="jan@mpxd.net" },
]
homepage = "https://mpxd.net/code/jan/gridlock"
repository = "https://mpxd.net/code/jan/gridlock"
keywords = [
"FDTD",
"gridding",
"simulation",
"nonuniform",
"FDFD",
"finite",
"difference",
]
classifiers = [
"Programming Language :: Python :: 3",
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: GNU Affero General Public License v3",
"Topic :: Multimedia :: Graphics :: 3D Rendering",
"Topic :: Scientific/Engineering :: Electronic Design Automation (EDA)",
"Topic :: Scientific/Engineering :: Physics",
"Topic :: Scientific/Engineering :: Visualization",
]
requires-python = ">=3.11"
include = [
"LICENSE.md"
]
dynamic = ["version"]
dependencies = [
"numpy>=1.26",
"float_raster>=0.8",
]
[tool.hatch.version]
path = "gridlock/__init__.py"
[project.optional-dependencies]
visualization = ["matplotlib"]
visualization-isosurface = [
"matplotlib",
"skimage>=0.13",
"mpl_toolkits",
]
[tool.ruff]
exclude = [
".git",
"dist",
]
line-length = 145
indent-width = 4
lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
lint.select = [
"NPY", "E", "F", "W", "B", "ANN", "UP", "SLOT", "SIM", "LOG",
"C4", "ISC", "PIE", "PT", "RET", "TCH", "PTH", "INT",
"ARG", "PL", "R", "TRY",
"G010", "G101", "G201", "G202",
"Q002", "Q003", "Q004",
]
lint.ignore = [
#"ANN001", # No annotation
"ANN002", # *args
"ANN003", # **kwargs
"ANN401", # Any
"SIM108", # single-line if / else assignment
"RET504", # x=y+z; return x
"PIE790", # unnecessary pass
"ISC003", # non-implicit string concatenation
"C408", # dict(x=y) instead of {'x': y}
"PLR09", # Too many xxx
"PLR2004", # magic number
"PLC0414", # import x as x
"TRY003", # Long exception message
"PTH123", # open()
]
[[tool.mypy.overrides]]
module = [
"matplotlib",
"matplotlib.axes",
"matplotlib.figure",
"mpl_toolkits.mplot3d",
]
ignore_missing_imports = true

47
setup.py Normal file
View file

@ -0,0 +1,47 @@
#!/usr/bin/env python3
from setuptools import setup, find_packages
with open('README.md', 'r') as f:
long_description = f.read()
with open('gridlock/VERSION.py', 'rt') as f:
version = f.readlines()[2].strip()
setup(name='gridlock',
version=version,
description='Coupled gridding library',
long_description=long_description,
long_description_content_type='text/markdown',
author='Jan Petykiewicz',
author_email='jan@mpxd.net',
url='https://mpxd.net/code/jan/gridlock',
packages=find_packages(),
package_data={
'gridlock': ['py.typed'],
},
install_requires=[
'numpy',
'float_raster',
],
extras_require={
'visualization': ['matplotlib'],
'visualization-isosurface': [
'matplotlib',
'skimage>=0.13',
'mpl_toolkits',
],
},
classifiers=[
'Programming Language :: Python :: 3',
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: GNU Affero General Public License v3',
'Topic :: Multimedia :: Graphics :: 3D Rendering',
'Topic :: Scientific/Engineering :: Electronic Design Automation (EDA)',
'Topic :: Scientific/Engineering :: Physics',
'Topic :: Scientific/Engineering :: Visualization',
],
)