Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ license = "BSD-2-Clause"
half = { version = "2.0", default-features = false, optional = true }
libc = "0.2"
nalgebra = { version = ">=0.30, <0.34", default-features = false, optional = true }
faer = { version = "0.21", optional = true }
num-complex = ">= 0.2, < 0.5"
num-integer = "0.1"
num-traits = "0.2"
ndarray = ">= 0.15, < 0.17"
pyo3 = { version = "0.23.4", default-features = false, features = ["macros"] }
rustc-hash = "2.0"

[features]
faer = ["dep:faer"]

[dev-dependencies]
pyo3 = { version = "0.23.3", default-features = false, features = ["auto-initialize"] }
nalgebra = { version = ">=0.30, <0.34", default-features = false, features = ["std"] }
Expand Down
28 changes: 27 additions & 1 deletion src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::{mem, os::raw::c_int, ptr};

use ndarray::{ArrayBase, Data, Dim, Dimension, IntoDimension, Ix1, OwnedRepr};
use ndarray::{ArrayBase, Data, Dim, Dimension, IntoDimension, Ix1, Ix2, OwnedRepr};
use pyo3::{Bound, Python};

use crate::array::{PyArray, PyArrayMethods};
Expand Down Expand Up @@ -90,6 +90,32 @@ impl<T: Element> IntoPyArray for Vec<T> {
}
}

#[cfg(feature = "faer")]
impl<T: Element> IntoPyArray for faer::Mat<T> {
type Item = T;
type Dim = Ix2;

fn into_pyarray<'py>(mut self, py: Python<'py>) -> Bound<'py, PyArray<Self::Item, Self::Dim>> {
let dims = Dim([self.nrows(), self.ncols()]);
let rstride = self.row_stride();
let cstride = self.col_stride();
let strides = [
rstride * mem::size_of::<T>() as npy_intp,
cstride * mem::size_of::<T>() as npy_intp,
];
let data_ptr = self.as_ptr_mut();
unsafe {
PyArray::from_raw_parts(
py,
dims,
strides.as_ptr(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngoldbaum does PyArray_NewFromDescr copy the strides array or just store the pointer? This looks like a potential use-after-free 🤔

(We already have the same pattern in the other into_pyarray functions in this file, which makes me think it's probably fine? Either that or there's a nasty bug in rust-numpy already...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an explicit copy if strides are passed in:

https://github.com/numpy/numpy/blob/a651643fc9b5699a6dc6f3c85ba499c1f52ce3aa/numpy/_core/src/multiarray/ctors.c#L801-L816

There's another code path that handles structured dtypes with subarrays but it looks like that copies the strides too.

There's a comment here worrying about unaligned input strides:

https://github.com/numpy/numpy/blob/a651643fc9b5699a6dc6f3c85ba499c1f52ce3aa/numpy/_core/src/multiarray/ctors.c#L911-L916

But also I have no idea why an unaligned stride array would ever be a problem.

data_ptr,
PySliceContainer::from(self),
)
}
}
}

impl<A, D> IntoPyArray for ArrayBase<OwnedRepr<A>, D>
where
A: Element,
Expand Down
24 changes: 24 additions & 0 deletions src/slice_container.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,30 @@ impl<T: Send + Sync> From<Vec<T>> for PySliceContainer {
}
}

#[cfg(feature = "faer")]
impl<T: Send + Sync> From<faer::Mat<T>> for PySliceContainer {
fn from(data: faer::Mat<T>) -> Self {
unsafe fn drop_faer_mat<T>(ptr: *mut u8, len_nrows: usize, cap_ncols: usize) {
let _ = faer::mat::MatMut::from_raw_parts_mut(
ptr as *mut T, len_nrows, cap_ncols, 1, cap_ncols as isize);
}

let mut data = mem::ManuallyDrop::new(data);

let ptr = data.as_ptr_mut() as *mut u8;
let len = data.nrows();
let cap = data.ncols();
let drop = drop_faer_mat::<T>;

Self {
ptr,
len,
cap,
drop,
}
}
}

impl<A, D> From<ArrayBase<OwnedRepr<A>, D>> for PySliceContainer
where
A: Send + Sync,
Expand Down
25 changes: 25 additions & 0 deletions tests/to_py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,31 @@ fn slice_container_type_confusion() {
});
}

#[cfg(feature = "faer")]
#[test]
fn faer_mat_to_numpy() {
let faer_mat: faer::Mat<f64> = faer::Scale(2.0) * faer::mat::Mat::<f64>::identity(2, 2);
let faer_mat_wide: faer::Mat<f64> = faer::mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let faer_mat_tall: faer::Mat<f64> = faer_mat_wide.transpose().to_owned();
Python::with_gil(|py| {
let mat_pyarray = faer_mat.into_pyarray(py);
let mat_wide_pyarray = faer_mat_wide.into_pyarray(py);
let mat_tall_pyarray = faer_mat_tall.into_pyarray(py);
assert_eq!(
mat_pyarray.readonly().as_array(),
array![[2.0f64, 0.0f64], [0.0f64, 2.0f64]]
);
assert_eq!(
mat_wide_pyarray.readonly().as_array(),
array![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0]]
);
assert_eq!(
mat_tall_pyarray.readonly().as_array(),
array![[1.0f64, 4.0], [2.0, 5.0], [3.0, 6.0]]
);
});
}

#[cfg(feature = "nalgebra")]
#[test]
fn matrix_to_numpy() {
Expand Down
Loading