Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions include/matx/operators/shift.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ namespace matx
__MATX_INLINE__ ShiftOp(T1 op, T2 shift) : op_(op), shift_(shift)
{
static_assert(DIM < Rank(), "Dimension to shift must be less than rank of tensor");
ASSERT_COMPATIBLE_OP_SIZES(shift_);
ASSERT_COMPATIBLE_OP_SIZES(op_);
}

template <typename... Is>
Expand Down Expand Up @@ -103,12 +105,14 @@ namespace matx

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
return detail::get_rank<T1>();
return detail::matx_max(detail::get_rank<T1>(), detail::get_rank<T2>());
}

constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto Size(int dim) const noexcept
{
return op_.Size(dim);
index_t size1 = detail::get_expanded_size<Rank()>(op_, dim);
index_t size2 = detail::get_expanded_size<Rank()>(shift_, dim);
return detail::matx_max(size1,size2);
}

template<typename R> __MATX_INLINE__ auto operator=(const R &rhs) { return set(*this, rhs); }
Expand Down