Add get_slice method for easy interpolation

This commit is contained in:
jan 2016-05-25 21:25:43 -07:00
parent 55d9f33090
commit 0a95bcda1b

View File

@ -677,14 +677,14 @@ class Grid(object):
self.draw_polygon(direction, center, p, thickness, eps_func) self.draw_polygon(direction, center, p, thickness, eps_func)
def visualize_slice(self, def get_slice(self,
surface_normal: Direction or int, surface_normal: Direction or int,
center: float, center: float,
which_shifts: int = 0, which_shifts: int = 0,
sample_period: int=1, sample_period: int = 1
finalize: bool=True): ) -> numpy.ndarray:
""" """
Visualize a slice of a grid. Retrieve a slice of a grid.
Interpolates if given a position between two planes. Interpolates if given a position between two planes.
:param surface_normal: Axis normal to the plane we're displaying. Can be a Direction or :param surface_normal: Axis normal to the plane we're displaying. Can be a Direction or
@ -692,9 +692,8 @@ class Grid(object):
:param center: Scalar specifying position along surface_normal axis. :param center: Scalar specifying position along surface_normal axis.
:param which_shifts: Which grid to display. Default is the first grid (0). :param which_shifts: Which grid to display. Default is the first grid (0).
:param sample_period: Period for down-sampling the image. Default 1 (disabled) :param sample_period: Period for down-sampling the image. Default 1 (disabled)
:return Array containing the portion of the grid.
""" """
from matplotlib import pyplot
if not is_scalar(center) and numpy.isreal(center): if not is_scalar(center) and numpy.isreal(center):
raise GridError('center must be a real scalar') raise GridError('center must be a real scalar')
@ -726,16 +725,42 @@ class Grid(object):
c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1]) c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1])
if center < c_min or center > c_max: if center < c_min or center > c_max:
raise GridError('Coordinate of visualized plane must be within simulation domain') raise GridError('Coordinate of selected plane must be within simulation domain')
# Extract grid values from planes above and below visualized slice # Extract grid values from planes above and below visualized slice
eps = zeros(self.shape[surface]) sliced_grid = zeros(self.shape[surface])
for ci, weight in zip(centers, w): for ci, weight in zip(centers, w):
s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3)) s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3))
eps += weight * self.grids[which_shifts][tuple(s)] sliced_grid += weight * self.grids[which_shifts][tuple(s)]
# Remove extra dimensions # Remove extra dimensions
eps = numpy.squeeze(eps) sliced_grid = numpy.squeeze(sliced_grid)
return sliced_grid
def visualize_slice(self,
surface_normal: Direction or int,
center: float,
which_shifts: int = 0,
sample_period: int = 1,
finalize: bool = True):
"""
Visualize a slice of a grid.
Interpolates if given a position between two planes.
:param surface_normal: Axis normal to the plane we're displaying. Can be a Direction or
integer in range(3)
:param center: Scalar specifying position along surface_normal axis.
:param which_shifts: Which grid to display. Default is the first grid (0).
:param sample_period: Period for down-sampling the image. Default 1 (disabled)
:param finalize: Whether to call pyplot.show() after constructing the plot.
"""
from matplotlib import pyplot
grid_slice = self.get_slice(surface_normal=surface_normal,
center=center,
which_shifts=which_shifts,
sample_period=sample_period)
surface = numpy.delete(range(3), surface_normal) surface = numpy.delete(range(3), surface_normal)
@ -744,7 +769,7 @@ class Grid(object):
x_label, y_label = ('xyz'[a] for a in surface) x_label, y_label = ('xyz'[a] for a in surface)
pyplot.figure() pyplot.figure()
pyplot.pcolormesh(xmesh, ymesh, eps) pyplot.pcolormesh(xmesh, ymesh, grid_slice)
pyplot.colorbar() pyplot.colorbar()
pyplot.gca().set_aspect('equal', adjustable='box') pyplot.gca().set_aspect('equal', adjustable='box')
pyplot.xlabel(x_label) pyplot.xlabel(x_label)