Compare commits
	
		
			10 Commits
		
	
	
		
			9ab97e763c
			...
			8e7e0edb1f
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 8e7e0edb1f | |||
| e5fdc3ce23 | |||
| 646911c4b5 | |||
| e256f56f2b | |||
| c32d94ed85 | |||
| 8c33a39c02 | |||
| f84a75f35a | |||
| 5a20339eab | |||
| e29c0901bd | |||
| a15e4bc05e | 
@ -15,8 +15,8 @@ Dependencies:
 | 
			
		||||
- mpl_toolkits.mplot3d  [Grid.visualize_isosurface()]
 | 
			
		||||
- skimage               [Grid.visualize_isosurface()]
 | 
			
		||||
"""
 | 
			
		||||
from .error import GridError
 | 
			
		||||
from .grid import Grid
 | 
			
		||||
from .error import GridError as GridError
 | 
			
		||||
from .grid import Grid as Grid
 | 
			
		||||
 | 
			
		||||
__author__ = 'Jan Petykiewicz'
 | 
			
		||||
__version__ = '1.1'
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										196
									
								
								gridlock/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										196
									
								
								gridlock/base.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,196 @@
 | 
			
		||||
from typing import Protocol
 | 
			
		||||
 | 
			
		||||
import numpy
 | 
			
		||||
from numpy.typing import NDArray
 | 
			
		||||
 | 
			
		||||
from . import GridError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GridBase(Protocol):
 | 
			
		||||
    exyz: list[NDArray]
 | 
			
		||||
    """Cell edges. Monotonically increasing without duplicates."""
 | 
			
		||||
 | 
			
		||||
    periodic: list[bool]
 | 
			
		||||
    """For each axis, determines how far the rightmost boundary gets shifted. """
 | 
			
		||||
 | 
			
		||||
    shifts: NDArray
 | 
			
		||||
    """Offsets `[[x0, y0, z0], [x1, y1, z1], ...]` for grid `0,1,...`"""
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def dxyz(self) -> list[NDArray]:
 | 
			
		||||
        """
 | 
			
		||||
        Cell sizes for each axis, no shifts applied
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List of 3 ndarrays of cell sizes
 | 
			
		||||
        """
 | 
			
		||||
        return [numpy.diff(ee) for ee in self.exyz]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def xyz(self) -> list[NDArray]:
 | 
			
		||||
        """
 | 
			
		||||
        Cell centers for each axis, no shifts applied
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List of 3 ndarrays of cell edges
 | 
			
		||||
        """
 | 
			
		||||
        return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def shape(self) -> NDArray[numpy.intp]:
 | 
			
		||||
        """
 | 
			
		||||
        The number of cells in x, y, and z
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            ndarray of [x_centers.size, y_centers.size, z_centers.size]
 | 
			
		||||
        """
 | 
			
		||||
        return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def num_grids(self) -> int:
 | 
			
		||||
        """
 | 
			
		||||
        The number of grids (number of shifts)
 | 
			
		||||
        """
 | 
			
		||||
        return self.shifts.shape[0]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def cell_data_shape(self) -> NDArray[numpy.intp]:
 | 
			
		||||
        """
 | 
			
		||||
        The shape of the cell_data ndarray (num_grids, *self.shape).
 | 
			
		||||
        """
 | 
			
		||||
        return numpy.hstack((self.num_grids, self.shape))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def dxyz_with_ghost(self) -> list[NDArray]:
 | 
			
		||||
        """
 | 
			
		||||
        Gives dxyz with an additional 'ghost' cell at the end, whose value depends
 | 
			
		||||
         on whether or not the axis has periodic boundary conditions. See main description
 | 
			
		||||
         above to learn why this is necessary.
 | 
			
		||||
 | 
			
		||||
         If periodic, final edge shifts same amount as first
 | 
			
		||||
         Otherwise, final edge shifts same amount as second-to-last
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list of [dxs, dys, dzs] with each element same length as elements of `self.xyz`
 | 
			
		||||
        """
 | 
			
		||||
        el = [0 if p else -1 for p in self.periodic]
 | 
			
		||||
        return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def center(self) -> NDArray[numpy.float64]:
 | 
			
		||||
        """
 | 
			
		||||
        Center position of the entire grid, no shifts applied
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            ndarray of [x_center, y_center, z_center]
 | 
			
		||||
        """
 | 
			
		||||
        # center is just average of first and last xyz, which is just the average of the
 | 
			
		||||
        #  first two and last two exyz
 | 
			
		||||
        centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)]
 | 
			
		||||
        return numpy.array(centers, dtype=float)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def dxyz_limits(self) -> tuple[NDArray, NDArray]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element
 | 
			
		||||
         ndarrays. No shifts are applied, so these are extreme bounds on these values (as a
 | 
			
		||||
         weighted average is performed when shifting).
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]`
 | 
			
		||||
        """
 | 
			
		||||
        d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float)
 | 
			
		||||
        d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float)
 | 
			
		||||
        return d_min, d_max
 | 
			
		||||
 | 
			
		||||
    def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns edges for which_shifts.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            which_shifts: Which grid (which shifts) to use, or `None` for unshifted
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List of 3 ndarrays of cell edges
 | 
			
		||||
        """
 | 
			
		||||
        if which_shifts is None:
 | 
			
		||||
            return self.exyz
 | 
			
		||||
        dxyz = self.dxyz_with_ghost
 | 
			
		||||
        shifts = self.shifts[which_shifts, :]
 | 
			
		||||
 | 
			
		||||
        # If shift is negative, use left cell's dx to determine shift
 | 
			
		||||
        for a in range(3):
 | 
			
		||||
            if shifts[a] < 0:
 | 
			
		||||
                dxyz[a] = numpy.roll(dxyz[a], 1)
 | 
			
		||||
 | 
			
		||||
        return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
 | 
			
		||||
 | 
			
		||||
    def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns cell sizes for `which_shifts`.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            which_shifts: Which grid (which shifts) to use, or `None` for unshifted
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List of 3 ndarrays of cell sizes
 | 
			
		||||
        """
 | 
			
		||||
        if which_shifts is None:
 | 
			
		||||
            return self.dxyz
 | 
			
		||||
        shifts = self.shifts[which_shifts, :]
 | 
			
		||||
        dxyz = self.dxyz_with_ghost
 | 
			
		||||
 | 
			
		||||
        # If shift is negative, use left cell's dx to determine size
 | 
			
		||||
        sdxyz = []
 | 
			
		||||
        for a in range(3):
 | 
			
		||||
            if shifts[a] < 0:
 | 
			
		||||
                roll_dxyz = numpy.roll(dxyz[a], 1)
 | 
			
		||||
                abs_shift = numpy.abs(shifts[a])
 | 
			
		||||
                sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
 | 
			
		||||
            else:
 | 
			
		||||
                sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
 | 
			
		||||
 | 
			
		||||
        return sdxyz
 | 
			
		||||
 | 
			
		||||
    def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns cell centers for `which_shifts`.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            which_shifts: Which grid (which shifts) to use, or `None` for unshifted
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List of 3 ndarrays of cell centers
 | 
			
		||||
        """
 | 
			
		||||
        if which_shifts is None:
 | 
			
		||||
            return self.xyz
 | 
			
		||||
        exyz = self.shifted_exyz(which_shifts)
 | 
			
		||||
        dxyz = self.shifted_dxyz(which_shifts)
 | 
			
		||||
        return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)]
 | 
			
		||||
 | 
			
		||||
    def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]:
 | 
			
		||||
        """
 | 
			
		||||
        Return cell widths, with each dimension shifted by the corresponding shifts.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            `[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]`
 | 
			
		||||
        """
 | 
			
		||||
        if self.num_grids != 3:
 | 
			
		||||
            raise GridError('Autoshifting requires exactly 3 grids')
 | 
			
		||||
        return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
 | 
			
		||||
 | 
			
		||||
    def allocate(self, fill_value: float | None = 1.0, dtype: type[numpy.number] = numpy.float32) -> NDArray:
 | 
			
		||||
        """
 | 
			
		||||
        Allocate an ndarray for storing grid data.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            fill_value: Value to initialize the grid to. If None, an
 | 
			
		||||
                uninitialized array is returned.
 | 
			
		||||
            dtype: Numpy dtype for the array. Default is `numpy.float32`.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            The allocated array
 | 
			
		||||
        """
 | 
			
		||||
        if fill_value is None:
 | 
			
		||||
            return numpy.empty(self.cell_data_shape, dtype=dtype)
 | 
			
		||||
        return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)
 | 
			
		||||
							
								
								
									
										683
									
								
								gridlock/draw.py
									
									
									
									
									
								
							
							
						
						
									
										683
									
								
								gridlock/draw.py
									
									
									
									
									
								
							@ -1,13 +1,14 @@
 | 
			
		||||
"""
 | 
			
		||||
Drawing-related methods for Grid class
 | 
			
		||||
