Move .grids data into separate cell_data array. Also remove Direction enum

This commit is contained in:
Jan Petykiewicz 2021-10-24 18:59:44 -07:00
parent fbf173072a
commit 551da07f3e
4 changed files with 90 additions and 100 deletions

View File

@ -16,7 +16,6 @@ Dependencies:
- skimage [Grid.visualize_isosurface()] - skimage [Grid.visualize_isosurface()]
""" """
from .error import GridError from .error import GridError
from .direction import Direction
from .grid import Grid from .grid import Grid
__author__ = 'Jan Petykiewicz' __author__ = 'Jan Petykiewicz'

View File

@ -6,15 +6,16 @@ from typing import List, Optional, Union, Sequence, Callable
import numpy # type: ignore import numpy # type: ignore
from float_raster import raster from float_raster import raster
from . import GridError, Direction
from ._helpers import is_scalar from ._helpers import is_scalar
from . import GridError
eps_callable_t = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray] eps_callable_t = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray]
def draw_polygons(self, def draw_polygons(self,
surface_normal: Union[Direction, int], cell_data: numpy.ndarray,
surface_normal: int,
center: numpy.ndarray, center: numpy.ndarray,
polygons: Sequence[numpy.ndarray], polygons: Sequence[numpy.ndarray],
thickness: float, thickness: float,
@ -24,8 +25,8 @@ def draw_polygons(self,
Draw polygons on an axis-aligned plane. Draw polygons on an axis-aligned plane.
Args: Args:
surface_normal: Axis normal to the plane we're drawing on. Can be a `Direction` or cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
integer in `range(3)` surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying an offset applied to all the polygons center: 3-element ndarray or list specifying an offset applied to all the polygons
polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon
(non-closed, clockwise). If Nx3, the surface_normal coordinate is ignored. Each (non-closed, clockwise). If Nx3, the surface_normal coordinate is ignored. Each
@ -39,11 +40,6 @@ def draw_polygons(self,
Raises: Raises:
GridError GridError
""" """
# Turn surface_normal into its integer representation
if isinstance(surface_normal, Direction):
surface_normal = surface_normal.value
assert(isinstance(surface_normal, int))
if surface_normal not in range(3): if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction') raise GridError('Invalid surface_normal direction')
@ -66,8 +62,8 @@ def draw_polygons(self,
% 'xyz'[surface_normal]) % 'xyz'[surface_normal])
# Broadcast eps where necessary # Broadcast eps where necessary
if is_scalar(eps): if numpy.size(eps) == 1:
eps = [eps] * len(self.grids) eps = [eps] * len(cell_data)
elif isinstance(eps, numpy.ndarray): elif isinstance(eps, numpy.ndarray):
raise GridError('ndarray not supported for eps') raise GridError('ndarray not supported for eps')
@ -101,7 +97,7 @@ def draw_polygons(self,
return numpy.insert(v_2d, surface_normal, (val,)) return numpy.insert(v_2d, surface_normal, (val,))
# iterate over grids # iterate over grids
for i, grid in enumerate(self.grids): for i, grid in enumerate(cell_data):
# ## Evaluate or expand eps[i] # ## Evaluate or expand eps[i]
if callable(eps[i]): if callable(eps[i]):
# meshgrid over the (shifted) domain # meshgrid over the (shifted) domain
@ -184,11 +180,12 @@ def draw_polygons(self,
# ## Modify the grid # ## Modify the grid
g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3)) g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3))
self.grids[g_slice] = (1 - w) * self.grids[g_slice] + w * eps_i cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * eps_i
def draw_polygon(self, def draw_polygon(self,
surface_normal: Union[Direction, int], cell_data: numpy.ndarray,
surface_normal: int,
center: numpy.ndarray, center: numpy.ndarray,
polygon: numpy.ndarray, polygon: numpy.ndarray,
thickness: float, thickness: float,
@ -198,8 +195,8 @@ def draw_polygon(self,
Draw a polygon on an axis-aligned plane. Draw a polygon on an axis-aligned plane.
Args: Args:
surface_normal: Axis normal to the plane we're drawing on. Can be a Direction or cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
integer in range(3) surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying an offset applied to the polygon center: 3-element ndarray or list specifying an offset applied to the polygon
polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed,
clockwise). If Nx3, the surface_normal coordinate is ignored. Must have at clockwise). If Nx3, the surface_normal coordinate is ignored. Must have at
@ -207,11 +204,12 @@ def draw_polygon(self,
thickness: Thickness of the layer to draw thickness: Thickness of the layer to draw
eps: Value to draw with ('epsilon'). See `draw_polygons()` for details. eps: Value to draw with ('epsilon'). See `draw_polygons()` for details.
""" """
self.draw_polygons(surface_normal, center, [polygon], thickness, eps) self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, eps)
def draw_slab(self, def draw_slab(self,
surface_normal: Union[Direction, int], cell_data: numpy.ndarray,
surface_normal: int,
center: numpy.ndarray, center: numpy.ndarray,
thickness: float, thickness: float,
eps: Union[List[Union[float, eps_callable_t]], float, eps_callable_t], eps: Union[List[Union[float, eps_callable_t]], float, eps_callable_t],
@ -220,15 +218,13 @@ def draw_slab(self,
Draw an axis-aligned infinite slab. Draw an axis-aligned infinite slab.
Args: Args:
surface_normal: Axis normal to the plane we're drawing on. Can be a `Direction` or cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
integer in `range(3)` surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: Surface_normal coordinate at the center of the slab center: Surface_normal coordinate at the center of the slab
thickness: Thickness of the layer to draw thickness: Thickness of the layer to draw
eps: Value to draw with ('epsilon'). See `draw_polygons()` for details. eps: Value to draw with ('epsilon'). See `draw_polygons()` for details.
""" """
# Turn surface_normal into its integer representation # Turn surface_normal into its integer representation
if isinstance(surface_normal, Direction):
surface_normal = surface_normal.value
if surface_normal not in range(3): if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction') raise GridError('Invalid surface_normal direction')
@ -258,10 +254,11 @@ def draw_slab(self,
[xyz_max[0], xyz_min[1]], [xyz_max[0], xyz_min[1]],
[xyz_min[0], xyz_min[1]]], dtype=float) [xyz_min[0], xyz_min[1]]], dtype=float)
self.draw_polygon(surface_normal, center_shift, p, thickness, eps) self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, eps)
def draw_cuboid(self, def draw_cuboid(self,
cell_data: numpy.ndarray,
center: numpy.ndarray, center: numpy.ndarray,
dimensions: numpy.ndarray, dimensions: numpy.ndarray,
eps: Union[List[Union[float, eps_callable_t]], float, eps_callable_t], eps: Union[List[Union[float, eps_callable_t]], float, eps_callable_t],
@ -270,6 +267,7 @@ def draw_cuboid(self,
Draw an axis-aligned cuboid Draw an axis-aligned cuboid
Args: Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
center: 3-element ndarray or list specifying the cuboid's center center: 3-element ndarray or list specifying the cuboid's center
dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge
sizes of the cuboid sizes of the cuboid
@ -280,11 +278,12 @@ def draw_cuboid(self,
[+dimensions[0], -dimensions[1]], [+dimensions[0], -dimensions[1]],
[-dimensions[0], -dimensions[1]]], dtype=float) / 2.0 [-dimensions[0], -dimensions[1]]], dtype=float) / 2.0
thickness = dimensions[2] thickness = dimensions[2]
self.draw_polygon(Direction.z, center, p, thickness, eps) self.draw_polygon(cell_data, 2, center, p, thickness, eps)
def draw_cylinder(self, def draw_cylinder(self,
surface_normal: Union[Direction, int], cell_data: numpy.ndarray,
surface_normal: int,
center: numpy.ndarray, center: numpy.ndarray,
radius: float, radius: float,
thickness: float, thickness: float,
@ -295,8 +294,8 @@ def draw_cylinder(self,
Draw an axis-aligned cylinder. Approximated by a num_points-gon Draw an axis-aligned cylinder. Approximated by a num_points-gon
Args: Args:
surface_normal: Axis normal to the plane we're drawing on. Can be a `Direction` or cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
integer in `range(3)` surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying the cylinder's center center: 3-element ndarray or list specifying the cylinder's center
radius: cylinder radius radius: cylinder radius
thickness: Thickness of the layer to draw thickness: Thickness of the layer to draw
@ -306,13 +305,14 @@ def draw_cylinder(self,
theta = numpy.linspace(0, 2*numpy.pi, num_points, endpoint=False) theta = numpy.linspace(0, 2*numpy.pi, num_points, endpoint=False)
x = radius * numpy.sin(theta) x = radius * numpy.sin(theta)
y = radius * numpy.cos(theta) y = radius * numpy.cos(theta)
self.draw_polygon(surface_normal, center, polygon, thickness, eps)
polygon = numpy.hstack((x[:, None], y[:, None])) polygon = numpy.hstack((x[:, None], y[:, None]))
self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, eps)
def draw_extrude_rectangle(self, def draw_extrude_rectangle(self,
cell_data: numpy.ndarray,
rectangle: numpy.ndarray, rectangle: numpy.ndarray,
direction: Union[Direction, int], direction: int,
polarity: int, polarity: int,
distance: float, distance: float,
) -> None: ) -> None:
@ -320,16 +320,12 @@ def draw_extrude_rectangle(self,
Extrude a rectangle of a previously-drawn structure along an axis. Extrude a rectangle of a previously-drawn structure along an axis.
Args: Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
rectangle: 2x3 ndarray or list specifying the rectangle's corners rectangle: 2x3 ndarray or list specifying the rectangle's corners
direction: Direction to extrude in. Direction enum or int in range(3) direction: Direction to extrude in. Integer in `range(3)`.
polarity: +1 or -1, direction along axis to extrude in polarity: +1 or -1, direction along axis to extrude in
distance: How far to extrude distance: How far to extrude
""" """
# Turn extrude_direction into its integer representation
if isinstance(direction, Direction):
direction = direction.value
assert(isinstance(direction, int))
s = numpy.sign(polarity) s = numpy.sign(polarity)
rectangle = numpy.array(rectangle, dtype=float) rectangle = numpy.array(rectangle, dtype=float)
@ -351,7 +347,7 @@ def draw_extrude_rectangle(self,
thickness = distance thickness = distance
eps_func = [] eps_func = []
for i, grid in enumerate(self.grids): for i, grid in enumerate(cell_data):
z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction] z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction]
ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)] ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)]
@ -373,5 +369,5 @@ def draw_extrude_rectangle(self,
eps_func.append(f_eps) eps_func.append(f_eps)
self.draw_polygon(direction, center, p, thickness, eps_func) self.draw_polygon(cell_data, direction, center, p, thickness, eps_func)

View File

@ -7,8 +7,8 @@ import pickle
import warnings import warnings
import copy import copy
from . import GridError, Direction
from ._helpers import is_scalar from ._helpers import is_scalar
from . import GridError
__author__ = 'Jan Petykiewicz' __author__ = 'Jan Petykiewicz'
@ -18,12 +18,18 @@ eps_callable_type = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], nump
class Grid: class Grid:
""" """
Simulation grid generator intended for electromagnetic simulations. Simulation grid metadata for finite-difference simulations.
Can be used to generate non-uniform rectangular grids (the entire grid
Can be used to generate non-uniform rectangular grids (the entire grid
is generated based on the coordinates of the boundary points). Also does is generated based on the coordinates of the boundary points). Also does
straightforward natural <-> grid unit conversion. straightforward natural <-> grid unit conversion.
`self.grids[i][a,b,c]` contains the value of epsilon for the cell located around This class handles data describing the grid, and should be paired with a
(separate) ndarray that contains the actual data in each cell. The `allocate()`
method can be used to create this ndarray.
The resulting `cell_data[i, a, b, c]` should correspond to the value in the
`i`-th grid, in the cell centered around
``` ```
(xyz[0][a] + dxyz[0][a] * shifts[i, 0], (xyz[0][a] + dxyz[0][a] * shifts[i, 0],
xyz[1][b] + dxyz[1][b] * shifts[i, 1], xyz[1][b] + dxyz[1][b] * shifts[i, 1],
@ -47,9 +53,6 @@ class Grid:
exyz: List[numpy.ndarray] exyz: List[numpy.ndarray]
"""Cell edges. Monotonically increasing without duplicates.""" """Cell edges. Monotonically increasing without duplicates."""
grids: numpy.ndarray
"""epsilon (or mu, or whatever) grids. shape is (num_grids, X, Y, Z)"""
periodic: List[bool] periodic: List[bool]
"""For each axis, determines how far the rightmost boundary gets shifted. """ """For each axis, determines how far the rightmost boundary gets shifted. """
@ -103,6 +106,20 @@ class Grid:
""" """
return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int) return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int)
@property
def num_grids(self) -> int:
"""
The number of grids (number of shifts)
"""
return self.shifts.shape[0]
@property
def cell_data_shape(self):
"""
The shape of the cell_data ndarray (num_grids, *self.shape).
"""
return numpy.hstack((self.num_grids, self.shape))
@property @property
def dxyz_with_ghost(self) -> List[numpy.ndarray]: def dxyz_with_ghost(self) -> List[numpy.ndarray]:
""" """
@ -218,16 +235,30 @@ class Grid:
Returns: Returns:
`[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]` `[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]`
""" """
if len(self.grids) != 3: if self.num_grids != 3:
raise GridError('autoshifting requires exactly 3 grids') raise GridError('Autoshifting requires exactly 3 grids')
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
def allocate(self, fill_value: Optional[float] = 1.0, dtype=numpy.float64) -> numpy.ndarray:
"""
Allocate an ndarray for storing grid data.
Args:
fill_value: Value to initialize the grid to. If None, an
uninitialized array is returned.
dtype: Numpy dtype for the array. Default is `numpy.float64`.
Returns:
The allocated array
"""
if fill_value is None:
return numpy.empty(self.cell_data_shape)
else:
return numpy.full(self.cell_data_shape, fill_value)
def __init__(self, def __init__(self,
pixel_edge_coordinates: Sequence[numpy.ndarray], pixel_edge_coordinates: Sequence[numpy.ndarray],
shifts: numpy.ndarray = Yee_Shifts_E, shifts: numpy.ndarray = Yee_Shifts_E,
initial: Union[float, numpy.ndarray] = 1.0,
num_grids: Optional[int] = None,
periodic: Union[bool, Sequence[bool]] = False, periodic: Union[bool, Sequence[bool]] = False,
) -> None: ) -> None:
""" """
@ -238,12 +269,6 @@ class Grid:
x=`x1`, the second has edges x=`x1` and x=`x2`, etc.) x=`x1`, the second has edges x=`x1` and x=`x2`, etc.)
shifts: Nx3 array containing `[x, y, z]` offsets for each of N grids. shifts: Nx3 array containing `[x, y, z]` offsets for each of N grids.
E-field Yee shifts are used by default. E-field Yee shifts are used by default.
initial: Grids are initialized to this value. If scalar, all grids are initialized
with ndarrays full of the scalar. If a list of scalars, `grid[i]` is initialized to an
ndarray full of `initial[i]`. If a list of ndarrays of the same shape as the grids, `grid[i]`
is set to `initial[i]`. Default `1.0`.
num_grids: How many grids to create. Must be <= `shifts.shape[0]`.
Default is `shifts.shape[0]`
periodic: Specifies how the sizes of edge cells are calculated; see main class periodic: Specifies how the sizes of edge cells are calculated; see main class
documentation. List of 3 bool, or a single bool that gets broadcast. Default `False`. documentation. List of 3 bool, or a single bool that gets broadcast. Default `False`.
@ -276,33 +301,6 @@ class Grid:
# TODO: Test negative shifts # TODO: Test negative shifts
warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2) warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2)
num_shifts = self.shifts.shape[0]
if num_grids is None:
num_grids = num_shifts
elif num_grids > num_shifts:
raise GridError('Number of grids exceeds number of shifts (%u)' % num_shifts)
grids_shape = hstack((num_grids, self.shape))
if isinstance(initial, (float, int)):
if isinstance(initial, int):
warnings.warn('Initial value is an int, grids will be integer-typed!', stacklevel=2)
self.grids = numpy.full(grids_shape, initial)
else:
if len(initial) < num_grids:
raise GridError('Too few initial grids specified!')
self.grids = numpy.empty(grids_shape)
for i in range(num_grids):
if is_scalar(initial[i]):
if initial[i] is not None:
if isinstance(initial[i], int):
warnings.warn('Initial value is an int, grid {} will be integer-typed!'.format(i), stacklevel=2)
self.grids[i] = numpy.full(self.shape, initial[i])
else:
if not numpy.array_equal(initial[i].shape, self.shape):
raise GridError('Initial grid sizes must match given coordinates')
self.grids[i] = initial[i]
@staticmethod @staticmethod
def load(filename: str) -> 'Grid': def load(filename: str) -> 'Grid':
""" """

