Compare commits

...

1 Commits

Author SHA1 Message Date
34f80202ba Major rework of arguments using Extent/Slab/Plane 2025-01-28 19:36:59 -08:00
8 changed files with 416 additions and 170 deletions

View File

@ -15,9 +15,24 @@ Dependencies:
- mpl_toolkits.mplot3d [Grid.visualize_isosurface()]
- skimage [Grid.visualize_isosurface()]
"""
from .error import GridError as GridError
from .utils import (
GridError as GridError,
Extent as Extent,
ExtentProtocol as ExtentProtocol,
ExtentDict as ExtentDict,
Slab as Slab,
SlabProtocol as SlabProtocol,
SlabDict as SlabDict,
Plane as Plane,
PlaneProtocol as PlaneProtocol,
PlaneDict as PlaneDict,
)
from .grid import Grid as Grid
__author__ = 'Jan Petykiewicz'
__version__ = '1.2'
version = __version__

View File

@ -1,10 +0,0 @@
from enum import Enum
class Direction(Enum):
"""
Enum for axis->integer mapping
"""
x = 0
y = 1
z = 2

View File

@ -7,7 +7,7 @@ import numpy
from numpy.typing import NDArray, ArrayLike
from float_raster import raster
from . import GridError
from .utils import GridError, Slab, SlabDict, SlabProtocol, Extent, ExtentDict, ExtentProtocol
from .position import GridPosMixin
@ -21,27 +21,26 @@ foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray]
foreground_t = float | foreground_callable_t
class GridDrawMixin(GridPosMixin):
def draw_polygons(
self,
cell_data: NDArray,
surface_normal: int,
center: ArrayLike,
slab: SlabProtocol | SlabDict,
polygons: Sequence[ArrayLike],
thickness: float,
foreground: Sequence[foreground_t] | foreground_t,
*,
offset2d: ArrayLike = (0, 0),
) -> None:
"""
Draw polygons on an axis-aligned plane.
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying an offset applied to all the polygons
polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon
(non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each
(non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each
polygon must have at least 3 vertices.
thickness: Thickness of the layer to draw
foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list
of any of these (1 per grid). Callable values should take an ndarray the shape of the
grid and return an ndarray of equal shape containing the foreground value at the given x, y,
@ -50,13 +49,13 @@ class GridDrawMixin(GridPosMixin):
Raises:
GridError
"""
if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction')
center = numpy.squeeze(center)
if isinstance(slab, dict):
slab = Slab(**slab)
poly_list = [numpy.asarray(poly) for poly in polygons]
# Check polygons, and remove redundant coordinates
surface = numpy.delete(range(3), surface_normal)
surface = numpy.delete(range(3), slab.axis)
for ii in range(len(poly_list)):
polygon = poly_list[ii]
@ -69,9 +68,8 @@ class GridDrawMixin(GridPosMixin):
if not polygon.shape[0] > 2:
raise GridError(malformed + 'must consist of more than 2 points')
if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1:
raise GridError(malformed + 'must be in plane with surface normal '
+ 'xyz'[surface_normal])
if polygon.ndim > 2 and not numpy.unique(polygon[:, slab.axis]).size == 1:
raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis])
# Broadcast foreground where necessary
foregrounds: Sequence[foreground_callable_t] | Sequence[float]
@ -87,10 +85,10 @@ class GridDrawMixin(GridPosMixin):
bd_2d_min = numpy.array([0, 0])
bd_2d_max = numpy.array([0, 0])
for polygon in poly_list:
bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0))
bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0))
bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center
bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center
bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) + offset2d
bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) + offset2d
bd_min = numpy.insert(bd_2d_min, slab.axis, slab.min)
bd_max = numpy.insert(bd_2d_max, slab.axis, slab.max)
# 2) Find indices (bdi) just outside bd elements
buf = 2 # size of safety buffer
@ -103,13 +101,13 @@ class GridDrawMixin(GridPosMixin):
bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int)
bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int)
# 3) Adjust polygons for center
poly_list = [poly + center[surface] for poly in poly_list]
# 3) Adjust polygons for offset2d
poly_list = [poly + offset2d for poly in poly_list]
# ## Generate weighing function
def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]:
v_2d = numpy.array(vector, dtype=float)
return numpy.insert(v_2d, surface_normal, (val,))
return numpy.insert(v_2d, slab.axis, (val,))
# iterate over grids
foreground_val: NDArray | float
@ -162,13 +160,12 @@ class GridDrawMixin(GridPosMixin):
w_xy = numpy.minimum(w_xy, 1.0)
# 2) Generate weights in z-direction
w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], ))
w_z = numpy.zeros(((bdi_max - bdi_min + 1)[slab.axis], ))
def get_zi(offset: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001
edges = self.shifted_exyz(i)[surface_normal]
point = center[surface_normal] + offset
def get_zi(point: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001
edges = self.shifted_exyz(i)[slab.axis]
grid_coord = numpy.digitize(point, edges) - 1
w_coord = grid_coord - bdi_min[surface_normal]
w_coord = grid_coord - bdi_min[slab.axis]
if w_coord < 0:
w_coord = 0
@ -177,12 +174,12 @@ class GridDrawMixin(GridPosMixin):
w_coord = w_z.size - 1
f = 1
else:
dz = self.shifted_dxyz(i)[surface_normal][grid_coord]
dz = self.shifted_dxyz(i)[slab.axis][grid_coord]
f = (point - edges[grid_coord]) / dz
return f, w_coord
zi_top_f, zi_top = get_zi(+thickness / 2.0)
zi_bot_f, zi_bot = get_zi(-thickness / 2.0)
zi_top_f, zi_top = get_zi(slab.max)
zi_bot_f, zi_bot = get_zi(slab.min)
w_z[zi_bot + 1:zi_top] = 1
@ -193,7 +190,7 @@ class GridDrawMixin(GridPosMixin):
w_z[zi_bot] = zi_top_f - zi_bot_f
# 3) Generate total weight function
w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,)))
w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], slab.axis, (2,)))
# ## Modify the grid
g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3))
@ -203,34 +200,36 @@ class GridDrawMixin(GridPosMixin):
def draw_polygon(
self,
cell_data: NDArray,
surface_normal: int,
center: ArrayLike,
slab: SlabProtocol | SlabDict,
polygon: ArrayLike,
thickness: float,
foreground: Sequence[foreground_t] | foreground_t,
*,
offset2d: ArrayLike = (0, 0),
) -> None:
"""
Draw a polygon on an axis-aligned plane.
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying an offset applied to the polygon
slab: `Slab` in which to draw polygons.
polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed,
clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at
clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Must have at
least 3 vertices.
thickness: Thickness of the layer to draw
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
"""
self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground)
self.draw_polygons(
cell_data = cell_data,
slab = slab,
polygons = [polygon],
foreground = foreground,
offset2d = offset2d,
)
def draw_slab(
self,
cell_data: NDArray,
surface_normal: int,
center: ArrayLike,
thickness: float,
slab: SlabProtocol | SlabDict,
foreground: Sequence[foreground_t] | foreground_t,
) -> None:
"""
@ -238,50 +237,45 @@ class GridDrawMixin(GridPosMixin):
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: `surface_normal` coordinate value at the center of the slab
slab:
thickness: Thickness of the layer to draw
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
"""
# Turn surface_normal into its integer representation
if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction')
if numpy.size(center) != 1:
center = numpy.squeeze(center)
if len(center) == 3:
center = center[surface_normal]
else:
raise GridError(f'Bad center: {center}')
if isinstance(slab, dict):
slab = Slab(**slab)
# Find center of slab
center_shift = self.center
center_shift[surface_normal] = center
center_shift[slab.axis] = slab.center
surface = numpy.delete(range(3), surface_normal)
surface = numpy.delete(range(3), slab.axis)
u_min, u_max = self.exyz[surface[0]][[0, -1]]
v_min, v_max = self.exyz[surface[1]][[0, -1]]
xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface]
xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface]
margin = 4 * numpy.max(self.dxyz[surface[0]].max(),
self.dxyz[surface[1]].max())
dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float)
p = numpy.array([[u_min - margin, v_max + margin],
[u_max + margin, v_max + margin],
[u_max + margin, v_min - margin],
[u_min - margin, v_min - margin]], dtype=float)
xyz_min -= 4 * dxyz
xyz_max += 4 * dxyz
p = numpy.array([[xyz_min[0], xyz_max[1]],
[xyz_max[0], xyz_max[1]],
[xyz_max[0], xyz_min[1]],
[xyz_min[0], xyz_min[1]]], dtype=float)
self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground)
self.draw_polygon(
cell_data = cell_data,
slab = slab,
polygon = p,
foreground = foreground,
)
def draw_cuboid(
self,
cell_data: NDArray,
center: ArrayLike,
dimensions: ArrayLike,
foreground: Sequence[foreground_t] | foreground_t,
*,
x: ExtentProtocol | ExtentDict,
y: ExtentProtocol | ExtentDict,
z: ExtentProtocol | ExtentDict,
) -> None:
"""
Draw an axis-aligned cuboid
@ -293,23 +287,30 @@ class GridDrawMixin(GridPosMixin):
sizes of the cuboid
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
"""
dimensions = numpy.asarray(dimensions)
p = numpy.array([[-dimensions[0], +dimensions[1]],
[+dimensions[0], +dimensions[1]],
[+dimensions[0], -dimensions[1]],
[-dimensions[0], -dimensions[1]]], dtype=float) * 0.5
thickness = dimensions[2]
self.draw_polygon(cell_data, 2, center, p, thickness, foreground)
if isinstance(x, dict):
x = Extent(**x)
if isinstance(y, dict):
y = Extent(**y)
if isinstance(z, dict):
z = Extent(**z)
center = numpy.asarray([x.center, y.center, z.center])
p = numpy.array([[x.min, y.max],
[x.max, y.max],
[x.max, y.min],
[x.min, y.min]], dtype=float)
slab = Slab(axis=2, center=z.center, span=z.span)
self.draw_polygon(cell_data=cell_data, slab=slab, polygon=p, foreground=foreground)
def draw_cylinder(
self,
cell_data: NDArray,
surface_normal: int,
center: ArrayLike,
h: SlabProtocol | SlabDict,
radius: float,
thickness: float,
num_points: int,
center2d: ArrayLike,
foreground: Sequence[foreground_t] | foreground_t,
) -> None:
"""
@ -317,18 +318,19 @@ class GridDrawMixin(GridPosMixin):
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying the cylinder's center
radius: cylinder radius
thickness: Thickness of the layer to draw
h:
radius:
num_points: The circle is approximated by a polygon with `num_points` vertices
center2d:
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
"""
theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False)
x = radius * numpy.sin(theta)
y = radius * numpy.cos(theta)
polygon = numpy.hstack((x[:, None], y[:, None]))
self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground)
if isinstance(h, dict):
h = Slab(**h)
theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False)[:, None]
xy0 = numpy.hstack((numpy.sin(theta), numpy.cos(theta)))
polygon = radius * xy0
self.draw_polygon(cell_data=cell_data, slab=h, polygon=polygon, foreground=foreground, offset2d=center2d)
def draw_extrude_rectangle(
@ -349,10 +351,10 @@ class GridDrawMixin(GridPosMixin):
polarity: +1 or -1, direction along axis to extrude in
distance: How far to extrude
"""
s = numpy.sign(polarity)
sgn = numpy.sign(polarity)
rectangle = numpy.array(rectangle, dtype=float)
if s == 0:
rectangle = numpy.asarray(rectangle, dtype=float)
if sgn == 0:
raise GridError('0 is not a valid polarity')
if direction not in range(3):
raise GridError(f'Invalid direction: {direction}')
@ -360,37 +362,38 @@ class GridDrawMixin(GridPosMixin):
raise GridError('Rectangle entries along extrusion direction do not match.')
center = rectangle.sum(axis=0) / 2.0
center[direction] += s * distance / 2.0
center[direction] += sgn * distance / 2.0
surface = numpy.delete(range(3), direction)
dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface]
p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5,
numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T
poly = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5,
numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T
thickness = distance
foreground_func = []
for i, grid in enumerate(cell_data):
z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction]
for ii, grid in enumerate(cell_data):
zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction]
ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)]
ind = [int(numpy.floor(zz)) if dd == direction else slice(None) for dd in range(3)]
fpart = z - numpy.floor(z)
mult = [1 - fpart, fpart][::s] # reverses if s negative
fpart = zz - numpy.floor(zz)
mult = [1 - fpart, fpart][::sgn] # reverses if s negative
foreground = mult[0] * grid[tuple(ind)]
ind[direction] += 1 # type: ignore #(known safe)
foreground += mult[1] * grid[tuple(ind)]
def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001
def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001
# transform from natural position to index
xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i)
xyzi = numpy.array([self.pos2ind(qrs, which_shifts=ii)
for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64)
# reshape to original shape and keep only in-plane components
qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface)
qi, ri = (numpy.reshape(xyzi[:, kk], xs.shape) for kk in surface)
return foreground[qi, ri]
foreground_func.append(f_foreground)
self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func)
slab = Slab(axis=direction, center=center[direction], span=thickness)
self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface])

