type annotation improvements

This commit is contained in:
Jan Petykiewicz 2024-07-29 01:57:39 -07:00
parent e256f56f2b
commit 646911c4b5
3 changed files with 11 additions and 14 deletions

View File

@ -1,8 +1,7 @@
from typing import ClassVar, Self, Protocol from typing import Protocol
from collections.abc import Callable, Sequence
import numpy import numpy
from numpy.typing import NDArray, ArrayLike from numpy.typing import NDArray
from . import GridError from . import GridError
@ -38,7 +37,7 @@ class GridBase(Protocol):
return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)] return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
@property @property
def shape(self) -> NDArray[numpy.int_]: def shape(self) -> NDArray[numpy.intp]:
""" """
The number of cells in x, y, and z The number of cells in x, y, and z
@ -55,7 +54,7 @@ class GridBase(Protocol):
return self.shifts.shape[0] return self.shifts.shape[0]
@property @property
def cell_data_shape(self): def cell_data_shape(self) -> NDArray[numpy.intp]:
""" """
The shape of the cell_data ndarray (num_grids, *self.shape). The shape of the cell_data ndarray (num_grids, *self.shape).
""" """
@ -180,7 +179,7 @@ class GridBase(Protocol):
raise GridError('Autoshifting requires exactly 3 grids') raise GridError('Autoshifting requires exactly 3 grids')
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)] return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray: def allocate(self, fill_value: float | None = 1.0, dtype: type[numpy.number] = numpy.float32) -> NDArray:
""" """
Allocate an ndarray for storing grid data. Allocate an ndarray for storing grid data.
@ -194,5 +193,4 @@ class GridBase(Protocol):
""" """
if fill_value is None: if fill_value is None:
return numpy.empty(self.cell_data_shape, dtype=dtype) return numpy.empty(self.cell_data_shape, dtype=dtype)
else:
return numpy.full(self.cell_data_shape, fill_value, dtype=dtype) return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)

View File

@ -1,7 +1,6 @@
""" """
Drawing-related methods for Grid class Drawing-related methods for Grid class
""" """
from typing import Union
from collections.abc import Sequence, Callable from collections.abc import Sequence, Callable
import numpy import numpy
@ -20,7 +19,7 @@ from .position import GridPosMixin
foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray] foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray]
foreground_t = Union[float, foreground_callable_t] foreground_t = float | foreground_callable_t
class GridDrawMixin(GridPosMixin): class GridDrawMixin(GridPosMixin):
@ -166,7 +165,7 @@ class GridDrawMixin(GridPosMixin):
# 2) Generate weights in z-direction # 2) Generate weights in z-direction
w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], )) w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], ))
def get_zi(offset, i=i, w_z=w_z): def get_zi(offset: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001
edges = self.shifted_exyz(i)[surface_normal] edges = self.shifted_exyz(i)[surface_normal]
point = center[surface_normal] + offset point = center[surface_normal] + offset
grid_coord = numpy.digitize(point, edges) - 1 grid_coord = numpy.digitize(point, edges) - 1
@ -384,10 +383,10 @@ class GridDrawMixin(GridPosMixin):
ind[direction] += 1 # type: ignore #(known safe) ind[direction] += 1 # type: ignore #(known safe)
foreground += mult[1] * grid[tuple(ind)] foreground += mult[1] * grid[tuple(ind)]
def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int_]: def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001
# transform from natural position to index # transform from natural position to index
xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i) xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i)
for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=int) for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64)
# reshape to original shape and keep only in-plane components # reshape to original shape and keep only in-plane components
qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface) qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface)
return foreground[qi, ri] return foreground[qi, ri]

View File

@ -29,7 +29,7 @@ if __name__ == '__main__':
# numpy.linspace(-5.5, 5.5, 10)] # numpy.linspace(-5.5, 5.5, 10)]
half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5] half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5]
xyz3 = [[-x for x in half_x[::-1]] + [0] + half_x, xyz3 = [numpy.array([-x for x in half_x[::-1]] + [0] + half_x),
numpy.linspace(-5.5, 5.5, 10), numpy.linspace(-5.5, 5.5, 10),
numpy.linspace(-5.5, 5.5, 10)] numpy.linspace(-5.5, 5.5, 10)]
eg = Grid(xyz3) eg = Grid(xyz3)