From 646911c4b5f707fc35174a1addcada4c5e4caadf Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 01:57:39 -0700 Subject: [PATCH] type annotation improvements --- gridlock/base.py | 14 ++++++-------- gridlock/draw.py | 9 ++++----- gridlock/examples/ex0.py | 2 +- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/gridlock/base.py b/gridlock/base.py index 6bd5fb8..aca9c69 100644 --- a/gridlock/base.py +++ b/gridlock/base.py @@ -1,8 +1,7 @@ -from typing import ClassVar, Self, Protocol -from collections.abc import Callable, Sequence +from typing import Protocol import numpy -from numpy.typing import NDArray, ArrayLike +from numpy.typing import NDArray from . import GridError @@ -38,7 +37,7 @@ class GridBase(Protocol): return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)] @property - def shape(self) -> NDArray[numpy.int_]: + def shape(self) -> NDArray[numpy.intp]: """ The number of cells in x, y, and z @@ -55,7 +54,7 @@ class GridBase(Protocol): return self.shifts.shape[0] @property - def cell_data_shape(self): + def cell_data_shape(self) -> NDArray[numpy.intp]: """ The shape of the cell_data ndarray (num_grids, *self.shape). """ @@ -180,7 +179,7 @@ class GridBase(Protocol): raise GridError('Autoshifting requires exactly 3 grids') return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] - def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray: + def allocate(self, fill_value: float | None = 1.0, dtype: type[numpy.number] = numpy.float32) -> NDArray: """ Allocate an ndarray for storing grid data. @@ -194,5 +193,4 @@ class GridBase(Protocol): """ if fill_value is None: return numpy.empty(self.cell_data_shape, dtype=dtype) - else: - return numpy.full(self.cell_data_shape, fill_value, dtype=dtype) + return numpy.full(self.cell_data_shape, fill_value, dtype=dtype) diff --git a/gridlock/draw.py b/gridlock/draw.py index e68fdc6..aa45200 100644 --- a/gridlock/draw.py +++ b/gridlock/draw.py @@ -1,7 +1,6 @@ """ Drawing-related methods for Grid class """ -from typing import Union from collections.abc import Sequence, Callable import numpy @@ -20,7 +19,7 @@ from .position import GridPosMixin foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray] -foreground_t = Union[float, foreground_callable_t] +foreground_t = float | foreground_callable_t class GridDrawMixin(GridPosMixin): @@ -166,7 +165,7 @@ class GridDrawMixin(GridPosMixin): # 2) Generate weights in z-direction w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], )) - def get_zi(offset, i=i, w_z=w_z): + def get_zi(offset: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001 edges = self.shifted_exyz(i)[surface_normal] point = center[surface_normal] + offset grid_coord = numpy.digitize(point, edges) - 1 @@ -384,10 +383,10 @@ class GridDrawMixin(GridPosMixin): ind[direction] += 1 # type: ignore #(known safe) foreground += mult[1] * grid[tuple(ind)] - def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int_]: + def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001 # transform from natural position to index xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i) - for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=int) + for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64) # reshape to original shape and keep only in-plane components qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface) return foreground[qi, ri] diff --git a/gridlock/examples/ex0.py b/gridlock/examples/ex0.py index 7dd4355..b96cdca 100644 --- a/gridlock/examples/ex0.py +++ b/gridlock/examples/ex0.py @@ -29,7 +29,7 @@ if __name__ == '__main__': # numpy.linspace(-5.5, 5.5, 10)] half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5] - xyz3 = [[-x for x in half_x[::-1]] + [0] + half_x, + xyz3 = [numpy.array([-x for x in half_x[::-1]] + [0] + half_x), numpy.linspace(-5.5, 5.5, 10), numpy.linspace(-5.5, 5.5, 10)] eg = Grid(xyz3)