use klamath_rs_ext if available

This commit is contained in:
jan 2024-12-20 19:58:56 -08:00
parent 4ffb87d361
commit 12357cd974

View File

@ -17,6 +17,11 @@ logger = logging.getLogger(__name__)
class KlamathError(Exception): class KlamathError(Exception):
pass pass
try:
from klamath_rs_ext import pack_int2, pack_int4
_USE_RS_EXT = True
except ImportError:
_USE_RS_EXT = False
# #
# Parse functions # Parse functions
@ -93,14 +98,14 @@ def pack_bitarray(data: int) -> bytes:
return struct.pack('>H', data) return struct.pack('>H', data)
def pack_int2(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: def _pack_int2(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes:
arr = numpy.asarray(data) arr = numpy.asarray(data)
if (arr > 32767).any() or (arr < -32768).any(): if (arr > 32767).any() or (arr < -32768).any():
raise KlamathError(f'int2 data out of range: {arr}') raise KlamathError(f'int2 data out of range: {arr}')
return arr.astype('>i2').tobytes() return arr.astype('>i2').tobytes()
def pack_int4(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: def _pack_int4(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes:
arr = numpy.asarray(data) arr = numpy.asarray(data)
if (arr > 2147483647).any() or (arr < -2147483648).any(): if (arr > 2147483647).any() or (arr < -2147483648).any():
raise KlamathError(f'int4 data out of range: {arr}') raise KlamathError(f'int4 data out of range: {arr}')
@ -189,3 +194,9 @@ def read(stream: IO[bytes], size: int) -> bytes:
if len(data) != size: if len(data) != size:
raise EOFError raise EOFError
return data return data
if not _USE_RS_EXT:
pack_int2 = _pack_int2
pack_int4 = _pack_int4