Skip to content

Commit 6cacb59

Browse files
committed
Merge pull request #1391 from pavanky/compile_fix
Fix warnings and memory allocation for cpu sort
2 parents 37a47fb + 9443bf2 commit 6cacb59

File tree

2 files changed

+42
-30
lines changed

2 files changed

+42
-30
lines changed

src/backend/cpu/kernel/sort_by_key_impl.hpp

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,11 @@ void sort0ByKeyIterative(Array<Tk> okey, Array<Tv> oval)
3131
Tk *okey_ptr = okey.get();
3232
Tv *oval_ptr = oval.get();
3333

34-
std::vector<IndexPair<Tk, Tv> > X;
35-
X.reserve(okey.dims()[0]);
34+
typedef IndexPair<Tk, Tv> CurrentPair;
35+
36+
dim_t size = okey.dims()[0];
37+
size_t bytes = size * sizeof(CurrentPair);
38+
CurrentPair *pairKeyVal = (CurrentPair *)memAlloc<char>(bytes);
3639

3740
for(dim_t w = 0; w < okey.dims()[3]; w++) {
3841
dim_t okeyW = w * okey.strides()[3];
@@ -47,23 +50,24 @@ void sort0ByKeyIterative(Array<Tk> okey, Array<Tv> oval)
4750
dim_t okeyOffset = okeyWZ + y * okey.strides()[1];
4851
dim_t ovalOffset = ovalWZ + y * oval.strides()[1];
4952

50-
X.clear();
51-
std::transform(okey_ptr + okeyOffset, okey_ptr + okeyOffset + okey.dims()[0],
52-
oval_ptr + ovalOffset,
53-
std::back_inserter(X),
54-
[](Tk v_, Tv i_) { return std::make_pair(v_, i_); }
55-
);
53+
Tk *okey_col_ptr = okey_ptr + okeyOffset;
54+
Tv *oval_col_ptr = oval_ptr + ovalOffset;
55+
56+
for(dim_t x = 0; x < size; x++) {
57+
pairKeyVal[x] = std::make_tuple(okey_col_ptr[x], oval_col_ptr[x]);
58+
}
5659

57-
std::stable_sort(X.begin(), X.end(), IPCompare<Tk, Tv, isAscending>());
60+
std::stable_sort(pairKeyVal, pairKeyVal + size, IPCompare<Tk, Tv, isAscending>());
5861

59-
for(unsigned it = 0; it < X.size(); it++) {
60-
okey_ptr[okeyOffset + it] = X[it].first;
61-
oval_ptr[ovalOffset + it] = X[it].second;
62+
for(unsigned x = 0; x < size; x++) {
63+
okey_ptr[okeyOffset + x] = std::get<0>(pairKeyVal[x]);
64+
oval_ptr[ovalOffset + x] = std::get<1>(pairKeyVal[x]);
6265
}
6366
}
6467
}
6568
}
6669

70+
memFree((char *)pairKeyVal);
6771
return;
6872
}
6973

@@ -108,24 +112,27 @@ void sortByKeyBatched(Array<Tk> okey, Array<Tv> oval)
108112
Tk *okey_ptr = okey.get();
109113
Tv *oval_ptr = oval.get();
110114

111-
std::vector<KeyIndexPair<Tk, Tv> > X;
112-
X.reserve(okey.elements());
115+
typedef KeyIndexPair<Tk, Tv> CurrentTuple;
116+
size_t size = okey.elements();
117+
size_t bytes = okey.elements() * sizeof(CurrentTuple);
118+
CurrentTuple *tupleKeyValIdx = (CurrentTuple *)memAlloc<char>(bytes);
113119

