snapshot 2021-10-24 18:42:26.970173

lethe/LATEST
jan 3 years ago
parent d054fc5a94
commit 0f04325c74

@ -16,7 +16,6 @@ Dependencies:
- skimage [Grid.visualize_isosurface()]
"""
from .error import GridError
from .direction import Direction
from .grid import Grid
__author__ = 'Jan Petykiewicz'

@ -4,19 +4,17 @@ Drawing-related methods for Grid class
from typing import List, Optional, Union, Sequence, Callable
import numpy # type: ignore
from numpy import diff, floor, ceil, zeros, hstack, newaxis
from float_raster import raster
from . import GridError, Direction
from ._helpers import is_scalar
from . import GridError
eps_callable_t = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray]
def draw_polygons(self,
surface_normal: Union[Direction, int],
cell_data: numpy.ndarray,
surface_normal: int,
center: numpy.ndarray,
polygons: Sequence[numpy.ndarray],
thickness: float,
@ -26,8 +24,8 @@ def draw_polygons(self,
Draw polygons on an axis-aligned plane.
Args:
surface_normal: Axis normal to the plane we're drawing on. Can be a `Direction` or
integer in `range(3)`
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying an offset applied to all the polygons
polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon
(non-closed, clockwise). If Nx3, the surface_normal coordinate is ignored. Each
@ -41,11 +39,6 @@ def draw_polygons(self,
Raises:
GridError
"""
# Turn surface_normal into its integer representation
if isinstance(surface_normal, Direction):
surface_normal = surface_normal.value
assert(isinstance(surface_normal, int))
if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction')
@ -55,7 +48,7 @@ def draw_polygons(self,
surface = numpy.delete(range(3), surface_normal)
for i, polygon in enumerate(polygons):
malformed = 'Malformed polygon: (%i)' % i
malformed = f'Malformed polygon: ({i})'
if polygon.shape[1] not in (2, 3):
raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray')
if polygon.shape[1] == 3:
@ -64,12 +57,12 @@ def draw_polygons(self,
if not polygon.shape[0] > 2:
raise GridError(malformed + 'must consist of more than 2 points')
if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1:
raise GridError(malformed + 'must be in plane with surface normal %s'
% 'xyz'[surface_normal])
raise GridError(malformed + 'must be in plane with surface normal '
+ 'xyz'[surface_normal])
# Broadcast eps where necessary
if is_scalar(eps):
eps = [eps] * len(self.grids)
if numpy.size(eps) == 1:
eps = [eps] * len(cell_data)
elif isinstance(eps, numpy.ndarray):
raise GridError('ndarray not supported for eps')
@ -91,8 +84,8 @@ def draw_polygons(self,
s_max = self.shifts.max(axis=0)
bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf
bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf
bdi_min = numpy.maximum(floor(bdi_min), 0).astype(int)
bdi_max = numpy.minimum(ceil(bdi_max), self.shape - 1).astype(int)
bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int)
bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int)
# 3) Adjust polygons for center
polygons = [poly + center[surface] for poly in polygons]
@ -103,7 +96,7 @@ def draw_polygons(self,
return numpy.insert(v_2d, surface_normal, (val,))
# iterate over grids
for i, grid in enumerate(self.grids):
for i, grid in enumerate(cell_data):
# ## Evaluate or expand eps[i]
if callable(eps[i]):
# meshgrid over the (shifted) domain
@ -113,14 +106,14 @@ def draw_polygons(self,
# evaluate on the meshgrid
eps_i = eps[i](x0, y0, z0)
if not numpy.isfinite(eps_i).all():
raise GridError('Non-finite values in eps[%u]' % i)
elif not is_scalar(eps[i]):
raise GridError('Unsupported eps[{}]: {}'.format(i, type(eps[i])))
raise GridError(f'Non-finite values in eps[{i}]')
elif numpy.size(eps[i]) != 1:
raise GridError(f'Unsupported eps[{i}]: {type(eps[i])}')
else:
# eps[i] is scalar non-callable
eps_i = eps[i]
w_xy = zeros((bdi_max - bdi_min + 1)[surface].astype(int))
w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int))
# Draw each polygon separately
for polygon in polygons:
@ -182,15 +175,16 @@ def draw_polygons(self,
w_z[zi_bot] = zi_top_f - zi_bot_f
# 3) Generate total weight function
w = (w_xy[:, :, newaxis] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,)))
w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,)))
# ## Modify the grid
g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3))
self.grids[g_slice] = (1 - w) * self.grids[g_slice] + w * eps_i
cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * eps_i
def draw_polygon(self,
surface_normal: Union[Direction, int],
cell_data: numpy.ndarray,
surface_normal: int,
center: numpy.ndarray,
polygon: numpy.ndarray,
thickness: float,
@ -200,20 +194,21 @@ def draw_polygon(self,
Draw a polygon on an axis-aligned plane.
Args:
surface_normal: Axis normal to the plane we're drawing on. Can be a Direction or
integer in range(3)
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying an offset applied to the polygon
polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed,
clockwise). If Nx3, the surface_normal coordinate is ignored. Must have at
clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at
least 3 vertices.
thickness: Thickness of the layer to draw
eps: Value to draw with ('epsilon'). See `draw_polygons()` for details.
"""
self.draw_polygons(surface_normal, center, [polygon], thickness, eps)
self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, eps)
def draw_slab(self,
surface_normal: Union[Direction, int],
cell_data: numpy.ndarray,
surface_normal: int,
center: numpy.ndarray,
thickness: float,
eps: Union[List[Union[float, eps_callable_t]], float, eps_callable_t],
@ -222,24 +217,22 @@ def draw_slab(self,
Draw an axis-aligned infinite slab.
Args:
surface_normal: Axis normal to the plane we're drawing on. Can be a `Direction` or
integer in `range(3)`
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: Surface_normal coordinate at the center of the slab
thickness: Thickness of the layer to draw
eps: Value to draw with ('epsilon'). See `draw_polygons()` for details.
"""
# Turn surface_normal into its integer representation
if isinstance(surface_normal, Direction):
surface_normal = surface_normal.value
if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction')
if not is_scalar(center):
if numpy.size(center) != 1:
center = numpy.squeeze(center)
if len(center) == 3:
center = center[surface_normal]
else:
raise GridError('Bad center: {}'.format(center))
raise GridError(f'Bad center: {center}')
# Find center of slab
center_shift = self.center
@ -260,10 +253,11 @@ def draw_slab(self,
[xyz_max[0], xyz_min[1]],
[xyz_min[0], xyz_min[1]]], dtype=float)
self.draw_polygon(surface_normal, center_shift, p, thickness, eps)
self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, eps)
def draw_cuboid(self,
cell_data: numpy.ndarray,
center: numpy.ndarray,
dimensions: numpy.ndarray,
eps: Union[List[Union[float, eps_callable_t]], float, eps_callable_t],
@ -272,6 +266,7 @@ def draw_cuboid(self,
Draw an axis-aligned cuboid
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
center: 3-element ndarray or list specifying the cuboid's center
dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge
sizes of the cuboid
@ -282,11 +277,12 @@ def draw_cuboid(self,
[+dimensions[0], -dimensions[1]],
[-dimensions[0], -dimensions[1]]], dtype=float) / 2.0
thickness = dimensions[2]
self.draw_polygon(Direction.z, center, p, thickness, eps)
self.draw_polygon(cell_data, 2, center, p, thickness, eps)
def draw_cylinder(self,
surface_normal: Union[Direction, int],
cell_data: numpy.ndarray,
surface_normal: int,
center: numpy.ndarray,
radius: float,
thickness: float,
@ -297,8 +293,8 @@ def draw_cylinder(self,
Draw an axis-aligned cylinder. Approximated by a num_points-gon
Args:
surface_normal: Axis normal to the plane we're drawing on. Can be a `Direction` or
integer in `range(3)`
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`.
center: 3-element ndarray or list specifying the cylinder's center
radius: cylinder radius
thickness: Thickness of the layer to draw
@ -308,13 +304,14 @@ def draw_cylinder(self,
theta = numpy.linspace(0, 2*numpy.pi, num_points, endpoint=False)
x = radius * numpy.sin(theta)
y = radius * numpy.cos(theta)
polygon = hstack((x[:, newaxis], y[:, newaxis]))
self.draw_polygon(surface_normal, center, polygon, thickness, eps)
polygon = numpy.hstack((x[:, None], y[:, None]))
self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, eps)
def draw_extrude_rectangle(self,
cell_data: numpy.ndarray,
rectangle: numpy.ndarray,
direction: Union[Direction, int],
direction: int,
polarity: int,
distance: float,
) -> None:
@ -322,23 +319,19 @@ def draw_extrude_rectangle(self,
Extrude a rectangle of a previously-drawn structure along an axis.
Args:
cell_data: Cell data to modify (e.g. created by `Grid.allocate()`)
rectangle: 2x3 ndarray or list specifying the rectangle's corners
direction: Direction to extrude in. Direction enum or int in range(3)
direction: Direction to extrude in. Integer in `range(3)`.
polarity: +1 or -1, direction along axis to extrude in
distance: How far to extrude
"""
# Turn extrude_direction into its integer representation
if isinstance(direction, Direction):
direction = direction.value
assert(isinstance(direction, int))
s = numpy.sign(polarity)
rectangle = numpy.array(rectangle, dtype=float)
if s == 0:
raise GridError('0 is not a valid polarity')
if direction not in range(3):
raise GridError('Invalid direction: {}'.format(direction))
raise GridError(f'Invalid direction: {direction}')
if rectangle[0, direction] != rectangle[1, direction]:
raise GridError('Rectangle entries along extrusion direction do not match.')
@ -347,18 +340,18 @@ def draw_extrude_rectangle(self,
surface = numpy.delete(range(3), direction)
dim = numpy.fabs(diff(rectangle, axis=0).T)[surface]
dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface]
p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0]/2.0,
numpy.array([-1, 1, 1, -1], dtype=float) * dim[1]/2.0)).T
thickness = distance
eps_func = []
for i, grid in enumerate(self.grids):
for i, grid in enumerate(cell_data):
z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction]
ind = [int(floor(z)) if i == direction else slice(None) for i in range(3)]
ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)]
fpart = z - floor(z)
fpart = z - numpy.floor(z)
mult = [1-fpart, fpart][::s] # reverses if s negative
eps = mult[0] * grid[tuple(ind)]
@ -375,5 +368,5 @@ def draw_extrude_rectangle(self,
eps_func.append(f_eps)
self.draw_polygon(direction, center, p, thickness, eps_func)
self.draw_polygon(cell_data, direction, center, p, thickness, eps_func)

@ -0,0 +1,44 @@
import numpy # type: ignore
from gridlock import Grid
if __name__ == '__main__':
# xyz = [numpy.arange(-5.0, 6.0), numpy.arange(-4.0, 5.0), [-1.0, 1.0]]
# eg = Grid(xyz)
# egc = Grid.allocate(0.0)
# # eg.draw_slab(egc, surface_normal=2, center=0, thickness=10, eps=2)
# eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=4,
# thickness=10, num_points=1000, eps=1)
# eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2)
# xyz2 = [numpy.arange(-5.0, 6.0), [-1.0, 1.0], numpy.arange(-4.0, 5.0)]
# eg2 = Grid(xyz2)
# eg2c = Grid.allocate(0.0)
# # eg2.draw_slab(eg2c, surface_normal=2, center=0, thickness=10, eps=2)
# eg2.draw_cylinder(eg2c, surface_normal=1, center=[0, 0, 0],
# radius=4, thickness=10, num_points=1000, eps=1.0)
# eg2.visualize_slice(eg2c, surface_normal=1, center=0, which_shifts=1)
# n = 20
# m = 3
# r1 = numpy.fromfunction(lambda x: numpy.sign(x - n) * 2 ** (abs(x - n)/m), (2*n, ))
# print(r1)
# xyz3 = [r1, numpy.linspace(-5.5, 5.5, 30), numpy.linspace(-5.5, 5.5, 10)]
# xyz3 = [numpy.linspace(-5.5, 5.5, 10),
# numpy.linspace(-5.5, 5.5, 10),
# numpy.linspace(-5.5, 5.5, 10)]
half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5]
xyz3 = [[-x for x in half_x[::-1]] + [0] + half_x,
numpy.linspace(-5.5, 5.5, 10),
numpy.linspace(-5.5, 5.5, 10)]
eg = Grid(xyz3)
egc = eg.allocate(0)
# eg.draw_slab(Direction.z, 0, 10, 2)
eg.save('/home/jan/Desktop/test.pickle')
eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0,
thickness=10, num_poitns=1000, eps=1)
eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]],
direction=1, poalarity=+1, distance=5)
eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2)
eg.visualize_isosurface(egc, which_shifts=2)

@ -1,4 +1,4 @@
from typing import List, Tuple, Callable, Dict, Optional, Union, Sequence, ClassVar
from typing import List, Tuple, Callable, Dict, Optional, Union, Sequence, ClassVar, TypeVar
import numpy # type: ignore
from numpy import diff, floor, ceil, zeros, hstack, newaxis
@ -7,23 +7,27 @@ import pickle
import warnings
import copy
from . import GridError, Direction
from ._helpers import is_scalar
from . import GridError
__author__ = 'Jan Petykiewicz'
eps_callable_type = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray]
T = TypeVar('T', bound='Grid')
class Grid:
"""
Simulation grid generator intended for electromagnetic simulations.
Can be used to generate non-uniform rectangular grids (the entire grid
Simulation grid metadata for finite-difference simulations.
Can be used to generate non-uniform rectangular grids (the entire grid
is generated based on the coordinates of the boundary points). Also does
straightforward natural <-> grid unit conversion.
`self.grids[i][a,b,c]` contains the value of epsilon for the cell located around
This class handles data describing the grid, and should be paired with a
(separate) ndarray that contains the actual data in each cell. The `allocate()`
method can be used to create this ndarray.
The resulting `cell_data[i, a, b, c]` should correspond to the value in the
`i`-th grid, in the cell centered around
```
(xyz[0][a] + dxyz[0][a] * shifts[i, 0],
xyz[1][b] + dxyz[1][b] * shifts[i, 1],
@ -47,9 +51,6 @@ class Grid:
exyz: List[numpy.ndarray]
"""Cell edges. Monotonically increasing without duplicates."""
grids: numpy.ndarray
"""epsilon (or mu, or whatever) grids. shape is (num_grids, X, Y, Z)"""
periodic: List[bool]
"""For each axis, determines how far the rightmost boundary gets shifted. """
@ -81,7 +82,7 @@ class Grid:
Returns:
List of 3 ndarrays of cell sizes
"""
return [diff(self.exyz[a]) for a in range(3)]
return [numpy.diff(ee) for ee in self.exyz]
@property
def xyz(self) -> List[numpy.ndarray]:
@ -103,6 +104,20 @@ class Grid:
"""
return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int)
@property
def num_grids(self) -> int:
"""
The number of grids (number of shifts)
"""
return self.shifts.shape[0]
@property
def cell_data_shape(self):
"""
The shape of the cell_data ndarray (num_grids, *self.shape).
"""
return numpy.hstack((self.num_grids, self.shape))
@property
def dxyz_with_ghost(self) -> List[numpy.ndarray]:
"""
@ -117,7 +132,7 @@ class Grid:
list of [dxs, dys, dzs] with each element same length as elements of `self.xyz`
"""
el = [0 if p else -1 for p in self.periodic]
return [hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el)]
return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el)]
@property
def center(self) -> numpy.ndarray:
@ -211,23 +226,37 @@ class Grid:
dxyz = self.shifted_dxyz(which_shifts)
return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)]
def autoshifted_dxyz(self):
def autoshifted_dxyz(self) -> List[numpy.ndarray]:
"""
Return cell widths, with each dimension shifted by the corresponding shifts.
Returns:
`[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]`
"""
if len(self.grids) != 3:
raise GridError('autoshifting requires exactly 3 grids')
if self.num_grids != 3:
raise GridError('Autoshifting requires exactly 3 grids')
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
def allocate(self, fill_value: Optional[float] = 1.0, dtype=numpy.float64) -> numpy.ndarray:
"""
Allocate an ndarray for storing grid data.
Args:
fill_value: Value to initialize the grid to. If None, an
uninitialized array is returned.
dtype: Numpy dtype for the array. Default is `numpy.float64`.
Returns:
The allocated array
"""
if fill_value is None:
return numpy.empty(self.cell_data_shape)
else:
return numpy.full(self.cell_data_shape, fill_value)
def __init__(self,
pixel_edge_coordinates: Sequence[numpy.ndarray],
shifts: numpy.ndarray = Yee_Shifts_E,
initial: Union[float, numpy.ndarray] = 1.0,
num_grids: Optional[int] = None,
periodic: Union[bool, Sequence[bool]] = False,
) -> None:
"""
@ -238,12 +267,6 @@ class Grid:
x=`x1`, the second has edges x=`x1` and x=`x2`, etc.)
shifts: Nx3 array containing `[x, y, z]` offsets for each of N grids.
E-field Yee shifts are used by default.
initial: Grids are initialized to this value. If scalar, all grids are initialized
with ndarrays full of the scalar. If a list of scalars, `grid[i]` is initialized to an
ndarray full of `initial[i]`. If a list of ndarrays of the same shape as the grids, `grid[i]`
is set to `initial[i]`. Default `1.0`.
num_grids: How many grids to create. Must be <= `shifts.shape[0]`.
Default is `shifts.shape[0]`
periodic: Specifies how the sizes of edge cells are calculated; see main class
documentation. List of 3 bool, or a single bool that gets broadcast. Default `False`.
@ -255,7 +278,7 @@ class Grid:
for i in range(3):
if len(self.exyz[i]) != len(pixel_edge_coordinates[i]):
warnings.warn('Dimension {} had duplicate edge coordinates'.format(i), stacklevel=2)
warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2)
if isinstance(periodic, bool):
self.periodic = [periodic] * 3
@ -264,10 +287,10 @@ class Grid:
if len(self.shifts.shape) != 2:
raise GridError('Misshapen shifts: shifts must have two axes! '
' The given shifts has shape {}'.format(self.shifts.shape))
f' The given shifts has shape {self.shifts.shape}')
if self.shifts.shape[1] != 3:
raise GridError('Misshapen shifts; second axis size should be 3,'
' shape is {}'.format(self.shifts.shape))
f' shape is {self.shifts.shape}')
if (numpy.abs(self.shifts) > 1).any():
raise GridError('Only shifts in the range [-1, 1] are currently supported')
@ -276,33 +299,6 @@ class Grid:
# TODO: Test negative shifts
warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2)
num_shifts = self.shifts.shape[0]
if num_grids is None:
num_grids = num_shifts
elif num_grids > num_shifts:
raise GridError('Number of grids exceeds number of shifts (%u)' % num_shifts)
grids_shape = hstack((num_grids, self.shape))
if isinstance(initial, (float, int)):
if isinstance(initial, int):
warnings.warn('Initial value is an int, grids will be integer-typed!', stacklevel=2)
self.grids = numpy.full(grids_shape, initial)
else:
if len(initial) < num_grids:
raise GridError('Too few initial grids specified!')
self.grids = numpy.empty(grids_shape)
for i in range(num_grids):
if is_scalar(initial[i]):
if initial[i] is not None:
if isinstance(initial[i], int):
warnings.warn('Initial value is an int, grid {} will be integer-typed!'.format(i), stacklevel=2)
self.grids[i] = numpy.full(self.shape, initial[i])
else:
if not numpy.array_equal(initial[i].shape, self.shape):
raise GridError('Initial grid sizes must match given coordinates')
self.grids[i] = initial[i]
@staticmethod
def load(filename: str) -> 'Grid':
"""
@ -318,17 +314,21 @@ class Grid:
g.__dict__.update(tmp_dict)
return g
def save(self, filename: str):
def save(self: T, filename: str) -> T:
"""
Save to file.
Args:
filename: Filename to save to.
Returns:
self
"""
with open(filename, 'wb') as f:
pickle.dump(self.__dict__, f, protocol=2)
return self
def copy(self):
def copy(self: T) -> T:
"""
Returns:
Deep copy of the grid.

@ -4,7 +4,6 @@ Position-related methods for Grid class
from typing import List, Optional
import numpy # type: ignore
from numpy import zeros
from . import GridError
@ -47,7 +46,7 @@ def ind2pos(self,
low_bound = -0.5
high_bound = -0.5
if (ind < low_bound).any() or (ind > self.shape - high_bound).any():
raise GridError('Position outside of grid: {}'.format(ind))
raise GridError(f'Position outside of grid: {ind}')
if round_ind:
rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1)
@ -85,19 +84,19 @@ def pos2ind(self,
"""
r = numpy.squeeze(r)
if r.size != 3:
raise GridError('r must be 3-element vector: {}'.format(r))
raise GridError(f'r must be 3-element vector: {r}')
if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]):
raise GridError('Invalid which_shifts: {}'.format(which_shifts))
raise GridError(f'Invalid which_shifts: {which_shifts}')
sexyz = self.shifted_exyz(which_shifts)
if check_bounds:
for a in range(3):
if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]):
raise GridError('Position[{}] outside of grid!'.format(a))
raise GridError(f'Position[{a}] outside of grid!')
grid_pos = zeros((3,))
grid_pos = numpy.zeros((3,))
for a in range(3):
xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in
xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds

