Skip to content

Viewing arrays with different shapes #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ public interface BooleanNdArray extends NdArray<Boolean> {
*/
BooleanNdArray setBoolean(boolean value, long... coordinates);

@Override
BooleanNdArray withShape(Shape shape);

@Override
BooleanNdArray slice(Index... indices);

Expand Down
3 changes: 3 additions & 0 deletions ndarray/src/main/java/org/tensorflow/ndarray/ByteNdArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ public interface ByteNdArray extends NdArray<Byte> {
*/
ByteNdArray setByte(byte value, long... coordinates);

@Override
ByteNdArray withShape(Shape shape);

@Override
ByteNdArray slice(Index... indices);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ default DoubleStream streamOfDoubles() {
return StreamSupport.stream(scalars().spliterator(), false).mapToDouble(DoubleNdArray::getDouble);
}

@Override
DoubleNdArray withShape(Shape shape);

@Override
DoubleNdArray slice(Index... indices);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ public interface FloatNdArray extends NdArray<Float> {
*/
FloatNdArray setFloat(float value, long... coordinates);

@Override
FloatNdArray withShape(Shape shape);

@Override
FloatNdArray slice(Index... coordinates);

Expand Down
3 changes: 3 additions & 0 deletions ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ default IntStream streamOfInts() {
return StreamSupport.stream(scalars().spliterator(), false).mapToInt(IntNdArray::getInt);
}

@Override
IntNdArray withShape(Shape shape);

@Override
IntNdArray slice(Index... indices);

Expand Down
3 changes: 3 additions & 0 deletions ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ default LongStream streamOfLongs() {
return StreamSupport.stream(scalars().spliterator(), false).mapToLong(LongNdArray::getLong);
}

@Override
LongNdArray withShape(Shape shape);

@Override
LongNdArray slice(Index... indices);

Expand Down
32 changes: 29 additions & 3 deletions ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
*/
package org.tensorflow.ndarray;

import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.index.Index;

import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.index.Index;

/**
* A data structure of N-dimensions.
*
Expand Down Expand Up @@ -101,6 +101,32 @@ public interface NdArray<T> extends Shaped {
*/
NdArraySequence<? extends NdArray<T>> scalars();

/**
* Returns a new N-dimensional view of this array with the given {@code shape}.
*
* <p>The provided {@code shape} must comply to the following characteristics:
* <ul>
* <li>new shape is known (i.e. has no unknown dimension)</li>
* <li>new shape size is equal to the size of the current shape (i.e. same number of elements)</li>
* </ul>
* For example,
* <pre>{@code
* NdArrays.ofInts(Shape.scalar()).withShape(Shape.of(1, 1)); // ok
* NdArrays.ofInts(Shape.of(2, 3).withShape(Shape.of(3, 2)); // ok
* NdArrays.ofInts(Shape.scalar()).withShape(Shape.of(1, 2)); // not ok, sizes are different (1 != 2)
* NdArrays.ofInts(Shape.of(2, 3)).withShape(Shape.unknown()); // not ok, new shape unknown
* }</pre>
*
* <p>Any changes applied to the returned view affect the data of this array as well, as there
* is no copy involved.
*
* @param shape the new shape to apply
* @return a new array viewing the data according to the new shape, or this array if shapes are the same
* @throws IllegalArgumentException if the provided {@code shape} is not compliant
* @throws UnsupportedOperationException if this array does not support this operation
*/
NdArray<T> withShape(Shape shape);

/**
* Creates a multi-dimensional view (or slice) of this array by mapping one or more dimensions
* to the given index selectors.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ public interface ShortNdArray extends NdArray<Short> {
*/
ShortNdArray setShort(short value, long... coordinates);

@Override
ShortNdArray withShape(Shape shape);

@Override
ShortNdArray slice(Index... coordinates);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.NdArraySequence;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.impl.AbstractNdArray;
import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace;
import org.tensorflow.ndarray.impl.sequence.FastElementSequence;
Expand All @@ -43,18 +44,29 @@ public NdArraySequence<U> elements(int dimensionIdx) {
DimensionalSpace elemDims = dimensions().from(dimensionIdx + 1);
try {
DataBufferWindow<? extends DataBuffer<T>> elemWindow = buffer().window(elemDims.physicalSize());
U element = instantiate(elemWindow.buffer(), elemDims);
U element = instantiateView(elemWindow.buffer(), elemDims);
return new FastElementSequence(this, dimensionIdx, element, elemWindow);
} catch (UnsupportedOperationException e) {
// If buffer windows are not supported, fallback to slicing (and slower) sequence
return new SlicingElementSequence<>(this, dimensionIdx, elemDims);
}
}

@Override
public U withShape(Shape shape) {
if (shape == null || shape.isUnknown() || shape.size() != this.shape().size()) {
throw new IllegalArgumentException("Shape " + shape + " cannot be used to reshape ndarray of shape " + this.shape());
}
if (shape.equals(this.shape())) {
return (U)this;
}
return instantiateView(buffer(), DimensionalSpace.create(shape));
}

@Override
public U slice(long position, DimensionalSpace sliceDimensions) {
DataBuffer<T> sliceBuffer = buffer().slice(position, sliceDimensions.physicalSize());
return instantiate(sliceBuffer, sliceDimensions);
return instantiateView(sliceBuffer, sliceDimensions);
}

@Override
Expand Down Expand Up @@ -147,7 +159,7 @@ protected AbstractDenseNdArray(DimensionalSpace dimensions) {

abstract protected DataBuffer<T> buffer();

abstract U instantiate(DataBuffer<T> buffer, DimensionalSpace dimensions);
abstract U instantiateView(DataBuffer<T> buffer, DimensionalSpace dimensions);

long positionOf(long[] coords, boolean isValue) {
if (coords == null || coords.length == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ protected BooleanDenseNdArray(BooleanDataBuffer buffer, Shape shape) {
}

@Override
BooleanDenseNdArray instantiate(DataBuffer<Boolean> buffer, DimensionalSpace dimensions) {
BooleanDenseNdArray instantiateView(DataBuffer<Boolean> buffer, DimensionalSpace dimensions) {
return new BooleanDenseNdArray((BooleanDataBuffer)buffer, dimensions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ protected ByteDenseNdArray(ByteDataBuffer buffer, Shape shape) {
}

@Override
ByteDenseNdArray instantiate(DataBuffer<Byte> buffer, DimensionalSpace dimensions) {
ByteDenseNdArray instantiateView(DataBuffer<Byte> buffer, DimensionalSpace dimensions) {
return new ByteDenseNdArray((ByteDataBuffer)buffer, dimensions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ protected DenseNdArray(DataBuffer<T> buffer, Shape shape) {
}

@Override
DenseNdArray<T> instantiate(DataBuffer<T> buffer, DimensionalSpace dimensions) {
DenseNdArray<T> instantiateView(DataBuffer<T> buffer, DimensionalSpace dimensions) {
return new DenseNdArray<>(buffer, dimensions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ protected DoubleDenseNdArray(DoubleDataBuffer buffer, Shape shape) {
}

@Override
DoubleDenseNdArray instantiate(DataBuffer<Double> buffer, DimensionalSpace dimensions) {
DoubleDenseNdArray instantiateView(DataBuffer<Double> buffer, DimensionalSpace dimensions) {
return new DoubleDenseNdArray((DoubleDataBuffer)buffer, dimensions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ protected FloatDenseNdArray(FloatDataBuffer buffer, Shape shape) {
}

@Override
FloatDenseNdArray instantiate(DataBuffer<Float> buffer, DimensionalSpace dimensions) {
FloatDenseNdArray instantiateView(DataBuffer<Float> buffer, DimensionalSpace dimensions) {
return new FloatDenseNdArray((FloatDataBuffer) buffer, dimensions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ protected IntDenseNdArray(IntDataBuffer buffer, Shape shape) {
}

@Override
IntDenseNdArray instantiate(DataBuffer<Integer> buffer, DimensionalSpace dimensions) {
IntDenseNdArray instantiateView(DataBuffer<Integer> buffer, DimensionalSpace dimensions) {
return new IntDenseNdArray((IntDataBuffer)buffer, dimensions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ protected LongDenseNdArray(LongDataBuffer buffer, Shape shape) {
}

@Override
LongDenseNdArray instantiate(DataBuffer<Long> buffer, DimensionalSpace dimensions) {
LongDenseNdArray instantiateView(DataBuffer<Long> buffer, DimensionalSpace dimensions) {
return new LongDenseNdArray((LongDataBuffer)buffer, dimensions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ protected ShortDenseNdArray(ShortDataBuffer buffer, Shape shape) {
}

@Override
ShortDenseNdArray instantiate(DataBuffer<Short> buffer, DimensionalSpace dimensions) {
ShortDenseNdArray instantiateView(DataBuffer<Short> buffer, DimensionalSpace dimensions) {
return new ShortDenseNdArray((ShortDataBuffer)buffer, dimensions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ protected long[] getIndicesCoordinates(LongNdArray l) {
*/
public abstract U toDense();

@Override
public U withShape(Shape shape) {
throw new UnsupportedOperationException("Sparse NdArrays cannot be viewed with a different shape");
}

/** {@inheritDoc} */
@Override
public NdArray<T> slice(Index... indices) {
Expand Down
27 changes: 27 additions & 0 deletions ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -384,4 +384,31 @@ public void streamingObjects() {
values = matrix.streamOfObjects().collect(Collectors.toList());
assertIterableEquals(List.of(valueOf(1L), valueOf(2L), valueOf(3L), valueOf(4L)), values);
}

@Test
public void withShape() {
Shape originalShape = Shape.scalar();
Shape newShape = originalShape.prepend(1).prepend(1); // [1, 1]

NdArray<T> originalArray = allocate(originalShape);
originalArray.setObject(valueOf(10L));
assertEquals(valueOf(10L), originalArray.getObject());

NdArray<T> newArray = originalArray.withShape(newShape);
assertNotNull(newArray);
assertEquals(newShape, newArray.shape());
assertEquals(originalShape, originalArray.shape());
assertEquals(valueOf(10L), newArray.getObject(0, 0));

NdArray<T> sameArray = originalArray.withShape(Shape.scalar());
assertSame(originalArray, sameArray);

assertThrows(IllegalArgumentException.class, () -> originalArray.withShape(Shape.of(2)));
assertThrows(IllegalArgumentException.class, () -> originalArray.withShape(Shape.unknown()));

NdArray<T> originalMatrix = allocate(Shape.of(2, 3));
assertThrows(IllegalArgumentException.class, () -> originalMatrix.withShape(Shape.scalar()));
NdArray<T> newMatrix = originalMatrix.withShape(Shape.of(3, 2));
assertEquals(Shape.of(3, 2), newMatrix.shape());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class SparseNdArrayTest {
Expand Down Expand Up @@ -188,4 +189,10 @@ public void testShort() {
assertEquals((short) 0, instance.getShort(2, 2));
assertEquals((short) 0xff00, instance.getShort(2, 3));
}

@Test
public void withShape() {
NdArray<?> sparseArray = NdArrays.sparseOf(indices, NdArrays.vectorOf(1, 2, 3), shape);
assertThrows(UnsupportedOperationException.class, () -> sparseArray.withShape(shape.prepend(1)));
}
}