Compare commits

..

No commits in common. "8e7e0edb1f8c070249e301427f0ad36ff3699ec5" and "9ab97e763cc5662df4297c51bc2331652698c345" have entirely different histories.

9 changed files with 758 additions and 826 deletions

View File

@ -15,8 +15,8 @@ Dependencies:
- mpl_toolkits.mplot3d [Grid.visualize_isosurface()] - mpl_toolkits.mplot3d [Grid.visualize_isosurface()]
- skimage [Grid.visualize_isosurface()] - skimage [Grid.visualize_isosurface()]
""" """
from .error import GridError as GridError from .error import GridError
from .grid import Grid as Grid from .grid import Grid
__author__ = 'Jan Petykiewicz' __author__ = 'Jan Petykiewicz'
__version__ = '1.1' __version__ = '1.1'

View File

@ -1,196 +0,0 @@
from typing import Protocol
import numpy
from numpy.typing import NDArray
from . import GridError
class GridBase(Protocol):
exyz: list[NDArray]
"""Cell edges. Monotonically increasing without duplicates."""
periodic: list[bool]
"""For each axis, determines how far the rightmost boundary gets shifted. """
shifts: NDArray
"""Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`"""
@property
def dxyz(self) -> list[NDArray]:
"""
Cell sizes for each axis, no shifts applied
Returns:
List of 3 ndarrays of cell sizes
"""
return [numpy.diff(ee) for ee in self.exyz]
@property
def xyz(self) -> list[NDArray]:
"""
Cell centers for each axis, no shifts applied
Returns:
List of 3 ndarrays of cell edges
"""
return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
@property
def shape(self) -> NDArray[numpy.intp]:
"""
The number of cells in x, y, and z
Returns:
ndarray of [x_centers.size, y_centers.size, z_centers.size]
"""
return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int)
@property
def num_grids(self) -> int:
"""
The number of grids (number of shifts)
"""
return self.shifts.shape[0]
@property
def cell_data_shape(self) -> NDArray[numpy.intp]:
"""
The shape of the cell_data ndarray (num_grids, *self.shape).
"""
return numpy.hstack((self.num_grids, self.shape))
@property
def dxyz_with_ghost(self) -> list[NDArray]:
"""
Gives dxyz with an additional 'ghost' cell at the end, whose value depends
on whether or not the axis has periodic boundary conditions. See main description
above to learn why this is necessary.
If periodic, final edge shifts same amount as first
Otherwise, final edge shifts same amount as second-to-last
Returns:
list of [dxs, dys, dzs] with each element same length as elements of `self.xyz`
"""
el = [0 if p else -1 for p in self.periodic]
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
@property
def center(self) -> NDArray[numpy.float64]:
"""
Center position of the entire grid, no shifts applied
Returns:
ndarray of [x_center, y_center, z_center]
"""
# center is just average of first and last xyz, which is just the average of the
# first two and last two exyz
centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)]
return numpy.array(centers, dtype=float)
@property
def dxyz_limits(self) -> tuple[NDArray, NDArray]:
"""
Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element
ndarrays. No shifts are applied, so these are extreme bounds on these values (as a
weighted average is performed when shifting).
Returns:
Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]`
"""
d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float)
d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float)
return d_min, d_max
def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]:
"""
Returns edges for which_shifts.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell edges
"""
if which_shifts is None:
return self.exyz
dxyz = self.dxyz_with_ghost
shifts = self.shifts[which_shifts, :]
# If shift is negative, use left cell's dx to determine shift
for a in range(3):
if shifts[a] < 0:
dxyz[a] = numpy.roll(dxyz[a], 1)
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
"""
Returns cell sizes for `which_shifts`.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell sizes
"""
if which_shifts is None:
return self.dxyz
shifts = self.shifts[which_shifts, :]
dxyz = self.dxyz_with_ghost
# If shift is negative, use left cell's dx to determine size
sdxyz = []
for a in range(3):
if shifts[a] < 0:
roll_dxyz = numpy.roll(dxyz[a], 1)
abs_shift = numpy.abs(shifts[a])
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
else:
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
return sdxyz
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
"""
Returns cell centers for `which_shifts`.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell centers
"""
if which_shifts is None:
return self.xyz
exyz = self.shifted_exyz(which_shifts)
dxyz = self.shifted_dxyz(which_shifts)
return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)]
def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]:
"""
Return cell widths, with each dimension shifted by the corresponding shifts.
Returns:
`[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]`
"""
if self.num_grids != 3:
raise GridError('Autoshifting requires exactly 3 grids')
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
def allocate(self, fill_value: float | None = 1.0, dtype: type[numpy.number] = numpy.float32) -> NDArray:
"""
Allocate an ndarray for storing grid data.
Args:
fill_value: Value to initialize the grid to. If None, an
uninitialized array is returned.
dtype: Numpy dtype for the array. Default is `numpy.float32`.
Returns:
The allocated array
"""
if fill_value is None:
return numpy.empty(self.cell_data_shape, dtype=dtype)
return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)

View File