"""
 | 
			
		||||
from typing import Union, Sequence, Callable
 | 
			
		||||
from collections.abc import Sequence, Callable
 | 
			
		||||
 | 
			
		||||
import numpy
 | 
			
		||||
from numpy.typing import NDArray, ArrayLike
 | 
			
		||||
from float_raster import raster
 | 
			
		||||
 | 
			
		||||
from . import GridError
 | 
			
		||||
from .position import GridPosMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# NOTE: Maybe it would make sense to create a GridDrawer class
 | 
			
		||||
@ -17,375 +18,379 @@ from . import GridError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray]
 | 
			
		||||
foreground_t = Union[float, foreground_callable_t]
 | 
			
		||||
foreground_t = float | foreground_callable_t
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def draw_polygons(
 | 
			
		||||
        self,
 | 
			
		||||
        cell_data: NDArray,
 | 
			
		||||
        surface_normal: int,
 | 
			
		||||
        center: ArrayLike,
 | 
			
		||||
        polygons: Sequence[NDArray],
 | 
			
		||||
        thickness: float,
 | 
			
		||||
        foreground: Sequence[foreground_t] | foreground_t,
 | 
			
		||||
        ) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Draw polygons on an axis-aligned plane.
 | 
			
		||||
class GridDrawMixin(GridPosMixin):
 | 
			
		||||
    def draw_polygons(
 | 
			
		||||
            self,
 | 
			
		||||
            cell_data: NDArray,
 | 
			
		||||
            surface_normal: int,
 | 
			
		||||
            center: ArrayLike,
 | 
			
		||||
            polygons: Sequence[ArrayLike],
 | 
			
		||||
            thickness: float,
 | 
			
		||||
            foreground: Sequence[foreground_t] | foreground_t,
 | 
			
		||||
            ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Draw polygons on an axis-aligned plane.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
 | 
			
		||||
        surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
 | 
			
		||||
        center: 3-element ndarray or list specifying an offset applied to all the polygons
 | 
			
		||||
        polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon
 | 
			
		||||
            (non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each
 | 
			
		||||
            polygon must have at least 3 vertices.
 | 
			
		||||
        thickness: Thickness of the layer to draw
 | 
			
		||||
        foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list
 | 
			
		||||
            of any of these (1 per grid). Callable values should take an ndarray the shape of the
 | 
			
		||||
            grid and return an ndarray of equal shape containing the foreground value at the given x, y,
 | 
			
		||||
            and z (natural, not grid coordinates).
 | 
			
		||||
        Args:
 | 
			
		||||
            cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
 | 
			
		||||
            surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
 | 
			
		||||
            center: 3-element ndarray or list specifying an offset applied to all the polygons
 | 
			
		||||
            polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon
 | 
			
		||||
                (non-closed, clockwise). If Nx3, the `surface_normal` coordinate is ignored. Each
 | 
			
		||||
                polygon must have at least 3 vertices.
 | 
			
		||||
            thickness: Thickness of the layer to draw
 | 
			
		||||
            foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list
 | 
			
		||||
                of any of these (1 per grid). Callable values should take an ndarray the shape of the
 | 
			
		||||
                grid and return an ndarray of equal shape containing the foreground value at the given x, y,
 | 
			
		||||
                and z (natural, not grid coordinates).
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
        GridError
 | 
			
		||||
    """
 | 
			
		||||
    if surface_normal not in range(3):
 | 
			
		||||
        raise GridError('Invalid surface_normal direction')
 | 
			
		||||
 | 
			
		||||
    center = numpy.squeeze(center)
 | 
			
		||||
 | 
			
		||||
    # Check polygons, and remove redundant coordinates
 | 
			
		||||
    surface = numpy.delete(range(3), surface_normal)
 | 
			
		||||
 | 
			
		||||
    for i, polygon in enumerate(polygons):
 | 
			
		||||
        malformed = f'Malformed polygon: ({i})'
 | 
			
		||||
        if polygon.shape[1] not in (2, 3):
 | 
			
		||||
            raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray')
 | 
			
		||||
        if polygon.shape[1] == 3:
 | 
			
		||||
            polygon = polygon[surface, :]
 | 
			
		||||
 | 
			
		||||
        if not polygon.shape[0] > 2:
 | 
			
		||||
            raise GridError(malformed + 'must consist of more than 2 points')
 | 
			
		||||
        if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1:
 | 
			
		||||
            raise GridError(malformed + 'must be in plane with surface normal '
 | 
			
		||||
                            + 'xyz'[surface_normal])
 | 
			
		||||
 | 
			
		||||
    # Broadcast foreground where necessary
 | 
			
		||||
    foregrounds: Sequence[foreground_callable_t] | Sequence[float]
 | 
			
		||||
    if numpy.size(foreground) == 1:     # type: ignore
 | 
			
		||||
        foregrounds = [foreground] * len(cell_data)  # type: ignore
 | 
			
		||||
    elif isinstance(foreground, numpy.ndarray):
 | 
			
		||||
        raise GridError('ndarray not supported for foreground')
 | 
			
		||||
    else:
 | 
			
		||||
        foregrounds = foreground    # type: ignore
 | 
			
		||||
 | 
			
		||||
    # ## Compute sub-domain of the grid occupied by polygons
 | 
			
		||||
    # 1) Compute outer bounds (bd) of polygons
 | 
			
		||||
    bd_2d_min = numpy.array([0, 0])
 | 
			
		||||
    bd_2d_max = numpy.array([0, 0])
 | 
			
		||||
    for polygon in polygons:
 | 
			
		||||
        bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0))
 | 
			
		||||
        bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0))
 | 
			
		||||
    bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center
 | 
			
		||||
    bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center
 | 
			
		||||
 | 
			
		||||
    # 2) Find indices (bdi) just outside bd elements
 | 
			
		||||
    buf = 2  # size of safety buffer
 | 
			
		||||
    # Use s_min and s_max with unshifted pos2ind to get absolute limits on
 | 
			
		||||
    #  the indices the polygons might affect
 | 
			
		||||
    s_min = self.shifts.min(axis=0)
 | 
			
		||||
    s_max = self.shifts.max(axis=0)
 | 
			
		||||
    bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf
 | 
			
		||||
    bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf
 | 
			
		||||
    bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int)
 | 
			
		||||
    bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int)
 | 
			
		||||
 | 
			
		||||
    # 3) Adjust polygons for center
 | 
			
		||||
    polygons = [poly + center[surface] for poly in polygons]
 | 
			
		||||
 | 
			
		||||
    # ## Generate weighing function
 | 
			
		||||
    def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]:
 | 
			
		||||
        v_2d = numpy.array(vector, dtype=float)
 | 
			
		||||
        return numpy.insert(v_2d, surface_normal, (val,))
 | 
			
		||||
 | 
			
		||||
    # iterate over grids
 | 
			
		||||
    for i, _ in enumerate(cell_data):
 | 
			
		||||
        # ## Evaluate or expand foregrounds[i]
 | 
			
		||||
        foregrounds_i = foregrounds[i]
 | 
			
		||||
        if callable(foregrounds_i):
 | 
			
		||||
            # meshgrid over the (shifted) domain
 | 
			
		||||
            domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k] + 1] for k in range(3)]
 | 
			
		||||
            (x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij')
 | 
			
		||||
 | 
			
		||||
            # evaluate on the meshgrid
 | 
			
		||||
            foreground_val = foregrounds_i(x0, y0, z0)
 | 
			
		||||
            if not numpy.isfinite(foreground_val).all():
 | 
			
		||||
                raise GridError(f'Non-finite values in foreground[{i}]')
 | 
			
		||||
        elif numpy.size(foregrounds_i) != 1:
 | 
			
		||||
            raise GridError(f'Unsupported foreground[{i}]: {type(foregrounds_i)}')
 | 
			
		||||
        else:
 | 
			
		||||
            # foreground[i] is scalar non-callable
 | 
			
		||||
            foreground_val = foregrounds_i
 | 
			
		||||
 | 
			
		||||
        w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int))
 | 
			
		||||
 | 
			
		||||
        # Draw each polygon separately
 | 
			
		||||
        for polygon in polygons:
 | 
			
		||||
 | 
			
		||||
            # Get the boundaries of the polygon
 | 
			
		||||
            pbd_min = polygon.min(axis=0)
 | 
			
		||||
            pbd_max = polygon.max(axis=0)
 | 
			
		||||
 | 
			
		||||
            # Find indices in w_xy just outside polygon
 | 
			
		||||
            #  using per-grid xy-weights (self.shifted_xyz())
 | 
			
		||||
            corner_min = self.pos2ind(to_3d(pbd_min), i,
 | 
			
		||||
                                      check_bounds=False)[surface].astype(int)
 | 
			
		||||
            corner_max = self.pos2ind(to_3d(pbd_max), i,
 | 
			
		||||
                                      check_bounds=False)[surface].astype(int)
 | 
			
		||||
 | 
			
		||||
            # Find indices in w_xy which are modified by polygon
 | 
			
		||||
            # First for the edge coordinates (+1 since we're indexing edges)
 | 
			
		||||
            edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max)]
 | 
			
		||||
            # Then for the pixel centers (-bdi_min since we're
 | 
			
		||||
            #  calculating weights within a subspace)
 | 
			
		||||
            centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface],
 | 
			
		||||
                                                                    corner_max - bdi_min[surface]))
 | 
			
		||||
 | 
			
		||||
            aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices))
 | 
			
		||||
            w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y)
 | 
			
		||||
 | 
			
		||||
        # Clamp overlapping polygons to 1
 | 
			
		||||
        w_xy = numpy.minimum(w_xy, 1.0)
 | 
			
		||||
 | 
			
		||||
        # 2) Generate weights in z-direction
 | 
			
		||||
        w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], ))
 | 
			
		||||
 | 
			
		||||
        def get_zi(offset, i=i, w_z=w_z):
 | 
			
		||||
            edges = self.shifted_exyz(i)[surface_normal]
 | 
			
		||||
            point = center[surface_normal] + offset
 | 
			
		||||
            grid_coord = numpy.digitize(point, edges) - 1
 | 
			
		||||
            w_coord = grid_coord - bdi_min[surface_normal]
 | 
			
		||||
 | 
			
		||||
            if w_coord < 0:
 | 
			
		||||
                w_coord = 0
 | 
			
		||||
                f = 0
 | 
			
		||||
            elif w_coord >= w_z.size:
 | 
			
		||||
                w_coord = w_z.size - 1
 | 
			
		||||
                f = 1
 | 
			
		||||
            else:
 | 
			
		||||
                dz = self.shifted_dxyz(i)[surface_normal][grid_coord]
 | 
			
		||||
                f = (point - edges[grid_coord]) / dz
 | 
			
		||||
            return f, w_coord
 | 
			
		||||
 | 
			
		||||
        zi_top_f, zi_top = get_zi(+thickness / 2.0)
 | 
			
		||||
        zi_bot_f, zi_bot = get_zi(-thickness / 2.0)
 | 
			
		||||
 | 
			
		||||
        w_z[zi_bot + 1:zi_top] = 1
 | 
			
		||||
 | 
			
		||||
        if zi_bot < zi_top:
 | 
			
		||||
            w_z[zi_top] = zi_top_f
 | 
			
		||||
            w_z[zi_bot] = 1 - zi_bot_f
 | 
			
		||||
        else:
 | 
			
		||||
            w_z[zi_bot] = zi_top_f - zi_bot_f
 | 
			
		||||
 | 
			
		||||
        # 3) Generate total weight function
 | 
			
		||||
        w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,)))
 | 
			
		||||
 | 
			
		||||
        # ## Modify the grid
 | 
			
		||||
        g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3))
 | 
			
		||||
        cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def draw_polygon(
 | 
			
		||||
        self,
 | 
			
		||||
        cell_data: NDArray,
 | 
			
		||||
        surface_normal: int,
 | 
			
		||||
        center: ArrayLike,
 | 
			
		||||
        polygon: ArrayLike,
 | 
			
		||||
        thickness: float,
 | 
			
		||||
        foreground: Sequence[foreground_t] | foreground_t,
 | 
			
		||||
        ) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Draw a polygon on an axis-aligned plane.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
 | 
			
		||||
        surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
 | 
			
		||||
        center: 3-element ndarray or list specifying an offset applied to the polygon
 | 
			
		||||
        polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed,
 | 
			
		||||
            clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at
 | 
			
		||||
            least 3 vertices.
 | 
			
		||||
        thickness: Thickness of the layer to draw
 | 
			
		||||
        foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
 | 
			
		||||
    """
 | 
			
		||||
    self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def draw_slab(
 | 
			
		||||
        self,
 | 
			
		||||
        cell_data: NDArray,
 | 
			
		||||
        surface_normal: int,
 | 
			
		||||
        center: ArrayLike,
 | 
			
		||||
        thickness: float,
 | 
			
		||||
        foreground: Sequence[foreground_t] | foreground_t,
 | 
			
		||||
        ) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Draw an axis-aligned infinite slab.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
 | 
			
		||||
        surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
 | 
			
		||||
        center: `surface_normal` coordinate value at the center of the slab
 | 
			
		||||
        thickness: Thickness of the layer to draw
 | 
			
		||||
        foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
 | 
			
		||||
    """
 | 
			
		||||
    # Turn surface_normal into its integer representation
 | 
			
		||||
    if surface_normal not in range(3):
 | 
			
		||||
        raise GridError('Invalid surface_normal direction')
 | 
			
		||||
 | 
			
		||||
    if numpy.size(center) != 1:
 | 
			
		||||
        Raises:
 | 
			
		||||
            GridError
 | 
			
		||||
        """
 | 
			
		||||
        if surface_normal not in range(3):
 | 
			
		||||
            raise GridError('Invalid surface_normal direction')
 | 
			
		||||
        center = numpy.squeeze(center)
 | 
			
		||||
        if len(center) == 3:
 | 
			
		||||
            center = center[surface_normal]
 | 
			
		||||
        poly_list = [numpy.array(poly, copy=False) for poly in polygons]
 | 
			
		||||
 | 
			
		||||
        # Check polygons, and remove redundant coordinates
 | 
			
		||||
        surface = numpy.delete(range(3), surface_normal)
 | 
			
		||||
 | 
			
		||||
        for ii in range(len(poly_list)):
 | 
			
		||||
            polygon = poly_list[ii]
 | 
			
		||||
            malformed = f'Malformed polygon: ({ii})'
 | 
			
		||||
            if polygon.shape[1] not in (2, 3):
 | 
			
		||||
                raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray')
 | 
			
		||||
            if polygon.shape[1] == 3:
 | 
			
		||||
                polygon = polygon[surface, :]
 | 
			
		||||
                poly_list[ii] = polygon
 | 
			
		||||
 | 
			
		||||
            if not polygon.shape[0] > 2:
 | 
			
		||||
                raise GridError(malformed + 'must consist of more than 2 points')
 | 
			
		||||
            if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1:
 | 
			
		||||
                raise GridError(malformed + 'must be in plane with surface normal '
 | 
			
		||||
                                + 'xyz'[surface_normal])
 | 
			
		||||
 | 
			
		||||
        # Broadcast foreground where necessary
 | 
			
		||||
        foregrounds: Sequence[foreground_callable_t] | Sequence[float]
 | 
			
		||||
        if numpy.size(foreground) == 1:     # type: ignore
 | 
			
		||||
            foregrounds = [foreground] * len(cell_data)  # type: ignore
 | 
			
		||||
        elif isinstance(foreground, numpy.ndarray):
 | 
			
		||||
            raise GridError('ndarray not supported for foreground')
 | 
			
		||||
        else:
 | 
			
		||||
            raise GridError(f'Bad center: {center}')
 | 
			
		||||
            foregrounds = foreground    # type: ignore
 | 
			
		||||
 | 
			
		||||
    # Find center of slab
 | 
			
		||||
    center_shift = self.center
 | 
			
		||||
    center_shift[surface_normal] = center
 | 
			
		||||
        # ## Compute sub-domain of the grid occupied by polygons
 | 
			
		||||
        # 1) Compute outer bounds (bd) of polygons
 | 
			
		||||
        bd_2d_min = numpy.array([0, 0])
 | 
			
		||||
        bd_2d_max = numpy.array([0, 0])
 | 
			
		||||
        for polygon in poly_list:
 | 
			
		||||
            bd_2d_min = numpy.minimum(bd_2d_min, polygon.min(axis=0))
 | 
			
		||||
            bd_2d_max = numpy.maximum(bd_2d_max, polygon.max(axis=0))
 | 
			
		||||
        bd_min = numpy.insert(bd_2d_min, surface_normal, -thickness / 2.0) + center
 | 
			
		||||
        bd_max = numpy.insert(bd_2d_max, surface_normal, +thickness / 2.0) + center
 | 
			
		||||
 | 
			
		||||
    surface = numpy.delete(range(3), surface_normal)
 | 
			
		||||
        # 2) Find indices (bdi) just outside bd elements
 | 
			
		||||
        buf = 2  # size of safety buffer
 | 
			
		||||
        # Use s_min and s_max with unshifted pos2ind to get absolute limits on
 | 
			
		||||
        #  the indices the polygons might affect
 | 
			
		||||
        s_min = self.shifts.min(axis=0)
 | 
			
		||||
        s_max = self.shifts.max(axis=0)
 | 
			
		||||
        bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf
 | 
			
		||||
        bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf
 | 
			
		||||
        bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int)
 | 
			
		||||
        bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int)
 | 
			
		||||
 | 
			
		||||
    xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface]
 | 
			
		||||
    xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface]
 | 
			
		||||
        # 3) Adjust polygons for center
 | 
			
		||||
        poly_list = [poly + center[surface] for poly in poly_list]
 | 
			
		||||
 | 
			
		||||
    dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float)
 | 
			
		||||
        # ## Generate weighing function
 | 
			
		||||
        def to_3d(vector: NDArray, val: float = 0.0) -> NDArray[numpy.float64]:
 | 
			
		||||
            v_2d = numpy.array(vector, dtype=float)
 | 
			
		||||
            return numpy.insert(v_2d, surface_normal, (val,))
 | 
			
		||||
 | 
			
		||||
    xyz_min -= 4 * dxyz
 | 
			
		||||
    xyz_max += 4 * dxyz
 | 
			
		||||
        # iterate over grids
 | 
			
		||||
        foreground_val: NDArray | float
 | 
			
		||||
        for i, _ in enumerate(cell_data):
 | 
			
		||||
            # ## Evaluate or expand foregrounds[i]
 | 
			
		||||
            foregrounds_i = foregrounds[i]
 | 
			
		||||
            if callable(foregrounds_i):
 | 
			
		||||
                # meshgrid over the (shifted) domain
 | 
			
		||||
                domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k] + 1] for k in range(3)]
 | 
			
		||||
                (x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij')
 | 
			
		||||
 | 
			
		||||
    p = numpy.array([[xyz_min[0], xyz_max[1]],
 | 
			
		||||
                     [xyz_max[0], xyz_max[1]],
 | 
			
		||||
                     [xyz_max[0], xyz_min[1]],
 | 
			
		||||
                     [xyz_min[0], xyz_min[1]]], dtype=float)
 | 
			
		||||
                # evaluate on the meshgrid
 | 
			
		||||
                foreground_val = foregrounds_i(x0, y0, z0)
 | 
			
		||||
                if not numpy.isfinite(foreground_val).all():
 | 
			
		||||
                    raise GridError(f'Non-finite values in foreground[{i}]')
 | 
			
		||||
            elif numpy.size(foregrounds_i) != 1:
 | 
			
		||||
                raise GridError(f'Unsupported foreground[{i}]: {type(foregrounds_i)}')
 | 
			
		||||
            else:
 | 
			
		||||
                # foreground[i] is scalar non-callable
 | 
			
		||||
                foreground_val = foregrounds_i
 | 
			
		||||
 | 
			
		||||
    self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground)
 | 
			
		||||
            w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int))
 | 
			
		||||
 | 
			
		||||
            # Draw each polygon separately
 | 
			
		||||
            for polygon in poly_list:
 | 
			
		||||
 | 
			
		||||
                # Get the boundaries of the polygon
 | 
			
		||||
                pbd_min = polygon.min(axis=0)
 | 
			
		||||
                pbd_max = polygon.max(axis=0)
 | 
			
		||||
 | 
			
		||||
                # Find indices in w_xy just outside polygon
 | 
			
		||||
                #  using per-grid xy-weights (self.shifted_xyz())
 | 
			
		||||
                corner_min = self.pos2ind(to_3d(pbd_min), i,
 | 
			
		||||
                                          check_bounds=False)[surface].astype(int)
 | 
			
		||||
                corner_max = self.pos2ind(to_3d(pbd_max), i,
 | 
			
		||||
                                          check_bounds=False)[surface].astype(int)
 | 
			
		||||
 | 
			
		||||
                # Find indices in w_xy which are modified by polygon
 | 
			
		||||
                # First for the edge coordinates (+1 since we're indexing edges)
 | 
			
		||||
                edge_slices = [numpy.s_[i:f + 2] for i, f in zip(corner_min, corner_max, strict=True)]
 | 
			
		||||
                # Then for the pixel centers (-bdi_min since we're
 | 
			
		||||
                #  calculating weights within a subspace)
 | 
			
		||||
                centers_slice = tuple(numpy.s_[i:f + 1] for i, f in zip(corner_min - bdi_min[surface],
 | 
			
		||||
                                                                        corner_max - bdi_min[surface], strict=True))
 | 
			
		||||
 | 
			
		||||
                aa_x, aa_y = (self.shifted_exyz(i)[a][s] for a, s in zip(surface, edge_slices, strict=True))
 | 
			
		||||
                w_xy[centers_slice] += raster(polygon.T, aa_x, aa_y)
 | 
			
		||||
 | 
			
		||||
            # Clamp overlapping polygons to 1
 | 
			
		||||
            w_xy = numpy.minimum(w_xy, 1.0)
 | 
			
		||||
 | 
			
		||||
            # 2) Generate weights in z-direction
 | 
			
		||||
            w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], ))
 | 
			
		||||
 | 
			
		||||
            def get_zi(offset: float, i=i, w_z=w_z) -> tuple[float, int]:          # noqa: ANN001
 | 
			
		||||
                edges = self.shifted_exyz(i)[surface_normal]
 | 
			
		||||
                point = center[surface_normal] + offset
 | 
			
		||||
                grid_coord = numpy.digitize(point, edges) - 1
 | 
			
		||||
                w_coord = grid_coord - bdi_min[surface_normal]
 | 
			
		||||
 | 
			
		||||
                if w_coord < 0:
 | 
			
		||||
                    w_coord = 0
 | 
			
		||||
                    f = 0
 | 
			
		||||
                elif w_coord >= w_z.size:
 | 
			
		||||
                    w_coord = w_z.size - 1
 | 
			
		||||
                    f = 1
 | 
			
		||||
                else:
 | 
			
		||||
                    dz = self.shifted_dxyz(i)[surface_normal][grid_coord]
 | 
			
		||||
                    f = (point - edges[grid_coord]) / dz
 | 
			
		||||
                return f, w_coord
 | 
			
		||||
 | 
			
		||||
            zi_top_f, zi_top = get_zi(+thickness / 2.0)
 | 
			
		||||
            zi_bot_f, zi_bot = get_zi(-thickness / 2.0)
 | 
			
		||||
 | 
			
		||||
            w_z[zi_bot + 1:zi_top] = 1
 | 
			
		||||
 | 
			
		||||
            if zi_bot < zi_top:
 | 
			
		||||
                w_z[zi_top] = zi_top_f
 | 
			
		||||
                w_z[zi_bot] = 1 - zi_bot_f
 | 
			
		||||
            else:
 | 
			
		||||
                w_z[zi_bot] = zi_top_f - zi_bot_f
 | 
			
		||||
 | 
			
		||||
            # 3) Generate total weight function
 | 
			
		||||
            w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,)))
 | 
			
		||||
 | 
			
		||||
            # ## Modify the grid
 | 
			
		||||
            g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3))
 | 
			
		||||
            cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def draw_cuboid(
 | 
			
		||||
        self,
 | 
			
		||||
        cell_data: NDArray,
 | 
			
		||||
        center: ArrayLike,
 | 
			
		||||
        dimensions: ArrayLike,
 | 
			
		||||
        foreground: Sequence[foreground_t] | foreground_t,
 | 
			
		||||
        ) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Draw an axis-aligned cuboid
 | 
			
		||||
    def draw_polygon(
 | 
			
		||||
            self,
 | 
			
		||||
            cell_data: NDArray,
 | 
			
		||||
            surface_normal: int,
 | 
			
		||||
            center: ArrayLike,
 | 
			
		||||
            polygon: ArrayLike,
 | 
			
		||||
            thickness: float,
 | 
			
		||||
            foreground: Sequence[foreground_t] | foreground_t,
 | 
			
		||||
            ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Draw a polygon on an axis-aligned plane.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
 | 
			
		||||
        center: 3-element ndarray or list specifying the cuboid's center
 | 
			
		||||
        dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge
 | 
			
		||||
             sizes of the cuboid
 | 
			
		||||
        foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
 | 
			
		||||
    """
 | 
			
		||||
    dimensions = numpy.array(dimensions, copy=False)
 | 
			
		||||
    p = numpy.array([[-dimensions[0], +dimensions[1]],
 | 
			
		||||
                     [+dimensions[0], +dimensions[1]],
 | 
			
		||||
                     [+dimensions[0], -dimensions[1]],
 | 
			
		||||
                     [-dimensions[0], -dimensions[1]]], dtype=float) / 2.0
 | 
			
		||||
    thickness = dimensions[2]
 | 
			
		||||
    self.draw_polygon(cell_data, 2, center, p, thickness, foreground)
 | 
			
		||||
        Args:
 | 
			
		||||
            cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
 | 
			
		||||
            surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
 | 
			
		||||
            center: 3-element ndarray or list specifying an offset applied to the polygon
 | 
			
		||||
            polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed,
 | 
			
		||||
                clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at
 | 
			
		||||
                least 3 vertices.
 | 
			
		||||
            thickness: Thickness of the layer to draw
 | 
			
		||||
            foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
 | 
			
		||||
        """
 | 
			
		||||
        self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, foreground)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def draw_cylinder(
 | 
			
		||||
        self,
 | 
			
		||||
        cell_data: NDArray,
 | 
			
		||||
        surface_normal: int,
 | 
			
		||||
        center: ArrayLike,
 | 
			
		||||
        radius: float,
 | 
			
		||||
        thickness: float,
 | 
			
		||||
        num_points: int,
 | 
			
		||||
        foreground: Sequence[foreground_t] | foreground_t,
 | 
			
		||||
        ) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Draw an axis-aligned cylinder. Approximated by a num_points-gon
 | 
			
		||||
    def draw_slab(
 | 
			
		||||
            self,
 | 
			
		||||
            cell_data: NDArray,
 | 
			
		||||
            surface_normal: int,
 | 
			
		||||
            center: ArrayLike,
 | 
			
		||||
            thickness: float,
 | 
			
		||||
            foreground: Sequence[foreground_t] | foreground_t,
 | 
			
		||||
            ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Draw an axis-aligned infinite slab.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
 | 
			
		||||
        surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
 | 
			
		||||
        center: 3-element ndarray or list specifying the cylinder's center
 | 
			
		||||
        radius: cylinder radius
 | 
			
		||||
        thickness: Thickness of the layer to draw
 | 
			
		||||
        num_points: The circle is approximated by a polygon with `num_points` vertices
 | 
			
		||||
        foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
 | 
			
		||||
    """
 | 
			
		||||
    theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False)
 | 
			
		||||
    x = radius * numpy.sin(theta)
 | 
			
		||||
    y = radius * numpy.cos(theta)
 | 
			
		||||
    polygon = numpy.hstack((x[:, None], y[:, None]))
 | 
			
		||||
    self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground)
 | 
			
		||||
        Args:
 | 
			
		||||
            cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
 | 
			
		||||
            surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
 | 
			
		||||
            center: `surface_normal` coordinate value at the center of the slab
 | 
			
		||||
            thickness: Thickness of the layer to draw
 | 
			
		||||
            foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
 | 
			
		||||
        """
 | 
			
		||||
        # Turn surface_normal into its integer representation
 | 
			
		||||
        if surface_normal not in range(3):
 | 
			
		||||
            raise GridError('Invalid surface_normal direction')
 | 
			
		||||
 | 
			
		||||
        if numpy.size(center) != 1:
 | 
			
		||||
            center = numpy.squeeze(center)
 | 
			
		||||
            if len(center) == 3:
 | 
			
		||||
                center = center[surface_normal]
 | 
			
		||||
            else:
 | 
			
		||||
                raise GridError(f'Bad center: {center}')
 | 
			
		||||
 | 
			
		||||
        # Find center of slab
 | 
			
		||||
        center_shift = self.center
 | 
			
		||||
        center_shift[surface_normal] = center
 | 
			
		||||
 | 
			
		||||
        surface = numpy.delete(range(3), surface_normal)
 | 
			
		||||
 | 
			
		||||
        xyz_min = numpy.array([self.xyz[a][0] for a in range(3)], dtype=float)[surface]
 | 
			
		||||
        xyz_max = numpy.array([self.xyz[a][-1] for a in range(3)], dtype=float)[surface]
 | 
			
		||||
 | 
			
		||||
        dxyz = numpy.array([max(self.dxyz[i]) for i in surface], dtype=float)
 | 
			
		||||
 | 
			
		||||
        xyz_min -= 4 * dxyz
 | 
			
		||||
        xyz_max += 4 * dxyz
 | 
			
		||||
 | 
			
		||||
        p = numpy.array([[xyz_min[0], xyz_max[1]],
 | 
			
		||||
                         [xyz_max[0], xyz_max[1]],
 | 
			
		||||
                         [xyz_max[0], xyz_min[1]],
 | 
			
		||||
                         [xyz_min[0], xyz_min[1]]], dtype=float)
 | 
			
		||||
 | 
			
		||||
        self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, foreground)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def draw_extrude_rectangle(
 | 
			
		||||
        self,
 | 
			
		||||
        cell_data: NDArray,
 | 
			
		||||
        rectangle: ArrayLike,
 | 
			
		||||
        direction: int,
 | 
			
		||||
        polarity: int,
 | 
			
		||||
        distance: float,
 | 
			
		||||
        ) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Extrude a rectangle of a previously-drawn structure along an axis.
 | 
			
		||||
    def draw_cuboid(
 | 
			
		||||
            self,
 | 
			
		||||
            cell_data: NDArray,
 | 
			
		||||
            center: ArrayLike,
 | 
			
		||||
            dimensions: ArrayLike,
 | 
			
		||||
            foreground: Sequence[foreground_t] | foreground_t,
 | 
			
		||||
            ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Draw an axis-aligned cuboid
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
 | 
			
		||||
        rectangle: 2x3 ndarray or list specifying the rectangle's corners
 | 
			
		||||
        direction: Direction to extrude in. Integer in `range(3)`.
 | 
			
		||||
        polarity: +1 or -1, direction along axis to extrude in
 | 
			
		||||
        distance: How far to extrude
 | 
			
		||||
    """
 | 
			
		||||
    s = numpy.sign(polarity)
 | 
			
		||||
        Args:
 | 
			
		||||
            cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
 | 
			
		||||
            center: 3-element ndarray or list specifying the cuboid's center
 | 
			
		||||
            dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge
 | 
			
		||||
                 sizes of the cuboid
 | 
			
		||||
            foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
 | 
			
		||||
        """
 | 
			
		||||
        dimensions = numpy.array(dimensions, copy=False)
 | 
			
		||||
        p = numpy.array([[-dimensions[0], +dimensions[1]],
 | 
			
		||||
                         [+dimensions[0], +dimensions[1]],
 | 
			
		||||
                         [+dimensions[0], -dimensions[1]],
 | 
			
		||||
                         [-dimensions[0], -dimensions[1]]], dtype=float) / 2.0
 | 
			
		||||
        thickness = dimensions[2]
 | 
			
		||||
        self.draw_polygon(cell_data, 2, center, p, thickness, foreground)
 | 
			
		||||
 | 
			
		||||
    rectangle = numpy.array(rectangle, dtype=float)
 | 
			
		||||
    if s == 0:
 | 
			
		||||
        raise GridError('0 is not a valid polarity')
 | 
			
		||||
    if direction not in range(3):
 | 
			
		||||
        raise GridError(f'Invalid direction: {direction}')
 | 
			
		||||
    if rectangle[0, direction] != rectangle[1, direction]:
 | 
			
		||||
        raise GridError('Rectangle entries along extrusion direction do not match.')
 | 
			
		||||
 | 
			
		||||
    center = rectangle.sum(axis=0) / 2.0
 | 
			
		||||
    center[direction] += s * distance / 2.0
 | 
			
		||||
    def draw_cylinder(
 | 
			
		||||
            self,
 | 
			
		||||
            cell_data: NDArray,
 | 
			
		||||
            surface_normal: int,
 | 
			
		||||
            center: ArrayLike,
 | 
			
		||||
            radius: float,
 | 
			
		||||
            thickness: float,
 | 
			
		||||
            num_points: int,
 | 
			
		||||
            foreground: Sequence[foreground_t] | foreground_t,
 | 
			
		||||
            ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Draw an axis-aligned cylinder. Approximated by a num_points-gon
 | 
			
		||||
 | 
			
		||||
    surface = numpy.delete(range(3), direction)
 | 
			
		||||
        Args:
 | 
			
		||||
            cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
 | 
			
		||||
            surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
 | 
			
		||||
            center: 3-element ndarray or list specifying the cylinder's center
 | 
			
		||||
            radius: cylinder radius
 | 
			
		||||
            thickness: Thickness of the layer to draw
 | 
			
		||||
            num_points: The circle is approximated by a polygon with `num_points` vertices
 | 
			
		||||
            foreground: Value to draw with ('brush color'). See `draw_polygons()` for details.
 | 
			
		||||
        """
 | 
			
		||||
        theta = numpy.linspace(0, 2 * numpy.pi, num_points, endpoint=False)
 | 
			
		||||
        x = radius * numpy.sin(theta)
 | 
			
		||||
        y = radius * numpy.cos(theta)
 | 
			
		||||
        polygon = numpy.hstack((x[:, None], y[:, None]))
 | 
			
		||||
        self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, foreground)
 | 
			
		||||
 | 
			
		||||
    dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface]
 | 
			
		||||
    p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5,
 | 
			
		||||
                      numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T
 | 
			
		||||
    thickness = distance
 | 
			
		||||
 | 
			
		||||
    foreground_func = []
 | 
			
		||||
    for i, grid in enumerate(cell_data):
 | 
			
		||||
        z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction]
 | 
			
		||||
    def draw_extrude_rectangle(
 | 
			
		||||
            self,
 | 
			
		||||
            cell_data: NDArray,
 | 
			
		||||
            rectangle: ArrayLike,
 | 
			
		||||
            direction: int,
 | 
			
		||||
            polarity: int,
 | 
			
		||||
            distance: float,
 | 
			
		||||
            ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Extrude a rectangle of a previously-drawn structure along an axis.
 | 
			
		||||
 | 
			
		||||
        ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)]
 | 
			
		||||
        Args:
 | 
			
		||||
            cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
 | 
			
		||||
            rectangle: 2x3 ndarray or list specifying the rectangle's corners
 | 
			
		||||
            direction: Direction to extrude in. Integer in `range(3)`.
 | 
			
		||||
            polarity: +1 or -1, direction along axis to extrude in
 | 
			
		||||
            distance: How far to extrude
 | 
			
		||||
        """
 | 
			
		||||
        s = numpy.sign(polarity)
 | 
			
		||||
 | 
			
		||||
        fpart = z - numpy.floor(z)
 | 
			
		||||
        mult = [1 - fpart, fpart][::s]  # reverses if s negative
 | 
			
		||||
        rectangle = numpy.array(rectangle, dtype=float)
 | 
			
		||||
        if s == 0:
 | 
			
		||||
            raise GridError('0 is not a valid polarity')
 | 
			
		||||
        if direction not in range(3):
 | 
			
		||||
            raise GridError(f'Invalid direction: {direction}')
 | 
			
		||||
        if rectangle[0, direction] != rectangle[1, direction]:
 | 
			
		||||
            raise GridError('Rectangle entries along extrusion direction do not match.')
 | 
			
		||||
 | 
			
		||||
        foreground = mult[0] * grid[tuple(ind)]
 | 
			
		||||
        ind[direction] += 1                         # type: ignore #(known safe)
 | 
			
		||||
        foreground += mult[1] * grid[tuple(ind)]
 | 
			
		||||
        center = rectangle.sum(axis=0) / 2.0
 | 
			
		||||
        center[direction] += s * distance / 2.0
 | 
			
		||||
 | 
			
		||||
        def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int_]:
 | 
			
		||||
            # transform from natural position to index
 | 
			
		||||
            xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i)
 | 
			
		||||
                                for qrs in zip(xs.flat, ys.flat, zs.flat)], dtype=int)
 | 
			
		||||
            # reshape to original shape and keep only in-plane components
 | 
			
		||||
            qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface)
 | 
			
		||||
            return foreground[qi, ri]
 | 
			
		||||
        surface = numpy.delete(range(3), direction)
 | 
			
		||||
 | 
			
		||||
        foreground_func.append(f_foreground)
 | 
			
		||||
        dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface]
 | 
			
		||||
        p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0] * 0.5,
 | 
			
		||||
                          numpy.array([-1, 1, 1, -1], dtype=float) * dim[1] * 0.5)).T
 | 
			
		||||
        thickness = distance
 | 
			
		||||
 | 
			
		||||
    self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func)
 | 
			
		||||
        foreground_func = []
 | 
			
		||||
        for i, grid in enumerate(cell_data):
 | 
			
		||||
            z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction]
 | 
			
		||||
 | 
			
		||||
            ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)]
 | 
			
		||||
 | 
			
		||||
            fpart = z - numpy.floor(z)
 | 
			
		||||
            mult = [1 - fpart, fpart][::s]  # reverses if s negative
 | 
			
		||||
 | 
			
		||||
            foreground = mult[0] * grid[tuple(ind)]
 | 
			
		||||
            ind[direction] += 1                         # type: ignore #(known safe)
 | 
			
		||||
            foreground += mult[1] * grid[tuple(ind)]
 | 
			
		||||
 | 
			
		||||
            def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int64]:            # noqa: ANN001
 | 
			
		||||
                # transform from natural position to index
 | 
			
		||||
                xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i)
 | 
			
		||||
                                    for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64)
 | 
			
		||||
                # reshape to original shape and keep only in-plane components
 | 
			
		||||
                qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface)
 | 
			
		||||
                return foreground[qi, ri]
 | 
			
		||||
 | 
			
		||||
            foreground_func.append(f_foreground)
 | 
			
		||||
 | 
			
		||||
        self.draw_polygon(cell_data, direction, center, p, thickness, foreground_func)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -29,7 +29,7 @@ if __name__ == '__main__':
 | 
			
		||||
    #         numpy.linspace(-5.5, 5.5, 10)]
 | 
			
		||||
 | 
			
		||||
    half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5]
 | 
			
		||||
    xyz3 = [[-x for x in half_x[::-1]] + [0] + half_x,
 | 
			
		||||
    xyz3 = [numpy.array([-x for x in half_x[::-1]] + [0] + half_x),
 | 
			
		||||
            numpy.linspace(-5.5, 5.5, 10),
 | 
			
		||||
            numpy.linspace(-5.5, 5.5, 10)]
 | 
			
		||||
    eg = Grid(xyz3)
 | 
			
		||||
@ -37,8 +37,8 @@ if __name__ == '__main__':
 | 
			
		||||
    # eg.draw_slab(Direction.z, 0, 10, 2)
 | 
			
		||||
    eg.save('/home/jan/Desktop/test.pickle')
 | 
			
		||||
    eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0,
 | 
			
		||||
                     thickness=10, num_poitns=1000, foreground=1)
 | 
			
		||||
                     thickness=10, num_points=1000, foreground=1)
 | 
			
		||||
    eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]],
 | 
			
		||||
                              direction=1, poalarity=+1, distance=5)
 | 
			
		||||
                              direction=1, polarity=+1, distance=5)
 | 
			
		||||
    eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2)
 | 
			
		||||
    eg.visualize_isosurface(egc, which_shifts=2)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										200
									
								
								gridlock/grid.py
									
									
									
									
									
								
							
							
						
						
									
										200
									
								
								gridlock/grid.py
									
									
									
									
									
								
							@ -1,4 +1,5 @@
 | 
			
		||||
from typing import Callable, Sequence, ClassVar, Self
 | 
			
		||||
from typing import ClassVar, Self
 | 
			
		||||
from collections.abc import Callable, Sequence
 | 
			
		||||
 | 
			
		||||
import numpy
 | 
			
		||||
from numpy.typing import NDArray, ArrayLike
 | 
			
		||||
@ -8,12 +9,15 @@ import warnings
 | 
			
		||||
import copy
 | 
			
		||||
 | 
			
		||||
from . import GridError
 | 
			
		||||
from .draw import GridDrawMixin
 | 
			
		||||
from .read import GridReadMixin
 | 
			
		||||
from .position import GridPosMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Grid:
 | 
			
		||||
class Grid(GridDrawMixin, GridReadMixin, GridPosMixin):
 | 
			
		||||
    """
 | 
			
		||||
    Simulation grid metadata for finite-difference simulations.
 | 
			
		||||
 | 
			
		||||
