diff --git a/miscplot/variability.py b/miscplot/variability.py index b5df661..a1dc99e 100644 --- a/miscplot/variability.py +++ b/miscplot/variability.py @@ -67,10 +67,7 @@ def variability_plot( df = df.sort(groups) df_groups = df.select(groups).unique(maintain_order=True).with_row_index(name='x_pos') df = df.join(df_groups, on=groups, maintain_order='left') - max_group_length = df.group_by(groups).len().select('len').max()[0, 0] - - label_stack = get_label_stack(df_groups, groups, wrap_fn) - size_lists = get_text_sizes(label_stack) + max_group_length = df.group_by(groups).len().select('len').max()[0, 0] # How many points in the largest x_pos # Add jitter to the scatterplots-plots jitter = 0.2 @@ -87,6 +84,8 @@ def variability_plot( y_data = numpy.concatenate(y_lists) # Get label contents and measure their sizes + label_stack = get_label_stack(df_groups, groups, wrap_fn) + size_lists = get_text_sizes(label_stack) y_ratios = [mainplot_ratios[1]] + get_label_y_ratios(groups, vert_groups, size_lists) # Build the figure and all axes @@ -268,23 +267,6 @@ def debounce(func: Callable, delay_s: float = 0.05) -> Callable: return debounced_func -def get_text_sizes(label_stack: label_stack_t) -> list[NDArray[numpy.float64]]: - fig, ax = pyplot.subplots() - text_obj = ax.text(0, 0, 'placeholder') - renderer = fig.canvas.get_renderer() - - size_lists = [] - for ll, level in enumerate(label_stack): - sizes = [] - for xmin, xmax, text_value in level: - text_obj.set_text(text_value) - tbox = text_obj.get_window_extent(renderer=renderer) - sizes.append((xmax - xmin + 1, tbox.width, tbox.height)) - size_lists.append(numpy.array(sizes)) - pyplot.close(fig) - return size_lists - - def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn: Callable) -> label_stack_t: """ For each level, get (xmin_inclusive, xmax_inclusive, wrapped_text_for_label) for all the labels @@ -307,6 +289,26 @@ def get_label_stack(df_groups: polars.DataFrame, groups: Sequence[str], wrap_fn: return label_stack +def get_text_sizes(label_stack: label_stack_t) -> list[NDArray[numpy.float64]]: + """ + Transform the label stack (see `get_label_stack` into a stack of (allocated x-span, unrotated x-size, unrotated y-size) + """ + fig, ax = pyplot.subplots() + text_obj = ax.text(0, 0, 'placeholder') + renderer = fig.canvas.get_renderer() + + size_lists = [] + for ll, level in enumerate(label_stack): + sizes = [] + for xmin, xmax, text_value in level: + text_obj.set_text(text_value) + tbox = text_obj.get_window_extent(renderer=renderer) + sizes.append((xmax - xmin + 1, tbox.width, tbox.height)) + size_lists.append(numpy.array(sizes)) + pyplot.close(fig) + return size_lists + + def get_label_y_ratios(groups: Sequence[str], vert_groups: set[str], size_lists: list[NDArray[numpy.float64]]) -> list[float]: """ For each level, figure out max(rotated_x_size / x_span) and max(rotated_y_size).