Compare commits
No commits in common. "483a8319975b52a571c45339995250ddae8159ec" and "48896c952996e2750431e72b2c70f6b1cd8516bd" have entirely different histories.
483a831997
...
48896c9529
@ -1,7 +1,10 @@
|
|||||||
from typing import Sequence, Callable, Any, TYPE_CHECKING
|
from typing import Sequence, Callable, Any
|
||||||
|
from pathlib import Path
|
||||||
import threading
|
import threading
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
from PyQt6 import QtCore
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
import polars
|
import polars
|
||||||
@ -11,10 +14,6 @@ from matplotlib.figure import Figure
|
|||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import pandas
|
|
||||||
|
|
||||||
|
|
||||||
def twrap(text: str, **kwargs) -> str:
|
def twrap(text: str, **kwargs) -> str:
|
||||||
kwargs.setdefault('width', 15)
|
kwargs.setdefault('width', 15)
|
||||||
intxt = text.replace('_', '-')
|
intxt = text.replace('_', '-')
|
||||||
@ -60,17 +59,18 @@ def variability_plot(
|
|||||||
vert_groups = set(vert_groups)
|
vert_groups = set(vert_groups)
|
||||||
|
|
||||||
df = polars.DataFrame(data_table)
|
df = polars.DataFrame(data_table)
|
||||||
|
# zero_bad: bool = True,
|
||||||
|
# if zero_bad:
|
||||||
|
# df.filter(col(data_col) != 0)
|
||||||
|
|
||||||
# Drop nulls and nans so that the boxplots don't disappear
|
|
||||||
df = polars.drop_nulls(subset=[data_col]).drop_nans(subset=[data_col])
|
|
||||||
|
|
||||||
# Assign category indicies (x_pos)
|
|
||||||
df = df.sort(groups)
|
df = df.sort(groups)
|
||||||
df_groups = df.select(groups).unique(maintain_order=True).with_row_index(name='x_pos')
|
df_groups = df.select(groups).unique(maintain_order=True).with_row_index(name='x_pos')
|
||||||
df = df.join(df_groups, on=groups, maintain_order='left')
|
df = df.join(df_groups, on=groups, maintain_order='left')
|
||||||
max_group_length = df.group_by(groups).len().select('len').max()[0, 0] # How many points in the largest x_pos
|
max_group_length = df.group_by(groups).len().select('len').max()[0, 0]
|
||||||
|
|
||||||
|
label_stack = get_label_stack(df_groups, groups, wrap_fn)
|
||||||
|
size_lists = get_text_sizes(label_stack)
|
||||||
|
|
||||||
# Add jitter to the scatterplots-plots
|
|
||||||
jitter = 0.2
|
jitter = 0.2
|
||||||
rng = numpy.random.default_rng(seed=0)
|
rng = numpy.random.default_rng(seed=0)
|
||||||
jitter_offsets = rng.uniform(low=-jitter, high=jitter, size=max_group_length)
|
jitter_offsets = rng.uniform(low=-jitter, high=jitter, size=max_group_length)
|
||||||
@ -84,12 +84,9 @@ def variability_plot(
|
|||||||
x_data = numpy.concatenate(x_lists)
|
x_data = numpy.concatenate(x_lists)
|
||||||
y_data = numpy.concatenate(y_lists)
|
y_data = numpy.concatenate(y_lists)
|
||||||
|
|
||||||
# Get label contents and measure their sizes
|
|
||||||
label_stack = get_label_stack(df_groups, groups, wrap_fn)
|
|
||||||
size_lists = get_text_sizes(label_stack)
|
|
||||||
y_ratios = [mainplot_ratios[1]] + get_label_y_ratios(groups, vert_groups, size_lists)
|
y_ratios = [mainplot_ratios[1]] + get_label_y_ratios(groups, vert_groups, size_lists)
|
||||||
|
|
||||||
# Build the figure and all axes
|
|
||||||
fig = pyplot.figure()
|
fig = pyplot.figure()
|
||||||
gs = gridspec.GridSpec(
|
gs = gridspec.GridSpec(
|
||||||
nrows = 1 + len(groups),
|
nrows = 1 + len(groups),
|
||||||
@ -109,9 +106,6 @@ def variability_plot(
|
|||||||
label_axes.append( fig.add_subplot(gs[ii, 0], sharex=ax))
|
label_axes.append( fig.add_subplot(gs[ii, 0], sharex=ax))
|
||||||
header_axes.append(fig.add_subplot(gs[ii, 1]))
|
header_axes.append(fig.add_subplot(gs[ii, 1]))
|
||||||
|
|
||||||
#
|
|
||||||
# Draw all the data
|
|
||||||
#
|
|
||||||
if dotprops:
|
if dotprops:
|
||||||
if not isinstance(dotprops, dict):
|
if not isinstance(dotprops, dict):
|
||||||
dotprops = {}
|
dotprops = {}
|
||||||
@ -156,9 +150,6 @@ def variability_plot(
|
|||||||
if mask_up.any():
|
if mask_up.any():
|
||||||
ax.scatter(x_data[mask_up], numpy.full(mask_up.sum(), ylim[1]), color='red', marker='^')
|
ax.scatter(x_data[mask_up], numpy.full(mask_up.sum(), ylim[1]), color='red', marker='^')
|
||||||
|
|
||||||
#
|
|
||||||
# Draw all the labels
|
|
||||||
#
|
|
||||||
max_x_pos = num_dsets - 1
|
max_x_pos = num_dsets - 1
|
||||||
major_xticks = []
|
major_xticks = []
|
||||||
minor_xticks = []
|
minor_xticks = []
|
||||||
@ -200,9 +191,6 @@ def variability_plot(
|
|||||||
|
|
||||||
textobjs.append(textrefs)
|
textobjs.append(textrefs)
|
||||||
|
|
||||||
#
|
|
||||||
# Set limits and grid on the main plot
|
|
||||||
#
|
|
||||||
ax.set_xlim(-0.5, num_dsets - 0.5)
|
ax.set_xlim(-0.5, num_dsets - 0.5)
|
||||||
if ylim is not None:
|
if ylim is not None:
|
||||||
ax.set_ylim(ylim)
|
ax.set_ylim(ylim)
|
||||||
@ -217,9 +205,6 @@ def variability_plot(
|
|||||||
ax.set_title(data_col)
|
ax.set_title(data_col)
|
||||||
ax.yaxis.set_minor_locator(ticker.AutoMinorLocator())
|
ax.yaxis.set_minor_locator(ticker.AutoMinorLocator())
|
||||||
|
|
||||||
#
|
|
||||||
# Add text resizing handlers to make sure labels are sized relative to their containing axes
|
|
||||||
#
|
|
||||||
def resize_labels(event) -> None:
|
def resize_labels(event) -> None:
|
||||||
# Resize labels
|
# Resize labels
|
||||||
margin_frac = 0.9
|
margin_frac = 0.9
|
||||||
@ -268,10 +253,24 @@ def debounce(func: Callable, delay_s: float = 0.05) -> Callable:
|
|||||||
return debounced_func
|
return debounced_func
|
||||||
|
|
||||||
|
|
||||||
|
def get_text_sizes(label_stack: label_stack_t) -> list[NDArray[numpy.float64]]:
|
||||||
|
fig, ax = pyplot.subplots()
|
||||||
|
text_obj = ax.text(0, 0, 'placeholder')
|
||||||
|
renderer = fig.canvas.get_renderer()
|
||||||
|
|
||||||
|
size_lists = []
|
||||||
|
for ll, level in enumerate(label_stack):
|
||||||
|
sizes = []
|
||||||
|
for xmin, xmax, text_value in level:
|
||||||
|
text_obj.set_text(text_value)
|
||||||
|
tbox = text_obj.get_window_extent(renderer=renderer)
|
||||||
|
sizes.append((xmax - xmin + 1, tbox.width, tbox.height))
|
||||||
|
size_lists.append(numpy.array(sizes))
|
||||||
|
pyplot.close(fig)
|
||||||
|
return size_lists
|
||||||
|
|
||||||
|
|
||||||
def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn: Callable) -> label_stack_t:
|
def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn: Callable) -> label_stack_t:
|
||||||
"""
|
|
||||||
For each level, get (xmin_inclusive, xmax_inclusive, wrapped_text_for_label) for all the labels
|
|
||||||
"""
|
|
||||||
label_stack = []
|
label_stack = []
|
||||||
for ll, level in enumerate(groups):
|
for ll, level in enumerate(groups):
|
||||||
spans = df_groups.group_by(groups[:ll + 1], maintain_order=True).agg(
|
spans = df_groups.group_by(groups[:ll + 1], maintain_order=True).agg(
|
||||||
@ -290,32 +289,7 @@ def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn:
|
|||||||
return label_stack
|
return label_stack
|
||||||
|
|
||||||
|
|
||||||
def get_text_sizes(label_stack: label_stack_t) -> list[NDArray[numpy.float64]]:
|
|
||||||
"""
|
|
||||||
Transform the label stack (see `get_label_stack` into a stack of (allocated x-span, unrotated x-size, unrotated y-size)
|
|
||||||
"""
|
|
||||||
fig, ax = pyplot.subplots()
|
|
||||||
text_obj = ax.text(0, 0, 'placeholder')
|
|
||||||
renderer = fig.canvas.get_renderer()
|
|
||||||
|
|
||||||
size_lists = []
|
|
||||||
for ll, level in enumerate(label_stack):
|
|
||||||
sizes = []
|
|
||||||
for xmin, xmax, text_value in level:
|
|
||||||
text_obj.set_text(text_value)
|
|
||||||
tbox = text_obj.get_window_extent(renderer=renderer)
|
|
||||||
sizes.append((xmax - xmin + 1, tbox.width, tbox.height))
|
|
||||||
size_lists.append(numpy.array(sizes))
|
|
||||||
pyplot.close(fig)
|
|
||||||
return size_lists
|
|
||||||
|
|
||||||
|
|
||||||
def get_label_y_ratios(groups: Sequence[str], vert_groups: set[str], size_lists: list[NDArray[numpy.float64]]) -> list[float]:
|
def get_label_y_ratios(groups: Sequence[str], vert_groups: set[str], size_lists: list[NDArray[numpy.float64]]) -> list[float]:
|
||||||
"""
|
|
||||||
For each level, figure out max(rotated_x_size / x_span) and max(rotated_y_size).
|
|
||||||
Normalize so that the sum of y-values is equal to the number of levels.
|
|
||||||
Output order is reversed so that the bottom labels (most general) come last.
|
|
||||||
"""
|
|
||||||
grouping_rotated = numpy.array([grouping in vert_groups for grouping in groups], dtype=bool)
|
grouping_rotated = numpy.array([grouping in vert_groups for grouping in groups], dtype=bool)
|
||||||
level_dims = []
|
level_dims = []
|
||||||
for sizes, rotated in zip(size_lists, grouping_rotated, strict=True):
|
for sizes, rotated in zip(size_lists, grouping_rotated, strict=True):
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from PyQt6 import QtCore
|
||||||
|
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from matplotlib import pyplot
|
from matplotlib import pyplot
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user