Skip to content

Commit 2d4371c

Browse files
committed
Remove extraction from Vec for PyArrayLikeDyn
1 parent 8bfbb27 commit 2d4371c

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

src/array_like.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ where
149149

150150
let py = ob.py();
151151

152-
if matches!(D::NDIM, None | Some(1)) {
152+
if matches!(D::NDIM, Some(1)) {
153153
if let Ok(vec) = ob.extract::<Vec<T>>() {
154154
let array = Array1::from(vec)
155155
.into_dimensionality()

tests/array_like.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use ndarray::array;
2-
use numpy::{get_array_module, AllowTypeChange, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn};
2+
use numpy::{
3+
get_array_module, AllowTypeChange, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn,
4+
PyUntypedArrayMethods,
5+
};
36
use pyo3::{
47
ffi::c_str,
58
types::{IntoPyDict, PyAnyMethods, PyDict},
@@ -105,7 +108,9 @@ fn convert_1d_list_on_extract() {
105108
Python::with_gil(|py| {
106109
let py_list = py.eval(c_str!("[1,2,3,4]"), None, None).unwrap();
107110
let extracted_array_1d = py_list.extract::<PyArrayLike1<'_, u32>>().unwrap();
108-
let extracted_array_dyn = py_list.extract::<PyArrayLikeDyn<'_, f64>>().unwrap();
111+
let extracted_array_dyn = py_list
112+
.extract::<PyArrayLikeDyn<'_, f64, AllowTypeChange>>()
113+
.unwrap();
109114

110115
assert_eq!(array![1, 2, 3, 4], extracted_array_1d.as_array());
111116
assert_eq!(
@@ -115,6 +120,25 @@ fn convert_1d_list_on_extract() {
115120
});
116121
}
117122

123+
#[test]
124+
fn preserve_trailing_singleton_dims() {
125+
Python::with_gil(|py| {
126+
let locals = get_np_locals(py);
127+
let py_array = py
128+
.eval(
129+
c_str!("np.array([[1], [2], [3]], dtype='int32')"),
130+
Some(&locals),
131+
None,
132+
)
133+
.unwrap();
134+
let extracted_array = py_array
135+
.extract::<PyArrayLikeDyn<'_, f64, AllowTypeChange>>()
136+
.unwrap();
137+
138+
assert_eq!(extracted_array.shape(), &[3, 1]);
139+
})
140+
}
141+
118142
#[test]
119143
fn unsafe_cast_shall_fail() {
120144
Python::with_gil(|py| {

0 commit comments

Comments
 (0)