@ -1,14 +1,13 @@
""" """
Drawing-related methods for Grid class Drawing-related methods for Grid class
""" """
from collections.abc import Sequence, Callable from typing import Union, Sequence, Callable
import numpy import numpy
from numpy.typing import NDArray, ArrayLike from numpy.typing import NDArray, ArrayLike
from float_raster import raster from float_raster import raster
from . import GridError from . import GridError
from .position import GridPosMixin
# NOTE: Maybe it would make sense to create a GridDrawer class # NOTE: Maybe it would make sense to create a GridDrawer class
@ -18,379 +17,375 @@ from .position import GridPosMixin
foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray] foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray]
foreground_t = float | foreground_callable_t foreground_t = Union[float, foreground_callable_t]
class GridDrawMixin(GridPosMixin): def draw_polygons(
def draw_polygons( self,
self, cell_data: NDArray,
cell_data: NDArray, surface_normal: int,
surface_normal: int, center: ArrayLike,
center: ArrayLike, polygons: Sequence[NDArray],
polygons: Sequence[ArrayLike], thickness: float,
thickness: float, foreground: Sequence[foreground_t] | foreground_t,
foreground: Sequence[foreground_t] | foreground_t, ) -> None:
) -> None: """
""" Draw polygons on an axis-aligned plane.
Draw polygons on an axis-aligned plane.
Args: Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying an offset applied to all the polygons center: 3-element ndarray or list specifying an offset applied to all the polygons
polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon
(non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each (non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each
polygon must have at least 3 vertices. polygon must have at least 3 vertices.
thickness: Thickness of the layer to draw thickness: Thickness of the layer to draw
foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list
of any of these (1 per grid). Callable values should take an ndarray the shape of the of any of these (1 per grid). Callable values should take an ndarray the shape of the
grid and return an ndarray of equal shape containing the foreground value at the given x, y, grid and return an ndarray of equal shape containing the foreground value at the given x, y,
and z (natural, not grid coordinates). and z (natural, not grid coordinates).
Raises: Raises:
GridError GridError
""" """
if surface_normal not in range(3): if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction') raise GridError('Invalid surface_normal direction')
center = numpy.squeeze(center)
poly_list = [numpy.array(poly, copy=False) for poly in polygons]
# Check polygons, and remove redundant coordinates center = numpy.squeeze(center)
surface = numpy.delete(range(3), surface_normal)
for ii in range(len(poly_list)): # Check polygons, and remove redundant coordinates
polygon = poly_list[ii] surface = numpy.delete(range(3), surface_normal)
malformed = f'Malformed polygon: ({ii})'
if polygon.shape[1] not in (2, 3):
raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray')
if polygon.shape[1] == 3:
polygon = polygon[surface, :]
poly_list[ii] = polygon
if not polygon.shape[0] > 2: for i, polygon in enumerate(polygons):
raise GridError(malformed + 'must consist of more than 2 points') malformed = f'Malformed polygon: ({i})'
if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1: if polygon.shape[1] not in (2, 3):
raise GridError(malformed + 'must be in plane with surface normal ' raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray')
+ 'xyz'[surface_normal]) if polygon.shape[1] == 3:
polygon = polygon[surface, :]
# Broadcast foreground where necessary if not polygon.shape[0] > 2:
foregrounds: Sequence[foreground_callable_t] | Sequence[float] raise GridError(malformed + 'must consist of more than 2 points')
if numpy.size(foreground) == 1: # type: ignore if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1:
foregrounds = [foreground] * len(cell_data) # type: ignore raise GridError(malformed + 'must be in plane with surface normal '
elif isinstance(foreground, numpy.ndarray): + 'xyz'[surface_normal])
raise GridError('ndarray not supported for foreground')
# Broadcast foreground where necessary
foregrounds: Sequence[foreground_callable_t] | Sequence[float]
if numpy.size(foreground) == 1: # type: ignore
foregrounds = [foreground] * len(cell_data) # type: ignore
elif isinstance(foreground, numpy.ndarray):
raise GridError('ndarray not supported for foreground')
else:
foregrounds = foreground # type: ignore
# ## Compute sub-domain of the grid occupied by polygons
# 1) Compute outer bounds (bd) of polygons
bd_2d_min = numpy.array([0, 0])
bd_2d_max = numpy.array([0, 0])
for polygon in polygons:
bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0))
bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0))
bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center
bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center
# 2) Find indices (bdi) just outside bd elements
buf = 2 # size of safety buffer
# Use s_min and s_max with unshifted pos2ind to get absolute limits on
# the indices the polygons might affect
s_min = self.shifts.min(axis=0)
s_max = self.shifts.max(axis=0)
bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf
bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf
bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int)
bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int)
# 3) Adjust polygons for center
polygons = [poly + center[surface] for poly in polygons]
# ## Generate weighing function
def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]:
v_2d = numpy.array(vector, dtype=float)
return numpy.insert(v_2d, surface_normal, (val,))
# iterate over grids
for i, _ in enumerate(cell_data):
# ## Evaluate or expand foregrounds[i]
foregrounds_i = foregrounds[i]
if callable(foregrounds_i):
# meshgrid over the (shifted) domain
domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k] + 1] for k in range(3)]
(x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij')
# evaluate on the meshgrid
foreground_val = foregrounds_i(x0, y0, z0)
if not numpy.isfinite(foreground_val).all():
raise GridError(f'Non-finite values in foreground[{i}]')
elif numpy.size(foregrounds_i) != 1:
raise GridError(f'Unsupported foreground[{i}]: {type(foregrounds_i)}')
else: else:
foregrounds = foreground # type: ignore # foreground[i] is scalar non-callable
foreground_val = foregrounds_i
# ## Compute sub-domain of the grid occupied by polygons w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int))
# 1) Compute outer bounds (bd) of polygons
bd_2d_min = numpy.array([0, 0])
bd_2d_max = numpy.array([0, 0])
for polygon in poly_list:
bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0))
bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0))
bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center
bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center
# 2) Find indices (bdi) just outside bd elements # Draw each polygon separately
buf = 2 # size of safety buffer for polygon in polygons:
# Use s_min and s_max with unshifted pos2ind to get absolute limits on
# the indices the polygons might affect
s_min = self.shifts.min(axis=0)
s_max = self.shifts.max(axis=0)
bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf
bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf
bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int)
bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int)
# 3) Adjust polygons for center # Get the boundaries of the polygon
poly_list = [poly + center[surface] for poly in poly_list] pbd_min = polygon.min(axis=0)
pbd_max = polygon.max(axis=0)
# ## Generate weighing function # Find indices in w_xy just outside polygon
def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]: # using per-grid xy-weights (self.shifted_xyz())
v_2d = numpy.array(vector, dtype=float) corner_min = self.pos2ind(to_3d(pbd_min), i,
return numpy.insert(v_2d, surface_normal, (val,)) check_bounds=False)[surface].astype(int)
corner_max = self.pos2ind(to_3d(pbd_max), i,
check_bounds=False)[surface].astype(int)
# iterate over grids # Find indices in w_xy which are modified by polygon
foreground_val: NDArray | float # First for the edge coordinates (+1 since we're indexing edges)
for i, _ in enumerate(cell_data): edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max)]
# ## Evaluate or expand foregrounds[i] # Then for the pixel centers (-bdi_min since we're
foregrounds_i = foregrounds[i] # calculating weights within a subspace)
if callable(foregrounds_i): centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface],
# meshgrid over the (shifted) domain corner_max - bdi_min[surface]))
domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k] + 1] for k in range(3)]
(x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij')
# evaluate on the meshgrid aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices))
foreground_val = foregrounds_i(x0, y0, z0) w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y)
if not numpy.isfinite(foreground_val).all():
raise GridError(f'Non-finite values in foreground[{i}]') # Clamp overlapping polygons to 1
elif numpy.size(foregrounds_i) != 1: w_xy = numpy.minimum(w_xy, 1.0)
raise GridError(f'Unsupported foreground[{i}]: {type(foregrounds_i)}')
# 2) Generate weights in z-direction
w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], ))
def get_zi(offset, i=i, w_z=w_z):
edges = self.shifted_exyz(i)[surface_normal]
point = center[surface_normal] + offset
grid_coord = numpy.digitize(point, edges) - 1
w_coord = grid_coord - bdi_min[surface_normal]
if w_coord < 0:
w_coord = 0
f = 0
elif w_coord >= w_z.size:
w_coord = w_z.size - 1
f = 1
else: else:
# foreground[i] is scalar non-callable dz = self.shifted_dxyz(i)[surface_normal][grid_coord]
foreground_val = foregrounds_i f = (point - edges[grid_coord]) / dz
return f, w_coord
w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int)) zi_top_f, zi_top = get_zi(+thickness / 2.0)
zi_bot_f, zi_bot = get_zi(-thickness / 2.0)
# Draw each polygon separately w_z[zi_bot + 1:zi_top] = 1
for polygon in poly_list:
# Get the boundaries of the polygon if zi_bot < zi_top:
pbd_min = polygon.min(axis=0) w_z[zi_top] = zi_top_f
pbd_max = polygon.max(axis=0) w_z[zi_bot] = 1 - zi_bot_f
else:
w_z[zi_bot] = zi_top_f - zi_bot_f
# Find indices in w_xy just outside polygon # 3) Generate total weight function
# using per-grid xy-weights (self.shifted_xyz()) w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,)))
corner_min = self.pos2ind(to_3d(pbd_min), i,
check_bounds=False)[surface].astype(int)
corner_max = self.pos2ind(to_3d(pbd_max), i,
check_bounds=False)[surface].astype(int)
# Find indices in w_xy which are modified by polygon # ## Modify the grid
# First for the edge coordinates (+1 since we're indexing edges) g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3))
edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max, strict=True)] cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val
# Then for the pixel centers (-bdi_min since we're
# calculating weights within a subspace)
centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface],
corner_max - bdi_min[surface], strict=True))
aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices, strict=True))
w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y)
# Clamp overlapping polygons to 1
w_xy = numpy.minimum(w_xy, 1.0)
# 2) Generate weights in z-direction
w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], ))
def get_zi(offset: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001
edges = self.shifted_exyz(i)[surface_normal]
point = center[surface_normal] + offset
grid_coord = numpy.digitize(point, edges) - 1
w_coord = grid_coord - bdi_min[surface_normal]
if w_coord < 0:
w_coord = 0
f = 0
elif w_coord >= w_z.size:
w_coord = w_z.size - 1
f = 1
else:
dz = self.shifted_dxyz(i)[surface_normal][grid_coord]
f = (point - edges[grid_coord]) / dz
return f, w_coord
zi_top_f, zi_top = get_zi(+thickness / 2.0)
zi_bot_f, zi_bot = get_zi(-thickness / 2.0)
w_z[zi_bot + 1:zi_top] = 1
if zi_bot < zi_top:
w_z[zi_top] = zi_top_f
w_z[zi_bot] = 1 - zi_bot_f
else:
w_z[zi_bot] = zi_top_f - zi_bot_f
# 3) Generate total weight function
w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,)))
# ## Modify the grid
g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3))
cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val
def draw_polygon( def draw_polygon(
self, self,
cell_data: NDArray, cell_data: NDArray,
surface_normal: int, surface_normal: int,
center: ArrayLike, center: ArrayLike,
polygon: ArrayLike, polygon: ArrayLike,
thickness: float, thickness: float,
foreground: Sequence[foreground_t] | foreground_t, foreground: Sequence[foreground_t] | foreground_t,
) -> None: ) -> None:
""" """
Draw a polygon on an axis-aligned plane. Draw a polygon on an axis-aligned plane.
Args: Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying an offset applied to the polygon center: 3-element ndarray or list specifying an offset applied to the polygon
polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed,
clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at
least 3 vertices. least 3 vertices.
thickness: Thickness of the layer to draw thickness: Thickness of the layer to draw
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
""" """
self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground) self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground)
def draw_slab( def draw_slab(
self, self,
cell_data: NDArray, cell_data: NDArray,
surface_normal: int, surface_normal: int,
center: ArrayLike, center: ArrayLike,
thickness: float, thickness: float,
foreground: Sequence[foreground_t] | foreground_t, foreground: Sequence[foreground_t] | foreground_t,
) -> None: ) -> None:
""" """
Draw an axis-aligned infinite slab. Draw an axis-aligned infinite slab.
Args: Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: `surface_normal` coordinate value at the center of the slab center: `surface_normal` coordinate value at the center of the slab
thickness: Thickness of the layer to draw thickness: Thickness of the layer to draw
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
""" """
# Turn surface_normal into its integer representation # Turn surface_normal into its integer representation
if surface_normal not in range(3): if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction') raise GridError('Invalid surface_normal direction')
if numpy.size(center) != 1: if numpy.size(center) != 1:
center = numpy.squeeze(center) center = numpy.squeeze(center)
if len(center) == 3: if len(center) == 3:
center = center[surface_normal] center = center[surface_normal]
else: else:
raise GridError(f'Bad center: {center}') raise GridError(f'Bad center: {center}')
# Find center of slab # Find center of slab
center_shift = self.center center_shift = self.center
center_shift[surface_normal] = center center_shift[surface_normal] = center
surface = numpy.delete(range(3), surface_normal) surface = numpy.delete(range(3), surface_normal)
xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface] xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface]
xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface] xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface]
dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float) dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float)
xyz_min -= 4 * dxyz xyz_min -= 4 * dxyz
xyz_max += 4 * dxyz xyz_max += 4 * dxyz
p = numpy.array([[xyz_min[0], xyz_max[1]], p = numpy.array([[xyz_min[0], xyz_max[1]],
[xyz_max[0], xyz_max[1]], [xyz_max[0], xyz_max[1]],
[xyz_max[0], xyz_min[1]], [xyz_max[0], xyz_min[1]],
[xyz_min[0], xyz_min[1]]], dtype=float) [xyz_min[0], xyz_min[1]]], dtype=float)
self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground) self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground)
def draw_cuboid( def draw_cuboid(
self, self,
cell_data: NDArray, cell_data: NDArray,
center: ArrayLike, center: ArrayLike,
dimensions: ArrayLike, dimensions: ArrayLike,
foreground: Sequence[foreground_t] | foreground_t, foreground: Sequence[foreground_t] | foreground_t,
) -> None: ) -> None:
""" """
Draw an axis-aligned cuboid Draw an axis-aligned cuboid
Args: Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
center: 3-element ndarray or list specifying the cuboid's center center: 3-element ndarray or list specifying the cuboid's center
dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge
sizes of the cuboid sizes of the cuboid
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
""" """
dimensions = numpy.array(dimensions, copy=False) dimensions = numpy.array(dimensions, copy=False)
p = numpy.array([[-dimensions[0], +dimensions[1]], p = numpy.array([[-dimensions[0], +dimensions[1]],
[+dimensions[0], +dimensions[1]], [+dimensions[0], +dimensions[1]],
[+dimensions[0], -dimensions[1]], [+dimensions[0], -dimensions[1]],
[-dimensions[0], -dimensions[1]]], dtype=float) / 2.0 [-dimensions[0], -dimensions[1]]], dtype=float) / 2.0
thickness = dimensions[2] thickness = dimensions[2]
self.draw_polygon(cell_data, 2, center, p, thickness, foreground) self.draw_polygon(cell_data, 2, center, p, thickness, foreground)
def draw_cylinder( def draw_cylinder(
self, self,
cell_data: NDArray, cell_data: NDArray,
surface_normal: int, surface_normal: int,
center: ArrayLike, center: ArrayLike,
radius: float, radius: float,
thickness: float, thickness: float,
num_points: int, num_points: int,
foreground: Sequence[foreground_t] | foreground_t, foreground: Sequence[foreground_t] | foreground_t,
) -> None: ) -> None:
""" """
Draw an axis-aligned cylinder. Approximated by a num_points-gon Draw an axis-aligned cylinder. Approximated by a num_points-gon
Args: Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying the cylinder's center center: 3-element ndarray or list specifying the cylinder's center
radius: cylinder radius radius: cylinder radius
thickness: Thickness of the layer to draw thickness: Thickness of the layer to draw
num_points: The circle is approximated by a polygon with `num_points` vertices num_points: The circle is approximated by a polygon with `num_points` vertices
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
""" """
theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False) theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False)
x = radius * numpy.sin(theta) x = radius * numpy.sin(theta)
y = radius * numpy.cos(theta) y = radius * numpy.cos(theta)
polygon = numpy.hstack((x[:, None], y[:, None])) polygon = numpy.hstack((x[:, None], y[:, None]))
self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground) self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground)
def draw_extrude_rectangle( def draw_extrude_rectangle(
self, self,
cell_data: NDArray, cell_data: NDArray,
rectangle: ArrayLike, rectangle: ArrayLike,
direction: int, direction: int,
polarity: int, polarity: int,
distance: float, distance: float,
) -> None: ) -> None:
""" """
Extrude a rectangle of a previously-drawn structure along an axis. Extrude a rectangle of a previously-drawn structure along an axis.
Args: Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
rectangle: 2x3 ndarray or list specifying the rectangle's corners rectangle: 2x3 ndarray or list specifying the rectangle's corners
direction: Direction to extrude in. Integer in `range(3)`. direction: Direction to extrude in. Integer in `range(3)`.
polarity: +1 or -1, direction along axis to extrude in polarity: +1 or -1, direction along axis to extrude in
distance: How far to extrude distance: How far to extrude
""" """
s = numpy.sign(polarity) s = numpy.sign(polarity)
rectangle = numpy.array(rectangle, dtype=float) rectangle = numpy.array(rectangle, dtype=float)
if s == 0: if s == 0:
raise GridError('0 is not a valid polarity') raise GridError('0 is not a valid polarity')
if direction not in range(3): if direction not in range(3):
raise GridError(f'Invalid direction: {direction}') raise GridError(f'Invalid direction: {direction}')
if rectangle[0, direction] != rectangle[1, direction]: if rectangle[0, direction] != rectangle[1, direction]:
raise GridError('Rectangle entries along extrusion direction do not match.') raise GridError('Rectangle entries along extrusion direction do not match.')
center = rectangle.sum(axis=0) / 2.0 center = rectangle.sum(axis=0) / 2.0
center[direction] += s * distance / 2.0 center[direction] += s * distance / 2.0
surface = numpy.delete(range(3), direction) surface = numpy.delete(range(3), direction)
dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface] dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface]
p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5, p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5,
numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T
thickness = distance thickness = distance
foreground_func = [] foreground_func = []
for i, grid in enumerate(cell_data): for i, grid in enumerate(cell_data):
z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction] z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction]
ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)] ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)]
fpart = z - numpy.floor(z) fpart = z - numpy.floor(z)
mult = [1 - fpart, fpart][::s] # reverses if s negative mult = [1 - fpart, fpart][::s] # reverses if s negative
foreground = mult[0] * grid[tuple(ind)] foreground = mult[0] * grid[tuple(ind)]
ind[direction] += 1 # type: ignore #(known safe) ind[direction] += 1 # type: ignore #(known safe)
foreground += mult[1] * grid[tuple(ind)] foreground += mult[1] * grid[tuple(ind)]
def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int_]:
# transform from natural position to index # transform from natural position to index
xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i) xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i)
for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64) for qrs in zip(xs.flat, ys.flat, zs.flat)], dtype=int)
# reshape to original shape and keep only in-plane components # reshape to original shape and keep only in-plane components
qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface) qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface)
return foreground[qi, ri] return foreground[qi, ri]
foreground_func.append(f_foreground) foreground_func.append(f_foreground)
self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func) self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func)

