Skip to content

Commit f5fc530

Browse files
JkSelfglutenperfbot
authored andcommitted
Fix array_union on NaN (7086)
1 parent 2606b95 commit f5fc530

File tree

5 files changed

+327
-0
lines changed

5 files changed

+327
-0
lines changed

velox/docs/functions/spark/array.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,15 @@ Array Functions
8686
SELECT array_sort(array(NULL, 1, NULL)); -- [1, NULL, NULL]
8787
SELECT array_sort(array(NULL, 2, 1)); -- [1, 2, NULL]
8888

89+
.. spark:function:: array_union(array(E), array(E1)) -> array(E2)
90+
91+
Returns an array of the elements in the union of array1 and array2, without duplicates. ::
92+
93+
SELECT array_union(array(1, 2, 3), array(1, 3, 5)); -- [1, 2, 3, 5]
94+
SELECT array_union(array(1, 3, 5), array(1, 2, 3)); -- [1, 3, 5, 2]
95+
SELECT array_union(array(1, 2, 3), array(1, 3, 5, null)); -- [1, 2, 3, 5, null]
96+
SELECT array_union(array(1, 2, NaN), array(1, 3, NaN)); -- [1, 2, NaN, 3]
97+
8998
.. spark:function:: concat(array(E), array(E1), ..., array(En)) -> array(E, E1, ..., En)
9099
91100
Returns the concatenation of array(E), array(E1), ..., array(En). ::
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
namespace facebook::velox::functions::sparksql {
20+
21+
/// This class implements the array union function.
22+
///
23+
/// DEFINITION:
24+
/// array_union(x, y) → array
25+
/// Returns an array of the elements in the union of x and y, without
26+
/// duplicates.
27+
template <typename T>
28+
struct ArrayUnionFunction {
29+
VELOX_DEFINE_FUNCTION_TYPES(T)
30+
31+
// Fast path for primitives.
32+
template <typename Out, typename In>
33+
void call(Out& out, const In& inputArray1, const In& inputArray2) {
34+
folly::F14FastSet<typename In::element_t> elementSet;
35+
bool nullAdded = false;
36+
bool nanAdded = false;
37+
auto addItems = [&](auto& inputArray) {
38+
for (const auto& item : inputArray) {
39+
if (item.has_value()) {
40+
if constexpr (
41+
std::is_same_v<In, arg_type<Array<float>>> ||
42+
std::is_same_v<In, arg_type<Array<double>>>) {
43+
bool isNaN = std::isnan(item.value());
44+
if ((isNaN && !nanAdded) ||
45+
(!isNaN && elementSet.insert(item.value()).second)) {
46+
auto& newItem = out.add_item();
47+
newItem = item.value();
48+
}
49+
if (!nanAdded && isNaN) {
50+
nanAdded = true;
51+
}
52+
} else if (elementSet.insert(item.value()).second) {
53+
auto& newItem = out.add_item();
54+
newItem = item.value();
55+
}
56+
} else if (!nullAdded) {
57+
nullAdded = true;
58+
out.add_null();
59+
}
60+
}
61+
};
62+
addItems(inputArray1);
63+
addItems(inputArray2);
64+
}
65+
66+
void call(
67+
out_type<Array<Generic<T1>>>& out,
68+
const arg_type<Array<Generic<T1>>>& inputArray1,
69+
const arg_type<Array<Generic<T1>>>& inputArray2) {
70+
folly::F14FastSet<exec::GenericView> elementSet;
71+
bool nullAdded = false;
72+
auto addItems = [&](auto& inputArray) {
73+
for (const auto& item : inputArray) {
74+
if (item.has_value()) {
75+
if (elementSet.insert(item.value()).second) {
76+
auto& newItem = out.add_item();
77+
newItem.copy_from(item.value());
78+
}
79+
} else if (!nullAdded) {
80+
nullAdded = true;
81+
out.add_null();
82+
}
83+
}
84+
};
85+
addItems(inputArray1);
86+
addItems(inputArray2);
87+
}
88+
};
89+
} // namespace facebook::velox::functions::sparksql

velox/functions/sparksql/Register.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "velox/functions/prestosql/StringFunctions.h"
2929
#include "velox/functions/sparksql/ArrayMinMaxFunction.h"
3030
#include "velox/functions/sparksql/ArraySort.h"
31+
#include "velox/functions/sparksql/ArrayUnionFunction.h"
3132
#include "velox/functions/sparksql/Bitwise.h"
3233
#include "velox/functions/sparksql/DateTimeFunctions.h"
3334
#include "velox/functions/sparksql/Hash.h"
@@ -150,6 +151,12 @@ inline void registerArrayMinMaxFunctions(const std::string& prefix) {
150151
}
151152
} // namespace
152153

