Browse Source

Add get_slice method for easy interpolation

tags/v0.2
jan 4 years ago
parent
commit
0a95bcda1b
1 changed files with 46 additions and 21 deletions
  1. +46
    -21
      gridlock/grid.py

+ 46
- 21
gridlock/grid.py View File

@@ -677,24 +677,23 @@ class Grid(object):

self.draw_polygon(direction, center, p, thickness, eps_func)

def visualize_slice(self,
surface_normal: Direction or int,
center: float,
which_shifts: int=0,
sample_period: int=1,
finalize: bool=True):
def get_slice(self,
surface_normal: Direction or int,
center: float,
which_shifts: int = 0,
sample_period: int = 1
) -> numpy.ndarray:
"""
Visualize a slice of a grid.
Interpolates if given a position between two planes.
Retrieve a slice of a grid.
Interpolates if given a position between two planes.

:param surface_normal: Axis normal to the plane we're displaying. Can be a Direction or
integer in range(3)
:param center: Scalar specifying position along surface_normal axis.
:param which_shifts: Which grid to display. Default is the first grid (0).
:param sample_period: Period for down-sampling the image. Default 1 (disabled)
:param surface_normal: Axis normal to the plane we're displaying. Can be a Direction or
integer in range(3)
:param center: Scalar specifying position along surface_normal axis.
:param which_shifts: Which grid to display. Default is the first grid (0).
:param sample_period: Period for down-sampling the image. Default 1 (disabled)
:return Array containing the portion of the grid.
"""
from matplotlib import pyplot

if not is_scalar(center) and numpy.isreal(center):
raise GridError('center must be a real scalar')

@@ -720,22 +719,48 @@ class Grid(object):
centers = numpy.unique([floor(center_index), ceil(center_index)]).astype(int)
if len(centers) == 2:
fpart = center_index - floor(center_index)
w = [1-fpart, fpart] # longer distance -> less weight
w = [1 - fpart, fpart] # longer distance -> less weight
else:
w = [1]

c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1])
if center < c_min or center > c_max:
raise GridError('Coordinate of visualized plane must be within simulation domain')
raise GridError('Coordinate of selected plane must be within simulation domain')

# Extract grid values from planes above and below visualized slice
eps = zeros(self.shape[surface])
sliced_grid = zeros(self.shape[surface])
for ci, weight in zip(centers, w):
s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3))
eps += weight * self.grids[which_shifts][tuple(s)]
sliced_grid += weight * self.grids[which_shifts][tuple(s)]

# Remove extra dimensions
eps = numpy.squeeze(eps)
sliced_grid = numpy.squeeze(sliced_grid)

return sliced_grid

def visualize_slice(self,
surface_normal: Direction or int,
center: float,
which_shifts: int = 0,
sample_period: int = 1,
finalize: bool = True):
"""
Visualize a slice of a grid.
Interpolates if given a position between two planes.

:param surface_normal: Axis normal to the plane we're displaying. Can be a Direction or
integer in range(3)
:param center: Scalar specifying position along surface_normal axis.
:param which_shifts: Which grid to display. Default is the first grid (0).
:param sample_period: Period for down-sampling the image. Default 1 (disabled)
:param finalize: Whether to call pyplot.show() after constructing the plot.
"""
from matplotlib import pyplot

grid_slice = self.get_slice(surface_normal=surface_normal,
center=center,
which_shifts=which_shifts,
sample_period=sample_period)

surface = numpy.delete(range(3), surface_normal)

@@ -744,7 +769,7 @@ class Grid(object):
x_label, y_label = ('xyz'[a] for a in surface)

pyplot.figure()
pyplot.pcolormesh(xmesh, ymesh, eps)
pyplot.pcolormesh(xmesh, ymesh, grid_slice)
pyplot.colorbar()
pyplot.gca().set_aspect('equal', adjustable='box')
pyplot.xlabel(x_label)


Loading…
Cancel
Save