@ -70,193 +74,6 @@ class Grid:
 | 
			
		||||
        ], dtype=float)
 | 
			
		||||
    """Default shifts for Yee grid H-field"""
 | 
			
		||||
 | 
			
		||||
    from .draw import (
 | 
			
		||||
        draw_polygons, draw_polygon, draw_slab, draw_cuboid,
 | 
			
		||||
        draw_cylinder, draw_extrude_rectangle,
 | 
			
		||||
        )
 | 
			
		||||
    from .read import get_slice, visualize_slice, visualize_isosurface
 | 
			
		||||
    from .position import ind2pos, pos2ind
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def dxyz(self) -> list[NDArray]:
 | 
			
		||||
        """
 | 
			
		||||
        Cell sizes for each axis, no shifts applied
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List of 3 ndarrays of cell sizes
 | 
			
		||||
        """
 | 
			
		||||
        return [numpy.diff(ee) for ee in self.exyz]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def xyz(self) -> list[NDArray]:
 | 
			
		||||
        """
 | 
			
		||||
        Cell centers for each axis, no shifts applied
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List of 3 ndarrays of cell edges
 | 
			
		||||
        """
 | 
			
		||||
        return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def shape(self) -> NDArray[numpy.int_]:
 | 
			
		||||
        """
 | 
			
		||||
        The number of cells in x, y, and z
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            ndarray of [x_centers.size, y_centers.size, z_centers.size]
 | 
			
		||||
        """
 | 
			
		||||
        return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def num_grids(self) -> int:
 | 
			
		||||
        """
 | 
			
		||||
        The number of grids (number of shifts)
 | 
			
		||||
        """
 | 
			
		||||
        return self.shifts.shape[0]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def cell_data_shape(self):
 | 
			
		||||
        """
 | 
			
		||||
        The shape of the cell_data ndarray (num_grids, *self.shape).
 | 
			
		||||
        """
 | 
			
		||||
        return numpy.hstack((self.num_grids, self.shape))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def dxyz_with_ghost(self) -> list[NDArray]:
 | 
			
		||||
        """
 | 
			
		||||
        Gives dxyz with an additional 'ghost' cell at the end, whose value depends
 | 
			
		||||
         on whether or not the axis has periodic boundary conditions. See main description
 | 
			
		||||
         above to learn why this is necessary.
 | 
			
		||||
 | 
			
		||||
         If periodic, final edge shifts same amount as first
 | 
			
		||||
         Otherwise, final edge shifts same amount as second-to-last
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list of [dxs, dys, dzs] with each element same length as elements of `self.xyz`
 | 
			
		||||
        """
 | 
			
		||||
        el = [0 if p else -1 for p in self.periodic]
 | 
			
		||||
        return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el)]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def center(self) -> NDArray[numpy.float64]:
 | 
			
		||||
        """
 | 
			
		||||
        Center position of the entire grid, no shifts applied
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            ndarray of [x_center, y_center, z_center]
 | 
			
		||||
        """
 | 
			
		||||
        # center is just average of first and last xyz, which is just the average of the
 | 
			
		||||
        #  first two and last two exyz
 | 
			
		||||
        centers = [(self.exyz[a][:2] + self.exyz[a][-2:]).sum() / 4.0 for a in range(3)]
 | 
			
		||||
        return numpy.array(centers, dtype=float)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def dxyz_limits(self) -> tuple[NDArray, NDArray]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the minimum and maximum cell size for each axis, as a tuple of two 3-element
 | 
			
		||||
         ndarrays. No shifts are applied, so these are extreme bounds on these values (as a
 | 
			
		||||
         weighted average is performed when shifting).
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Tuple of 2 ndarrays, `d_min=[min(dx), min(dy), min(dz)]` and `d_max=[...]`
 | 
			
		||||
        """
 | 
			
		||||
        d_min = numpy.array([min(self.dxyz[a]) for a in range(3)], dtype=float)
 | 
			
		||||
        d_max = numpy.array([max(self.dxyz[a]) for a in range(3)], dtype=float)
 | 
			
		||||
        return d_min, d_max
 | 
			
		||||
 | 
			
		||||
    def shifted_exyz(self, which_shifts: int | None) -> list[NDArray]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns edges for which_shifts.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            which_shifts: Which grid (which shifts) to use, or `None` for unshifted
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List of 3 ndarrays of cell edges
 | 
			
		||||
        """
 | 
			
		||||
        if which_shifts is None:
 | 
			
		||||
            return self.exyz
 | 
			
		||||
        dxyz = self.dxyz_with_ghost
 | 
			
		||||
        shifts = self.shifts[which_shifts, :]
 | 
			
		||||
 | 
			
		||||
        # If shift is negative, use left cell's dx to determine shift
 | 
			
		||||
        for a in range(3):
 | 
			
		||||
            if shifts[a] < 0:
 | 
			
		||||
                dxyz[a] = numpy.roll(dxyz[a], 1)
 | 
			
		||||
 | 
			
		||||
        return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)]
 | 
			
		||||
 | 
			
		||||
    def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns cell sizes for `which_shifts`.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            which_shifts: Which grid (which shifts) to use, or `None` for unshifted
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List of 3 ndarrays of cell sizes
 | 
			
		||||
        """
 | 
			
		||||
        if which_shifts is None:
 | 
			
		||||
            return self.dxyz
 | 
			
		||||
        shifts = self.shifts[which_shifts, :]
 | 
			
		||||
        dxyz = self.dxyz_with_ghost
 | 
			
		||||
 | 
			
		||||
        # If shift is negative, use left cell's dx to determine size
 | 
			
		||||
        sdxyz = []
 | 
			
		||||
        for a in range(3):
 | 
			
		||||
            if shifts[a] < 0:
 | 
			
		||||
                roll_dxyz = numpy.roll(dxyz[a], 1)
 | 
			
		||||
                abs_shift = numpy.abs(shifts[a])
 | 
			
		||||
                sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift))
 | 
			
		||||
            else:
 | 
			
		||||
                sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a])
 | 
			
		||||
 | 
			
		||||
        return sdxyz
 | 
			
		||||
 | 
			
		||||
    def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns cell centers for `which_shifts`.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            which_shifts: Which grid (which shifts) to use, or `None` for unshifted
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List of 3 ndarrays of cell centers
 | 
			
		||||
        """
 | 
			
		||||
        if which_shifts is None:
 | 
			
		||||
            return self.xyz
 | 
			
		||||
        exyz = self.shifted_exyz(which_shifts)
 | 
			
		||||
        dxyz = self.shifted_dxyz(which_shifts)
 | 
			
		||||
        return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)]
 | 
			
		||||
 | 
			
		||||
    def autoshifted_dxyz(self) -> list[NDArray[numpy.float64]]:
 | 
			
		||||
        """
 | 
			
		||||
        Return cell widths, with each dimension shifted by the corresponding shifts.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            `[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]`
 | 
			
		||||
        """
 | 
			
		||||
        if self.num_grids != 3:
 | 
			
		||||
            raise GridError('Autoshifting requires exactly 3 grids')
 | 
			
		||||
        return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
 | 
			
		||||
 | 
			
		||||
    def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray:
 | 
			
		||||
        """
 | 
			
		||||
        Allocate an ndarray for storing grid data.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            fill_value: Value to initialize the grid to. If None, an
 | 
			
		||||
                uninitialized array is returned.
 | 
			
		||||
            dtype: Numpy dtype for the array. Default is `numpy.float32`.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            The allocated array
 | 
			
		||||
        """
 | 
			
		||||
        if fill_value is None:
 | 
			
		||||
            return numpy.empty(self.cell_data_shape, dtype=dtype)
 | 
			
		||||
        else:
 | 
			
		||||
            return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
            self,
 | 
			
		||||
            pixel_edge_coordinates: Sequence[ArrayLike],
 | 
			
		||||