View File

@ -29,7 +29,7 @@ if __name__ == '__main__':
# numpy.linspace(-5.5, 5.5, 10)] # numpy.linspace(-5.5, 5.5, 10)]
half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5] half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5]
xyz3 = [numpy.array([-x for x in half_x[::-1]] + [0] + half_x), xyz3 = [[-x for x in half_x[::-1]] + [0] + half_x,
numpy.linspace(-5.5, 5.5, 10), numpy.linspace(-5.5, 5.5, 10),
numpy.linspace(-5.5, 5.5, 10)] numpy.linspace(-5.5, 5.5, 10)]
eg = Grid(xyz3) eg = Grid(xyz3)
@ -37,8 +37,8 @@ if __name__ == '__main__':
# eg.draw_slab(Direction.z, 0, 10, 2) # eg.draw_slab(Direction.z, 0, 10, 2)
eg.save('/home/jan/Desktop/test.pickle') eg.save('/home/jan/Desktop/test.pickle')
eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0, eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0,
thickness=10, num_points=1000, foreground=1) thickness=10, num_poitns=1000, foreground=1)
eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]], eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]],
direction=1, polarity=+1, distance=5) direction=1, poalarity=+1, distance=5)
eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2) eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2)
eg.visualize_isosurface(egc, which_shifts=2) eg.visualize_isosurface(egc, which_shifts=2)

