Compare commits
1 Commits
master
...
arg_rework
Author | SHA1 | Date | |
---|---|---|---|
34f80202ba |
@ -15,9 +15,24 @@ Dependencies:
|
|||||||
- mpl_toolkits.mplot3d [Grid.visualize_isosurface()]
|
- mpl_toolkits.mplot3d [Grid.visualize_isosurface()]
|
||||||
- skimage [Grid.visualize_isosurface()]
|
- skimage [Grid.visualize_isosurface()]
|
||||||
"""
|
"""
|
||||||
from .error import GridError as GridError
|
from .utils import (
|
||||||
|
GridError as GridError,
|
||||||
|
|
||||||
|
Extent as Extent,
|
||||||
|
ExtentProtocol as ExtentProtocol,
|
||||||
|
ExtentDict as ExtentDict,
|
||||||
|
|
||||||
|
Slab as Slab,
|
||||||
|
SlabProtocol as SlabProtocol,
|
||||||
|
SlabDict as SlabDict,
|
||||||
|
|
||||||
|
Plane as Plane,
|
||||||
|
PlaneProtocol as PlaneProtocol,
|
||||||
|
PlaneDict as PlaneDict,
|
||||||
|
)
|
||||||
from .grid import Grid as Grid
|
from .grid import Grid as Grid
|
||||||
|
|
||||||
|
|
||||||
__author__ = 'Jan Petykiewicz'
|
__author__ = 'Jan Petykiewicz'
|
||||||
__version__ = '1.2'
|
__version__ = '1.2'
|
||||||
version = __version__
|
version = __version__
|
||||||
|
@ -1,10 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
|
||||||
class Direction(Enum):
|
|
||||||
"""
|
|
||||||
Enum for axis->integer mapping
|
|
||||||
"""
|
|
||||||
x = 0
|
|
||||||
y = 1
|
|
||||||
z = 2
|
|
209
gridlock/draw.py
209
gridlock/draw.py
@ -7,7 +7,7 @@ import numpy
|
|||||||
from numpy.typing import NDArray, ArrayLike
|
from numpy.typing import NDArray, ArrayLike
|
||||||
from float_raster import raster
|
from float_raster import raster
|
||||||
|
|
||||||
from . import GridError
|
from .utils import GridError, Slab, SlabDict, SlabProtocol, Extent, ExtentDict, ExtentProtocol
|
||||||
from .position import GridPosMixin
|
from .position import GridPosMixin
|
||||||
|
|
||||||
|
|
||||||
@ -21,27 +21,26 @@ foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray]
|
|||||||
foreground_t = float | foreground_callable_t
|
foreground_t = float | foreground_callable_t
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GridDrawMixin(GridPosMixin):
|
class GridDrawMixin(GridPosMixin):
|
||||||
def draw_polygons(
|
def draw_polygons(
|
||||||
self,
|
self,
|
||||||
cell_data: NDArray,
|
cell_data: NDArray,
|
||||||
surface_normal: int,
|
slab: SlabProtocol | SlabDict,
|
||||||
center: ArrayLike,
|
|
||||||
polygons: Sequence[ArrayLike],
|
polygons: Sequence[ArrayLike],
|
||||||
thickness: float,
|
|
||||||
foreground: Sequence[foreground_t] | foreground_t,
|
foreground: Sequence[foreground_t] | foreground_t,
|
||||||
|
*,
|
||||||
|
offset2d: ArrayLike = (0, 0),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Draw polygons on an axis-aligned plane.
|
Draw polygons on an axis-aligned plane.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
|
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
|
||||||
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
|
|
||||||
center: 3-element ndarray or list specifying an offset applied to all the polygons
|
center: 3-element ndarray or list specifying an offset applied to all the polygons
|
||||||
polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon
|
polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon
|
||||||
(non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each
|
(non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each
|
||||||
polygon must have at least 3 vertices.
|
polygon must have at least 3 vertices.
|
||||||
thickness: Thickness of the layer to draw
|
|
||||||
foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list
|
foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list
|
||||||
of any of these (1 per grid). Callable values should take an ndarray the shape of the
|
of any of these (1 per grid). Callable values should take an ndarray the shape of the
|
||||||
grid and return an ndarray of equal shape containing the foreground value at the given x, y,
|
grid and return an ndarray of equal shape containing the foreground value at the given x, y,
|
||||||
@ -50,13 +49,13 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
Raises:
|
Raises:
|
||||||
GridError
|
GridError
|
||||||
"""
|
"""
|
||||||
if surface_normal not in range(3):
|
if isinstance(slab, dict):
|
||||||
raise GridError('Invalid surface_normal direction')
|
slab = Slab(**slab)
|
||||||
center = numpy.squeeze(center)
|
|
||||||
poly_list = [numpy.asarray(poly) for poly in polygons]
|
poly_list = [numpy.asarray(poly) for poly in polygons]
|
||||||
|
|
||||||
# Check polygons, and remove redundant coordinates
|
# Check polygons, and remove redundant coordinates
|
||||||
surface = numpy.delete(range(3), surface_normal)
|
surface = numpy.delete(range(3), slab.axis)
|
||||||
|
|
||||||
for ii in range(len(poly_list)):
|
for ii in range(len(poly_list)):
|
||||||
polygon = poly_list[ii]
|
polygon = poly_list[ii]
|
||||||
@ -69,9 +68,8 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
|
|
||||||
if not polygon.shape[0] > 2:
|
if not polygon.shape[0] > 2:
|
||||||
raise GridError(malformed + 'must consist of more than 2 points')
|
raise GridError(malformed + 'must consist of more than 2 points')
|
||||||
if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1:
|
if polygon.ndim > 2 and not numpy.unique(polygon[:, slab.axis]).size == 1:
|
||||||
raise GridError(malformed + 'must be in plane with surface normal '
|
raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis])
|
||||||
+ 'xyz'[surface_normal])
|
|
||||||
|
|
||||||
# Broadcast foreground where necessary
|
# Broadcast foreground where necessary
|
||||||
foregrounds: Sequence[foreground_callable_t] | Sequence[float]
|
foregrounds: Sequence[foreground_callable_t] | Sequence[float]
|
||||||
@ -87,10 +85,10 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
bd_2d_min = numpy.array([0, 0])
|
bd_2d_min = numpy.array([0, 0])
|
||||||
bd_2d_max = numpy.array([0, 0])
|
bd_2d_max = numpy.array([0, 0])
|
||||||
for polygon in poly_list:
|
for polygon in poly_list:
|
||||||
bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0))
|
bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) + offset2d
|
||||||
bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0))
|
bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) + offset2d
|
||||||
bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center
|
bd_min = numpy.insert(bd_2d_min, slab.axis, slab.min)
|
||||||
bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center
|
bd_max = numpy.insert(bd_2d_max, slab.axis, slab.max)
|
||||||
|
|
||||||
# 2) Find indices (bdi) just outside bd elements
|
# 2) Find indices (bdi) just outside bd elements
|
||||||
buf = 2 # size of safety buffer
|
buf = 2 # size of safety buffer
|
||||||
@ -103,13 +101,13 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int)
|
bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int)
|
||||||
bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int)
|
bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int)
|
||||||
|
|
||||||
# 3) Adjust polygons for center
|
# 3) Adjust polygons for offset2d
|
||||||
poly_list = [poly + center[surface] for poly in poly_list]
|
poly_list = [poly + offset2d for poly in poly_list]
|
||||||
|
|
||||||
# ## Generate weighing function
|
# ## Generate weighing function
|
||||||
def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]:
|
def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]:
|
||||||
v_2d = numpy.array(vector, dtype=float)
|
v_2d = numpy.array(vector, dtype=float)
|
||||||
return numpy.insert(v_2d, surface_normal, (val,))
|
return numpy.insert(v_2d, slab.axis, (val,))
|
||||||
|
|
||||||
# iterate over grids
|
# iterate over grids
|
||||||
foreground_val: NDArray | float
|
foreground_val: NDArray | float
|
||||||
@ -162,13 +160,12 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
w_xy = numpy.minimum(w_xy, 1.0)
|
w_xy = numpy.minimum(w_xy, 1.0)
|
||||||
|
|
||||||
# 2) Generate weights in z-direction
|
# 2) Generate weights in z-direction
|
||||||
w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], ))
|
w_z = numpy.zeros(((bdi_max - bdi_min + 1)[slab.axis], ))
|
||||||
|
|
||||||
def get_zi(offset: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001
|
def get_zi(point: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001
|
||||||
edges = self.shifted_exyz(i)[surface_normal]
|
edges = self.shifted_exyz(i)[slab.axis]
|
||||||
point = center[surface_normal] + offset
|
|
||||||
grid_coord = numpy.digitize(point, edges) - 1
|
grid_coord = numpy.digitize(point, edges) - 1
|
||||||
w_coord = grid_coord - bdi_min[surface_normal]
|
w_coord = grid_coord - bdi_min[slab.axis]
|
||||||
|
|
||||||
if w_coord < 0:
|
if w_coord < 0:
|
||||||
w_coord = 0
|
w_coord = 0
|
||||||
@ -177,12 +174,12 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
w_coord = w_z.size - 1
|
w_coord = w_z.size - 1
|
||||||
f = 1
|
f = 1
|
||||||
else:
|
else:
|
||||||
dz = self.shifted_dxyz(i)[surface_normal][grid_coord]
|
dz = self.shifted_dxyz(i)[slab.axis][grid_coord]
|
||||||
f = (point - edges[grid_coord]) / dz
|
f = (point - edges[grid_coord]) / dz
|
||||||
return f, w_coord
|
return f, w_coord
|
||||||
|
|
||||||
zi_top_f, zi_top = get_zi(+thickness / 2.0)
|
zi_top_f, zi_top = get_zi(slab.max)
|
||||||
zi_bot_f, zi_bot = get_zi(-thickness / 2.0)
|
zi_bot_f, zi_bot = get_zi(slab.min)
|
||||||
|
|
||||||
w_z[zi_bot + 1:zi_top] = 1
|
w_z[zi_bot + 1:zi_top] = 1
|
||||||
|
|
||||||
@ -193,7 +190,7 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
w_z[zi_bot] = zi_top_f - zi_bot_f
|
w_z[zi_bot] = zi_top_f - zi_bot_f
|
||||||
|
|
||||||
# 3) Generate total weight function
|
# 3) Generate total weight function
|
||||||
w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,)))
|
w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], slab.axis, (2,)))
|
||||||
|
|
||||||
# ## Modify the grid
|
# ## Modify the grid
|
||||||
g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3))
|
g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3))
|
||||||
@ -203,34 +200,36 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
def draw_polygon(
|
def draw_polygon(
|
||||||
self,
|
self,
|
||||||
cell_data: NDArray,
|
cell_data: NDArray,
|
||||||
surface_normal: int,
|
slab: SlabProtocol | SlabDict,
|
||||||
center: ArrayLike,
|
|
||||||
polygon: ArrayLike,
|
polygon: ArrayLike,
|
||||||
thickness: float,
|
|
||||||
foreground: Sequence[foreground_t] | foreground_t,
|
foreground: Sequence[foreground_t] | foreground_t,
|
||||||
|
*,
|
||||||
|
offset2d: ArrayLike = (0, 0),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Draw a polygon on an axis-aligned plane.
|
Draw a polygon on an axis-aligned plane.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
|
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
|
||||||
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
|
slab: `Slab` in which to draw polygons.
|
||||||
center: 3-element ndarray or list specifying an offset applied to the polygon
|
|
||||||
polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed,
|
polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed,
|
||||||
clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at
|
clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Must have at
|
||||||
least 3 vertices.
|
least 3 vertices.
|
||||||
thickness: Thickness of the layer to draw
|
|
||||||
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
|
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
|
||||||
"""
|
"""
|
||||||
self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground)
|
self.draw_polygons(
|
||||||
|
cell_data = cell_data,
|
||||||
|
slab = slab,
|
||||||
|
polygons = [polygon],
|
||||||
|
foreground = foreground,
|
||||||
|
offset2d = offset2d,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def draw_slab(
|
def draw_slab(
|
||||||
self,
|
self,
|
||||||
cell_data: NDArray,
|
cell_data: NDArray,
|
||||||
surface_normal: int,
|
slab: SlabProtocol | SlabDict,
|
||||||
center: ArrayLike,
|
|
||||||
thickness: float,
|
|
||||||
foreground: Sequence[foreground_t] | foreground_t,
|
foreground: Sequence[foreground_t] | foreground_t,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -238,50 +237,45 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
|
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
|
||||||
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
|
slab:
|
||||||
center: `surface_normal` coordinate value at the center of the slab
|
|
||||||
thickness: Thickness of the layer to draw
|
thickness: Thickness of the layer to draw
|
||||||
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
|
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
|
||||||
"""
|
"""
|
||||||
# Turn surface_normal into its integer representation
|
if isinstance(slab, dict):
|
||||||
if surface_normal not in range(3):
|
slab = Slab(**slab)
|
||||||
raise GridError('Invalid surface_normal direction')
|
|
||||||
|
|
||||||
if numpy.size(center) != 1:
|
|
||||||
center = numpy.squeeze(center)
|
|
||||||
if len(center) == 3:
|
|
||||||
center = center[surface_normal]
|
|
||||||
else:
|
|
||||||
raise GridError(f'Bad center: {center}')
|
|
||||||
|
|
||||||
# Find center of slab
|
# Find center of slab
|
||||||
center_shift = self.center
|
center_shift = self.center
|
||||||
center_shift[surface_normal] = center
|
center_shift[slab.axis] = slab.center
|
||||||
|
|
||||||
surface = numpy.delete(range(3), surface_normal)
|
surface = numpy.delete(range(3), slab.axis)
|
||||||
|
u_min, u_max = self.exyz[surface[0]][[0, -1]]
|
||||||
|
v_min, v_max = self.exyz[surface[1]][[0, -1]]
|
||||||
|
|
||||||
xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface]
|
margin = 4 * numpy.max(self.dxyz[surface[0]].max(),
|
||||||
xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface]
|
self.dxyz[surface[1]].max())
|
||||||
|
|
||||||
dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float)
|
p = numpy.array([[u_min - margin, v_max + margin],
|
||||||
|
[u_max + margin, v_max + margin],
|
||||||
|
[u_max + margin, v_min - margin],
|
||||||
|
[u_min - margin, v_min - margin]], dtype=float)
|
||||||
|
|
||||||
xyz_min -= 4 * dxyz
|
self.draw_polygon(
|
||||||
xyz_max += 4 * dxyz
|
cell_data = cell_data,
|
||||||
|
slab = slab,
|
||||||
p = numpy.array([[xyz_min[0], xyz_max[1]],
|
polygon = p,
|
||||||
[xyz_max[0], xyz_max[1]],
|
foreground = foreground,
|
||||||
[xyz_max[0], xyz_min[1]],
|
)
|
||||||
[xyz_min[0], xyz_min[1]]], dtype=float)
|
|
||||||
|
|
||||||
self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground)
|
|
||||||
|
|
||||||
|
|
||||||
def draw_cuboid(
|
def draw_cuboid(
|
||||||
self,
|
self,
|
||||||
cell_data: NDArray,
|
cell_data: NDArray,
|
||||||
center: ArrayLike,
|
|
||||||
dimensions: ArrayLike,
|
|
||||||
foreground: Sequence[foreground_t] | foreground_t,
|
foreground: Sequence[foreground_t] | foreground_t,
|
||||||
|
*,
|
||||||
|
x: ExtentProtocol | ExtentDict,
|
||||||
|
y: ExtentProtocol | ExtentDict,
|
||||||
|
z: ExtentProtocol | ExtentDict,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Draw an axis-aligned cuboid
|
Draw an axis-aligned cuboid
|
||||||
@ -293,23 +287,30 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
sizes of the cuboid
|
sizes of the cuboid
|
||||||
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
|
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
|
||||||
"""
|
"""
|
||||||
dimensions = numpy.asarray(dimensions)
|
if isinstance(x, dict):
|
||||||
p = numpy.array([[-dimensions[0], +dimensions[1]],
|
x = Extent(**x)
|
||||||
[+dimensions[0], +dimensions[1]],
|
if isinstance(y, dict):
|
||||||
[+dimensions[0], -dimensions[1]],
|
y = Extent(**y)
|
||||||
[-dimensions[0], -dimensions[1]]], dtype=float) * 0.5
|
if isinstance(z, dict):
|
||||||
thickness = dimensions[2]
|
z = Extent(**z)
|
||||||
self.draw_polygon(cell_data, 2, center, p, thickness, foreground)
|
|
||||||
|
center = numpy.asarray([x.center, y.center, z.center])
|
||||||
|
|
||||||
|
p = numpy.array([[x.min, y.max],
|
||||||
|
[x.max, y.max],
|
||||||
|
[x.max, y.min],
|
||||||
|
[x.min, y.min]], dtype=float)
|
||||||
|
slab = Slab(axis=2, center=z.center, span=z.span)
|
||||||
|
self.draw_polygon(cell_data=cell_data, slab=slab, polygon=p, foreground=foreground)
|
||||||
|
|
||||||
|
|
||||||
def draw_cylinder(
|
def draw_cylinder(
|
||||||
self,
|
self,
|
||||||
cell_data: NDArray,
|
cell_data: NDArray,
|
||||||
surface_normal: int,
|
h: SlabProtocol | SlabDict,
|
||||||
center: ArrayLike,
|
|
||||||
radius: float,
|
radius: float,
|
||||||
thickness: float,
|
|
||||||
num_points: int,
|
num_points: int,
|
||||||
|
center2d: ArrayLike,
|
||||||
foreground: Sequence[foreground_t] | foreground_t,
|
foreground: Sequence[foreground_t] | foreground_t,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -317,18 +318,19 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
|
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
|
||||||
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
|
h:
|
||||||
center: 3-element ndarray or list specifying the cylinder's center
|
radius:
|
||||||
radius: cylinder radius
|
|
||||||
thickness: Thickness of the layer to draw
|
|
||||||
num_points: The circle is approximated by a polygon with `num_points` vertices
|
num_points: The circle is approximated by a polygon with `num_points` vertices
|
||||||
|
center2d:
|
||||||
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
|
foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
|
||||||
"""
|
"""
|
||||||
theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False)
|
if isinstance(h, dict):
|
||||||
x = radius * numpy.sin(theta)
|
h = Slab(**h)
|
||||||
y = radius * numpy.cos(theta)
|
|
||||||
polygon = numpy.hstack((x[:, None], y[:, None]))
|
theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False)[:, None]
|
||||||
self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground)
|
xy0 = numpy.hstack((numpy.sin(theta), numpy.cos(theta)))
|
||||||
|
polygon = radius * xy0
|
||||||
|
self.draw_polygon(cell_data=cell_data, slab=h, polygon=polygon, foreground=foreground, offset2d=center2d)
|
||||||
|
|
||||||
|
|
||||||
def draw_extrude_rectangle(
|
def draw_extrude_rectangle(
|
||||||
@ -349,10 +351,10 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
polarity: +1 or -1, direction along axis to extrude in
|
polarity: +1 or -1, direction along axis to extrude in
|
||||||
distance: How far to extrude
|
distance: How far to extrude
|
||||||
"""
|
"""
|
||||||
s = numpy.sign(polarity)
|
sgn = numpy.sign(polarity)
|
||||||
|
|
||||||
rectangle = numpy.array(rectangle, dtype=float)
|
rectangle = numpy.asarray(rectangle, dtype=float)
|
||||||
if s == 0:
|
if sgn == 0:
|
||||||
raise GridError('0 is not a valid polarity')
|
raise GridError('0 is not a valid polarity')
|
||||||
if direction not in range(3):
|
if direction not in range(3):
|
||||||
raise GridError(f'Invalid direction: {direction}')
|
raise GridError(f'Invalid direction: {direction}')
|
||||||
@ -360,37 +362,38 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
raise GridError('Rectangle entries along extrusion direction do not match.')
|
raise GridError('Rectangle entries along extrusion direction do not match.')
|
||||||
|
|
||||||
center = rectangle.sum(axis=0) / 2.0
|
center = rectangle.sum(axis=0) / 2.0
|
||||||
center[direction] += s * distance / 2.0
|
center[direction] += sgn * distance / 2.0
|
||||||
|
|
||||||
surface = numpy.delete(range(3), direction)
|
surface = numpy.delete(range(3), direction)
|
||||||
|
|
||||||
dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface]
|
dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface]
|
||||||
p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5,
|
poly = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5,
|
||||||
numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T
|
numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T
|
||||||
thickness = distance
|
thickness = distance
|
||||||
|
|
||||||
foreground_func = []
|
foreground_func = []
|
||||||
for i, grid in enumerate(cell_data):
|
for ii, grid in enumerate(cell_data):
|
||||||
z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction]
|
zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction]
|
||||||
|
|
||||||
ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)]
|
ind = [int(numpy.floor(zz)) if dd == direction else slice(None) for dd in range(3)]
|
||||||
|
|
||||||
fpart = z - numpy.floor(z)
|
fpart = zz - numpy.floor(zz)
|
||||||
mult = [1 - fpart, fpart][::s] # reverses if s negative
|
mult = [1 - fpart, fpart][::sgn] # reverses if s negative
|
||||||
|
|
||||||
foreground = mult[0] * grid[tuple(ind)]
|
foreground = mult[0] * grid[tuple(ind)]
|
||||||
ind[direction] += 1 # type: ignore #(known safe)
|
ind[direction] += 1 # type: ignore #(known safe)
|
||||||
foreground += mult[1] * grid[tuple(ind)]
|
foreground += mult[1] * grid[tuple(ind)]
|
||||||
|
|
||||||
def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001
|
def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001
|
||||||
# transform from natural position to index
|
# transform from natural position to index
|
||||||
xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i)
|
xyzi = numpy.array([self.pos2ind(qrs, which_shifts=ii)
|
||||||
for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64)
|
for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64)
|
||||||
# reshape to original shape and keep only in-plane components
|
# reshape to original shape and keep only in-plane components
|
||||||
qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface)
|
qi, ri = (numpy.reshape(xyzi[:, kk], xs.shape) for kk in surface)
|
||||||
return foreground[qi, ri]
|
return foreground[qi, ri]
|
||||||
|
|
||||||
foreground_func.append(f_foreground)
|
foreground_func.append(f_foreground)
|
||||||
|
|
||||||
self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func)
|
slab = Slab(axis=direction, center=center[direction], span=thickness)
|
||||||
|
self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface])
|
||||||
|
|
||||||
|
@ -1,2 +0,0 @@
|
|||||||
class GridError(Exception):
|
|
||||||
pass
|
|
@ -6,18 +6,18 @@ if __name__ == '__main__':
|
|||||||
# xyz = [numpy.arange(-5.0, 6.0), numpy.arange(-4.0, 5.0), [-1.0, 1.0]]
|
# xyz = [numpy.arange(-5.0, 6.0), numpy.arange(-4.0, 5.0), [-1.0, 1.0]]
|
||||||
# eg = Grid(xyz)
|
# eg = Grid(xyz)
|
||||||
# egc = Grid.allocate(0.0)
|
# egc = Grid.allocate(0.0)
|
||||||
# # eg.draw_slab(egc, surface_normal=2, center=0, thickness=10, foreground=2)
|
# # eg.draw_slab(egc, slab=dict(axis=2, center=0, span=10), foreground=2)
|
||||||
# eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=4,
|
# eg.draw_cylinder(egc, h=slab(axis=2, center=0, span=10),
|
||||||
# thickness=10, num_points=1000, foreground=1)
|
# center2d=[0, 0], radius=4, thickness=10, num_points=1000, foreground=1)
|
||||||
# eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2)
|
# eg.visualize_slice(egc, plane=dict(z=0), which_shifts=2)
|
||||||
|
|
||||||
# xyz2 = [numpy.arange(-5.0, 6.0), [-1.0, 1.0], numpy.arange(-4.0, 5.0)]
|
# xyz2 = [numpy.arange(-5.0, 6.0), [-1.0, 1.0], numpy.arange(-4.0, 5.0)]
|
||||||
# eg2 = Grid(xyz2)
|
# eg2 = Grid(xyz2)
|
||||||
# eg2c = Grid.allocate(0.0)
|
# eg2c = Grid.allocate(0.0)
|
||||||
# # eg2.draw_slab(eg2c, surface_normal=2, center=0, thickness=10, foreground=2)
|
# # eg2.draw_slab(eg2c, slab=dict(axis=2, center=0, span=10), foreground=2)
|
||||||
# eg2.draw_cylinder(eg2c, surface_normal=1, center=[0, 0, 0],
|
# eg2.draw_cylinder(eg2c, h=slab(axis=1, center=0, span=10), center2d=[0, 0],
|
||||||
# radius=4, thickness=10, num_points=1000, foreground=1.0)
|
# radius=4, num_points=1000, foreground=1.0)
|
||||||
# eg2.visualize_slice(eg2c, surface_normal=1, center=0, which_shifts=1)
|
# eg2.visualize_slice(eg2c, plane=dict(y=0), which_shifts=1)
|
||||||
|
|
||||||
# n = 20
|
# n = 20
|
||||||
# m = 3
|
# m = 3
|
||||||
@ -36,9 +36,20 @@ if __name__ == '__main__':
|
|||||||
egc = eg.allocate(0)
|
egc = eg.allocate(0)
|
||||||
# eg.draw_slab(Direction.z, 0, 10, 2)
|
# eg.draw_slab(Direction.z, 0, 10, 2)
|
||||||
eg.save('/home/jan/Desktop/test.pickle')
|
eg.save('/home/jan/Desktop/test.pickle')
|
||||||
eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0,
|
eg.draw_cylinder(
|
||||||
thickness=10, num_points=1000, foreground=1)
|
egc,
|
||||||
eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]],
|
h=dict(axis='z', center=0, span=10),
|
||||||
direction=1, polarity=+1, distance=5)
|
center2d=[0, 0],
|
||||||
eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2)
|
radius=2.0,
|
||||||
|
num_points=1000,
|
||||||
|
foreground=1,
|
||||||
|
)
|
||||||
|
eg.draw_extrude_rectangle(
|
||||||
|
egc,
|
||||||
|
rectangle=[[-2, 1, -1], [0, 1, 1]],
|
||||||
|
direction=1,
|
||||||
|
polarity=+1,
|
||||||
|
distance=5,
|
||||||
|
)
|
||||||
|
eg.visualize_slice(egc, plane=dict(z=0), which_shifts=2)
|
||||||
eg.visualize_isosurface(egc, which_shifts=2)
|
eg.visualize_isosurface(egc, which_shifts=2)
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, TYPE_CHECKING
|
|||||||
import numpy
|
import numpy
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from . import GridError
|
from .utils import GridError, Plane, PlaneDict, PlaneProtocol
|
||||||
from .position import GridPosMixin
|
from .position import GridPosMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -23,27 +23,25 @@ class GridReadMixin(GridPosMixin):
|
|||||||
def get_slice(
|
def get_slice(
|
||||||
self,
|
self,
|
||||||
cell_data: NDArray,
|
cell_data: NDArray,
|
||||||
surface_normal: int,
|
plane: PlaneProtocol | PlaneDict,
|
||||||
center: float,
|
|
||||||
which_shifts: int = 0,
|
which_shifts: int = 0,
|
||||||
sample_period: int = 1
|
sample_period: int = 1
|
||||||
) -> NDArray:
|
) -> NDArray:
|
||||||
"""
|
"""
|
||||||
Retrieve a slice of a grid.
|
Retrieve a slice of a grid.
|
||||||
Interpolates if given a position between two planes.
|
Interpolates if given a position between two grid planes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cell_data: Cell data to slice
|
cell_data: Cell data to slice
|
||||||
surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
|
plane: Axis and position (`Plane`) of the plane to read.
|
||||||
center: Scalar specifying position along surface_normal axis.
|
|
||||||
which_shifts: Which grid to display. Default is the first grid (0).
|
which_shifts: Which grid to display. Default is the first grid (0).
|
||||||
sample_period: Period for down-sampling the image. Default 1 (disabled)
|
sample_period: Period for down-sampling the image. Default 1 (disabled)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Array containing the portion of the grid.
|
Array containing the portion of the grid.
|
||||||
"""
|
"""
|
||||||
if numpy.size(center) != 1 or not numpy.isreal(center):
|
if isinstance(plane, dict):
|
||||||
raise GridError('center must be a real scalar')
|
plane = Plane(**plane)
|
||||||
|
|
||||||
sp = round(sample_period)
|
sp = round(sample_period)
|
||||||
if sp <= 0:
|
if sp <= 0:
|
||||||
@ -52,15 +50,12 @@ class GridReadMixin(GridPosMixin):
|
|||||||
if numpy.size(which_shifts) != 1 or which_shifts < 0:
|
if numpy.size(which_shifts) != 1 or which_shifts < 0:
|
||||||
raise GridError('Invalid which_shifts')
|
raise GridError('Invalid which_shifts')
|
||||||
|
|
||||||
if surface_normal not in range(3):
|
surface = numpy.delete(range(3), plane.axis)
|
||||||
raise GridError('Invalid surface_normal direction')
|
|
||||||
|
|
||||||
surface = numpy.delete(range(3), surface_normal)
|
|
||||||
|
|
||||||
# Extract indices and weights of planes
|
# Extract indices and weights of planes
|
||||||
center3 = numpy.insert([0, 0], surface_normal, (center,))
|
center3 = numpy.insert([0, 0], plane.axis, (plane.pos,))
|
||||||
center_index = self.pos2ind(center3, which_shifts,
|
center_index = self.pos2ind(center3, which_shifts,
|
||||||
round_ind=False, check_bounds=False)[surface_normal]
|
round_ind=False, check_bounds=False)[plane.axis]
|
||||||
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
|
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
|
||||||
if len(centers) == 2:
|
if len(centers) == 2:
|
||||||
fpart = center_index - numpy.floor(center_index)
|
fpart = center_index - numpy.floor(center_index)
|
||||||
@ -68,14 +63,14 @@ class GridReadMixin(GridPosMixin):
|
|||||||
else:
|
else:
|
||||||
w = [1]
|
w = [1]
|
||||||
|
|
||||||
c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1])
|
c_min, c_max = (self.xyz[plane.axis][i] for i in [0, -1])
|
||||||
if center < c_min or center > c_max:
|
if plane.pos < c_min or plane.pos > c_max:
|
||||||
raise GridError('Coordinate of selected plane must be within simulation domain')
|
raise GridError('Coordinate of selected plane must be within simulation domain')
|
||||||
|
|
||||||
# Extract grid values from planes above and below visualized slice
|
# Extract grid values from planes above and below visualized slice
|
||||||
sliced_grid = numpy.zeros(self.shape[surface])
|
sliced_grid = numpy.zeros(self.shape[surface])
|
||||||
for ci, weight in zip(centers, w, strict=True):
|
for ci, weight in zip(centers, w, strict=True):
|
||||||
s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3))
|
s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3))
|
||||||
sliced_grid += weight * cell_data[which_shifts][tuple(s)]
|
sliced_grid += weight * cell_data[which_shifts][tuple(s)]
|
||||||
|
|
||||||
# Remove extra dimensions
|
# Remove extra dimensions
|
||||||
@ -87,20 +82,19 @@ class GridReadMixin(GridPosMixin):
|
|||||||
def visualize_slice(
|
def visualize_slice(
|
||||||
self,
|
self,
|
||||||
cell_data: NDArray,
|
cell_data: NDArray,
|
||||||
surface_normal: int,
|
plane: PlaneProtocol | PlaneDict,
|
||||||
center: float,
|
|
||||||
which_shifts: int = 0,
|
which_shifts: int = 0,
|
||||||
sample_period: int = 1,
|
sample_period: int = 1,
|
||||||
finalize: bool = True,
|
finalize: bool = True,
|
||||||
pcolormesh_args: dict[str, Any] | None = None,
|
pcolormesh_args: dict[str, Any] | None = None,
|
||||||
) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']:
|
) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']:
|
||||||
"""
|
"""
|
||||||
Visualize a slice of a grid.
|
Visualize a slice of a grid.
|
||||||
Interpolates if given a position between two planes.
|
Interpolates if given a position between two grid planes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
|
cell_data: Cell data to visualize
|
||||||
center: Scalar specifying position along surface_normal axis.
|
plane: Axis and position (`Plane`) of the plane to read.
|
||||||
which_shifts: Which grid to display. Default is the first grid (0).
|
which_shifts: Which grid to display. Default is the first grid (0).
|
||||||
sample_period: Period for down-sampling the image. Default 1 (disabled)
|
sample_period: Period for down-sampling the image. Default 1 (disabled)
|
||||||
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
|
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
|
||||||
@ -110,16 +104,20 @@ class GridReadMixin(GridPosMixin):
|
|||||||
"""
|
"""
|
||||||
from matplotlib import pyplot
|
from matplotlib import pyplot
|
||||||
|
|
||||||
|
if isinstance(plane, dict):
|
||||||
|
plane = Plane(**plane)
|
||||||
|
|
||||||
if pcolormesh_args is None:
|
if pcolormesh_args is None:
|
||||||
pcolormesh_args = {}
|
pcolormesh_args = {}
|
||||||
|
|
||||||
grid_slice = self.get_slice(cell_data=cell_data,
|
grid_slice = self.get_slice(
|
||||||
surface_normal=surface_normal,
|
cell_data=cell_data,
|
||||||
center=center,
|
plane=plane,
|
||||||
which_shifts=which_shifts,
|
which_shifts=which_shifts,
|
||||||
sample_period=sample_period)
|
sample_period=sample_period,
|
||||||
|
)
|
||||||
|
|
||||||
surface = numpy.delete(range(3), surface_normal)
|
surface = numpy.delete(range(3), plane.axis)
|
||||||
|
|
||||||
x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
|
x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
|
||||||
xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
|
xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
|
||||||
@ -145,7 +143,7 @@ class GridReadMixin(GridPosMixin):
|
|||||||
sample_period: int = 1,
|
sample_period: int = 1,
|
||||||
show_edges: bool = True,
|
show_edges: bool = True,
|
||||||
finalize: bool = True,
|
finalize: bool = True,
|
||||||
) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']:
|
) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']:
|
||||||
"""
|
"""
|
||||||
Draw an isosurface plot of the device.
|
Draw an isosurface plot of the device.
|
||||||
|
|
||||||
@ -183,18 +181,18 @@ class GridReadMixin(GridPosMixin):
|
|||||||
fig = pyplot.figure()
|
fig = pyplot.figure()
|
||||||
ax = fig.add_subplot(111, projection='3d')
|
ax = fig.add_subplot(111, projection='3d')
|
||||||
if show_edges:
|
if show_edges:
|
||||||
ax.plot_trisurf(xs, ys, faces, zs)
|
ax.plot_trisurf(xs, ys, faces, zs) # type: ignore
|
||||||
else:
|
else:
|
||||||
ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none')
|
ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') # type: ignore
|
||||||
|
|
||||||
# Add a fake plot of a cube to force the axes to be equal lengths
|
# Add a fake plot of a cube to force the axes to be equal lengths
|
||||||
max_range = numpy.array([xs.max() - xs.min(),
|
max_range = numpy.array([xs.max() - xs.min(),
|
||||||
ys.max() - ys.min(),
|
ys.max() - ys.min(),
|
||||||
zs.max() - zs.min()], dtype=float).max()
|
zs.max() - zs.min()], dtype=float).max()
|
||||||
mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2]
|
mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2]
|
||||||
xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min())
|
xbs = 0.5 * max_range * mg[0].ravel() + 0.5 * (xs.max() + xs.min())
|
||||||
ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min())
|
ybs = 0.5 * max_range * mg[1].ravel() + 0.5 * (ys.max() + ys.min())
|
||||||
zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min())
|
zbs = 0.5 * max_range * mg[2].ravel() + 0.5 * (zs.max() + zs.min())
|
||||||
# Comment or uncomment following both lines to test the fake bounding box:
|
# Comment or uncomment following both lines to test the fake bounding box:
|
||||||
for xb, yb, zb in zip(xbs, ybs, zbs, strict=True):
|
for xb, yb, zb in zip(xbs, ybs, zbs, strict=True):
|
||||||
ax.plot([xb], [yb], [zb], 'w')
|
ax.plot([xb], [yb], [zb], 'w')
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import numpy
|
import numpy
|
||||||
from numpy.testing import assert_allclose #, assert_array_equal
|
from numpy.testing import assert_allclose #, assert_array_equal
|
||||||
|
|
||||||
from .. import Grid
|
from .. import Grid, Extent #, Slab, Plane
|
||||||
|
|
||||||
|
|
||||||
def test_draw_oncenter_2x2() -> None:
|
def test_draw_oncenter_2x2() -> None:
|
||||||
@ -12,7 +12,13 @@ def test_draw_oncenter_2x2() -> None:
|
|||||||
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
|
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
|
||||||
arr = grid.allocate(0)
|
arr = grid.allocate(0)
|
||||||
|
|
||||||
grid.draw_cuboid(arr, center=[0, 0, 0], dimensions=[1, 1, 10], foreground=1)
|
grid.draw_cuboid(
|
||||||
|
arr,
|
||||||
|
x=dict(center=0, span=1),
|
||||||
|
y=Extent(center=0, span=1),
|
||||||
|
z=dict(center=0, span=10),
|
||||||
|
foreground=1,
|
||||||
|
)
|
||||||
|
|
||||||
correct = numpy.array([[0.25, 0.25],
|
correct = numpy.array([[0.25, 0.25],
|
||||||
[0.25, 0.25]])[None, :, :, None]
|
[0.25, 0.25]])[None, :, :, None]
|
||||||
@ -27,7 +33,13 @@ def test_draw_ongrid_4x4() -> None:
|
|||||||
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
|
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
|
||||||
arr = grid.allocate(0)
|
arr = grid.allocate(0)
|
||||||
|
|
||||||
grid.draw_cuboid(arr, center=[0, 0, 0], dimensions=[2, 2, 10], foreground=1)
|
grid.draw_cuboid(
|
||||||
|
arr,
|
||||||
|
x=dict(center=0, span=2),
|
||||||
|
y=dict(min=-1, max=1),
|
||||||
|
z=dict(center=0, min=-5),
|
||||||
|
foreground=1,
|
||||||
|
)
|
||||||
|
|
||||||
correct = numpy.array([[0, 0, 0, 0],
|
correct = numpy.array([[0, 0, 0, 0],
|
||||||
[0, 1, 1, 0],
|
[0, 1, 1, 0],
|
||||||
@ -44,7 +56,13 @@ def test_draw_xshift_4x4() -> None:
|
|||||||
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
|
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
|
||||||
arr = grid.allocate(0)
|
arr = grid.allocate(0)
|
||||||
|
|
||||||
grid.draw_cuboid(arr, center=[0.5, 0, 0], dimensions=[1.5, 2, 10], foreground=1)
|
grid.draw_cuboid(
|
||||||
|
arr,
|
||||||
|
x=dict(center=0.5, span=1.5),
|
||||||
|
y=dict(min=-1, max=1),
|
||||||
|
z=dict(center=0, span=10),
|
||||||
|
foreground=1,
|
||||||
|
)
|
||||||
|
|
||||||
correct = numpy.array([[0, 0, 0, 0],
|
correct = numpy.array([[0, 0, 0, 0],
|
||||||
[0, 0.25, 0.25, 0],
|
[0, 0.25, 0.25, 0],
|
||||||
@ -61,7 +79,13 @@ def test_draw_yshift_4x4() -> None:
|
|||||||
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
|
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
|
||||||
arr = grid.allocate(0)
|
arr = grid.allocate(0)
|
||||||
|
|
||||||
grid.draw_cuboid(arr, center=[0, 0.5, 0], dimensions=[2, 1.5, 10], foreground=1)
|
grid.draw_cuboid(
|
||||||
|
arr,
|
||||||
|
x=dict(min=-1, max=1),
|
||||||
|
y=dict(center=0.5, span=1.5),
|
||||||
|
z=dict(center=0, span=10),
|
||||||
|
foreground=1,
|
||||||
|
)
|
||||||
|
|
||||||
correct = numpy.array([[0, 0, 0, 0],
|
correct = numpy.array([[0, 0, 0, 0],
|
||||||
[0, 0.25, 1, 0.25],
|
[0, 0.25, 1, 0.25],
|
||||||
@ -78,7 +102,13 @@ def test_draw_2shift_4x4() -> None:
|
|||||||
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
|
grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]])
|
||||||
arr = grid.allocate(0)
|
arr = grid.allocate(0)
|
||||||
|
|
||||||
grid.draw_cuboid(arr, center=[0.5, 0, 0], dimensions=[1.5, 1, 10], foreground=1)
|
grid.draw_cuboid(
|
||||||
|
arr,
|
||||||
|
x=dict(center=0.5, span=1.5),
|
||||||
|
y=dict(min=-0.5, max=0.5),
|
||||||
|
z=dict(center=0, span=10),
|
||||||
|
foreground=1,
|
||||||
|
)
|
||||||
|
|
||||||
correct = numpy.array([[0, 0, 0, 0],
|
correct = numpy.array([[0, 0, 0, 0],
|
||||||
[0, 0.125, 0.125, 0],
|
[0, 0.125, 0.125, 0],
|
||||||
|
201
gridlock/utils.py
Normal file
201
gridlock/utils.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
from typing import Protocol, TypedDict, runtime_checkable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
class GridError(Exception):
|
||||||
|
""" Base error type for `gridlock` """
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ExtentDict(TypedDict, total=False):
|
||||||
|
min: float
|
||||||
|
center: float
|
||||||
|
max: float
|
||||||
|
span: float
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class ExtentProtocol(Protocol):
|
||||||
|
center: float
|
||||||
|
span: float
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max(self) -> float: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def min(self) -> float: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(init=False, slots=True)
|
||||||
|
class Extent(ExtentProtocol):
|
||||||
|
center: float
|
||||||
|
span: float
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max(self) -> float:
|
||||||
|
return self.center + self.span / 2
|
||||||
|
|
||||||
|
@property
|
||||||
|
def min(self) -> float:
|
||||||
|
return self.center - self.span / 2
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
min: float | None = None,
|
||||||
|
center: float | None = None,
|
||||||
|
max: float | None = None,
|
||||||
|
span: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
if sum(cc is None for cc in (min, center, max, span)) != 2:
|
||||||
|
raise GridError('Exactly two of min, center, max, span must be None!')
|
||||||
|
|
||||||
|
if span is None:
|
||||||
|
if center is None:
|
||||||
|
assert min is not None
|
||||||
|
assert max is not None
|
||||||
|
assert max >= min
|
||||||
|
center = 0.5 * (max + min)
|
||||||
|
span = max - min
|
||||||
|
elif max is None:
|
||||||
|
assert min is not None
|
||||||
|
assert center is not None
|
||||||
|
span = 2 * (center - min)
|
||||||
|
elif min is None:
|
||||||
|
assert center is not None
|
||||||
|
assert max is not None
|
||||||
|
span = 2 * (max - center)
|
||||||
|
else: # noqa: PLR5501
|
||||||
|
if center is not None:
|
||||||
|
pass
|
||||||
|
elif max is None:
|
||||||
|
assert min is not None
|
||||||
|
assert span is not None
|
||||||
|
center = min + 0.5 * span
|
||||||
|
elif min is None:
|
||||||
|
assert max is not None
|
||||||
|
assert span is not None
|
||||||
|
center = max - 0.5 * span
|
||||||
|
|
||||||
|
assert center is not None
|
||||||
|
assert span is not None
|
||||||
|
if hasattr(center, '__len__'):
|
||||||
|
assert len(center) == 1
|
||||||
|
if hasattr(span, '__len__'):
|
||||||
|
assert len(span) == 1
|
||||||
|
self.center = center
|
||||||
|
self.span = span
|
||||||
|
|
||||||
|
|
||||||
|
class SlabDict(TypedDict, total=False):
|
||||||
|
min: float
|
||||||
|
center: float
|
||||||
|
max: float
|
||||||
|
span: float
|
||||||
|
axis: int | str
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SlabProtocol(ExtentProtocol, Protocol):
|
||||||
|
axis: int
|
||||||
|
center: float
|
||||||
|
span: float
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max(self) -> float: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def min(self) -> float: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(init=False, slots=True)
|
||||||
|
class Slab(Extent, SlabProtocol):
|
||||||
|
axis: int
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
axis: int | str,
|
||||||
|
*,
|
||||||
|
min: float | None = None,
|
||||||
|
center: float | None = None,
|
||||||
|
max: float | None = None,
|
||||||
|
span: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
Extent.__init__(self, min=min, center=center, max=max, span=span)
|
||||||
|
|
||||||
|
if isinstance(axis, str):
|
||||||
|
axis_int = 'xyz'.find(axis.lower())
|
||||||
|
else:
|
||||||
|
axis_int = axis
|
||||||
|
if axis_int not in range(3):
|
||||||
|
raise GridError(f'Invalid axis (slab normal direction): {axis}')
|
||||||
|
self.axis = axis_int
|
||||||
|
|
||||||
|
def as_plane(self, where: str) -> 'Plane':
|
||||||
|
if where == 'center':
|
||||||
|
return Plane(axis=self.axis, pos=self.center)
|
||||||
|
if where == 'min':
|
||||||
|
return Plane(axis=self.axis, pos=self.min)
|
||||||
|
if where == 'max':
|
||||||
|
return Plane(axis=self.axis, pos=self.max)
|
||||||
|
raise GridError(f'Invalid {where=}')
|
||||||
|
|
||||||
|
|
||||||
|
class PlaneDict(TypedDict, total=False):
|
||||||
|
x: float
|
||||||
|
y: float
|
||||||
|
z: float
|
||||||
|
axis: int
|
||||||
|
pos: float
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class PlaneProtocol(Protocol):
|
||||||
|
axis: int
|
||||||
|
pos: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(init=False, slots=True)
|
||||||
|
class Plane(PlaneProtocol):
|
||||||
|
axis: int
|
||||||
|
pos: float
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
axis: int | str | None = None,
|
||||||
|
pos: float | None = None,
|
||||||
|
x: float | None = None,
|
||||||
|
y: float | None = None,
|
||||||
|
z: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
xx = x
|
||||||
|
yy = y
|
||||||
|
zz = z
|
||||||
|
|
||||||
|
if sum(aa is not None for aa in (pos, xx, yy, zz)) != 1:
|
||||||
|
raise GridError('Exactly one of pos, x, y, z must be non-None!')
|
||||||
|
if (axis is None) != (pos is None):
|
||||||
|
raise GridError('Either both or neither of `axis` and `pos` must be defined.')
|
||||||
|
|
||||||
|
if isinstance(axis, str):
|
||||||
|
axis_int = 'xyz'.find(axis.lower())
|
||||||
|
elif axis is None:
|
||||||
|
axis_int = (xx is None, yy is None, zz is None).index(False)
|
||||||
|
else:
|
||||||
|
axis_int = axis
|
||||||
|
|
||||||
|
if axis_int not in range(3):
|
||||||
|
raise GridError(f'Invalid axis (slab normal direction): {axis=} {x=} {y=} {z=}')
|
||||||
|
self.axis = axis_int
|
||||||
|
|
||||||
|
if pos is not None:
|
||||||
|
cpos = pos
|
||||||
|
else:
|
||||||
|
cpos = (xx, yy, zz)[axis_int]
|
||||||
|
assert cpos is not None
|
||||||
|
|
||||||
|
if hasattr(cpos, '__len__'):
|
||||||
|
assert len(cpos) == 1
|
||||||
|
self.pos = cpos
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user