From 12357cd9740e4d1d286c8bb0cb621cd49eb942c0 Mon Sep 17 00:00:00 2001 From: jan Date: Fri, 20 Dec 2024 19:58:56 -0800 Subject: [PATCH] use klamath_rs_ext if available --- klamath/basic.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/klamath/basic.py b/klamath/basic.py index 86c9d59..135398b 100644 --- a/klamath/basic.py +++ b/klamath/basic.py @@ -17,6 +17,11 @@ logger = logging.getLogger(__name__) class KlamathError(Exception): pass +try: + from klamath_rs_ext import pack_int2, pack_int4 + _USE_RS_EXT = True +except ImportError: + _USE_RS_EXT = False # # Parse functions @@ -93,14 +98,14 @@ def pack_bitarray(data: int) -> bytes: return struct.pack('>H', data) -def pack_int2(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: +def _pack_int2(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: arr = numpy.asarray(data) if (arr > 32767).any() or (arr < -32768).any(): raise KlamathError(f'int2 data out of range: {arr}') return arr.astype('>i2').tobytes() -def pack_int4(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: +def _pack_int4(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: arr = numpy.asarray(data) if (arr > 2147483647).any() or (arr < -2147483648).any(): raise KlamathError(f'int4 data out of range: {arr}') @@ -189,3 +194,9 @@ def read(stream: IO[bytes], size: int) -> bytes: if len(data) != size: raise EOFError return data + + +if not _USE_RS_EXT: + pack_int2 = _pack_int2 + pack_int4 = _pack_int4 +