View File

@ -1,5 +1,4 @@
from typing import ClassVar, Self from typing import Callable, Sequence, ClassVar, Self
from collections.abc import Callable, Sequence
import numpy import numpy
from numpy.typing import NDArray, ArrayLike from numpy.typing import NDArray, ArrayLike
@ -9,15 +8,12 @@ import warnings
import copy import copy
from . import GridError from . import GridError
from .draw import GridDrawMixin
from .read import GridReadMixin
from .position import GridPosMixin
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray] foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): class Grid:
""" """
Simulation grid metadata for finite-difference simulations. Simulation grid metadata for finite-difference simulations.
@ -74,6 +70,193 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
], dtype=float) ], dtype=float)
"""Default shifts for Yee grid H-field""" """Default shifts for Yee grid H-field"""
from .draw import (
draw_polygons, draw_polygon, draw_slab, draw_cuboid,
draw_cylinder, draw_extrude_rectangle,
)
from .read import get_slice, visualize_slice, visualize_isosurface
from .position import ind2pos, pos2ind
@property
def dxyz(self) -> list[NDArray]:
"""
Cell sizes for each axis, no shifts applied
Returns:
List of 3 ndarrays of cell sizes
"""
return [numpy.diff(ee) for ee in self.exyz]
@property
def xyz(self) -> list[NDArray]:
"""
Cell centers for each axis, no shifts applied
Returns:
List of 3 ndarrays of cell edges
"""
return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
@property
def shape(self) -> NDArray[numpy.int_]:
"""
The number of cells in x, y, and z
Returns:
ndarray of [x_centers.size, y_centers.size, z_centers.size]
"""
return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int)
@property
def num_grids(self) -> int:
"""
The number of grids (number of shifts)
"""
return self.shifts.shape[0]
@property
def cell_data_shape(self):
"""
The shape of the cell_data ndarray (num_grids, *self.shape).
"""
return numpy.hstack((self.num_grids, self.shape))
@property
def dxyz_with_ghost(self) -> list[NDArray]:
"""
Gives dxyz with an additional 'ghost' cell at the end, whose value depends
on whether or not the axis has periodic boundary conditions. See main description
above to learn why this is necessary.
If periodic, final edge shifts same amount as first
Otherwise, final edge shifts same amount as second-to-last
Returns:
list of [dxs, dys, dzs] with each element same length as elements of `self.xyz`
"""
el = [0 if p else -1 for p in self.periodic]
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el)]
@property
def center(self) -> NDArray[numpy.float64]:
"""
Center position of the entire grid, no shifts applied
Returns:
ndarray of [x_center, y_center, z_center]
"""
# center is just average of first and last xyz, which is just the average of the
# first two and last two exyz
centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)]
return numpy.array(centers, dtype=float)
@property
def dxyz_limits(self) -> tuple[NDArray, NDArray]:
"""
Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element
ndarrays. No shifts are applied, so these are extreme bounds on these values (as a
weighted average is performed when shifting).
Returns:
Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]`
"""
d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float)
d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float)
return d_min, d_max
def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]:
"""
Returns edges for which_shifts.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell edges
"""
if which_shifts is None:
return self.exyz
dxyz = self.dxyz_with_ghost
shifts = self.shifts[which_shifts, :]
# If shift is negative, use left cell's dx to determine shift
for a in range(3):
if shifts[a] < 0:
dxyz[a] = numpy.roll(dxyz[a], 1)
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
"""
Returns cell sizes for `which_shifts`.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell sizes
"""
if which_shifts is None:
return self.dxyz
shifts = self.shifts[which_shifts, :]
dxyz = self.dxyz_with_ghost
# If shift is negative, use left cell's dx to determine size
sdxyz = []
for a in range(3):
if shifts[a] < 0:
roll_dxyz = numpy.roll(dxyz[a], 1)
abs_shift = numpy.abs(shifts[a])
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
else:
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
return sdxyz
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
"""
Returns cell centers for `which_shifts`.
Args:
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
Returns:
List of 3 ndarrays of cell centers
"""
if which_shifts is None:
return self.xyz
exyz = self.shifted_exyz(which_shifts)
dxyz = self.shifted_dxyz(which_shifts)
return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)]
def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]:
"""
Return cell widths, with each dimension shifted by the corresponding shifts.
Returns:
`[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]`
"""
if self.num_grids != 3:
raise GridError('Autoshifting requires exactly 3 grids')
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray:
"""
Allocate an ndarray for storing grid data.
Args:
fill_value: Value to initialize the grid to. If None, an
uninitialized array is returned.
dtype: Numpy dtype for the array. Default is `numpy.float32`.
Returns:
The allocated array
"""
if fill_value is None:
return numpy.empty(self.cell_data_shape, dtype=dtype)
else:
return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)
def __init__( def __init__(
self, self,
pixel_edge_coordinates: Sequence[ArrayLike], pixel_edge_coordinates: Sequence[ArrayLike],
@ -94,12 +277,11 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
Raises: Raises:
`GridError` on invalid input `GridError` on invalid input
""" """
edge_arrs = [numpy.array(cc, copy=False) for cc in pixel_edge_coordinates] self.exyz = [numpy.unique(pixel_edge_coordinates[i]) for i in range(3)]
self.exyz = [numpy.unique(edges) for edges in edge_arrs]
self.shifts = numpy.array(shifts, dtype=float) self.shifts = numpy.array(shifts, dtype=float)
for i in range(3): for i in range(3):
if self.exyz[i].size != edge_arrs[i].size: if len(self.exyz[i]) != len(pixel_edge_coordinates[i]):
warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2) warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2)
if isinstance(periodic, bool): if isinstance(periodic, bool):

View File

@ -5,114 +5,112 @@ import numpy
from numpy.typing import NDArray, ArrayLike from numpy.typing import NDArray, ArrayLike
from . import GridError from . import GridError
from .base import GridBase
class GridPosMixin(GridBase): def ind2pos(
def ind2pos( self,
self, ind: NDArray,
ind: NDArray, which_shifts: int | None = None,
which_shifts: int | None = None, round_ind: bool = True,
round_ind: bool = True, check_bounds: bool = True
check_bounds: bool = True ) -> NDArray[numpy.float64]:
) -> NDArray[numpy.float64]: """
""" Returns the natural position corresponding to the specified cell center indices.
Returns the natural position corresponding to the specified cell center indices. The resulting position is clipped to the bounds of the grid
The resulting position is clipped to the bounds of the grid (to cell centers if `round_ind=True`, or cell outer edges if `round_ind=False`)
(to cell centers if `round_ind=True`, or cell outer edges if `round_ind=False`)
Args: Args:
ind: Indices of the position. Can be fractional. (3-element ndarray or list) ind: Indices of the position. Can be fractional. (3-element ndarray or list)
which_shifts: which grid number (`shifts`) to use which_shifts: which grid number (`shifts`) to use
round_ind: Whether to round ind to the nearest integer position before indexing round_ind: Whether to round ind to the nearest integer position before indexing
(default `True`) (default `True`)
check_bounds: Whether to raise an `GridError` if the provided ind is outside of check_bounds: Whether to raise an `GridError` if the provided ind is outside of
the grid, as defined above (centers if `round_ind`, else edges) (default `True`) the grid, as defined above (centers if `round_ind`, else edges) (default `True`)
Returns: Returns:
3-element ndarray specifying the natural position 3-element ndarray specifying the natural position
Raises: Raises:
`GridError` if invalid `which_shifts` `GridError` if invalid `which_shifts`
`GridError` if `check_bounds` and out of bounds `GridError` if `check_bounds` and out of bounds
""" """
if which_shifts is not None and which_shifts >= self.shifts.shape[0]: if which_shifts is not None and which_shifts >= self.shifts.shape[0]:
raise GridError('Invalid shifts') raise GridError('Invalid shifts')
ind = numpy.array(ind, dtype=float) ind = numpy.array(ind, dtype=float)
if check_bounds:
if round_ind:
low_bound = 0.0
high_bound = -1.0
else:
low_bound = -0.5
high_bound = -0.5
if (ind < low_bound).any() or (ind > self.shape - high_bound).any():
raise GridError(f'Position outside of grid: {ind}')
if check_bounds:
if round_ind: if round_ind:
rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) low_bound = 0.0
sxyz = self.shifted_xyz(which_shifts) high_bound = -1.0
position = [sxyz[a][rind[a]].astype(int) for a in range(3)]
else: else:
sexyz = self.shifted_exyz(which_shifts) low_bound = -0.5
position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a]) high_bound = -0.5
for a in range(3)] if (ind < low_bound).any() or (ind > self.shape - high_bound).any():
return numpy.array(position, dtype=float) raise GridError(f'Position outside of grid: {ind}')
def pos2ind(
self,
r: ArrayLike,
which_shifts: int | None,
round_ind: bool = True,
check_bounds: bool = True
) -> NDArray[numpy.float64]:
"""
Returns the cell-center indices corresponding to the specified natural position.
The resulting position is clipped to within the outer centers of the grid.
Args:
r: Natural position that we will convert into indices (3-element ndarray or list)
which_shifts: which grid number (`shifts`) to use
round_ind: Whether to round the returned indices to the nearest integers.
check_bounds: Whether to throw an `GridError` if `r` is outside the grid edges
Returns:
3-element ndarray specifying the indices
Raises:
`GridError` if invalid `which_shifts`
`GridError` if `check_bounds` and out of bounds
"""
r = numpy.squeeze(r)
if r.size != 3:
raise GridError(f'r must be 3-element vector: {r}')
if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]):
raise GridError(f'Invalid which_shifts: {which_shifts}')
if round_ind:
rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1)
sxyz = self.shifted_xyz(which_shifts)
position = [sxyz[a][rind[a]].astype(int) for a in range(3)]
else:
sexyz = self.shifted_exyz(which_shifts) sexyz = self.shifted_exyz(which_shifts)
position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a])
for a in range(3)]
return numpy.array(position, dtype=float)
if check_bounds:
for a in range(3):
if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]):
raise GridError(f'Position[{a}] outside of grid!')
grid_pos = numpy.zeros((3,)) def pos2ind(
self,
r: ArrayLike,
which_shifts: int | None,
round_ind: bool = True,
check_bounds: bool = True
) -> NDArray[numpy.float64]:
"""
Returns the cell-center indices corresponding to the specified natural position.
The resulting position is clipped to within the outer centers of the grid.
Args:
r: Natural position that we will convert into indices (3-element ndarray or list)
which_shifts: which grid number (`shifts`) to use
round_ind: Whether to round the returned indices to the nearest integers.
check_bounds: Whether to throw an `GridError` if `r` is outside the grid edges
Returns:
3-element ndarray specifying the indices
Raises:
`GridError` if invalid `which_shifts`
`GridError` if `check_bounds` and out of bounds
"""
r = numpy.squeeze(r)
if r.size != 3:
raise GridError(f'r must be 3-element vector: {r}')
if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]):
raise GridError(f'Invalid which_shifts: {which_shifts}')
sexyz = self.shifted_exyz(which_shifts)
if check_bounds:
for a in range(3): for a in range(3):
xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]):
xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds raise GridError(f'Position[{a}] outside of grid!')
# No need to interpolate if round_ind is true or we were outside the grid grid_pos = numpy.zeros((3,))
if round_ind or xi != xi_clipped: for a in range(3):
grid_pos[a] = xi_clipped xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in
else: xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds
# Interpolate
x = self.shifted_xyz(which_shifts)[a][xi]
dx = self.shifted_dxyz(which_shifts)[a][xi]
f = (r[a] - x) / dx
# Clip to centers # No need to interpolate if round_ind is true or we were outside the grid
grid_pos[a] = numpy.clip(xi + f, 0, self.shape[a] - 1) if round_ind or xi != xi_clipped:
return grid_pos grid_pos[a] = xi_clipped
else:
# Interpolate
x = self.shifted_xyz(which_shifts)[a][xi]
dx = self.shifted_dxyz(which_shifts)[a][xi]
f = (r[a] - x) / dx
# Clip to centers
grid_pos[a] = numpy.clip(xi + f, 0, self.shape[a] - 1)
return grid_pos

