diff --git a/meanas/fdmath/vectorization.py b/meanas/fdmath/vectorization.py index fef3c5e..3871801 100644 --- a/meanas/fdmath/vectorization.py +++ b/meanas/fdmath/vectorization.py @@ -28,14 +28,16 @@ def vec(f: cfdfield_t) -> vcfdfield_t: def vec(f: ArrayLike) -> vfdfield_t | vcfdfield_t: pass -def vec(f: fdfield_t | cfdfield_t | ArrayLike | None) -> vfdfield_t | vcfdfield_t | None: +def vec( + f: fdfield_t | cfdfield_t | ArrayLike | None, + ) -> vfdfield_t | vcfdfield_t | None: """ - Create a 1D ndarray from a 3D vector field which spans a 1-3D region. + Create a 1D ndarray from a vector field which spans a 1-3D region. Returns `None` if called with `f=None`. Args: - f: A vector field, `[f_x, f_y, f_z]` where each `f_` component is a 1- to + f: A vector field, e.g. `[f_x, f_y, f_z]` where each `f_` component is a 1- to 3-D ndarray (`f_*` should all be the same size). Doesn't fail with `f=None`. Returns: @@ -47,33 +49,38 @@ def vec(f: fdfield_t | cfdfield_t | ArrayLike | None) -> vfdfield_t | vcfdfield_ @overload -def unvec(v: None, shape: Sequence[int]) -> None: +def unvec(v: None, shape: Sequence[int], nvdim: int) -> None: pass @overload -def unvec(v: vfdfield_t, shape: Sequence[int]) -> fdfield_t: +def unvec(v: vfdfield_t, shape: Sequence[int], nvdim: int) -> fdfield_t: pass @overload -def unvec(v: vcfdfield_t, shape: Sequence[int]) -> cfdfield_t: +def unvec(v: vcfdfield_t, shape: Sequence[int], nvdim: int) -> cfdfield_t: pass -def unvec(v: vfdfield_t | vcfdfield_t | None, shape: Sequence[int]) -> fdfield_t | cfdfield_t | None: +def unvec( + v: vfdfield_t | vcfdfield_t | None, + shape: Sequence[int], + nvdim: int = 3, + ) -> fdfield_t | cfdfield_t | None: """ - Perform the inverse of vec(): take a 1D ndarray and output a 3D field - of form `[f_x, f_y, f_z]` where each of `f_*` is a len(shape)-dimensional + Perform the inverse of vec(): take a 1D ndarray and output an `nvdim`-component field + of form e.g. `[f_x, f_y, f_z]` (`nvdim=3`) where each of `f_*` is a len(shape)-dimensional ndarray. Returns `None` if called with `v=None`. Args: - v: 1D ndarray representing a 3D vector field of shape shape (or None) + v: 1D ndarray representing a vector field of shape shape (or None) shape: shape of the vector field + nvdim: Number of components in each vector Returns: `[f_x, f_y, f_z]` where each `f_` is a `len(shape)` dimensional ndarray (or `None`) """ if v is None: return None - return v.reshape((3, *shape), order='C') + return v.reshape((nvdim, *shape), order='C')