Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 51 additions & 47 deletions include/matx/core/iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
namespace matx {
/**
* @brief Iterator around operators for libraries that can take iterators as input (CUB).
*
*
* @tparam T Data type
* @tparam RANK Rank of tensor
* @tparam Desc Descriptor for tensor
*
*
*/
template <typename OperatorType, bool ConvertType = true>
struct RandomOperatorIterator {
Expand All @@ -66,14 +66,14 @@ struct RandomOperatorIterator {

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(const OperatorBaseType &t, stride_type offset) : t_(t), offset_(offset) {}

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(OperatorBaseType &&t, stride_type offset) : t_(t), offset_(offset) {}

/**
* @brief Dereference value at a pre-computed offset
*
* @return Value at offset
*
* @return Value at offset
*/
[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ value_type operator*() const
{
Expand All @@ -86,9 +86,9 @@ struct RandomOperatorIterator {
return cuda::std::apply([&](auto &&...args) -> value_type {
const auto tmp = t_.operator()(args...);
return tmp;
}, arrs);
}, arrs);
}
}
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type& operator=(const self_type &rhs)
{
Expand All @@ -97,7 +97,7 @@ struct RandomOperatorIterator {
}
offset_ = rhs.offset_;
return *this;
}
}


[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type operator+(difference_type offset) const
Expand All @@ -114,27 +114,27 @@ struct RandomOperatorIterator {
auto arrs = detail::GetIdxFromAbs(t_, offset_+offset);
return cuda::std::apply([&](auto &&...args) -> value_type {
return static_cast<value_type>(t_.operator()(args...));
}, arrs);
}, arrs);
}
}
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type operator++(int)
{
self_type retval = *this;
offset_++;
return retval;
}
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type operator++()
{
offset_++;
return *this;
}
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ difference_type offset()
{
return offset_;
}
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type& operator+=(difference_type offset)
{
Expand All @@ -146,7 +146,7 @@ struct RandomOperatorIterator {
{
offset_ -= offset;
return *this;
}
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ friend bool operator!=(const self_type &a, const self_type &b)
{
Expand All @@ -156,7 +156,7 @@ struct RandomOperatorIterator {
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ friend bool operator==(const self_type &a, const self_type &b)
{
return a.offset_ == b.offset_;
}
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() {
return OperatorType::Rank();
Expand All @@ -165,27 +165,27 @@ struct RandomOperatorIterator {
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
{
return t_.Size(dim);
}
}

OperatorBaseType t_;
stride_type offset_;
stride_type offset_;
};

template <typename OperatorType, bool ConvertType = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t operator-(const RandomOperatorIterator<OperatorType, ConvertType> &a, const RandomOperatorIterator<OperatorType, ConvertType> &b)
{
return a.offset_ - b.offset_;
}
}



/**
* @brief Iterator around operators for libraries that can take iterators as output (CUB).
*
*
* @tparam T Data type
* @tparam RANK Rank of tensor
* @tparam Desc Descriptor for tensor
*
*
*/
template <typename OperatorType, bool ConvertType = true>
struct RandomOperatorOutputIterator {
Expand All @@ -209,7 +209,7 @@ struct RandomOperatorOutputIterator {

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(const OperatorBaseType &t, stride_type offset) : t_(t), offset_(offset) {}

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(OperatorBaseType &&t, stride_type offset) : t_(t), offset_(offset) {}

Expand All @@ -225,33 +225,32 @@ struct RandomOperatorOutputIterator {
return cuda::std::apply([&](auto &&...args) -> reference {
auto &tmp = t_.operator()(args...);
return tmp;
}, arrs);
}, arrs);
}
}
}

[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type operator+(difference_type offset) const
{
return self_type{t_, offset_ + offset};
}


[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator[](difference_type offset)
[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator[](difference_type offset) const
{
return *self_type{t_, offset_ + offset};
}
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type operator++(int)
{
self_type retval = *this;
offset_++;
return retval;
}
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type operator++()
{
offset_++;
return *this;
}
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type& operator+=(difference_type offset)
{
Expand All @@ -276,7 +275,7 @@ struct RandomOperatorOutputIterator {
{
offset_ -= offset;
return *this;
}
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type& operator--() {
--offset_;
Expand Down Expand Up @@ -316,26 +315,26 @@ struct RandomOperatorOutputIterator {
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
{
return t_.Size(dim);
}
}

OperatorBaseType t_;
stride_type offset_;
stride_type offset_;
};

/**
* @brief Iterator around operators for libraries that can take iterators as input/output (Thrust).
*
*
* @tparam T Data type
* @tparam RANK Rank of tensor
* @tparam Desc Descriptor for tensor
*
*
*/
template <typename OperatorType, bool ConvertType = true>
struct RandomOperatorThrustIterator {
using self_type = RandomOperatorThrustIterator<OperatorType, ConvertType>;
using const_strip_type = remove_cvref_t<typename OperatorType::value_type>;
using value_type = typename std::conditional_t<ConvertType,
detail::convert_matx_type_t<const_strip_type>,
detail::convert_matx_type_t<const_strip_type>,
const_strip_type>;
// using stride_type = std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
// index_t>;
Expand All @@ -356,7 +355,7 @@ struct RandomOperatorThrustIterator {

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(const OperatorBaseType &t, stride_type offset) : t_(t), offset_(offset) {}

template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(OperatorBaseType &&t, stride_type offset) : t_(t), offset_(offset) {}

Expand All @@ -372,33 +371,32 @@ struct RandomOperatorThrustIterator {
return cuda::std::apply([&](auto &&...args) -> reference {
auto &tmp = const_cast<const_strip_type&>(t_.operator()(args...));
return tmp;
}, arrs);
}, arrs);
}
}
}

[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type operator+(difference_type offset) const
{
return self_type{t_, offset_ + offset};
}


[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator[](difference_type offset)
[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator[](difference_type offset) const
{
return *self_type{t_, offset_ + offset};
}
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type operator++(int)
{
self_type retval = *this;
offset_++;
return retval;
}
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type operator++()
{
offset_++;
return *this;
}
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type& operator+=(difference_type offset)
{
Expand All @@ -423,7 +421,7 @@ struct RandomOperatorThrustIterator {
{
offset_ -= offset;
return *this;
}
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type& operator--() {
--offset_;
Expand All @@ -438,7 +436,7 @@ struct RandomOperatorThrustIterator {
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ friend bool operator==(const self_type &a, const self_type &b)
{
return a.offset_ == b.offset_;
}
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() {
return OperatorType::Rank();
Expand All @@ -447,10 +445,10 @@ struct RandomOperatorThrustIterator {
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
{
return t_.Size(dim);
}
}

OperatorBaseType t_;
stride_type offset_;
stride_type offset_;
};


Expand Down Expand Up @@ -544,6 +542,12 @@ struct EndOffset {
return self_type{size_, offset_ + offset};
}

__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ self_type& operator+=(difference_type offset)
{
offset_ += offset;
return *this;
}

[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ stride_type operator[](difference_type offset) const
{
return ( offset + 1) * size_;
Expand All @@ -559,7 +563,7 @@ template <typename OperatorType>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t operator-(const RandomOperatorOutputIterator<OperatorType> &a, const RandomOperatorOutputIterator<OperatorType> &b)
{
return a.offset_ - b.offset_;
}
}


template <typename Op>
Expand Down