View File

@ -7,7 +7,6 @@ import numpy
from numpy.typing import NDArray from numpy.typing import NDArray
from . import GridError from . import GridError
from .position import GridPosMixin
if TYPE_CHECKING: if TYPE_CHECKING:
import matplotlib.axes import matplotlib.axes
@ -19,187 +18,185 @@ if TYPE_CHECKING:
# .visualize_isosurface uses mpl_toolkits.mplot3d # .visualize_isosurface uses mpl_toolkits.mplot3d
class GridReadMixin(GridPosMixin): def get_slice(
def get_slice( self,
self, cell_data: NDArray,
cell_data: NDArray, surface_normal: int,
surface_normal: int, center: float,
center: float, which_shifts: int = 0,
which_shifts: int = 0, sample_period: int = 1
sample_period: int = 1 ) -> NDArray:
) -> NDArray: """
""" Retrieve a slice of a grid.
Retrieve a slice of a grid. Interpolates if given a position between two planes.
Interpolates if given a position between two planes.
Args: Args:
cell_data: Cell data to slice cell_data: Cell data to slice
surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
center: Scalar specifying position along surface_normal axis. center: Scalar specifying position along surface_normal axis.
which_shifts: Which grid to display. Default is the first grid (0). which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled) sample_period: Period for down-sampling the image. Default 1 (disabled)
Returns: Returns:
Array containing the portion of the grid. Array containing the portion of the grid.
""" """
if numpy.size(center) != 1 or not numpy.isreal(center): if numpy.size(center) != 1 or not numpy.isreal(center):
raise GridError('center must be a real scalar') raise GridError('center must be a real scalar')
sp = round(sample_period) sp = round(sample_period)
if sp <= 0: if sp <= 0:
raise GridError('sample_period must be positive') raise GridError('sample_period must be positive')
if numpy.size(which_shifts) != 1 or which_shifts < 0: if numpy.size(which_shifts) != 1 or which_shifts < 0:
raise GridError('Invalid which_shifts') raise GridError('Invalid which_shifts')
if surface_normal not in range(3): if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction') raise GridError('Invalid surface_normal direction')
surface = numpy.delete(range(3), surface_normal) surface = numpy.delete(range(3), surface_normal)
# Extract indices and weights of planes # Extract indices and weights of planes
center3 = numpy.insert([0, 0], surface_normal, (center,)) center3 = numpy.insert([0, 0], surface_normal, (center,))
center_index = self.pos2ind(center3, which_shifts, center_index = self.pos2ind(center3, which_shifts,
round_ind=False, check_bounds=False)[surface_normal] round_ind=False, check_bounds=False)[surface_normal]
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
if len(centers) == 2: if len(centers) == 2:
fpart = center_index - numpy.floor(center_index) fpart = center_index - numpy.floor(center_index)
w = [1 - fpart, fpart] # longer distance -> less weight w = [1 - fpart, fpart] # longer distance -> less weight
else: else:
w = [1] w = [1]
c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1]) c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1])
if center < c_min or center > c_max: if center < c_min or center > c_max:
raise GridError('Coordinate of selected plane must be within simulation domain') raise GridError('Coordinate of selected plane must be within simulation domain')
# Extract grid values from planes above and below visualized slice # Extract grid values from planes above and below visualized slice
sliced_grid = numpy.zeros(self.shape[surface]) sliced_grid = numpy.zeros(self.shape[surface])
for ci, weight in zip(centers, w, strict=True): for ci, weight in zip(centers, w):
s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3)) s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3))
sliced_grid += weight * cell_data[which_shifts][tuple(s)] sliced_grid += weight * cell_data[which_shifts][tuple(s)]
# Remove extra dimensions # Remove extra dimensions
sliced_grid = numpy.squeeze(sliced_grid) sliced_grid = numpy.squeeze(sliced_grid)
return sliced_grid return sliced_grid
def visualize_slice( def visualize_slice(
self, self,
cell_data: NDArray, cell_data: NDArray,
surface_normal: int, surface_normal: int,
center: float, center: float,
which_shifts: int = 0, which_shifts: int = 0,
sample_period: int = 1, sample_period: int = 1,
finalize: bool = True, finalize: bool = True,
pcolormesh_args: dict[str, Any] | None = None, pcolormesh_args: dict[str, Any] | None = None,
) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']:
""" """
Visualize a slice of a grid. Visualize a slice of a grid.
Interpolates if given a position between two planes. Interpolates if given a position between two planes.
Args: Args:
surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
center: Scalar specifying position along surface_normal axis. center: Scalar specifying position along surface_normal axis.
which_shifts: Which grid to display. Default is the first grid (0). which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled) sample_period: Period for down-sampling the image. Default 1 (disabled)
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
Returns: Returns:
(Figure, Axes) (Figure, Axes)
""" """
from matplotlib import pyplot from matplotlib import pyplot
if pcolormesh_args is None: if pcolormesh_args is None:
pcolormesh_args = {} pcolormesh_args = {}
grid_slice = self.get_slice(cell_data=cell_data, grid_slice = self.get_slice(cell_data=cell_data,
surface_normal=surface_normal, surface_normal=surface_normal,
center=center, center=center,
which_shifts=which_shifts, which_shifts=which_shifts,
sample_period=sample_period) sample_period=sample_period)
surface = numpy.delete(range(3), surface_normal) surface = numpy.delete(range(3), surface_normal)
x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
x_label, y_label = ('xyz'[a] for a in surface) x_label, y_label = ('xyz'[a] for a in surface)
fig, ax = pyplot.subplots() fig, ax = pyplot.subplots()
mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args)
fig.colorbar(mappable) fig.colorbar(mappable)
ax.set_aspect('equal', adjustable='box') ax.set_aspect('equal', adjustable='box')
ax.set_xlabel(x_label) ax.set_xlabel(x_label)
ax.set_ylabel(y_label) ax.set_ylabel(y_label)
if finalize: if finalize:
pyplot.show() pyplot.show()
return fig, ax return fig, ax
def visualize_isosurface( def visualize_isosurface(
self, self,
cell_data: NDArray, cell_data: NDArray,
level: float | None = None, level: float | None = None,
which_shifts: int = 0, which_shifts: int = 0,
sample_period: int = 1, sample_period: int = 1,
show_edges: bool = True, show_edges: bool = True,
finalize: bool = True, finalize: bool = True,
) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']:
""" """
Draw an isosurface plot of the device. Draw an isosurface plot of the device.
Args: Args:
cell_data: Cell data to visualize cell_data: Cell data to visualize
level: Value at which to find isosurface. Default (None) uses mean value in grid. level: Value at which to find isosurface. Default (None) uses mean value in grid.
which_shifts: Which grid to display. Default is the first grid (0). which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled) sample_period: Period for down-sampling the image. Default 1 (disabled)
show_edges: Whether to draw triangle edges. Default `True` show_edges: Whether to draw triangle edges. Default `True`
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
Returns: Returns:
(Figure, Axes) (Figure, Axes)
""" """
from matplotlib import pyplot from matplotlib import pyplot
import skimage.measure import skimage.measure
# Claims to be unused, but needed for subplot(projection='3d') # Claims to be unused, but needed for subplot(projection='3d')
from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d import Axes3D
del Axes3D # imported for side effects only
# Get data from cell_data # Get data from cell_data
grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period] grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period]
if level is None: if level is None:
level = grid.mean() level = grid.mean()
# Find isosurface with marching cubes # Find isosurface with marching cubes
verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level)
# Convert vertices from index to position # Convert vertices from index to position
pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False) pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False)
for i in range(verts.shape[0])], dtype=float) for i in range(verts.shape[0])], dtype=float)
xs, ys, zs = (pos_verts[:, a] for a in range(3)) xs, ys, zs = (pos_verts[:, a] for a in range(3))
# Draw the plot # Draw the plot
fig = pyplot.figure() fig = pyplot.figure()
ax = fig.add_subplot(111, projection='3d') ax = fig.add_subplot(111, projection='3d')
if show_edges: if show_edges:
ax.plot_trisurf(xs, ys, faces, zs) ax.plot_trisurf(xs, ys, faces, zs)
else: else:
ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none')
# Add a fake plot of a cube to force the axes to be equal lengths # Add a fake plot of a cube to force the axes to be equal lengths
max_range = numpy.array([xs.max() - xs.min(), max_range = numpy.array([xs.max() - xs.min(),
ys.max() - ys.min(), ys.max() - ys.min(),
zs.max() - zs.min()], dtype=float).max() zs.max() - zs.min()], dtype=float).max()
mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2] mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2]
xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min()) xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min())
ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min()) ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min())
zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min()) zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min())
# Comment or uncomment following both lines to test the fake bounding box: # Comment or uncomment following both lines to test the fake bounding box:
for xb, yb, zb in zip(xbs, ybs, zbs, strict=True): for xb, yb, zb in zip(xbs, ybs, zbs):
ax.plot([xb], [yb], [zb], 'w') ax.plot([xb], [yb], [zb], 'w')
if finalize: if finalize:
pyplot.show() pyplot.show()
return fig, ax return fig, ax