View File

@ -5,8 +5,8 @@ from typing import Dict, Optional, Union, Any
import numpy # type: ignore import numpy # type: ignore
from . import GridError, Direction
from ._helpers import is_scalar from ._helpers import is_scalar
from . import GridError
# .visualize_* uses matplotlib # .visualize_* uses matplotlib
# .visualize_isosurface uses skimage # .visualize_isosurface uses skimage
@ -14,7 +14,8 @@ from ._helpers import is_scalar
def get_slice(self, def get_slice(self,
surface_normal: Union[Direction, int], cell_data: numpy.ndarray,
surface_normal: int,
center: float, center: float,
which_shifts: int = 0, which_shifts: int = 0,
sample_period: int = 1 sample_period: int = 1
@ -24,8 +25,8 @@ def get_slice(self,
Interpolates if given a position between two planes. Interpolates if given a position between two planes.
Args: Args:
surface_normal: Axis normal to the plane we're displaying. Can be a `Direction` or cell_data: Cell data to slice
integer in `range(3)` surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
center: Scalar specifying position along surface_normal axis. center: Scalar specifying position along surface_normal axis.
which_shifts: Which grid to display. Default is the first grid (0). which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled) sample_period: Period for down-sampling the image. Default 1 (disabled)
@ -43,9 +44,6 @@ def get_slice(self,
if not is_scalar(which_shifts) or which_shifts < 0: if not is_scalar(which_shifts) or which_shifts < 0:
raise GridError('Invalid which_shifts') raise GridError('Invalid which_shifts')
# Turn surface_normal into its integer representation
if isinstance(surface_normal, Direction):
surface_normal = surface_normal.value
if surface_normal not in range(3): if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction') raise GridError('Invalid surface_normal direction')
@ -70,7 +68,7 @@ def get_slice(self,
sliced_grid = numpy.zeros(self.shape[surface]) sliced_grid = numpy.zeros(self.shape[surface])
for ci, weight in zip(centers, w): for ci, weight in zip(centers, w):
s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3)) s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3))
sliced_grid += weight * self.grids[which_shifts][tuple(s)] sliced_grid += weight * cell_data[which_shifts][tuple(s)]
# Remove extra dimensions # Remove extra dimensions
sliced_grid = numpy.squeeze(sliced_grid) sliced_grid = numpy.squeeze(sliced_grid)
@ -79,7 +77,8 @@ def get_slice(self,
def visualize_slice(self, def visualize_slice(self,
surface_normal: Union[Direction, int], cell_data: numpy.ndarray,
surface_normal: int,
center: float, center: float,
which_shifts: int = 0, which_shifts: int = 0,
sample_period: int = 1, sample_period: int = 1,
@ -91,8 +90,7 @@ def visualize_slice(self,
Interpolates if given a position between two planes. Interpolates if given a position between two planes.
Args: Args:
surface_normal: Axis normal to the plane we're displaying. Can be a `Direction` or surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
integer in `range(3)`
center: Scalar specifying position along surface_normal axis. center: Scalar specifying position along surface_normal axis.
which_shifts: Which grid to display. Default is the first grid (0). which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled) sample_period: Period for down-sampling the image. Default 1 (disabled)
@ -100,14 +98,11 @@ def visualize_slice(self,
""" """
from matplotlib import pyplot from matplotlib import pyplot
# Set surface normal to its integer value
if isinstance(surface_normal, Direction):
surface_normal = surface_normal.value
if pcolormesh_args is None: if pcolormesh_args is None:
pcolormesh_args = {} pcolormesh_args = {}
grid_slice = self.get_slice(surface_normal=surface_normal, grid_slice = self.get_slice(cell_data=cell_data,
surface_normal=surface_normal,
center=center, center=center,
which_shifts=which_shifts, which_shifts=which_shifts,
sample_period=sample_period) sample_period=sample_period)
@ -129,6 +124,7 @@ def visualize_slice(self,
def visualize_isosurface(self, def visualize_isosurface(self,
cell_data: numpy.ndarray,
level: Optional[float] = None, level: Optional[float] = None,
which_shifts: int = 0, which_shifts: int = 0,
sample_period: int = 1, sample_period: int = 1,
@ -139,6 +135,7 @@ def visualize_isosurface(self,
Draw an isosurface plot of the device. Draw an isosurface plot of the device.
Args: Args:
cell_data: Cell data to visualize
level: Value at which to find isosurface. Default (None) uses mean value in grid. level: Value at which to find isosurface. Default (None) uses mean value in grid.
which_shifts: Which grid to display. Default is the first grid (0). which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled) sample_period: Period for down-sampling the image. Default 1 (disabled)
@ -150,8 +147,8 @@ def visualize_isosurface(self,
# Claims to be unused, but needed for subplot(projection='3d') # Claims to be unused, but needed for subplot(projection='3d')
from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d import Axes3D
# Get data from self.grids # Get data from cell_data
grid = self.grids[which_shifts][::sample_period, ::sample_period, ::sample_period] grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period]
if level is None: if level is None:
level = grid.mean() level = grid.mean()