diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 171b4cc..d547794 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -16,7 +16,6 @@ Dependencies: - skimage [Grid.visualize_isosurface()] """ from .error import GridError -from .direction import Direction from .grid import Grid __author__ = 'Jan Petykiewicz' diff --git a/gridlock/draw.py b/gridlock/draw.py index 15a6bff..ca00e6b 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -4,19 +4,17 @@ Drawing-related methods for Grid class from typing import List, Optional, Union, Sequence, Callable import numpy # type: ignore -from numpy import diff, floor, ceil, zeros, hstack, newaxis - from float_raster import raster -from . import GridError, Direction -from ._helpers import is_scalar +from . import GridError eps_callable_t = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray] def draw_polygons(self, - surface_normal: Union[Direction, int], + cell_data: numpy.ndarray, + surface_normal: int, center: numpy.ndarray, polygons: Sequence[numpy.ndarray], thickness: float, @@ -26,8 +24,8 @@ def draw_polygons(self, Draw polygons on an axis-aligned plane. Args: - surface_normal: Axis normal to the plane we're drawing on. Can be a `Direction` or - integer in `range(3)` + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. center: 3-element ndarray or list specifying an offset applied to all the polygons polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon (non-closed, clockwise). If Nx3, the surface_normal coordinate is ignored. Each @@ -41,11 +39,6 @@ def draw_polygons(self, Raises: GridError """ - # Turn surface_normal into its integer representation - if isinstance(surface_normal, Direction): - surface_normal = surface_normal.value - assert(isinstance(surface_normal, int)) - if surface_normal not in range(3): raise GridError('Invalid surface_normal direction') @@ -55,7 +48,7 @@ def draw_polygons(self, surface = numpy.delete(range(3), surface_normal) for i, polygon in enumerate(polygons): - malformed = 'Malformed polygon: (%i)' % i + malformed = f'Malformed polygon: ({i})' if polygon.shape[1] not in (2, 3): raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') if polygon.shape[1] == 3: @@ -64,12 +57,12 @@ def draw_polygons(self, if not polygon.shape[0] > 2: raise GridError(malformed + 'must consist of more than 2 points') if polygon.ndim > 2 and not numpy.unique(polygon[:, surface_normal]).size == 1: - raise GridError(malformed + 'must be in plane with surface normal %s' - % 'xyz'[surface_normal]) + raise GridError(malformed + 'must be in plane with surface normal ' + + 'xyz'[surface_normal]) # Broadcast eps where necessary - if is_scalar(eps): - eps = [eps] * len(self.grids) + if numpy.size(eps) == 1: + eps = [eps] * len(cell_data) elif isinstance(eps, numpy.ndarray): raise GridError('ndarray not supported for eps') @@ -91,8 +84,8 @@ def draw_polygons(self, s_max = self.shifts.max(axis=0) bdi_min = self.pos2ind(bd_min + s_min, None, round_ind=False, check_bounds=False) - buf bdi_max = self.pos2ind(bd_max + s_max, None, round_ind=False, check_bounds=False) + buf - bdi_min = numpy.maximum(floor(bdi_min), 0).astype(int) - bdi_max = numpy.minimum(ceil(bdi_max), self.shape - 1).astype(int) + bdi_min = numpy.maximum(numpy.floor(bdi_min), 0).astype(int) + bdi_max = numpy.minimum(numpy.ceil(bdi_max), self.shape - 1).astype(int) # 3) Adjust polygons for center polygons = [poly + center[surface] for poly in polygons] @@ -103,7 +96,7 @@ def draw_polygons(self, return numpy.insert(v_2d, surface_normal, (val,)) # iterate over grids - for i, grid in enumerate(self.grids): + for i, grid in enumerate(cell_data): # ## Evaluate or expand eps[i] if callable(eps[i]): # meshgrid over the (shifted) domain @@ -113,14 +106,14 @@ def draw_polygons(self, # evaluate on the meshgrid eps_i = eps[i](x0, y0, z0) if not numpy.isfinite(eps_i).all(): - raise GridError('Non-finite values in eps[%u]' % i) - elif not is_scalar(eps[i]): - raise GridError('Unsupported eps[{}]: {}'.format(i, type(eps[i]))) + raise GridError(f'Non-finite values in eps[{i}]') + elif numpy.size(eps[i]) != 1: + raise GridError(f'Unsupported eps[{i}]: {type(eps[i])}') else: # eps[i] is scalar non-callable eps_i = eps[i] - w_xy = zeros((bdi_max - bdi_min + 1)[surface].astype(int)) + w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int)) # Draw each polygon separately for polygon in polygons: @@ -182,15 +175,16 @@ def draw_polygons(self, w_z[zi_bot] = zi_top_f - zi_bot_f # 3) Generate total weight function - w = (w_xy[:, :, newaxis] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,))) + w = (w_xy[:, :, None] * w_z).transpose(numpy.insert([0, 1], surface_normal, (2,))) # ## Modify the grid g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3)) - self.grids[g_slice] = (1 - w) * self.grids[g_slice] + w * eps_i + cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * eps_i def draw_polygon(self, - surface_normal: Union[Direction, int], + cell_data: numpy.ndarray, + surface_normal: int, center: numpy.ndarray, polygon: numpy.ndarray, thickness: float, @@ -200,20 +194,21 @@ def draw_polygon(self, Draw a polygon on an axis-aligned plane. Args: - surface_normal: Axis normal to the plane we're drawing on. Can be a Direction or - integer in range(3) + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. center: 3-element ndarray or list specifying an offset applied to the polygon polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, - clockwise). If Nx3, the surface_normal coordinate is ignored. Must have at + clockwise). If Nx3, the `surface_normal` coordinate is ignored. Must have at least 3 vertices. thickness: Thickness of the layer to draw eps: Value to draw with ('epsilon'). See `draw_polygons()` for details. """ - self.draw_polygons(surface_normal, center, [polygon], thickness, eps) + self.draw_polygons(cell_data, surface_normal, center, [polygon], thickness, eps) def draw_slab(self, - surface_normal: Union[Direction, int], + cell_data: numpy.ndarray, + surface_normal: int, center: numpy.ndarray, thickness: float, eps: Union[List[Union[float, eps_callable_t]], float, eps_callable_t], @@ -222,24 +217,22 @@ def draw_slab(self, Draw an axis-aligned infinite slab. Args: - surface_normal: Axis normal to the plane we're drawing on. Can be a `Direction` or - integer in `range(3)` + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. center: Surface_normal coordinate at the center of the slab thickness: Thickness of the layer to draw eps: Value to draw with ('epsilon'). See `draw_polygons()` for details. """ # Turn surface_normal into its integer representation - if isinstance(surface_normal, Direction): - surface_normal = surface_normal.value if surface_normal not in range(3): raise GridError('Invalid surface_normal direction') - if not is_scalar(center): + if numpy.size(center) != 1: center = numpy.squeeze(center) if len(center) == 3: center = center[surface_normal] else: - raise GridError('Bad center: {}'.format(center)) + raise GridError(f'Bad center: {center}') # Find center of slab center_shift = self.center @@ -260,10 +253,11 @@ def draw_slab(self, [xyz_max[0], xyz_min[1]], [xyz_min[0], xyz_min[1]]], dtype=float) - self.draw_polygon(surface_normal, center_shift, p, thickness, eps) + self.draw_polygon(cell_data, surface_normal, center_shift, p, thickness, eps) def draw_cuboid(self, + cell_data: numpy.ndarray, center: numpy.ndarray, dimensions: numpy.ndarray, eps: Union[List[Union[float, eps_callable_t]], float, eps_callable_t], @@ -272,6 +266,7 @@ def draw_cuboid(self, Draw an axis-aligned cuboid Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) center: 3-element ndarray or list specifying the cuboid's center dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge sizes of the cuboid @@ -282,11 +277,12 @@ def draw_cuboid(self, [+dimensions[0], -dimensions[1]], [-dimensions[0], -dimensions[1]]], dtype=float) / 2.0 thickness = dimensions[2] - self.draw_polygon(Direction.z, center, p, thickness, eps) + self.draw_polygon(cell_data, 2, center, p, thickness, eps) def draw_cylinder(self, - surface_normal: Union[Direction, int], + cell_data: numpy.ndarray, + surface_normal: int, center: numpy.ndarray, radius: float, thickness: float, @@ -297,8 +293,8 @@ def draw_cylinder(self, Draw an axis-aligned cylinder. Approximated by a num_points-gon Args: - surface_normal: Axis normal to the plane we're drawing on. Can be a `Direction` or - integer in `range(3)` + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + surface_normal: Axis normal to the plane we're drawing on. Integer in `range(3)`. center: 3-element ndarray or list specifying the cylinder's center radius: cylinder radius thickness: Thickness of the layer to draw @@ -308,13 +304,14 @@ def draw_cylinder(self, theta = numpy.linspace(0, 2*numpy.pi, num_points, endpoint=False) x = radius * numpy.sin(theta) y = radius * numpy.cos(theta) - polygon = hstack((x[:, newaxis], y[:, newaxis])) - self.draw_polygon(surface_normal, center, polygon, thickness, eps) + polygon = numpy.hstack((x[:, None], y[:, None])) + self.draw_polygon(cell_data, surface_normal, center, polygon, thickness, eps) def draw_extrude_rectangle(self, + cell_data: numpy.ndarray, rectangle: numpy.ndarray, - direction: Union[Direction, int], + direction: int, polarity: int, distance: float, ) -> None: @@ -322,23 +319,19 @@ def draw_extrude_rectangle(self, Extrude a rectangle of a previously-drawn structure along an axis. Args: + cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) rectangle: 2x3 ndarray or list specifying the rectangle's corners - direction: Direction to extrude in. Direction enum or int in range(3) + direction: Direction to extrude in. Integer in `range(3)`. polarity: +1 or -1, direction along axis to extrude in distance: How far to extrude """ - # Turn extrude_direction into its integer representation - if isinstance(direction, Direction): - direction = direction.value - assert(isinstance(direction, int)) - s = numpy.sign(polarity) rectangle = numpy.array(rectangle, dtype=float) if s == 0: raise GridError('0 is not a valid polarity') if direction not in range(3): - raise GridError('Invalid direction: {}'.format(direction)) + raise GridError(f'Invalid direction: {direction}') if rectangle[0, direction] != rectangle[1, direction]: raise GridError('Rectangle entries along extrusion direction do not match.') @@ -347,18 +340,18 @@ def draw_extrude_rectangle(self, surface = numpy.delete(range(3), direction) - dim = numpy.fabs(diff(rectangle, axis=0).T)[surface] + dim = numpy.fabs(numpy.diff(rectangle, axis=0).T)[surface] p = numpy.vstack((numpy.array([-1, -1, 1, 1], dtype=float) * dim[0]/2.0, numpy.array([-1, 1, 1, -1], dtype=float) * dim[1]/2.0)).T thickness = distance eps_func = [] - for i, grid in enumerate(self.grids): + for i, grid in enumerate(cell_data): z = self.pos2ind(rectangle[0, :], i, round_ind=False, check_bounds=False)[direction] - ind = [int(floor(z)) if i == direction else slice(None) for i in range(3)] + ind = [int(numpy.floor(z)) if i == direction else slice(None) for i in range(3)] - fpart = z - floor(z) + fpart = z - numpy.floor(z) mult = [1-fpart, fpart][::s] # reverses if s negative eps = mult[0] * grid[tuple(ind)] @@ -375,5 +368,5 @@ def draw_extrude_rectangle(self, eps_func.append(f_eps) - self.draw_polygon(direction, center, p, thickness, eps_func) + self.draw_polygon(cell_data, direction, center, p, thickness, eps_func) diff --git a/gridlock/examples/ex0.py b/gridlock/examples/ex0.py new file mode 100644 index 0000000..0e1cd86 --- /dev/null +++ b/gridlock/examples/ex0.py @@ -0,0 +1,44 @@ +import numpy # type: ignore +from gridlock import Grid + + +if __name__ == '__main__': + # xyz = [numpy.arange(-5.0, 6.0), numpy.arange(-4.0, 5.0), [-1.0, 1.0]] + # eg = Grid(xyz) + # egc = Grid.allocate(0.0) + # # eg.draw_slab(egc, surface_normal=2, center=0, thickness=10, eps=2) + # eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=4, + # thickness=10, num_points=1000, eps=1) + # eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2) + + # xyz2 = [numpy.arange(-5.0, 6.0), [-1.0, 1.0], numpy.arange(-4.0, 5.0)] + # eg2 = Grid(xyz2) + # eg2c = Grid.allocate(0.0) + # # eg2.draw_slab(eg2c, surface_normal=2, center=0, thickness=10, eps=2) + # eg2.draw_cylinder(eg2c, surface_normal=1, center=[0, 0, 0], + # radius=4, thickness=10, num_points=1000, eps=1.0) + # eg2.visualize_slice(eg2c, surface_normal=1, center=0, which_shifts=1) + + # n = 20 + # m = 3 + # r1 = numpy.fromfunction(lambda x: numpy.sign(x - n) * 2 ** (abs(x - n)/m), (2*n, )) + # print(r1) + # xyz3 = [r1, numpy.linspace(-5.5, 5.5, 30), numpy.linspace(-5.5, 5.5, 10)] + # xyz3 = [numpy.linspace(-5.5, 5.5, 10), + # numpy.linspace(-5.5, 5.5, 10), + # numpy.linspace(-5.5, 5.5, 10)] + + half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5] + xyz3 = [[-x for x in half_x[::-1]] + [0] + half_x, + numpy.linspace(-5.5, 5.5, 10), + numpy.linspace(-5.5, 5.5, 10)] + eg = Grid(xyz3) + egc = eg.allocate(0) + # eg.draw_slab(Direction.z, 0, 10, 2) + eg.save('/home/jan/Desktop/test.pickle') + eg.draw_cylinder(egc, surface_normal=2, center=[0, 0, 0], radius=2.0, + thickness=10, num_poitns=1000, eps=1) + eg.draw_extrude_rectangle(egc, rectangle=[[-2, 1, -1], [0, 1, 1]], + direction=1, poalarity=+1, distance=5) + eg.visualize_slice(egc, surface_normal=2, center=0, which_shifts=2) + eg.visualize_isosurface(egc, which_shifts=2) diff --git a/gridlock/grid.py b/gridlock/grid.py index 7dc62c6..0a77ba1 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Callable, Dict, Optional, Union, Sequence, ClassVar +from typing import List, Tuple, Callable, Dict, Optional, Union, Sequence, ClassVar, TypeVar import numpy # type: ignore from numpy import diff, floor, ceil, zeros, hstack, newaxis @@ -7,23 +7,27 @@ import pickle import warnings import copy -from . import GridError, Direction -from ._helpers import is_scalar +from . import GridError -__author__ = 'Jan Petykiewicz' - eps_callable_type = Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray] +T = TypeVar('T', bound='Grid') class Grid: """ - Simulation grid generator intended for electromagnetic simulations. - Can be used to generate non-uniform rectangular grids (the entire grid + Simulation grid metadata for finite-difference simulations. + + Can be used to generate non-uniform rectangular grids (the entire grid is generated based on the coordinates of the boundary points). Also does straightforward natural <-> grid unit conversion. - `self.grids[i][a,b,c]` contains the value of epsilon for the cell located around + This class handles data describing the grid, and should be paired with a + (separate) ndarray that contains the actual data in each cell. The `allocate()` + method can be used to create this ndarray. + + The resulting `cell_data[i, a, b, c]` should correspond to the value in the + `i`-th grid, in the cell centered around ``` (xyz[0][a] + dxyz[0][a] * shifts[i, 0], xyz[1][b] + dxyz[1][b] * shifts[i, 1], @@ -47,9 +51,6 @@ class Grid: exyz: List[numpy.ndarray] """Cell edges. Monotonically increasing without duplicates.""" - grids: numpy.ndarray - """epsilon (or mu, or whatever) grids. shape is (num_grids, X, Y, Z)""" - periodic: List[bool] """For each axis, determines how far the rightmost boundary gets shifted. """ @@ -81,7 +82,7 @@ class Grid: Returns: List of 3 ndarrays of cell sizes """ - return [diff(self.exyz[a]) for a in range(3)] + return [numpy.diff(ee) for ee in self.exyz] @property def xyz(self) -> List[numpy.ndarray]: @@ -103,6 +104,20 @@ class Grid: """ return numpy.array([coord.size - 1 for coord in self.exyz], dtype=int) + @property + def num_grids(self) -> int: + """ + The number of grids (number of shifts) + """ + return self.shifts.shape[0] + + @property + def cell_data_shape(self): + """ + The shape of the cell_data ndarray (num_grids, *self.shape). + """ + return numpy.hstack((self.num_grids, self.shape)) + @property def dxyz_with_ghost(self) -> List[numpy.ndarray]: """ @@ -117,7 +132,7 @@ class Grid: list of [dxs, dys, dzs] with each element same length as elements of `self.xyz` """ el = [0 if p else -1 for p in self.periodic] - return [hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el)] + return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el)] @property def center(self) -> numpy.ndarray: @@ -211,23 +226,37 @@ class Grid: dxyz = self.shifted_dxyz(which_shifts) return [exyz[a][:-1] + dxyz[a] / 2.0 for a in range(3)] - def autoshifted_dxyz(self): + def autoshifted_dxyz(self) -> List[numpy.ndarray]: """ Return cell widths, with each dimension shifted by the corresponding shifts. Returns: `[grid.shifted_dxyz(which_shifts=a)[a] for a in range(3)]` """ - if len(self.grids) != 3: - raise GridError('autoshifting requires exactly 3 grids') + if self.num_grids != 3: + raise GridError('Autoshifting requires exactly 3 grids') return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] + def allocate(self, fill_value: Optional[float] = 1.0, dtype=numpy.float64) -> numpy.ndarray: + """ + Allocate an ndarray for storing grid data. + + Args: + fill_value: Value to initialize the grid to. If None, an + uninitialized array is returned. + dtype: Numpy dtype for the array. Default is `numpy.float64`. + + Returns: + The allocated array + """ + if fill_value is None: + return numpy.empty(self.cell_data_shape) + else: + return numpy.full(self.cell_data_shape, fill_value) def __init__(self, pixel_edge_coordinates: Sequence[numpy.ndarray], shifts: numpy.ndarray = Yee_Shifts_E, - initial: Union[float, numpy.ndarray] = 1.0, - num_grids: Optional[int] = None, periodic: Union[bool, Sequence[bool]] = False, ) -> None: """ @@ -238,12 +267,6 @@ class Grid: x=`x1`, the second has edges x=`x1` and x=`x2`, etc.) shifts: Nx3 array containing `[x, y, z]` offsets for each of N grids. E-field Yee shifts are used by default. - initial: Grids are initialized to this value. If scalar, all grids are initialized - with ndarrays full of the scalar. If a list of scalars, `grid[i]` is initialized to an - ndarray full of `initial[i]`. If a list of ndarrays of the same shape as the grids, `grid[i]` - is set to `initial[i]`. Default `1.0`. - num_grids: How many grids to create. Must be <= `shifts.shape[0]`. - Default is `shifts.shape[0]` periodic: Specifies how the sizes of edge cells are calculated; see main class documentation. List of 3 bool, or a single bool that gets broadcast. Default `False`. @@ -255,7 +278,7 @@ class Grid: for i in range(3): if len(self.exyz[i]) != len(pixel_edge_coordinates[i]): - warnings.warn('Dimension {} had duplicate edge coordinates'.format(i), stacklevel=2) + warnings.warn(f'Dimension {i} had duplicate edge coordinates', stacklevel=2) if isinstance(periodic, bool): self.periodic = [periodic] * 3 @@ -264,10 +287,10 @@ class Grid: if len(self.shifts.shape) != 2: raise GridError('Misshapen shifts: shifts must have two axes! ' - ' The given shifts has shape {}'.format(self.shifts.shape)) + f' The given shifts has shape {self.shifts.shape}') if self.shifts.shape[1] != 3: raise GridError('Misshapen shifts; second axis size should be 3,' - ' shape is {}'.format(self.shifts.shape)) + f' shape is {self.shifts.shape}') if (numpy.abs(self.shifts) > 1).any(): raise GridError('Only shifts in the range [-1, 1] are currently supported') @@ -276,33 +299,6 @@ class Grid: # TODO: Test negative shifts warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2) - num_shifts = self.shifts.shape[0] - if num_grids is None: - num_grids = num_shifts - elif num_grids > num_shifts: - raise GridError('Number of grids exceeds number of shifts (%u)' % num_shifts) - - grids_shape = hstack((num_grids, self.shape)) - if isinstance(initial, (float, int)): - if isinstance(initial, int): - warnings.warn('Initial value is an int, grids will be integer-typed!', stacklevel=2) - self.grids = numpy.full(grids_shape, initial) - else: - if len(initial) < num_grids: - raise GridError('Too few initial grids specified!') - - self.grids = numpy.empty(grids_shape) - for i in range(num_grids): - if is_scalar(initial[i]): - if initial[i] is not None: - if isinstance(initial[i], int): - warnings.warn('Initial value is an int, grid {} will be integer-typed!'.format(i), stacklevel=2) - self.grids[i] = numpy.full(self.shape, initial[i]) - else: - if not numpy.array_equal(initial[i].shape, self.shape): - raise GridError('Initial grid sizes must match given coordinates') - self.grids[i] = initial[i] - @staticmethod def load(filename: str) -> 'Grid': """ @@ -318,17 +314,21 @@ class Grid: g.__dict__.update(tmp_dict) return g - def save(self, filename: str): + def save(self: T, filename: str) -> T: """ Save to file. Args: filename: Filename to save to. + + Returns: + self """ with open(filename, 'wb') as f: pickle.dump(self.__dict__, f, protocol=2) + return self - def copy(self): + def copy(self: T) -> T: """ Returns: Deep copy of the grid. diff --git a/gridlock/position.py b/gridlock/position.py index 282824f..1224a12 100644 --- a/gridlock/position.py +++ b/gridlock/position.py @@ -4,7 +4,6 @@ Position-related methods for Grid class from typing import List, Optional import numpy # type: ignore -from numpy import zeros from . import GridError @@ -47,7 +46,7 @@ def ind2pos(self, low_bound = -0.5 high_bound = -0.5 if (ind < low_bound).any() or (ind > self.shape - high_bound).any(): - raise GridError('Position outside of grid: {}'.format(ind)) + raise GridError(f'Position outside of grid: {ind}') if round_ind: rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) @@ -85,19 +84,19 @@ def pos2ind(self, """ r = numpy.squeeze(r) if r.size != 3: - raise GridError('r must be 3-element vector: {}'.format(r)) + raise GridError(f'r must be 3-element vector: {r}') if (which_shifts is not None) and (which_shifts >= self.shifts.shape[0]): - raise GridError('Invalid which_shifts: {}'.format(which_shifts)) + raise GridError(f'Invalid which_shifts: {which_shifts}') sexyz = self.shifted_exyz(which_shifts) if check_bounds: for a in range(3): if self.shape[a] > 1 and (r[a] < sexyz[a][0] or r[a] > sexyz[a][-1]): - raise GridError('Position[{}] outside of grid!'.format(a)) + raise GridError(f'Position[{a}] outside of grid!') - grid_pos = zeros((3,)) + grid_pos = numpy.zeros((3,)) for a in range(3): xi = numpy.digitize(r[a], sexyz[a]) - 1 # Figure out which cell we're in xi_clipped = numpy.clip(xi, 0, sexyz[a].size - 2) # Clip back into grid bounds diff --git a/gridlock/read.py b/gridlock/read.py index 32e145e..aa059d5 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -4,10 +4,8 @@ Readback and visualization methods for Grid class from typing import Dict, Optional, Union, Any import numpy # type: ignore -from numpy import floor, ceil, zeros -from . import GridError, Direction -from ._helpers import is_scalar +from . import GridError # .visualize_* uses matplotlib # .visualize_isosurface uses skimage @@ -15,7 +13,8 @@ from ._helpers import is_scalar def get_slice(self, - surface_normal: Union[Direction, int], + cell_data: numpy.ndarray, + surface_normal: int, center: float, which_shifts: int = 0, sample_period: int = 1 @@ -25,8 +24,8 @@ def get_slice(self, Interpolates if given a position between two planes. Args: - surface_normal: Axis normal to the plane we're displaying. Can be a `Direction` or - integer in `range(3)` + cell_data: Cell data to slice + surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. center: Scalar specifying position along surface_normal axis. which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) @@ -34,19 +33,16 @@ def get_slice(self, Returns: Array containing the portion of the grid. """ - if not is_scalar(center) and numpy.isreal(center): + if numpy.size(center) != 1 or not numpy.isreal(center): raise GridError('center must be a real scalar') sp = round(sample_period) if sp <= 0: raise GridError('sample_period must be positive') - if not is_scalar(which_shifts) or which_shifts < 0: + if numpy.size(which_shifts) != 1 or which_shifts < 0: raise GridError('Invalid which_shifts') - # Turn surface_normal into its integer representation - if isinstance(surface_normal, Direction): - surface_normal = surface_normal.value if surface_normal not in range(3): raise GridError('Invalid surface_normal direction') @@ -56,9 +52,9 @@ def get_slice(self, center3 = numpy.insert([0, 0], surface_normal, (center,)) center_index = self.pos2ind(center3, which_shifts, round_ind=False, check_bounds=False)[surface_normal] - centers = numpy.unique([floor(center_index), ceil(center_index)]).astype(int) + centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) if len(centers) == 2: - fpart = center_index - floor(center_index) + fpart = center_index - numpy.floor(center_index) w = [1 - fpart, fpart] # longer distance -> less weight else: w = [1] @@ -68,10 +64,10 @@ def get_slice(self, raise GridError('Coordinate of selected plane must be within simulation domain') # Extract grid values from planes above and below visualized slice - sliced_grid = zeros(self.shape[surface]) + sliced_grid = numpy.zeros(self.shape[surface]) for ci, weight in zip(centers, w): s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3)) - sliced_grid += weight * self.grids[which_shifts][tuple(s)] + sliced_grid += weight * cell_data[which_shifts][tuple(s)] # Remove extra dimensions sliced_grid = numpy.squeeze(sliced_grid) @@ -80,7 +76,8 @@ def get_slice(self, def visualize_slice(self, - surface_normal: Union[Direction, int], + cell_data: numpy.ndarray, + surface_normal: int, center: float, which_shifts: int = 0, sample_period: int = 1, @@ -92,8 +89,7 @@ def visualize_slice(self, Interpolates if given a position between two planes. Args: - surface_normal: Axis normal to the plane we're displaying. Can be a `Direction` or - integer in `range(3)` + surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`. center: Scalar specifying position along surface_normal axis. which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) @@ -101,14 +97,11 @@ def visualize_slice(self, """ from matplotlib import pyplot - # Set surface normal to its integer value - if isinstance(surface_normal, Direction): - surface_normal = surface_normal.value - if pcolormesh_args is None: pcolormesh_args = {} - grid_slice = self.get_slice(surface_normal=surface_normal, + grid_slice = self.get_slice(cell_data=cell_data, + surface_normal=surface_normal, center=center, which_shifts=which_shifts, sample_period=sample_period) @@ -130,6 +123,7 @@ def visualize_slice(self, def visualize_isosurface(self, + cell_data: numpy.ndarray, level: Optional[float] = None, which_shifts: int = 0, sample_period: int = 1, @@ -140,6 +134,7 @@ def visualize_isosurface(self, Draw an isosurface plot of the device. Args: + cell_data: Cell data to visualize level: Value at which to find isosurface. Default (None) uses mean value in grid. which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) @@ -151,8 +146,8 @@ def visualize_isosurface(self, # Claims to be unused, but needed for subplot(projection='3d') from mpl_toolkits.mplot3d import Axes3D - # Get data from self.grids - grid = self.grids[which_shifts][::sample_period, ::sample_period, ::sample_period] + # Get data from cell_data + grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period] if level is None: level = grid.mean()