diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 281c2b1..120291f 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -15,9 +15,24 @@ Dependencies: - mpl_toolkits.mplot3d [Grid.visualize_isosurface()] - skimage [Grid.visualize_isosurface()] """ -from .error import GridError as GridError +from .utils import ( + GridError as GridError, + + Extent as Extent, + ExtentProtocol as ExtentProtocol, + ExtentDict as ExtentDict, + + Slab as Slab, + SlabProtocol as SlabProtocol, + SlabDict as SlabDict, + + Plane as Plane, + PlaneProtocol as PlaneProtocol, + PlaneDict as PlaneDict, + ) from .grid import Grid as Grid + __author__ = 'Jan Petykiewicz' __version__ = '1.2' version = __version__ diff --git a/gridlock/direction.py b/gridlock/direction.py deleted file mode 100644 index b93b122..0000000 --- a/gridlock/direction.py +++ /dev/null @@ -1,10 +0,0 @@ -from enum import Enum - - -class Direction(Enum): - """ - Enum for axis->integer mapping - """ - x = 0 - y = 1 - z = 2 diff --git a/gridlock/draw.py b/gridlock/draw.py index 52e02f5..4930231 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -7,7 +7,7 @@ import numpy from numpy.typing import NDArray, ArrayLike from float_raster import raster -from . import GridError +from .utils import GridError, Slab, SlabDict, SlabProtocol, Extent, ExtentDict, ExtentProtocol from .position import GridPosMixin @@ -21,27 +21,26 @@ foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray] foreground_t = float | foreground_callable_t + class GridDrawMixin(GridPosMixin): def draw_polygons( self, cell_data: NDArray, - surface_normal: int, - center: ArrayLike, + slab: SlabProtocol | SlabDict, polygons: Sequence[ArrayLike], - thickness: float, foreground: Sequence[foreground_t] | foreground_t, + *, + offset2d: ArrayLike = (0, 0), ) -> None: """ Draw polygons on an axis-aligned plane. Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. center: 3-element ndarray or list specifying an offset applied to all the polygons polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon - (non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each + (non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each polygon must have at least 3 vertices. - thickness: Thickness of the layer to draw foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list of any of these (1 per grid). Callable values should take an ndarray the shape of the grid and return an ndarray of equal shape containing the foreground value at the given x, y, @@ -50,13 +49,13 @@ class GridDrawMixin(GridPosMixin): Raises: GridError """ - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') - center = numpy.squeeze(center) + if isinstance(slab, dict): + slab = Slab(**slab) + poly_list = [numpy.asarray(poly) for poly in polygons] # Check polygons, and remove redundant coordinates - surface = numpy.delete(range(3), surface_normal) + surface = numpy.delete(range(3), slab.axis) for ii in range(len(poly_list)): polygon = poly_list[ii] @@ -69,9 +68,8 @@ class GridDrawMixin(GridPosMixin): if not polygon.shape[0] > 2: raise GridError(malformed + 'must consist of more than 2 points') - if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1: - raise GridError(malformed + 'must be in plane with surface normal ' - + 'xyz'[surface_normal]) + if polygon.ndim > 2 and not numpy.unique(polygon[:, slab.axis]).size == 1: + raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis]) # Broadcast foreground where necessary foregrounds: Sequence[foreground_callable_t] | Sequence[float] @@ -87,10 +85,10 @@ class GridDrawMixin(GridPosMixin): bd_2d_min = numpy.array([0, 0]) bd_2d_max = numpy.array([0, 0]) for polygon in poly_list: - bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) - bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) - bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center - bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center + bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0)) + offset2d + bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0)) + offset2d + bd_min = numpy.insert(bd_2d_min, slab.axis, slab.min) + bd_max = numpy.insert(bd_2d_max, slab.axis, slab.max) # 2) Find indices (bdi) just outside bd elements buf = 2 # size of safety buffer @@ -103,13 +101,13 @@ class GridDrawMixin(GridPosMixin): bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int) bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int) - # 3) Adjust polygons for center - poly_list = [poly + center[surface] for poly in poly_list] + # 3) Adjust polygons for offset2d + poly_list = [poly + offset2d for poly in poly_list] # ## Generate weighing function def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]: v_2d = numpy.array(vector, dtype=float) - return numpy.insert(v_2d, surface_normal, (val,)) + return numpy.insert(v_2d, slab.axis, (val,)) # iterate over grids foreground_val: NDArray | float @@ -162,13 +160,12 @@ class GridDrawMixin(GridPosMixin): w_xy = numpy.minimum(w_xy, 1.0) # 2) Generate weights in z-direction - w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], )) + w_z = numpy.zeros(((bdi_max - bdi_min + 1)[slab.axis], )) - def get_zi(offset: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001 - edges = self.shifted_exyz(i)[surface_normal] - point = center[surface_normal] + offset + def get_zi(point: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001 + edges = self.shifted_exyz(i)[slab.axis] grid_coord = numpy.digitize(point, edges) - 1 - w_coord = grid_coord - bdi_min[surface_normal] + w_coord = grid_coord - bdi_min[slab.axis] if w_coord < 0: w_coord = 0 @@ -177,12 +174,12 @@ class GridDrawMixin(GridPosMixin): w_coord = w_z.size - 1 f = 1 else: - dz = self.shifted_dxyz(i)[surface_normal][grid_coord] + dz = self.shifted_dxyz(i)[slab.axis][grid_coord] f = (point - edges[grid_coord]) / dz return f, w_coord - zi_top_f, zi_top = get_zi(+thickness / 2.0) - zi_bot_f, zi_bot = get_zi(-thickness / 2.0) + zi_top_f, zi_top = get_zi(slab.max) + zi_bot_f, zi_bot = get_zi(slab.min) w_z[zi_bot + 1:zi_top] = 1 @@ -193,7 +190,7 @@ class GridDrawMixin(GridPosMixin): w_z[zi_bot] = zi_top_f - zi_bot_f # 3) Generate total weight function - w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,))) + w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], slab.axis, (2,))) # ## Modify the grid g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3)) @@ -203,34 +200,36 @@ class GridDrawMixin(GridPosMixin): def draw_polygon( self, cell_data: NDArray, - surface_normal: int, - center: ArrayLike, + slab: SlabProtocol | SlabDict, polygon: ArrayLike, - thickness: float, foreground: Sequence[foreground_t] | foreground_t, + *, + offset2d: ArrayLike = (0, 0), ) -> None: """ Draw a polygon on an axis-aligned plane. Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: 3-element ndarray or list specifying an offset applied to the polygon + slab: `Slab` in which to draw polygons. polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, - clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at + clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Must have at least 3 vertices. - thickness: Thickness of the layer to draw foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. """ - self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground) + self.draw_polygons( + cell_data = cell_data, + slab = slab, + polygons = [polygon], + foreground = foreground, + offset2d = offset2d, + ) def draw_slab( self, cell_data: NDArray, - surface_normal: int, - center: ArrayLike, - thickness: float, + slab: SlabProtocol | SlabDict, foreground: Sequence[foreground_t] | foreground_t, ) -> None: """ @@ -238,50 +237,45 @@ class GridDrawMixin(GridPosMixin): Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: `surface_normal` coordinate value at the center of the slab + slab: thickness: Thickness of the layer to draw foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. """ - # Turn surface_normal into its integer representation - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') - - if numpy.size(center) != 1: - center = numpy.squeeze(center) - if len(center) == 3: - center = center[surface_normal] - else: - raise GridError(f'Bad center: {center}') + if isinstance(slab, dict): + slab = Slab(**slab) # Find center of slab center_shift = self.center - center_shift[surface_normal] = center + center_shift[slab.axis] = slab.center - surface = numpy.delete(range(3), surface_normal) + surface = numpy.delete(range(3), slab.axis) + u_min, u_max = self.exyz[surface[0]][[0, -1]] + v_min, v_max = self.exyz[surface[1]][[0, -1]] - xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface] - xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface] + margin = 4 * numpy.max(self.dxyz[surface[0]].max(), + self.dxyz[surface[1]].max()) - dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float) + p = numpy.array([[u_min - margin, v_max + margin], + [u_max + margin, v_max + margin], + [u_max + margin, v_min - margin], + [u_min - margin, v_min - margin]], dtype=float) - xyz_min -= 4 * dxyz - xyz_max += 4 * dxyz - - p = numpy.array([[xyz_min[0], xyz_max[1]], - [xyz_max[0], xyz_max[1]], - [xyz_max[0], xyz_min[1]], - [xyz_min[0], xyz_min[1]]], dtype=float) - - self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground) + self.draw_polygon( + cell_data = cell_data, + slab = slab, + polygon = p, + foreground = foreground, + ) def draw_cuboid( self, cell_data: NDArray, - center: ArrayLike, - dimensions: ArrayLike, foreground: Sequence[foreground_t] | foreground_t, + *, + x: ExtentProtocol | ExtentDict, + y: ExtentProtocol | ExtentDict, + z: ExtentProtocol | ExtentDict, ) -> None: """ Draw an axis-aligned cuboid @@ -293,23 +287,30 @@ class GridDrawMixin(GridPosMixin): sizes of the cuboid foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. """ - dimensions = numpy.asarray(dimensions) - p = numpy.array([[-dimensions[0], +dimensions[1]], - [+dimensions[0], +dimensions[1]], - [+dimensions[0], -dimensions[1]], - [-dimensions[0], -dimensions[1]]], dtype=float) * 0.5 - thickness = dimensions[2] - self.draw_polygon(cell_data, 2, center, p, thickness, foreground) + if isinstance(x, dict): + x = Extent(**x) + if isinstance(y, dict): + y = Extent(**y) + if isinstance(z, dict): + z = Extent(**z) + + center = numpy.asarray([x.center, y.center, z.center]) + + p = numpy.array([[x.min, y.max], + [x.max, y.max], + [x.max, y.min], + [x.min, y.min]], dtype=float) + slab = Slab(axis=2, center=z.center, span=z.span) + self.draw_polygon(cell_data=cell_data, slab=slab, polygon=p, foreground=foreground) def draw_cylinder( self, cell_data: NDArray, - surface_normal: int, - center: ArrayLike, + h: SlabProtocol | SlabDict, radius: float, - thickness: float, num_points: int, + center2d: ArrayLike, foreground: Sequence[foreground_t] | foreground_t, ) -> None: """ @@ -317,18 +318,19 @@ class GridDrawMixin(GridPosMixin): Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. - center: 3-element ndarray or list specifying the cylinder's center - radius: cylinder radius - thickness: Thickness of the layer to draw + h: + radius: num_points: The circle is approximated by a polygon with `num_points` vertices + center2d: foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. """ - theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False) - x = radius * numpy.sin(theta) - y = radius * numpy.cos(theta) - polygon = numpy.hstack((x[:, None], y[:, None])) - self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground) + if isinstance(h, dict): + h = Slab(**h) + + theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False)[:, None] + xy0 = numpy.hstack((numpy.sin(theta), numpy.cos(theta))) + polygon = radius * xy0 + self.draw_polygon(cell_data=cell_data, slab=h, polygon=polygon, foreground=foreground, offset2d=center2d) def draw_extrude_rectangle( @@ -349,10 +351,10 @@ class GridDrawMixin(GridPosMixin): polarity: +1 or -1, direction along axis to extrude in distance: How far to extrude """ - s = numpy.sign(polarity) + sgn = numpy.sign(polarity) - rectangle = numpy.array(rectangle, dtype=float) - if s == 0: + rectangle = numpy.asarray(rectangle, dtype=float) + if sgn == 0: raise GridError('0 is not a valid polarity') if direction not in range(3): raise GridError(f'Invalid direction: {direction}') @@ -360,37 +362,38 @@ class GridDrawMixin(GridPosMixin): raise GridError('Rectangle entries along extrusion direction do not match.') center = rectangle.sum(axis=0) / 2.0 - center[direction] += s * distance / 2.0 + center[direction] += sgn * distance / 2.0 surface = numpy.delete(range(3), direction) dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface] - p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5, - numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T + poly = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5, + numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T thickness = distance foreground_func = [] - for i, grid in enumerate(cell_data): - z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction] + for ii, grid in enumerate(cell_data): + zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction] - ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)] + ind = [int(numpy.floor(zz)) if dd == direction else slice(None) for dd in range(3)] - fpart = z - numpy.floor(z) - mult = [1 - fpart, fpart][::s] # reverses if s negative + fpart = zz - numpy.floor(zz) + mult = [1 - fpart, fpart][::sgn] # reverses if s negative foreground = mult[0] * grid[tuple(ind)] ind[direction] += 1 # type: ignore #(known safe) foreground += mult[1] * grid[tuple(ind)] - def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 + def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 # transform from natural position to index - xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i) + xyzi = numpy.array([self.pos2ind(qrs, which_shifts=ii) for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64) # reshape to original shape and keep only in-plane components - qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface) + qi, ri = (numpy.reshape(xyzi[:, kk], xs.shape) for kk in surface) return foreground[qi, ri] foreground_func.append(f_foreground) - self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func) + slab = Slab(axis=direction, center=center[direction], span=thickness) + self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface]) diff --git a/gridlock/error.py b/gridlock/error.py deleted file mode 100644 index 3974e9c..0000000 --- a/gridlock/error.py +++ /dev/null @@ -1,2 +0,0 @@ -class GridError(Exception): - pass diff --git a/gridlock/examples/ex0.py b/gridlock/examples/ex0.py index d13a8cf..4ff2fb9 100644 --- a/gridlock/examples/ex0.py +++ b/gridlock/examples/ex0.py @@ -6,18 +6,18 @@ if __name__ == '__main__': # xyz = [numpy.arange(-5.0, 6.0), numpy.arange(-4.0, 5.0), [-1.0, 1.0]] # eg = Grid(xyz) # egc = Grid.allocate(0.0) - # # eg.draw_slab(egc, surface_normal=2, center=0, thickness=10, foreground=2) - # eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=4, - # thickness=10, num_points=1000, foreground=1) - # eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2) + # # eg.draw_slab(egc, slab=dict(axis=2, center=0, span=10), foreground=2) + # eg.draw_cylinder(egc, h=slab(axis=2, center=0, span=10), + # center2d=[0, 0], radius=4, thickness=10, num_points=1000, foreground=1) + # eg.visualize_slice(egc, plane=dict(z=0), which_shifts=2) # xyz2 = [numpy.arange(-5.0, 6.0), [-1.0, 1.0], numpy.arange(-4.0, 5.0)] # eg2 = Grid(xyz2) # eg2c = Grid.allocate(0.0) - # # eg2.draw_slab(eg2c, surface_normal=2, center=0, thickness=10, foreground=2) - # eg2.draw_cylinder(eg2c, surface_normal=1, center=[0, 0, 0], - # radius=4, thickness=10, num_points=1000, foreground=1.0) - # eg2.visualize_slice(eg2c, surface_normal=1, center=0, which_shifts=1) + # # eg2.draw_slab(eg2c, slab=dict(axis=2, center=0, span=10), foreground=2) + # eg2.draw_cylinder(eg2c, h=slab(axis=1, center=0, span=10), center2d=[0, 0], + # radius=4, num_points=1000, foreground=1.0) + # eg2.visualize_slice(eg2c, plane=dict(y=0), which_shifts=1) # n = 20 # m = 3 @@ -36,9 +36,20 @@ if __name__ == '__main__': egc = eg.allocate(0) # eg.draw_slab(Direction.z, 0, 10, 2) eg.save('/home/jan/Desktop/test.pickle') - eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0, - thickness=10, num_points=1000, foreground=1) - eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]], - direction=1, polarity=+1, distance=5) - eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2) + eg.draw_cylinder( + egc, + h=dict(axis='z', center=0, span=10), + center2d=[0, 0], + radius=2.0, + num_points=1000, + foreground=1, + ) + eg.draw_extrude_rectangle( + egc, + rectangle=[[-2, 1, -1], [0, 1, 1]], + direction=1, + polarity=+1, + distance=5, + ) + eg.visualize_slice(egc, plane=dict(z=0), which_shifts=2) eg.visualize_isosurface(egc, which_shifts=2) diff --git a/gridlock/read.py b/gridlock/read.py index 0f46836..707251a 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -6,7 +6,7 @@ from typing import Any, TYPE_CHECKING import numpy from numpy.typing import NDArray -from . import GridError +from .utils import GridError, Plane, PlaneDict, PlaneProtocol from .position import GridPosMixin if TYPE_CHECKING: @@ -23,27 +23,25 @@ class GridReadMixin(GridPosMixin): def get_slice( self, cell_data: NDArray, - surface_normal: int, - center: float, + plane: PlaneProtocol | PlaneDict, which_shifts: int = 0, sample_period: int = 1 ) -> NDArray: """ Retrieve a slice of a grid. - Interpolates if given a position between two planes. + Interpolates if given a position between two grid planes. Args: cell_data: Cell data to slice - surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. - center: Scalar specifying position along surface_normal axis. + plane: Axis and position (`Plane`) of the plane to read. which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) Returns: Array containing the portion of the grid. """ - if numpy.size(center) != 1 or not numpy.isreal(center): - raise GridError('center must be a real scalar') + if isinstance(plane, dict): + plane = Plane(**plane) sp = round(sample_period) if sp <= 0: @@ -52,15 +50,12 @@ class GridReadMixin(GridPosMixin): if numpy.size(which_shifts) != 1 or which_shifts < 0: raise GridError('Invalid which_shifts') - if surface_normal not in range(3): - raise GridError('Invalid surface_normal direction') - - surface = numpy.delete(range(3), surface_normal) + surface = numpy.delete(range(3), plane.axis) # Extract indices and weights of planes - center3 = numpy.insert([0, 0], surface_normal, (center,)) + center3 = numpy.insert([0, 0], plane.axis, (plane.pos,)) center_index = self.pos2ind(center3, which_shifts, - round_ind=False, check_bounds=False)[surface_normal] + round_ind=False, check_bounds=False)[plane.axis] centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) if len(centers) == 2: fpart = center_index - numpy.floor(center_index) @@ -68,14 +63,14 @@ class GridReadMixin(GridPosMixin): else: w = [1] - c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1]) - if center < c_min or center > c_max: + c_min, c_max = (self.xyz[plane.axis][i] for i in [0, -1]) + if plane.pos < c_min or plane.pos > c_max: raise GridError('Coordinate of selected plane must be within simulation domain') # Extract grid values from planes above and below visualized slice sliced_grid = numpy.zeros(self.shape[surface]) for ci, weight in zip(centers, w, strict=True): - s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3)) + s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3)) sliced_grid += weight * cell_data[which_shifts][tuple(s)] # Remove extra dimensions @@ -87,20 +82,19 @@ class GridReadMixin(GridPosMixin): def visualize_slice( self, cell_data: NDArray, - surface_normal: int, - center: float, + plane: PlaneProtocol | PlaneDict, which_shifts: int = 0, sample_period: int = 1, finalize: bool = True, pcolormesh_args: dict[str, Any] | None = None, - ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: + ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: """ Visualize a slice of a grid. - Interpolates if given a position between two planes. + Interpolates if given a position between two grid planes. Args: - surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. - center: Scalar specifying position along surface_normal axis. + cell_data: Cell data to visualize + plane: Axis and position (`Plane`) of the plane to read. which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` @@ -110,16 +104,20 @@ class GridReadMixin(GridPosMixin): """ from matplotlib import pyplot + if isinstance(plane, dict): + plane = Plane(**plane) + if pcolormesh_args is None: pcolormesh_args = {} - grid_slice = self.get_slice(cell_data=cell_data, - surface_normal=surface_normal, - center=center, - which_shifts=which_shifts, - sample_period=sample_period) + grid_slice = self.get_slice( + cell_data=cell_data, + plane=plane, + which_shifts=which_shifts, + sample_period=sample_period, + ) - surface = numpy.delete(range(3), surface_normal) + surface = numpy.delete(range(3), plane.axis) x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') @@ -145,7 +143,7 @@ class GridReadMixin(GridPosMixin): sample_period: int = 1, show_edges: bool = True, finalize: bool = True, - ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: + ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: """ Draw an isosurface plot of the device. @@ -183,18 +181,18 @@ class GridReadMixin(GridPosMixin): fig = pyplot.figure() ax = fig.add_subplot(111, projection='3d') if show_edges: - ax.plot_trisurf(xs, ys, faces, zs) + ax.plot_trisurf(xs, ys, faces, zs) # type: ignore else: - ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') + ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none') # type: ignore # Add a fake plot of a cube to force the axes to be equal lengths max_range = numpy.array([xs.max() - xs.min(), ys.max() - ys.min(), zs.max() - zs.min()], dtype=float).max() mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2] - xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min()) - ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min()) - zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min()) + xbs = 0.5 * max_range * mg[0].ravel() + 0.5 * (xs.max() + xs.min()) + ybs = 0.5 * max_range * mg[1].ravel() + 0.5 * (ys.max() + ys.min()) + zbs = 0.5 * max_range * mg[2].ravel() + 0.5 * (zs.max() + zs.min()) # Comment or uncomment following both lines to test the fake bounding box: for xb, yb, zb in zip(xbs, ybs, zbs, strict=True): ax.plot([xb], [yb], [zb], 'w') diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 7a70211..8d9ca92 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -2,7 +2,7 @@ import numpy from numpy.testing import assert_allclose #, assert_array_equal -from .. import Grid +from .. import Grid, Extent #, Slab, Plane def test_draw_oncenter_2x2() -> None: @@ -12,7 +12,13 @@ def test_draw_oncenter_2x2() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0, 0, 0], dimensions=[1, 1, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(center=0, span=1), + y=Extent(center=0, span=1), + z=dict(center=0, span=10), + foreground=1, + ) correct = numpy.array([[0.25, 0.25], [0.25, 0.25]])[None, :, :, None] @@ -27,7 +33,13 @@ def test_draw_ongrid_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0, 0, 0], dimensions=[2, 2, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(center=0, span=2), + y=dict(min=-1, max=1), + z=dict(center=0, min=-5), + foreground=1, + ) correct = numpy.array([[0, 0, 0, 0], [0, 1, 1, 0], @@ -44,7 +56,13 @@ def test_draw_xshift_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0.5, 0, 0], dimensions=[1.5, 2, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(center=0.5, span=1.5), + y=dict(min=-1, max=1), + z=dict(center=0, span=10), + foreground=1, + ) correct = numpy.array([[0, 0, 0, 0], [0, 0.25, 0.25, 0], @@ -61,7 +79,13 @@ def test_draw_yshift_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0, 0.5, 0], dimensions=[2, 1.5, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(min=-1, max=1), + y=dict(center=0.5, span=1.5), + z=dict(center=0, span=10), + foreground=1, + ) correct = numpy.array([[0, 0, 0, 0], [0, 0.25, 1, 0.25], @@ -78,7 +102,13 @@ def test_draw_2shift_4x4() -> None: grid = Grid([xs, ys, zs], shifts=[[0, 0, 0]]) arr = grid.allocate(0) - grid.draw_cuboid(arr, center=[0.5, 0, 0], dimensions=[1.5, 1, 10], foreground=1) + grid.draw_cuboid( + arr, + x=dict(center=0.5, span=1.5), + y=dict(min=-0.5, max=0.5), + z=dict(center=0, span=10), + foreground=1, + ) correct = numpy.array([[0, 0, 0, 0], [0, 0.125, 0.125, 0], diff --git a/gridlock/utils.py b/gridlock/utils.py new file mode 100644 index 0000000..7a12035 --- /dev/null +++ b/gridlock/utils.py @@ -0,0 +1,201 @@ +from typing import Protocol, TypedDict, runtime_checkable +from dataclasses import dataclass + + +class GridError(Exception): + """ Base error type for `gridlock` """ + pass + + +class ExtentDict(TypedDict, total=False): + min: float + center: float + max: float + span: float + + +@runtime_checkable +class ExtentProtocol(Protocol): + center: float + span: float + + @property + def max(self) -> float: ... + + @property + def min(self) -> float: ... + + +@dataclass(init=False, slots=True) +class Extent(ExtentProtocol): + center: float + span: float + + @property + def max(self) -> float: + return self.center + self.span / 2 + + @property + def min(self) -> float: + return self.center - self.span / 2 + + def __init__( + self, + *, + min: float | None = None, + center: float | None = None, + max: float | None = None, + span: float | None = None, + ) -> None: + if sum(cc is None for cc in (min, center, max, span)) != 2: + raise GridError('Exactly two of min, center, max, span must be None!') + + if span is None: + if center is None: + assert min is not None + assert max is not None + assert max >= min + center = 0.5 * (max + min) + span = max - min + elif max is None: + assert min is not None + assert center is not None + span = 2 * (center - min) + elif min is None: + assert center is not None + assert max is not None + span = 2 * (max - center) + else: # noqa: PLR5501 + if center is not None: + pass + elif max is None: + assert min is not None + assert span is not None + center = min + 0.5 * span + elif min is None: + assert max is not None + assert span is not None + center = max - 0.5 * span + + assert center is not None + assert span is not None + if hasattr(center, '__len__'): + assert len(center) == 1 + if hasattr(span, '__len__'): + assert len(span) == 1 + self.center = center + self.span = span + + +class SlabDict(TypedDict, total=False): + min: float + center: float + max: float + span: float + axis: int | str + + +@runtime_checkable +class SlabProtocol(ExtentProtocol, Protocol): + axis: int + center: float + span: float + + @property + def max(self) -> float: ... + + @property + def min(self) -> float: ... + + +@dataclass(init=False, slots=True) +class Slab(Extent, SlabProtocol): + axis: int + + def __init__( + self, + axis: int | str, + *, + min: float | None = None, + center: float | None = None, + max: float | None = None, + span: float | None = None, + ) -> None: + Extent.__init__(self, min=min, center=center, max=max, span=span) + + if isinstance(axis, str): + axis_int = 'xyz'.find(axis.lower()) + else: + axis_int = axis + if axis_int not in range(3): + raise GridError(f'Invalid axis (slab normal direction): {axis}') + self.axis = axis_int + + def as_plane(self, where: str) -> 'Plane': + if where == 'center': + return Plane(axis=self.axis, pos=self.center) + if where == 'min': + return Plane(axis=self.axis, pos=self.min) + if where == 'max': + return Plane(axis=self.axis, pos=self.max) + raise GridError(f'Invalid {where=}') + + +class PlaneDict(TypedDict, total=False): + x: float + y: float + z: float + axis: int + pos: float + + +@runtime_checkable +class PlaneProtocol(Protocol): + axis: int + pos: float + + +@dataclass(init=False, slots=True) +class Plane(PlaneProtocol): + axis: int + pos: float + + def __init__( + self, + *, + axis: int | str | None = None, + pos: float | None = None, + x: float | None = None, + y: float | None = None, + z: float | None = None, + ) -> None: + xx = x + yy = y + zz = z + + if sum(aa is not None for aa in (pos, xx, yy, zz)) != 1: + raise GridError('Exactly one of pos, x, y, z must be non-None!') + if (axis is None) != (pos is None): + raise GridError('Either both or neither of `axis` and `pos` must be defined.') + + if isinstance(axis, str): + axis_int = 'xyz'.find(axis.lower()) + elif axis is None: + axis_int = (xx is None, yy is None, zz is None).index(False) + else: + axis_int = axis + + if axis_int not in range(3): + raise GridError(f'Invalid axis (slab normal direction): {axis=} {x=} {y=} {z=}') + self.axis = axis_int + + if pos is not None: + cpos = pos + else: + cpos = (xx, yy, zz)[axis_int] + assert cpos is not None + + if hasattr(cpos, '__len__'): + assert len(cpos) == 1 + self.pos = cpos +