@ -277,11 +94,12 @@ class Grid:
 | 
			
		||||
        Raises:
 | 
			
		||||
            `GridError` on invalid input
 | 
			
		||||
        """
 | 
			
		||||
        self.exyz = [numpy.unique(pixel_edge_coordinates[i]) for i in range(3)]
 | 
			
		||||
        edge_arrs = [numpy.array(cc, copy=False) for cc in pixel_edge_coordinates]
 | 
			
		||||
        self.exyz = [numpy.unique(edges) for edges in edge_arrs]
 | 
			
		||||
        self.shifts = numpy.array(shifts, dtype=float)
 | 
			
		||||
 | 
			
		||||
        for i in range(3):
 | 
			
		||||
            if len(self.exyz[i]) != len(pixel_edge_coordinates[i]):
 | 
			
		||||
            if self.exyz[i].size != edge_arrs[i].size:
 | 
			
		||||
                warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2)
 | 
			
		||||
 | 
			
		||||
        if isinstance(periodic, bool):
 | 
			
		||||
 | 
			
		||||
@ -5,112 +5,114 @@ import numpy
 | 
			
		||||
from numpy.typing import NDArray, ArrayLike
 | 
			
		||||
 | 
			
		||||
from . import GridError
 | 
			
		||||
from .base import GridBase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ind2pos(
 | 
			
		||||
        self,
 | 
			
		||||
        ind: NDArray,
 | 
			
		||||
        which_shifts: int | None = None,
 | 
			
		||||
        round_ind: bool = True,
 | 
			
		||||
        check_bounds: bool = True
 | 
			
		||||
        ) -> NDArray[numpy.float64]:
 | 
			
		||||
    """
 | 
			
		||||
    Returns the natural position corresponding to the specified cell center indices.
 | 
			
		||||
     The resulting position is clipped to the bounds of the grid
 | 
			
		||||
    (to cell centers if `round_ind=True`, or cell outer edges if `round_ind=False`)
 | 
			
		||||
