diff --git a/gridlock/read.py b/gridlock/read.py index b5840e7..600227d 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -135,12 +135,86 @@ class GridReadMixin(GridPosMixin): ax.set_aspect('equal', adjustable='box') ax.set_xlabel(x_label) ax.set_ylabel(y_label) + if finalize: pyplot.show() return fig, ax + def visualize_edges( + self, + cell_data: NDArray, + plane: PlaneProtocol | PlaneDict, + which_shifts: int = 0, + finalize: bool = True, + contour_args: dict[str, Any] | None = None, + ax: 'matplotlib.axes.Axes' | None = None, + level_fraction: float = 0.7, + ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: + """ + Visualize the edges of a grid slice. + This is intended as an overlay on top of visualize_slice (e.g. showing epsilon boundaries + on an E-field plot). + + Interpolates if given a position between two grid planes. + + Args: + cell_data: Cell data to visualize + plane: Axis and position (`Plane`) of the plane to read. + which_shifts: Which grid to display. Default is the first grid (0). + sample_period: Period for down-sampling the image. Default 1 (disabled) + finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + pcolormesh_args: Args passed through to matplotlib `pcolormesh()` + ax: If provided, plot to these axes (instead of creating a new figure & axes) + level_fraction: Value between 0 and 1 which tunes how many contours are generated. + 1 indicates that every possible step should have its own contour. + + Returns: + (Figure, Axes) + """ + from matplotlib import pyplot + + if level_fraction > 1: + raise GridError(f'{level_fraction=} must be between 0 and 1') + + if isinstance(plane, dict): + plane = Plane(**plane) + + if pcolormesh_args is None: + pcolormesh_args = {} + + grid_slice = self.get_slice( + cell_data = cell_data, + plane = plane, + which_shifts = which_shifts, + sample_period = sample_period, + ) + cvals, cval_counts = numpy.unique(grid_slice, return_counts=True) + if cvals.size == 1: + levels = [cvals[0] + 1] + else: + cval_order = numpy.argsort(cval_counts)[::-1] + level_count = 2 + while cval_counts[cval_order[:level_count]].sum() < level_fraction: + level_count += 1 + ctr_levels = cvals[cval_order[:level_count]] + levels = numpy.diff(ctr_levels[::-1]) + ctr_levels[:0:-1] + + surface = numpy.delete(range(3), plane.axis) + + if ax is None: + fig, ax = pyplot.subplots() + else: + fig = ax.figure + xc, yc = (self.shifted_exyz(which_shifts)[a] for a in surface) + xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij') + + mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) + + return fig, ax + + def visualize_isosurface( self, cell_data: NDArray,