Add get_slice method for easy interpolation
This commit is contained in:
parent
55d9f33090
commit
0a95bcda1b
@ -677,14 +677,14 @@ class Grid(object):
|
||||
|
||||
self.draw_polygon(direction, center, p, thickness, eps_func)
|
||||
|
||||
def visualize_slice(self,
|
||||
def get_slice(self,
|
||||
surface_normal: Direction or int,
|
||||
center: float,
|
||||
which_shifts: int=0,
|
||||
sample_period: int=1,
|
||||
finalize: bool=True):
|
||||
which_shifts: int = 0,
|
||||
sample_period: int = 1
|
||||
) -> numpy.ndarray:
|
||||
"""
|
||||
Visualize a slice of a grid.
|
||||
Retrieve a slice of a grid.
|
||||
Interpolates if given a position between two planes.
|
||||
|
||||
:param surface_normal: Axis normal to the plane we're displaying. Can be a Direction or
|
||||
@ -692,9 +692,8 @@ class Grid(object):
|
||||
:param center: Scalar specifying position along surface_normal axis.
|
||||
:param which_shifts: Which grid to display. Default is the first grid (0).
|
||||
:param sample_period: Period for down-sampling the image. Default 1 (disabled)
|
||||
:return Array containing the portion of the grid.
|
||||
"""
|
||||
from matplotlib import pyplot
|
||||
|
||||
if not is_scalar(center) and numpy.isreal(center):
|
||||
raise GridError('center must be a real scalar')
|
||||
|
||||
@ -720,22 +719,48 @@ class Grid(object):
|
||||
centers = numpy.unique([floor(center_index), ceil(center_index)]).astype(int)
|
||||
if len(centers) == 2:
|
||||
fpart = center_index - floor(center_index)
|
||||
w = [1-fpart, fpart] # longer distance -> less weight
|
||||
w = [1 - fpart, fpart] # longer distance -> less weight
|
||||
else:
|
||||
w = [1]
|
||||
|
||||
c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1])
|
||||
if center < c_min or center > c_max:
|
||||
raise GridError('Coordinate of visualized plane must be within simulation domain')
|
||||
raise GridError('Coordinate of selected plane must be within simulation domain')
|
||||
|
||||
# Extract grid values from planes above and below visualized slice
|
||||
eps = zeros(self.shape[surface])
|
||||
sliced_grid = zeros(self.shape[surface])
|
||||
for ci, weight in zip(centers, w):
|
||||
s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3))
|
||||
eps += weight * self.grids[which_shifts][tuple(s)]
|
||||
sliced_grid += weight * self.grids[which_shifts][tuple(s)]
|
||||
|
||||
# Remove extra dimensions
|
||||
eps = numpy.squeeze(eps)
|
||||
sliced_grid = numpy.squeeze(sliced_grid)
|
||||
|
||||
return sliced_grid
|
||||
|
||||
def visualize_slice(self,
|
||||
surface_normal: Direction or int,
|
||||
center: float,
|
||||
which_shifts: int = 0,
|
||||
sample_period: int = 1,
|
||||
finalize: bool = True):
|
||||
"""
|
||||
Visualize a slice of a grid.
|
||||
Interpolates if given a position between two planes.
|
||||
|
||||
:param surface_normal: Axis normal to the plane we're displaying. Can be a Direction or
|
||||
integer in range(3)
|
||||
:param center: Scalar specifying position along surface_normal axis.
|
||||
:param which_shifts: Which grid to display. Default is the first grid (0).
|
||||
:param sample_period: Period for down-sampling the image. Default 1 (disabled)
|
||||
:param finalize: Whether to call pyplot.show() after constructing the plot.
|
||||
"""
|
||||
from matplotlib import pyplot
|
||||
|
||||
grid_slice = self.get_slice(surface_normal=surface_normal,
|
||||
center=center,
|
||||
which_shifts=which_shifts,
|
||||
sample_period=sample_period)
|
||||
|
||||
surface = numpy.delete(range(3), surface_normal)
|
||||
|
||||
@ -744,7 +769,7 @@ class Grid(object):
|
||||
x_label, y_label = ('xyz'[a] for a in surface)
|
||||
|
||||
pyplot.figure()
|
||||
pyplot.pcolormesh(xmesh, ymesh, eps)
|
||||
pyplot.pcolormesh(xmesh, ymesh, grid_slice)
|
||||
pyplot.colorbar()
|
||||
pyplot.gca().set_aspect('equal', adjustable='box')
|
||||
pyplot.xlabel(x_label)
|
||||
|
Loading…
Reference in New Issue
Block a user