Skip to content

Commit e5fa55b

Browse files
Peter Enescufacebook-github-bot
authored andcommitted
feat: Add Presto function array_top_n
Summary: Adds Presto function array_top_n as a simple function in Velox. Function uses a temporary vector to store inputted values and heap sorts them up to k values (second input to function). Updates ArrayFunction.h with struct ArrayTopNFunction and adds new tester function ArrayTopNTest.cpp Differential Revision: D68031372
1 parent ce273fa commit e5fa55b

File tree

4 files changed

+533
-0
lines changed

4 files changed

+533
-0
lines changed

velox/functions/prestosql/ArrayFunctions.h

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121
#include "velox/expression/PrestoCastHooks.h"
2222
#include "velox/functions/Udf.h"
2323
#include "velox/functions/lib/CheckedArithmetic.h"
24+
#include "velox/functions/lib/ComparatorUtil.h"
2425
#include "velox/functions/prestosql/json/SIMDJsonUtil.h"
2526
#include "velox/functions/prestosql/types/JsonType.h"
2627
#include "velox/type/Conversions.h"
2728
#include "velox/type/FloatingPointUtil.h"
2829

30+
#include <queue>
31+
2932
namespace facebook::velox::functions {
3033

3134
template <typename TExecCtx, bool isMax>
@@ -729,6 +732,120 @@ inline void checkIndexArrayTrim(int64_t size, int64_t arraySize) {
729732
}
730733
}
731734

735+
/// This class implements the array_top_n function.
736+
///
737+
/// DEFINITION:
738+
/// array_top_n(array(T), int) -> array(T)
739+
/// Returns the top n elements of the array in descending order.
740+
template <typename T>
741+
struct ArrayTopNFunction {
742+
VELOX_DEFINE_FUNCTION_TYPES(T);
743+
744+
// Definition for primitives.
745+
template <typename TReturn, typename TInput>
746+
FOLLY_ALWAYS_INLINE void
747+
call(TReturn& result, const TInput& array, int64_t n) {
748+
VELOX_CHECK(
749+
n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n));
750+
751+
// Define min-heap to store the top n elements.
752+
std::priority_queue<
753+
typename TInput::element_t,
754+
std::vector<typename TInput::element_t>,
755+
std::greater<>>
756+
minHeap;
757+
758+
// Iterate through the array and push elements to the min-heap.
759+
int numNull = 0;
760+
for (const auto& item : array) {
761+
if (item.has_value()) {
762+
minHeap.push(item.value());
763+
if (minHeap.size() > n) {
764+
minHeap.pop();
765+
}
766+
} else {
767+
++numNull;
768+
}
769+
}
770+
771+
// Reverse the min-heap to get the top n elements in descending order.
772+
std::vector<typename TInput::element_t> reversed(minHeap.size());
773+
auto index = minHeap.size();
774+
while (!minHeap.empty()) {
775+
reversed[--index] = minHeap.top();
776+
minHeap.pop();
777+
}
778+
779+
// Copy mutated vector to result vector up to minHeap's size items.
780+
for (const auto& item : reversed) {
781+
result.push_back(item);
782+
}
783+
784+
// Backfill nulls if needed.
785+
while (result.size() < n && numNull > 0) {
786+
result.add_null();
787+
--numNull;
788+
}
789+
}
790+
791+
// Generic implementation.
792+
FOLLY_ALWAYS_INLINE void call(
793+
out_type<Array<Orderable<T1>>>& result,
794+
const arg_type<Array<Orderable<T1>>>& array,
795+
const int64_t n) {
796+
VELOX_CHECK(
797+
n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n));
798+
799+
// Define comparator to compare complex types.
800+
struct ComplexTypeComparator {
801+
const arg_type<Array<Orderable<T1>>>& array;
802+
ComplexTypeComparator(const arg_type<Array<Orderable<T1>>>& array)
803+
: array(array) {}
804+
805+
bool operator()(const int64_t& a, const int64_t& b) const {
806+
static constexpr CompareFlags kFlags = {
807+
.nullHandlingMode =
808+
CompareFlags::NullHandlingMode::kNullAsIndeterminate};
809+
return array[a].value().compare(array[b].value(), kFlags).value() > 0;
810+
}
811+
};
812+
813+
// Iterate through the array and push elements to the min-heap.
814+
std::priority_queue<int64_t, std::vector<int64_t>, ComplexTypeComparator>
815+
minHeap(array);
816+
int numNull = 0;
817+
for (int i = 0; i < array.size(); ++i) {
818+
if (array[i].has_value()) {
819+
minHeap.push(i);
820+
if (minHeap.size() > n) {
821+
minHeap.pop();
822+
}
823+
} else {
824+
++numNull;
825+
}
826+
}
827+
828+
// Reverse the min-heap to get the top n elements in descending order.
829+
std::vector<int64_t> reversed(minHeap.size());
830+
auto index = minHeap.size();
831+
while (!minHeap.empty()) {
832+
reversed[--index] = minHeap.top();
833+
minHeap.pop();
834+
}
835+
836+
// Copy mutated vector to result vector up to minHeap's size items.
837+
for (const auto& index : reversed) {
838+
result.push_back(array[index].value());
839+
}
840+
841+
// Backfill nulls if needed.
842+
while (result.size() < n && numNull > 0) {
843+
result.add_null();
844+
--numNull;
845+
}
846+
}
847+
};
848+
732849
template <typename T>
733850
struct ArrayTrimFunction {
734851
VELOX_DEFINE_FUNCTION_TYPES(T);

velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ inline void registerArrayTrimFunctions(const std::string& prefix) {
9797
{prefix + "trim_array"});
9898
}
9999

100+
template <typename T>
101+
inline void registerArrayTopNFunction(const std::string& prefix) {
102+
registerFunction<ArrayTopNFunction, Array<T>, Array<T>, int64_t>(
103+
{prefix + "array_top_n"});
104+
}
105+
100106
template <typename T>
101107
inline void registerArrayRemoveNullFunctions(const std::string& prefix) {
102108
registerFunction<ArrayRemoveNullFunction, Array<T>, Array<T>>(
@@ -241,6 +247,19 @@ void registerArrayFunctions(const std::string& prefix) {
241247
Array<Varchar>,
242248
int64_t>({prefix + "trim_array"});
243249

250+
registerArrayTopNFunction<int8_t>(prefix);
251+
registerArrayTopNFunction<int16_t>(prefix);
252+
registerArrayTopNFunction<int32_t>(prefix);
253+
registerArrayTopNFunction<int64_t>(prefix);
254+
registerArrayTopNFunction<int128_t>(prefix);
255+
registerArrayTopNFunction<float>(prefix);
256+
registerArrayTopNFunction<double>(prefix);
257+
registerArrayTopNFunction<Varchar>(prefix);
258+
registerArrayTopNFunction<Timestamp>(prefix);
259+
registerArrayTopNFunction<Date>(prefix);
260+
registerArrayTopNFunction<Varbinary>(prefix);
261+
registerArrayTopNFunction<Orderable<T1>>(prefix);
262+
244263
registerArrayRemoveNullFunctions<int8_t>(prefix);
245264
registerArrayRemoveNullFunctions<int16_t>(prefix);
246265
registerArrayRemoveNullFunctions<int32_t>(prefix);

0 commit comments

Comments
 (0)