diff --git a/gridlock/draw.py b/gridlock/draw.py index 5390871..b2a3245 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -70,10 +70,13 @@ def draw_polygons( + 'xyz'[surface_normal]) # Broadcast foreground where necessary - if numpy.size(foreground) == 1: - foreground = [foreground] * len(cell_data) + foregrounds: Union[Sequence[foreground_callable_t], Sequence[float]] + if numpy.size(foreground) == 1: # type: ignore + foregrounds = [foreground] * len(cell_data) # type: ignore elif isinstance(foreground, numpy.ndarray): raise GridError('ndarray not supported for foreground') + else: + foregrounds = foreground # type: ignore # ## Compute sub-domain of the grid occupied by polygons # 1) Compute outer bounds (bd) of polygons @@ -105,22 +108,23 @@ def draw_polygons( return numpy.insert(v_2d, surface_normal, (val,)) # iterate over grids - for i, grid in enumerate(cell_data): - # ## Evaluate or expand foreground[i] - if callable(foreground[i]): + for i, _ in enumerate(cell_data): + # ## Evaluate or expand foregrounds[i] + foregrounds_i = foregrounds[i] + if callable(foregrounds_i): # meshgrid over the (shifted) domain domain = [self.shifted_xyz(i)[k][bdi_min[k]:bdi_max[k]+1] for k in range(3)] (x0, y0, z0) = numpy.meshgrid(*domain, indexing='ij') # evaluate on the meshgrid - foreground_i = foreground[i](x0, y0, z0) - if not numpy.isfinite(foreground_i).all(): + foreground_val = foregrounds_i(x0, y0, z0) + if not numpy.isfinite(foreground_val).all(): raise GridError(f'Non-finite values in foreground[{i}]') - elif numpy.size(foreground[i]) != 1: - raise GridError(f'Unsupported foreground[{i}]: {type(foreground[i])}') + elif numpy.size(foregrounds_i) != 1: + raise GridError(f'Unsupported foreground[{i}]: {type(foregrounds_i)}') else: # foreground[i] is scalar non-callable - foreground_i = foreground[i] + foreground_val = foregrounds_i w_xy = numpy.zeros((bdi_max - bdi_min + 1)[surface].astype(int)) @@ -188,7 +192,7 @@ def draw_polygons( # ## Modify the grid g_slice = (i,) + tuple(numpy.s_[bdi_min[a]:bdi_max[a] + 1] for a in range(3)) - cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_i + cell_data[g_slice] = (1 - w) * cell_data[g_slice] + w * foreground_val def draw_polygon(