From 481b56874ee9c42f8534a378fa85e89c1e523d93 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 20 Apr 2026 10:52:45 -0700 Subject: [PATCH] [draw] fix extrude without out-of-bounds slice --- gridlock/draw.py | 23 +++++++++++++---------- gridlock/test/test_grid.py | 17 +++++++++++++++++ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/gridlock/draw.py b/gridlock/draw.py index 864468f..321ec15 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -76,10 +76,10 @@ class GridDrawMixin(GridPosMixin): # Broadcast foreground where necessary foregrounds: Sequence[foreground_callable_t] | Sequence[float] - if numpy.size(foreground) == 1: # type: ignore - foregrounds = [foreground] * len(cell_data) # type: ignore - elif isinstance(foreground, numpy.ndarray): + if isinstance(foreground, numpy.ndarray): raise GridError('ndarray not supported for foreground') + if callable(foreground) or numpy.isscalar(foreground): + foregrounds = [foreground] * len(cell_data) # type: ignore[list-item] else: foregrounds = foreground # type: ignore @@ -376,15 +376,18 @@ class GridDrawMixin(GridPosMixin): foreground_func = [] for ii, grid in enumerate(cell_data): zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction] - - ind = [int(numpy.floor(zz)) if dd == direction else slice(None) for dd in range(3)] - fpart = zz - numpy.floor(zz) - mult = [1 - fpart, fpart][::sgn] # reverses if s negative + low = int(numpy.clip(numpy.floor(zz), 0, grid.shape[direction] - 1)) + high = int(numpy.clip(numpy.floor(zz) + 1, 0, grid.shape[direction] - 1)) - foreground = mult[0] * grid[tuple(ind)] - ind[direction] += 1 # type: ignore #(known safe) - foreground += mult[1] * grid[tuple(ind)] + low_ind = [low if dd == direction else slice(None) for dd in range(3)] + high_ind = [high if dd == direction else slice(None) for dd in range(3)] + + if low == high: + foreground = grid[tuple(low_ind)] + else: + mult = [1 - fpart, fpart][::sgn] # reverses if s negative + foreground = mult[0] * grid[tuple(low_ind)] + mult[1] * grid[tuple(high_ind)] def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 # transform from natural position to index diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index 2cb60c5..e7b3b28 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -238,6 +238,23 @@ def test_get_slice_uses_shifted_grid_bounds() -> None: grid.get_slice(cell_data, Plane(x=2.1), which_shifts=0) +def test_draw_extrude_rectangle_uses_boundary_slice() -> None: + grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]]) + cell_data = grid.allocate(0) + source = numpy.array([[1, 2], + [3, 4]], dtype=float) + cell_data[0, :, :, 1] = source + + grid.draw_extrude_rectangle( + cell_data, + rectangle=[[0, 0, 2], [2, 2, 2]], + direction=2, + polarity=-1, + distance=2, + ) + + assert_allclose(cell_data[0, :, :, 0], source) + assert_allclose(cell_data[0, :, :, 1], source) def test_sampled_preview_exyz_tracks_nonuniform_centers() -> None: