type annotation improvements
This commit is contained in:
parent
e256f56f2b
commit
646911c4b5
@ -1,8 +1,7 @@
|
|||||||
from typing import ClassVar, Self, Protocol
|
from typing import Protocol
|
||||||
from collections.abc import Callable, Sequence
|
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.typing import NDArray, ArrayLike
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from . import GridError
|
from . import GridError
|
||||||
|
|
||||||
@ -38,7 +37,7 @@ class GridBase(Protocol):
|
|||||||
return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
|
return [self.exyz[a][:-1] + self.dxyz[a] / 2.0 for a in range(3)]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> NDArray[numpy.int_]:
|
def shape(self) -> NDArray[numpy.intp]:
|
||||||
"""
|
"""
|
||||||
The number of cells in x, y, and z
|
The number of cells in x, y, and z
|
||||||
|
|
||||||
@ -55,7 +54,7 @@ class GridBase(Protocol):
|
|||||||
return self.shifts.shape[0]
|
return self.shifts.shape[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cell_data_shape(self):
|
def cell_data_shape(self) -> NDArray[numpy.intp]:
|
||||||
"""
|
"""
|
||||||
The shape of the cell_data ndarray (num_grids, *self.shape).
|
The shape of the cell_data ndarray (num_grids, *self.shape).
|
||||||
"""
|
"""
|
||||||
@ -180,7 +179,7 @@ class GridBase(Protocol):
|
|||||||
raise GridError('Autoshifting requires exactly 3 grids')
|
raise GridError('Autoshifting requires exactly 3 grids')
|
||||||
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
|
return [self.shifted_dxyz(which_shifts=a)[a] for a in range(3)]
|
||||||
|
|
||||||
def allocate(self, fill_value: float | None = 1.0, dtype=numpy.float32) -> NDArray:
|
def allocate(self, fill_value: float | None = 1.0, dtype: type[numpy.number] = numpy.float32) -> NDArray:
|
||||||
"""
|
"""
|
||||||
Allocate an ndarray for storing grid data.
|
Allocate an ndarray for storing grid data.
|
||||||
|
|
||||||
@ -194,5 +193,4 @@ class GridBase(Protocol):
|
|||||||
"""
|
"""
|
||||||
if fill_value is None:
|
if fill_value is None:
|
||||||
return numpy.empty(self.cell_data_shape, dtype=dtype)
|
return numpy.empty(self.cell_data_shape, dtype=dtype)
|
||||||
else:
|
return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)
|
||||||
return numpy.full(self.cell_data_shape, fill_value, dtype=dtype)
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Drawing-related methods for Grid class
|
Drawing-related methods for Grid class
|
||||||
"""
|
"""
|
||||||
from typing import Union
|
|
||||||
from collections.abc import Sequence, Callable
|
from collections.abc import Sequence, Callable
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
@ -20,7 +19,7 @@ from .position import GridPosMixin
|
|||||||
|
|
||||||
|
|
||||||
foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray]
|
foreground_callable_t = Callable[[NDArray, NDArray, NDArray], NDArray]
|
||||||
foreground_t = Union[float, foreground_callable_t]
|
foreground_t = float | foreground_callable_t
|
||||||
|
|
||||||
|
|
||||||
class GridDrawMixin(GridPosMixin):
|
class GridDrawMixin(GridPosMixin):
|
||||||
@ -166,7 +165,7 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
# 2) Generate weights in z-direction
|
# 2) Generate weights in z-direction
|
||||||
w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], ))
|
w_z = numpy.zeros(((bdi_max - bdi_min + 1)[surface_normal], ))
|
||||||
|
|
||||||
def get_zi(offset, i=i, w_z=w_z):
|
def get_zi(offset: float, i=i, w_z=w_z) -> tuple[float, int]: # noqa: ANN001
|
||||||
edges = self.shifted_exyz(i)[surface_normal]
|
edges = self.shifted_exyz(i)[surface_normal]
|
||||||
point = center[surface_normal] + offset
|
point = center[surface_normal] + offset
|
||||||
grid_coord = numpy.digitize(point, edges) - 1
|
grid_coord = numpy.digitize(point, edges) - 1
|
||||||
@ -384,10 +383,10 @@ class GridDrawMixin(GridPosMixin):
|
|||||||
ind[direction] += 1 # type: ignore #(known safe)
|
ind[direction] += 1 # type: ignore #(known safe)
|
||||||
foreground += mult[1] * grid[tuple(ind)]
|
foreground += mult[1] * grid[tuple(ind)]
|
||||||
|
|
||||||
def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int_]:
|
def f_foreground(xs, ys, zs, i=i, foreground=foreground) -> NDArray[numpy.int64]: # noqa: ANN001
|
||||||
# transform from natural position to index
|
# transform from natural position to index
|
||||||
xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i)
|
xyzi = numpy.array([self.pos2ind(qrs, which_shifts=i)
|
||||||
for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=int)
|
for qrs in zip(xs.flat, ys.flat, zs.flat, strict=True)], dtype=numpy.int64)
|
||||||
# reshape to original shape and keep only in-plane components
|
# reshape to original shape and keep only in-plane components
|
||||||
qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface)
|
qi, ri = (numpy.reshape(xyzi[:, k], xs.shape) for k in surface)
|
||||||
return foreground[qi, ri]
|
return foreground[qi, ri]
|
||||||
|
@ -29,7 +29,7 @@ if __name__ == '__main__':
|
|||||||
# numpy.linspace(-5.5, 5.5, 10)]
|
# numpy.linspace(-5.5, 5.5, 10)]
|
||||||
|
|
||||||
half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5]
|
half_x = [.25, .5, 0.75, 1, 1.25, 1.5, 2, 2.5, 3, 3.5]
|
||||||
xyz3 = [[-x for x in half_x[::-1]] + [0] + half_x,
|
xyz3 = [numpy.array([-x for x in half_x[::-1]] + [0] + half_x),
|
||||||
numpy.linspace(-5.5, 5.5, 10),
|
numpy.linspace(-5.5, 5.5, 10),
|
||||||
numpy.linspace(-5.5, 5.5, 10)]
|
numpy.linspace(-5.5, 5.5, 10)]
|
||||||
eg = Grid(xyz3)
|
eg = Grid(xyz3)
|
||||||
|
Loading…
Reference in New Issue
Block a user