diff --git a/gridlock/read.py b/gridlock/read.py index 503e996..707251a 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -87,7 +87,6 @@ class GridReadMixin(GridPosMixin): sample_period: int = 1, finalize: bool = True, pcolormesh_args: dict[str, Any] | None = None, - ax: 'matplotlib.axes.Axes | None' = None, ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: """ Visualize a slice of a grid. @@ -99,8 +98,6 @@ class GridReadMixin(GridPosMixin): which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - pcolormesh_args: Args passed through to matplotlib `pcolormesh()` - ax: If provided, plot to these axes (instead of creating a new figure & axes) Returns: (Figure, Axes) @@ -114,10 +111,10 @@ class GridReadMixin(GridPosMixin): pcolormesh_args = {} grid_slice = self.get_slice( - cell_data = cell_data, - plane = plane, - which_shifts = which_shifts, - sample_period = sample_period, + cell_data=cell_data, + plane=plane, + which_shifts=which_shifts, + sample_period=sample_period, ) surface = numpy.delete(range(3), plane.axis) @@ -126,93 +123,12 @@ class GridReadMixin(GridPosMixin): xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') x_label, y_label = ('xyz'[a] for a in surface) - if ax is None: - fig, ax = pyplot.subplots() - else: - fig = ax.figure + fig, ax = pyplot.subplots() mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) fig.colorbar(mappable) ax.set_aspect('equal', adjustable='box') ax.set_xlabel(x_label) ax.set_ylabel(y_label) - - if finalize: - pyplot.show() - - return fig, ax - - - def visualize_edges( - self, - cell_data: NDArray, - plane: PlaneProtocol | PlaneDict, - which_shifts: int = 0, - sample_period: int = 1, - finalize: bool = True, - contour_args: dict[str, Any] | None = None, - ax: 'matplotlib.axes.Axes | None' = None, - level_fraction: float = 0.7, - ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: - """ - Visualize the edges of a grid slice. - This is intended as an overlay on top of visualize_slice (e.g. showing epsilon boundaries - on an E-field plot). - - Interpolates if given a position between two grid planes. - - Args: - cell_data: Cell data to visualize - plane: Axis and position (`Plane`) of the plane to read. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) - finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - contour_args: Args passed through to matplotlib `pcolormesh()` - ax: If provided, plot to these axes (instead of creating a new figure & axes) - level_fraction: Value between 0 and 1 which tunes how many contours are generated. - 1 indicates that every possible step should have its own contour. - - Returns: - (Figure, Axes) - """ - from matplotlib import pyplot - - if level_fraction > 1: - raise GridError(f'{level_fraction=} must be between 0 and 1') - - if isinstance(plane, dict): - plane = Plane(**plane) - - if contour_args is None: - contour_args = dict(alpha=0.8, colors='gray') - - grid_slice = self.get_slice( - cell_data = cell_data, - plane = plane, - which_shifts = which_shifts, - sample_period = sample_period, - ) - cvals, cval_counts = numpy.unique(grid_slice, return_counts=True) - if cvals.size == 1: - levels = [cvals[0] + 1] - else: - cval_order = numpy.argsort(cval_counts)[::-1] - level_count = 2 - while cval_counts[cval_order[:level_count]].sum() < level_fraction: - level_count += 1 - ctr_levels = cvals[cval_order[:level_count]] - levels = numpy.diff(ctr_levels[::-1]) + ctr_levels[:0:-1] - - surface = numpy.delete(range(3), plane.axis) - - if ax is None: - fig, ax = pyplot.subplots() - else: - fig = ax.figure - xc, yc = (self.shifted_xyz(which_shifts)[a] for a in surface) - xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij') - - mappable = ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) - if finalize: pyplot.show()