57 lines
1.7 KiB
Python
Raw Normal View History

2024-12-21 09:50:19 -08:00
from collections.abc import Sequence
import numpy
from numpy.typing import NDArray
from .klamath_rs_ext import arr_to_int2, arr_to_int4
def pack_int2(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes:
arr = numpy.asarray(data)
if arr.dtype in (
numpy.float64, numpy.float32,
numpy.int64, numpy.uint64,
numpy.int32, numpy.uint32,
numpy.int16, numpy.uint16,
):
arr = numpy.require(arr, requirements=('C_CONTIGUOUS', 'ALIGNED', 'WRITEABLE', 'OWNDATA'))
if arr is data:
arr = numpy.array(arr, copy=True)
arr_to_int2(arr)
i2arr = arr.view('>i2')[::arr.itemsize // 2]
return i2arr.tobytes()
if arr.dtype == numpy.dtype('>i2'):
return arr.tobytes()
if (arr > 32767).any() or (arr < -32768).any():
raise Exception(f'int2 data out of range: {arr}')
return arr.astype('>i2').tobytes()
def pack_int4(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes:
arr = numpy.asarray(data)
if arr.dtype in (
numpy.float64, numpy.float32,
numpy.int64, numpy.uint64,
numpy.int32, numpy.uint32,
):
arr = numpy.require(arr, requirements=('C_CONTIGUOUS', 'ALIGNED', 'WRITEABLE', 'OWNDATA'))
if arr is data:
arr = numpy.array(arr, copy=True)
arr_to_int4(arr)
i4arr = arr.view('>i4')[::arr.itemsize // 4]
return i4arr.tobytes()
if arr.dtype == numpy.dtype('>i4'):
return arr.tobytes()
if (arr > 2147483647).any() or (arr < -2147483648).any():
raise Exception(f'int4 data out of range: {arr}')
return arr.astype('>i4').tobytes()