diff --git a/miscplot/variability.py b/miscplot/variability.py index 20c3c01..b5df661 100644 --- a/miscplot/variability.py +++ b/miscplot/variability.py @@ -63,6 +63,7 @@ def variability_plot( # Drop nulls and nans so that the boxplots don't disappear df = polars.drop_nulls(subset=[data_col]).drop_nans(subset=[data_col]) + # Assign category indicies (x_pos) df = df.sort(groups) df_groups = df.select(groups).unique(maintain_order=True).with_row_index(name='x_pos') df = df.join(df_groups, on=groups, maintain_order='left') @@ -71,6 +72,7 @@ def variability_plot( label_stack = get_label_stack(df_groups, groups, wrap_fn) size_lists = get_text_sizes(label_stack) + # Add jitter to the scatterplots-plots jitter = 0.2 rng = numpy.random.default_rng(seed=0) jitter_offsets = rng.uniform(low=-jitter, high=jitter, size=max_group_length) @@ -84,9 +86,10 @@ def variability_plot( x_data = numpy.concatenate(x_lists) y_data = numpy.concatenate(y_lists) - + # Get label contents and measure their sizes y_ratios = [mainplot_ratios[1]] + get_label_y_ratios(groups, vert_groups, size_lists) + # Build the figure and all axes fig = pyplot.figure() gs = gridspec.GridSpec( nrows = 1 + len(groups), @@ -106,6 +109,9 @@ def variability_plot( label_axes.append( fig.add_subplot(gs[ii, 0], sharex=ax)) header_axes.append(fig.add_subplot(gs[ii, 1])) + # + # Draw all the data + # if dotprops: if not isinstance(dotprops, dict): dotprops = {} @@ -150,6 +156,9 @@ def variability_plot( if mask_up.any(): ax.scatter(x_data[mask_up], numpy.full(mask_up.sum(), ylim[1]), color='red', marker='^') + # + # Draw all the labels + # max_x_pos = num_dsets - 1 major_xticks = [] minor_xticks = [] @@ -191,6 +200,9 @@ def variability_plot( textobjs.append(textrefs) + # + # Set limits and grid on the main plot + # ax.set_xlim(-0.5, num_dsets - 0.5) if ylim is not None: ax.set_ylim(ylim) @@ -205,6 +217,9 @@ def variability_plot( ax.set_title(data_col) ax.yaxis.set_minor_locator(ticker.AutoMinorLocator()) + # + # Add text resizing handlers to make sure labels are sized relative to their containing axes + # def resize_labels(event) -> None: # Resize labels margin_frac = 0.9 @@ -271,6 +286,9 @@ def get_text_sizes(label_stack: label_stack_t) -> list[NDArray[numpy.float64]]: def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn: Callable) -> label_stack_t: + """ + For each level, get (xmin_inclusive, xmax_inclusive, wrapped_text_for_label) for all the labels + """ label_stack = [] for ll, level in enumerate(groups): spans = df_groups.group_by(groups[:ll + 1], maintain_order=True).agg( @@ -290,6 +308,11 @@ def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn: def get_label_y_ratios(groups: Sequence[str], vert_groups: set[str], size_lists: list[NDArray[numpy.float64]]) -> list[float]: + """ + For each level, figure out max(rotated_x_size / x_span) and max(rotated_y_size). + Normalize so that the sum of y-values is equal to the number of levels. + Output order is reversed so that the bottom labels (most general) come last. + """ grouping_rotated = numpy.array([grouping in vert_groups for grouping in groups], dtype=bool) level_dims = [] for sizes, rotated in zip(size_lists, grouping_rotated, strict=True):