diff --git a/gridlock/read.py b/gridlock/read.py index bbd4f39..2fe35d5 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -90,8 +90,8 @@ def visualize_slice( which_shifts: int = 0, sample_period: int = 1, finalize: bool = True, - pcolormesh_args: Optional[Dict[str, Any]] = None, - ) -> None: + pcolormesh_args: dict[str, Any] | None = None, + ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: """ Visualize a slice of a grid. Interpolates if given a position between two planes. @@ -102,6 +102,9 @@ def visualize_slice( which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + + Returns: + (Figure, Axes) """ from matplotlib import pyplot @@ -120,15 +123,17 @@ def visualize_slice( xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') x_label, y_label = ('xyz'[a] for a in surface) - pyplot.figure() - pyplot.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) - pyplot.colorbar() - pyplot.gca().set_aspect('equal', adjustable='box') - pyplot.xlabel(x_label) - pyplot.ylabel(y_label) + fig, ax = pyplot.subplots() + mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) + fig.colorbar(mappable) + ax.set_aspect('equal', adjustable='box') + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) if finalize: pyplot.show() + return fig, ax + def visualize_isosurface( self, @@ -138,7 +143,7 @@ def visualize_isosurface( sample_period: int = 1, show_edges: bool = True, finalize: bool = True, - ) -> None: + ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']: """ Draw an isosurface plot of the device. @@ -149,6 +154,9 @@ def visualize_isosurface( sample_period: Period for down-sampling the image. Default 1 (disabled) show_edges: Whether to draw triangle edges. Default `True` finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` + + Returns: + (Figure, Axes) """ from matplotlib import pyplot import skimage.measure @@ -190,3 +198,5 @@ def visualize_isosurface( if finalize: pyplot.show() + + return fig, ax