return figure and axes after plotting

This commit is contained in:
Jan Petykiewicz 2024-07-18 00:17:58 -07:00
parent 3e4e6eead3
commit d44e02e2f7

View File

@ -90,8 +90,8 @@ def visualize_slice(
which_shifts: int = 0, which_shifts: int = 0,
sample_period: int = 1, sample_period: int = 1,
finalize: bool = True, finalize: bool = True,
pcolormesh_args: Optional[Dict[str, Any]] = None, pcolormesh_args: dict[str, Any] | None = None,
) -> None: ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']:
""" """
Visualize a slice of a grid. Visualize a slice of a grid.
Interpolates if given a position between two planes. Interpolates if given a position between two planes.
@ -102,6 +102,9 @@ def visualize_slice(
which_shifts: Which grid to display. Default is the first grid (0). which_shifts: Which grid to display. Default is the first grid (0).
sample_period: Period for down-sampling the image. Default 1 (disabled) sample_period: Period for down-sampling the image. Default 1 (disabled)
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
Returns:
(Figure, Axes)
""" """
from matplotlib import pyplot from matplotlib import pyplot
@ -120,15 +123,17 @@ def visualize_slice(
xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij')
x_label, y_label = ('xyz'[a] for a in surface) x_label, y_label = ('xyz'[a] for a in surface)
pyplot.figure() fig, ax = pyplot.subplots()
pyplot.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args)
pyplot.colorbar() fig.colorbar(mappable)
pyplot.gca().set_aspect('equal', adjustable='box') ax.set_aspect('equal', adjustable='box')
pyplot.xlabel(x_label) ax.set_xlabel(x_label)
pyplot.ylabel(y_label) ax.set_ylabel(y_label)
if finalize: if finalize:
pyplot.show() pyplot.show()
return fig, ax
def visualize_isosurface( def visualize_isosurface(
self, self,
@ -138,7 +143,7 @@ def visualize_isosurface(
sample_period: int = 1, sample_period: int = 1,
show_edges: bool = True, show_edges: bool = True,
finalize: bool = True, finalize: bool = True,
) -> None: ) -> tuple['matplotlib.axes.Axes', 'matplotlib.figure.Figure']:
""" """
Draw an isosurface plot of the device. Draw an isosurface plot of the device.
@ -149,6 +154,9 @@ def visualize_isosurface(
sample_period: Period for down-sampling the image. Default 1 (disabled) sample_period: Period for down-sampling the image. Default 1 (disabled)
show_edges: Whether to draw triangle edges. Default `True` show_edges: Whether to draw triangle edges. Default `True`
finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True`
Returns:
(Figure, Axes)
""" """
from matplotlib import pyplot from matplotlib import pyplot
import skimage.measure import skimage.measure
@ -190,3 +198,5 @@ def visualize_isosurface(
if finalize: if finalize:
pyplot.show() pyplot.show()
return fig, ax