diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java index 64ab7eb..7be6f90 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java @@ -16,6 +16,7 @@ */ package org.tensorflow.ndarray; +import java.util.function.Consumer; import org.tensorflow.ndarray.buffer.BooleanDataBuffer; import org.tensorflow.ndarray.buffer.ByteDataBuffer; import org.tensorflow.ndarray.buffer.DataBuffer; @@ -25,6 +26,9 @@ import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.buffer.LongDataBuffer; import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; +import org.tensorflow.ndarray.hydrator.NdArrayHydrator; +import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray; import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray; import org.tensorflow.ndarray.impl.dense.ByteDenseNdArray; import org.tensorflow.ndarray.impl.dense.DenseNdArray; @@ -33,7 +37,10 @@ import org.tensorflow.ndarray.impl.dense.IntDenseNdArray; import org.tensorflow.ndarray.impl.dense.LongDenseNdArray; import org.tensorflow.ndarray.impl.dense.ShortDenseNdArray; +import org.tensorflow.ndarray.impl.dense.hydrator.DenseNdArrayHydrator; +import org.tensorflow.ndarray.impl.dense.hydrator.DoubleDenseNdArrayHydrator; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; import org.tensorflow.ndarray.impl.sparse.BooleanSparseNdArray; import org.tensorflow.ndarray.impl.sparse.ByteSparseNdArray; import org.tensorflow.ndarray.impl.sparse.DoubleSparseNdArray; @@ -41,6 +48,8 @@ import org.tensorflow.ndarray.impl.sparse.IntSparseNdArray; import org.tensorflow.ndarray.impl.sparse.LongSparseNdArray; import org.tensorflow.ndarray.impl.sparse.ShortSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.hydrator.DoubleSparseNdArrayHydrator; +import org.tensorflow.ndarray.impl.sparse.hydrator.SparseNdArrayHydrator; /** Utility class for instantiating {@link NdArray} objects. */ public final class NdArrays { @@ -555,6 +564,20 @@ public static DoubleNdArray ofDoubles(Shape shape) { return wrap(shape, DataBuffers.ofDoubles(shape.size())); } + /** + * Creates an N-dimensional array of doubles of the given shape, hydrating it with data after its allocation + * + * @param shape shape of the array + * @param hydrate initialize the data of the created array, using a hydrator + * @return new double N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static DoubleNdArray ofDoubles(Shape shape, Consumer hydrate) { + DoubleDenseNdArray array = (DoubleDenseNdArray)ofDoubles(shape); + hydrate.accept(new DoubleDenseNdArrayHydrator(array)); + return array; + } + /** * Wraps a buffer in a double N-dimensional array of a given shape. * @@ -568,6 +591,23 @@ public static DoubleNdArray wrap(Shape shape, DoubleDataBuffer buffer) { return DoubleDenseNdArray.create(buffer, shape); } + /** + * Creates an Sparse array of doubles of the given shape, hydrating it with data after its allocation + * + * @param shape shape of the array + * @param numValues number of double value actually set in the array, others defaulting to the zero value + * @param hydrate initialize the data of the created array, using a hydrator + * @return new double N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static DoubleSparseNdArray sparseOfDoubles(Shape shape, long numValues, Consumer hydrate) { + LongNdArray indices = ofLongs(Shape.of(numValues, shape.numDimensions())); + DoubleNdArray values = ofDoubles(Shape.of(numValues)); + DoubleSparseNdArray array = DoubleSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + hydrate.accept(new DoubleSparseNdArrayHydrator(array)); + return array; + } + /** * Creates a Sparse array of double values with a default value of zero * @@ -756,6 +796,21 @@ public static NdArray ofObjects(Class clazz, Shape shape) { return wrap(shape, DataBuffers.ofObjects(clazz, shape.size())); } + /** + * Creates an N-dimensional array of objects of the given shape, hydrating it with data after its allocation + * + * @param clazz class of the data to be stored in this array + * @param shape shape of the array + * @param hydrate initialize the data of the created array, using a hydrator + * @return new N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static NdArray ofObjects(Class clazz, Shape shape, Consumer> hydrate) { + AbstractDenseNdArray array = (AbstractDenseNdArray)ofObjects(clazz, shape); + hydrate.accept(new DenseNdArrayHydrator(array)); + return array; + } + /** * Wraps a buffer in an N-dimensional array of a given shape. * @@ -770,6 +825,24 @@ public static NdArray wrap(Shape shape, DataBuffer buffer) { return DenseNdArray.wrap(buffer, shape); } + /** + * Creates an Sparse array of objects of the given shape, hydrating it with data after its allocation + * + * @param type the class type represented by this sparse array. + * @param shape shape of the array + * @param numValues number of values actually set in the array, others defaulting to the zero value + * @param hydrate initialize the data of the created array, using a hydrator + * @return new N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static NdArray sparseOfObjects(Class type, Shape shape, long numValues, Consumer> hydrate) { + LongNdArray indices = ofLongs(Shape.of(numValues, shape.numDimensions())); + NdArray values = ofObjects(type, Shape.of(numValues)); + AbstractSparseNdArray array = (AbstractSparseNdArray)sparseOfObjects(type, indices, values, shape); + hydrate.accept(new SparseNdArrayHydrator(array)); + return array; + } + /** * Creates a Sparse array of values with a null default value * diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java new file mode 100644 index 0000000..9440e42 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java @@ -0,0 +1,172 @@ +package org.tensorflow.ndarray.hydrator; + +import org.tensorflow.ndarray.DoubleNdArray; + +/** + * Specialization of the {@link NdArrayHydrator} API for hydrating arrays of doubles. + * + * @see NdArrayHydrator + */ +public interface DoubleNdArrayHydrator { + + /** + * An API for hydrate an {@link DoubleNdArray} using scalar values + */ + interface Scalars { + + /** + * Position the hydrator to the given {@code coordinates} to write the next scalars. + * + * @param coordinates position in the hydrated array + * @return this API + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a scalar + */ + Scalars at(long... coordinates); + + /** + * Set a double value as the next scalar value in the hydrated array. + * + * @param scalar next scalar value + * @return this API + * @throws IllegalArgumentException if {@code scalar} is null + */ + Scalars put(double scalar); + } + + /** + * An API for hydrate an {@link DoubleNdArray} using vectors, i.e. a list of scalars + */ + interface Vectors { + + /** + * Position the hydrator to the given {@code coordinates} to write the next vectors. + * + * @param coordinates position in the hydrated array + * @return this API + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a vector + */ + Vectors at(long... coordinates); + + /** + * Set a list of doubles as the next vector in the hydrated array. + * + * @param vector next vector values + * @return this API + * @throws IllegalArgumentException if {@code vector} is empty or its length is greater than the size of the dimension + * {@code n-1}, given {@code n} the rank of the hydrated array + */ + Vectors put(double... vector); + } + + /** + * An API for hydrate an {@link DoubleNdArray} using n-dimensional elements (sub-arrays). + */ + interface Elements { + + /** + * Position the hydrator to the given {@code coordinates} to write the next elements. + * + * @param coordinates position in the hydrated array + * @return this API + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of an element of the hydrated array + */ + Elements at(long... coordinates); + + /** + * Set a n-dimensional array of doubles as the next element in the hydrated array. + * + * @param element array containing the next element values + * @return this API + * @throws IllegalArgumentException if {@code element} is null or its shape is incompatible with the current hydrator position + */ + Elements put(DoubleNdArray element); + } + + /** + * Start to hydrate the targeted array with scalars. + * + * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, + * defaults to the first scalar of this array. + * + * Example of usage: + *
{@code
+   *    DoubleNdArray array = NdArrays.ofDoubles(Shape.of(3, 2), hydrator -> {
+   *        hydrator.byScalars()
+   *          .put(10.0)
+   *          .put(20.0)
+   *          .put(30.0)
+   *          .at(2, 1)
+   *          .put(40.0);
+   *    });
+   *    // -> [[10.0, 20.0], [30.0, 0.0], [0.0, 40.0]]
+   * }
+ * + * @param coordinates position in the hydrated array to start from + * @return a {@link Scalars} instance + * @throws IllegalArgumentException if {@code coordinates} are set but are not one of a scalar + */ + Scalars byScalars(long... coordinates); + + /** + * Start to hydrate the targeted array with vectors. + * + * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, + * defaults to the first scalar of the first vector of this array. + * + * Example of usage: + *
{@code
+   *    DoubleNdArray array = NdArrays.ofDoubles(Shape.of(3, 2), hydrator -> {
+   *        hydrator.byVectors()
+   *          .put(10.0, 20.0)
+   *          .put(30.0)
+   *          .at(2)
+   *          .put(40.0, 50.0);
+   *    });
+   *    // -> [[10.0, 20.0], [30.0, null], [40.0, 50.0]]
+   * }
+ * + * @param coordinates position in the hydrated array to start from + * @return a {@link Vectors} instance + * @throws IllegalArgumentException if hydrated array is of rank-0 or if {@code coordinates} are set but are not one of a vector + */ + Vectors byVectors(long... coordinates); + + /** + * Start to hydrate the targeted array with n-dimensional elements. + * + * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, + * defaults to the first element in the first (0) dimension of the hydrated array. + * + * Example of usage: + *
{@code
+   *    DoubleNdArray vector = NdArrays.vectorOf(10.0, 20.0);
+   *    DoubleNdArray scalar = NdArrays.scalarOf(30.0);
+   *
+   *    DoubleNdArray array = NdArrays.ofDoubles(Shape.of(4, 2), hydrator -> {
+   *        hydrator.byElements()
+   *          .put(vector)
+   *          .put(vector)
+   *          .at(2, 1)
+   *          .put(scalar)
+   *          .at(3)
+   *          .put(vector);
+   *    });
+   *    // -> [[10.0, 20.0], [10.0, 20.0], [0.0, 30.0], [10.0, 20.0]]
+   * }
+ * + * @param coordinates position in the hydrated array to start from + * @return a {@link Elements} instance + * @throws IllegalArgumentException if {@code coordinates} are set but are not one of an element of the hydrated array + */ + Elements byElements(long... coordinates); + + /** + * Creates an API to hydrate the targeted array with {@code Double} boxed type. + * + * Note that sticking to primitive types improve I/O performances overall, so only rely boxed types if the data is already + * available in that format. + * + * @return a hydrator supporting {@code Double} boxed type + */ + NdArrayHydrator boxed(); +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java new file mode 100644 index 0000000..fc64f89 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java @@ -0,0 +1,177 @@ +package org.tensorflow.ndarray.hydrator; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.buffer.DataBuffer; + +/** + * Interface for initializing the data of a {@link NdArray} that has just been allocated. + * + * While it is always possible to set the data of a read-write NdArray using standard output methods, + * like {@link NdArray#write(DataBuffer)} or {@link NdArray#copyTo(NdArray)}, the hydrator API focuses on + * sequential per-element initialization, similar to standard Java arrays. + * + * Since the hydrator API is only accessible right after the array have been allocated, it can be used to + * initialize data-sensitive arrays, like {@link org.tensorflow.ndarray.SparseNdArray}, which can be only + * written once and stay read-only thereafter. + * + * @param the type of data of the {@link NdArray} to initialize + */ +public interface NdArrayHydrator { + + /** + * An API for hydrate an {@link NdArray} using scalar values + * + * @param the type of data of the {@link NdArray} to initialize + */ + interface Scalars { + + /** + * Position the hydrator to the given {@code coordinates} to write the next scalars. + * + * @param coordinates position in the hydrated array + * @return this API + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a scalar + */ + Scalars at(long... coordinates); + + /** + * Set an object as the next scalar value in the hydrated array. + * + * @param scalar next scalar value + * @return this API + * @throws IllegalArgumentException if {@code scalar} is null + */ + Scalars put(T scalar); + } + + /** + * An API for hydrate an {@link NdArray} using vectors, i.e. a list of scalars + * + * @param the type of data of the {@link NdArray} to initialize + */ + interface Vectors { + + /** + * Position the hydrator to the given {@code coordinates} to write the next vectors. + * + * @param coordinates position in the hydrated array + * @return this API + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a vector + */ + Vectors at(long... coordinates); + + /** + * Set a list of objects as the next vector in the hydrated array. + * + * @param vector next vector values + * @return this API + * @throws IllegalArgumentException if {@code vector} is empty or its length is greater than the size of the dimension + * {@code n-1}, given {@code n} the rank of the hydrated array + */ + Vectors put(T... vector); + } + + /** + * An API for hydrate an {@link NdArray} using n-dimensional elements (sub-arrays). + * + * @param the type of data of the {@link NdArray} to initialize + */ + interface Elements { + + /** + * Position the hydrator to the given {@code coordinates} to write the next elements. + * + * @param coordinates position in the hydrated array + * @return this API + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of an element of the hydrated array + */ + Elements at(long... coordinates); + + /** + * Set a n-dimensional array of objects as the next element in the hydrated array. + * + * @param element array containing the next element values + * @return this API + * @throws IllegalArgumentException if {@code element} is null or its shape is incompatible with the current hydrator position + */ + Elements put(NdArray element); + } + + /** + * Start to hydrate the targeted array with scalars. + * + * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, + * defaults to the first scalar of this array. + * + * Example of usage: + *
{@code
+   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(3, 2), hydrator -> {
+   *        hydrator.byScalars()
+   *          .put("Cat")
+   *          .put("Dog")
+   *          .put("House")
+   *          .at(2, 1)
+   *          .put("Apple");
+   *    });
+   *    // -> [["Cat", "Dog"], ["House", null], [null, "Apple"]]
+   * }
+ * + * @param coordinates position in the hydrated array to start from + * @return a {@link Scalars} instance + * @throws IllegalArgumentException if {@code coordinates} are set but are not one of a scalar + */ + Scalars byScalars(long... coordinates); + + /** + * Start to hydrate the targeted array with vectors. + * + * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, + * defaults to the first scalar of the first vector of this array. + * + * Example of usage: + *
{@code
+   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(3, 2), hydrator -> {
+   *        hydrator.byVectors()
+   *          .put("Cat", "Dog")
+   *          .put("House")
+   *          .at(2)
+   *          .put("Orange", "Apple");
+   *    });
+   *    // -> [["Cat", "Dog"], ["House", null], ["Orange", "Apple"]]
+   * }
+ * + * @param coordinates position in the hydrated array to start from + * @return a {@link Vectors} instance + * @throws IllegalArgumentException if hydrated array is of rank-0 or if {@code coordinates} are set but are not one of a vector + */ + Vectors byVectors(long... coordinates); + + /** + * Start to hydrate the targeted array with n-dimensional elements. + * + * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, + * defaults to the first element in the first (0) dimension of the hydrated array. + * + * Example of usage: + *
{@code
+   *    NdArray vector = NdArrays.vectorOfObjects("Cat", "Dog");
+   *    NdArray scalar = NdArrays.scalarOfObject("Apple");
+   *
+   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(4, 2), hydrator -> {
+   *        hydrator.byElements()
+   *          .put(vector)
+   *          .put(vector)
+   *          .at(2, 1)
+   *          .put(scalar)
+   *          .at(3)
+   *          .put(vector);
+   *    });
+   *    // -> [["Cat", "Dog"], ["Cat", "Dog"], [null, "Apple"], ["Cat", "Dog"]]
+   * }
+ * + * @param coordinates position in the hydrated array to start from + * @return a {@link Elements} instance + * @throws IllegalArgumentException if {@code coordinates} are set but are not one of an element of the hydrated array + */ + Elements byElements(long... coordinates); +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java index 0497095..86184e0 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java @@ -31,6 +31,20 @@ @SuppressWarnings("unchecked") public abstract class AbstractDenseNdArray> extends AbstractNdArray { + abstract public DataBuffer buffer(); + + public NdArraySequence elementsAt(long[] startCoords) { + DimensionalSpace elemDims = dimensions().from(startCoords.length); + try { + DataBufferWindow> elemWindow = buffer().window(elemDims.physicalSize()); + U element = instantiate(elemWindow.buffer(), elemDims); + return new FastElementSequence(this, startCoords, element, elemWindow); + } catch (UnsupportedOperationException e) { + // If buffer windows are not supported, fallback to slicing (and slower) sequence + return new SlicingElementSequence(this, startCoords, elemDims); + } + } + @Override public NdArraySequence elements(int dimensionIdx) { if (dimensionIdx >= shape().numDimensions()) { @@ -40,15 +54,7 @@ public NdArraySequence elements(int dimensionIdx) { if (rank() == 0 && dimensionIdx < 0) { return new SingleElementSequence<>(this); } - DimensionalSpace elemDims = dimensions().from(dimensionIdx + 1); - try { - DataBufferWindow> elemWindow = buffer().window(elemDims.physicalSize()); - U element = instantiate(elemWindow.buffer(), elemDims); - return new FastElementSequence(this, dimensionIdx, element, elemWindow); - } catch (UnsupportedOperationException e) { - // If buffer windows are not supported, fallback to slicing (and slower) sequence - return new SlicingElementSequence<>(this, dimensionIdx, elemDims); - } + return elementsAt(new long[dimensionIdx + 1]); } @Override @@ -136,8 +142,6 @@ protected AbstractDenseNdArray(DimensionalSpace dimensions) { super(dimensions); } - abstract protected DataBuffer buffer(); - abstract U instantiate(DataBuffer buffer, DimensionalSpace dimensions); long positionOf(long[] coords, boolean isValue) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java index 0764146..a3caf40 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java @@ -31,6 +31,11 @@ public static BooleanNdArray create(BooleanDataBuffer buffer, Shape shape) { return new BooleanDenseNdArray(buffer, shape); } + @Override + public BooleanDataBuffer buffer() { + return buffer; + } + @Override public boolean getBoolean(long... indices) { return buffer.getBoolean(positionOf(indices, true)); @@ -77,11 +82,6 @@ BooleanDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dim return new BooleanDenseNdArray((BooleanDataBuffer)buffer, dimensions); } - @Override - protected BooleanDataBuffer buffer() { - return buffer; - } - private final BooleanDataBuffer buffer; private BooleanDenseNdArray(BooleanDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java index 172432b..fa3b722 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java @@ -31,6 +31,11 @@ public static ByteNdArray create(ByteDataBuffer buffer, Shape shape) { return new ByteDenseNdArray(buffer, shape); } + @Override + public ByteDataBuffer buffer() { + return buffer; + } + @Override public byte getByte(long... indices) { return buffer.getByte(positionOf(indices, true)); @@ -77,11 +82,6 @@ ByteDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimension return new ByteDenseNdArray((ByteDataBuffer)buffer, dimensions); } - @Override - protected ByteDataBuffer buffer() { - return buffer; - } - private final ByteDataBuffer buffer; private ByteDenseNdArray(ByteDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java index 819d95d..54d337b 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java @@ -50,7 +50,7 @@ DenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensions) { } @Override - protected DataBuffer buffer() { + public DataBuffer buffer() { return buffer; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java index f54b8d0..d30c350 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java @@ -31,6 +31,11 @@ public static DoubleNdArray create(DoubleDataBuffer buffer, Shape shape) { return new DoubleDenseNdArray(buffer, shape); } + @Override + public DoubleDataBuffer buffer() { + return buffer; + } + @Override public double getDouble(long... indices) { return buffer.getDouble(positionOf(indices, true)); @@ -77,11 +82,6 @@ DoubleDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimen return new DoubleDenseNdArray((DoubleDataBuffer)buffer, dimensions); } - @Override - protected DoubleDataBuffer buffer() { - return buffer; - } - private final DoubleDataBuffer buffer; private DoubleDenseNdArray(DoubleDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java index 196b5ef..b164211 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java @@ -31,6 +31,11 @@ public static FloatNdArray create(FloatDataBuffer buffer, Shape shape) { return new FloatDenseNdArray(buffer, shape); } + @Override + public FloatDataBuffer buffer() { + return buffer; + } + @Override public float getFloat(long... indices) { return buffer.getFloat(positionOf(indices, true)); @@ -77,11 +82,6 @@ FloatDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensi return new FloatDenseNdArray((FloatDataBuffer) buffer, dimensions); } - @Override - public FloatDataBuffer buffer() { - return buffer; - } - private final FloatDataBuffer buffer; private FloatDenseNdArray(FloatDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java index a7af498..3cbd15e 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java @@ -31,6 +31,11 @@ public static IntNdArray create(IntDataBuffer buffer, Shape shape) { return new IntDenseNdArray(buffer, shape); } + @Override + public IntDataBuffer buffer() { + return buffer; + } + @Override public int getInt(long... indices) { return buffer.getInt(positionOf(indices, true)); @@ -77,11 +82,6 @@ IntDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensi return new IntDenseNdArray((IntDataBuffer)buffer, dimensions); } - @Override - protected IntDataBuffer buffer() { - return buffer; - } - private final IntDataBuffer buffer; private IntDenseNdArray(IntDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java index cd56dad..8c33528 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java @@ -31,6 +31,11 @@ public static LongNdArray create(LongDataBuffer buffer, Shape shape) { return new LongDenseNdArray(buffer, shape); } + @Override + public LongDataBuffer buffer() { + return buffer; + } + @Override public long getLong(long... indices) { return buffer.getLong(positionOf(indices, true)); @@ -77,11 +82,6 @@ LongDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimension return new LongDenseNdArray((LongDataBuffer)buffer, dimensions); } - @Override - protected LongDataBuffer buffer() { - return buffer; - } - private final LongDataBuffer buffer; private LongDenseNdArray(LongDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java index 291c01a..a44a81a 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java @@ -31,6 +31,11 @@ public static ShortNdArray create(ShortDataBuffer buffer, Shape shape) { return new ShortDenseNdArray(buffer, shape); } + @Override + public ShortDataBuffer buffer() { + return buffer; + } + @Override public short getShort(long... indices) { return buffer.getShort(positionOf(indices, true)); @@ -77,11 +82,6 @@ ShortDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensi return new ShortDenseNdArray((ShortDataBuffer)buffer, dimensions); } - @Override - protected ShortDataBuffer buffer() { - return buffer; - } - private final ShortDataBuffer buffer; private ShortDenseNdArray(ShortDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java new file mode 100644 index 0000000..d7bced0 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java @@ -0,0 +1,101 @@ +package org.tensorflow.ndarray.impl.dense.hydrator; + +import java.util.Iterator; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.hydrator.NdArrayHydrator; +import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray; +import org.tensorflow.ndarray.impl.sequence.PositionIterator; + +public class DenseNdArrayHydrator implements NdArrayHydrator { + + public DenseNdArrayHydrator(AbstractDenseNdArray> array) { + this.array = array; + } + + @Override + public Scalars byScalars(long... coordinates) { + return new ScalarsImpl(coordinates); + } + + @Override + public Vectors byVectors(long... coordinates) { + return new VectorsImpl(coordinates); + } + + @Override + public Elements byElements(long... coordinates) { + return new ElementsImpl(coordinates); + } + + class ScalarsImpl implements Scalars { + + public Scalars at(long... coordinates) { + positionIterator = Helpers.iterateByPosition(array, 0, coordinates); + return this; + } + + @Override + public Scalars put(T scalar) { + if (scalar == null) { + throw new IllegalArgumentException("Scalar value cannot be null"); + } + array.buffer().setObject(scalar, positionIterator.nextLong()); + return this; + } + + ScalarsImpl(long[] coordinates) { + positionIterator = Helpers.iterateByPosition(array, 0, coordinates); + } + + private PositionIterator positionIterator; + } + + class VectorsImpl implements Vectors { + + @Override + public Vectors at(long... coordinates) { + positionIterator = Helpers.iterateByPosition(array, 1, coordinates); + return this; + } + + @Override + public Vectors put(T... vector) { + Helpers.validateVectorLength(vector.length, array.shape()); + array.buffer().offset(positionIterator.nextLong()).write(vector); + return this; + } + + VectorsImpl(long[] coordinates) { + positionIterator = Helpers.iterateByPosition(array, 1, coordinates); + } + + private PositionIterator positionIterator; + } + + class ElementsImpl implements Elements { + + @Override + public Elements at(long... coordinates) { + this.elementIterator = Helpers.iterateByElement(array, coordinates); + return this; + } + + @Override + public Elements put(NdArray element) { + if (element == null) { + throw new IllegalArgumentException("Element cannot be null"); + } + element.copyTo(elementIterator.next()); + return this; + } + + ElementsImpl(long[] coordinates) { + this.elementIterator = Helpers.iterateByElement(array, coordinates); + } + + private Iterator> elementIterator; + } + + private final AbstractDenseNdArray> array; +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java new file mode 100644 index 0000000..087bb6f --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java @@ -0,0 +1,104 @@ +package org.tensorflow.ndarray.impl.dense.hydrator; + +import java.util.Iterator; + +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; +import org.tensorflow.ndarray.hydrator.NdArrayHydrator; +import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; +import org.tensorflow.ndarray.impl.sequence.PositionIterator; + +public class DoubleDenseNdArrayHydrator implements DoubleNdArrayHydrator { + + public DoubleDenseNdArrayHydrator(DoubleDenseNdArray array) { + this.array = array; + } + + @Override + public Scalars byScalars(long... coordinates) { + return new ScalarsImpl(coordinates); + } + + @Override + public Vectors byVectors(long... coordinates) { + return new VectorsImpl(coordinates); + } + + @Override + public Elements byElements(long... coordinates) { + return new ElementsImpl(coordinates); + } + + @Override + public NdArrayHydrator boxed() { + return new DenseNdArrayHydrator(array); + } + + class ScalarsImpl implements Scalars { + + public Scalars at(long... coordinates) { + positionIterator = Helpers.iterateByPosition(array, 0, coordinates); + return this; + } + + @Override + public Scalars put(double scalar) { + array.buffer().setObject(scalar, positionIterator.nextLong()); + return this; + } + + ScalarsImpl(long[] coordinates) { + positionIterator = Helpers.iterateByPosition(array, 0, coordinates); + } + + private PositionIterator positionIterator; + } + + class VectorsImpl implements Vectors { + + @Override + public Vectors at(long... coordinates) { + positionIterator = Helpers.iterateByPosition(array, 1, coordinates); + return this; + } + + @Override + public Vectors put(double... vector) { + Helpers.validateVectorLength(vector.length, array.shape()); + array.buffer().offset(positionIterator.nextLong()).write(vector); + return this; + } + + VectorsImpl(long[] coordinates) { + positionIterator = Helpers.iterateByPosition(array, 1, coordinates); + } + + private PositionIterator positionIterator; + } + + class ElementsImpl implements Elements { + + @Override + public Elements at(long... coordinates) { + this.elementIterator = Helpers.iterateByElement(array, coordinates); + return this; + } + + @Override + public Elements put(DoubleNdArray element) { + if (element == null) { + throw new IllegalArgumentException("Element cannot be null"); + } + element.copyTo(elementIterator.next()); + return this; + } + + ElementsImpl(long[] coordinates) { + this.elementIterator = Helpers.iterateByElement(array, coordinates); + } + + private Iterator elementIterator; + } + + private final DoubleDenseNdArray array; +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/Helpers.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/Helpers.java new file mode 100644 index 0000000..5c503c9 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/Helpers.java @@ -0,0 +1,52 @@ +package org.tensorflow.ndarray.impl.dense.hydrator; + +import java.util.Arrays; +import java.util.Iterator; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArraySequence; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sequence.IndexedPositionIterator; +import org.tensorflow.ndarray.impl.sequence.PositionIterator; + +final class Helpers { + + static PositionIterator iterateByPosition(AbstractDenseNdArray array, int elementRank, long[] coords) { + DimensionalSpace dimensions = array.dimensions(); + int dimensionIdx = dimensions.numDimensions() - elementRank - 1; + if (dimensionIdx < 0) { + throw new IllegalArgumentException("Cannot hydrate array of shape " + array.shape() + " with elements of rank " + elementRank); + } + if (coords == null || coords.length == 0) { + return PositionIterator.create(dimensions, dimensionIdx); + } + if ((coords.length - 1) != dimensionIdx) { + throw new IllegalArgumentException(Arrays.toString(coords) + " are not valid coordinates for dimension " + + dimensionIdx + " in an array of shape " + dimensions.shape()); + } + return PositionIterator.create(dimensions, coords); + } + + static > Iterator iterateByElement(AbstractDenseNdArray array, long[] coords) { + DimensionalSpace dimensions = array.dimensions(); + int dimensionIdx; + if (coords == null || coords.length == 0) { + return array.elements(0).iterator(); + } + if (coords.length > dimensions.numDimensions()) { + throw new IllegalArgumentException(Arrays.toString(coords) + " are not valid coordinates for an array of shape " + dimensions.shape()); + } + return array.elementsAt(coords).iterator(); + } + + static void validateVectorLength(int length, Shape shape) { + if (length == 0) { + throw new IllegalArgumentException("Vector cannot be empty"); + } + if (length > shape.get(-1)) { + throw new IllegalArgumentException("Vector cannot exceed " + shape.get(-1) + " elements"); + } + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java index 8c9c9f8..a020526 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java @@ -17,9 +17,11 @@ package org.tensorflow.ndarray.impl.sequence; -final class CoordinatesIncrementor { +import java.util.Arrays; - boolean increment() { +public final class CoordinatesIncrementor { + + public boolean increment() { for (int i = coords.length - 1; i >= 0; --i) { if ((coords[i] = (coords[i] + 1) % shape[i]) > 0) { return true; @@ -28,11 +30,19 @@ boolean increment() { return false; } - CoordinatesIncrementor(long[] shape, int dimensionIdx) { + public CoordinatesIncrementor(long[] shape, int dimensionIdx) { this.shape = shape; this.coords = new long[dimensionIdx + 1]; } - final long[] shape; - final long[] coords; + public CoordinatesIncrementor(long[] shape, long[] coords) { + if (coords.length == 0 || coords.length > shape.length) { + throw new IllegalArgumentException(); + } + this.shape = shape; + this.coords = Arrays.copyOf(coords, coords.length); + } + + public final long[] shape; + public final long[] coords; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java index 92cebeb..2430030 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java @@ -34,8 +34,12 @@ public final class FastElementSequence> implements NdArraySequence { public FastElementSequence(AbstractNdArray ndArray, int dimensionIdx, U element, DataBufferWindow elementWindow) { + this(ndArray, new long[dimensionIdx + 1], element, elementWindow); + } + + public FastElementSequence(AbstractNdArray ndArray, long[] startCoords, U element, DataBufferWindow elementWindow) { this.ndArray = ndArray; - this.dimensionIdx = dimensionIdx; + this.startCoords = startCoords; this.element = element; this.elementWindow = elementWindow; } @@ -47,7 +51,7 @@ public Iterator iterator() { @Override public void forEachIndexed(BiConsumer consumer) { - PositionIterator.createIndexed(ndArray.dimensions(), dimensionIdx).forEachIndexed((long[] coords, long position) -> { + PositionIterator.createIndexed(ndArray.dimensions(), startCoords).forEachIndexed((long[] coords, long position) -> { elementWindow.slideTo(position); consumer.accept(coords, element); }); @@ -55,7 +59,7 @@ public void forEachIndexed(BiConsumer consumer) { @Override public NdArraySequence asSlices() { - return new SlicingElementSequence(ndArray, dimensionIdx); + return new SlicingElementSequence(ndArray, startCoords); } private class SequenceIterator implements Iterator { @@ -71,11 +75,11 @@ public U next() { return element; } - private final PositionIterator positionIterator = PositionIterator.create(ndArray.dimensions(), dimensionIdx); + private final PositionIterator positionIterator = PositionIterator.create(ndArray.dimensions(), startCoords); } private final AbstractNdArray ndArray; - private final int dimensionIdx; + private final long[] startCoords; private final U element; private final DataBufferWindow elementWindow; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java index 80b3de6..c7a3dcf 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java @@ -17,6 +17,8 @@ package org.tensorflow.ndarray.impl.sequence; +import java.util.Arrays; + import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; class IndexedSequentialPositionIterator extends SequentialPositionIterator implements IndexedPositionIterator { @@ -24,28 +26,28 @@ class IndexedSequentialPositionIterator extends SequentialPositionIterator imple @Override public void forEachIndexed(CoordsLongConsumer consumer) { while (hasNext()) { - consumer.consume(coords, nextLong()); - incrementCoords(); + consumer.consume(coords, super.nextLong()); + dimensions.incrementCoordinates(coords); } } - private void incrementCoords() { - for (int i = coords.length - 1; i >= 0; --i) { - if (coords[i] < shape[i] - 1) { - coords[i] += 1L; - return; - } - coords[i] = 0L; - } + @Override + public long nextLong() { + long tmp = super.nextLong(); + dimensions.incrementCoordinates(coords); + return tmp; } IndexedSequentialPositionIterator(DimensionalSpace dimensions, int dimensionIdx) { - super(dimensions, dimensionIdx); - this.shape = dimensions.shape().asArray(); - this.coords = new long[dimensionIdx + 1]; - //this.coordsIncrementor = new CoordinatesIncrementor(dimensions.shape().asArray(), dimensionIdx); + this(dimensions, new long[dimensionIdx + 1]); + } + + IndexedSequentialPositionIterator(DimensionalSpace dimensions, long[] coords) { + super(dimensions, coords); + this.dimensions = dimensions; + this.coords = Arrays.copyOf(coords, coords.length); } - private final long[] shape; + private final DimensionalSpace dimensions; private final long[] coords; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java index 789474c..7dc2f96 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java @@ -17,6 +17,7 @@ package org.tensorflow.ndarray.impl.sequence; +import java.util.Arrays; import java.util.NoSuchElementException; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; @@ -33,7 +34,7 @@ public long nextLong() { throw new NoSuchElementException(); } long position = dimensions.positionOf(coords); - increment(); + incrementCoords(); return position; } @@ -41,28 +42,23 @@ public long nextLong() { public void forEachIndexed(CoordsLongConsumer consumer) { while (hasNext()) { consumer.consume(coords, dimensions.positionOf(coords)); - increment(); + incrementCoords(); } } - private void increment() { - if (!increment(coords, dimensions)) { + private void incrementCoords() { + if (!dimensions.incrementCoordinates(coords)) { coords = null; } } - static boolean increment(long[] coords, DimensionalSpace dimensions) { - for (int i = coords.length - 1; i >= 0; --i) { - if ((coords[i] = (coords[i] + 1) % dimensions.get(i).numElements()) > 0) { - return true; - } - } - return false; + NdPositionIterator(DimensionalSpace dimensions, int dimensionIdx) { + this(dimensions, new long[dimensionIdx + 1]); } - NdPositionIterator(DimensionalSpace dimensions, int dimensionIdx) { + NdPositionIterator(DimensionalSpace dimensions, long[] coords) { this.dimensions = dimensions; - this.coords = new long[dimensionIdx + 1]; + this.coords = Arrays.copyOf(coords, coords.length); } private final DimensionalSpace dimensions; diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java index 83ed940..a30c31c 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java @@ -17,6 +17,7 @@ package org.tensorflow.ndarray.impl.sequence; +import java.util.Arrays; import java.util.PrimitiveIterator; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; @@ -29,6 +30,13 @@ static PositionIterator create(DimensionalSpace dimensions, int dimensionIdx) { return new SequentialPositionIterator(dimensions, dimensionIdx); } + static PositionIterator create(DimensionalSpace dimensions, long... startCoords) { + if (dimensions.isSegmented()) { + return new NdPositionIterator(dimensions, startCoords); + } + return new SequentialPositionIterator(dimensions, startCoords); + } + static IndexedPositionIterator createIndexed(DimensionalSpace dimensions, int dimensionIdx) { if (dimensions.isSegmented()) { return new NdPositionIterator(dimensions, dimensionIdx); @@ -36,6 +44,13 @@ static IndexedPositionIterator createIndexed(DimensionalSpace dimensions, int di return new IndexedSequentialPositionIterator(dimensions, dimensionIdx); } + static IndexedPositionIterator createIndexed(DimensionalSpace dimensions, long... startCoords) { + if (dimensions.isSegmented()) { + return new NdPositionIterator(dimensions, startCoords); + } + return new IndexedSequentialPositionIterator(dimensions, startCoords); + } + static PositionIterator sequence(long stride, long end) { return new SequentialPositionIterator(stride, end); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java index 65c6fc9..71332ea 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java @@ -17,6 +17,7 @@ package org.tensorflow.ndarray.impl.sequence; +import java.util.Arrays; import java.util.NoSuchElementException; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; @@ -24,7 +25,7 @@ class SequentialPositionIterator implements PositionIterator { @Override public boolean hasNext() { - return index < end; + return pos < end; } @Override @@ -32,7 +33,7 @@ public long nextLong() { if (!hasNext()) { throw new NoSuchElementException(); } - return stride * index++; + return stride * pos++; } SequentialPositionIterator(DimensionalSpace dimensions, int dimensionIdx) { @@ -42,6 +43,12 @@ public long nextLong() { } this.stride = dimensions.get(dimensionIdx).elementSize(); this.end = size; + this.pos = 0L; + } + + SequentialPositionIterator(DimensionalSpace dimensions, long[] coords) { + this(dimensions, coords.length - 1); + this.pos = dimensions.positionOf(coords) / stride; } SequentialPositionIterator(long stride, long end) { @@ -51,5 +58,5 @@ public long nextLong() { private final long stride; private final long end; - private long index; + private long pos; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java index 6fe8398..6d11968 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java @@ -33,18 +33,26 @@ public final class SlicingElementSequence> implements NdArraySequence { public SlicingElementSequence(AbstractNdArray ndArray, int dimensionIdx) { - this(ndArray, dimensionIdx, ndArray.dimensions().from(dimensionIdx + 1)); + this(ndArray, new long[dimensionIdx + 1]); + } + + public SlicingElementSequence(AbstractNdArray ndArray, long[] startCoords) { + this(ndArray, startCoords, ndArray.dimensions().from(startCoords.length)); } public SlicingElementSequence(AbstractNdArray ndArray, int dimensionIdx, DimensionalSpace elementDimensions) { + this(ndArray, new long[dimensionIdx + 1], elementDimensions); + } + + public SlicingElementSequence(AbstractNdArray ndArray, long[] startCoords, DimensionalSpace elementDimensions) { this.ndArray = ndArray; - this.dimensionIdx = dimensionIdx; + this.startCoords = startCoords; this.elementDimensions = elementDimensions; } @Override public Iterator iterator() { - PositionIterator positionIterator = PositionIterator.create(ndArray.dimensions(), dimensionIdx); + PositionIterator positionIterator = PositionIterator.create(ndArray.dimensions(), startCoords); return new Iterator() { @Override @@ -61,7 +69,7 @@ public U next() { @Override public void forEachIndexed(BiConsumer consumer) { - PositionIterator.createIndexed(ndArray.dimensions(), dimensionIdx).forEachIndexed((long[] coords, long position) -> + PositionIterator.createIndexed(ndArray.dimensions(), startCoords).forEachIndexed((long[] coords, long position) -> consumer.accept(coords, ndArray.slice(position, elementDimensions)) ); } @@ -72,6 +80,6 @@ public NdArraySequence asSlices() { } private final AbstractNdArray ndArray; - private final int dimensionIdx; + private final long[] startCoords; private final DimensionalSpace elementDimensions; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java index daffad9..984e7e2 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java @@ -14,6 +14,11 @@ =======================================================================*/ package org.tensorflow.ndarray.impl.sparse; +import java.nio.ReadOnlyBufferException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.LongStream; import org.tensorflow.ndarray.IllegalRankException; import org.tensorflow.ndarray.LongNdArray; import org.tensorflow.ndarray.NdArray; @@ -30,12 +35,6 @@ import org.tensorflow.ndarray.impl.sequence.SlicingElementSequence; import org.tensorflow.ndarray.index.Index; -import java.nio.ReadOnlyBufferException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.stream.LongStream; - /** * Abstract base class for sparse array. * diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydrator.java new file mode 100644 index 0000000..f99631c --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydrator.java @@ -0,0 +1,126 @@ +package org.tensorflow.ndarray.impl.sparse.hydrator; + +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; +import org.tensorflow.ndarray.hydrator.NdArrayHydrator; +import org.tensorflow.ndarray.impl.sparse.DoubleSparseNdArray; + +public class DoubleSparseNdArrayHydrator implements DoubleNdArrayHydrator { + + public DoubleSparseNdArrayHydrator(DoubleSparseNdArray array) { + this.array = array; + } + + @Override + public Scalars byScalars(long... coordinates) { + return new ScalarsImpl(coordinates); + } + + @Override + public Vectors byVectors(long... coordinates) { + return new VectorsImpl(coordinates); + } + + @Override + public Elements byElements(long... coordinates) { + return new ElementsImpl(coordinates); + } + + @Override + public NdArrayHydrator boxed() { + return new SparseNdArrayHydrator(array); + } + + private class ScalarsImpl implements Scalars { + + @Override + public Scalars at(long... coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); + return this; + } + + @Override + public Scalars put(double scalar) { + addValue(scalar, coordinates); + array.dimensions().incrementCoordinates(coordinates); + return this; + } + + private long[] coordinates; + + private ScalarsImpl(long[] coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); + } + } + + private class VectorsImpl implements Vectors { + + @Override + public Vectors at(long... coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 1); + return this; + } + + @Override + public Vectors put(double... vector) { + if (vector.length == 0 || vector.length > array.shape().get(-1)) { + throw new IllegalArgumentException("Vector cannot be null nor exceed " + array.shape().get(-1) + " elements"); + } + for (int i = 0; i < vector.length; ++i) { + addValue(vector[i], coordinates, i); + } + array.dimensions().incrementCoordinates(coordinates); + return this; + } + + private long[] coordinates; + + private VectorsImpl(long[] coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 1); + } + } + + private class ElementsImpl implements Elements { + + @Override + public Elements at(long... coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates); + return this; + } + + @Override + public Elements put(DoubleNdArray element) { + if (element == null) { + throw new IllegalArgumentException("Element cannot be null"); + } + if (element.shape().isScalar()) { + addValue(element.getDouble(), coordinates); + } else { + element.scalars().forEachIndexed((scalarCoords, scalar) -> { + addValue(scalar.getDouble(), coordinates, scalarCoords); + }); + } + array.dimensions().incrementCoordinates(coordinates); + return this; + } + + private long[] coordinates; + + private ElementsImpl(long[] coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates); + } + } + + private final DoubleSparseNdArray array; + private long valueCount = 0; + + private void addValue(double value, long[] origin, long... coords) { + if (value != array.getDefaultValue()) { + array.getValues().setDouble(value, valueCount); + Helpers.writeValueCoords(array, valueCount, origin, coords); + ++valueCount; + } + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/Helpers.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/Helpers.java new file mode 100644 index 0000000..440855f --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/Helpers.java @@ -0,0 +1,49 @@ +package org.tensorflow.ndarray.impl.sparse.hydrator; + +import java.util.Arrays; + +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; + +final class Helpers { + + static long[] validateCoordinates(AbstractSparseNdArray array, long[] coords, int elementRank) { + DimensionalSpace dimensions = array.dimensions(); + int dimensionIdx = 0; + if (elementRank >= 0) { + dimensionIdx = dimensions.numDimensions() - elementRank - 1; + if (dimensionIdx < 0) { + throw new IllegalArgumentException("Cannot hydrate array of shape " + array.shape() + " with elements of rank " + elementRank); + } + } + if (coords == null || coords.length == 0) { + return new long[dimensionIdx + 1]; + } + if ((coords.length - 1) != dimensionIdx) { + throw new IllegalArgumentException(Arrays.toString(coords) + " are not valid coordinates for dimension " + + dimensionIdx + " in an array of shape " + dimensions.shape()); + } + return Arrays.copyOf(coords, coords.length); + } + + static long[] validateCoordinates(AbstractSparseNdArray array, long[] coords) { + if (coords == null || coords.length == 0) { + return new long[1]; + } + int dimensionIdx = array.shape().numDimensions() - coords.length; + if (dimensionIdx < 0) { + throw new IllegalArgumentException("Cannot hydrate array of shape " + array.shape() + " with elements of rank " + (coords.length - 1)); + } + return Arrays.copyOf(coords, coords.length); + } + + static void writeValueCoords(AbstractSparseNdArray array, long valueIndex, long[] origin, long[] coords) { + int coordsIndex = 0; + for (long c: origin) { + array.getIndices().setLong(c, valueIndex, coordsIndex++); + } + for (long c: coords) { + array.getIndices().setLong(c, valueIndex, coordsIndex++); + } + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java new file mode 100644 index 0000000..5d183f5 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java @@ -0,0 +1,120 @@ +package org.tensorflow.ndarray.impl.sparse.hydrator; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.hydrator.NdArrayHydrator; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; + +public class SparseNdArrayHydrator implements NdArrayHydrator { + + public SparseNdArrayHydrator(AbstractSparseNdArray array) { + this.array = array; + } + + @Override + public Scalars byScalars(long... coordinates) { + return new ScalarsImpl(coordinates); + } + + @Override + public Vectors byVectors(long... coordinates) { + return new VectorsImpl(coordinates); + } + + @Override + public Elements byElements(long... coordinates) { + return new ElementsImpl(coordinates); + } + + private class ScalarsImpl implements Scalars { + + @Override + public Scalars at(long... coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); + return this; + } + + @Override + public Scalars put(T scalar) { + if (scalar == null) { + throw new IllegalArgumentException("Scalar cannot be null"); + } + if (scalar != array.getDefaultValue()) { + array.getValues().setObject(scalar, index); + array.getIndices().set(NdArrays.vectorOf(coordinates), index++); + } + array.dimensions().incrementCoordinates(coordinates); + return this; + } + + protected ScalarsImpl(long[] coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); + } + + protected long[] coordinates; + } + + private class VectorsImpl implements Vectors { + + @Override + public Vectors at(long... coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 1); + return this; + } + + @Override + public Vectors put(T... vector) { + if (vector.length == 0 || vector.length > array.shape().get(-1)) { + throw new IllegalArgumentException("Vector cannot be null nor exceed " + array.shape().get(-1) + " elements"); + } + for (T value : vector) { + if (value != array.getDefaultValue()) { + array.getValues().setObject(value, index); + array.getIndices().set(NdArrays.vectorOf(coordinates), index++); + } + array.dimensions().incrementCoordinates(coordinates); + } + return this; + } + + protected VectorsImpl(long[] coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); + } + + protected long[] coordinates; + } + + private class ElementsImpl implements Elements { + + @Override + public Elements at(long... coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, coordinates.length - 1); + return this; + } + + @Override + public Elements put(NdArray element) { + if (element == null) { + throw new IllegalArgumentException("Array cannot be null"); + } + element.scalars().forEach(s -> { + T value = s.getObject(); + if (value != array.getDefaultValue()) { + array.getValues().setObject(value, index); + array.getIndices().set(NdArrays.vectorOf(coordinates), index++); + } + array.dimensions().incrementCoordinates(coordinates); + }); + return this; + } + + protected ElementsImpl(long[] coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, coordinates.length - 1); + } + + protected long[] coordinates; + } + + private final AbstractSparseNdArray array; + private long index = 0; +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydratorTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydratorTestBase.java new file mode 100644 index 0000000..69b38f9 --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydratorTestBase.java @@ -0,0 +1,141 @@ +package org.tensorflow.ndarray.hydrator; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.function.Consumer; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; + +public abstract class DoubleNdArrayHydratorTestBase { + + protected abstract DoubleNdArray newArray(Shape shape, long numValues, Consumer hydrate); + + @Test + public void hydrateNdArrayByScalars() { + DoubleNdArray array = newArray(Shape.of(3, 2, 3), 14, hydrator -> { + hydrator + .byScalars() + .put(0.0) + .put(0.1) + .put(0.2) + .put(0.3) + .put(0.4) + .put(0.5) + .put(1.0) + .put(1.1) + .put(1.2) + .at(2, 0, 0) + .put(2.0) + .put(2.1) + .put(2.2) + .put(2.3) + .put(2.4) + .put(2.5); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][]{ + {{0.0, 0.1, 0.2}, {0.3, 0.4, 0.5}}, + {{1.0, 1.1, 1.2}, {0.0, 0.0, 0.0}}, + {{2.0, 2.1, 2.2}, {2.3, 2.4, 2.5}} + }), array); + + array = newArray(Shape.of(3, 2), 4, hydrator -> { + hydrator + .byScalars() + .put(10.0) + .put(20.0) + .put(30.0) + .at(2, 1) + .put(40.0); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][]{{10.0, 20.0}, {30.0, 0.0}, {0.0, 40.0}}), array); + } + + @Test + public void hydrateNdArrayByVectors() { + DoubleNdArray array = newArray(Shape.of(3, 2, 3), 14, hydrator -> { + hydrator + .byVectors() + .put(0.0, 0.1, 0.2) + .put(0.3, 0.4, 0.5) + .put(1.0, 1.1, 1.2) + .at(2, 0) + .put(2.0, 2.1, 2.2) + .put(2.3, 2.4, 2.5); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][]{ + {{0.0, 0.1, 0.2}, {0.3, 0.4, 0.5}}, + {{1.0, 1.1, 1.2}, {0.0, 0.0, 0.0}}, + {{2.0, 2.1, 2.2}, {2.3, 2.4, 2.5}} + }), array); + + array = newArray(Shape.of(3, 2), 5, hydrator -> { + hydrator + .byVectors() + .put(10.0, 20.0) + .put(30.0) + .at(2) + .put(40.0, 50.0); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][]{{10.0, 20.0}, {30.0, 0.0}, {40.0, 50.0}}), array); + } + + @Test + public void vectorCannotBeEmpty() { + try { + newArray(Shape.of(3, 2), 1, hydrator -> hydrator.byVectors().put()); + fail(); + } catch (IllegalArgumentException e) { + // ok + } + } + + @Test + public void hydrateNdArrayByElements() { + DoubleNdArray array = newArray(Shape.of(3, 2, 3), 14, hydrator -> { + hydrator + .byElements() + .put(StdArrays.ndCopyOf(new double[][]{ + {0.0, 0.1, 0.2}, + {0.3, 0.4, 0.5} + })) + .at(1, 0) + .put(NdArrays.vectorOf(1.0, 1.1, 1.2)) + .at(2) + .put(StdArrays.ndCopyOf(new double[][]{ + {2.0, 2.1, 2.2}, + {2.3, 2.4, 2.5} + })); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][]{ + {{0.0, 0.1, 0.2}, {0.3, 0.4, 0.5}}, + {{1.0, 1.1, 1.2}, {0.0, 0.0, 0.0}}, + {{2.0, 2.1, 2.2}, {2.3, 2.4, 2.5}} + }), array); + + DoubleNdArray vector = NdArrays.vectorOf(10.0, 20.0); + DoubleNdArray scalar = NdArrays.scalarOf(30.0); + + array = newArray(Shape.of(4, 2), 7, hydrator -> { + hydrator + .byElements() + .put(vector) + .put(vector) + .at(2, 1) + .put(scalar) + .at(3) + .put(vector); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][]{{10.0, 20.0}, {10.0, 20.0}, {0.0, 30.0}, {10.0, 20.0}}), array); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydratorTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydratorTest.java new file mode 100644 index 0000000..63badf4 --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydratorTest.java @@ -0,0 +1,24 @@ +package org.tensorflow.ndarray.impl.dense.hydrator; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.function.Consumer; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; +import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydratorTestBase; +import org.tensorflow.ndarray.hydrator.NdArrayHydrator; +import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; + +public class DoubleDenseNdArrayHydratorTest extends DoubleNdArrayHydratorTestBase { + + @Override + protected DoubleNdArray newArray(Shape shape, long numValues, Consumer hydrate) { + return NdArrays.ofDoubles(shape, hydrate); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java index bad7840..3ad462d 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java @@ -98,7 +98,7 @@ public void slicingElementSequenceReturnsUniqueInstances() { public void fastElementSequenceReturnsSameInstance() { IntNdArray array = NdArrays.ofInts(Shape.of(2, 3, 2)); IntNdArray element = array.get(0); - NdArraySequence sequence = new FastElementSequence( + NdArraySequence sequence = new FastElementSequence( (AbstractNdArray) array, 1, element, mockDataBufferWindow(2)); sequence.forEach(e -> { if (e != element) { diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydratorTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydratorTest.java new file mode 100644 index 0000000..d76922d --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydratorTest.java @@ -0,0 +1,23 @@ +package org.tensorflow.ndarray.impl.sparse.hydrator; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.function.Consumer; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; +import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydratorTestBase; +import org.tensorflow.ndarray.hydrator.NdArrayHydrator; + +public class DoubleSparseNdArrayHydratorTest extends DoubleNdArrayHydratorTestBase { + + @Override + protected DoubleNdArray newArray(Shape shape, long numValues, Consumer hydrate) { + return NdArrays.sparseOfDoubles(shape, numValues, hydrate); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydratorTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydratorTest.java new file mode 100644 index 0000000..8ae36f5 --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydratorTest.java @@ -0,0 +1,85 @@ +package org.tensorflow.ndarray.impl.sparse.hydrator; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; + +public class SparseNdArrayHydratorTest { + + @Test + public void hydrateNdArrayByScalars() { + DoubleNdArray array = NdArrays.sparseOfDoubles(15, Shape.of(3, 2, 3), hydrator -> { + hydrator + .byScalars() + .put(0.0) + .put(0.1) + .put(0.2) + .put(0.3) + .put(0.4) + .put(0.5) + .put(1.0) + .put(1.1) + .put(1.2) + .at(2, 0, 0) + .put(2.0) + .put(2.1) + .put(2.2) + .put(2.3) + .put(2.4) + .put(2.5); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][] { + {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, + {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, + {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} + }), array); + } + + @Test + public void hydrateNdArrayByVectors() { + DoubleNdArray array = NdArrays.sparseOfDoubles(15, Shape.of(3, 2, 3), hydrator -> { + hydrator.byVectors() + .put(0.0, 0.1, 0.2) + .put(0.3, 0.4, 0.5) + .put(1.0, 1.1, 1.2) + .at(2, 0) + .put(2.0, 2.1, 2.2) + .put(2.3, 2.4, 2.5); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][] { + {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, + {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, + {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} + }), array); + } + + @Test + public void hydrateNdArrayByElements() { + DoubleNdArray array = NdArrays.sparseOfDoubles(15, Shape.of(3, 2, 3), hydrator -> { + hydrator.byElements() + .put(StdArrays.ndCopyOf(new double[][] { + { 0.0, 0.1, 0.2 }, + { 0.3, 0.4, 0.5 } + })) + .at(1, 0) + .put(NdArrays.vectorOf(1.0, 1.1, 1.2)) + .at(2) + .put(StdArrays.ndCopyOf(new double[][] { + { 2.0, 2.1, 2.2 }, + { 2.3, 2.4, 2.5 } + })); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][] { + {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, + {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, + {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} + }), array); + } +}