Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions docs_input/api/creation/operators/diag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ diag
====

`diag` comes in two forms: a generator and an operator. The generator version is used to generate a diagonal
tensor with a given value, while the operator pulls diagonal elements from a tensor.
tensor with a given value or a 1D input operator, while the operator pulls diagonal elements from a tensor.

Operator
________
.. doxygenfunction:: matx::diag(T1 t)
.. doxygenfunction:: diag(T1 t, index_t k = 0)

Examples
~~~~~~~~
Expand All @@ -19,10 +19,23 @@ Examples
:end-before: example-end diag-op-test-1
:dedent:

.. literalinclude:: ../../../../test/00_operators/GeneratorTests.cu
:language: cpp
:start-after: example-begin diag-op-test-2
:end-before: example-end diag-op-test-2
:dedent:

.. literalinclude:: ../../../../test/00_operators/GeneratorTests.cu
:language: cpp
:start-after: example-begin diag-op-test-3
:end-before: example-end diag-op-test-3
:dedent:

Generator
_________

.. doxygenfunction:: matx::diag(const index_t (&s)[RANK], T val)
.. doxygenfunction:: matx::diag(T val)

Examples
~~~~~~~~
Expand Down
3 changes: 2 additions & 1 deletion include/matx/generators/diag.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ namespace matx
*
*/
template <typename T = int, typename ShapeType,
std::enable_if_t<!std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
std::enable_if_t<!std::is_array_v<typename remove_cvref<ShapeType>::type> &&
!is_matx_op<ShapeType>(), bool> = true>
inline auto diag(ShapeType &&s, T val)
{
return detail::Diag<T, ShapeType>(std::forward<ShapeType>(s), val);
Expand Down
81 changes: 67 additions & 14 deletions include/matx/operators/diag.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,39 +52,82 @@ namespace matx
{
private:
typename base_type<T1>::type op_;
index_t k_;

public:
using matxop = bool;
using value_type = typename T1::value_type;

__MATX_INLINE__ std::string str() const { return "diag(" + op_.str() + ")"; }

__MATX_INLINE__ DiagOp(T1 op) : op_(op) {}
__MATX_INLINE__ DiagOp(T1 op, index_t k) : op_(op), k_(k) { }

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const
{
static_assert(sizeof...(Is) == RANK - 1, "Diagonal operator must have one fewer index than rank of operator");
static_assert(RANK > 1, "Cannot make get diagonals from 0D tensor");

static_assert(RANK != 0, "Cannot make get diagonals from 0D tensor");
using tt = cuda::std::tuple_element_t<0, cuda::std::tuple<Is...>>;
auto tup = cuda::std::make_tuple(indices..., static_cast<tt>(0));
cuda::std::get<RANK - 1>(tup) = pp_get<RANK-2>(indices...);
return cuda::std::apply(op_, tup);

if constexpr (RANK == 1) {
static_assert(sizeof...(Is) == 2, "Indexing of diag() on a 1D input must be 2 indices");
if (((pp_get<0>(indices...) == indices) && ...)) {
return (value_type)(pp_get<0>(indices...));
}
else {
return (value_type)(0);
}
}
else {
static_assert(sizeof...(Is) == RANK - 1, "Diagonal operator must have one fewer op() index than rank of operator");

// Offset either the rows or columns by k_, depending on if it's negative
if (k_ < 0) {
auto tup = cuda::std::make_tuple(indices..., static_cast<tt>(0));
cuda::std::get<RANK - 1>(tup) = pp_get<RANK-2>(indices...) ;
cuda::std::get<RANK - 2>(tup) = cuda::std::get<RANK - 2>(tup) - k_;
return cuda::std::apply(op_, tup);
}
else {
auto tup = cuda::std::make_tuple(indices..., static_cast<tt>(0));
cuda::std::get<RANK - 1>(tup) = pp_get<RANK-2>(indices...) + k_;
return cuda::std::apply(op_, tup);
}
}
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
return RANK - 1;
if constexpr (RANK == 1) {
return 2;
}
else {
return RANK - 1;
}
}

constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size([[maybe_unused]] int dim) const
{
if (dim < RANK - 2) {
return op_.Size(dim);
if constexpr (RANK == 1) {
return op_.Size(0);
}
else {
return cuda::std::min(op_.Size(RANK - 1), op_.Size(RANK-2));
if (dim < RANK - 2) {
return op_.Size(dim);
}
else {
if (k_ == 0) {
return cuda::std::min(op_.Size(RANK - 1), op_.Size(RANK-2));
}
else {
// If k is off the main diagonal we need to adjust the sizes
if (k_ > 0) {
return cuda::std::min(op_.Size(RANK - 1), op_.Size(RANK-2) - k_);
}
else {
return cuda::std::min(op_.Size(RANK - 1) + k_, op_.Size(RANK-2));
}
}
}
}
}

Expand All @@ -107,11 +150,21 @@ namespace matx
}

/**
* Get the elements on the diagonal
* Get the elements on the diagonal (2D inputs and above), or generate a diagonal matrix (1D input)
*
* @param t
* Input operator
* @param k
* Diagonal to pull (0 is the main diagonal). Only used for 2D tensors and above
*/
template <typename T1>
auto __MATX_INLINE__ diag(T1 t) { return detail::DiagOp<T1, T1::Rank()>(t); }
#ifdef DOXYGEN_ONLY
auto __MATX_INLINE__ diag(T1 t, index_t k = 0) {
#else
template <typename T1, std::enable_if_t<is_matx_op<T1>(), bool> = true>
auto __MATX_INLINE__ diag(T1 t, index_t k = 0) {
#endif
MATX_ASSERT_STR(T1::Rank() != 1 || k == 0, matxInvalidParameter,
"k parameter in diag() can only be used for 2D tensors and above");
return detail::DiagOp<T1, T1::Rank()>(t, k);
}
} // end namespace matx
35 changes: 34 additions & 1 deletion test/00_operators/GeneratorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,11 @@ TYPED_TEST(BasicGeneratorTestsAll, Diag)
MATX_ENTER_HANDLER();
{
// example-begin diag-op-test-1
// The generator form of `diag()` takes an operator input and returns only
// The generator form of `diag()` with a >= 2D input takes an operator input and returns only
// the diagonal elements as output
auto tc = make_tensor<TestType>({10, 10});
auto td = make_tensor<TestType>({10});
auto tdk = make_tensor<TestType>({4});

// Initialize the diagonal elements of `tc`
for (int i = 0; i < 10; i++) {
Expand All @@ -171,6 +172,38 @@ TYPED_TEST(BasicGeneratorTestsAll, Diag)
}
}

{
// example-begin diag-op-test-2
// Assign the diagonal elements of `tc` to `td` with the 6th diagonal.
auto op = diag(tc, 6);
(tdk = op).run(exec);
// example-end diag-op-test-2

exec.sync();

ASSERT_EQ(op.Size(0), 4);
ASSERT_EQ(op.Rank(), 1);
for (int i = 0; i < tdk.Size(0); i++) {
MATX_ASSERT_EQ(tdk(i), tc(i, i + 6));
}
}

{
// example-begin diag-op-test-3
// Assign the diagonal elements of `tc` to `td` with the 6th diagonal.
auto op = diag(tc, -6);
(tdk = op).run(exec);
// example-end diag-op-test-3

exec.sync();

ASSERT_EQ(op.Size(0), 4);
ASSERT_EQ(op.Rank(), 1);
for (int i = 0; i < tdk.Size(0); i++) {
MATX_ASSERT_EQ(tdk(i), tc(i + 6, i));
}
}

// Test with a nested transform. Restrict to floating point types for
// the convolution
if constexpr (std::is_same_v<TestType,float> || std::is_same_v<TestType,double>)
Expand Down