114-
for(unsigned i = 0; i < okey.elements(); i++) {
115-
X.push_back(std::make_pair(std::make_pair(okey_ptr[i], oval_ptr[i]), key[i]));
120+
for(unsigned i = 0; i < size; i++) {
121+
tupleKeyValIdx[i] = std::make_tuple(okey_ptr[i], oval_ptr[i], key[i]);
116122
}
117123

118124
memFree(key); // key is no longer required
119125

120-
std::stable_sort(X.begin(), X.end(), KIPCompareV<Tk, Tv, isAscending>());
126+
std::stable_sort(tupleKeyValIdx, tupleKeyValIdx + size, KIPCompareV<Tk, Tv, isAscending>());
121127

122-
std::stable_sort(X.begin(), X.end(), KIPCompareK<Tk, Tv, true>());
128+
std::stable_sort(tupleKeyValIdx, tupleKeyValIdx + size, KIPCompareK<Tk, Tv, true>());
123129

124-
for(unsigned it = 0; it < okey.elements(); it++) {
125-
okey_ptr[it] = X[it].first.first;
126-
oval_ptr[it] = X[it].first.second;
130+
for(unsigned x = 0; x < okey.elements(); x++) {
131+
okey_ptr[x] = std::get<0>(tupleKeyValIdx[x]);
132+
oval_ptr[x] = std::get<1>(tupleKeyValIdx[x]);
127133
}
128134

135+
memFree((char *)tupleKeyValIdx);
129136
return;
130137
}
131138

@@ -163,4 +170,3 @@ void sort0ByKey(Array<Tk> okey, Array<Tv> oval)
163170
INSTANTIATE(Tk, uintl , dr)
164171
}
165172
}
166-

src/backend/cpu/kernel/sort_helper.hpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,34 @@ namespace cpu
1414
namespace kernel
1515
{
1616
template <typename Tk, typename Tv>
17-
using IndexPair = std::pair<Tk, Tv>;
17+
using IndexPair = std::tuple<Tk, Tv>;
1818

1919
template <typename Tk, typename Tv, bool isAscending>
2020
struct IPCompare
2121
{
2222
bool operator()(const IndexPair<Tk, Tv> &lhs, const IndexPair<Tk, Tv> &rhs)
2323
{
2424
// Check stable sort condition
25-
if(isAscending) return (lhs.first < rhs.first);
26-
else return (lhs.first > rhs.first);
25+
Tk lhsVal = std::get<0>(lhs);
26+
Tk rhsVal = std::get<0>(rhs);
27+
if(isAscending) return (lhsVal < rhsVal);
28+
else return (lhsVal > rhsVal);
2729
}
2830
};
2931

3032
template <typename Tk, typename Tv>
31-
using KeyIndexPair = std::pair<IndexPair<Tk, Tv>, uint>;
33+
using KeyIndexPair = std::tuple<Tk, Tv, uint>;
3234

3335
template <typename Tk, typename Tv, bool isAscending>
3436
struct KIPCompareV
3537
{
3638
bool operator()(const KeyIndexPair<Tk, Tv> &lhs, const KeyIndexPair<Tk, Tv> &rhs)
3739
{
3840
// Check stable sort condition
39-
if(isAscending) return (lhs.first.first < rhs.first.first);
40-
else return (lhs.first.first > rhs.first.first);
41+
Tk lhsVal = std::get<0>(lhs);
42+
Tk rhsVal = std::get<0>(rhs);
43+
if(isAscending) return (lhsVal < rhsVal);
44+
else return (lhsVal > rhsVal);
4145
}
4246
};
4347

@@ -46,8 +50,10 @@ namespace cpu
4650
{
4751
bool operator()(const KeyIndexPair<Tk, Tv> &lhs, const KeyIndexPair<Tk, Tv> &rhs)
4852
{
49-
if(isAscending) return (lhs.second < rhs.second);
50-
else return (lhs.second > rhs.second);
53+
uint lhsVal = std::get<2>(lhs);
54+
uint rhsVal = std::get<2>(rhs);
55+
if(isAscending) return (lhsVal < rhsVal);
56+
else return (lhsVal > rhsVal);
5157
}
5258
};
5359
}

0 commit comments

Comments
 (0)