class GridPosMixin(GridBase):
 | 
			
		||||
    def ind2pos(
 | 
			
		||||
            self,
 | 
			
		||||
            ind: NDArray,
 | 
			
		||||
            which_shifts: int | None = None,
 | 
			
		||||
            round_ind: bool = True,
 | 
			
		||||
            check_bounds: bool = True
 | 
			
		||||
            ) -> NDArray[numpy.float64]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the natural position corresponding to the specified cell center indices.
 | 
			
		||||
         The resulting position is clipped to the bounds of the grid
 | 
			
		||||
        (to cell centers if `round_ind=True`, or cell outer edges if `round_ind=False`)
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        ind: Indices of the position. Can be fractional. (3-element ndarray or list)
 | 
			
		||||
        which_shifts: which grid number (`shifts`) to use
 | 
			
		||||
        round_ind: Whether to round ind to the nearest integer position before indexing
 | 
			
		||||
            (default `True`)
 | 
			
		||||
        check_bounds: Whether to raise an `GridError` if the provided ind is outside of
 | 
			
		||||
            the grid, as defined above (centers if `round_ind`, else edges) (default `True`)
 | 
			
		||||
        Args:
 | 
			
		||||
            ind: Indices of the position. Can be fractional. (3-element ndarray or list)
 | 
			
		||||
            which_shifts: which grid number (`shifts`) to use
 | 
			
		||||
            round_ind: Whether to round ind to the nearest integer position before indexing
 | 
			
		||||
                (default `True`)
 | 
			
		||||
            check_bounds: Whether to raise an `GridError` if the provided ind is outside of
 | 
			
		||||
                the grid, as defined above (centers if `round_ind`, else edges) (default `True`)
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        3-element ndarray specifying the natural position
 | 
			
		||||
        Returns:
 | 
			
		||||
            3-element ndarray specifying the natural position
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
        `GridError` if invalid `which_shifts`
 | 
			
		||||
        `GridError` if `check_bounds` and out of bounds
 | 
			
		||||
    """
 | 
			
		||||
    if which_shifts is not None and which_shifts >= self.shifts.shape[0]:
 | 
			
		||||
        raise GridError('Invalid shifts')
 | 
			
		||||
    ind = numpy.array(ind, dtype=float)
 | 
			
		||||
        Raises:
 | 
			
		||||
            `GridError` if invalid `which_shifts`
 | 
			
		||||
            `GridError` if `check_bounds` and out of bounds
 | 
			
		||||
        """
 | 
			
		||||
        if which_shifts is not None and which_shifts >= self.shifts.shape[0]:
 | 
			
		||||
            raise GridError('Invalid shifts')
 | 
			
		||||
        ind = numpy.array(ind, dtype=float)
 | 
			
		||||
 | 
			
		||||
        if check_bounds:
 | 
			
		||||
            if round_ind:
 | 
			
		||||
                low_bound = 0.0
 | 
			
		||||
                high_bound = -1.0
 | 
			
		||||
            else:
 | 
			
		||||
                low_bound = -0.5
 | 
			
		||||
                high_bound = -0.5
 | 
			
		||||
            if (ind < low_bound).any() or (ind > self.shape - high_bound).any():
 | 
			
		||||
                raise GridError(f'Position outside of grid: {ind}')
 | 
			
		||||
 | 
			
		||||
    if check_bounds:
 | 
			
		||||
        if round_ind:
 | 
			
		||||
            low_bound = 0.0
 | 
			
		||||
            high_bound = -1.0
 | 
			
		||||
            rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1)
 | 
			
		||||
            sxyz = self.shifted_xyz(which_shifts)
 | 
			
		||||
            position = [sxyz[a][rind[a]].astype(int) for a in range(3)]
 | 
			
		||||
        else:
 | 
			
		||||
            low_bound = -0.5
 | 
			
		||||
            high_bound = -0.5
 | 
			
		||||
        if (ind < low_bound).any() or (ind > self.shape - high_bound).any():
 | 
			
		||||
            raise GridError(f'Position outside of grid: {ind}')
 | 
			
		||||
            sexyz = self.shifted_exyz(which_shifts)
 | 
			
		||||
            position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a])
 | 
			
		||||
                        for a in range(3)]
 | 
			
		||||
        return numpy.array(position, dtype=float)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def pos2ind(
 | 
			
		||||
            self,
 | 
			
		||||
            r: ArrayLike,
 | 
			
		||||
            which_shifts: int | None,
 | 
			
		||||
            round_ind: bool = True,
 | 
			
		||||
            check_bounds: bool = True
 | 
			
		||||
            ) -> NDArray[numpy.float64]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the cell-center indices corresponding to the specified natural position.
 | 
			
		||||
             The resulting position is clipped to within the outer centers of the grid.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            r: Natural position that we will convert into indices (3-element ndarray or list)
 | 
			
		||||
            which_shifts: which grid number (`shifts`) to use
 | 
			
		||||
            round_ind: Whether to round the returned indices to the nearest integers.
 | 
			
		||||
            check_bounds: Whether to throw an `GridError` if `r` is outside the grid edges
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            3-element ndarray specifying the indices
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            `GridError` if invalid `which_shifts`
 | 
			
		||||
            `GridError` if `check_bounds` and out of bounds
 | 
			
		||||
        """
 | 
			
		||||
        r = numpy.squeeze(r)
 | 
			
		||||
        if r.size != 3:
 | 
			
		||||
            raise GridError(f'r must be 3-element vector: {r}')
 | 
			
		||||
 | 
			
		||||
        if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]):
 | 
			
		||||
            raise GridError(f'Invalid which_shifts: {which_shifts}')
 | 
			
		||||
 | 
			
		||||
    if round_ind:
 | 
			
		||||
        rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1)
 | 
			
		||||
        sxyz = self.shifted_xyz(which_shifts)
 | 
			
		||||
        position = [sxyz[a][rind[a]].astype(int) for a in range(3)]
 | 
			
		||||
    else:
 | 
			
		||||
        sexyz = self.shifted_exyz(which_shifts)
 | 
			
		||||
        position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a])
 | 
			
		||||
                    for a in range(3)]
 | 
			
		||||
    return numpy.array(position, dtype=float)
 | 
			
		||||
 | 
			
		||||
        if check_bounds:
 | 
			
		||||
            for a in range(3):
 | 
			
		||||
                if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]):
 | 
			
		||||
                    raise GridError(f'Position[{a}] outside of grid!')
 | 
			
		||||
 | 
			
		||||
def pos2ind(
 | 
			
		||||
        self,
 | 
			
		||||
        r: ArrayLike,
 | 
			
		||||
        which_shifts: int | None,
 | 
			
		||||
        round_ind: bool = True,
 | 
			
		||||
        check_bounds: bool = True
 | 
			
		||||
        ) -> NDArray[numpy.float64]:
 | 
			
		||||
    """
 | 
			
		||||
    Returns the cell-center indices corresponding to the specified natural position.
 | 
			
		||||
         The resulting position is clipped to within the outer centers of the grid.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        r: Natural position that we will convert into indices (3-element ndarray or list)
 | 
			
		||||
        which_shifts: which grid number (`shifts`) to use
 | 
			
		||||
        round_ind: Whether to round the returned indices to the nearest integers.
 | 
			
		||||
        check_bounds: Whether to throw an `GridError` if `r` is outside the grid edges
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        3-element ndarray specifying the indices
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
        `GridError` if invalid `which_shifts`
 | 
			
		||||
        `GridError` if `check_bounds` and out of bounds
 | 
			
		||||
    """
 | 
			
		||||
    r = numpy.squeeze(r)
 | 
			
		||||
    if r.size != 3:
 | 
			
		||||
        raise GridError(f'r must be 3-element vector: {r}')
 | 
			
		||||
 | 
			
		||||
    if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]):
 | 
			
		||||
        raise GridError(f'Invalid which_shifts: {which_shifts}')
 | 
			
		||||
 | 
			
		||||
    sexyz = self.shifted_exyz(which_shifts)
 | 
			
		||||
 | 
			
		||||
    if check_bounds:
 | 
			
		||||
        grid_pos = numpy.zeros((3,))
 | 
			
		||||
        for a in range(3):
 | 
			
		||||
            if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]):
 | 
			
		||||
                raise GridError(f'Position[{a}] outside of grid!')
 | 
			
		||||
            xi = numpy.digitize(r[a], sexyz[a]) - 1  # Figure out which cell we're in
 | 
			
		||||
            xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2)  # Clip back into grid bounds
 | 
			
		||||
 | 
			
		||||
    grid_pos = numpy.zeros((3,))
 | 
			
		||||
    for a in range(3):
 | 
			
		||||
        xi = numpy.digitize(r[a], sexyz[a]) - 1  # Figure out which cell we're in
 | 
			
		||||
        xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2)  # Clip back into grid bounds
 | 
			
		||||
            # No need to interpolate if round_ind is true or we were outside the grid
 | 
			
		||||
            if round_ind or xi != xi_clipped:
 | 
			
		||||
                grid_pos[a] = xi_clipped
 | 
			
		||||
            else:
 | 
			
		||||
                # Interpolate
 | 
			
		||||
                x = self.shifted_xyz(which_shifts)[a][xi]
 | 
			
		||||
                dx = self.shifted_dxyz(which_shifts)[a][xi]
 | 
			
		||||
                f = (r[a] - x) / dx
 | 
			
		||||
 | 
			
		||||
        # No need to interpolate if round_ind is true or we were outside the grid
 | 
			
		||||
        if round_ind or xi != xi_clipped:
 | 
			
		||||
            grid_pos[a] = xi_clipped
 | 
			
		||||
        else:
 | 
			
		||||
            # Interpolate
 | 
			
		||||
            x = self.shifted_xyz(which_shifts)[a][xi]
 | 
			
		||||
            dx = self.shifted_dxyz(which_shifts)[a][xi]
 | 
			
		||||
            f = (r[a] - x) / dx
 | 
			
		||||
 | 
			
		||||
            # Clip to centers
 | 
			
		||||
            grid_pos[a] = numpy.clip(xi + f, 0, self.shape[a] - 1)
 | 
			
		||||
    return grid_pos
 | 
			
		||||
                # Clip to centers
 | 
			
		||||
                grid_pos[a] = numpy.clip(xi + f, 0, self.shape[a] - 1)
 | 
			
		||||
        return grid_pos
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										303
									
								
								gridlock/read.py
									
									
									
									
									
								
							
							
						
						
									
										303
									
								
								gridlock/read.py
									
									
									
									
									
								
							@ -7,6 +7,7 @@ import numpy
 | 
			
		||||
from numpy.typing import NDArray
 | 
			
		||||
 | 
			
		||||
from . import GridError
 | 
			
		||||
from .position import GridPosMixin
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    import matplotlib.axes
 | 
			
		||||
@ -18,185 +19,187 @@ if TYPE_CHECKING:
 | 
			
		||||
# .visualize_isosurface uses mpl_toolkits.mplot3d
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_slice(
 | 
			
		||||
        self,
 | 
			
		||||
        cell_data: NDArray,
 | 
			
		||||
        surface_normal: int,
 | 
			
		||||
        center: float,
 | 
			
		||||
        which_shifts: int = 0,
 | 
			
		||||
        sample_period: int = 1
 | 
			
		||||
        ) -> NDArray:
 | 
			
		||||
    """
 | 
			
		||||
    Retrieve a slice of a grid.
 | 
			
		||||
    Interpolates if given a position between two planes.
 | 
			
		||||
