100 lines
2.9 KiB
Python
Raw Normal View History

2024-12-21 09:50:19 -08:00
from collections.abc import Sequence
2024-12-21 13:56:51 -08:00
import ctypes
from pathlib import Path
from itertools import chain
2024-12-21 09:50:19 -08:00
import numpy
from numpy.typing import NDArray
2024-12-21 13:56:51 -08:00
so_path = Path(__file__).resolve().parent / 'libklamath_rs_ext.so'
clib = ctypes.CDLL(so_path)
CONV_TABLE_i16 = {
numpy.float64: clib.f64_to_i16,
numpy.float32: clib.f32_to_i16,
numpy.int64: clib.i64_to_i16,
numpy.int32: clib.i32_to_i16,
numpy.int16: clib.i16_to_i16,
numpy.uint64: clib.u64_to_i16,
numpy.uint32: clib.u32_to_i16,
numpy.uint16: clib.u16_to_i16,
}
CONV_TABLE_i32 = {
numpy.float64: clib.f64_to_i32,
numpy.float32: clib.f32_to_i32,
numpy.int64: clib.i64_to_i32,
numpy.int32: clib.i32_to_i32,
numpy.uint64: clib.u64_to_i32,
numpy.uint32: clib.u32_to_i32,
}
clib.f64_to_i16.restype = ctypes.c_double
clib.f32_to_i16.restype = ctypes.c_float
clib.i64_to_i16.restype = ctypes.c_int64
clib.i32_to_i16.restype = ctypes.c_int32
clib.i16_to_i16.restype = ctypes.c_int16
clib.u64_to_i16.restype = ctypes.c_uint64
clib.u32_to_i16.restype = ctypes.c_uint32
clib.u16_to_i16.restype = ctypes.c_uint16
clib.f64_to_i32.restype = ctypes.c_double
clib.f32_to_i32.restype = ctypes.c_float
clib.i64_to_i32.restype = ctypes.c_int64
clib.i32_to_i32.restype = ctypes.c_int32
clib.u64_to_i32.restype = ctypes.c_uint64
clib.u32_to_i32.restype = ctypes.c_uint32
for fn in chain(CONV_TABLE_i16.values(), CONV_TABLE_i32.values()):
fn.argtypes = [ctypes.POINTER(fn.restype), ctypes.c_size_t]
2024-12-21 09:50:19 -08:00
def pack_int2(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes:
arr = numpy.asarray(data)
2024-12-21 13:56:51 -08:00
if arr.dtype in CONV_TABLE_i16.keys():
2024-12-21 09:50:19 -08:00
arr = numpy.require(arr, requirements=('C_CONTIGUOUS', 'ALIGNED', 'WRITEABLE', 'OWNDATA'))
if arr is data:
arr = numpy.array(arr, copy=True)
2024-12-21 13:56:51 -08:00
fn = CONV_TABLE_i16[arr.dtype]
result = fn(arr.ctypes.data_as(fn.argtypes[0]), arr.size)
2024-12-21 09:50:19 -08:00
i2arr = arr.view('>i2')[::arr.itemsize // 2]
return i2arr.tobytes()
if arr.dtype == numpy.dtype('>i2'):
return arr.tobytes()
if (arr > 32767).any() or (arr < -32768).any():
raise Exception(f'int2 data out of range: {arr}')
return arr.astype('>i2').tobytes()
def pack_int4(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes:
arr = numpy.asarray(data)
2024-12-21 13:56:51 -08:00
if arr.dtype in CONV_TABLE_i32.keys():
2024-12-21 09:50:19 -08:00
arr = numpy.require(arr, requirements=('C_CONTIGUOUS', 'ALIGNED', 'WRITEABLE', 'OWNDATA'))
if arr is data:
arr = numpy.array(arr, copy=True)
2024-12-21 13:56:51 -08:00
fn = CONV_TABLE_i32[arr.dtype]
result = fn(arr.ctypes.data_as(fn.argtypes[0]), arr.size)
2024-12-21 09:50:19 -08:00
i4arr = arr.view('>i4')[::arr.itemsize // 4]
return i4arr.tobytes()
if arr.dtype == numpy.dtype('>i4'):
return arr.tobytes()
if (arr > 2147483647).any() or (arr < -2147483648).any():
raise Exception(f'int4 data out of range: {arr}')
return arr.astype('>i4').tobytes()