View File

@ -1,6 +1,6 @@
import pytest # type: ignore import pytest # type: ignore
import numpy import numpy
from numpy.testing import assert_allclose #, assert_array_equal from numpy.testing import assert_allclose, assert_array_equal
from .. import Grid from .. import Grid

View File

@ -53,47 +53,3 @@ visualization-isosurface = [
"skimage>=0.13", "skimage>=0.13",
"mpl_toolkits", "mpl_toolkits",
] ]
[tool.ruff]
exclude = [
".git",
"dist",
]
line-length = 145
indent-width = 4
lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
lint.select = [
"NPY", "E", "F", "W", "B", "ANN", "UP", "SLOT", "SIM", "LOG",
"C4", "ISC", "PIE", "PT", "RET", "TCH", "PTH", "INT",
"ARG", "PL", "R", "TRY",
"G010", "G101", "G201", "G202",
"Q002", "Q003", "Q004",
]
lint.ignore = [
#"ANN001", # No annotation
"ANN002", # *args
"ANN003", # **kwargs
"ANN401", # Any
"ANN101", # self: Self
"SIM108", # single-line if / else assignment
"RET504", # x=y+z; return x
"PIE790", # unnecessary pass
"ISC003", # non-implicit string concatenation
"C408", # dict(x=y) instead of {'x': y}
"PLR09", # Too many xxx
"PLR2004", # magic number
"PLC0414", # import x as x
"TRY003", # Long exception message
"PTH123", # open()
]
[[tool.mypy.overrides]]
module = [
"matplotlib",
"matplotlib.axes",
"matplotlib.figure",
"mpl_toolkits.mplot3d",
]
ignore_missing_imports = true