Skip to content

Commit 3eb0ee0

Browse files
JoshRosengatorsmile
authored andcommitted
[SPARK-20685] Fix BatchPythonEvaluation bug in case of single UDF w/ repeated arg.
## What changes were proposed in this pull request? There's a latent corner-case bug in PySpark UDF evaluation where executing a `BatchPythonEvaluation` with a single multi-argument UDF where _at least one argument value is repeated_ will crash at execution with a confusing error. This problem was introduced in #12057: the code there has a fast path for handling a "batch UDF evaluation consisting of a single Python UDF", but that branch incorrectly assumes that a single UDF won't have repeated arguments and therefore skips the code for unpacking arguments from the input row (whose schema may not necessarily match the UDF inputs due to de-duplication of repeated arguments which occurred in the JVM before sending UDF inputs to Python). This fix here is simply to remove this special-casing: it turns out that the code in the "multiple UDFs" branch just so happens to work for the single-UDF case because Python treats `(x)` as equivalent to `x`, not as a single-argument tuple. ## How was this patch tested? New regression test in `pyspark.python.sql.tests` module (tested and confirmed that it fails before my fix). Author: Josh Rosen <[email protected]> Closes #17927 from JoshRosen/SPARK-20685. (cherry picked from commit 8ddbc43) Signed-off-by: Xiao Li <[email protected]>
1 parent 86cef4d commit 3eb0ee0

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

python/pyspark/sql/tests.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,12 @@ def test_chained_udf(self):
324324
[row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
325325
self.assertEqual(row[0], 6)
326326

327+
def test_single_udf_with_repeated_argument(self):
328+
# regression test for SPARK-20685
329+
self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
330+
row = self.spark.sql("SELECT add(1, 1)").first()
331+
self.assertEqual(tuple(row), (2, ))
332+
327333
def test_multiple_udfs(self):
328334
self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType())
329335
[row] = self.spark.sql("SELECT double(1), double(2)").collect()

python/pyspark/worker.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,22 +87,19 @@ def read_single_udf(pickleSer, infile):
8787

8888
def read_udfs(pickleSer, infile):
8989
num_udfs = read_int(infile)
90-
if num_udfs == 1:
91-
# fast path for single UDF
92-
_, udf = read_single_udf(pickleSer, infile)
93-
mapper = lambda a: udf(*a)
94-
else:
95-
udfs = {}
96-
call_udf = []
97-
for i in range(num_udfs):
98-
arg_offsets, udf = read_single_udf(pickleSer, infile)
99-
udfs['f%d' % i] = udf
100-
args = ["a[%d]" % o for o in arg_offsets]
101-
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
102-
# Create function like this:
103-
# lambda a: (f0(a0), f1(a1, a2), f2(a3))
104-
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
105-
mapper = eval(mapper_str, udfs)
90+
udfs = {}
91+
call_udf = []
92+
for i in range(num_udfs):
93+
arg_offsets, udf = read_single_udf(pickleSer, infile)
94+
udfs['f%d' % i] = udf
95+
args = ["a[%d]" % o for o in arg_offsets]
96+
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
97+
# Create function like this:
98+
# lambda a: (f0(a0), f1(a1, a2), f2(a3))
99+
# In the special case of a single UDF this will return a single result rather
100+
# than a tuple of results; this is the format that the JVM side expects.
101+
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
102+
mapper = eval(mapper_str, udfs)
106103

107104
func = lambda _, it: map(mapper, it)
108105
ser = BatchedSerializer(PickleSerializer(), 100)

0 commit comments

Comments
 (0)