ctypes approach

This commit is contained in:
jan 2024-12-21 13:56:51 -08:00
commit 320958d888
5 changed files with 118 additions and 95 deletions

View file

@ -7,87 +7,43 @@ pub mod elements;
pub mod library;
//use ndarray;
use numpy::{PyArray1, PyUntypedArray, PyUntypedArrayMethods, PyArrayDescrMethods, PyArrayMethods, dtype};
use pyo3::prelude::{Python, pymodule, PyModule, PyResult, Bound, wrap_pyfunction, pyfunction, PyModuleMethods, PyAnyMethods};
use pyo3::exceptions::{PyValueError, PyTypeError};
use rust_util::ToInt2BE;
use rust_util::ToInt4BE;
#[pymodule]
fn klamath_rs_ext(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(arr_to_int2, m)?)?;
m.add_function(wrap_pyfunction!(arr_to_int4, m)?)?;
Ok(())
}
#[pyfunction]
fn arr_to_int2(py: Python<'_>, pyarr: &Bound<'_, PyUntypedArray>) -> PyResult<()> {
use rust_util::ToInt2BE;
assert!(pyarr.is_c_contiguous(), "Array must be c-contiguous!");
macro_rules! i2if {
( $el_type:expr, $tt:ty ) => {
if $el_type.is_equiv_to(&dtype::<$tt>(py)) {
let arr = pyarr.downcast::<PyArray1<$tt>>()?;
let mut array = unsafe { arr.as_array_mut() };
for xx in array.iter_mut() {
*xx = <$tt>::convert_to_i2be(*xx).map_err(
|e| PyValueError::new_err(format!("Invalid value for 2-byte int: {}", e))
)?;
macro_rules! mkfun {
( $fname:ident, $tt:ty, $elfn:ident ) => {
#[no_mangle]
pub extern "C" fn $fname(arr: *mut $tt, size: usize) -> $tt {
let sl = unsafe { std::slice::from_raw_parts_mut(arr, size) };
for xx in sl.iter_mut() {
let res = <$tt>::$elfn(*xx);
match res {
Err(cc) => return cc,
Ok(cc) => { *xx = cc; },
}
return Ok(())
}
0 as $tt
}
}
let el_type = pyarr.dtype();
i2if!(el_type, f64);
i2if!(el_type, f32);
i2if!(el_type, i64);
i2if!(el_type, u64);
i2if!(el_type, i32);
i2if!(el_type, u32);
i2if!(el_type, i16);
i2if!(el_type, u16);
Err(PyTypeError::new_err(format!("arr_to_int2 not implemented for type {:?}", el_type)))
}
#[pyfunction]
fn arr_to_int4(py: Python<'_>, pyarr: &Bound<'_, PyUntypedArray>) -> PyResult<()> {
use rust_util::ToInt4BE;
mkfun!(f64_to_i16, f64, convert_to_i2be);
mkfun!(f32_to_i16, f32, convert_to_i2be);
mkfun!(i64_to_i16, i64, convert_to_i2be);
mkfun!(u64_to_i16, u64, convert_to_i2be);
mkfun!(i32_to_i16, i32, convert_to_i2be);
mkfun!(u32_to_i16, u32, convert_to_i2be);
mkfun!(i16_to_i16, i16, convert_to_i2be);
mkfun!(u16_to_i16, u16, convert_to_i2be);
assert!(pyarr.is_c_contiguous(), "Array must be c-contiguous!");
macro_rules! i4if {
( $el_type:expr, $tt:ty ) => {
if $el_type.is_equiv_to(&dtype::<$tt>(py)) {
let arr = pyarr.downcast::<PyArray1<$tt>>()?;
let mut array = unsafe { arr.as_array_mut() };
for xx in array.iter_mut() {
*xx = <$tt>::convert_to_i4be(*xx).map_err(
|e| PyValueError::new_err(format!("Invalid value for 4-byte int: {}", e))
)?;
}
return Ok(())
}
}
}
let el_type = pyarr.dtype();
i4if!(el_type, f64);
i4if!(el_type, f32);
i4if!(el_type, i64);
i4if!(el_type, u64);
i4if!(el_type, i32);
i4if!(el_type, u32);
Err(PyTypeError::new_err(format!("arr_to_int4 not implemented for type {:?}", el_type)))
}
mkfun!(f64_to_i32, f64, convert_to_i4be);
mkfun!(f32_to_i32, f32, convert_to_i4be);
mkfun!(i64_to_i32, i64, convert_to_i4be);
mkfun!(u64_to_i32, u64, convert_to_i4be);
mkfun!(i32_to_i32, i32, convert_to_i4be);
mkfun!(u32_to_i32, u32, convert_to_i4be);
mod rust_util {