type annotation improvements
This commit is contained in:
parent
e256f56f2b
commit
646911c4b5
@ -1,8 +1,7 @@
|
||||
from typing import ClassVar, Self, Protocol
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Protocol
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray, ArrayLike
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from . import GridError
|
||||
|
||||
@ -38,7 +37,7 @@ class GridBase(Protocol):
|
||||
return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
|
||||
|
||||
@property
|
||||
def shape(self) -> NDArray[numpy.int_]:
|
||||
def shape(self) -> NDArray[numpy.intp]:
|
||||
"""
|
||||
The number of cells in x, y, and z
|
||||
|
||||
@ -55,7 +54,7 @@ class GridBase(Protocol):
|
||||
return self.shifts.shape[0]
|
||||
|
||||
@property
|
||||
def cell_data_shape(self):
|
||||
def cell_data_shape(self) -> NDArray[numpy.intp]:
|
||||
"""
|
||||
The shape of the cell_data ndarray (num_grids, *self.shape).
|
||||
"""
|
||||
@ -180,7 +179,7 @@ class GridBase(Protocol):
|
||||
raise GridError('Autoshifting requires exactly 3 grids')
|
||||
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
|
||||
|
||||
def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray:
|
||||
def allocate(self, fill_value: float | None = 1.0, dtype: type[numpy.number] = numpy.float32) -> NDArray:
|
||||
"""
|
||||
Allocate an ndarray for storing grid data.
|
||||
|
||||
@ -194,5 +193,4 @@ class GridBase(Protocol):
|
||||
"""
|
||||
if fill_value is None:
|
||||
return numpy.empty(self.cell_data_shape, dtype=dtype)
|
||||
else:
|
||||
return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""
|
||||
Drawing-related methods for Grid class
|
||||
"""
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Callable
|
||||
|
||||
import numpy
|
||||
@ -20,7 +19,7 @@ from .position import GridPosMixin
|
||||
|
||||
|
||||
foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray]
|
||||
foreground_t = Union[float, foreground_callable_t]
|
||||
foreground_t = float | foreground_callable_t
|
||||
|
||||
|
||||
class GridDrawMixin(GridPosMixin):
|
||||
@ -166,7 +165,7 @@ class GridDrawMixin(GridPosMixin):
|
||||
# 2) Generate weights in z-direction
|
||||
w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], ))
|
||||
|
||||
def get_zi(offset, i=i, w_z=w_z):
|
||||
def get_zi(offset: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001
|
||||
edges = self.shifted_exyz(i)[surface_normal]
|
||||
point = center[surface_normal] + offset
|
||||
grid_coord = numpy.digitize(point, edges) - 1
|
||||
@ -384,10 +383,10 @@ class GridDrawMixin(GridPosMixin):
|
||||
ind[direction] += 1 # type: ignore #(known safe)
|
||||
foreground += mult[1] * grid[tuple(ind)]
|
||||
|
||||
def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int_]:
|
||||
def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001
|
||||
# transform from natural position to index
|
||||
xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i)
|
||||
for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=int)
|
||||
for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64)
|
||||
# reshape to original shape and keep only in-plane components
|
||||
qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface)
|
||||
return foreground[qi, ri]
|
||||
|
@ -29,7 +29,7 @@ if __name__ == '__main__':
|
||||
# numpy.linspace(-5.5, 5.5, 10)]
|
||||
|
||||
half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5]
|
||||
xyz3 = [[-x for x in half_x[::-1]] + [0] + half_x,
|
||||
xyz3 = [numpy.array([-x for x in half_x[::-1]] + [0] + half_x),
|
||||
numpy.linspace(-5.5, 5.5, 10),
|
||||
numpy.linspace(-5.5, 5.5, 10)]
|
||||
eg = Grid(xyz3)
|
||||
|
Loading…
Reference in New Issue
Block a user