diff --git a/docs_input/api/creation/operators/diag.rst b/docs_input/api/creation/operators/diag.rst index 325406513..9108cc3bf 100644 --- a/docs_input/api/creation/operators/diag.rst +++ b/docs_input/api/creation/operators/diag.rst @@ -4,11 +4,11 @@ diag ==== `diag` comes in two forms: a generator and an operator. The generator version is used to generate a diagonal -tensor with a given value, while the operator pulls diagonal elements from a tensor. +tensor with a given value or a 1D input operator, while the operator pulls diagonal elements from a tensor. Operator ________ -.. doxygenfunction:: matx::diag(T1 t) +.. doxygenfunction:: diag(T1 t, index_t k = 0) Examples ~~~~~~~~ @@ -19,10 +19,23 @@ Examples :end-before: example-end diag-op-test-1 :dedent: +.. literalinclude:: ../../../../test/00_operators/GeneratorTests.cu + :language: cpp + :start-after: example-begin diag-op-test-2 + :end-before: example-end diag-op-test-2 + :dedent: + +.. literalinclude:: ../../../../test/00_operators/GeneratorTests.cu + :language: cpp + :start-after: example-begin diag-op-test-3 + :end-before: example-end diag-op-test-3 + :dedent: + Generator _________ .. doxygenfunction:: matx::diag(const index_t (&s)[RANK], T val) +.. doxygenfunction:: matx::diag(T val) Examples ~~~~~~~~ diff --git a/include/matx/generators/diag.h b/include/matx/generators/diag.h index 20d4bf109..6bed59519 100644 --- a/include/matx/generators/diag.h +++ b/include/matx/generators/diag.h @@ -121,7 +121,8 @@ namespace matx * */ template ::type>, bool> = true> + std::enable_if_t::type> && + !is_matx_op(), bool> = true> inline auto diag(ShapeType &&s, T val) { return detail::Diag(std::forward(s), val); diff --git a/include/matx/operators/diag.h b/include/matx/operators/diag.h index 4667b7d8a..bba89f1ad 100644 --- a/include/matx/operators/diag.h +++ b/include/matx/operators/diag.h @@ -52,6 +52,7 @@ namespace matx { private: typename base_type::type op_; + index_t k_; public: using matxop = bool; @@ -59,32 +60,74 @@ namespace matx __MATX_INLINE__ std::string str() const { return "diag(" + op_.str() + ")"; } - __MATX_INLINE__ DiagOp(T1 op) : op_(op) {} + __MATX_INLINE__ DiagOp(T1 op, index_t k) : op_(op), k_(k) { } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const { - static_assert(sizeof...(Is) == RANK - 1, "Diagonal operator must have one fewer index than rank of operator"); - static_assert(RANK > 1, "Cannot make get diagonals from 0D tensor"); - + static_assert(RANK != 0, "Cannot make get diagonals from 0D tensor"); using tt = cuda::std::tuple_element_t<0, cuda::std::tuple>; - auto tup = cuda::std::make_tuple(indices..., static_cast(0)); - cuda::std::get(tup) = pp_get(indices...); - return cuda::std::apply(op_, tup); + + if constexpr (RANK == 1) { + static_assert(sizeof...(Is) == 2, "Indexing of diag() on a 1D input must be 2 indices"); + if (((pp_get<0>(indices...) == indices) && ...)) { + return (value_type)(pp_get<0>(indices...)); + } + else { + return (value_type)(0); + } + } + else { + static_assert(sizeof...(Is) == RANK - 1, "Diagonal operator must have one fewer op() index than rank of operator"); + + // Offset either the rows or columns by k_, depending on if it's negative + if (k_ < 0) { + auto tup = cuda::std::make_tuple(indices..., static_cast(0)); + cuda::std::get(tup) = pp_get(indices...) ; + cuda::std::get(tup) = cuda::std::get(tup) - k_; + return cuda::std::apply(op_, tup); + } + else { + auto tup = cuda::std::make_tuple(indices..., static_cast(0)); + cuda::std::get(tup) = pp_get(indices...) + k_; + return cuda::std::apply(op_, tup); + } + } } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() { - return RANK - 1; + if constexpr (RANK == 1) { + return 2; + } + else { + return RANK - 1; + } } constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size([[maybe_unused]] int dim) const { - if (dim < RANK - 2) { - return op_.Size(dim); + if constexpr (RANK == 1) { + return op_.Size(0); } else { - return cuda::std::min(op_.Size(RANK - 1), op_.Size(RANK-2)); + if (dim < RANK - 2) { + return op_.Size(dim); + } + else { + if (k_ == 0) { + return cuda::std::min(op_.Size(RANK - 1), op_.Size(RANK-2)); + } + else { + // If k is off the main diagonal we need to adjust the sizes + if (k_ > 0) { + return cuda::std::min(op_.Size(RANK - 1), op_.Size(RANK-2) - k_); + } + else { + return cuda::std::min(op_.Size(RANK - 1) + k_, op_.Size(RANK-2)); + } + } + } } } @@ -107,11 +150,21 @@ namespace matx } /** - * Get the elements on the diagonal + * Get the elements on the diagonal (2D inputs and above), or generate a diagonal matrix (1D input) * * @param t * Input operator + * @param k + * Diagonal to pull (0 is the main diagonal). Only used for 2D tensors and above */ - template - auto __MATX_INLINE__ diag(T1 t) { return detail::DiagOp(t); } +#ifdef DOXYGEN_ONLY + auto __MATX_INLINE__ diag(T1 t, index_t k = 0) { +#else + template (), bool> = true> + auto __MATX_INLINE__ diag(T1 t, index_t k = 0) { +#endif + MATX_ASSERT_STR(T1::Rank() != 1 || k == 0, matxInvalidParameter, + "k parameter in diag() can only be used for 2D tensors and above"); + return detail::DiagOp(t, k); + } } // end namespace matx diff --git a/test/00_operators/GeneratorTests.cu b/test/00_operators/GeneratorTests.cu index 8ffb1e60e..742747f19 100644 --- a/test/00_operators/GeneratorTests.cu +++ b/test/00_operators/GeneratorTests.cu @@ -145,10 +145,11 @@ TYPED_TEST(BasicGeneratorTestsAll, Diag) MATX_ENTER_HANDLER(); { // example-begin diag-op-test-1 - // The generator form of `diag()` takes an operator input and returns only + // The generator form of `diag()` with a >= 2D input takes an operator input and returns only // the diagonal elements as output auto tc = make_tensor({10, 10}); auto td = make_tensor({10}); + auto tdk = make_tensor({4}); // Initialize the diagonal elements of `tc` for (int i = 0; i < 10; i++) { @@ -171,6 +172,38 @@ TYPED_TEST(BasicGeneratorTestsAll, Diag) } } + { + // example-begin diag-op-test-2 + // Assign the diagonal elements of `tc` to `td` with the 6th diagonal. + auto op = diag(tc, 6); + (tdk = op).run(exec); + // example-end diag-op-test-2 + + exec.sync(); + + ASSERT_EQ(op.Size(0), 4); + ASSERT_EQ(op.Rank(), 1); + for (int i = 0; i < tdk.Size(0); i++) { + MATX_ASSERT_EQ(tdk(i), tc(i, i + 6)); + } + } + + { + // example-begin diag-op-test-3 + // Assign the diagonal elements of `tc` to `td` with the 6th diagonal. + auto op = diag(tc, -6); + (tdk = op).run(exec); + // example-end diag-op-test-3 + + exec.sync(); + + ASSERT_EQ(op.Size(0), 4); + ASSERT_EQ(op.Rank(), 1); + for (int i = 0; i < tdk.Size(0); i++) { + MATX_ASSERT_EQ(tdk(i), tc(i + 6, i)); + } + } + // Test with a nested transform. Restrict to floating point types for // the convolution if constexpr (std::is_same_v || std::is_same_v)