Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions speedtests/sequences.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ bool Next(const TRange& range, TValue& value)
template <class TValue, TValue... values>
struct BoostSequence
{
using ValueType = TValue;
using value_type = TValue;
using const_iterator = const TValue*;
constexpr const_iterator begin() const { return arr.begin(); }
constexpr const_iterator end() const { return arr.end(); }
Expand All @@ -188,7 +188,7 @@ struct span_impl;
template <class Start, class T, T... Ns>
struct span_impl<Start, std::integer_sequence<T, Ns...>>
{
using ValueType = T;
using value_type = T;
using const_iterator = const T*;
T data[sizeof...(Ns)] = {(Ns + Start{})...};

Expand Down
81 changes: 44 additions & 37 deletions src/include/miopen/sequences.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ To be a sequence class must follow the concept:

class Sequence
{
using ValueType = <unspecified>;
using value_type = <unspecified>;

Sequence();

Expand All @@ -61,11 +61,10 @@ class Sequence
constexpr <unspecified> find(TValue value) const;
}

ValueType may have any constraints depending on specific sequences but it also must have operator==.
It is forbidden to use any of existing sequences with several equal values. It will lead to hangs.
Each of existing sequences which can produce such scenario has assert which will fail with a reason
of failure stated.
Example: sequence 1-2-1 will be looped as:
'value_type' may have any constraints depending on specific sequences but it also must have
operator==. It is forbidden to use any of existing sequences with several equal values. It will lead
to hangs. Each of existing sequences which can produce such scenario has assert which will fail with
a reason of failure stated. Example: sequence 1-2-1 will be looped as:

1-2-1 -> 1-2-1 -> 1-2-1 -> 1-2-1
| | | |
Expand Down Expand Up @@ -185,19 +184,27 @@ auto GenericFind(const TRange& range, const TValue& value)
template <class TValue, TValue... values>
struct Sequence
{
using const_iterator = const TValue*;
using ValueType = TValue;
using sequence_type = std::array<TValue, sizeof...(values)>;
using value_type = typename sequence_type::value_type;
using iterator = typename sequence_type::iterator;
using const_iterator = typename sequence_type::const_iterator;
using reference = typename sequence_type::reference;
using const_reference = typename sequence_type::const_reference;
using difference_type = typename sequence_type::difference_type;

Sequence() { assert(ValidateValues() && "Values must be unique"); }

constexpr const_iterator begin() const { return data.begin(); }
constexpr const_iterator end() const { return data.end(); }
constexpr const_iterator find(const TValue& value) const { return data.data() + find_(value); }
constexpr const_iterator find(const_reference value) const
{
return std::next(data.begin(), find_(value));
}

private:
static constexpr std::array<int, sizeof...(values)> data = {{values...}};

static constexpr int ValuesCount() { return sizeof...(values); }
static constexpr difference_type ValuesCount() { return sizeof...(values); }

static constexpr bool ValidateValues()
{
Expand All @@ -209,31 +216,31 @@ struct Sequence
return true;
}

template <int icur, TValue cur, TValue... rest>
template <difference_type icur, TValue cur, TValue... rest>
struct Find
{
int operator()(const TValue& value) const
constexpr difference_type operator()(const TValue& value) const
{
if(value == cur)
return icur;
return rest_(value);
}

Find<icur + 1, rest...> rest_ = {};
Find<icur + 1, rest...> rest_{};
};

template <int icur, TValue cur>
template <difference_type icur, TValue cur>
struct Find<icur, cur>
{
int operator()(const TValue& value) const
constexpr difference_type operator()(const TValue& value) const
{
if(value == cur)
return icur;
return icur + 1;
}
};

Find<0, values...> find_ = {};
Find<0, values...> find_{};
};

template <class TValue>
Expand Down Expand Up @@ -280,7 +287,7 @@ template <class TValue, TValue low, TValue high>
struct Span
{
using const_iterator = SpanIterator<int, high>;
using ValueType = TValue;
using value_type = TValue;

constexpr const_iterator begin() const { return {low}; }
constexpr const_iterator end() const { return {}; }
Expand Down Expand Up @@ -326,7 +333,7 @@ template <class TValue, TValue low, TValue high>
struct TwoPowersSpan
{
using const_iterator = TwoPowersSpanIterator<int, high>;
using ValueType = TValue;
using value_type = TValue;

constexpr const_iterator begin() const { return {low}; }
constexpr const_iterator end() const { return {high + 1}; }
Expand All @@ -349,12 +356,12 @@ struct TwoPowersSpan
};

template <class TFirst, class... TRest>
struct JoinIterator : public SequenceIteratorBase<typename TFirst::ValueType>
struct JoinIterator : public SequenceIteratorBase<typename TFirst::value_type>
{
using ValueType = typename TFirst::ValueType;
using value_type = typename TFirst::value_type;

JoinIterator() : SequenceIteratorBase<ValueType>(0) {}
JoinIterator(ValueType value_) : SequenceIteratorBase<ValueType>(value_), finished(false) {}
JoinIterator() : SequenceIteratorBase<value_type>(0) {}
JoinIterator(value_type value_) : SequenceIteratorBase<value_type>(value_), finished(false) {}

JoinIterator& operator++()
{
Expand Down Expand Up @@ -389,7 +396,7 @@ struct JoinIterator : public SequenceIteratorBase<typename TFirst::ValueType>
template <class TCur, class TNext, class... TRest_>
struct Next<TCur, TNext, TRest_...>
{
bool operator()(ValueType& value)
bool operator()(value_type& value)
{
auto it = GenericFind(cur, value);

Expand All @@ -415,7 +422,7 @@ struct JoinIterator : public SequenceIteratorBase<typename TFirst::ValueType>
template <class TSeq>
struct Next<TSeq>
{
bool operator()(ValueType& value)
bool operator()(value_type& value)
{
auto it = GenericFind(seq, value);

Expand All @@ -436,20 +443,20 @@ struct JoinIterator : public SequenceIteratorBase<typename TFirst::ValueType>
template <class TFirst, class... TRest>
struct Join
{
using ValueType = typename TFirst::ValueType;
using value_type = typename TFirst::value_type;
using const_iterator = JoinIterator<TFirst, TRest...>;

Join() { assert(Validate() && "Values must be unique"); }

constexpr const_iterator begin() const { return {*first.begin()}; }
constexpr const_iterator end() const { return {}; }
constexpr const_iterator find(ValueType value) const { return find_(value); }
constexpr const_iterator find(value_type value) const { return find_(value); }

private:
template <class TCur, class... TRest_>
struct Find
{
const_iterator operator()(ValueType value) const
const_iterator operator()(value_type value) const
{
auto it = cur(value);
return it ? it : rest(value);
Expand All @@ -463,7 +470,7 @@ struct Join
template <class TSeq>
struct Find<TSeq>
{
const_iterator operator()(ValueType value) const
const_iterator operator()(value_type value) const
{
auto it = GenericFind(seq, value);

Expand All @@ -480,7 +487,7 @@ struct Join
bool Validate() const
{
auto cur = begin();
std::vector<ValueType> values(1, *cur);
std::vector<value_type> values(1, *cur);

while(++cur != end())
{
Expand All @@ -497,11 +504,11 @@ struct Join
Find<TFirst, TRest...> find_ = {};
};

template <class TInner, typename TInner::ValueType mul>
template <class TInner, typename TInner::value_type mul>
struct MultipliedIterator
{
using InnerIterator = typename TInner::const_iterator;
using ValueType = typename TInner::ValueType;
using value_type = typename TInner::value_type;

MultipliedIterator(InnerIterator inner_) : inner(inner_) {}

Expand All @@ -518,7 +525,7 @@ struct MultipliedIterator
return copy;
}

ValueType operator*() const { return mul * *inner; }
value_type operator*() const { return mul * *inner; }
bool operator==(const MultipliedIterator& other) const { return inner == other.inner; }
bool operator!=(const MultipliedIterator& other) const { return !(*this == other); }

Expand All @@ -527,15 +534,15 @@ struct MultipliedIterator
};

/// A sequence containing values of another sequence multiplied by a constant.
template <class TInner, typename TInner::ValueType mul>
template <class TInner, typename TInner::value_type mul>
struct Multiplied
{
using ValueType = typename TInner::ValueType;
using value_type = typename TInner::value_type;
using const_iterator = MultipliedIterator<TInner, mul>;

constexpr const_iterator begin() const { return {inner.begin()}; }
constexpr const_iterator end() const { return {inner.end()}; }
constexpr const_iterator find(ValueType value) const
constexpr const_iterator find(value_type value) const
{
return {GenericFind(inner, value / mul)};
}
Expand Down Expand Up @@ -614,7 +621,7 @@ bool SeqNextImpl(rank<1>, const Sequence<TValue, first, values...>&, TValue& val
}

template <class TSequence>
bool SeqNextImpl(rank<0>, const TSequence& seq, typename TSequence::ValueType& value)
bool SeqNextImpl(rank<0>, const TSequence& seq, typename TSequence::value_type& value)
{
auto it = GenericFind(seq, value);

Expand All @@ -630,7 +637,7 @@ bool SeqNextImpl(rank<0>, const TSequence& seq, typename TSequence::ValueType& v
}

template <class TSequence>
bool SeqNext(const TSequence& seq, typename TSequence::ValueType& value)
bool SeqNext(const TSequence& seq, typename TSequence::value_type& value)
{
return SeqNextImpl(rank<16>{}, seq, value);
}
Expand Down