diff --git a/gridlock/__init__.py b/gridlock/__init__.py index 759d1c1..120291f 100644 --- a/gridlock/__init__.py +++ b/gridlock/__init__.py @@ -31,9 +31,8 @@ from .utils import ( PlaneDict as PlaneDict, ) from .grid import Grid as Grid -from .data import GridData as GridData __author__ = 'Jan Petykiewicz' -__version__ = '2.2' +__version__ = '1.2' version = __version__ diff --git a/gridlock/base.py b/gridlock/base.py index e68d955..aca9c69 100644 --- a/gridlock/base.py +++ b/gridlock/base.py @@ -76,21 +76,6 @@ class GridBase(Protocol): el = [0 if p else -1 for p in self.periodic] return [numpy.hstack((self.dxyz[a], self.dxyz[a][e])) for a, e in zip(range(3), el, strict=True)] - def _shifted_edge_dxyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: - if which_shifts is None: - return self.dxyz_with_ghost - - shifts = self.shifts[which_shifts, :] - edge_dxyz = [] - for a in range(3): - if shifts[a] < 0: - ghost = self.dxyz[a][-1] if self.periodic[a] else self.dxyz[a][0] - edge_dxyz.append(numpy.hstack((ghost, self.dxyz[a]))) - else: - ghost = self.dxyz[a][0] if self.periodic[a] else self.dxyz[a][-1] - edge_dxyz.append(numpy.hstack((self.dxyz[a], ghost))) - return edge_dxyz - @property def center(self) -> NDArray[numpy.float64]: """ @@ -130,9 +115,15 @@ class GridBase(Protocol): """ if which_shifts is None: return self.exyz - edge_dxyz = self._shifted_edge_dxyz(which_shifts) + dxyz = self.dxyz_with_ghost shifts = self.shifts[which_shifts, :] - return [self.exyz[a] + edge_dxyz[a] * shifts[a] for a in range(3)] + + # If shift is negative, use left cell's dx to determine shift + for a in range(3): + if shifts[a] < 0: + dxyz[a] = numpy.roll(dxyz[a], 1) + + return [self.exyz[a] + dxyz[a] * shifts[a] for a in range(3)] def shifted_dxyz(self, which_shifts: int | None) -> list[NDArray]: """ @@ -146,7 +137,20 @@ class GridBase(Protocol): """ if which_shifts is None: return self.dxyz - return [numpy.diff(exyz) for exyz in self.shifted_exyz(which_shifts)] + shifts = self.shifts[which_shifts, :] + dxyz = self.dxyz_with_ghost + + # If shift is negative, use left cell's dx to determine size + sdxyz = [] + for a in range(3): + if shifts[a] < 0: + roll_dxyz = numpy.roll(dxyz[a], 1) + abs_shift = numpy.abs(shifts[a]) + sdxyz.append(roll_dxyz[:-1] * abs_shift + roll_dxyz[1:] * (1 - abs_shift)) + else: + sdxyz.append(dxyz[a][:-1] * (1 - shifts[a]) + dxyz[a][1:] * shifts[a]) + + return sdxyz def shifted_xyz(self, which_shifts: int | None) -> list[NDArray[numpy.float64]]: """ diff --git a/gridlock/data.py b/gridlock/data.py deleted file mode 100644 index 5e6faa5..0000000 --- a/gridlock/data.py +++ /dev/null @@ -1,176 +0,0 @@ -from dataclasses import dataclass -from typing import Self -from collections.abc import Sequence - -import numpy -from numpy.typing import NDArray, ArrayLike - -from .draw import foreground_t -from .grid import Grid, _grid_from_payload, _load_payload, _payload_scalar_str, _save_npz_payload -from .utils import ( - ExtentDict, - ExtentProtocol, - GridError, - PlaneDict, - PlaneProtocol, - SlabDict, - SlabProtocol, -) - - -@dataclass(slots=True) -class GridData: - grid: Grid - cell_data: NDArray - - def __post_init__(self) -> None: - if tuple(self.cell_data.shape) != tuple(self.grid.cell_data_shape): - raise GridError( - f'cell_data has shape {self.cell_data.shape}, expected {tuple(self.grid.cell_data_shape)}' - ) - - @staticmethod - def load(filename: str) -> 'GridData': - payload = _load_payload(filename) - if _payload_scalar_str(payload, 'kind') != 'grid_data': - raise GridError('Serialized payload does not contain GridData') - if 'cell_data' not in payload: - raise GridError('Serialized GridData payload is missing cell_data') - - return GridData(_grid_from_payload(payload), numpy.array(payload['cell_data'])) - - def save(self, filename: str) -> Self: - payload = self.grid._serialization_payload(kind='grid_data') - payload['cell_data'] = self.cell_data - _save_npz_payload(filename, payload) - return self - - def copy(self) -> Self: - return GridData(self.grid.copy(), self.cell_data.copy()) - - def draw_polygons( - self, - foreground: Sequence[foreground_t] | foreground_t, - slab: SlabProtocol | SlabDict, - polygons: Sequence[ArrayLike], - *, - offset2d: ArrayLike = (0, 0), - ) -> Self: - self.grid.draw_polygons(self.cell_data, foreground, slab, polygons, offset2d=offset2d) - return self - - def draw_polygon( - self, - foreground: Sequence[foreground_t] | foreground_t, - slab: SlabProtocol | SlabDict, - polygon: ArrayLike, - *, - offset2d: ArrayLike = (0, 0), - ) -> Self: - self.grid.draw_polygon(self.cell_data, foreground, slab, polygon, offset2d=offset2d) - return self - - def draw_slab( - self, - foreground: Sequence[foreground_t] | foreground_t, - slab: SlabProtocol | SlabDict, - ) -> Self: - self.grid.draw_slab(self.cell_data, foreground, slab) - return self - - def draw_cuboid( - self, - foreground: Sequence[foreground_t] | foreground_t, - *, - x: ExtentProtocol | ExtentDict, - y: ExtentProtocol | ExtentDict, - z: ExtentProtocol | ExtentDict, - ) -> Self: - self.grid.draw_cuboid(self.cell_data, foreground, x=x, y=y, z=z) - return self - - def draw_cylinder( - self, - h: SlabProtocol | SlabDict, - radius: float, - num_points: int, - center2d: ArrayLike, - foreground: Sequence[foreground_t] | foreground_t, - ) -> Self: - self.grid.draw_cylinder(self.cell_data, h, radius, num_points, center2d, foreground) - return self - - def draw_extrude_rectangle( - self, - rectangle: ArrayLike, - direction: int, - polarity: int, - distance: float, - ) -> Self: - self.grid.draw_extrude_rectangle(self.cell_data, rectangle, direction, polarity, distance) - return self - - def get_slice( - self, - plane: PlaneProtocol | PlaneDict, - which_shifts: int = 0, - sample_period: int = 1, - ) -> NDArray: - return self.grid.get_slice(self.cell_data, plane, which_shifts=which_shifts, sample_period=sample_period) - - def visualize_slice( - self, - plane: PlaneProtocol | PlaneDict, - which_shifts: int = 0, - sample_period: int = 1, - finalize: bool = True, - pcolormesh_args: dict[str, object] | None = None, - ax: object | None = None, - ) -> tuple[object, object]: - return self.grid.visualize_slice( - self.cell_data, - plane, - which_shifts=which_shifts, - sample_period=sample_period, - finalize=finalize, - pcolormesh_args=pcolormesh_args, - ax=ax, - ) - - def visualize_edges( - self, - plane: PlaneProtocol | PlaneDict, - which_shifts: int = 0, - sample_period: int = 1, - finalize: bool = True, - contour_args: dict[str, object] | None = None, - ax: object | None = None, - level_fraction: float = 0.7, - ) -> tuple[object, object]: - return self.grid.visualize_edges( - self.cell_data, - plane, - which_shifts=which_shifts, - sample_period=sample_period, - finalize=finalize, - contour_args=contour_args, - ax=ax, - level_fraction=level_fraction, - ) - - def visualize_isosurface( - self, - level: float | None = None, - which_shifts: int = 0, - sample_period: int = 1, - show_edges: bool = True, - finalize: bool = True, - ) -> tuple[object, object]: - return self.grid.visualize_isosurface( - self.cell_data, - level=level, - which_shifts=which_shifts, - sample_period=sample_period, - show_edges=show_edges, - finalize=finalize, - ) diff --git a/gridlock/draw.py b/gridlock/draw.py index 321ec15..0b93d20 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -21,31 +21,30 @@ foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray] foreground_t = float | foreground_callable_t + class GridDrawMixin(GridPosMixin): def draw_polygons( self, cell_data: NDArray, - foreground: Sequence[foreground_t] | foreground_t, slab: SlabProtocol | SlabDict, polygons: Sequence[ArrayLike], + foreground: Sequence[foreground_t] | foreground_t, *, offset2d: ArrayLike = (0, 0), ) -> None: """ - Draw polygons on an axis-aligned slab. + Draw polygons on an axis-aligned plane. Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + center: 3-element ndarray or list specifying an offset applied to all the polygons + polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon + (non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each + polygon must have at least 3 vertices. foreground: Value to draw with ('brush color'). Can be scalar, callable, or a list of any of these (1 per grid). Callable values should take an ndarray the shape of the grid and return an ndarray of equal shape containing the foreground value at the given x, y, and z (natural, not grid coordinates). - slab: `Slab` or slab-like dict specifying the slab in which the polygons will be drawn. - polygons: List of Nx2 or Nx3 ndarrays, each specifying the vertices of a polygon - (non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Each - polygon must have at least 3 vertices. - offset2d: 2D offset to apply to polygon coordinates -- this offset is added directly - to the given polygon vertex coordinates. Default (0, 0). Raises: GridError @@ -61,25 +60,23 @@ class GridDrawMixin(GridPosMixin): for ii in range(len(poly_list)): polygon = poly_list[ii] malformed = f'Malformed polygon: ({ii})' - if polygon.ndim != 2: - raise GridError(malformed + 'must be a 2-dimensional ndarray') if polygon.shape[1] not in (2, 3): raise GridError(malformed + 'must be a Nx2 or Nx3 ndarray') if polygon.shape[1] == 3: - if numpy.unique(polygon[:, slab.axis]).size != 1: - raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis]) - polygon = polygon[:, surface] + polygon = polygon[surface, :] poly_list[ii] = polygon if not polygon.shape[0] > 2: raise GridError(malformed + 'must consist of more than 2 points') + if polygon.ndim > 2 and not numpy.unique(polygon[:, slab.axis]).size == 1: + raise GridError(malformed + 'must be in plane with surface normal ' + 'xyz'[slab.axis]) # Broadcast foreground where necessary foregrounds: Sequence[foreground_callable_t] | Sequence[float] - if isinstance(foreground, numpy.ndarray): + if numpy.size(foreground) == 1: # type: ignore + foregrounds = [foreground] * len(cell_data) # type: ignore + elif isinstance(foreground, numpy.ndarray): raise GridError('ndarray not supported for foreground') - if callable(foreground) or numpy.isscalar(foreground): - foregrounds = [foreground] * len(cell_data) # type: ignore[list-item] else: foregrounds = foreground # type: ignore @@ -203,9 +200,9 @@ class GridDrawMixin(GridPosMixin): def draw_polygon( self, cell_data: NDArray, - foreground: Sequence[foreground_t] | foreground_t, slab: SlabProtocol | SlabDict, polygon: ArrayLike, + foreground: Sequence[foreground_t] | foreground_t, *, offset2d: ArrayLike = (0, 0), ) -> None: @@ -214,13 +211,11 @@ class GridDrawMixin(GridPosMixin): Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) - foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - slab: `Slab` or slab-like dict specifying the slab in which the polygon will be drawn. + slab: `Slab` in which to draw polygons. polygon: Nx2 or Nx3 ndarray specifying the vertices of a polygon (non-closed, clockwise). If Nx3, the `slab.axis`-th coordinate is ignored. Must have at least 3 vertices. - offset2d: 2D offset to apply to polygon coordinates -- this offset is added directly - to the given polygon vertex coordinates. Default (0, 0). + foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. """ self.draw_polygons( cell_data = cell_data, @@ -234,16 +229,17 @@ class GridDrawMixin(GridPosMixin): def draw_slab( self, cell_data: NDArray, - foreground: Sequence[foreground_t] | foreground_t, slab: SlabProtocol | SlabDict, + foreground: Sequence[foreground_t] | foreground_t, ) -> None: """ Draw an axis-aligned infinite slab. Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + slab: + thickness: Thickness of the layer to draw foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - slab: `Slab` or slab-like dict (geometrical slab specification) """ if isinstance(slab, dict): slab = Slab(**slab) @@ -286,10 +282,10 @@ class GridDrawMixin(GridPosMixin): Args: cell_data: Cell data to modify (e.g. created by `Grid.allocate()`) + center: 3-element ndarray or list specifying the cuboid's center + dimensions: 3-element list or ndarray containing the x, y, and z edge-to-edge + sizes of the cuboid foreground: Value to draw with ('brush color'). See `draw_polygons()` for details. - x: `Extent` or extent-like dict specifying the x-extent of the cuboid. - y: `Extent` or extent-like dict specifying the y-extent of the cuboid. - z: `Extent` or extent-like dict specifying the z-extent of the cuboid. """ if isinstance(x, dict): x = Extent(**x) @@ -298,6 +294,8 @@ class GridDrawMixin(GridPosMixin): if isinstance(z, dict): z = Extent(**z) + center = numpy.asarray([x.center, y.center, z.center]) + p = numpy.array([[x.min, y.max], [x.max, y.max], [x.max, y.min], @@ -376,18 +374,15 @@ class GridDrawMixin(GridPosMixin): foreground_func = [] for ii, grid in enumerate(cell_data): zz = self.pos2ind(rectangle[0, :], ii, round_ind=False, check_bounds=False)[direction] + + ind = [int(numpy.floor(zz)) if dd == direction else slice(None) for dd in range(3)] + fpart = zz - numpy.floor(zz) - low = int(numpy.clip(numpy.floor(zz), 0, grid.shape[direction] - 1)) - high = int(numpy.clip(numpy.floor(zz) + 1, 0, grid.shape[direction] - 1)) + mult = [1 - fpart, fpart][::sgn] # reverses if s negative - low_ind = [low if dd == direction else slice(None) for dd in range(3)] - high_ind = [high if dd == direction else slice(None) for dd in range(3)] - - if low == high: - foreground = grid[tuple(low_ind)] - else: - mult = [1 - fpart, fpart][::sgn] # reverses if s negative - foreground = mult[0] * grid[tuple(low_ind)] + mult[1] * grid[tuple(high_ind)] + foreground = mult[0] * grid[tuple(ind)] + ind[direction] += 1 # type: ignore #(known safe) + foreground += mult[1] * grid[tuple(ind)] def f_foreground(xs, ys, zs, ii=ii, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 # transform from natural position to index @@ -401,3 +396,4 @@ class GridDrawMixin(GridPosMixin): slab = Slab(axis=direction, center=center[direction], span=thickness) self.draw_polygon(cell_data, slab=slab, polygon=poly, foreground=foreground_func, offset2d=center[surface]) + diff --git a/gridlock/grid.py b/gridlock/grid.py index eeb9708..5790dbd 100644 --- a/gridlock/grid.py +++ b/gridlock/grid.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, ClassVar, Self +from typing import ClassVar, Self from collections.abc import Callable, Sequence import numpy @@ -13,78 +13,8 @@ from .draw import GridDrawMixin from .read import GridReadMixin from .position import GridPosMixin -if TYPE_CHECKING: - from .data import GridData - foreground_callable_type = Callable[[NDArray, NDArray, NDArray], NDArray] -_FORMAT_VERSION = 1 - - -def _is_npz_file(filename: str) -> bool: - with open(filename, 'rb') as f: - return f.read(2) == b'PK' - - -def _save_npz_payload(filename: str, payload: dict[str, Any]) -> None: - with open(filename, 'wb') as f: - numpy.savez_compressed(f, **payload) - - -def _load_payload(filename: str) -> dict[str, Any]: - if _is_npz_file(filename): - with numpy.load(filename, allow_pickle=False) as payload: - return {key: payload[key] for key in payload.files} - - with open(filename, 'rb') as f: - legacy = pickle.load(f) - - if isinstance(legacy, Grid): - return legacy._serialization_payload(kind='grid') - if isinstance(legacy, dict): - grid = Grid([[-1, 1]] * 3) - grid.__dict__.update(legacy) - return grid._serialization_payload(kind='grid') - raise GridError('Unsupported serialized Grid payload') - - -def _payload_scalar_str(payload: dict[str, Any], key: str) -> str: - if key not in payload: - raise GridError(f'Missing serialized key: {key}') - - value = numpy.asarray(payload[key]) - if value.size != 1: - raise GridError(f'Serialized key {key} must be scalar') - return str(value.reshape(())) - - -def _payload_scalar_int(payload: dict[str, Any], key: str) -> int: - if key not in payload: - raise GridError(f'Missing serialized key: {key}') - - value = numpy.asarray(payload[key]) - if value.size != 1: - raise GridError(f'Serialized key {key} must be scalar') - return int(value.reshape(())) - - -def _grid_from_payload(payload: dict[str, Any]) -> 'Grid': - if _payload_scalar_int(payload, 'format_version') != _FORMAT_VERSION: - raise GridError('Unsupported serialized Grid format version') - - exyz = [] - for axis in range(3): - key = f'exyz_{axis}' - if key not in payload: - raise GridError(f'Missing serialized key: {key}') - exyz.append(numpy.array(payload[key], dtype=float)) - - if 'shifts' not in payload or 'periodic' not in payload: - raise GridError('Serialized Grid payload is missing shifts or periodic data') - - shifts = numpy.array(payload['shifts'], dtype=float) - periodic = numpy.array(payload['periodic'], dtype=bool).tolist() - return Grid(exyz, shifts=shifts, periodic=periodic) class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): @@ -165,8 +95,6 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): `GridError` on invalid input """ edge_arrs = [numpy.array(cc) for cc in pixel_edge_coordinates] - if len(edge_arrs) != 3: - raise GridError('pixel_edge_coordinates must contain exactly 3 coordinate arrays') self.exyz = [numpy.unique(edges) for edges in edge_arrs] self.shifts = numpy.array(shifts, dtype=float) @@ -178,10 +106,6 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): self.periodic = [periodic] * 3 else: self.periodic = list(periodic) - if len(self.periodic) != 3: - raise GridError('periodic must be a bool or a sequence of length 3') - if not all(isinstance(pp, bool | numpy.bool_) for pp in self.periodic): - raise GridError('periodic sequence entries must be bool values') if len(self.shifts.shape) != 2: raise GridError('Misshapen shifts: shifts must have two axes! ' @@ -193,16 +117,9 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): if (numpy.abs(self.shifts) > 1).any(): raise GridError('Only shifts in the range [-1, 1] are currently supported') - def _serialization_payload(self, *, kind: str) -> dict[str, Any]: - payload: dict[str, Any] = { - 'kind': numpy.array(kind), - 'format_version': numpy.array(_FORMAT_VERSION, dtype=int), - 'shifts': self.shifts, - 'periodic': numpy.array(self.periodic, dtype=bool), - } - for axis, exyz in enumerate(self.exyz): - payload[f'exyz_{axis}'] = exyz - return payload + if (self.shifts < 0).any(): + # TODO: Test negative shifts + warnings.warn('Negative shifts are still experimental and mostly untested, be careful!', stacklevel=2) @staticmethod def load(filename: str) -> 'Grid': @@ -212,11 +129,12 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Args: filename: Filename to load from. """ - payload = _load_payload(filename) - kind = _payload_scalar_str(payload, 'kind') - if kind not in ('grid', 'grid_data'): - raise GridError(f'Unsupported serialized kind: {kind}') - return _grid_from_payload(payload) + with open(filename, 'rb') as f: + tmp_dict = pickle.load(f) + + g = Grid([[-1, 1]] * 3) + g.__dict__.update(tmp_dict) + return g def save(self, filename: str) -> Self: """ @@ -228,18 +146,10 @@ class Grid(GridDrawMixin, GridReadMixin, GridPosMixin): Returns: self """ - _save_npz_payload(filename, self._serialization_payload(kind='grid')) + with open(filename, 'wb') as f: + pickle.dump(self.__dict__, f, protocol=2) return self - def with_data( - self, - fill_value: float | None = 1.0, - dtype: type[numpy.number] = numpy.float32, - ) -> 'GridData': - from .data import GridData - - return GridData(self.copy(), self.allocate(fill_value=fill_value, dtype=dtype)) - def copy(self) -> Self: """ Returns: diff --git a/gridlock/position.py b/gridlock/position.py index 6344ea4..b705b99 100644 --- a/gridlock/position.py +++ b/gridlock/position.py @@ -47,13 +47,13 @@ class GridPosMixin(GridBase): else: low_bound = -0.5 high_bound = -0.5 - if (ind < low_bound).any() or (ind > self.shape + high_bound).any(): + if (ind < low_bound).any() or (ind > self.shape - high_bound).any(): raise GridError(f'Position outside of grid: {ind}') if round_ind: rind = numpy.clip(numpy.round(ind).astype(int), 0, self.shape - 1) sxyz = self.shifted_xyz(which_shifts) - position = [sxyz[a][rind[a]] for a in range(3)] + position = [sxyz[a][rind[a]].astype(int) for a in range(3)] else: sexyz = self.shifted_exyz(which_shifts) position = [numpy.interp(ind[a], numpy.arange(sexyz[a].size) - 0.5, sexyz[a]) diff --git a/gridlock/read.py b/gridlock/read.py index f8a40a1..707251a 100644 --- a/gridlock/read.py +++ b/gridlock/read.py @@ -20,26 +20,6 @@ if TYPE_CHECKING: class GridReadMixin(GridPosMixin): - @staticmethod - def _preview_exyz_from_centers(centers: NDArray, fallback_edges: NDArray) -> NDArray[numpy.float64]: - if centers.size > 1: - midpoints = 0.5 * (centers[:-1] + centers[1:]) - first = centers[0] - 0.5 * (centers[1] - centers[0]) - last = centers[-1] + 0.5 * (centers[-1] - centers[-2]) - return numpy.hstack(([first], midpoints, [last])) - return numpy.array([fallback_edges[0], fallback_edges[-1]], dtype=float) - - def _sampled_exyz(self, which_shifts: int, sample_period: int) -> list[NDArray[numpy.float64]]: - if sample_period <= 1: - return self.shifted_exyz(which_shifts) - - shifted_xyz = self.shifted_xyz(which_shifts) - shifted_exyz = self.shifted_exyz(which_shifts) - return [ - self._preview_exyz_from_centers(shifted_xyz[a][::sample_period], shifted_exyz[a]) - for a in range(3) - ] - def get_slice( self, cell_data: NDArray, @@ -73,7 +53,7 @@ class GridReadMixin(GridPosMixin): surface = numpy.delete(range(3), plane.axis) # Extract indices and weights of planes - center3 = numpy.insert([0.0, 0.0], plane.axis, (plane.pos,)) + center3 = numpy.insert([0, 0], plane.axis, (plane.pos,)) center_index = self.pos2ind(center3, which_shifts, round_ind=False, check_bounds=False)[plane.axis] centers = numpy.unique([numpy.floor(center_index), numpy.ceil(center_index)]).astype(int) @@ -83,13 +63,12 @@ class GridReadMixin(GridPosMixin): else: w = [1] - c_min, c_max = (self.shifted_xyz(which_shifts)[plane.axis][i] for i in [0, -1]) + c_min, c_max = (self.xyz[plane.axis][i] for i in [0, -1]) if plane.pos < c_min or plane.pos > c_max: raise GridError('Coordinate of selected plane must be within simulation domain') # Extract grid values from planes above and below visualized slice - sample_shape = tuple(self.shifted_xyz(which_shifts)[a][::sp].size for a in surface) - sliced_grid = numpy.zeros(sample_shape, dtype=numpy.result_type(cell_data.dtype, float)) + sliced_grid = numpy.zeros(self.shape[surface]) for ci, weight in zip(centers, w, strict=True): s = tuple(ci if a == plane.axis else numpy.s_[::sp] for a in range(3)) sliced_grid += weight * cell_data[which_shifts][tuple(s)] @@ -108,7 +87,6 @@ class GridReadMixin(GridPosMixin): sample_period: int = 1, finalize: bool = True, pcolormesh_args: dict[str, Any] | None = None, - ax: 'matplotlib.axes.Axes | None' = None, ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: """ Visualize a slice of a grid. @@ -120,8 +98,6 @@ class GridReadMixin(GridPosMixin): which_shifts: Which grid to display. Default is the first grid (0). sample_period: Period for down-sampling the image. Default 1 (disabled) finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - pcolormesh_args: Args passed through to matplotlib `pcolormesh()` - ax: If provided, plot to these axes (instead of creating a new figure & axes) Returns: (Figure, Axes) @@ -135,109 +111,24 @@ class GridReadMixin(GridPosMixin): pcolormesh_args = {} grid_slice = self.get_slice( - cell_data = cell_data, - plane = plane, - which_shifts = which_shifts, - sample_period = sample_period, + cell_data=cell_data, + plane=plane, + which_shifts=which_shifts, + sample_period=sample_period, ) surface = numpy.delete(range(3), plane.axis) - if sample_period == 1: - x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) - else: - x, y = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface) - pcolormesh_args.setdefault('shading', 'nearest') + x, y = (self.shifted_exyz(which_shifts)[a] for a in surface) xmesh, ymesh = numpy.meshgrid(x, y, indexing='ij') x_label, y_label = ('xyz'[a] for a in surface) - if ax is None: - fig, ax = pyplot.subplots() - else: - fig = ax.figure + fig, ax = pyplot.subplots() mappable = ax.pcolormesh(xmesh, ymesh, grid_slice, **pcolormesh_args) fig.colorbar(mappable) ax.set_aspect('equal', adjustable='box') ax.set_xlabel(x_label) ax.set_ylabel(y_label) - - if finalize: - pyplot.show() - - return fig, ax - - - def visualize_edges( - self, - cell_data: NDArray, - plane: PlaneProtocol | PlaneDict, - which_shifts: int = 0, - sample_period: int = 1, - finalize: bool = True, - contour_args: dict[str, Any] | None = None, - ax: 'matplotlib.axes.Axes | None' = None, - level_fraction: float = 0.7, - ) -> tuple['matplotlib.figure.Figure', 'matplotlib.axes.Axes']: - """ - Visualize the edges of a grid slice. - This is intended as an overlay on top of visualize_slice (e.g. showing epsilon boundaries - on an E-field plot). - - Interpolates if given a position between two grid planes. - - Args: - cell_data: Cell data to visualize - plane: Axis and position (`Plane`) of the plane to read. - which_shifts: Which grid to display. Default is the first grid (0). - sample_period: Period for down-sampling the image. Default 1 (disabled) - finalize: Whether to call `pyplot.show()` after constructing the plot. Default `True` - contour_args: Args passed through to matplotlib `pcolormesh()` - ax: If provided, plot to these axes (instead of creating a new figure & axes) - level_fraction: Value between 0 and 1 which tunes how many contours are generated. - 1 indicates that every possible step should have its own contour. - - Returns: - (Figure, Axes) - """ - from matplotlib import pyplot - - if level_fraction > 1: - raise GridError(f'{level_fraction=} must be between 0 and 1') - - if isinstance(plane, dict): - plane = Plane(**plane) - - if contour_args is None: - contour_args = dict(alpha=0.8, colors='gray') - - grid_slice = self.get_slice( - cell_data = cell_data, - plane = plane, - which_shifts = which_shifts, - sample_period = sample_period, - ) - cvals, cval_counts = numpy.unique(grid_slice, return_counts=True) - if cvals.size == 1: - levels = [cvals[0] + 1] - else: - cval_order = numpy.argsort(cval_counts)[::-1] - level_count = 2 - while cval_counts[cval_order[:level_count]].sum() < level_fraction: - level_count += 1 - ctr_levels = cvals[cval_order[:level_count]] - levels = numpy.diff(ctr_levels[::-1]) + ctr_levels[:0:-1] - - surface = numpy.delete(range(3), plane.axis) - - if ax is None: - fig, ax = pyplot.subplots() - else: - fig = ax.figure - xc, yc = (self.shifted_xyz(which_shifts)[a][::sample_period] for a in surface) - xcmesh, ycmesh = numpy.meshgrid(xc, yc, indexing='ij') - - ax.contour(xcmesh, ycmesh, grid_slice, levels=levels, **contour_args) - if finalize: pyplot.show() @@ -282,14 +173,8 @@ class GridReadMixin(GridPosMixin): verts, faces, _normals, _values = skimage.measure.marching_cubes(grid, level) # Convert vertices from index to position - preview_exyz = self._sampled_exyz(which_shifts, sample_period) - pos_verts = numpy.array([ - [ - numpy.interp(verts[i, a], numpy.arange(preview_exyz[a].size) - 0.5, preview_exyz[a]) - for a in range(3) - ] - for i in range(verts.shape[0]) - ], dtype=float) + pos_verts = numpy.array([self.ind2pos(verts[i, :], which_shifts, round_ind=False) + for i in range(verts.shape[0])], dtype=float) xs, ys, zs = (pos_verts[:, a] for a in range(3)) # Draw the plot diff --git a/gridlock/test/test_grid.py b/gridlock/test/test_grid.py index ae0a73a..8d9ca92 100644 --- a/gridlock/test/test_grid.py +++ b/gridlock/test/test_grid.py @@ -1,9 +1,8 @@ -import pytest +# import pytest import numpy from numpy.testing import assert_allclose #, assert_array_equal -import pickle -from .. import Grid, GridData, Extent, GridError, Plane, Slab +from .. import Grid, Extent #, Slab, Plane def test_draw_oncenter_2x2() -> None: @@ -117,350 +116,3 @@ def test_draw_2shift_4x4() -> None: [0, 0.125, 0.125, 0]])[None, :, :, None] assert_allclose(arr, correct) - - -def test_ind2pos_round_preserves_float_centers() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]]) - - pos = grid.ind2pos(numpy.array([1, 0, 0]), which_shifts=0) - - assert_allclose(pos, [2.0, 1.0, 0.5]) - - -def test_ind2pos_enforces_bounds_for_rounded_and_fractional_indices() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0, 0, 0]]) - - with pytest.raises(GridError): - grid.ind2pos(numpy.array([2, 0, 0]), which_shifts=0, check_bounds=True) - - edge_pos = grid.ind2pos(numpy.array([1.5, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True) - assert_allclose(edge_pos, [3.0, 2.0, 1.0]) - - with pytest.raises(GridError): - grid.ind2pos(numpy.array([1.6, 0.5, 0.5]), which_shifts=0, round_ind=False, check_bounds=True) - - -def test_draw_polygon_accepts_coplanar_nx3_vertices() -> None: - grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) - arr_2d = grid.allocate(0) - arr_3d = grid.allocate(0) - slab = dict(axis='z', center=0.5, span=1.0) - - polygon_2d = numpy.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=float) - polygon_3d = numpy.array([[0, 0, 0.5], - [1, 0, 0.5], - [1, 1, 0.5], - [0, 1, 0.5]], dtype=float) - - grid.draw_polygon(arr_2d, slab=slab, polygon=polygon_2d, foreground=1) - grid.draw_polygon(arr_3d, slab=slab, polygon=polygon_3d, foreground=1) - - assert_allclose(arr_3d, arr_2d) - - -def test_draw_polygon_rejects_noncoplanar_nx3_vertices() -> None: - grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]) - arr = grid.allocate(0) - polygon = numpy.array([[0, 0, 0.5], - [1, 0, 0.5], - [1, 1, 0.75], - [0, 1, 0.5]], dtype=float) - - with pytest.raises(GridError): - grid.draw_polygon(arr, slab=dict(axis='z', center=0.5, span=1.0), polygon=polygon, foreground=1) - - -def test_get_slice_supports_sampling() -> None: - grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]]) - cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) - - grid_slice = grid.get_slice(cell_data, Plane(z=0.5), sample_period=2) - - assert_allclose(grid_slice, cell_data[0, ::2, ::2, 0]) - - -def test_sampled_visualization_helpers_do_not_error() -> None: - matplotlib = pytest.importorskip('matplotlib') - matplotlib.use('Agg') - from matplotlib import pyplot - - grid = Grid([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1]], shifts=[[0, 0, 0]]) - cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) - - fig_slice, ax_slice = grid.visualize_slice(cell_data, Plane(z=0.5), sample_period=2, finalize=False) - fig_edges, ax_edges = grid.visualize_edges(cell_data, Plane(z=0.5), sample_period=2, finalize=False) - - assert fig_slice is ax_slice.figure - assert fig_edges is ax_edges.figure - - pyplot.close(fig_slice) - pyplot.close(fig_edges) - - -def test_grid_constructor_rejects_invalid_coordinate_count() -> None: - with pytest.raises(GridError): - Grid([[0, 1], [0, 1]], shifts=[[0, 0, 0]]) - - with pytest.raises(GridError): - Grid([[0, 1], [0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) - - -def test_grid_constructor_rejects_invalid_periodic_length() -> None: - with pytest.raises(GridError): - Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]], periodic=[True, False]) - - -def test_extent_and_slab_reject_inverted_geometry() -> None: - with pytest.raises(GridError): - Extent(center=0, min=1) - - with pytest.raises(GridError): - Extent(min=2, max=1) - - with pytest.raises(GridError): - Slab(axis='z', center=1, max=0) - - -def test_extent_accepts_scalar_like_inputs() -> None: - extent = Extent(min=numpy.array([1.0]), span=numpy.array([4.0])) - - assert_allclose([extent.center, extent.span, extent.min, extent.max], [3.0, 4.0, 1.0, 5.0]) - - -def test_get_slice_uses_shifted_grid_bounds() -> None: - grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0.5, 0, 0]]) - cell_data = numpy.arange(numpy.prod(grid.cell_data_shape), dtype=float).reshape(grid.cell_data_shape) - - grid_slice = grid.get_slice(cell_data, Plane(x=2.0), which_shifts=0) - - assert_allclose(grid_slice, cell_data[0, 1, :, :]) - - with pytest.raises(GridError): - grid.get_slice(cell_data, Plane(x=2.1), which_shifts=0) - - -def test_draw_extrude_rectangle_uses_boundary_slice() -> None: - grid = Grid([[0, 1, 2], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]]) - cell_data = grid.allocate(0) - source = numpy.array([[1, 2], - [3, 4]], dtype=float) - cell_data[0, :, :, 1] = source - - grid.draw_extrude_rectangle( - cell_data, - rectangle=[[0, 0, 2], [2, 2, 2]], - direction=2, - polarity=-1, - distance=2, - ) - - assert_allclose(cell_data[0, :, :, 0], source) - assert_allclose(cell_data[0, :, :, 1], source) - - -def test_sampled_preview_exyz_tracks_nonuniform_centers() -> None: - grid = Grid([[0, 1, 3, 6, 10], [0, 1, 2], [0, 1, 2]], shifts=[[0, 0, 0]]) - - sampled_exyz = grid._sampled_exyz(0, 2) - - assert_allclose(sampled_exyz[0], [-1.5, 2.5, 6.5]) - - -def test_visualize_isosurface_sampling_uses_preview_lattice(monkeypatch: pytest.MonkeyPatch) -> None: - matplotlib = pytest.importorskip('matplotlib') - matplotlib.use('Agg') - skimage_measure = pytest.importorskip('skimage.measure') - from matplotlib import pyplot - from mpl_toolkits.mplot3d.axes3d import Axes3D - - captured: dict[str, numpy.ndarray] = {} - - def fake_marching_cubes(_grid: numpy.ndarray, _level: float) -> tuple[numpy.ndarray, numpy.ndarray, None, None]: - verts = numpy.array([[0.5, 0.5, 0.5], - [0.5, 1.5, 0.5], - [1.5, 0.5, 0.5]], dtype=float) - faces = numpy.array([[0, 1, 2]], dtype=int) - return verts, faces, None, None - - def fake_plot_trisurf( # noqa: ANN202 - _self: object, - xs: numpy.ndarray, - ys: numpy.ndarray, - faces: numpy.ndarray, - zs: numpy.ndarray, - *_args: object, - **_kwargs: object, - ) -> object: - captured['xs'] = numpy.asarray(xs) - captured['ys'] = numpy.asarray(ys) - captured['faces'] = numpy.asarray(faces) - captured['zs'] = numpy.asarray(zs) - return object() - - monkeypatch.setattr(skimage_measure, 'marching_cubes', fake_marching_cubes) - monkeypatch.setattr(Axes3D, 'plot_trisurf', fake_plot_trisurf) - - grid = Grid([numpy.arange(7, dtype=float), numpy.arange(7, dtype=float), numpy.arange(7, dtype=float)], shifts=[[0, 0, 0]]) - cell_data = numpy.zeros(grid.cell_data_shape) - - fig, _ax = grid.visualize_isosurface(cell_data, level=0.5, sample_period=2, finalize=False) - - assert_allclose(captured['xs'], [1.5, 1.5, 3.5]) - assert_allclose(captured['ys'], [1.5, 3.5, 1.5]) - assert_allclose(captured['zs'], [1.5, 1.5, 1.5]) - - pyplot.close(fig) - - -def test_grid_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False]) - path = tmp_path / 'grid.state' - - grid.save(str(path)) - loaded = Grid.load(str(path)) - - assert path.exists() - for original, restored in zip(grid.exyz, loaded.exyz, strict=True): - assert_allclose(restored, original) - assert_allclose(loaded.shifts, grid.shifts) - assert loaded.periodic == grid.periodic - - -def test_grid_load_supports_legacy_pickle(tmp_path: pytest.TempPathFactory) -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]], periodic=[True, False, False]) - path = tmp_path / 'grid.pickle' - with open(path, 'wb') as f: - pickle.dump(grid.__dict__, f, protocol=2) - - loaded = Grid.load(str(path)) - - for original, restored in zip(grid.exyz, loaded.exyz, strict=True): - assert_allclose(restored, original) - assert_allclose(loaded.shifts, grid.shifts) - assert loaded.periodic == grid.periodic - - -def test_griddata_save_load_round_trip_npz(tmp_path: pytest.TempPathFactory) -> None: - data = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[0.5, 0, 0]]).with_data(fill_value=2.0) - data.cell_data[0, 1, 0, 0] = 5.0 - path = tmp_path / 'griddata.state' - - data.save(str(path)) - loaded = GridData.load(str(path)) - - assert path.exists() - assert_allclose(loaded.cell_data, data.cell_data) - assert_allclose(loaded.grid.shifts, data.grid.shifts) - assert loaded.grid.periodic == data.grid.periodic - - -def test_griddata_rejects_invalid_payload_kind(tmp_path: pytest.TempPathFactory) -> None: - path = tmp_path / 'grid.state' - Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).save(str(path)) - - with pytest.raises(GridError): - GridData.load(str(path)) - - -def test_negative_shift_nonperiodic_edges_and_widths() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) - - assert_allclose(grid.shifted_exyz(0)[0], [-0.5, 0.5, 2.0]) - assert_allclose(grid.shifted_dxyz(0)[0], [1.0, 1.5]) - assert_allclose(grid.shifted_xyz(0)[0], [0.0, 1.25]) - - -def test_negative_shift_periodic_edges_and_widths() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[True, False, False]) - - assert_allclose(grid.shifted_exyz(0)[0], [-1.0, 0.5, 2.0]) - assert_allclose(grid.shifted_dxyz(0)[0], [1.5, 1.5]) - - -def test_negative_shift_coordinate_round_trip() -> None: - grid = Grid([[0, 1, 3], [0, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) - - ind = grid.pos2ind([1.25, 1.0, 0.5], 0, round_ind=False) - pos = grid.ind2pos(ind, 0, round_ind=False) - - assert_allclose(ind, [1.0, 0.0, 0.0]) - assert_allclose(pos, [1.25, 1.0, 0.5]) - - -def test_negative_shift_draw_cuboid_fractional_fill() -> None: - grid = Grid([[0, 1, 3], [0, 1], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) - arr = grid.allocate(0) - - grid.draw_cuboid( - arr, - x=dict(min=0, max=1), - y=dict(min=0, max=1), - z=dict(min=0, max=1), - foreground=1, - ) - - assert_allclose(arr[0, :, 0, 0], [0.5, 1 / 3]) - - -def test_negative_shift_get_slice_uses_shifted_centers() -> None: - grid = Grid([[0, 1, 3], [0, 1, 2], [0, 1]], shifts=[[-0.5, 0, 0]], periodic=[False, False, False]) - cell_data = numpy.zeros(grid.cell_data_shape) - cell_data[0, 1, :, 0] = [7, 9] - x_center = float(grid.shifted_xyz(0)[0][1]) - - grid_slice = grid.get_slice(cell_data, Plane(x=x_center), which_shifts=0) - - assert_allclose(grid_slice, [7, 9]) - - -def test_grid_with_data_returns_griddata() -> None: - grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) - data = grid.with_data(fill_value=2.0) - - assert isinstance(data, GridData) - assert_allclose(data.cell_data, numpy.full(grid.cell_data_shape, 2.0, dtype=numpy.float32)) - - -def test_griddata_constructor_validates_shape() -> None: - grid = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]) - - with pytest.raises(GridError): - GridData(grid, numpy.zeros((1, 1, 1))) - - -def test_griddata_draw_methods_are_chainable() -> None: - data = Grid([[0, 1, 2], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0) - - chained = data.draw_cuboid( - foreground=1, - x=dict(min=0, max=1), - y=dict(min=0, max=1), - z=dict(min=0, max=1), - ).draw_polygon( - foreground=0.5, - slab=dict(axis='z', center=0.5, span=1.0), - polygon=numpy.array([[0, 0], [2, 0], [2, 1], [0, 1]], dtype=float), - ) - - assert chained is data - assert data.cell_data.sum() > 0 - - -def test_griddata_read_methods_delegate() -> None: - data = Grid([[0, 1, 2], [0, 1, 2], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=0) - data.cell_data[0, :, :, 0] = numpy.array([[1, 2], [3, 4]], dtype=float) - - assert_allclose( - data.get_slice(Plane(z=0.5)), - data.grid.get_slice(data.cell_data, Plane(z=0.5)), - ) - - -def test_griddata_copy_is_independent() -> None: - data = Grid([[0, 1], [0, 1], [0, 1]], shifts=[[0, 0, 0]]).with_data(fill_value=1.0) - cloned = data.copy() - cloned.cell_data[0, 0, 0, 0] = 5.0 - - assert data is not cloned - assert data.grid is not cloned.grid - assert data.cell_data[0, 0, 0, 0] == 1.0 diff --git a/gridlock/utils.py b/gridlock/utils.py index 585b999..7a12035 100644 --- a/gridlock/utils.py +++ b/gridlock/utils.py @@ -1,30 +1,13 @@ -from typing import Protocol, TypedDict, runtime_checkable, cast +from typing import Protocol, TypedDict, runtime_checkable from dataclasses import dataclass -import numpy - class GridError(Exception): """ Base error type for `gridlock` """ pass -def _coerce_scalar(name: str, value: object) -> float: - arr = numpy.asarray(value) - if arr.size != 1: - raise GridError(f'{name} must be a scalar value') - - try: - return float(arr.reshape(())) - except (TypeError, ValueError) as exc: - raise GridError(f'{name} must be a real scalar value') from exc - - class ExtentDict(TypedDict, total=False): - """ - Geometrical definition of an extent (1D bounded region) - Must contain exactly two of `min`, `max`, `center`, or `span`. - """ min: float center: float max: float @@ -33,9 +16,6 @@ class ExtentDict(TypedDict, total=False): @runtime_checkable class ExtentProtocol(Protocol): - """ - Anything that looks like an `Extent` - """ center: float span: float @@ -48,10 +28,6 @@ class ExtentProtocol(Protocol): @dataclass(init=False, slots=True) class Extent(ExtentProtocol): - """ - Geometrical definition of an extent (1D bounded region) - May be constructed with any two of `min`, `max`, `center`, or `span`. - """ center: float span: float @@ -71,53 +47,47 @@ class Extent(ExtentProtocol): max: float | None = None, span: float | None = None, ) -> None: - values = { - 'min': None if min is None else _coerce_scalar('min', min), - 'center': None if center is None else _coerce_scalar('center', center), - 'max': None if max is None else _coerce_scalar('max', max), - 'span': None if span is None else _coerce_scalar('span', span), - } - if sum(value is not None for value in values.values()) != 2: - raise GridError('Exactly two of min, center, max, span must be provided') + if sum(cc is None for cc in (min, center, max, span)) != 2: + raise GridError('Exactly two of min, center, max, span must be None!') - min_v = values['min'] - center_v = values['center'] - max_v = values['max'] - span_v = values['span'] + if span is None: + if center is None: + assert min is not None + assert max is not None + assert max >= min + center = 0.5 * (max + min) + span = max - min + elif max is None: + assert min is not None + assert center is not None + span = 2 * (center - min) + elif min is None: + assert center is not None + assert max is not None + span = 2 * (max - center) + else: # noqa: PLR5501 + if center is not None: + pass + elif max is None: + assert min is not None + assert span is not None + center = min + 0.5 * span + elif min is None: + assert max is not None + assert span is not None + center = max - 0.5 * span - if span_v is not None and span_v < 0: - raise GridError('span must be non-negative') - - if min_v is not None and max_v is not None: - if max_v < min_v: - raise GridError('max must be greater than or equal to min') - center_v = 0.5 * (max_v + min_v) - span_v = max_v - min_v - elif center_v is not None and min_v is not None: - span_v = 2 * (center_v - min_v) - if span_v < 0: - raise GridError('min must be less than or equal to center') - elif center_v is not None and max_v is not None: - span_v = 2 * (max_v - center_v) - if span_v < 0: - raise GridError('center must be less than or equal to max') - elif min_v is not None and span_v is not None: - center_v = min_v + 0.5 * span_v - elif max_v is not None and span_v is not None: - center_v = max_v - 0.5 * span_v - - if center_v is None or span_v is None: - raise GridError('Unable to construct extent from the provided values') - - self.center = center_v - self.span = span_v + assert center is not None + assert span is not None + if hasattr(center, '__len__'): + assert len(center) == 1 + if hasattr(span, '__len__'): + assert len(span) == 1 + self.center = center + self.span = span class SlabDict(TypedDict, total=False): - """ - Geometrical definition of a slab (3D region bounded on one axis only) - Must contain `axis` plus any two of `min`, `max`, `center`, or `span`. - """ min: float center: float max: float @@ -127,9 +97,6 @@ class SlabDict(TypedDict, total=False): @runtime_checkable class SlabProtocol(ExtentProtocol, Protocol): - """ - Anything that looks like a `Slab` - """ axis: int center: float span: float @@ -143,10 +110,6 @@ class SlabProtocol(ExtentProtocol, Protocol): @dataclass(init=False, slots=True) class Slab(Extent, SlabProtocol): - """ - Geometrical definition of a slab (3D region bounded on one axis only) - May be constructed with `axis` (bounded axis) plus any two of `min`, `max`, `center`, or `span`. - """ axis: int def __init__( @@ -179,10 +142,6 @@ class Slab(Extent, SlabProtocol): class PlaneDict(TypedDict, total=False): - """ - Geometrical definition of a plane (2D unbounded region in 3D space) - Must contain exactly one of `x`, `y`, `z`, or both `axis` and `pos` - """ x: float y: float z: float @@ -192,19 +151,12 @@ class PlaneDict(TypedDict, total=False): @runtime_checkable class PlaneProtocol(Protocol): - """ - Anything that looks like a `Plane` - """ axis: int pos: float @dataclass(init=False, slots=True) class Plane(PlaneProtocol): - """ - Geometrical definition of a plane (2D unbounded region in 3D space) - May be constructed with any of `x=4`, `y=5`, `z=-5`, or `axis=2, pos=-5`. - """ axis: int pos: float @@ -240,9 +192,10 @@ class Plane(PlaneProtocol): if pos is not None: cpos = pos else: - cpos = cast('float', (xx, yy, zz)[axis_int]) + cpos = (xx, yy, zz)[axis_int] assert cpos is not None if hasattr(cpos, '__len__'): assert len(cpos) == 1 self.pos = cpos + diff --git a/pyproject.toml b/pyproject.toml index 03d0d19..1df2e8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ lint.ignore = [ "ANN002", # *args "ANN003", # **kwargs "ANN401", # Any + "ANN101", # self: Self "SIM108", # single-line if / else assignment "RET504", # x=y+z; return x "PIE790", # unnecessary pass