154+
template <typename T>
155+
inline void registerArrayUnionFunctions(const std::string& prefix) {
156+
registerFunction<sparksql::ArrayUnionFunction, Array<T>, Array<T>, Array<T>>(
157+
{prefix + "array_union"});
158+
}
159+
153160
void registerFunctions(const std::string& prefix) {
154161
registerAllSpecialFormGeneralFunctions();
155162

@@ -407,6 +414,19 @@ void registerFunctions(const std::string& prefix) {
407414
{prefix + "monotonically_increasing_id"});
408415

409416
registerFunction<UuidFunction, Varchar, Constant<int64_t>>({prefix + "uuid"});
417+
registerArrayUnionFunctions<bool>(prefix);
418+
registerArrayUnionFunctions<int8_t>(prefix);
419+
registerArrayUnionFunctions<int16_t>(prefix);
420+
registerArrayUnionFunctions<int32_t>(prefix);
421+
registerArrayUnionFunctions<int64_t>(prefix);
422+
registerArrayUnionFunctions<int128_t>(prefix);
423+
registerArrayUnionFunctions<float>(prefix);
424+
registerArrayUnionFunctions<double>(prefix);
425+
registerArrayUnionFunctions<Varchar>(prefix);
426+
registerArrayUnionFunctions<Varbinary>(prefix);
427+
registerArrayUnionFunctions<Date>(prefix);
428+
registerArrayUnionFunctions<Timestamp>(prefix);
429+
registerArrayUnionFunctions<Generic<T1>>(prefix);
410430
}
411431

