diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/BooleanNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/BooleanNdArray.java index 5b4bedb..bd16a9a 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/BooleanNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/BooleanNdArray.java @@ -68,6 +68,9 @@ public interface BooleanNdArray extends NdArray { */ BooleanNdArray setBoolean(boolean value, long... coordinates); + @Override + BooleanNdArray withShape(Shape shape); + @Override BooleanNdArray slice(Index... indices); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/ByteNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/ByteNdArray.java index 0e6f118..47e5a0d 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/ByteNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/ByteNdArray.java @@ -68,6 +68,9 @@ public interface ByteNdArray extends NdArray { */ ByteNdArray setByte(byte value, long... coordinates); + @Override + ByteNdArray withShape(Shape shape); + @Override ByteNdArray slice(Index... indices); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java index b0e6dab..da42bab 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java @@ -83,6 +83,9 @@ default DoubleStream streamOfDoubles() { return StreamSupport.stream(scalars().spliterator(), false).mapToDouble(DoubleNdArray::getDouble); } + @Override + DoubleNdArray withShape(Shape shape); + @Override DoubleNdArray slice(Index... indices); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/FloatNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/FloatNdArray.java index 8d4fbf5..34e4201 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/FloatNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/FloatNdArray.java @@ -68,6 +68,9 @@ public interface FloatNdArray extends NdArray { */ FloatNdArray setFloat(float value, long... coordinates); + @Override + FloatNdArray withShape(Shape shape); + @Override FloatNdArray slice(Index... coordinates); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java index e6a5cf0..71f19b9 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java @@ -83,6 +83,9 @@ default IntStream streamOfInts() { return StreamSupport.stream(scalars().spliterator(), false).mapToInt(IntNdArray::getInt); } + @Override + IntNdArray withShape(Shape shape); + @Override IntNdArray slice(Index... indices); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java index e7bd266..a55b2ab 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java @@ -83,6 +83,9 @@ default LongStream streamOfLongs() { return StreamSupport.stream(scalars().spliterator(), false).mapToLong(LongNdArray::getLong); } + @Override + LongNdArray withShape(Shape shape); + @Override LongNdArray slice(Index... indices); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java index f1e84d4..a75da48 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java @@ -16,14 +16,14 @@ */ package org.tensorflow.ndarray; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.index.Index; + import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.stream.Stream; import java.util.stream.StreamSupport; -import org.tensorflow.ndarray.buffer.DataBuffer; -import org.tensorflow.ndarray.index.Index; - /** * A data structure of N-dimensions. * @@ -101,6 +101,32 @@ public interface NdArray extends Shaped { */ NdArraySequence> scalars(); + /** + * Returns a new N-dimensional view of this array with the given {@code shape}. + * + *

The provided {@code shape} must comply to the following characteristics: + *

    + *
  • new shape is known (i.e. has no unknown dimension)
  • + *
  • new shape size is equal to the size of the current shape (i.e. same number of elements)
  • + *
+ * For example, + *
{@code
+   *    NdArrays.ofInts(Shape.scalar()).withShape(Shape.of(1, 1));  // ok
+   *    NdArrays.ofInts(Shape.of(2, 3).withShape(Shape.of(3, 2));   // ok
+   *    NdArrays.ofInts(Shape.scalar()).withShape(Shape.of(1, 2));  // not ok, sizes are different (1 != 2)
+   *    NdArrays.ofInts(Shape.of(2, 3)).withShape(Shape.unknown()); // not ok, new shape unknown
+   * }
+ * + *

Any changes applied to the returned view affect the data of this array as well, as there + * is no copy involved. + * + * @param shape the new shape to apply + * @return a new array viewing the data according to the new shape, or this array if shapes are the same + * @throws IllegalArgumentException if the provided {@code shape} is not compliant + * @throws UnsupportedOperationException if this array does not support this operation + */ + NdArray withShape(Shape shape); + /** * Creates a multi-dimensional view (or slice) of this array by mapping one or more dimensions * to the given index selectors. diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/ShortNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/ShortNdArray.java index f9335b4..92b608f 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/ShortNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/ShortNdArray.java @@ -68,6 +68,9 @@ public interface ShortNdArray extends NdArray { */ ShortNdArray setShort(short value, long... coordinates); + @Override + ShortNdArray withShape(Shape shape); + @Override ShortNdArray slice(Index... coordinates); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java index 30af952..baaf23e 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java @@ -18,6 +18,7 @@ import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArraySequence; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.impl.AbstractNdArray; import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace; import org.tensorflow.ndarray.impl.sequence.FastElementSequence; @@ -43,7 +44,7 @@ public NdArraySequence elements(int dimensionIdx) { DimensionalSpace elemDims = dimensions().from(dimensionIdx + 1); try { DataBufferWindow> elemWindow = buffer().window(elemDims.physicalSize()); - U element = instantiate(elemWindow.buffer(), elemDims); + U element = instantiateView(elemWindow.buffer(), elemDims); return new FastElementSequence(this, dimensionIdx, element, elemWindow); } catch (UnsupportedOperationException e) { // If buffer windows are not supported, fallback to slicing (and slower) sequence @@ -51,10 +52,21 @@ public NdArraySequence elements(int dimensionIdx) { } } + @Override + public U withShape(Shape shape) { + if (shape == null || shape.isUnknown() || shape.size() != this.shape().size()) { + throw new IllegalArgumentException("Shape " + shape + " cannot be used to reshape ndarray of shape " + this.shape()); + } + if (shape.equals(this.shape())) { + return (U)this; + } + return instantiateView(buffer(), DimensionalSpace.create(shape)); + } + @Override public U slice(long position, DimensionalSpace sliceDimensions) { DataBuffer sliceBuffer = buffer().slice(position, sliceDimensions.physicalSize()); - return instantiate(sliceBuffer, sliceDimensions); + return instantiateView(sliceBuffer, sliceDimensions); } @Override @@ -147,7 +159,7 @@ protected AbstractDenseNdArray(DimensionalSpace dimensions) { abstract protected DataBuffer buffer(); - abstract U instantiate(DataBuffer buffer, DimensionalSpace dimensions); + abstract U instantiateView(DataBuffer buffer, DimensionalSpace dimensions); long positionOf(long[] coords, boolean isValue) { if (coords == null || coords.length == 0) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java index 0764146..9c134b5 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java @@ -73,7 +73,7 @@ protected BooleanDenseNdArray(BooleanDataBuffer buffer, Shape shape) { } @Override - BooleanDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensions) { + BooleanDenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { return new BooleanDenseNdArray((BooleanDataBuffer)buffer, dimensions); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java index 172432b..a2525c6 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java @@ -73,7 +73,7 @@ protected ByteDenseNdArray(ByteDataBuffer buffer, Shape shape) { } @Override - ByteDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensions) { + ByteDenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { return new ByteDenseNdArray((ByteDataBuffer)buffer, dimensions); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java index 819d95d..18b3755 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java @@ -45,7 +45,7 @@ protected DenseNdArray(DataBuffer buffer, Shape shape) { } @Override - DenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensions) { + DenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { return new DenseNdArray<>(buffer, dimensions); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java index f54b8d0..a967ce1 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java @@ -73,7 +73,7 @@ protected DoubleDenseNdArray(DoubleDataBuffer buffer, Shape shape) { } @Override - DoubleDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensions) { + DoubleDenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { return new DoubleDenseNdArray((DoubleDataBuffer)buffer, dimensions); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java index 196b5ef..a04c192 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java @@ -73,7 +73,7 @@ protected FloatDenseNdArray(FloatDataBuffer buffer, Shape shape) { } @Override - FloatDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensions) { + FloatDenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { return new FloatDenseNdArray((FloatDataBuffer) buffer, dimensions); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java index a7af498..e1a726f 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java @@ -73,7 +73,7 @@ protected IntDenseNdArray(IntDataBuffer buffer, Shape shape) { } @Override - IntDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensions) { + IntDenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { return new IntDenseNdArray((IntDataBuffer)buffer, dimensions); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java index cd56dad..802cbcd 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java @@ -73,7 +73,7 @@ protected LongDenseNdArray(LongDataBuffer buffer, Shape shape) { } @Override - LongDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensions) { + LongDenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { return new LongDenseNdArray((LongDataBuffer)buffer, dimensions); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java index 291c01a..434b260 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java @@ -73,7 +73,7 @@ protected ShortDenseNdArray(ShortDataBuffer buffer, Shape shape) { } @Override - ShortDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensions) { + ShortDenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { return new ShortDenseNdArray((ShortDataBuffer)buffer, dimensions); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java index 8e3892d..e4a2fba 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java @@ -212,6 +212,11 @@ protected long[] getIndicesCoordinates(LongNdArray l) { */ public abstract U toDense(); + @Override + public U withShape(Shape shape) { + throw new UnsupportedOperationException("Sparse NdArrays cannot be viewed with a different shape"); + } + /** {@inheritDoc} */ @Override public NdArray slice(Index... indices) { diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java index 8a09ec7..577c2b7 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java @@ -384,4 +384,31 @@ public void streamingObjects() { values = matrix.streamOfObjects().collect(Collectors.toList()); assertIterableEquals(List.of(valueOf(1L), valueOf(2L), valueOf(3L), valueOf(4L)), values); } + + @Test + public void withShape() { + Shape originalShape = Shape.scalar(); + Shape newShape = originalShape.prepend(1).prepend(1); // [1, 1] + + NdArray originalArray = allocate(originalShape); + originalArray.setObject(valueOf(10L)); + assertEquals(valueOf(10L), originalArray.getObject()); + + NdArray newArray = originalArray.withShape(newShape); + assertNotNull(newArray); + assertEquals(newShape, newArray.shape()); + assertEquals(originalShape, originalArray.shape()); + assertEquals(valueOf(10L), newArray.getObject(0, 0)); + + NdArray sameArray = originalArray.withShape(Shape.scalar()); + assertSame(originalArray, sameArray); + + assertThrows(IllegalArgumentException.class, () -> originalArray.withShape(Shape.of(2))); + assertThrows(IllegalArgumentException.class, () -> originalArray.withShape(Shape.unknown())); + + NdArray originalMatrix = allocate(Shape.of(2, 3)); + assertThrows(IllegalArgumentException.class, () -> originalMatrix.withShape(Shape.scalar())); + NdArray newMatrix = originalMatrix.withShape(Shape.of(3, 2)); + assertEquals(Shape.of(3, 2), newMatrix.shape()); + } } diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/SparseNdArrayTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/SparseNdArrayTest.java index 43779b3..0c5d6b3 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/SparseNdArrayTest.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/SparseNdArrayTest.java @@ -25,6 +25,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class SparseNdArrayTest { @@ -188,4 +189,10 @@ public void testShort() { assertEquals((short) 0, instance.getShort(2, 2)); assertEquals((short) 0xff00, instance.getShort(2, 3)); } + + @Test + public void withShape() { + NdArray sparseArray = NdArrays.sparseOf(indices, NdArrays.vectorOf(1, 2, 3), shape); + assertThrows(UnsupportedOperationException.class, () -> sparseArray.withShape(shape.prepend(1))); + } }