From e7e42a2ef85f25e2ab6e0d18d5bffe46e88243a1 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Sun, 28 Jul 2024 23:04:12 -0700 Subject: [PATCH] Allow NDArray inputs to pack_* and avoid unnecesary copies --- klamath/basic.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/klamath/basic.py b/klamath/basic.py index 9d7bdcb..6d4d924 100644 --- a/klamath/basic.py +++ b/klamath/basic.py @@ -92,15 +92,15 @@ def pack_bitarray(data: int) -> bytes: return struct.pack('>H', data) -def pack_int2(data: Sequence[int]) -> bytes: - arr = numpy.array(data) +def pack_int2(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: + arr = numpy.array(data, copy=False) if (arr > 32767).any() or (arr < -32768).any(): raise KlamathError(f'int2 data out of range: {arr}') return arr.astype('>i2').tobytes() -def pack_int4(data: Sequence[int]) -> bytes: - arr = numpy.array(data) +def pack_int4(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: + arr = numpy.array(data, copy=False) if (arr > 2147483647).any() or (arr < -2147483648).any(): raise KlamathError(f'int4 data out of range: {arr}') return arr.astype('>i4').tobytes() @@ -164,8 +164,8 @@ def encode_real8(fnums: NDArray[numpy.float64]) -> NDArray[numpy.uint64]: return real8.astype(numpy.uint64, copy=False) -def pack_real8(data: Sequence[float]) -> bytes: - return encode_real8(numpy.array(data)).astype('>u8').tobytes() +def pack_real8(data: NDArray[numpy.floating] | Sequence[float] | float) -> bytes: + return encode_real8(numpy.array(data, copy=False)).astype('>u8').tobytes() def pack_ascii(data: bytes) -> bytes: