Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/_sources/api/creation/operators/ones.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ ones

Generate an operator of ones

ones() has both a shapeless and a shaped version. The shapeless version is preferred
in contexts where the shape can be deduced by the expression, thus simplifying the code.
If the shape cannot be deducded, the explicit shape version is used to specify the shape
directly.

.. doxygenfunction:: matx::ones()
.. doxygenfunction:: matx::ones(ShapeType &&s)
.. doxygenfunction:: matx::ones(const index_t (&s)[RANK])

Expand All @@ -17,3 +23,10 @@ Examples
:end-before: example-end ones-gen-test-1
:dedent:

.. literalinclude:: ../../../../test/00_operators/GeneratorTests.cu
:language: cpp
:start-after: example-begin ones-gen-test-2
:end-before: example-end ones-gen-test-2
:dedent:


6 changes: 6 additions & 0 deletions docs_input/api/creation/operators/zeros.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ zeros

Generate an operator of zeros

zeros() has both a shapeless and a shaped version. The shapeless version is preferred
in contexts where the shape can be deduced by the expression, thus simplifying the code.
If the shape cannot be deducded, the explicit shape version is used to specify the shape
directly.

.. doxygenfunction:: matx::zeros()
.. doxygenfunction:: matx::zeros(ShapeType &&s)
.. doxygenfunction:: matx::zeros(const index_t (&s)[RANK])

Expand Down
36 changes: 27 additions & 9 deletions include/matx/generators/ones.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ namespace matx
* Shape of tensor
*/
template <typename T = int, typename ShapeType,
std::enable_if_t<!std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
inline auto ones(ShapeType &&s)
{
return detail::ConstVal<T, ShapeType>(std::forward<ShapeType>(s), T(1));
}
std::enable_if_t<!std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
inline auto ones(ShapeType &&s)
{
return detail::ConstVal<T, ShapeType>(std::forward<ShapeType>(s), T(1));
}

/**
* Return one for all elements
Expand All @@ -68,9 +68,27 @@ namespace matx
* Shape of tensor
*/
template <typename T = int, int RANK>
inline auto ones(const index_t (&s)[RANK])
{
return ones<T>(detail::to_array(s));
}
inline auto ones(const index_t (&s)[RANK])
{
return ones<T>(detail::to_array(s));
}

/**
* Return one for all elements
*
* Ones is used as an operator that always returns a 1 type for all
* elements. It can be used in place of memset to set all values to 1.
* This version of ones() is shapeless and can be used in contexts where the shape
* can be deduced.
*
* @tparam T
* Data type
*
*/
template <typename T = int>
inline auto ones()
{
return ones<T, NoShape>(NoShape{});
}

} // end namespace matx
36 changes: 27 additions & 9 deletions include/matx/generators/zeros.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ namespace matx
* Shape of tensor
*/
template <typename T = int, typename ShapeType,
std::enable_if_t<!std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
inline auto zeros(ShapeType &&s)
{
return detail::ConstVal<T, ShapeType>(std::forward<ShapeType>(s), T(0));
}
std::enable_if_t<!std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
inline auto zeros(ShapeType &&s)
{
return detail::ConstVal<T, ShapeType>(std::forward<ShapeType>(s), T(0));
}

/**
* Return zero for all elements
Expand All @@ -67,8 +67,26 @@ namespace matx
* Shape of tensor
*/
template <typename T = int, int RANK>
inline auto zeros(const index_t (&s)[RANK])
{
return zeros<T>(detail::to_array(s));
}
inline auto zeros(const index_t (&s)[RANK])
{
return zeros<T>(detail::to_array(s));
}

/**
* Return zeros for all elements
*
* zeros is used as an operator that always returns a 0 type for all
* elements. It can be used in place of memset to set all values to 0.
* This version of zeros() is shapeless and can be used in contexts where the shape
* can be deduced.
*
* @tparam T
* Data type
*
*/
template <typename T = int>
inline auto zeros()
{
return zeros<T, NoShape>(NoShape{});
}
} // end namespace matx
19 changes: 15 additions & 4 deletions include/matx/operators/constval.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,22 @@ namespace matx
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ T operator()(Is...) const {
return v_; };

constexpr inline __MATX_HOST__ __MATX_DEVICE__ auto Size(int dim) const
{
return *(s_.begin() + dim);
constexpr inline __MATX_HOST__ __MATX_DEVICE__ auto Size(int dim) const {
if constexpr (!is_noshape_v<ShapeType>) {
return *(s_.begin() + dim);
}
else {
return index_t(0);
}
}
static inline constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() {
if constexpr (!is_noshape_v<ShapeType>) {
return RANK;
}
else {
return matxNoRank;
}
}
static inline constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() { return RANK; }
};
}
} // end namespace matx
2 changes: 1 addition & 1 deletion include/matx/transforms/cov.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ template <typename TensorTypeC, typename TensorTypeA> class matxCovHandle_t {
make_tensor(devsT, tmp, MATX_ASYNC_DEVICE_MEMORY, stream);

// Populate our ones matrix
(onesM = ones({a.Size(RANK - 2), a.Size(RANK - 2)})).run(stream);
(onesM = ones()).run(stream);
}

static CovParams_t GetCovParams([[maybe_unused]] TensorTypeC &c, const TensorTypeA &a, cudaStream_t stream = 0)
Expand Down
8 changes: 6 additions & 2 deletions test/00_operators/GeneratorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,10 @@ TEST(OperatorTests, Kron)
tensor_t<dtype, 2> ov2({4, 6});
av.SetVals({{1, 2, 3}, {4, 5, 6}});

// example-begin ones-gen-test-2
// Explicit shape specified in ones()
(ov2 = kron(av, ones({2, 2}))).run(exec);
// example-end ones-gen-test-2
cudaStreamSynchronize(0);
MATX_TEST_ASSERT_COMPARE(pb, ov2, "rect", 0);

Expand Down Expand Up @@ -330,7 +333,7 @@ TYPED_TEST(BasicGeneratorTestsAll, Zeros)
std::array<index_t, 1> s({count});
auto t1 = make_tensor<TestType>(s);

(t1 = zeros(s)).run(exec);
(t1 = zeros()).run(exec);
// example-end zeros-gen-test-1

cudaStreamSynchronize(0);
Expand All @@ -357,7 +360,7 @@ TYPED_TEST(BasicGeneratorTestsAll, Ones)
std::array<index_t, 1> s({count});
auto t1 = make_tensor<TestType>(s);

(t1 = ones(s)).run(exec);
(t1 = ones()).run(exec);
// example-end ones-gen-test-1
cudaStreamSynchronize(0);

Expand All @@ -369,6 +372,7 @@ TYPED_TEST(BasicGeneratorTestsAll, Ones)
EXPECT_TRUE(MatXUtils::MatXTypeCompare(t1(i), (TestType)1));
}
}

MATX_EXIT_HANDLER();
}

Expand Down
2 changes: 1 addition & 1 deletion test/00_operators/OperatorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1559,7 +1559,7 @@ TYPED_TEST(OperatorTestsNumericAllExecs, RemapRankZero)
// 2D source tensor cases
{
auto from = make_tensor<int>({N,N});
(from = ones(from.Shape())).run(exec);
(from = ones()).run(exec);
sync();

auto i0 = make_tensor<int>({});
Expand Down
4 changes: 2 additions & 2 deletions test/00_tensor/BasicTensorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ TYPED_TEST(BasicTensorTestsAll, Print)
using TestType = std::tuple_element_t<0, TypeParam>;

auto t = make_tensor<TestType>({3});
(t = ones(t.Shape())).run(this->exec);
(t = ones()).run(this->exec);
print(t);

MATX_EXIT_HANDLER();
Expand All @@ -474,7 +474,7 @@ TYPED_TEST(BasicTensorTestsAll, DevicePrint)
using TestType = std::tuple_element_t<0, TypeParam>;

auto t = make_tensor<TestType>({3}, MATX_DEVICE_MEMORY);
(t = ones(t.Shape())).run(this->exec);
(t = ones()).run(this->exec);
print(t);

MATX_EXIT_HANDLER();
Expand Down
14 changes: 7 additions & 7 deletions test/00_tensor/EinsumTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ TYPED_TEST(EinsumTestsFloatNonComplexNonHalfTypes, GEMM)
auto b2 = make_tensor<TestType>({20,10});
auto c2 = make_tensor<TestType>({10,10});
auto c22 = make_tensor<TestType>({10,10});
(a2 = ones(a2.Shape())).run(exec);
(b2 = ones(b2.Shape())).run(exec);
(a2 = ones()).run(exec);
(b2 = ones()).run(exec);

// Perform a GEMM of a2 * b2. Compare results to traditional matmul call
(c2 = cutensor::einsum("mk,kn->mn", a2, b2)).run(exec);
Expand Down Expand Up @@ -196,8 +196,8 @@ TYPED_TEST(EinsumTestsFloatNonComplexNonHalfTypes, GEMMTranspose)
auto b2 = make_tensor<TestType>({20,10});
auto c2 = make_tensor<TestType>({10,5});
auto c22 = make_tensor<TestType>({5,10});
(a2 = ones(a2.Shape())).run(exec);
(b2 = ones(b2.Shape())).run(exec);
(a2 = ones()).run(exec);
(b2 = ones()).run(exec);

// Perform a GEMM of a2 * b2 and store the results transposed
(c2 = cutensor::einsum("mk,kn->nm", a2, b2)).run(exec);
Expand Down Expand Up @@ -225,8 +225,8 @@ TYPED_TEST(EinsumTestsFloatNonComplexNonHalfTypes, Permute)
auto a = make_tensor<TestType>({5,20,4,3});
auto b = make_tensor<TestType>({20,3,4,5});
auto b2 = make_tensor<TestType>({20,3,4,5});
(a = ones(a.Shape())).run(exec);
(b = ones(b.Shape())).run(exec);
(a = ones()).run(exec);
(b = ones()).run(exec);

// Permute a 4D tensor. This gives the same output as Permute, but is much faster
(b = cutensor::einsum("ijkl->jlki", a)).run(exec);
Expand Down Expand Up @@ -287,7 +287,7 @@ TYPED_TEST(EinsumTestsFloatNonComplexNonHalfTypes, Trace)
auto a2 = make_tensor<TestType>({10,10});
auto c0_0 = make_tensor<TestType>({});
auto c0_1 = make_tensor<TestType>({});
(a2 = ones(a2.Shape())).run(exec);
(a2 = ones()).run(exec);

// Perform a GEMM of a2 * b2. Compare results to traditional matmul call
(c0_0 = cutensor::einsum("ii->", a2)).run(exec);
Expand Down