View File

@ -1,2 +0,0 @@
class GridError(Exception):
pass

View File

@ -6,18 +6,18 @@ if __name__ == '__main__':
# xyz = [numpy.arange(-5.0, 6.0), numpy.arange(-4.0, 5.0), [-1.0, 1.0]]
# eg = Grid(xyz)
# egc = Grid.allocate(0.0)
# # eg.draw_slab(egc, surface_normal=2, center=0, thickness=10, foreground=2)
# eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=4,
# thickness=10, num_points=1000, foreground=1)
# eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2)
# # eg.draw_slab(egc, slab=dict(axis=2, center=0, span=10), foreground=2)
# eg.draw_cylinder(egc, h=slab(axis=2, center=0, span=10),
# center2d=[0, 0], radius=4, thickness=10, num_points=1000, foreground=1)
# eg.visualize_slice(egc, plane=dict(z=0), which_shifts=2)
# xyz2 = [numpy.arange(-5.0, 6.0), [-1.0, 1.0], numpy.arange(-4.0, 5.0)]
# eg2 = Grid(xyz2)
# eg2c = Grid.allocate(0.0)
# # eg2.draw_slab(eg2c, surface_normal=2, center=0, thickness=10, foreground=2)
# eg2.draw_cylinder(eg2c, surface_normal=1, center=[0, 0, 0],
# radius=4, thickness=10, num_points=1000, foreground=1.0)
# eg2.visualize_slice(eg2c, surface_normal=1, center=0, which_shifts=1)
# # eg2.draw_slab(eg2c, slab=dict(axis=2, center=0, span=10), foreground=2)
# eg2.draw_cylinder(eg2c, h=slab(axis=1, center=0, span=10), center2d=[0, 0],
# radius=4, num_points=1000, foreground=1.0)
# eg2.visualize_slice(eg2c, plane=dict(y=0), which_shifts=1)
# n = 20
# m = 3
@ -36,9 +36,20 @@ if __name__ == '__main__':
egc = eg.allocate(0)
# eg.draw_slab(Direction.z, 0, 10, 2)
eg.save('/home/jan/Desktop/test.pickle')
eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0,
thickness=10, num_points=1000, foreground=1)
eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]],
direction=1, polarity=+1, distance=5)
eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2)
eg.draw_cylinder(
egc,
h=dict(axis='z', center=0, span=10),
center2d=[0, 0],
radius=2.0,
num_points=1000,
foreground=1,
)
eg.draw_extrude_rectangle(
egc,
rectangle=[[-2, 1, -1], [0, 1, 1]],
direction=1,
polarity=+1,
distance=5,
)
eg.visualize_slice(egc, plane=dict(z=0), which_shifts=2)
eg.visualize_isosurface(egc, which_shifts=2)

