184 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			184 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
Readback and visualization methods for Grid class
 | 
						|
"""
 | 
						|
from typing import Dict, Optional, Union, Any
 | 
						|
 | 
						|
import numpy        # type: ignore
 | 
						|
 | 
						|
from . import GridError
 | 
						|
 | 
						|
# .visualize_* uses matplotlib
 | 
						|
# .visualize_isosurface uses skimage
 | 
						|
# .visualize_isosurface uses mpl_toolkits.mplot3d
 | 
						|
 | 
						|
 | 
						|
def get_slice(self,
 | 
						|
              cell_data: numpy.ndarray,
 | 
						|
              surface_normal: int,
 | 
						|
              center: float,
 | 
						|
              which_shifts: int = 0,
 | 
						|
              sample_period: int = 1
 | 
						|
              ) -> numpy.ndarray:
 | 
						|
    """
 | 
						|
    Retrieve a slice of a grid.
 | 
						|
    Interpolates if given a position between two planes.
 | 
						|
 | 
						|
    Args:
 | 
						|
        cell_data: Cell data to slice
 | 
						|
        surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
 | 
						|
        center: Scalar specifying position along surface_normal axis.
 | 
						|
        which_shifts: Which grid to display. Default is the first grid (0).
 | 
						|
        sample_period: Period for down-sampling the image. Default 1 (disabled)
 | 
						|
 | 
						|
    Returns:
 | 
						|
        Array containing the portion of the grid.
 | 
						|
    """
 | 
						|
    if numpy.size(center) != 1 or not numpy.isreal(center):
 | 
						|
        raise GridError('center must be a real scalar')
 | 
						|
 | 
						|
    sp = round(sample_period)
 | 
						|
    if sp <= 0:
 | 
						|
        raise GridError('sample_period must be positive')
 | 
						|
 | 
						|
    if numpy.size(which_shifts) != 1 or which_shifts < 0:
 | 
						|
        raise GridError('Invalid which_shifts')
 | 
						|
 | 
						|
    if surface_normal not in range(3):
 | 
						|
        raise GridError('Invalid surface_normal direction')
 | 
						|
 | 
						|
    surface = numpy.delete(range(3), surface_normal)
 | 
						|
 | 
						|
    # Extract indices and weights of planes
 | 
						|
    center3 = numpy.insert([0, 0], surface_normal, (center,))
 | 
						|
    center_index = self.pos2ind(center3, which_shifts,
 | 
						|
                                round_ind=False, check_bounds=False)[surface_normal]
 | 
						|
    centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int)
 | 
						|
    if len(centers) == 2:
 | 
						|
        fpart = center_index - numpy.floor(center_index)
 | 
						|
        w = [1 - fpart, fpart]  # longer distance -> less weight
 | 
						|
    else:
 | 
						|
        w = [1]
 | 
						|
 | 
						|
    c_min, c_max = (self.xyz[surface_normal][i] for i in [0, -1])
 | 
						|
    if center < c_min or center > c_max:
 | 
						|
        raise GridError('Coordinate of selected plane must be within simulation domain')
 | 
						|
 | 
						|
    # Extract grid values from planes above and below visualized slice
 | 
						|
    sliced_grid = numpy.zeros(self.shape[surface])
 | 
						|
    for ci, weight in zip(centers, w):
 | 
						|
        s = tuple(ci if a == surface_normal else numpy.s_[::sp] for a in range(3))
 | 
						|
        sliced_grid += weight * cell_data[which_shifts][tuple(s)]
 | 
						|
 | 
						|
    # Remove extra dimensions
 | 
						|
    sliced_grid = numpy.squeeze(sliced_grid)
 | 
						|
 | 
						|
    return sliced_grid
 | 
						|
 | 
						|
 | 
						|
def visualize_slice(self,
 | 
						|
                    cell_data: numpy.ndarray,
 | 
						|
                    surface_normal: int,
 | 
						|
                    center: float,
 | 
						|
                    which_shifts: int = 0,
 | 
						|
                    sample_period: int = 1,
 | 
						|
                    finalize: bool = True,
 | 
						|
                    pcolormesh_args: Optional[Dict[str, Any]] = None,
 | 
						|
                    ) -> None:
 | 
						|
    """
 | 
						|
    Visualize a slice of a grid.
 | 
						|
    Interpolates if given a position between two planes.
 | 
						|
 | 
						|
    Args:
 | 
						|
        surface_normal: Axis normal to the plane we're displaying. Integer in `range(3)`.
 | 
						|
        center: Scalar specifying position along surface_normal axis.
 | 
						|
        which_shifts: Which grid to display. Default is the first grid (0).
 | 
						|
        sample_period: Period for down-sampling the image. Default 1 (disabled)
 | 
						|
        finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
 | 
						|
    """
 | 
						|
    from matplotlib import pyplot
 | 
						|
 | 
						|
    if pcolormesh_args is None:
 | 
						|
        pcolormesh_args = {}
 | 
						|
 | 
						|
    grid_slice = self.get_slice(cell_data=cell_data,
 | 
						|
                                surface_normal=surface_normal,
 | 
						|
                                center=center,
 | 
						|
                                which_shifts=which_shifts,
 | 
						|
                                sample_period=sample_period)
 | 
						|
 | 
						|
    surface = numpy.delete(range(3), surface_normal)
 | 
						|
 | 
						|
    x, y = (self.shifted_exyz(which_shifts)[a] for a in surface)
 | 
						|
    xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
 | 
						|
    x_label, y_label = ('xyz'[a] for a in surface)
 | 
						|
 | 
						|
    pyplot.figure()
 | 
						|
    pyplot.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args)
 | 
						|
    pyplot.colorbar()
 | 
						|
    pyplot.gca().set_aspect('equal', adjustable='box')
 | 
						|
    pyplot.xlabel(x_label)
 | 
						|
    pyplot.ylabel(y_label)
 | 
						|
    if finalize:
 | 
						|
        pyplot.show()
 | 
						|
 | 
						|
 | 
						|
def visualize_isosurface(self,
 | 
						|
                         cell_data: numpy.ndarray,
 | 
						|
                         level: Optional[float] = None,
 | 
						|
                         which_shifts: int = 0,
 | 
						|
                         sample_period: int = 1,
 | 
						|
                         show_edges: bool = True,
 | 
						|
                         finalize: bool = True,
 | 
						|
                         ) -> None:
 | 
						|
    """
 | 
						|
    Draw an isosurface plot of the device.
 | 
						|
 | 
						|
    Args:
 | 
						|
        cell_data: Cell data to visualize
 | 
						|
        level: Value at which to find isosurface. Default (None) uses mean value in grid.
 | 
						|
        which_shifts: Which grid to display. Default is the first grid (0).
 | 
						|
        sample_period: Period for down-sampling the image. Default 1 (disabled)
 | 
						|
        show_edges: Whether to draw triangle edges. Default `True`
 | 
						|
        finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
 | 
						|
    """
 | 
						|
    from matplotlib import pyplot
 | 
						|
    import skimage.measure
 | 
						|
    # Claims to be unused, but needed for subplot(projection='3d')
 | 
						|
    from mpl_toolkits.mplot3d import Axes3D
 | 
						|
 | 
						|
    # Get data from cell_data
 | 
						|
    grid = cell_data[which_shifts][::sample_period, ::sample_period, ::sample_period]
 | 
						|
    if level is None:
 | 
						|
        level = grid.mean()
 | 
						|
 | 
						|
    # Find isosurface with marching cubes
 | 
						|
    verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level)
 | 
						|
 | 
						|
    # Convert vertices from index to position
 | 
						|
    pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False)
 | 
						|
                             for i in range(verts.shape[0])], dtype=float)
 | 
						|
    xs, ys, zs = (pos_verts[:, a] for a in range(3))
 | 
						|
 | 
						|
    # Draw the plot
 | 
						|
    fig = pyplot.figure()
 | 
						|
    ax = fig.add_subplot(111, projection='3d')
 | 
						|
    if show_edges:
 | 
						|
        ax.plot_trisurf(xs, ys, faces, zs)
 | 
						|
    else:
 | 
						|
        ax.plot_trisurf(xs, ys, faces, zs, edgecolor='none')
 | 
						|
 | 
						|
    # Add a fake plot of a cube to force the axes to be equal lengths
 | 
						|
    max_range = numpy.array([xs.max() - xs.min(),
 | 
						|
                             ys.max() - ys.min(),
 | 
						|
                             zs.max() - zs.min()], dtype=float).max()
 | 
						|
    mg = numpy.mgrid[-1:2:2, -1:2:2, -1:2:2]
 | 
						|
    xbs = 0.5 * max_range * mg[0].flatten() + 0.5 * (xs.max() + xs.min())
 | 
						|
    ybs = 0.5 * max_range * mg[1].flatten() + 0.5 * (ys.max() + ys.min())
 | 
						|
    zbs = 0.5 * max_range * mg[2].flatten() + 0.5 * (zs.max() + zs.min())
 | 
						|
    # Comment or uncomment following both lines to test the fake bounding box:
 | 
						|
    for xb, yb, zb in zip(xbs, ybs, zbs):
 | 
						|
        ax.plot([xb], [yb], [zb], 'w')
 | 
						|
 | 
						|
    if finalize:
 | 
						|
        pyplot.show()
 |