Compare commits
10 Commits
9ab97e763c
...
8e7e0edb1f
Author | SHA1 | Date | |
---|---|---|---|
8e7e0edb1f | |||
e5fdc3ce23 | |||
646911c4b5 | |||
e256f56f2b | |||
c32d94ed85 | |||
8c33a39c02 | |||
f84a75f35a | |||
5a20339eab | |||
e29c0901bd | |||
a15e4bc05e |
@ -15,8 +15,8 @@ Dependencies:
|
|||||||
- mpl_toolkits.mplot3d [Grid.visualize_isosurface()]
|
- mpl_toolkits.mplot3d [Grid.visualize_isosurface()]
|
||||||
- skimage [Grid.visualize_isosurface()]
|
- skimage [Grid.visualize_isosurface()]
|
||||||
"""
|
"""
|
||||||
from .error import GridError
|
from .error import GridError as GridError
|
||||||
from .grid import Grid
|
from .grid import Grid as Grid
|
||||||
|
|
||||||
__author__ = 'Jan Petykiewicz'
|
__author__ = 'Jan Petykiewicz'
|
||||||
__version__ = '1.1'
|
__version__ = '1.1'
|
||||||
|
196
gridlock/base.py
Normal file
196
gridlock/base.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from . import GridError
|
||||||
|
|
||||||
|
|
||||||
|
class GridBase(Protocol):
|
||||||
|
exyz: list[NDArray]
|
||||||
|
"""Cell edges. Monotonically increasing without duplicates."""
|
||||||
|
|
||||||
|
periodic: list[bool]
|
||||||
|
"""For each axis, determines how far the rightmost boundary gets shifted. """
|
||||||
|
|
||||||
|
shifts: NDArray
|
||||||
|
"""Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dxyz(self) -> list[NDArray]:
|
||||||
|
"""
|
||||||
|
Cell sizes for each axis, no shifts applied
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of 3 ndarrays of cell sizes
|
||||||
|
"""
|
||||||
|
return [numpy.diff(ee) for ee in self.exyz]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def xyz(self) -> list[NDArray]:
|
||||||
|
"""
|
||||||
|
Cell centers for each axis, no shifts applied
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of 3 ndarrays of cell edges
|
||||||
|
"""
|
||||||
|
return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self) -> NDArray[numpy.intp]:
|
||||||
|
"""
|
||||||
|
The number of cells in x, y, and z
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ndarray of [x_centers.size, y_centers.size, z_centers.size]
|
||||||
|
"""
|
||||||
|
return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_grids(self) -> int:
|
||||||
|
"""
|
||||||
|
The number of grids (number of shifts)
|
||||||
|
"""
|
||||||
|
return self.shifts.shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cell_data_shape(self) -> NDArray[numpy.intp]:
|
||||||
|
"""
|
||||||
|
The shape of the cell_data ndarray (num_grids, *self.shape).
|
||||||
|
"""
|
||||||
|
return numpy.hstack((self.num_grids, self.shape))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dxyz_with_ghost(self) -> list[NDArray]:
|
||||||
|
"""
|
||||||
|
Gives dxyz with an additional 'ghost' cell at the end, whose value depends
|
||||||
|
on whether or not the axis has periodic boundary conditions. See main description
|
||||||
|
above to learn why this is necessary.
|
||||||
|
|
||||||
|
If periodic, final edge shifts same amount as first
|
||||||
|
Otherwise, final edge shifts same amount as second-to-last
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of [dxs, dys, dzs] with each element same length as elements of `self.xyz`
|
||||||
|
"""
|
||||||
|
el = [0 if p else -1 for p in self.periodic]
|
||||||
|
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def center(self) -> NDArray[numpy.float64]:
|
||||||
|
"""
|
||||||
|
Center position of the entire grid, no shifts applied
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ndarray of [x_center, y_center, z_center]
|
||||||
|
"""
|
||||||
|
# center is just average of first and last xyz, which is just the average of the
|
||||||
|
# first two and last two exyz
|
||||||
|
centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)]
|
||||||
|
return numpy.array(centers, dtype=float)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dxyz_limits(self) -> tuple[NDArray, NDArray]:
|
||||||
|
"""
|
||||||
|
Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element
|
||||||
|
ndarrays. No shifts are applied, so these are extreme bounds on these values (as a
|
||||||
|
weighted average is performed when shifting).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]`
|
||||||
|
"""
|
||||||
|
d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float)
|
||||||
|
d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float)
|
||||||
|
return d_min, d_max
|
||||||
|
|
||||||
|
def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]:
|
||||||
|
"""
|
||||||
|
Returns edges for which_shifts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of 3 ndarrays of cell edges
|
||||||
|
"""
|
||||||
|
if which_shifts is None:
|
||||||
|
return self.exyz
|
||||||
|
dxyz = self.dxyz_with_ghost
|
||||||
|
shifts = self.shifts[which_shifts, :]
|
||||||
|
|
||||||
|
# If shift is negative, use left cell's dx to determine shift
|
||||||
|
for a in range(3):
|
||||||
|
if shifts[a] < 0:
|
||||||
|
dxyz[a] = numpy.roll(dxyz[a], 1)
|
||||||
|
|
||||||
|
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
|
||||||
|
|
||||||
|
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
|
||||||
|
"""
|
||||||
|
Returns cell sizes for `which_shifts`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of 3 ndarrays of cell sizes
|
||||||
|
"""
|
||||||
|
if which_shifts is None:
|
||||||
|
return self.dxyz
|
||||||
|
shifts = self.shifts[which_shifts, :]
|
||||||
|
dxyz = self.dxyz_with_ghost
|
||||||
|
|
||||||
|
# If shift is negative, use left cell's dx to determine size
|
||||||
|
sdxyz = []
|
||||||
|
for a in range(3):
|
||||||
|
if shifts[a] < 0:
|
||||||
|
roll_dxyz = numpy.roll(dxyz[a], 1)
|
||||||
|
abs_shift = numpy.abs(shifts[a])
|
||||||
|
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
|
||||||
|
else:
|
||||||
|
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
|
||||||
|
|
||||||
|
return sdxyz
|
||||||
|
|
||||||
|
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
||||||
|
"""
|
||||||
|
Returns cell centers for `which_shifts`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of 3 ndarrays of cell centers
|
||||||
|
"""
|
||||||
|
if which_shifts is None:
|
||||||
|
return self.xyz
|
||||||
|
exyz = self.shifted_exyz(which_shifts)
|
||||||
|
dxyz = self.shifted_dxyz(which_shifts)
|
||||||
|
return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)]
|
||||||
|
|
||||||
|
def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]:
|
||||||
|
"""
|
||||||
|
Return cell widths, with each dimension shifted by the corresponding shifts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]`
|
||||||
|
"""
|
||||||
|
if self.num_grids != 3:
|
||||||
|
raise GridError('Autoshifting requires exactly 3 grids')
|
||||||
|
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
|
||||||
|
|
||||||
|
def allocate(self, fill_value: float | None = 1.0, dtype: type[numpy.number] = numpy.float32) -> NDArray:
|
||||||
|
"""
|
||||||
|
Allocate an ndarray for storing grid data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fill_value: Value to initialize the grid to. If None, an
|
||||||
|
uninitialized array is returned.
|
||||||
|
dtype: Numpy dtype for the array. Default is `numpy.float32`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The allocated array
|
||||||
|
"""
|
||||||
|
if fill_value is None:
|
||||||
|
return numpy.empty(self.cell_data_shape, dtype=dtype)
|
||||||
|
return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)
|
@ -1,13 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
Drawing-related methods for Grid class
|
Drawing-related methods for Grid class
|
||||||
"""
|
"""
|
||||||
from typing import Union, Sequence, Callable
|
from collections.abc import Sequence, Callable
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.typing import NDArray, ArrayLike
|
from numpy.typing import NDArray, ArrayLike
|
||||||
from float_raster import raster
|
from float_raster import raster
|
||||||
|
|
||||||
from . import GridError
|
from . import GridError
|
||||||
|
from .position import GridPosMixin
|
||||||
|
|
||||||
|
|
||||||
# NOTE: Maybe it would make sense to create a GridDrawer class
|
# NOTE: Maybe it would make sense to create a GridDrawer class
|
||||||
@ -17,15 +18,16 @@ from . import GridError
|
|||||||
|
|
||||||
|
|
||||||
foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray]
|
foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray]
|
||||||
foreground_t = Union[float, foreground_callable_t]
|
foreground_t = float | foreground_callable_t
|
||||||
|
|
||||||
|
|
||||||
|
class GridDrawMixin(GridPosMixin):
|
||||||
def draw_polygons(
|
def draw_polygons(
|
||||||
self,
|
self,
|
||||||
cell_data: NDArray,
|
cell_data: NDArray,
|
||||||
surface_normal: int,
|
surface_normal: int,
|
||||||
center: ArrayLike,
|
center: ArrayLike,
|
||||||
polygons: Sequence[NDArray],
|
polygons: Sequence[ArrayLike],
|
||||||
thickness: float,
|
thickness: float,
|
||||||
foreground: Sequence[foreground_t] | foreground_t,
|
foreground: Sequence[foreground_t] | foreground_t,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -50,18 +52,20 @@ def draw_polygons(
|
|||||||
"""
|
"""
|
||||||
if surface_normal not in range(3):
|
if surface_normal not in range(3):
|
||||||
raise GridError('Invalid surface_normal direction')
|
raise GridError('Invalid surface_normal direction')
|
||||||
|
|
||||||
center = numpy.squeeze(center)
|
center = numpy.squeeze(center)
|
||||||
|
poly_list = [numpy.array(poly, copy=False) for poly in polygons]
|
||||||
|
|
||||||
# Check polygons, and remove redundant coordinates
|
# Check polygons, and remove redundant coordinates
|
||||||
surface = numpy.delete(range(3), surface_normal)
|
surface = numpy.delete(range(3), surface_normal)
|
||||||
|
|
||||||
for i, polygon in enumerate(polygons):
|
for ii in range(len(poly_list)):
|
||||||
malformed = f'Malformed polygon: ({i})'
|
polygon = poly_list[ii]
|
||||||
|
malformed = f'Malformed polygon: ({ii})'
|
||||||
if polygon.shape[1] not in (2, 3):
|
if polygon.shape[1] not in (2, 3):
|
||||||
raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray')
|
raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray')
|
||||||
if polygon.shape[1] == 3:
|
if polygon.shape[1] == 3:
|
||||||
polygon = polygon[surface, :]
|
polygon = polygon[surface, :]
|
||||||
|
poly_list[ii] = polygon
|
||||||
|
|
||||||
if not polygon.shape[0] > 2:
|
if not polygon.shape[0] > 2:
|
||||||
raise GridError(malformed + 'must consist of more than 2 points')
|
raise GridError(malformed + 'must consist of more than 2 points')
|
||||||
@ -82,7 +86,7 @@ def draw_polygons(
|
|||||||
# 1) Compute outer bounds (bd) of polygons
|
# 1) Compute outer bounds (bd) of polygons
|
||||||
bd_2d_min = numpy.array([0, 0])
|
bd_2d_min = numpy.array([0, 0])
|
||||||
bd_2d_max = numpy.array([0, 0])
|
bd_2d_max = numpy.array([0, 0])
|
||||||
for polygon in polygons:
|
for polygon in poly_list:
|
||||||
bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0))
|
bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0))
|
||||||
bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0))
|
bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0))
|
||||||
bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center
|
bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center
|
||||||
@ -100,7 +104,7 @@ def draw_polygons(
|
|||||||
bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int)
|
bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int)
|
||||||
|
|
||||||
# 3) Adjust polygons for center
|
# 3) Adjust polygons for center
|
||||||
polygons = [poly + center[surface] for poly in polygons]
|
poly_list = [poly + center[surface] for poly in poly_list]
|
||||||
|
|
||||||
# ## Generate weighing function
|
# ## Generate weighing function
|
||||||
def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]:
|
def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]:
|
||||||
@ -108,6 +112,7 @@ def draw_polygons(
|
|||||||
return numpy.insert(v_2d, surface_normal, (val,))
|
return numpy.insert(v_2d, surface_normal, (val,))
|
||||||
|
|
||||||
# iterate over grids
|
# iterate over grids
|
||||||
|
foreground_val: NDArray | float
|
||||||
for i, _ in enumerate(cell_data):
|
for i, _ in enumerate(cell_data):
|
||||||
# ## Evaluate or expand foregrounds[i]
|
# ## Evaluate or expand foregrounds[i]
|
||||||
foregrounds_i = foregrounds[i]
|
foregrounds_i = foregrounds[i]
|
||||||
@ -129,7 +134,7 @@ def draw_polygons(
|
|||||||
w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int))
|
w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int))
|
||||||
|
|
||||||
# Draw each polygon separately
|
# Draw each polygon separately
|
||||||
for polygon in polygons:
|
for polygon in poly_list:
|
||||||
|
|
||||||
# Get the boundaries of the polygon
|
# Get the boundaries of the polygon
|
||||||
pbd_min = polygon.min(axis=0)
|
pbd_min = polygon.min(axis=0)
|
||||||
@ -144,13 +149,13 @@ def draw_polygons(
|
|||||||
|
|
||||||
# Find indices in w_xy which are modified by polygon
|
# Find indices in w_xy which are modified by polygon
|
||||||
# First for the edge coordinates (+1 since we're indexing edges)
|
# First for the edge coordinates (+1 since we're indexing edges)
|
||||||
edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max)]
|
edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max, strict=True)]
|
||||||
# Then for the pixel centers (-bdi_min since we're
|
# Then for the pixel centers (-bdi_min since we're
|
||||||
# calculating weights within a subspace)
|
# calculating weights within a subspace)
|
||||||
centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface],
|
centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface],
|
||||||
corner_max - bdi_min[surface]))
|
corner_max - bdi_min[surface], strict=True))
|
||||||
|
|
||||||
aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices))
|
aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices, strict=True))
|
||||||
w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y)
|
w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y)
|
||||||
|
|
||||||
# Clamp overlapping polygons to 1
|
# Clamp overlapping polygons to 1
|
||||||
@ -159,7 +164,7 @@ def draw_polygons(
|
|||||||
# 2) Generate weights in z-direction
|
# 2) Generate weights in z-direction
|
||||||
w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], ))
|
w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], ))
|
||||||
|
|
||||||
def get_zi(offset, i=i, w_z=w_z):
|
def get_zi(offset: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001
|
||||||
edges = self.shifted_exyz(i)[surface_normal]
|
edges = self.shifted_exyz(i)[surface_normal]
|
||||||
point = center[surface_normal] + offset
|
point = center[surface_normal] + offset
|
||||||
grid_coord = numpy.digitize(point, edges) - 1
|
grid_coord = numpy.digitize(point, edges) - 1
|
||||||
@ -377,10 +382,10 @@ def draw_extrude_rectangle(
|
|||||||
ind[direction] += 1 # type: ignore #(known safe)
|
ind[direction] += 1 # type: ignore #(known safe)
|
||||||
foreground += mult[1] * grid[tuple(ind)]
|
foreground += mult[1] * grid[tuple(ind)]
|
||||||
|
|
||||||
def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int_]:
|
def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001
|
||||||
# transform from natural position to index
|
# transform from natural position to index
|
||||||
xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i)
|
xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i)
|
||||||
for qrs in zip(xs.flat, ys.flat, zs.flat)], dtype=int)
|
for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64)
|
||||||
# reshape to original shape and keep only in-plane components
|
# reshape to original shape and keep only in-plane components
|
||||||
qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface)
|
qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface)
|
||||||
return foreground[qi, ri]
|
return foreground[qi, ri]
|
||||||
|
@ -29,7 +29,7 @@ if __name__ == '__main__':
|
|||||||
# numpy.linspace(-5.5, 5.5, 10)]
|
# numpy.linspace(-5.5, 5.5, 10)]
|
||||||
|
|
||||||
half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5]
|
half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5]
|
||||||
xyz3 = [[-x for x in half_x[::-1]] + [0] + half_x,
|
xyz3 = [numpy.array([-x for x in half_x[::-1]] + [0] + half_x),
|
||||||
numpy.linspace(-5.5, 5.5, 10),
|
numpy.linspace(-5.5, 5.5, 10),
|
||||||
numpy.linspace(-5.5, 5.5, 10)]
|
numpy.linspace(-5.5, 5.5, 10)]
|
||||||
eg = Grid(xyz3)
|
eg = Grid(xyz3)
|
||||||
@ -37,8 +37,8 @@ if __name__ == '__main__':
|
|||||||
# eg.draw_slab(Direction.z, 0, 10, 2)
|
# eg.draw_slab(Direction.z, 0, 10, 2)
|
||||||
eg.save('/home/jan/Desktop/test.pickle')
|
eg.save('/home/jan/Desktop/test.pickle')
|
||||||
eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0,
|
eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0,
|
||||||
thickness=10, num_poitns=1000, foreground=1)
|
thickness=10, num_points=1000, foreground=1)
|
||||||
eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]],
|
eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]],
|
||||||
direction=1, poalarity=+1, distance=5)
|
direction=1, polarity=+1, distance=5)
|
||||||
eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2)
|
eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2)
|
||||||
eg.visualize_isosurface(egc, which_shifts=2)
|
eg.visualize_isosurface(egc, which_shifts=2)
|
||||||
|
200
gridlock/grid.py
200
gridlock/grid.py
@ -1,4 +1,5 @@
|
|||||||
from typing import Callable, Sequence, ClassVar, Self
|
from typing import ClassVar, Self
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.typing import NDArray, ArrayLike
|
from numpy.typing import NDArray, ArrayLike
|
||||||
@ -8,12 +9,15 @@ import warnings
|
|||||||
import copy
|
import copy
|
||||||
|
|
||||||
from . import GridError
|
from . import GridError
|
||||||
|
from .draw import GridDrawMixin
|
||||||
|
from .read import GridReadMixin
|
||||||
|
from .position import GridPosMixin
|
||||||
|
|
||||||
|
|
||||||
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
|
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
|
||||||
|
|
||||||
|
|
||||||
class Grid:
|
class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
|
||||||
"""
|
"""
|
||||||
Simulation grid metadata for finite-difference simulations.
|
Simulation grid metadata for finite-difference simulations.
|
||||||
|
|
||||||
@ -70,193 +74,6 @@ class Grid:
|
|||||||
], dtype=float)
|
], dtype=float)
|
||||||
"""Default shifts for Yee grid H-field"""
|
"""Default shifts for Yee grid H-field"""
|
||||||
|
|
||||||
from .draw import (
|
|
||||||
draw_polygons, draw_polygon, draw_slab, draw_cuboid,
|
|
||||||
draw_cylinder, draw_extrude_rectangle,
|
|
||||||
)
|
|
||||||
from .read import get_slice, visualize_slice, visualize_isosurface
|
|
||||||
from .position import ind2pos, pos2ind
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dxyz(self) -> list[NDArray]:
|
|
||||||
"""
|
|
||||||
Cell sizes for each axis, no shifts applied
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of 3 ndarrays of cell sizes
|
|
||||||
"""
|
|
||||||
return [numpy.diff(ee) for ee in self.exyz]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def xyz(self) -> list[NDArray]:
|
|
||||||
"""
|
|
||||||
Cell centers for each axis, no shifts applied
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of 3 ndarrays of cell edges
|
|
||||||
"""
|
|
||||||
return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def shape(self) -> NDArray[numpy.int_]:
|
|
||||||
"""
|
|
||||||
The number of cells in x, y, and z
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ndarray of [x_centers.size, y_centers.size, z_centers.size]
|
|
||||||
"""
|
|
||||||
return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_grids(self) -> int:
|
|
||||||
"""
|
|
||||||
The number of grids (number of shifts)
|
|
||||||
"""
|
|
||||||
return self.shifts.shape[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cell_data_shape(self):
|
|
||||||
"""
|
|
||||||
The shape of the cell_data ndarray (num_grids, *self.shape).
|
|
||||||
"""
|
|
||||||
return numpy.hstack((self.num_grids, self.shape))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dxyz_with_ghost(self) -> list[NDArray]:
|
|
||||||
"""
|
|
||||||
Gives dxyz with an additional 'ghost' cell at the end, whose value depends
|
|
||||||
on whether or not the axis has periodic boundary conditions. See main description
|
|
||||||
above to learn why this is necessary.
|
|
||||||
|
|
||||||
If periodic, final edge shifts same amount as first
|
|
||||||
Otherwise, final edge shifts same amount as second-to-last
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list of [dxs, dys, dzs] with each element same length as elements of `self.xyz`
|
|
||||||
"""
|
|
||||||
el = [0 if p else -1 for p in self.periodic]
|
|
||||||
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el)]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def center(self) -> NDArray[numpy.float64]:
|
|
||||||
"""
|
|
||||||
Center position of the entire grid, no shifts applied
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ndarray of [x_center, y_center, z_center]
|
|
||||||
"""
|
|
||||||
# center is just average of first and last xyz, which is just the average of the
|
|
||||||
# first two and last two exyz
|
|
||||||
centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)]
|
|
||||||
return numpy.array(centers, dtype=float)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dxyz_limits(self) -> tuple[NDArray, NDArray]:
|
|
||||||
"""
|
|
||||||
Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element
|
|
||||||
ndarrays. No shifts are applied, so these are extreme bounds on these values (as a
|
|
||||||
weighted average is performed when shifting).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]`
|
|
||||||
"""
|
|
||||||
d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float)
|
|
||||||
d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float)
|
|
||||||
return d_min, d_max
|
|
||||||
|
|
||||||
def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]:
|
|
||||||
"""
|
|
||||||
Returns edges for which_shifts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of 3 ndarrays of cell edges
|
|
||||||
"""
|
|
||||||
if which_shifts is None:
|
|
||||||
return self.exyz
|
|
||||||
dxyz = self.dxyz_with_ghost
|
|
||||||
shifts = self.shifts[which_shifts, :]
|
|
||||||
|
|
||||||
# If shift is negative, use left cell's dx to determine shift
|
|
||||||
for a in range(3):
|
|
||||||
if shifts[a] < 0:
|
|
||||||
dxyz[a] = numpy.roll(dxyz[a], 1)
|
|
||||||
|
|
||||||
return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
|
|
||||||
|
|
||||||
def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
|
|
||||||
"""
|
|
||||||
Returns cell sizes for `which_shifts`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of 3 ndarrays of cell sizes
|
|
||||||
"""
|
|
||||||
if which_shifts is None:
|
|
||||||
return self.dxyz
|
|
||||||
shifts = self.shifts[which_shifts, :]
|
|
||||||
dxyz = self.dxyz_with_ghost
|
|
||||||
|
|
||||||
# If shift is negative, use left cell's dx to determine size
|
|
||||||
sdxyz = []
|
|
||||||
for a in range(3):
|
|
||||||
if shifts[a] < 0:
|
|
||||||
roll_dxyz = numpy.roll(dxyz[a], 1)
|
|
||||||
abs_shift = numpy.abs(shifts[a])
|
|
||||||
sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
|
|
||||||
else:
|
|
||||||
sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
|
|
||||||
|
|
||||||
return sdxyz
|
|
||||||
|
|
||||||
def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
|
|
||||||
"""
|
|
||||||
Returns cell centers for `which_shifts`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
which_shifts: Which grid (which shifts) to use, or `None` for unshifted
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of 3 ndarrays of cell centers
|
|
||||||
"""
|
|
||||||
if which_shifts is None:
|
|
||||||
return self.xyz
|
|
||||||
exyz = self.shifted_exyz(which_shifts)
|
|
||||||
dxyz = self.shifted_dxyz(which_shifts)
|
|
||||||
return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)]
|
|
||||||
|
|
||||||
def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]:
|
|
||||||
"""
|
|
||||||
Return cell widths, with each dimension shifted by the corresponding shifts.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]`
|
|
||||||
"""
|
|
||||||
if self.num_grids != 3:
|
|
||||||
raise GridError('Autoshifting requires exactly 3 grids')
|
|
||||||
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
|
|
||||||
|
|
||||||
def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray:
|
|
||||||
"""
|
|
||||||
Allocate an ndarray for storing grid data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fill_value: Value to initialize the grid to. If None, an
|
|
||||||
uninitialized array is returned.
|
|
||||||
dtype: Numpy dtype for the array. Default is `numpy.float32`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The allocated array
|
|
||||||
"""
|
|
||||||
if fill_value is None:
|
|
||||||
return numpy.empty(self.cell_data_shape, dtype=dtype)
|
|
||||||
else:
|
|
||||||
return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pixel_edge_coordinates: Sequence[ArrayLike],
|
pixel_edge_coordinates: Sequence[ArrayLike],
|
||||||
@ -277,11 +94,12 @@ class Grid:
|
|||||||
Raises:
|
Raises:
|
||||||
`GridError` on invalid input
|
`GridError` on invalid input
|
||||||
"""
|
"""
|
||||||
self.exyz = [numpy.unique(pixel_edge_coordinates[i]) for i in range(3)]
|
edge_arrs = [numpy.array(cc, copy=False) for cc in pixel_edge_coordinates]
|
||||||
|
self.exyz = [numpy.unique(edges) for edges in edge_arrs]
|
||||||
self.shifts = numpy.array(shifts, dtype=float)
|
self.shifts = numpy.array(shifts, dtype=float)
|
||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
if len(self.exyz[i]) != len(pixel_edge_coordinates[i]):
|
if self.exyz[i].size != edge_arrs[i].size:
|
||||||
warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2)
|
warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2)
|
||||||
|
|
||||||
if isinstance(periodic, bool):
|
if isinstance(periodic, bool):
|
||||||
|
@ -5,8 +5,10 @@ import numpy
|
|||||||
from numpy.typing import NDArray, ArrayLike
|
from numpy.typing import NDArray, ArrayLike
|
||||||
|
|
||||||
from . import GridError
|
from . import GridError
|
||||||
|
from .base import GridBase
|
||||||
|
|
||||||
|
|
||||||
|
class GridPosMixin(GridBase):
|
||||||
def ind2pos(
|
def ind2pos(
|
||||||
self,
|
self,
|
||||||
ind: NDArray,
|
ind: NDArray,
|
||||||
|
@ -7,6 +7,7 @@ import numpy
|
|||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from . import GridError
|
from . import GridError
|
||||||
|
from .position import GridPosMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import matplotlib.axes
|
import matplotlib.axes
|
||||||
@ -18,6 +19,7 @@ if TYPE_CHECKING:
|
|||||||
# .visualize_isosurface uses mpl_toolkits.mplot3d
|
# .visualize_isosurface uses mpl_toolkits.mplot3d
|
||||||
|
|
||||||
|
|
||||||
|
class GridReadMixin(GridPosMixin):
|
||||||
def get_slice(
|
def get_slice(
|
||||||
self,
|
self,
|
||||||
cell_data: NDArray,
|
cell_data: NDArray,
|
||||||
@ -72,7 +74,7 @@ def get_slice(
|
|||||||
|
|
||||||
# Extract grid values from planes above and below visualized slice
|
# Extract grid values from planes above and below visualized slice
|
||||||
sliced_grid = numpy.zeros(self.shape[surface])
|
sliced_grid = numpy.zeros(self.shape[surface])
|
||||||
for ci, weight in zip(centers, w):
|
for ci, weight in zip(centers, w, strict=True):
|
||||||
s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3))
|
s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3))
|
||||||
sliced_grid += weight * cell_data[which_shifts][tuple(s)]
|
sliced_grid += weight * cell_data[which_shifts][tuple(s)]
|
||||||
|
|
||||||
@ -162,6 +164,7 @@ def visualize_isosurface(
|
|||||||
import skimage.measure
|
import skimage.measure
|
||||||
# Claims to be unused, but needed for subplot(projection='3d')
|
# Claims to be unused, but needed for subplot(projection='3d')
|
||||||
from mpl_toolkits.mplot3d import Axes3D
|
from mpl_toolkits.mplot3d import Axes3D
|
||||||
|
del Axes3D # imported for side effects only
|
||||||
|
|
||||||
# Get data from cell_data
|
# Get data from cell_data
|
||||||
grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period]
|
grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period]
|
||||||
@ -193,7 +196,7 @@ def visualize_isosurface(
|
|||||||
ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min())
|
ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min())
|
||||||
zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min())
|
zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min())
|
||||||
# Comment or uncomment following both lines to test the fake bounding box:
|
# Comment or uncomment following both lines to test the fake bounding box:
|
||||||
for xb, yb, zb in zip(xbs, ybs, zbs):
|
for xb, yb, zb in zip(xbs, ybs, zbs, strict=True):
|
||||||
ax.plot([xb], [yb], [zb], 'w')
|
ax.plot([xb], [yb], [zb], 'w')
|
||||||
|
|
||||||
if finalize:
|
if finalize:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import pytest # type: ignore
|
import pytest # type: ignore
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.testing import assert_allclose, assert_array_equal
|
from numpy.testing import assert_allclose #, assert_array_equal
|
||||||
|
|
||||||
from .. import Grid
|
from .. import Grid
|
||||||
|
|
||||||
|
@ -53,3 +53,47 @@ visualization-isosurface = [
|
|||||||
"skimage>=0.13",
|
"skimage>=0.13",
|
||||||
"mpl_toolkits",
|
"mpl_toolkits",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
exclude = [
|
||||||
|
".git",
|
||||||
|
"dist",
|
||||||
|
]
|
||||||
|
line-length = 145
|
||||||
|
indent-width = 4
|
||||||
|
lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||||
|
lint.select = [
|
||||||
|
"NPY", "E", "F", "W", "B", "ANN", "UP", "SLOT", "SIM", "LOG",
|
||||||
|
"C4", "ISC", "PIE", "PT", "RET", "TCH", "PTH", "INT",
|
||||||
|
"ARG", "PL", "R", "TRY",
|
||||||
|
"G010", "G101", "G201", "G202",
|
||||||
|
"Q002", "Q003", "Q004",
|
||||||
|
]
|
||||||
|
lint.ignore = [
|
||||||
|
#"ANN001", # No annotation
|
||||||
|
"ANN002", # *args
|
||||||
|
"ANN003", # **kwargs
|
||||||
|
"ANN401", # Any
|
||||||
|
"ANN101", # self: Self
|
||||||
|
"SIM108", # single-line if / else assignment
|
||||||
|
"RET504", # x=y+z; return x
|
||||||
|
"PIE790", # unnecessary pass
|
||||||
|
"ISC003", # non-implicit string concatenation
|
||||||
|
"C408", # dict(x=y) instead of {'x': y}
|
||||||
|
"PLR09", # Too many xxx
|
||||||
|
"PLR2004", # magic number
|
||||||
|
"PLC0414", # import x as x
|
||||||
|
"TRY003", # Long exception message
|
||||||
|
"PTH123", # open()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = [
|
||||||
|
"matplotlib",
|
||||||
|
"matplotlib.axes",
|
||||||
|
"matplotlib.figure",
|
||||||
|
"mpl_toolkits.mplot3d",
|
||||||
|
]
|
||||||
|
ignore_missing_imports = true
|
||||||
|
Loading…
Reference in New Issue
Block a user