1
1
use ndarray:: array;
2
- use numpy:: { get_array_module, AllowTypeChange , PyArrayLike1 , PyArrayLike2 , PyArrayLikeDyn } ;
2
+ use numpy:: {
3
+ get_array_module, AllowTypeChange , PyArrayLike1 , PyArrayLike2 , PyArrayLikeDyn ,
4
+ PyUntypedArrayMethods ,
5
+ } ;
3
6
use pyo3:: {
4
7
ffi:: c_str,
5
8
types:: { IntoPyDict , PyAnyMethods , PyDict } ,
@@ -105,7 +108,9 @@ fn convert_1d_list_on_extract() {
105
108
Python :: with_gil ( |py| {
106
109
let py_list = py. eval ( c_str ! ( "[1,2,3,4]" ) , None , None ) . unwrap ( ) ;
107
110
let extracted_array_1d = py_list. extract :: < PyArrayLike1 < ' _ , u32 > > ( ) . unwrap ( ) ;
108
- let extracted_array_dyn = py_list. extract :: < PyArrayLikeDyn < ' _ , f64 > > ( ) . unwrap ( ) ;
111
+ let extracted_array_dyn = py_list
112
+ . extract :: < PyArrayLikeDyn < ' _ , f64 , AllowTypeChange > > ( )
113
+ . unwrap ( ) ;
109
114
110
115
assert_eq ! ( array![ 1 , 2 , 3 , 4 ] , extracted_array_1d. as_array( ) ) ;
111
116
assert_eq ! (
@@ -115,6 +120,25 @@ fn convert_1d_list_on_extract() {
115
120
} ) ;
116
121
}
117
122
123
+ #[ test]
124
+ fn preserve_trailing_singleton_dims ( ) {
125
+ Python :: with_gil ( |py| {
126
+ let locals = get_np_locals ( py) ;
127
+ let py_array = py
128
+ . eval (
129
+ c_str ! ( "np.array([[1], [2], [3]], dtype='int32')" ) ,
130
+ Some ( & locals) ,
131
+ None ,
132
+ )
133
+ . unwrap ( ) ;
134
+ let extracted_array = py_array
135
+ . extract :: < PyArrayLikeDyn < ' _ , f64 , AllowTypeChange > > ( )
136
+ . unwrap ( ) ;
137
+
138
+ assert_eq ! ( extracted_array. shape( ) , & [ 3 , 1 ] ) ;
139
+ } )
140
+ }
141
+
118
142
#[ test]
119
143
fn unsafe_cast_shall_fail ( ) {
120
144
Python :: with_gil ( |py| {
0 commit comments