331 lines
12 KiB
Python
331 lines
12 KiB
Python
from typing import Sequence, Callable, Any
|
|
from pathlib import Path
|
|
import threading
|
|
import textwrap
|
|
|
|
from PyQt6 import QtCore
|
|
|
|
import numpy
|
|
from numpy.typing import NDArray
|
|
import polars
|
|
from polars import col
|
|
from matplotlib import pyplot, gridspec, ticker
|
|
from matplotlib.figure import Figure
|
|
from matplotlib.axes import Axes
|
|
|
|
|
|
def twrap(text: str, **kwargs) -> str:
|
|
kwargs.setdefault('width', 15)
|
|
intxt = text.replace('_', '-')
|
|
wrapped = textwrap.fill(intxt, **kwargs)
|
|
return wrapped.replace('-', '_')
|
|
|
|
|
|
def variability_plot(
|
|
data_table: 'dict | Sequence | NDArray | polars.Series | pandas.DataFrame',
|
|
data_col: str,
|
|
groups: Sequence[str],
|
|
vert_groups: Sequence[str] = (),
|
|
*,
|
|
wrap_fn: Callable[[str], str] = twrap,
|
|
mainplot_ratios: tuple[float, float] = (10, 10),
|
|
ylim: tuple[float, float] | None = None,
|
|
dotprops: dict[str, Any] | bool = True,
|
|
boxprops: dict[str, Any] | bool = True,
|
|
meanprops: dict[str, Any] | bool = True,
|
|
) -> tuple[Figure, Axes, list[Axes], list[Axes]]:
|
|
"""
|
|
Create a variability plot (categorical box & scatter plot)
|
|
|
|
Args:
|
|
data_table: Dataset to plot. Passed directly to `polars.DataFrame()`.
|
|
data_col: Column to use for box/scatterplot value.
|
|
groups: Columns to group by. Coarsest grouping should be first, and will appear
|
|
furthest from the scatterplot, at the bottom of the figure.
|
|
vert_groups: Labels for these column names will be rotated (text will run vertically).
|
|
wrap_fn: Function called to wrap label text (i.e. insert newlines).
|
|
Default wraps at 15 characters, preferentially on underscores or whitespace.
|
|
mainplot_ratios: Scale factors setting the size of the main axes, relative to the size
|
|
of other axes. Default is (10, 10).
|
|
ylim: Y-limits for the scatter/box plot. Points which fall outside the limits are drawn
|
|
as red triangles at the edges.
|
|
dotprops: Passed as kwargs to scatterplot.
|
|
boxprops: Passed as kwargs to boxplot.
|
|
meanprops: Passed as kwargs to lineplot of means.
|
|
|
|
Returns:
|
|
figure, data axes, label axes, header axes
|
|
"""
|
|
vert_groups = set(vert_groups)
|
|
|
|
df = polars.DataFrame(data_table)
|
|
# zero_bad: bool = True,
|
|
# if zero_bad:
|
|
# df.filter(col(data_col) != 0)
|
|
|
|
df = df.sort(groups)
|
|
df_groups = df.select(groups).unique(maintain_order=True).with_row_index(name='x_pos')
|
|
df = df.join(df_groups, on=groups, maintain_order='left')
|
|
max_group_length = df.group_by(groups).len().select('len').max()[0, 0]
|
|
|
|
label_stack = get_label_stack(df_groups, groups, wrap_fn)
|
|
size_lists = get_text_sizes(label_stack)
|
|
|
|
jitter = 0.2
|
|
rng = numpy.random.default_rng(seed=0)
|
|
jitter_offsets = rng.uniform(low=-jitter, high=jitter, size=max_group_length)
|
|
|
|
x_lists = []
|
|
y_lists = []
|
|
for _labels, gdf in df.group_by(groups, maintain_order=True):
|
|
x_lists.append(gdf['x_pos'][0] + jitter_offsets[:gdf.height])
|
|
y_lists.append(gdf[data_col])
|
|
num_dsets = len(x_lists)
|
|
x_data = numpy.concatenate(x_lists)
|
|
y_data = numpy.concatenate(y_lists)
|
|
|
|
|
|
y_ratios = [mainplot_ratios[1]] + get_label_y_ratios(groups, vert_groups, size_lists)
|
|
|
|
fig = pyplot.figure()
|
|
gs = gridspec.GridSpec(
|
|
nrows = 1 + len(groups),
|
|
ncols = 2,
|
|
height_ratios = y_ratios,
|
|
width_ratios = [mainplot_ratios[0], 1],
|
|
hspace = 0,
|
|
wspace = 0.05,
|
|
#left = 0.07,
|
|
right = 0.98,
|
|
)
|
|
|
|
ax = fig.add_subplot(gs[0, 0])
|
|
label_axes = []
|
|
header_axes = []
|
|
for ii in range(1, len(groups) + 1):
|
|
label_axes.append( fig.add_subplot(gs[ii, 0], sharex=ax))
|
|
header_axes.append(fig.add_subplot(gs[ii, 1]))
|
|
|
|
if dotprops:
|
|
if not isinstance(dotprops, dict):
|
|
dotprops = {}
|
|
dotprops.setdefault('alpha', 0.7)
|
|
dotprops.setdefault('color', 'black')
|
|
_dotplt = ax.scatter(x_data, y_data, s=numpy.ones_like(y_data), **dotprops)
|
|
if boxprops:
|
|
if not isinstance(boxprops, dict):
|
|
boxprops = {}
|
|
boxprops.setdefault('showfliers', False)
|
|
boxprops.setdefault('medianprops', dict(linewidth=3, color='darkred', alpha=0.8))
|
|
boxprops.setdefault('boxprops', dict(linewidth=0.5, color='black'))
|
|
boxprops.setdefault('whiskerprops', dict(linewidth=0.5, color='black'))
|
|
_boxplt = ax.boxplot(y_lists, positions=range(num_dsets), **boxprops)
|
|
if meanprops:
|
|
means = [yl.mean() for yl in y_lists]
|
|
xy = [(-0.5, means[0])]
|
|
for xx, yy in enumerate(means):
|
|
xy += [(xx - 0.25, yy),
|
|
(xx + 0.25, yy)]
|
|
xy += [(xx + 0.5, yy)]
|
|
xy = numpy.array(xy)
|
|
|
|
if not isinstance(meanprops, dict):
|
|
meanprops = {}
|
|
meanprops.setdefault('color', 'blue')
|
|
meanprops.setdefault('alpha', 0.8)
|
|
meanprops.setdefault('linewidth', 0.5)
|
|
_meanplt = ax.plot(xy[:, 0], xy[:, 1], **meanprops)
|
|
|
|
|
|
#for xd, yd in zip(x_lists, y_lists, zip=True):
|
|
# plt = ax.scatter(xd, yd, s=numpy.ones_like(yd), color='black', alpha=0.9)
|
|
# if boxplot:
|
|
# boxplt = ax.boxplot(
|
|
|
|
if ylim is not None:
|
|
mask_dn = y_data < ylim[0]
|
|
mask_up = y_data > ylim[1]
|
|
if mask_dn.any():
|
|
ax.scatter(x_data[mask_dn], numpy.full(mask_dn.sum(), ylim[0]), color='red', marker='v')
|
|
if mask_up.any():
|
|
ax.scatter(x_data[mask_up], numpy.full(mask_up.sum(), ylim[1]), color='red', marker='^')
|
|
|
|
max_x_pos = num_dsets - 1
|
|
major_xticks = []
|
|
minor_xticks = []
|
|
textobjs = []
|
|
headerobjs = []
|
|
for ll, level in enumerate(groups):
|
|
axl = label_axes[-1 - ll]
|
|
axh = header_axes[-1 - ll]
|
|
axl.axis('off')
|
|
axh.axis('off')
|
|
header_txt = axh.text(0, 0.5, level, ha='left', va='center', fontsize=10, weight='bold')
|
|
headerobjs.append((axh, header_txt))
|
|
|
|
textrefs = dict(axes=[], texts=[], span_fracs=[])
|
|
for (xmin, xmax, text_value), (xsize, twidth, theight) in zip(label_stack[ll], size_lists[ll], strict=True):
|
|
text_obj = axl.text(
|
|
0.5 * (xmin + xmax), 0.5, text_value,
|
|
ha = 'center', va = 'center',
|
|
fontsize = 10,
|
|
rotation = (90 if level in vert_groups else 0),
|
|
clip_on = True,
|
|
)
|
|
|
|
textrefs['axes'].append(axl)
|
|
textrefs['texts'].append(text_obj)
|
|
textrefs['span_fracs'].append(xsize / num_dsets)
|
|
|
|
if xmax < max_x_pos:
|
|
maxpt = xmax + 0.5
|
|
axl.axvline(maxpt, color='gray', linestyle=':', linewidth=0.5)
|
|
if ll == len(groups) - 2:
|
|
major_xticks.append(maxpt)
|
|
elif ll ==len(groups) - 1:
|
|
minor_xticks.append(maxpt)
|
|
else:
|
|
axl.axline(xy1=(xmin, 0), slope=0, color='gray', linewidth=0.5)
|
|
axl.axline(xy1=(xmin, 1), slope=0, color='gray', linewidth=0.5, alpha=0.5)
|
|
axl.set_ylim(0, 1)
|
|
|
|
textobjs.append(textrefs)
|
|
|
|
ax.set_xlim(-0.5, num_dsets - 0.5)
|
|
if ylim is not None:
|
|
ax.set_ylim(ylim)
|
|
ax.set_xticks(major_xticks, minor=False)
|
|
ax.set_xticks(minor_xticks, minor=True)
|
|
ax.set_xticklabels([], minor=False)
|
|
ax.set_xticklabels([], minor=True)
|
|
ax.tick_params('x', which='both', bottom=False)
|
|
ax.grid(alpha=0.2, which='minor')
|
|
ax.grid(alpha=1, which='major')
|
|
ax.set_ylabel(data_col)
|
|
ax.set_title(data_col)
|
|
ax.yaxis.set_minor_locator(ticker.AutoMinorLocator())
|
|
|
|
def resize_labels(event) -> None:
|
|
# Resize labels
|
|
margin_frac = 0.9
|
|
max_fontsize = 12
|
|
for level in textobjs:
|
|
ax_sizes = numpy.array([[abox.width, abox.height] for abox in [axl.get_window_extent() for axl in level['axes']]])
|
|
tx_sizes = numpy.array([[tbox.width, tbox.height] for tbox in [txt.get_window_extent() for txt in level['texts']]])
|
|
cur_fontsize = level['texts'][0].get_fontsize()
|
|
ax_sizes[:, 0] *= level['span_fracs']
|
|
|
|
scales = margin_frac * ax_sizes / tx_sizes
|
|
tgt_fontsize = min(cur_fontsize * scales.min(), max_fontsize)
|
|
for txt in level['texts']:
|
|
txt.set_fontsize(tgt_fontsize)
|
|
fig.canvas.draw_idle()
|
|
|
|
def resize_headers(event) -> None:
|
|
# Resize headers
|
|
margin_frac = 0.9
|
|
max_fontsize = 12
|
|
hax_sizes = numpy.array([[abox.width, abox.height] for abox in [axh.get_window_extent() for axh, _ in headerobjs]])
|
|
htx_sizes = numpy.array([[tbox.width, tbox.height] for tbox in [txh.get_window_extent() for _, txh in headerobjs]])
|
|
cur_fontsize = headerobjs[0][1].get_fontsize()
|
|
scales = margin_frac * hax_sizes / htx_sizes
|
|
tgt_fontsize = min(cur_fontsize * scales.min(), max_fontsize)
|
|
for _, txt in headerobjs:
|
|
txt.set_fontsize(tgt_fontsize)
|
|
fig.canvas.draw_idle()
|
|
|
|
fig.canvas.mpl_connect('resize_event', debounce(resize_labels))
|
|
fig.canvas.mpl_connect('resize_event', debounce(resize_headers))
|
|
return fig, ax, label_axes, header_axes
|
|
|
|
|
|
|
|
label_stack_t = list[list[tuple[int, int, str]]]
|
|
|
|
def debounce(func: Callable, delay_s: float = 0.05) -> Callable:
|
|
timer = None
|
|
def debounced_func(*args, **kwargs) -> None:
|
|
nonlocal timer
|
|
if timer is not None:
|
|
timer.cancel()
|
|
timer = threading.Timer(delay_s, lambda: func(*args, **kwargs))
|
|
timer.start()
|
|
return debounced_func
|
|
|
|
|
|
def get_text_sizes(label_stack: label_stack_t) -> list[NDArray[numpy.float64]]:
|
|
fig, ax = pyplot.subplots()
|
|
text_obj = ax.text(0, 0, 'placeholder')
|
|
renderer = fig.canvas.get_renderer()
|
|
|
|
size_lists = []
|
|
for ll, level in enumerate(label_stack):
|
|
sizes = []
|
|
for xmin, xmax, text_value in level:
|
|
text_obj.set_text(text_value)
|
|
tbox = text_obj.get_window_extent(renderer=renderer)
|
|
sizes.append((xmax - xmin + 1, tbox.width, tbox.height))
|
|
size_lists.append(numpy.array(sizes))
|
|
pyplot.close(fig)
|
|
return size_lists
|
|
|
|
|
|
def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn: Callable) -> label_stack_t:
|
|
label_stack = []
|
|
for ll, level in enumerate(groups):
|
|
spans = df_groups.group_by(groups[:ll + 1], maintain_order=True).agg(
|
|
xmin = col('x_pos').min(),
|
|
xmax = col('x_pos').max(),
|
|
).with_columns(
|
|
xspan = col('xmax') - col('xmin') + 1,
|
|
)
|
|
|
|
label_row = []
|
|
for row in spans.to_dicts():
|
|
# df row is plot column
|
|
text_value = wrap_fn(str(row[level]))
|
|
label_row.append((row['xmin'], row['xmax'], text_value))
|
|
label_stack.append(label_row)
|
|
return label_stack
|
|
|
|
|
|
def get_label_y_ratios(groups: Sequence[str], vert_groups: set[str], size_lists: list[NDArray[numpy.float64]]) -> list[float]:
|
|
grouping_rotated = numpy.array([grouping in vert_groups for grouping in groups], dtype=bool)
|
|
level_dims = []
|
|
for sizes, rotated in zip(size_lists, grouping_rotated, strict=True):
|
|
if rotated:
|
|
dxy = sizes[:, [2, 1]].copy()
|
|
else:
|
|
dxy = sizes[:, [1, 2]].copy()
|
|
dxy[:, 0] /= sizes[:, 0]
|
|
level_dims.append(dxy.max(axis=0))
|
|
scales = numpy.array(level_dims)
|
|
ratios = (scales.shape[0] * scales[::-1, 1] / scales[:, 1].sum()).tolist()
|
|
return ratios
|
|
|
|
|
|
def _mk_data(filename: str) -> None:
|
|
"""
|
|
Make some dummy data and write it to a csv file
|
|
"""
|
|
rows = []
|
|
rng = numpy.random.default_rng(seed=0)
|
|
for mm in ('liminal', 'transitive', 'extrinsic'):
|
|
for dd in ('elevator', 'snare', 'tibetan_foxhole', 'inverse_thresher'):
|
|
for vv in (('tiny', 'elective', 'baseline') if '_' in dd else ('dormant', 'volatile')):
|
|
std = rng.uniform(low=.1, high=1, size=1)
|
|
mean = rng.uniform(low=4, high=6, size=1)
|
|
for qq in rng.standard_normal(size=100) * std + mean:
|
|
rows.append(dict(MeasurementType=mm, Device=dd, DeviceVariant=vv, MeasuredValue=qq))
|
|
df = polars.DataFrame(rows)
|
|
df.write_csv()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
filename = 'dummy_data.csv'
|
|
_mk_data(filename)
|
|
df = polars.read_csv(filename)
|
|
variability_plot(df, 'MeasuredValue', ['MeasurementType', 'Device', 'DeviceVariant'], vert_groups=['DeviceVariant'], ylim=(3, 7))
|
|
|
|
pyplot.show(block=True)
|