View File

@ -6,7 +6,7 @@ from typing import Any, TYPE_CHECKING
import numpy
from numpy.typing import NDArray
from . import GridError
from .utils import GridError, Plane, PlaneDict, PlaneProtocol
from .position import GridPosMixin
if TYPE_CHECKING:
@ -23,27 +23,25 @@ class GridReadMixin(GridPosMixin):
def get_slice(
self,
cell_data: NDArray,
surface_normal: int,
center: float,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1
) -> NDArray:
"""
Retrieve a slice of a grid.
Interpolates if given a position between two planes.
Interpolates if given a position between two grid planes.
Args:
cell_data: Cell data to slice
surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
center: Scalar specifying position along surface_normal axis.
plane: Axis and position (`Plane`) of the plane to read.
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
Returns:
Array containing the portion of the grid.
"""
if numpy.size(center) != 1 or not numpy.isreal(center):
raise GridError('center must be a real scalar')
if isinstance(plane, dict):
plane = Plane(**plane)
sp = round(sample_period)
if sp <= 0:
@ -52,15 +50,12 @@ class GridReadMixin(GridPosMixin):
if numpy.size(which_shifts) != 1 or which_shifts < 0:
raise GridError('Invalid which_shifts')
if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction')
surface = numpy.delete(range(3), surface_normal)
surface = numpy.delete(range(3), plane.axis)
# Extract indices and weights of planes
center3 = numpy.insert([0, 0], surface_normal, (center,))
center3 = numpy.insert([0, 0], plane.axis, (plane.pos,))
center_index = self.pos2ind(center3, which_shifts,
round_ind=False, check_bounds=False)[surface_normal]
round_ind=False, check_bounds=False)[plane.axis]
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
if len(centers) == 2:
fpart = center_index - numpy.floor(center_index)
@ -68,14 +63,14 @@ class GridReadMixin(GridPosMixin):
else:
w = [1]
c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1])
if center < c_min or center > c_max:
c_min, c_max = (self.xyz[plane.axis][i] for i in [0, -1])
if plane.pos < c_min or plane.pos > c_max:
raise GridError('Coordinate of selected plane must be within simulation domain')
# Extract grid values from planes above and below visualized slice
sliced_grid = numpy.zeros(self.shape[surface])
for ci, weight in zip(centers, w, strict=True):
s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3))
s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3))
sliced_grid += weight * cell_data[which_shifts][tuple(s)]
# Remove extra dimensions
@ -87,20 +82,19 @@ class GridReadMixin(GridPosMixin):
def visualize_slice(
self,
cell_data: NDArray,
surface_normal: int,
center: float,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1,
finalize: bool = True,
pcolormesh_args: dict[str, Any] | None = None,
) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']:
) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']:
"""
Visualize a slice of a grid.
Interpolates if given a position between two planes.
Interpolates if given a position between two grid planes.
Args:
surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
center: Scalar specifying position along surface_normal axis.
cell_data: Cell data to visualize
plane: Axis and position (`Plane`) of the plane to read.
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
@ -110,16 +104,20 @@ class GridReadMixin(GridPosMixin):
"""
from matplotlib import pyplot
if isinstance(plane, dict):
plane = Plane(**plane)
if pcolormesh_args is None:
pcolormesh_args = {}
grid_slice = self.get_slice(cell_data=cell_data,
surface_normal=surface_normal,
center=center,
which_shifts=which_shifts,
sample_period=sample_period)
grid_slice = self.get_slice(
cell_data=cell_data,
plane=plane,
which_shifts=which_shifts,
sample_period=sample_period,
)
surface = numpy.delete(range(3), surface_normal)
surface = numpy.delete(range(3), plane.axis)
x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
@ -145,7 +143,7 @@ class GridReadMixin(GridPosMixin):
sample_period: int = 1,
show_edges: bool = True,
finalize: bool = True,
) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']:
) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']:
"""
Draw an isosurface plot of the device.
@ -183,18 +181,18 @@ class GridReadMixin(GridPosMixin):
fig = pyplot.figure()
ax = fig.add_subplot(111, projection='3d')
if show_edges:
ax.plot_trisurf(xs, ys, faces, zs)
ax.plot_trisurf(xs, ys, faces, zs) # type: ignore
else:
ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none')
ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') # type: ignore
# Add a fake plot of a cube to force the axes to be equal lengths
max_range = numpy.array([xs.max() - xs.min(),
ys.max() - ys.min(),
zs.max() - zs.min()], dtype=float).max()
mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2]
xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min())
ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min())
zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min())
xbs = 0.5 * max_range * mg[0].ravel() + 0.5 * (xs.max() + xs.min())
ybs = 0.5 * max_range * mg[1].ravel() + 0.5 * (ys.max() + ys.min())
zbs = 0.5 * max_range * mg[2].ravel() + 0.5 * (zs.max() + zs.min())
# Comment or uncomment following both lines to test the fake bounding box:
for xb, yb, zb in zip(xbs, ybs, zbs, strict=True):
ax.plot([xb], [yb], [zb], 'w')

View File

@ -2,7 +2,7 @@
import numpy
from numpy.testing import assert_allclose #, assert_array_equal
from .. import Grid
from .. import Grid, Extent #, Slab, Plane
def test_draw_oncenter_2x2() -> None:
@ -12,7 +12,13 @@ def test_draw_oncenter_2x2() -> None:
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
arr = grid.allocate(0)
grid.draw_cuboid(arr, center=[0, 0, 0], dimensions=[1, 1, 10], foreground=1)
grid.draw_cuboid(
arr,
x=dict(center=0, span=1),
y=Extent(center=0, span=1),
z=dict(center=0, span=10),
foreground=1,
)
correct = numpy.array([[0.25, 0.25],
[0.25, 0.25]])[None, :, :, None]
@ -27,7 +33,13 @@ def test_draw_ongrid_4x4() -> None:
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
arr = grid.allocate(0)
grid.draw_cuboid(arr, center=[0, 0, 0], dimensions=[2, 2, 10], foreground=1)
grid.draw_cuboid(
arr,
x=dict(center=0, span=2),
y=dict(min=-1, max=1),
z=dict(center=0, min=-5),
foreground=1,
)
correct = numpy.array([[0, 0, 0, 0],
[0, 1, 1, 0],
@ -44,7 +56,13 @@ def test_draw_xshift_4x4() -> None:
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
arr = grid.allocate(0)
grid.draw_cuboid(arr, center=[0.5, 0, 0], dimensions=[1.5, 2, 10], foreground=1)
grid.draw_cuboid(
arr,
x=dict(center=0.5, span=1.5),
y=dict(min=-1, max=1),
z=dict(center=0, span=10),
foreground=1,
)
correct = numpy.array([[0, 0, 0, 0],
[0, 0.25, 0.25, 0],
@ -61,7 +79,13 @@ def test_draw_yshift_4x4() -> None:
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
arr = grid.allocate(0)
grid.draw_cuboid(arr, center=[0, 0.5, 0], dimensions=[2, 1.5, 10], foreground=1)
grid.draw_cuboid(
arr,
x=dict(min=-1, max=1),
y=dict(center=0.5, span=1.5),
z=dict(center=0, span=10),
foreground=1,
)
correct = numpy.array([[0, 0, 0, 0],
[0, 0.25, 1, 0.25],
@ -78,7 +102,13 @@ def test_draw_2shift_4x4() -> None:
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
arr = grid.allocate(0)
grid.draw_cuboid(arr, center=[0.5, 0, 0], dimensions=[1.5, 1, 10], foreground=1)
grid.draw_cuboid(
arr,
x=dict(center=0.5, span=1.5),
y=dict(min=-0.5, max=0.5),
z=dict(center=0, span=10),
foreground=1,
)
correct = numpy.array([[0, 0, 0, 0],
[0, 0.125, 0.125, 0],

201
gridlock/utils.py Normal file
View File

@ -0,0 +1,201 @@
from typing import Protocol, TypedDict, runtime_checkable
from dataclasses import dataclass
class GridError(Exception):
""" Base error type for `gridlock` """
pass
class ExtentDict(TypedDict, total=False):
min: float
center: float
max: float
span: float
@runtime_checkable
class ExtentProtocol(Protocol):
center: float
span: float
@property
def max(self) -> float: ...
@property
def min(self) -> float: ...
@dataclass(init=False, slots=True)
class Extent(ExtentProtocol):
center: float
span: float
@property
def max(self) -> float:
return self.center + self.span / 2
@property
def min(self) -> float:
return self.center - self.span / 2
def __init__(
self,
*,
min: float | None = None,
center: float | None = None,
max: float | None = None,
span: float | None = None,
) -> None:
if sum(cc is None for cc in (min, center, max, span)) != 2:
raise GridError('Exactly two of min, center, max, span must be None!')
if span is None:
if center is None:
assert min is not None
assert max is not None
assert max >= min
center = 0.5 * (max + min)
span = max - min
elif max is None:
assert min is not None
assert center is not None
span = 2 * (center - min)
elif min is None:
assert center is not None
assert max is not None
span = 2 * (max - center)
else: # noqa: PLR5501
if center is not None:
pass
elif max is None:
assert min is not None
assert span is not None
center = min + 0.5 * span
elif min is None:
assert max is not None
assert span is not None
center = max - 0.5 * span
assert center is not None
assert span is not None
if hasattr(center, '__len__'):
assert len(center) == 1
if hasattr(span, '__len__'):
assert len(span) == 1
self.center = center
self.span = span
class SlabDict(TypedDict, total=False):
min: float
center: float
max: float
span: float
axis: int | str
@runtime_checkable
class SlabProtocol(ExtentProtocol, Protocol):
axis: int
center: float
span: float
@property
def max(self) -> float: ...
@property
def min(self) -> float: ...
@dataclass(init=False, slots=True)
class Slab(Extent, SlabProtocol):
axis: int
def __init__(
self,
axis: int | str,
*,
min: float | None = None,
center: float | None = None,
max: float | None = None,
span: float | None = None,
) -> None:
Extent.__init__(self, min=min, center=center, max=max, span=span)
if isinstance(axis, str):
axis_int = 'xyz'.find(axis.lower())
else:
axis_int = axis
if axis_int not in range(3):
raise GridError(f'Invalid axis (slab normal direction): {axis}')
self.axis = axis_int
def as_plane(self, where: str) -> 'Plane':
if where == 'center':
return Plane(axis=self.axis, pos=self.center)
if where == 'min':
return Plane(axis=self.axis, pos=self.min)
if where == 'max':
return Plane(axis=self.axis, pos=self.max)
raise GridError(f'Invalid {where=}')
class PlaneDict(TypedDict, total=False):
x: float
y: float
z: float
axis: int
pos: float
@runtime_checkable
class PlaneProtocol(Protocol):
axis: int
pos: float
@dataclass(init=False, slots=True)
class Plane(PlaneProtocol):
axis: int
pos: float
def __init__(
self,
*,
axis: int | str | None = None,
pos: float | None = None,
x: float | None = None,
y: float | None = None,
z: float | None = None,
) -> None:
xx = x
yy = y
zz = z
if sum(aa is not None for aa in (pos, xx, yy, zz)) != 1:
raise GridError('Exactly one of pos, x, y, z must be non-None!')
if (axis is None) != (pos is None):
raise GridError('Either both or neither of `axis` and `pos` must be defined.')
if isinstance(axis, str):
axis_int = 'xyz'.find(axis.lower())
elif axis is None:
axis_int = (xx is None, yy is None, zz is None).index(False)
else:
axis_int = axis
if axis_int not in range(3):
raise GridError(f'Invalid axis (slab normal direction): {axis=} {x=} {y=} {z=}')
self.axis = axis_int
if pos is not None:
cpos = pos
else:
cpos = (xx, yy, zz)[axis_int]
assert cpos is not None
if hasattr(cpos, '__len__'):
assert len(cpos) == 1
self.pos = cpos