pyo3 variant working
This commit is contained in:
parent
ad1c2f1c35
commit
ba07d253d2
6 changed files with 234 additions and 2 deletions
6
klamath_rs_ext/__init__.py
Normal file
6
klamath_rs_ext/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from .basic import pack_int2 as pack_int2
|
||||
from .basic import pack_int4 as pack_int4
|
||||
|
||||
|
||||
__version__ = 0.1
|
||||
|
||||
56
klamath_rs_ext/basic.py
Normal file
56
klamath_rs_ext/basic.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
from collections.abc import Sequence
|
||||
|
||||
import numpy
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .klamath_rs_ext import arr_to_int2, arr_to_int4
|
||||
|
||||
|
||||
def pack_int2(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes:
|
||||
arr = numpy.asarray(data)
|
||||
|
||||
if arr.dtype in (
|
||||
numpy.float64, numpy.float32,
|
||||
numpy.int64, numpy.uint64,
|
||||
numpy.int32, numpy.uint32,
|
||||
numpy.int16, numpy.uint16,
|
||||
):
|
||||
arr = numpy.require(arr, requirements=('C_CONTIGUOUS', 'ALIGNED', 'WRITEABLE', 'OWNDATA'))
|
||||
if arr is data:
|
||||
arr = numpy.array(arr, copy=True)
|
||||
arr_to_int2(arr)
|
||||
i2arr = arr.view('>i2')[::arr.itemsize // 2]
|
||||
return i2arr.tobytes()
|
||||
|
||||
if arr.dtype == numpy.dtype('>i2'):
|
||||
return arr.tobytes()
|
||||
|
||||
if (arr > 32767).any() or (arr < -32768).any():
|
||||
raise Exception(f'int2 data out of range: {arr}')
|
||||
|
||||
return arr.astype('>i2').tobytes()
|
||||
|
||||
|
||||
def pack_int4(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes:
|
||||
arr = numpy.asarray(data)
|
||||
|
||||
if arr.dtype in (
|
||||
numpy.float64, numpy.float32,
|
||||
numpy.int64, numpy.uint64,
|
||||
numpy.int32, numpy.uint32,
|
||||
):
|
||||
arr = numpy.require(arr, requirements=('C_CONTIGUOUS', 'ALIGNED', 'WRITEABLE', 'OWNDATA'))
|
||||
if arr is data:
|
||||
arr = numpy.array(arr, copy=True)
|
||||
arr_to_int4(arr)
|
||||
i4arr = arr.view('>i4')[::arr.itemsize // 4]
|
||||
return i4arr.tobytes()
|
||||
|
||||
if arr.dtype == numpy.dtype('>i4'):
|
||||
return arr.tobytes()
|
||||
|
||||
if (arr > 2147483647).any() or (arr < -2147483648).any():
|
||||
raise Exception(f'int4 data out of range: {arr}')
|
||||
|
||||
return arr.astype('>i4').tobytes()
|
||||
|
||||
0
klamath_rs_ext/py.typed
Normal file
0
klamath_rs_ext/py.typed
Normal file
Loading…
Add table
Add a link
Reference in a new issue