miscplot/miscplot/variability.py

356 lines
13 KiB
Python

from typing import Sequence, Callable, Any
from pathlib import Path
import threading
import textwrap
from PyQt6 import QtCore
import numpy
from numpy.typing import NDArray
import polars
from polars import col
from matplotlib import pyplot, gridspec, ticker
from matplotlib.figure import Figure
from matplotlib.axes import Axes
def twrap(text: str, **kwargs) -> str:
kwargs.setdefault('width', 15)
intxt = text.replace('_', '-')
wrapped = textwrap.fill(intxt, **kwargs)
return wrapped.replace('-', '_')
def variability_plot(
data_table: 'dict | Sequence | NDArray | polars.Series | pandas.DataFrame',
data_col: str,
groups: Sequence[str],
vert_groups: Sequence[str] = (),
*,
wrap_fn: Callable[[str], str] = twrap,
mainplot_ratios: tuple[float, float] = (10, 10),
ylim: tuple[float, float] | None = None,
dotprops: dict[str, Any] | bool = True,
boxprops: dict[str, Any] | bool = True,
meanprops: dict[str, Any] | bool = True,
) -> tuple[Figure, Axes, list[Axes], list[Axes]]:
"""
Create a variability plot (categorical box & scatter plot)
Args:
data_table: Dataset to plot. Passed directly to `polars.DataFrame()`.
data_col: Column to use for box/scatterplot value.
groups: Columns to group by. Coarsest grouping should be first, and will appear
furthest from the scatterplot, at the bottom of the figure.
vert_groups: Labels for these column names will be rotated (text will run vertically).
wrap_fn: Function called to wrap label text (i.e. insert newlines).
Default wraps at 15 characters, preferentially on underscores or whitespace.
mainplot_ratios: Scale factors setting the size of the main axes, relative to the size
of other axes. Default is (10, 10).
ylim: Y-limits for the scatter/box plot. Points which fall outside the limits are drawn
as red triangles at the edges.
dotprops: Passed as kwargs to scatterplot.
boxprops: Passed as kwargs to boxplot.
meanprops: Passed as kwargs to lineplot of means.
Returns:
figure, data axes, label axes, header axes
"""
vert_groups = set(vert_groups)
df = polars.DataFrame(data_table)
# Drop nulls and nans so that the boxplots don't disappear
df = polars.drop_nulls(subset=[data_col]).drop_nans(subset=[data_col])
# Assign category indicies (x_pos)
df = df.sort(groups)
df_groups = df.select(groups).unique(maintain_order=True).with_row_index(name='x_pos')
df = df.join(df_groups, on=groups, maintain_order='left')
max_group_length = df.group_by(groups).len().select('len').max()[0, 0] # How many points in the largest x_pos
# Add jitter to the scatterplots-plots
jitter = 0.2
rng = numpy.random.default_rng(seed=0)
jitter_offsets = rng.uniform(low=-jitter, high=jitter, size=max_group_length)
x_lists = []
y_lists = []
for _labels, gdf in df.group_by(groups, maintain_order=True):
x_lists.append(gdf['x_pos'][0] + jitter_offsets[:gdf.height])
y_lists.append(gdf[data_col])
num_dsets = len(x_lists)
x_data = numpy.concatenate(x_lists)
y_data = numpy.concatenate(y_lists)
# Get label contents and measure their sizes
label_stack = get_label_stack(df_groups, groups, wrap_fn)
size_lists = get_text_sizes(label_stack)
y_ratios = [mainplot_ratios[1]] + get_label_y_ratios(groups, vert_groups, size_lists)
# Build the figure and all axes
fig = pyplot.figure()
gs = gridspec.GridSpec(
nrows = 1 + len(groups),
ncols = 2,
height_ratios = y_ratios,
width_ratios = [mainplot_ratios[0], 1],
hspace = 0,
wspace = 0.05,
#left = 0.07,
right = 0.98,
)
ax = fig.add_subplot(gs[0, 0])
label_axes = []
header_axes = []
for ii in range(1, len(groups) + 1):
label_axes.append( fig.add_subplot(gs[ii, 0], sharex=ax))
header_axes.append(fig.add_subplot(gs[ii, 1]))
#
# Draw all the data
#
if dotprops:
if not isinstance(dotprops, dict):
dotprops = {}
dotprops.setdefault('alpha', 0.7)
dotprops.setdefault('color', 'black')
_dotplt = ax.scatter(x_data, y_data, s=numpy.ones_like(y_data), **dotprops)
if boxprops:
if not isinstance(boxprops, dict):
boxprops = {}
boxprops.setdefault('showfliers', False)
boxprops.setdefault('medianprops', dict(linewidth=3, color='darkred', alpha=0.8))
boxprops.setdefault('boxprops', dict(linewidth=0.5, color='black'))
boxprops.setdefault('whiskerprops', dict(linewidth=0.5, color='black'))
_boxplt = ax.boxplot(y_lists, positions=range(num_dsets), **boxprops)
if meanprops:
means = [yl.mean() for yl in y_lists]
xy = [(-0.5, means[0])]
for xx, yy in enumerate(means):
xy += [(xx - 0.25, yy),
(xx + 0.25, yy)]
xy += [(xx + 0.5, yy)]
xy = numpy.array(xy)
if not isinstance(meanprops, dict):
meanprops = {}
meanprops.setdefault('color', 'blue')
meanprops.setdefault('alpha', 0.8)
meanprops.setdefault('linewidth', 0.5)
_meanplt = ax.plot(xy[:, 0], xy[:, 1], **meanprops)
#for xd, yd in zip(x_lists, y_lists, zip=True):
# plt = ax.scatter(xd, yd, s=numpy.ones_like(yd), color='black', alpha=0.9)
# if boxplot:
# boxplt = ax.boxplot(
if ylim is not None:
mask_dn = y_data < ylim[0]
mask_up = y_data > ylim[1]
if mask_dn.any():
ax.scatter(x_data[mask_dn], numpy.full(mask_dn.sum(), ylim[0]), color='red', marker='v')
if mask_up.any():
ax.scatter(x_data[mask_up], numpy.full(mask_up.sum(), ylim[1]), color='red', marker='^')
#
# Draw all the labels
#
max_x_pos = num_dsets - 1
major_xticks = []
minor_xticks = []
textobjs = []
headerobjs = []
for ll, level in enumerate(groups):
axl = label_axes[-1 - ll]
axh = header_axes[-1 - ll]
axl.axis('off')
axh.axis('off')
header_txt = axh.text(0, 0.5, level, ha='left', va='center', fontsize=10, weight='bold')
headerobjs.append((axh, header_txt))
textrefs = dict(axes=[], texts=[], span_fracs=[])
for (xmin, xmax, text_value), (xsize, twidth, theight) in zip(label_stack[ll], size_lists[ll], strict=True):
text_obj = axl.text(
0.5 * (xmin + xmax), 0.5, text_value,
ha = 'center', va = 'center',
fontsize = 10,
rotation = (90 if level in vert_groups else 0),
clip_on = True,
)
textrefs['axes'].append(axl)
textrefs['texts'].append(text_obj)
textrefs['span_fracs'].append(xsize / num_dsets)
if xmax < max_x_pos:
maxpt = xmax + 0.5
axl.axvline(maxpt, color='gray', linestyle=':', linewidth=0.5)
if ll == len(groups) - 2:
major_xticks.append(maxpt)
elif ll ==len(groups) - 1:
minor_xticks.append(maxpt)
else:
axl.axline(xy1=(xmin, 0), slope=0, color='gray', linewidth=0.5)
axl.axline(xy1=(xmin, 1), slope=0, color='gray', linewidth=0.5, alpha=0.5)
axl.set_ylim(0, 1)
textobjs.append(textrefs)
#
# Set limits and grid on the main plot
#
ax.set_xlim(-0.5, num_dsets - 0.5)
if ylim is not None:
ax.set_ylim(ylim)
ax.set_xticks(major_xticks, minor=False)
ax.set_xticks(minor_xticks, minor=True)
ax.set_xticklabels([], minor=False)
ax.set_xticklabels([], minor=True)
ax.tick_params('x', which='both', bottom=False)
ax.grid(alpha=0.2, which='minor')
ax.grid(alpha=1, which='major')
ax.set_ylabel(data_col)
ax.set_title(data_col)
ax.yaxis.set_minor_locator(ticker.AutoMinorLocator())
#
# Add text resizing handlers to make sure labels are sized relative to their containing axes
#
def resize_labels(event) -> None:
# Resize labels
margin_frac = 0.9
max_fontsize = 12
for level in textobjs:
ax_sizes = numpy.array([[abox.width, abox.height] for abox in [axl.get_window_extent() for axl in level['axes']]])
tx_sizes = numpy.array([[tbox.width, tbox.height] for tbox in [txt.get_window_extent() for txt in level['texts']]])
cur_fontsize = level['texts'][0].get_fontsize()
ax_sizes[:, 0] *= level['span_fracs']
scales = margin_frac * ax_sizes / tx_sizes
tgt_fontsize = min(cur_fontsize * scales.min(), max_fontsize)
for txt in level['texts']:
txt.set_fontsize(tgt_fontsize)
fig.canvas.draw_idle()
def resize_headers(event) -> None:
# Resize headers
margin_frac = 0.9
max_fontsize = 12
hax_sizes = numpy.array([[abox.width, abox.height] for abox in [axh.get_window_extent() for axh, _ in headerobjs]])
htx_sizes = numpy.array([[tbox.width, tbox.height] for tbox in [txh.get_window_extent() for _, txh in headerobjs]])
cur_fontsize = headerobjs[0][1].get_fontsize()
scales = margin_frac * hax_sizes / htx_sizes
tgt_fontsize = min(cur_fontsize * scales.min(), max_fontsize)
for _, txt in headerobjs:
txt.set_fontsize(tgt_fontsize)
fig.canvas.draw_idle()
fig.canvas.mpl_connect('resize_event', debounce(resize_labels))
fig.canvas.mpl_connect('resize_event', debounce(resize_headers))
return fig, ax, label_axes, header_axes
label_stack_t = list[list[tuple[int, int, str]]]
def debounce(func: Callable, delay_s: float = 0.05) -> Callable:
timer = None
def debounced_func(*args, **kwargs) -> None:
nonlocal timer
if timer is not None:
timer.cancel()
timer = threading.Timer(delay_s, lambda: func(*args, **kwargs))
timer.start()
return debounced_func
def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn: Callable) -> label_stack_t:
"""
For each level, get (xmin_inclusive, xmax_inclusive, wrapped_text_for_label) for all the labels
"""
label_stack = []
for ll, level in enumerate(groups):
spans = df_groups.group_by(groups[:ll + 1], maintain_order=True).agg(
xmin = col('x_pos').min(),
xmax = col('x_pos').max(),
).with_columns(
xspan = col('xmax') - col('xmin') + 1,
)
label_row = []
for row in spans.to_dicts():
# df row is plot column
text_value = wrap_fn(str(row[level]))
label_row.append((row['xmin'], row['xmax'], text_value))
label_stack.append(label_row)
return label_stack
def get_text_sizes(label_stack: label_stack_t) -> list[NDArray[numpy.float64]]:
"""
Transform the label stack (see `get_label_stack` into a stack of (allocated x-span, unrotated x-size, unrotated y-size)
"""
fig, ax = pyplot.subplots()
text_obj = ax.text(0, 0, 'placeholder')
renderer = fig.canvas.get_renderer()
size_lists = []
for ll, level in enumerate(label_stack):
sizes = []
for xmin, xmax, text_value in level:
text_obj.set_text(text_value)
tbox = text_obj.get_window_extent(renderer=renderer)
sizes.append((xmax - xmin + 1, tbox.width, tbox.height))
size_lists.append(numpy.array(sizes))
pyplot.close(fig)
return size_lists
def get_label_y_ratios(groups: Sequence[str], vert_groups: set[str], size_lists: list[NDArray[numpy.float64]]) -> list[float]:
"""
For each level, figure out max(rotated_x_size / x_span) and max(rotated_y_size).
Normalize so that the sum of y-values is equal to the number of levels.
Output order is reversed so that the bottom labels (most general) come last.
"""
grouping_rotated = numpy.array([grouping in vert_groups for grouping in groups], dtype=bool)
level_dims = []
for sizes, rotated in zip(size_lists, grouping_rotated, strict=True):
if rotated:
dxy = sizes[:, [2, 1]].copy()
else:
dxy = sizes[:, [1, 2]].copy()
dxy[:, 0] /= sizes[:, 0]
level_dims.append(dxy.max(axis=0))
scales = numpy.array(level_dims)
ratios = (scales.shape[0] * scales[::-1, 1] / scales[:, 1].sum()).tolist()
return ratios
def _mk_data(filename: str) -> None:
"""
Make some dummy data and write it to a csv file
"""
rows = []
rng = numpy.random.default_rng(seed=0)
for mm in ('liminal', 'transitive', 'extrinsic'):
for dd in ('elevator', 'snare', 'tibetan_foxhole', 'inverse_thresher'):
for vv in (('tiny', 'elective', 'baseline') if '_' in dd else ('dormant', 'volatile')):
std = rng.uniform(low=.1, high=1, size=1)
mean = rng.uniform(low=4, high=6, size=1)
for qq in rng.standard_normal(size=100) * std + mean:
rows.append(dict(MeasurementType=mm, Device=dd, DeviceVariant=vv, MeasuredValue=qq))
df = polars.DataFrame(rows)
df.write_csv()
if __name__ == '__main__':
filename = 'dummy_data.csv'
_mk_data(filename)
df = polars.read_csv(filename)
variability_plot(df, 'MeasuredValue', ['MeasurementType', 'Device', 'DeviceVariant'], vert_groups=['DeviceVariant'], ylim=(3, 7))
pyplot.show(block=True)