@ -4,10 +4,8 @@ Readback and visualization methods for Grid class
from typing import Dict, Optional, Union, Any
import numpy # type: ignore
from numpy import floor, ceil, zeros
from . import GridError, Direction
from ._helpers import is_scalar
from . import GridError
# .visualize_* uses matplotlib
# .visualize_isosurface uses skimage
@ -15,7 +13,8 @@ from ._helpers import is_scalar
def get_slice(self,
surface_normal: Union[Direction, int],
cell_data: numpy.ndarray,
surface_normal: int,
center: float,
which_shifts: int = 0,
sample_period: int = 1
@ -25,8 +24,8 @@ def get_slice(self,
Interpolates if given a position between two planes.
Args:
surface_normal: Axis normal to the plane we're displaying. Can be a `Direction` or
integer in `range(3)`
cell_data: Cell data to slice
surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
center: Scalar specifying position along surface_normal axis.
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
@ -34,19 +33,16 @@ def get_slice(self,
Returns:
Array containing the portion of the grid.
"""
if not is_scalar(center) and numpy.isreal(center):
if numpy.size(center) != 1 or not numpy.isreal(center):
raise GridError('center must be a real scalar')
sp = round(sample_period)
if sp <= 0:
raise GridError('sample_period must be positive')
if not is_scalar(which_shifts) or which_shifts < 0:
if numpy.size(which_shifts) != 1 or which_shifts < 0:
raise GridError('Invalid which_shifts')
# Turn surface_normal into its integer representation
if isinstance(surface_normal, Direction):
surface_normal = surface_normal.value
if surface_normal not in range(3):
raise GridError('Invalid surface_normal direction')
@ -56,9 +52,9 @@ def get_slice(self,
center3 = numpy.insert([0, 0], surface_normal, (center,))
center_index = self.pos2ind(center3, which_shifts,
round_ind=False, check_bounds=False)[surface_normal]
centers = numpy.unique([floor(center_index), ceil(center_index)]).astype(int)
centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
if len(centers) == 2:
fpart = center_index - floor(center_index)
fpart = center_index - numpy.floor(center_index)
w = [1 - fpart, fpart] # longer distance -> less weight
else:
w = [1]
@ -68,10 +64,10 @@ def get_slice(self,
raise GridError('Coordinate of selected plane must be within simulation domain')
# Extract grid values from planes above and below visualized slice
sliced_grid = zeros(self.shape[surface])
sliced_grid = numpy.zeros(self.shape[surface])
for ci, weight in zip(centers, w):
s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3))
sliced_grid += weight * self.grids[which_shifts][tuple(s)]
sliced_grid += weight * cell_data[which_shifts][tuple(s)]
# Remove extra dimensions
sliced_grid = numpy.squeeze(sliced_grid)
@ -80,7 +76,8 @@ def get_slice(self,
def visualize_slice(self,
surface_normal: Union[Direction, int],
cell_data: numpy.ndarray,
surface_normal: int,
center: float,
which_shifts: int = 0,
sample_period: int = 1,
@ -92,8 +89,7 @@ def visualize_slice(self,
Interpolates if given a position between two planes.
Args:
surface_normal: Axis normal to the plane we're displaying. Can be a `Direction` or
integer in `range(3)`
surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
center: Scalar specifying position along surface_normal axis.
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
@ -101,14 +97,11 @@ def visualize_slice(self,
"""
from matplotlib import pyplot
# Set surface normal to its integer value
if isinstance(surface_normal, Direction):
surface_normal = surface_normal.value
if pcolormesh_args is None:
pcolormesh_args = {}
grid_slice = self.get_slice(surface_normal=surface_normal,
grid_slice = self.get_slice(cell_data=cell_data,
surface_normal=surface_normal,
center=center,
which_shifts=which_shifts,
sample_period=sample_period)
@ -130,6 +123,7 @@ def visualize_slice(self,
def visualize_isosurface(self,
cell_data: numpy.ndarray,
level: Optional[float] = None,
which_shifts: int = 0,
sample_period: int = 1,
@ -140,6 +134,7 @@ def visualize_isosurface(self,
Draw an isosurface plot of the device.
Args:
cell_data: Cell data to visualize
level: Value at which to find isosurface. Default (None) uses mean value in grid.
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
@ -151,8 +146,8 @@ def visualize_isosurface(self,
# Claims to be unused, but needed for subplot(projection='3d')
from mpl_toolkits.mplot3d import Axes3D
# Get data from self.grids
grid = self.grids[which_shifts][::sample_period, ::sample_period, ::sample_period]
# Get data from cell_data
grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period]
if level is None:
level = grid.mean()

Loading…
Cancel
Save