diff --git a/klamath_rs_ext/basic.py b/klamath_rs_ext/basic.py index 4540fe6..aabc246 100644 --- a/klamath_rs_ext/basic.py +++ b/klamath_rs_ext/basic.py @@ -32,7 +32,7 @@ CONV_TABLE_i32 = { def pack_int2(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: - arr = numpy.asarray(data) + arr = numpy.atleast_1d(data) for dtype in CONV_TABLE_i16.keys(): if arr.dtype != dtype: @@ -43,7 +43,8 @@ def pack_int2(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: arr = numpy.array(arr, copy=True) fn = CONV_TABLE_i16[dtype] - result = fn(ffi.from_buffer(arr), arr.size) + buf = ffi.from_buffer(ffi.typeof(fn).args[0], arr, require_writable=True) + result = fn(buf, arr.size) if result != 0: raise ValueError(f'Invalid value for conversion to Int2: {result}') @@ -61,7 +62,7 @@ def pack_int2(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: def pack_int4(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: - arr = numpy.asarray(data) + arr = numpy.atleast_1d(data) for dtype in CONV_TABLE_i32.keys(): if arr.dtype != dtype: @@ -72,7 +73,8 @@ def pack_int4(data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: arr = numpy.array(arr, copy=True) fn = CONV_TABLE_i32[dtype] - result = fn(ffi.from_buffer(arr), arr.size) + buf = ffi.from_buffer(ffi.typeof(fn).args[0], arr, require_writable=True) + result = fn(buf, arr.size) if result != 0: raise ValueError(f'Invalid value for conversion to Int4: {result}')