Skip to content

Commit 0157df8

Browse files
committed
Add more reshaping methods
1 parent 0bdce23 commit 0157df8

File tree

2 files changed

+236
-0
lines changed

2 files changed

+236
-0
lines changed

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ pub use dimension::NdIndex;
105105
pub use dimension::IxDynImpl;
106106
pub use indexes::{indices, indices_of};
107107
pub use error::{ShapeError, ErrorKind};
108+
pub use reshape::{ArrayViewTemp, ArrayViewMutTemp, ArrayViewMutTempRepr, Order};
108109
pub use slice::{Slice, SliceInfo, SliceNextDim, SliceOrIndex};
109110

110111
use iterators::Baseiter;
@@ -152,6 +153,7 @@ mod linalg_traits;
152153
mod linspace;
153154
mod numeric_util;
154155
mod error;
156+
mod reshape;
155157
mod shape_builder;
156158
mod stacking;
157159
mod zip;

src/reshape.rs

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
use {Array, ArrayBase, ArrayView, ArrayViewMut, Data, DataMut, DataOwned, Dimension,
2+
IntoDimension, ShapeError};
3+
4+
pub enum ArrayViewTemp<'a, A: 'a, D: Dimension> {
5+
Temporary(Array<A, D>),
6+
View(ArrayView<'a, A, D>),
7+
}
8+
9+
impl<'a, A, D> ArrayViewTemp<'a, A, D>
10+
where
11+
D: Dimension,
12+
{
13+
pub fn view(&self) -> ArrayView<A, D> {
14+
match *self {
15+
ArrayViewTemp::Temporary(ref tmp) => tmp.view(),
16+
ArrayViewTemp::View(ref view) => view.view(),
17+
}
18+
}
19+
}
20+
21+
pub struct ArrayViewMutTempRepr<'a, A, Do, Dt>
22+
where
23+
A: 'a,
24+
Do: Dimension,
25+
Dt: Dimension,
26+
{
27+
/// Mutable view of the original array.
28+
view: ArrayViewMut<'a, A, Do>,
29+
/// Temporary owned array that can be modified through `.view_mut()`.
30+
tmp: Array<A, Dt>,
31+
/// Closure that gets called `(drop_hook)(view, tmp)` when `self` is dropped.
32+
drop_hook: Box<FnMut(&mut ArrayViewMut<'a, A, Do>, &mut Array<A, Dt>)>,
33+
}
34+
35+
impl<'a, A, Do, Dt> ArrayViewMutTempRepr<'a, A, Do, Dt>
36+
where
37+
Do: Dimension,
38+
Dt: Dimension,
39+
{
40+
pub fn new(
41+
view: ArrayViewMut<'a, A, Do>,
42+
tmp: Array<A, Dt>,
43+
drop_hook: Box<FnMut(&mut ArrayViewMut<'a, A, Do>, &mut Array<A, Dt>)>,
44+
) -> Self {
45+
ArrayViewMutTempRepr {
46+
view,
47+
tmp,
48+
drop_hook,
49+
}
50+
}
51+
52+
/// Returns a view of the temporary array.
53+
pub fn view(&self) -> ArrayView<A, Dt> {
54+
self.tmp.view()
55+
}
56+
57+
/// Returns a mutable view of the temporary array.
58+
pub fn view_mut(&mut self) -> ArrayViewMut<A, Dt> {
59+
self.tmp.view_mut()
60+
}
61+
}
62+
63+
impl<'a, A, Do, Dt> Drop for ArrayViewMutTempRepr<'a, A, Do, Dt>
64+
where
65+
Do: Dimension,
66+
Dt: Dimension,
67+
{
68+
fn drop(&mut self) {
69+
let ArrayViewMutTempRepr {
70+
ref mut view,
71+
ref mut tmp,
72+
ref mut drop_hook,
73+
} = *self;
74+
(drop_hook)(view, tmp);
75+
}
76+
}
77+
78+
pub enum ArrayViewMutTemp<'a, A, Do, Dt>
79+
where
80+
A: 'a,
81+
Do: Dimension,
82+
Dt: Dimension,
83+
{
84+
Temporary(ArrayViewMutTempRepr<'a, A, Do, Dt>),
85+
View(ArrayViewMut<'a, A, Dt>),
86+
}
87+
88+
impl<'a, A, Do, Dt> ArrayViewMutTemp<'a, A, Do, Dt>
89+
where
90+
Do: Dimension,
91+
Dt: Dimension,
92+
{
93+
pub fn view(&self) -> ArrayView<A, Dt> {
94+
match *self {
95+
ArrayViewMutTemp::Temporary(ref tmp) => tmp.view(),
96+
ArrayViewMutTemp::View(ref view) => view.view(),
97+
}
98+
}
99+
100+
pub fn view_mut(&mut self) -> ArrayViewMut<A, Dt> {
101+
match *self {
102+
ArrayViewMutTemp::Temporary(ref mut tmp) => tmp.view_mut(),
103+
ArrayViewMutTemp::View(ref mut view) => view.view_mut(),
104+
}
105+
}
106+
}
107+
108+
/// This is analogous to the `order` parameter in
109+
/// [`numpy.reshape()`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.reshape.html#numpy.reshape).
110+
pub enum Order {
111+
/// C-like order
112+
RowMajor,
113+
/// Fortran-like order
114+
ColMajor,
115+
/// Fortran-like order if the array is Fortran contiguous in memory, C-like order otherwise
116+
Automatic,
117+
}
118+
119+
impl<A, S, D> ArrayBase<S, D>
120+
where
121+
S: Data<Elem = A>,
122+
D: Dimension,
123+
{
124+
/// Returns an `ArrayViewTemp` instance with the desired shape.
125+
///
126+
/// The reshaped data can be read by calling `.view()` on the
127+
/// `ArrayViewTemp` instance.
128+
///
129+
/// This method does not require the data to be contiguous in memory.
130+
///
131+
/// **Errors** if `self` doesn't have the same number of elements as `shape`.
132+
pub fn view_with_shape<E>(
133+
&self,
134+
shape: E,
135+
order: Order,
136+
) -> Result<ArrayViewTemp<A, E::Dim>, ShapeError>
137+
where
138+
A: Clone,
139+
E: IntoDimension,
140+
{
141+
match order {
142+
Order::RowMajor => if self.is_standard_layout() {
143+
Ok(ArrayViewTemp::View(self.view().into_shape(shape)?))
144+
} else {
145+
let tmp = Array::from_iter(self.iter().cloned()).into_shape(shape)?;
146+
Ok(ArrayViewTemp::Temporary(tmp))
147+
},
148+
Order::ColMajor => unimplemented!(),
149+
Order::Automatic => {
150+
if self.ndim() > 1 && self.view().reversed_axes().is_standard_layout() {
151+
self.view_with_shape(shape, Order::ColMajor)
152+
} else {
153+
self.view_with_shape(shape, Order::RowMajor)
154+
}
155+
}
156+
}
157+
}
158+
159+
/// Returns an `ArrayViewMutTemp` instance with the desired shape.
160+
///
161+
/// The reshaped data can be read/written by calling `.view_mut()` on the
162+
/// `ArrayViewMutTemp` instance.
163+
///
164+
/// This method does not require the data to be contiguous in memory.
165+
///
166+
/// **Errors** if `self` doesn't have the same number of elements as `shape`.
167+
pub fn view_mut_with_shape<E>(
168+
&mut self,
169+
shape: E,
170+
order: Order,
171+
) -> Result<ArrayViewMutTemp<A, D, E::Dim>, ShapeError>
172+
where
173+
A: Clone,
174+
S: DataMut,
175+
E: IntoDimension,
176+
{
177+
match order {
178+
Order::RowMajor => if self.is_standard_layout() {
179+
Ok(ArrayViewMutTemp::View(self.view_mut().into_shape(shape)?))
180+
} else {
181+
let tmp = Array::from_iter(self.iter().cloned()).into_shape(shape)?;
182+
Ok(ArrayViewMutTemp::Temporary(ArrayViewMutTempRepr::new(
183+
self.view_mut(),
184+
tmp,
185+
Box::new(|view, tmp| {
186+
view.iter_mut()
187+
.zip(tmp.iter())
188+
.for_each(|(o, t)| *o = t.clone())
189+
}),
190+
)))
191+
},
192+
Order::ColMajor => unimplemented!(),
193+
Order::Automatic => {
194+
if self.ndim() > 1 && self.view().reversed_axes().is_standard_layout() {
195+
self.view_mut_with_shape(shape, Order::ColMajor)
196+
} else {
197+
self.view_mut_with_shape(shape, Order::RowMajor)
198+
}
199+
}
200+
}
201+
}
202+
203+
/// Returns a new array with the desired shape.
204+
///
205+
/// This method does not require the data to be contiguous in memory.
206+
///
207+
/// **Errors** if `self` doesn't have the same number of elements as `shape`.
208+
pub fn into_shape_owned<E>(
209+
self,
210+
shape: E,
211+
order: Order,
212+
) -> Result<ArrayBase<S, E::Dim>, ShapeError>
213+
where
214+
A: Clone,
215+
S: DataOwned,
216+
E: IntoDimension,
217+
{
218+
match order {
219+
Order::RowMajor => if self.is_standard_layout() {
220+
self.into_shape(shape)
221+
} else {
222+
unimplemented!()
223+
},
224+
Order::ColMajor => unimplemented!(),
225+
Order::Automatic => {
226+
if self.ndim() > 1 && self.view().reversed_axes().is_standard_layout() {
227+
self.into_shape_owned(shape, Order::ColMajor)
228+
} else {
229+
self.into_shape_owned(shape, Order::RowMajor)
230+
}
231+
}
232+
}
233+
}
234+
}

0 commit comments

Comments
 (0)