412432
} // namespace sparksql
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "velox/common/base/tests/GTestUtils.h"
18+
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"
19+
20+
using namespace facebook::velox;
21+
using namespace facebook::velox::test;
22+
23+
namespace facebook::velox::functions::sparksql::test {
24+
namespace {
25+
26+
class ArrayUnionTest : public SparkFunctionBaseTest {
27+
protected:
28+
void testExpression(
29+
const std::string& expression,
30+
const std::vector<VectorPtr>& input,
31+
const VectorPtr& expected) {
32+
auto result = evaluate(expression, makeRowVector(input));
33+
assertEqualVectors(expected, result);
34+
}
35+
36+
template <typename T>
37+
void testFloatArray() {
38+
const auto array1 = makeArrayVector<T>(
39+
{{1.99, 2.78, 3.98, 4.01},
40+
{3.89, 4.99, 5.13},
41+
{7.13, 8.91, std::numeric_limits<T>::quiet_NaN()},
42+
{10.02, 20.01, std::numeric_limits<T>::quiet_NaN()}});
43+
const auto array2 = makeArrayVector<T>(
44+
{{2.78, 4.01, 5.99},
45+
{3.89, 4.99, 5.13},
46+
{7.13, 8.91, std::numeric_limits<T>::quiet_NaN()},
47+
{40.99, 50.12}});
48+
49+
VectorPtr expected;
50+
expected = makeArrayVector<T>({
51+
{1.99, 2.78, 3.98, 4.01, 5.99},
52+
{3.89, 4.99, 5.13},
53+
{7.13, 8.91, std::numeric_limits<T>::quiet_NaN()},
54+
{10.02, 20.01, std::numeric_limits<T>::quiet_NaN(), 40.99, 50.12},
55+
});
56+
testExpression("array_union(c0, c1)", {array1, array2}, expected);
57+
58+
expected = makeArrayVector<T>({
59+
{2.78, 4.01, 5.99, 1.99, 3.98},
60+
{3.89, 4.99, 5.13},
61+
{7.13, 8.91, std::numeric_limits<T>::quiet_NaN()},
62+
{40.99, 50.12, 10.02, 20.01, std::numeric_limits<T>::quiet_NaN()},
63+
});
64+
testExpression("array_union(c0, c1)", {array2, array1}, expected);
65+
}
66+
};
67+
68+
// Union two integer arrays.
69+
TEST_F(ArrayUnionTest, intArray) {
70+
const auto array1 = makeArrayVector<int64_t>(
71+
{{1, 2, 3, 4}, {3, 4, 5}, {7, 8, 9}, {10, 20, 30}});
72+
const auto array2 =
73+
makeArrayVector<int64_t>({{2, 4, 5}, {3, 4, 5}, {}, {40, 50}});
74+
VectorPtr expected;
75+
76+
expected = makeArrayVector<int64_t>({
77+
{1, 2, 3, 4, 5},
78+
{3, 4, 5},
79+
{7, 8, 9},
80+
{10, 20, 30, 40, 50},
81+
});
82+
testExpression("array_union(c0, c1)", {array1, array2}, expected);
83+
84+
expected = makeArrayVector<int64_t>({
85+
{2, 4, 5, 1, 3},
86+
{3, 4, 5},
87+
{7, 8, 9},
88+
{40, 50, 10, 20, 30},
89+
});
90+
testExpression("array_union(c0, c1)", {array2, array1}, expected);
91+
}
92+
93+
// Union two float or double arrays.
94+
TEST_F(ArrayUnionTest, floatArray) {
95+
testFloatArray<float>();
96+
testFloatArray<double>();
97+
}
98+
99+
// Union two string arrays.
100+
TEST_F(ArrayUnionTest, stringArray) {
101+
const auto array1 =
102+
makeArrayVector<StringView>({{"foo", "bar"}, {"foo", "baz"}});
103+
const auto array2 =
104+
makeArrayVector<StringView>({{"foo", "bar"}, {"bar", "baz"}});
105+
VectorPtr expected;
106+
107+
expected = makeArrayVector<StringView>({
108+
{"foo", "bar"},
109+
{"foo", "baz", "bar"},
110+
});
111+
testExpression("array_union(c0, c1)", {array1, array2}, expected);
112+
}
113+
114+
// Union two integer arrays with null.
115+
TEST_F(ArrayUnionTest, nullArray) {
116+
const auto array1 = makeNullableArrayVector<int64_t>({
117+
{{1, std::nullopt, 3, 4}},
118+
{7, 8, 9},
119+
{{10, std::nullopt, std::nullopt}},
120+
});
121+
const auto array2 = makeNullableArrayVector<int64_t>({
122+
{{std::nullopt, std::nullopt, 3, 5}},
123+
std::nullopt,
124+
{{1, 10}},
125+
});
126+
VectorPtr expected;
127+
128+
expected = makeNullableArrayVector<int64_t>({
129+
{{1, std::nullopt, 3, 4, 5}},
130+
std::nullopt,
131+
{{10, std::nullopt, 1}},
132+
});
133+
testExpression("array_union(c0, c1)", {array1, array2}, expected);
134+
135+
expected = makeNullableArrayVector<int64_t>({
136+
{{std::nullopt, 3, 5, 1, 4}},
137+
std::nullopt,
138+
{{1, 10, std::nullopt}},
139+
});
140+
testExpression("array_union(c0, c1)", {array2, array1}, expected);
141+
}
142+
143+
// Union array vectors.
144+
TEST_F(ArrayUnionTest, complexTypes) {
145+
auto baseVector = makeArrayVector<int64_t>(
146+
{{1, 1}, {2, 2}, {3, 3}, {4, 4}, {5, 5}, {6, 6}});
147+
148+
// Create arrays of array vector using above base vector.
149+
// [[1, 1], [2, 2]]
150+
// [[3, 3], [4, 4]]
151+
// [[5, 5], [6, 6]]
152+
auto arrayOfArrays1 = makeArrayVector({0, 2, 4}, baseVector);
153+
// [[1, 1], [2, 2], [3, 3]]
154+
// [[4, 4]]
155+
// [[5, 5], [6, 6]]
156+
auto arrayOfArrays2 = makeArrayVector({0, 3, 4}, baseVector);
157+
158+
// [[1, 1], [2, 2], [3, 3]]
159+
// [[3, 3], [4, 4]]
160+
// [[5, 5], [6, 6]]
161+
auto expected = makeArrayVector(
162+
{0, 3, 5},
163+
makeArrayVector<int64_t>(
164+
{{1, 1}, {2, 2}, {3, 3}, {3, 3}, {4, 4}, {5, 5}, {6, 6}}));
165+
166+
testExpression(
167+
"array_union(c0, c1)", {arrayOfArrays1, arrayOfArrays2}, expected);
168+
}
169+
170+
// Union double array vectors.
171+
TEST_F(ArrayUnionTest, complexDoubleType) {
172+
auto baseVector = makeArrayVector<double>(
173+
{{1.0, 1.0},
174+
{2.0, 2.0},
175+
{3.0, 3.0},
176+
{4.0, 4.0},
177+
{5.0, std::numeric_limits<double>::quiet_NaN()},
178+
{6.0, 6.0}});
179+
180+
// Create arrays of array vector using above base vector.
181+
// [[1.0, 1.0], [2.0, 2.0]]
182+
// [[3.0, 3.0], [4.0, 4.0]]
183+
// [[5.0, NaN], [6.0, 6.0]]
184+
auto arrayOfArrays1 = makeArrayVector({0, 2, 4}, baseVector);
185+
// [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]
186+
// [[4.0, 4.0]]
187+
// [[5.0, NaN], [6.0, 6.0]]
188+
auto arrayOfArrays2 = makeArrayVector({0, 3, 4}, baseVector);
189+
190+
// [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]
191+
// [[3.0, 3.0], [4.0, 4.0]]
192+
// [[5.0, NaN], [6.0, 6.0]]
193+
auto expected = makeArrayVector(
194+
{0, 3, 5},
195+
makeArrayVector<double>(
196+
{{1.0, 1.0},
197+
{2.0, 2.0},
198+
{3.0, 3.0},
199+
{3.0, 3.0},
200+
{4.0, 4.0},
201+
{5.0, std::numeric_limits<double>::quiet_NaN()},
202+
{6.0, 6.0}}));
203+
204+
testExpression(
205+
"array_union(c0, c1)", {arrayOfArrays1, arrayOfArrays2}, expected);
206+
}
207+
} // namespace
208+
} // namespace facebook::velox::functions::sparksql::test

velox/functions/sparksql/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ add_executable(
1818
ArrayMaxTest.cpp
1919
ArrayMinTest.cpp
2020
ArraySortTest.cpp
21+
ArrayUnionTest.cpp
2122
BitwiseTest.cpp
2223
ComparisonsTest.cpp
2324
DateTimeFunctionsTest.cpp

0 commit comments

Comments
 (0)