diff --git a/miscplot/variability.py b/miscplot/variability.py index aff667d..b380aac 100644 --- a/miscplot/variability.py +++ b/miscplot/variability.py @@ -1,7 +1,10 @@ -from typing import Sequence, Callable, Any, TYPE_CHECKING +from typing import Sequence, Callable, Any +from pathlib import Path import threading import textwrap +from PyQt6 import QtCore + import numpy from numpy.typing import NDArray import polars @@ -11,10 +14,6 @@ from matplotlib.figure import Figure from matplotlib.axes import Axes -if TYPE_CHECKING: - import pandas - - def twrap(text: str, **kwargs) -> str: kwargs.setdefault('width', 15) intxt = text.replace('_', '-') @@ -60,17 +59,18 @@ def variability_plot( vert_groups = set(vert_groups) df = polars.DataFrame(data_table) +# zero_bad: bool = True, +# if zero_bad: +# df.filter(col(data_col) != 0) - # Drop nulls and nans so that the boxplots don't disappear - df = polars.drop_nulls(subset=[data_col]).drop_nans(subset=[data_col]) - - # Assign category indicies (x_pos) df = df.sort(groups) df_groups = df.select(groups).unique(maintain_order=True).with_row_index(name='x_pos') df = df.join(df_groups, on=groups, maintain_order='left') - max_group_length = df.group_by(groups).len().select('len').max()[0, 0] # How many points in the largest x_pos + max_group_length = df.group_by(groups).len().select('len').max()[0, 0] + + label_stack = get_label_stack(df_groups, groups, wrap_fn) + size_lists = get_text_sizes(label_stack) - # Add jitter to the scatterplots-plots jitter = 0.2 rng = numpy.random.default_rng(seed=0) jitter_offsets = rng.uniform(low=-jitter, high=jitter, size=max_group_length) @@ -84,12 +84,9 @@ def variability_plot( x_data = numpy.concatenate(x_lists) y_data = numpy.concatenate(y_lists) - # Get label contents and measure their sizes - label_stack = get_label_stack(df_groups, groups, wrap_fn) - size_lists = get_text_sizes(label_stack) + y_ratios = [mainplot_ratios[1]] + get_label_y_ratios(groups, vert_groups, size_lists) - # Build the figure and all axes fig = pyplot.figure() gs = gridspec.GridSpec( nrows = 1 + len(groups), @@ -109,9 +106,6 @@ def variability_plot( label_axes.append( fig.add_subplot(gs[ii, 0], sharex=ax)) header_axes.append(fig.add_subplot(gs[ii, 1])) - # - # Draw all the data - # if dotprops: if not isinstance(dotprops, dict): dotprops = {} @@ -156,9 +150,6 @@ def variability_plot( if mask_up.any(): ax.scatter(x_data[mask_up], numpy.full(mask_up.sum(), ylim[1]), color='red', marker='^') - # - # Draw all the labels - # max_x_pos = num_dsets - 1 major_xticks = [] minor_xticks = [] @@ -200,9 +191,6 @@ def variability_plot( textobjs.append(textrefs) - # - # Set limits and grid on the main plot - # ax.set_xlim(-0.5, num_dsets - 0.5) if ylim is not None: ax.set_ylim(ylim) @@ -217,9 +205,6 @@ def variability_plot( ax.set_title(data_col) ax.yaxis.set_minor_locator(ticker.AutoMinorLocator()) - # - # Add text resizing handlers to make sure labels are sized relative to their containing axes - # def resize_labels(event) -> None: # Resize labels margin_frac = 0.9 @@ -268,10 +253,24 @@ def debounce(func: Callable, delay_s: float = 0.05) -> Callable: return debounced_func +def get_text_sizes(label_stack: label_stack_t) -> list[NDArray[numpy.float64]]: + fig, ax = pyplot.subplots() + text_obj = ax.text(0, 0, 'placeholder') + renderer = fig.canvas.get_renderer() + + size_lists = [] + for ll, level in enumerate(label_stack): + sizes = [] + for xmin, xmax, text_value in level: + text_obj.set_text(text_value) + tbox = text_obj.get_window_extent(renderer=renderer) + sizes.append((xmax - xmin + 1, tbox.width, tbox.height)) + size_lists.append(numpy.array(sizes)) + pyplot.close(fig) + return size_lists + + def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn: Callable) -> label_stack_t: - """ - For each level, get (xmin_inclusive, xmax_inclusive, wrapped_text_for_label) for all the labels - """ label_stack = [] for ll, level in enumerate(groups): spans = df_groups.group_by(groups[:ll + 1], maintain_order=True).agg( @@ -290,32 +289,7 @@ def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn: return label_stack -def get_text_sizes(label_stack: label_stack_t) -> list[NDArray[numpy.float64]]: - """ - Transform the label stack (see `get_label_stack` into a stack of (allocated x-span, unrotated x-size, unrotated y-size) - """ - fig, ax = pyplot.subplots() - text_obj = ax.text(0, 0, 'placeholder') - renderer = fig.canvas.get_renderer() - - size_lists = [] - for ll, level in enumerate(label_stack): - sizes = [] - for xmin, xmax, text_value in level: - text_obj.set_text(text_value) - tbox = text_obj.get_window_extent(renderer=renderer) - sizes.append((xmax - xmin + 1, tbox.width, tbox.height)) - size_lists.append(numpy.array(sizes)) - pyplot.close(fig) - return size_lists - - def get_label_y_ratios(groups: Sequence[str], vert_groups: set[str], size_lists: list[NDArray[numpy.float64]]) -> list[float]: - """ - For each level, figure out max(rotated_x_size / x_span) and max(rotated_y_size). - Normalize so that the sum of y-values is equal to the number of levels. - Output order is reversed so that the bottom labels (most general) come last. - """ grouping_rotated = numpy.array([grouping in vert_groups for grouping in groups], dtype=bool) level_dims = [] for sizes, rotated in zip(size_lists, grouping_rotated, strict=True): diff --git a/miscplot/wmap.py b/miscplot/wmap.py index 58ce050..d371812 100644 --- a/miscplot/wmap.py +++ b/miscplot/wmap.py @@ -1,5 +1,7 @@ from typing import Any, Callable +from PyQt6 import QtCore + from matplotlib.figure import Figure from matplotlib.axes import Axes from matplotlib import pyplot