Compare commits

..

8 Commits

View File

@ -87,6 +87,7 @@ class GridReadMixin(GridPosMixin):
sample_period: int = 1, sample_period: int = 1,
finalize: bool = True, finalize: bool = True,
pcolormesh_args: dict[str, Any] | None = None, pcolormesh_args: dict[str, Any] | None = None,
ax: 'matplotlib.axes.Axes | None' = None,
) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']:
""" """
Visualize a slice of a grid. Visualize a slice of a grid.
@ -98,6 +99,8 @@ class GridReadMixin(GridPosMixin):
which_shifts: Which grid to display. Default is the first grid (0). which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled) sample_period: Period for down-sampling the image. Default 1 (disabled)
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
pcolormesh_args: Args passed through to matplotlib `pcolormesh()`
ax: If provided, plot to these axes (instead of creating a new figure & axes)
Returns: Returns:
(Figure, Axes) (Figure, Axes)
@ -111,10 +114,10 @@ class GridReadMixin(GridPosMixin):
pcolormesh_args = {} pcolormesh_args = {}
grid_slice = self.get_slice( grid_slice = self.get_slice(
cell_data=cell_data, cell_data = cell_data,
plane=plane, plane = plane,
which_shifts=which_shifts, which_shifts = which_shifts,
sample_period=sample_period, sample_period = sample_period,
) )
surface = numpy.delete(range(3), plane.axis) surface = numpy.delete(range(3), plane.axis)
@ -123,12 +126,93 @@ class GridReadMixin(GridPosMixin):
xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
x_label, y_label = ('xyz'[a] for a in surface) x_label, y_label = ('xyz'[a] for a in surface)
if ax is None:
fig, ax = pyplot.subplots() fig, ax = pyplot.subplots()
else:
fig = ax.figure
mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args)
fig.colorbar(mappable) fig.colorbar(mappable)
ax.set_aspect('equal', adjustable='box') ax.set_aspect('equal', adjustable='box')
ax.set_xlabel(x_label) ax.set_xlabel(x_label)
ax.set_ylabel(y_label) ax.set_ylabel(y_label)
if finalize:
pyplot.show()
return fig, ax
def visualize_edges(
self,
cell_data: NDArray,
plane: PlaneProtocol | PlaneDict,
which_shifts: int = 0,
sample_period: int = 1,
finalize: bool = True,
contour_args: dict[str, Any] | None = None,
ax: 'matplotlib.axes.Axes | None' = None,
level_fraction: float = 0.7,
) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']:
"""
Visualize the edges of a grid slice.
This is intended as an overlay on top of visualize_slice (e.g. showing epsilon boundaries
on an E-field plot).
Interpolates if given a position between two grid planes.
Args:
cell_data: Cell data to visualize
plane: Axis and position (`Plane`) of the plane to read.
which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled)
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
contour_args: Args passed through to matplotlib `pcolormesh()`
ax: If provided, plot to these axes (instead of creating a new figure & axes)
level_fraction: Value between 0 and 1 which tunes how many contours are generated.
1 indicates that every possible step should have its own contour.
Returns:
(Figure, Axes)
"""
from matplotlib import pyplot
if level_fraction > 1:
raise GridError(f'{level_fraction=} must be between 0 and 1')
if isinstance(plane, dict):
plane = Plane(**plane)
if contour_args is None:
contour_args = dict(alpha=0.8, colors='gray')
grid_slice = self.get_slice(
cell_data = cell_data,
plane = plane,
which_shifts = which_shifts,
sample_period = sample_period,
)
cvals, cval_counts = numpy.unique(grid_slice, return_counts=True)
if cvals.size == 1:
levels = [cvals[0] + 1]
else:
cval_order = numpy.argsort(cval_counts)[::-1]
level_count = 2
while cval_counts[cval_order[:level_count]].sum() < level_fraction:
level_count += 1
ctr_levels = cvals[cval_order[:level_count]]
levels = numpy.diff(ctr_levels[::-1]) + ctr_levels[:0:-1]
surface = numpy.delete(range(3), plane.axis)
if ax is None:
fig, ax = pyplot.subplots()
else:
fig = ax.figure
xc, yc = (self.shifted_xyz(which_shifts)[a] for a in surface)
xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij')
mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args)
if finalize: if finalize:
pyplot.show() pyplot.show()