miscplot/miscplot/variability.py

331 lines
12 KiB
Python

from typing import Sequence, Callable, Any
from pathlib import Path
import threading
import textwrap
from PyQt6 import QtCore
import numpy
from numpy.typing import NDArray
import polars
from polars import col
from matplotlib import pyplot, gridspec, ticker
from matplotlib.figure import Figure
from matplotlib.axes import Axes
def twrap(text: str, **kwargs) -> str:
kwargs.setdefault('width', 15)
intxt = text.replace('_', '-')
wrapped = textwrap.fill(intxt, **kwargs)
return wrapped.replace('-', '_')
def variability_plot(
data_table: 'dict | Sequence | NDArray | polars.Series | pandas.DataFrame',
data_col: str,
groups: Sequence[str],
vert_groups: Sequence[str] = (),
*,
wrap_fn: Callable[[str], str] = twrap,
mainplot_ratios: tuple[float, float] = (10, 10),
ylim: tuple[float, float] | None = None,
dotprops: dict[str, Any] | bool = True,
boxprops: dict[str, Any] | bool = True,
meanprops: dict[str, Any] | bool = True,
) -> tuple[Figure, Axes, list[Axes], list[Axes]]:
"""
Create a variability plot (categorical box & scatter plot)
Args:
data_table: Dataset to plot. Passed directly to `polars.DataFrame()`.
data_col: Column to use for box/scatterplot value.
groups: Columns to group by. Coarsest grouping should be first, and will appear
furthest from the scatterplot, at the bottom of the figure.
vert_groups: Labels for these column names will be rotated (text will run vertically).
wrap_fn: Function called to wrap label text (i.e. insert newlines).
Default wraps at 15 characters, preferentially on underscores or whitespace.
mainplot_ratios: Scale factors setting the size of the main axes, relative to the size
of other axes. Default is (10, 10).
ylim: Y-limits for the scatter/box plot. Points which fall outside the limits are drawn
as red triangles at the edges.
dotprops: Passed as kwargs to scatterplot.
boxprops: Passed as kwargs to boxplot.
meanprops: Passed as kwargs to lineplot of means.
Returns:
figure, data axes, label axes, header axes
"""
vert_groups = set(vert_groups)
df = polars.DataFrame(data_table)
# zero_bad: bool = True,
# if zero_bad:
# df.filter(col(data_col) != 0)
df = df.sort(groups)
df_groups = df.select(groups).unique(maintain_order=True).with_row_index(name='x_pos')
df = df.join(df_groups, on=groups, maintain_order='left')
max_group_length = df.group_by(groups).len().select('len').max()[0, 0]
label_stack = get_label_stack(df_groups, groups, wrap_fn)
size_lists = get_text_sizes(label_stack)
jitter = 0.2
rng = numpy.random.default_rng(seed=0)
jitter_offsets = rng.uniform(low=-jitter, high=jitter, size=max_group_length)
x_lists = []
y_lists = []
for _labels, gdf in df.group_by(groups, maintain_order=True):
x_lists.append(gdf['x_pos'][0] + jitter_offsets[:gdf.height])
y_lists.append(gdf[data_col])
num_dsets = len(x_lists)
x_data = numpy.concatenate(x_lists)
y_data = numpy.concatenate(y_lists)
y_ratios = [mainplot_ratios[1]] + get_label_y_ratios(groups, vert_groups, size_lists)
fig = pyplot.figure()
gs = gridspec.GridSpec(
nrows = 1 + len(groups),
ncols = 2,
height_ratios = y_ratios,
width_ratios = [mainplot_ratios[0], 1],
hspace = 0,
wspace = 0.05,
#left = 0.07,
right = 0.98,
)
ax = fig.add_subplot(gs[0, 0])
label_axes = []
header_axes = []
for ii in range(1, len(groups) + 1):
label_axes.append( fig.add_subplot(gs[ii, 0], sharex=ax))
header_axes.append(fig.add_subplot(gs[ii, 1]))
if dotprops:
if not isinstance(dotprops, dict):
dotprops = {}
dotprops.setdefault('alpha', 0.7)
dotprops.setdefault('color', 'black')
_dotplt = ax.scatter(x_data, y_data, s=numpy.ones_like(y_data), **dotprops)
if boxprops:
if not isinstance(boxprops, dict):
boxprops = {}
boxprops.setdefault('showfliers', False)
boxprops.setdefault('medianprops', dict(linewidth=3, color='darkred', alpha=0.8))
boxprops.setdefault('boxprops', dict(linewidth=0.5, color='black'))
boxprops.setdefault('whiskerprops', dict(linewidth=0.5, color='black'))
_boxplt = ax.boxplot(y_lists, positions=range(num_dsets), **boxprops)
if meanprops:
means = [yl.mean() for yl in y_lists]
xy = [(-0.5, means[0])]
for xx, yy in enumerate(means):
xy += [(xx - 0.25, yy),
(xx + 0.25, yy)]
xy += [(xx + 0.5, yy)]
xy = numpy.array(xy)
if not isinstance(meanprops, dict):
meanprops = {}
meanprops.setdefault('color', 'blue')
meanprops.setdefault('alpha', 0.8)
meanprops.setdefault('linewidth', 0.5)
_meanplt = ax.plot(xy[:, 0], xy[:, 1], **meanprops)
#for xd, yd in zip(x_lists, y_lists, zip=True):
# plt = ax.scatter(xd, yd, s=numpy.ones_like(yd), color='black', alpha=0.9)
# if boxplot:
# boxplt = ax.boxplot(
if ylim is not None:
mask_dn = y_data < ylim[0]
mask_up = y_data > ylim[1]
if mask_dn.any():
ax.scatter(x_data[mask_dn], numpy.full(mask_dn.sum(), ylim[0]), color='red', marker='v')
if mask_up.any():
ax.scatter(x_data[mask_up], numpy.full(mask_up.sum(), ylim[1]), color='red', marker='^')
max_x_pos = num_dsets - 1
major_xticks = []
minor_xticks = []
textobjs = []
headerobjs = []
for ll, level in enumerate(groups):
axl = label_axes[-1 - ll]
axh = header_axes[-1 - ll]
axl.axis('off')
axh.axis('off')
header_txt = axh.text(0, 0.5, level, ha='left', va='center', fontsize=10, weight='bold')
headerobjs.append((axh, header_txt))
textrefs = dict(axes=[], texts=[], span_fracs=[])
for (xmin, xmax, text_value), (xsize, twidth, theight) in zip(label_stack[ll], size_lists[ll], strict=True):
text_obj = axl.text(
0.5 * (xmin + xmax), 0.5, text_value,
ha = 'center', va = 'center',
fontsize = 10,
rotation = (90 if level in vert_groups else 0),
clip_on = True,
)
textrefs['axes'].append(axl)
textrefs['texts'].append(text_obj)
textrefs['span_fracs'].append(xsize / num_dsets)
if xmax < max_x_pos:
maxpt = xmax + 0.5
axl.axvline(maxpt, color='gray', linestyle=':', linewidth=0.5)
if ll == len(groups) - 2:
major_xticks.append(maxpt)
elif ll ==len(groups) - 1:
minor_xticks.append(maxpt)
else:
axl.axline(xy1=(xmin, 0), slope=0, color='gray', linewidth=0.5)
axl.axline(xy1=(xmin, 1), slope=0, color='gray', linewidth=0.5, alpha=0.5)
axl.set_ylim(0, 1)
textobjs.append(textrefs)
ax.set_xlim(-0.5, num_dsets - 0.5)
if ylim is not None:
ax.set_ylim(ylim)
ax.set_xticks(major_xticks, minor=False)
ax.set_xticks(minor_xticks, minor=True)
ax.set_xticklabels([], minor=False)
ax.set_xticklabels([], minor=True)
ax.tick_params('x', which='both', bottom=False)
ax.grid(alpha=0.2, which='minor')
ax.grid(alpha=1, which='major')
ax.set_ylabel(data_col)
ax.set_title(data_col)
ax.yaxis.set_minor_locator(ticker.AutoMinorLocator())
def resize_labels(event) -> None:
# Resize labels
margin_frac = 0.9
max_fontsize = 12
for level in textobjs:
ax_sizes = numpy.array([[abox.width, abox.height] for abox in [axl.get_window_extent() for axl in level['axes']]])
tx_sizes = numpy.array([[tbox.width, tbox.height] for tbox in [txt.get_window_extent() for txt in level['texts']]])
cur_fontsize = level['texts'][0].get_fontsize()
ax_sizes[:, 0] *= level['span_fracs']
scales = margin_frac * ax_sizes / tx_sizes
tgt_fontsize = min(cur_fontsize * scales.min(), max_fontsize)
for txt in level['texts']:
txt.set_fontsize(tgt_fontsize)
fig.canvas.draw_idle()
def resize_headers(event) -> None:
# Resize headers
margin_frac = 0.9
max_fontsize = 12
hax_sizes = numpy.array([[abox.width, abox.height] for abox in [axh.get_window_extent() for axh, _ in headerobjs]])
htx_sizes = numpy.array([[tbox.width, tbox.height] for tbox in [txh.get_window_extent() for _, txh in headerobjs]])
cur_fontsize = headerobjs[0][1].get_fontsize()
scales = margin_frac * hax_sizes / htx_sizes
tgt_fontsize = min(cur_fontsize * scales.min(), max_fontsize)
for _, txt in headerobjs:
txt.set_fontsize(tgt_fontsize)
fig.canvas.draw_idle()
fig.canvas.mpl_connect('resize_event', debounce(resize_labels))
fig.canvas.mpl_connect('resize_event', debounce(resize_headers))
return fig, ax, label_axes, header_axes
label_stack_t = list[list[tuple[int, int, str]]]
def debounce(func: Callable, delay_s: float = 0.05) -> Callable:
timer = None
def debounced_func(*args, **kwargs) -> None:
nonlocal timer
if timer is not None:
timer.cancel()
timer = threading.Timer(delay_s, lambda: func(*args, **kwargs))
timer.start()
return debounced_func
def get_text_sizes(label_stack: label_stack_t) -> list[NDArray[numpy.float64]]:
fig, ax = pyplot.subplots()
text_obj = ax.text(0, 0, 'placeholder')
renderer = fig.canvas.get_renderer()
size_lists = []
for ll, level in enumerate(label_stack):
sizes = []
for xmin, xmax, text_value in level:
text_obj.set_text(text_value)
tbox = text_obj.get_window_extent(renderer=renderer)
sizes.append((xmax - xmin + 1, tbox.width, tbox.height))
size_lists.append(numpy.array(sizes))
pyplot.close(fig)
return size_lists
def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn: Callable) -> label_stack_t:
label_stack = []
for ll, level in enumerate(groups):
spans = df_groups.group_by(groups[:ll + 1], maintain_order=True).agg(
xmin = col('x_pos').min(),
xmax = col('x_pos').max(),
).with_columns(
xspan = col('xmax') - col('xmin') + 1,
)
label_row = []
for row in spans.to_dicts():
# df row is plot column
text_value = wrap_fn(str(row[level]))
label_row.append((row['xmin'], row['xmax'], text_value))
label_stack.append(label_row)
return label_stack
def get_label_y_ratios(groups: Sequence[str], vert_groups: set[str], size_lists: list[NDArray[numpy.float64]]) -> list[float]:
grouping_rotated = numpy.array([grouping in vert_groups for grouping in groups], dtype=bool)
level_dims = []
for sizes, rotated in zip(size_lists, grouping_rotated, strict=True):
if rotated:
dxy = sizes[:, [2, 1]].copy()
else:
dxy = sizes[:, [1, 2]].copy()
dxy[:, 0] /= sizes[:, 0]
level_dims.append(dxy.max(axis=0))
scales = numpy.array(level_dims)
ratios = (scales.shape[0] * scales[::-1, 1] / scales[:, 1].sum()).tolist()
return ratios
def _mk_data(filename: str) -> None:
"""
Make some dummy data and write it to a csv file
"""
rows = []
rng = numpy.random.default_rng(seed=0)
for mm in ('liminal', 'transitive', 'extrinsic'):
for dd in ('elevator', 'snare', 'tibetan_foxhole', 'inverse_thresher'):
for vv in (('tiny', 'elective', 'baseline') if '_' in dd else ('dormant', 'volatile')):
std = rng.uniform(low=.1, high=1, size=1)
mean = rng.uniform(low=4, high=6, size=1)
for qq in rng.standard_normal(size=100) * std + mean:
rows.append(dict(MeasurementType=mm, Device=dd, DeviceVariant=vv, MeasuredValue=qq))
df = polars.DataFrame(rows)
df.write_csv()
if __name__ == '__main__':
filename = 'dummy_data.csv'
_mk_data(filename)
df = polars.read_csv(filename)
variability_plot(df, 'MeasuredValue', ['MeasurementType', 'Device', 'DeviceVariant'], vert_groups=['DeviceVariant'], ylim=(3, 7))
pyplot.show(block=True)