class GridReadMixin(GridPosMixin):
 | 
			
		||||
    def get_slice(
 | 
			
		||||
            self,
 | 
			
		||||
            cell_data: NDArray,
 | 
			
		||||
            surface_normal: int,
 | 
			
		||||
            center: float,
 | 
			
		||||
            which_shifts: int = 0,
 | 
			
		||||
            sample_period: int = 1
 | 
			
		||||
            ) -> NDArray:
 | 
			
		||||
        """
 | 
			
		||||
        Retrieve a slice of a grid.
 | 
			
		||||
        Interpolates if given a position between two planes.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        cell_data: Cell data to slice
 | 
			
		||||
        surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
 | 
			
		||||
        center: Scalar specifying position along surface_normal axis.
 | 
			
		||||
        which_shifts: Which grid to display. Default is the first grid (0).
 | 
			
		||||
        sample_period: Period for down-sampling the image. Default 1 (disabled)
 | 
			
		||||
        Args:
 | 
			
		||||
            cell_data: Cell data to slice
 | 
			
		||||
            surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
 | 
			
		||||
            center: Scalar specifying position along surface_normal axis.
 | 
			
		||||
            which_shifts: Which grid to display. Default is the first grid (0).
 | 
			
		||||
            sample_period: Period for down-sampling the image. Default 1 (disabled)
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Array containing the portion of the grid.
 | 
			
		||||
    """
 | 
			
		||||
    if numpy.size(center) != 1 or not numpy.isreal(center):
 | 
			
		||||
        raise GridError('center must be a real scalar')
 | 
			
		||||
        Returns:
 | 
			
		||||
            Array containing the portion of the grid.
 | 
			
		||||
        """
 | 
			
		||||
        if numpy.size(center) != 1 or not numpy.isreal(center):
 | 
			
		||||
            raise GridError('center must be a real scalar')
 | 
			
		||||
 | 
			
		||||
    sp = round(sample_period)
 | 
			
		||||
    if sp <= 0:
 | 
			
		||||
        raise GridError('sample_period must be positive')
 | 
			
		||||
        sp = round(sample_period)
 | 
			
		||||
        if sp <= 0:
 | 
			
		||||
            raise GridError('sample_period must be positive')
 | 
			
		||||
 | 
			
		||||
    if numpy.size(which_shifts) != 1 or which_shifts < 0:
 | 
			
		||||
        raise GridError('Invalid which_shifts')
 | 
			
		||||
        if numpy.size(which_shifts) != 1 or which_shifts < 0:
 | 
			
		||||
            raise GridError('Invalid which_shifts')
 | 
			
		||||
 | 
			
		||||
    if surface_normal not in range(3):
 | 
			
		||||
        raise GridError('Invalid surface_normal direction')
 | 
			
		||||
        if surface_normal not in range(3):
 | 
			
		||||
            raise GridError('Invalid surface_normal direction')
 | 
			
		||||
 | 
			
		||||
    surface = numpy.delete(range(3), surface_normal)
 | 
			
		||||
        surface = numpy.delete(range(3), surface_normal)
 | 
			
		||||
 | 
			
		||||
    # Extract indices and weights of planes
 | 
			
		||||
    center3 = numpy.insert([0, 0], surface_normal, (center,))
 | 
			
		||||
    center_index = self.pos2ind(center3, which_shifts,
 | 
			
		||||
                                round_ind=False, check_bounds=False)[surface_normal]
 | 
			
		||||
    centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
 | 
			
		||||
    if len(centers) == 2:
 | 
			
		||||
        fpart = center_index - numpy.floor(center_index)
 | 
			
		||||
        w = [1 - fpart, fpart]  # longer distance -> less weight
 | 
			
		||||
    else:
 | 
			
		||||
        w = [1]
 | 
			
		||||
        # Extract indices and weights of planes
 | 
			
		||||
        center3 = numpy.insert([0, 0], surface_normal, (center,))
 | 
			
		||||
        center_index = self.pos2ind(center3, which_shifts,
 | 
			
		||||
                                    round_ind=False, check_bounds=False)[surface_normal]
 | 
			
		||||
        centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
 | 
			
		||||
        if len(centers) == 2:
 | 
			
		||||
            fpart = center_index - numpy.floor(center_index)
 | 
			
		||||
            w = [1 - fpart, fpart]  # longer distance -> less weight
 | 
			
		||||
        else:
 | 
			
		||||
            w = [1]
 | 
			
		||||
 | 
			
		||||
    c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1])
 | 
			
		||||
    if center < c_min or center > c_max:
 | 
			
		||||
        raise GridError('Coordinate of selected plane must be within simulation domain')
 | 
			
		||||
        c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1])
 | 
			
		||||
        if center < c_min or center > c_max:
 | 
			
		||||
            raise GridError('Coordinate of selected plane must be within simulation domain')
 | 
			
		||||
 | 
			
		||||
    # Extract grid values from planes above and below visualized slice
 | 
			
		||||
    sliced_grid = numpy.zeros(self.shape[surface])
 | 
			
		||||
    for ci, weight in zip(centers, w):
 | 
			
		||||
        s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3))
 | 
			
		||||
        sliced_grid += weight * cell_data[which_shifts][tuple(s)]
 | 
			
		||||
        # Extract grid values from planes above and below visualized slice
 | 
			
		||||
        sliced_grid = numpy.zeros(self.shape[surface])
 | 
			
		||||
        for ci, weight in zip(centers, w, strict=True):
 | 
			
		||||
            s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3))
 | 
			
		||||
            sliced_grid += weight * cell_data[which_shifts][tuple(s)]
 | 
			
		||||
 | 
			
		||||
    # Remove extra dimensions
 | 
			
		||||
    sliced_grid = numpy.squeeze(sliced_grid)
 | 
			
		||||
        # Remove extra dimensions
 | 
			
		||||
        sliced_grid = numpy.squeeze(sliced_grid)
 | 
			
		||||
 | 
			
		||||
    return sliced_grid
 | 
			
		||||
        return sliced_grid
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def visualize_slice(
 | 
			
		||||
        self,
 | 
			
		||||
        cell_data: NDArray,
 | 
			
		||||
        surface_normal: int,
 | 
			
		||||
        center: float,
 | 
			
		||||
        which_shifts: int = 0,
 | 
			
		||||
        sample_period: int = 1,
 | 
			
		||||
        finalize: bool = True,
 | 
			
		||||
        pcolormesh_args: dict[str, Any] | None = None,
 | 
			
		||||
        ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']:
 | 
			
		||||
    """
 | 
			
		||||
    Visualize a slice of a grid.
 | 
			
		||||
    Interpolates if given a position between two planes.
 | 
			
		||||
    def visualize_slice(
 | 
			
		||||
            self,
 | 
			
		||||
            cell_data: NDArray,
 | 
			
		||||
            surface_normal: int,
 | 
			
		||||
            center: float,
 | 
			
		||||
            which_shifts: int = 0,
 | 
			
		||||
            sample_period: int = 1,
 | 
			
		||||
            finalize: bool = True,
 | 
			
		||||
            pcolormesh_args: dict[str, Any] | None = None,
 | 
			
		||||
            ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']:
 | 
			
		||||
        """
 | 
			
		||||
        Visualize a slice of a grid.
 | 
			
		||||
        Interpolates if given a position between two planes.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
 | 
			
		||||
        center: Scalar specifying position along surface_normal axis.
 | 
			
		||||
        which_shifts: Which grid to display. Default is the first grid (0).
 | 
			
		||||
        sample_period: Period for down-sampling the image. Default 1 (disabled)
 | 
			
		||||
        finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
 | 
			
		||||
        Args:
 | 
			
		||||
            surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
 | 
			
		||||
            center: Scalar specifying position along surface_normal axis.
 | 
			
		||||
            which_shifts: Which grid to display. Default is the first grid (0).
 | 
			
		||||
            sample_period: Period for down-sampling the image. Default 1 (disabled)
 | 
			
		||||
            finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        (Figure, Axes)
 | 
			
		||||
    """
 | 
			
		||||
    from matplotlib import pyplot
 | 
			
		||||
        Returns:
 | 
			
		||||
            (Figure, Axes)
 | 
			
		||||
        """
 | 
			
		||||
        from matplotlib import pyplot
 | 
			
		||||
 | 
			
		||||
    if pcolormesh_args is None:
 | 
			
		||||
        pcolormesh_args = {}
 | 
			
		||||
        if pcolormesh_args is None:
 | 
			
		||||
            pcolormesh_args = {}
 | 
			
		||||
 | 
			
		||||
    grid_slice = self.get_slice(cell_data=cell_data,
 | 
			
		||||
                                surface_normal=surface_normal,
 | 
			
		||||
                                center=center,
 | 
			
		||||
                                which_shifts=which_shifts,
 | 
			
		||||
                                sample_period=sample_period)
 | 
			
		||||
        grid_slice = self.get_slice(cell_data=cell_data,
 | 
			
		||||
                                    surface_normal=surface_normal,
 | 
			
		||||
                                    center=center,
 | 
			
		||||
                                    which_shifts=which_shifts,
 | 
			
		||||
                                    sample_period=sample_period)
 | 
			
		||||
 | 
			
		||||
    surface = numpy.delete(range(3), surface_normal)
 | 
			
		||||
        surface = numpy.delete(range(3), surface_normal)
 | 
			
		||||
 | 
			
		||||
    x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
 | 
			
		||||
    xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
 | 
			
		||||
    x_label, y_label = ('xyz'[a] for a in surface)
 | 
			
		||||
        x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
 | 
			
		||||
        xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
 | 
			
		||||
        x_label, y_label = ('xyz'[a] for a in surface)
 | 
			
		||||
 | 
			
		||||
    fig, ax = pyplot.subplots()
 | 
			
		||||
    mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args)
 | 
			
		||||
    fig.colorbar(mappable)
 | 
			
		||||
    ax.set_aspect('equal', adjustable='box')
 | 
			
		||||
    ax.set_xlabel(x_label)
 | 
			
		||||
    ax.set_ylabel(y_label)
 | 
			
		||||
    if finalize:
 | 
			
		||||
        pyplot.show()
 | 
			
		||||
        fig, ax = pyplot.subplots()
 | 
			
		||||
        mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args)
 | 
			
		||||
        fig.colorbar(mappable)
 | 
			
		||||
        ax.set_aspect('equal', adjustable='box')
 | 
			
		||||
        ax.set_xlabel(x_label)
 | 
			
		||||
        ax.set_ylabel(y_label)
 | 
			
		||||
        if finalize:
 | 
			
		||||
            pyplot.show()
 | 
			
		||||
 | 
			
		||||
    return fig, ax
 | 
			
		||||
        return fig, ax
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def visualize_isosurface(
 | 
			
		||||
        self,
 | 
			
		||||
        cell_data: NDArray,
 | 
			
		||||
        level: float | None = None,
 | 
			
		||||
        which_shifts: int = 0,
 | 
			
		||||
        sample_period: int = 1,
 | 
			
		||||
        show_edges: bool = True,
 | 
			
		||||
        finalize: bool = True,
 | 
			
		||||
        ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']:
 | 
			
		||||
    """
 | 
			
		||||
    Draw an isosurface plot of the device.
 | 
			
		||||
    def visualize_isosurface(
 | 
			
		||||
            self,
 | 
			
		||||
            cell_data: NDArray,
 | 
			
		||||
            level: float | None = None,
 | 
			
		||||
            which_shifts: int = 0,
 | 
			
		||||
            sample_period: int = 1,
 | 
			
		||||
            show_edges: bool = True,
 | 
			
		||||
            finalize: bool = True,
 | 
			
		||||
            ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']:
 | 
			
		||||
        """
 | 
			
		||||
        Draw an isosurface plot of the device.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        cell_data: Cell data to visualize
 | 
			
		||||
        level: Value at which to find isosurface. Default (None) uses mean value in grid.
 | 
			
		||||
        which_shifts: Which grid to display. Default is the first grid (0).
 | 
			
		||||
        sample_period: Period for down-sampling the image. Default 1 (disabled)
 | 
			
		||||
        show_edges: Whether to draw triangle edges. Default `True`
 | 
			
		||||
        finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
 | 
			
		||||
        Args:
 | 
			
		||||
            cell_data: Cell data to visualize
 | 
			
		||||
            level: Value at which to find isosurface. Default (None) uses mean value in grid.
 | 
			
		||||
            which_shifts: Which grid to display. Default is the first grid (0).
 | 
			
		||||
            sample_period: Period for down-sampling the image. Default 1 (disabled)
 | 
			
		||||
            show_edges: Whether to draw triangle edges. Default `True`
 | 
			
		||||
            finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        (Figure, Axes)
 | 
			
		||||
    """
 | 
			
		||||
    from matplotlib import pyplot
 | 
			
		||||
    import skimage.measure
 | 
			
		||||
    # Claims to be unused, but needed for subplot(projection='3d')
 | 
			
		||||
    from mpl_toolkits.mplot3d import Axes3D
 | 
			
		||||
        Returns:
 | 
			
		||||
            (Figure, Axes)
 | 
			
		||||
        """
 | 
			
		||||
        from matplotlib import pyplot
 | 
			
		||||
        import skimage.measure
 | 
			
		||||
        # Claims to be unused, but needed for subplot(projection='3d')
 | 
			
		||||
        from mpl_toolkits.mplot3d import Axes3D
 | 
			
		||||
        del Axes3D      # imported for side effects only
 | 
			
		||||
 | 
			
		||||
    # Get data from cell_data
 | 
			
		||||
    grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period]
 | 
			
		||||
    if level is None:
 | 
			
		||||
        level = grid.mean()
 | 
			
		||||
        # Get data from cell_data
 | 
			
		||||
        grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period]
 | 
			
		||||
        if level is None:
 | 
			
		||||
            level = grid.mean()
 | 
			
		||||
 | 
			
		||||
    # Find isosurface with marching cubes
 | 
			
		||||
    verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level)
 | 
			
		||||
        # Find isosurface with marching cubes
 | 
			
		||||
        verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level)
 | 
			
		||||
 | 
			
		||||
    # Convert vertices from index to position
 | 
			
		||||
    pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False)
 | 
			
		||||
                             for i in range(verts.shape[0])], dtype=float)
 | 
			
		||||
    xs, ys, zs = (pos_verts[:, a] for a in range(3))
 | 
			
		||||
        # Convert vertices from index to position
 | 
			
		||||
        pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False)
 | 
			
		||||
                                 for i in range(verts.shape[0])], dtype=float)
 | 
			
		||||
        xs, ys, zs = (pos_verts[:, a] for a in range(3))
 | 
			
		||||
 | 
			
		||||
    # Draw the plot
 | 
			
		||||
    fig = pyplot.figure()
 | 
			
		||||
    ax = fig.add_subplot(111, projection='3d')
 | 
			
		||||
    if show_edges:
 | 
			
		||||
        ax.plot_trisurf(xs, ys, faces, zs)
 | 
			
		||||
    else:
 | 
			
		||||
        ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none')
 | 
			
		||||
        # Draw the plot
 | 
			
		||||
        fig = pyplot.figure()
 | 
			
		||||
        ax = fig.add_subplot(111, projection='3d')
 | 
			
		||||
        if show_edges:
 | 
			
		||||
            ax.plot_trisurf(xs, ys, faces, zs)
 | 
			
		||||
        else:
 | 
			
		||||
            ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none')
 | 
			
		||||
 | 
			
		||||
    # Add a fake plot of a cube to force the axes to be equal lengths
 | 
			
		||||
    max_range = numpy.array([xs.max() - xs.min(),
 | 
			
		||||
                             ys.max() - ys.min(),
 | 
			
		||||
                             zs.max() - zs.min()], dtype=float).max()
 | 
			
		||||
    mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2]
 | 
			
		||||
    xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min())
 | 
			
		||||
    ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min())
 | 
			
		||||
    zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min())
 | 
			
		||||
    # Comment or uncomment following both lines to test the fake bounding box:
 | 
			
		||||
    for xb, yb, zb in zip(xbs, ybs, zbs):
 | 
			
		||||
        ax.plot([xb], [yb], [zb], 'w')
 | 
			
		||||
        # Add a fake plot of a cube to force the axes to be equal lengths
 | 
			
		||||
        max_range = numpy.array([xs.max() - xs.min(),
 | 
			
		||||
                                 ys.max() - ys.min(),
 | 
			
		||||
                                 zs.max() - zs.min()], dtype=float).max()
 | 
			
		||||
        mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2]
 | 
			
		||||
        xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min())
 | 
			
		||||
        ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min())
 | 
			
		||||
        zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min())
 | 
			
		||||
        # Comment or uncomment following both lines to test the fake bounding box:
 | 
			
		||||
        for xb, yb, zb in zip(xbs, ybs, zbs, strict=True):
 | 
			
		||||
            ax.plot([xb], [yb], [zb], 'w')
 | 
			
		||||
 | 
			
		||||
    if finalize:
 | 
			
		||||
        pyplot.show()
 | 
			
		||||
        if finalize:
 | 
			
		||||
            pyplot.show()
 | 
			
		||||
 | 
			
		||||
    return fig, ax
 | 
			
		||||
        return fig, ax
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
import pytest       # type: ignore
 | 
			
		||||
import numpy
 | 
			
		||||
from numpy.testing import assert_allclose, assert_array_equal
 | 
			
		||||
from numpy.testing import assert_allclose       #, assert_array_equal
 | 
			
		||||
 | 
			
		||||
from .. import Grid
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -53,3 +53,47 @@ visualization-isosurface = [
 | 
			
		||||
    "skimage>=0.13",
 | 
			
		||||
    "mpl_toolkits",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[tool.ruff]
 | 
			
		||||
exclude = [
 | 
			
		||||
    ".git",
 | 
			
		||||
    "dist",
 | 
			
		||||
    ]
 | 
			
		||||
line-length = 145
 | 
			
		||||
indent-width = 4
 | 
			
		||||
lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
 | 
			
		||||
lint.select = [
 | 
			
		||||
    "NPY", "E", "F", "W", "B", "ANN", "UP", "SLOT", "SIM", "LOG",
 | 
			
		||||
    "C4", "ISC", "PIE", "PT", "RET", "TCH", "PTH", "INT",
 | 
			
		||||
    "ARG", "PL", "R", "TRY",
 | 
			
		||||
    "G010", "G101", "G201", "G202",
 | 
			
		||||
    "Q002", "Q003", "Q004",
 | 
			
		||||
    ]
 | 
			
		||||
lint.ignore = [
 | 
			
		||||
    #"ANN001",   # No annotation
 | 
			
		||||
    "ANN002",   # *args
 | 
			
		||||
    "ANN003",   # **kwargs
 | 
			
		||||
    "ANN401",   # Any
 | 
			
		||||
    "ANN101",   # self: Self
 | 
			
		||||
    "SIM108",   # single-line if / else assignment
 | 
			
		||||
    "RET504",   # x=y+z; return x
 | 
			
		||||
    "PIE790",   # unnecessary pass
 | 
			
		||||
    "ISC003",   # non-implicit string concatenation
 | 
			
		||||
    "C408",     # dict(x=y) instead of {'x': y}
 | 
			
		||||
    "PLR09",    # Too many xxx
 | 
			
		||||
    "PLR2004",  # magic number
 | 
			
		||||
    "PLC0414",  # import x as x
 | 
			
		||||
    "TRY003",   # Long exception message
 | 
			
		||||
    "PTH123",   # open()
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[[tool.mypy.overrides]]
 | 
			
		||||
module = [
 | 
			
		||||
    "matplotlib",
 | 
			
		||||
    "matplotlib.axes",
 | 
			
		||||
    "matplotlib.figure",
 | 
			
		||||
    "mpl_toolkits.mplot3d",
 | 
			
		||||
    ]
 | 
			
		||||
ignore_missing_imports = true
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user