Compare commits

..

4 Commits

2 changed files with 55 additions and 31 deletions

View File

@ -1,10 +1,7 @@
from typing import Sequence, Callable, Any from typing import Sequence, Callable, Any, TYPE_CHECKING
from pathlib import Path
import threading import threading
import textwrap import textwrap
from PyQt6 import QtCore
import numpy import numpy
from numpy.typing import NDArray from numpy.typing import NDArray
import polars import polars
@ -14,6 +11,10 @@ from matplotlib.figure import Figure
from matplotlib.axes import Axes from matplotlib.axes import Axes
if TYPE_CHECKING:
import pandas
def twrap(text: str, **kwargs) -> str: def twrap(text: str, **kwargs) -> str:
kwargs.setdefault('width', 15) kwargs.setdefault('width', 15)
intxt = text.replace('_', '-') intxt = text.replace('_', '-')
@ -59,18 +60,17 @@ def variability_plot(
vert_groups = set(vert_groups) vert_groups = set(vert_groups)
df = polars.DataFrame(data_table) df = polars.DataFrame(data_table)
# zero_bad: bool = True,
# if zero_bad:
# df.filter(col(data_col) != 0)
# Drop nulls and nans so that the boxplots don't disappear
df = polars.drop_nulls(subset=[data_col]).drop_nans(subset=[data_col])
# Assign category indicies (x_pos)
df = df.sort(groups) df = df.sort(groups)
df_groups = df.select(groups).unique(maintain_order=True).with_row_index(name='x_pos') df_groups = df.select(groups).unique(maintain_order=True).with_row_index(name='x_pos')
df = df.join(df_groups, on=groups, maintain_order='left') df = df.join(df_groups, on=groups, maintain_order='left')
max_group_length = df.group_by(groups).len().select('len').max()[0, 0] max_group_length = df.group_by(groups).len().select('len').max()[0, 0] # How many points in the largest x_pos
label_stack = get_label_stack(df_groups, groups, wrap_fn)
size_lists = get_text_sizes(label_stack)
# Add jitter to the scatterplots-plots
jitter = 0.2 jitter = 0.2
rng = numpy.random.default_rng(seed=0) rng = numpy.random.default_rng(seed=0)
jitter_offsets = rng.uniform(low=-jitter, high=jitter, size=max_group_length) jitter_offsets = rng.uniform(low=-jitter, high=jitter, size=max_group_length)
@ -84,9 +84,12 @@ def variability_plot(
x_data = numpy.concatenate(x_lists) x_data = numpy.concatenate(x_lists)
y_data = numpy.concatenate(y_lists) y_data = numpy.concatenate(y_lists)
# Get label contents and measure their sizes
label_stack = get_label_stack(df_groups, groups, wrap_fn)
size_lists = get_text_sizes(label_stack)
y_ratios = [mainplot_ratios[1]] + get_label_y_ratios(groups, vert_groups, size_lists) y_ratios = [mainplot_ratios[1]] + get_label_y_ratios(groups, vert_groups, size_lists)
# Build the figure and all axes
fig = pyplot.figure() fig = pyplot.figure()
gs = gridspec.GridSpec( gs = gridspec.GridSpec(
nrows = 1 + len(groups), nrows = 1 + len(groups),
@ -106,6 +109,9 @@ def variability_plot(
label_axes.append( fig.add_subplot(gs[ii, 0], sharex=ax)) label_axes.append( fig.add_subplot(gs[ii, 0], sharex=ax))
header_axes.append(fig.add_subplot(gs[ii, 1])) header_axes.append(fig.add_subplot(gs[ii, 1]))
#
# Draw all the data
#
if dotprops: if dotprops:
if not isinstance(dotprops, dict): if not isinstance(dotprops, dict):
dotprops = {} dotprops = {}
@ -150,6 +156,9 @@ def variability_plot(
if mask_up.any(): if mask_up.any():
ax.scatter(x_data[mask_up], numpy.full(mask_up.sum(), ylim[1]), color='red', marker='^') ax.scatter(x_data[mask_up], numpy.full(mask_up.sum(), ylim[1]), color='red', marker='^')
#
# Draw all the labels
#
max_x_pos = num_dsets - 1 max_x_pos = num_dsets - 1
major_xticks = [] major_xticks = []
minor_xticks = [] minor_xticks = []
@ -191,6 +200,9 @@ def variability_plot(
textobjs.append(textrefs) textobjs.append(textrefs)
#
# Set limits and grid on the main plot
#
ax.set_xlim(-0.5, num_dsets - 0.5) ax.set_xlim(-0.5, num_dsets - 0.5)
if ylim is not None: if ylim is not None:
ax.set_ylim(ylim) ax.set_ylim(ylim)
@ -205,6 +217,9 @@ def variability_plot(
ax.set_title(data_col) ax.set_title(data_col)
ax.yaxis.set_minor_locator(ticker.AutoMinorLocator()) ax.yaxis.set_minor_locator(ticker.AutoMinorLocator())
#
# Add text resizing handlers to make sure labels are sized relative to their containing axes
#
def resize_labels(event) -> None: def resize_labels(event) -> None:
# Resize labels # Resize labels
margin_frac = 0.9 margin_frac = 0.9
@ -253,24 +268,10 @@ def debounce(func: Callable, delay_s: float = 0.05) -> Callable:
return debounced_func return debounced_func
def get_text_sizes(label_stack: label_stack_t) -> list[NDArray[numpy.float64]]:
fig, ax = pyplot.subplots()
text_obj = ax.text(0, 0, 'placeholder')
renderer = fig.canvas.get_renderer()
size_lists = []
for ll, level in enumerate(label_stack):
sizes = []
for xmin, xmax, text_value in level:
text_obj.set_text(text_value)
tbox = text_obj.get_window_extent(renderer=renderer)
sizes.append((xmax - xmin + 1, tbox.width, tbox.height))
size_lists.append(numpy.array(sizes))
pyplot.close(fig)
return size_lists
def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn: Callable) -> label_stack_t: def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn: Callable) -> label_stack_t:
"""
For each level, get (xmin_inclusive, xmax_inclusive, wrapped_text_for_label) for all the labels
"""
label_stack = [] label_stack = []
for ll, level in enumerate(groups): for ll, level in enumerate(groups):
spans = df_groups.group_by(groups[:ll + 1], maintain_order=True).agg( spans = df_groups.group_by(groups[:ll + 1], maintain_order=True).agg(
@ -289,7 +290,32 @@ def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn:
return label_stack return label_stack
def get_text_sizes(label_stack: label_stack_t) -> list[NDArray[numpy.float64]]:
"""
Transform the label stack (see `get_label_stack` into a stack of (allocated x-span, unrotated x-size, unrotated y-size)
"""
fig, ax = pyplot.subplots()
text_obj = ax.text(0, 0, 'placeholder')
renderer = fig.canvas.get_renderer()
size_lists = []
for ll, level in enumerate(label_stack):
sizes = []
for xmin, xmax, text_value in level:
text_obj.set_text(text_value)
tbox = text_obj.get_window_extent(renderer=renderer)
sizes.append((xmax - xmin + 1, tbox.width, tbox.height))
size_lists.append(numpy.array(sizes))
pyplot.close(fig)
return size_lists
def get_label_y_ratios(groups: Sequence[str], vert_groups: set[str], size_lists: list[NDArray[numpy.float64]]) -> list[float]: def get_label_y_ratios(groups: Sequence[str], vert_groups: set[str], size_lists: list[NDArray[numpy.float64]]) -> list[float]:
"""
For each level, figure out max(rotated_x_size / x_span) and max(rotated_y_size).
Normalize so that the sum of y-values is equal to the number of levels.
Output order is reversed so that the bottom labels (most general) come last.
"""
grouping_rotated = numpy.array([grouping in vert_groups for grouping in groups], dtype=bool) grouping_rotated = numpy.array([grouping in vert_groups for grouping in groups], dtype=bool)
level_dims = [] level_dims = []
for sizes, rotated in zip(size_lists, grouping_rotated, strict=True): for sizes, rotated in zip(size_lists, grouping_rotated, strict=True):

View File

@ -1,7 +1,5 @@
from typing import Any, Callable from typing import Any, Callable
from PyQt6 import QtCore
from matplotlib.figure import Figure from matplotlib.figure import Figure
from matplotlib.axes import Axes from matplotlib.axes import Axes
from matplotlib import pyplot from matplotlib import pyplot