Skip to content

Commit 7cd8f90

Browse files
xinrong-mengHyukjinKwon
authored andcommitted
[SPARK-43440][PYTHON][CONNECT] Support registration of an Arrow-optimized Python UDF
### What changes were proposed in this pull request? The PR proposes to provide support for the registration of an Arrow-optimized Python UDF in both vanilla PySpark and Spark Connect. ### Why are the changes needed? Currently, when users register an Arrow-optimized Python UDF, it will be registered as a pickled Python UDF and thus, executed without Arrow optimization. We should support Arrow-optimized Python UDFs registration and execute them with Arrow optimization. ### Does this PR introduce _any_ user-facing change? Yes. No API changes, but result differences are expected in some cases. Previously, a registered Arrow-optimized Python UDF will be executed without Arrow optimization. Now, it will be executed with Arrow optimization, as shown below. ```sh >>> df = spark.range(2) >>> df.createOrReplaceTempView("df") >>> from pyspark.sql.functions import udf >>> udf(useArrow=True) ... def f(x): ... return str(x) ... >>> spark.udf.register('str_f', f) <pyspark.sql.udf.UserDefinedFunction object at 0x7fa1980c16a0> >>> spark.sql("select str_f(id) from df").explain() # Executed with Arrow optimization == Physical Plan == *(2) Project [pythonUDF0#32 AS f(id)#30] +- ArrowEvalPython [f(id#27L)#29], [pythonUDF0#32], 101 +- *(1) Range (0, 2, step=1, splits=16) ``` Enabling or disabling Arrow optimization can produce result differences in some cases - we are working on minimizing the result differences though. ### How was this patch tested? Unit test. Closes #41125 from xinrong-meng/registerArrowPythonUDF. Authored-by: Xinrong Meng <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent dd4db21 commit 7cd8f90

File tree

4 files changed

+39
-14
lines changed

4 files changed

+39
-14
lines changed

python/pyspark/sql/connect/udf.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -252,31 +252,32 @@ def register(
252252
f = cast("UserDefinedFunctionLike", f)
253253
if f.evalType not in [
254254
PythonEvalType.SQL_BATCHED_UDF,
255+
PythonEvalType.SQL_ARROW_BATCHED_UDF,
255256
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
256257
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
257258
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
258259
]:
259260
raise PySparkTypeError(
260261
error_class="INVALID_UDF_EVAL_TYPE",
261262
message_parameters={
262-
"eval_type": "SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
263-
"SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF"
263+
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
264+
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or "
265+
"SQL_GROUPED_AGG_PANDAS_UDF"
264266
},
265267
)
266-
return_udf = f
267268
self.sparkSession._client.register_udf(
268269
f.func, f.returnType, name, f.evalType, f.deterministic
269270
)
271+
return f
270272
else:
271273
if returnType is None:
272274
returnType = StringType()
273-
return_udf = _create_udf(
275+
py_udf = _create_udf(
274276
f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, name=name
275277
)
276278

277-
self.sparkSession._client.register_udf(f, returnType, name)
278-
279-
return return_udf
279+
self.sparkSession._client.register_udf(py_udf.func, returnType, name)
280+
return py_udf
280281

281282
register.__doc__ = PySparkUDFRegistration.register.__doc__
282283

python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def test_register_grouped_map_udf(self):
219219
exception=pe.exception,
220220
error_class="INVALID_UDF_EVAL_TYPE",
221221
message_parameters={
222-
"eval_type": "SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
222+
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
223223
"SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF"
224224
},
225225
)

python/pyspark/sql/tests/test_arrow_python_udf.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,24 @@ def test_eval_type(self):
119119
udf(lambda x: str(x), useArrow=False).evalType, PythonEvalType.SQL_BATCHED_UDF
120120
)
121121

122+
def test_register(self):
123+
df = self.spark.range(1).selectExpr(
124+
"array(1, 2, 3) as array",
125+
)
126+
str_repr_func = self.spark.udf.register("str_repr", udf(lambda x: str(x), useArrow=True))
127+
128+
# To verify that Arrow optimization is on
129+
self.assertEquals(
130+
df.selectExpr("str_repr(array) AS str_id").first()[0],
131+
"[1 2 3]", # The input is a NumPy array when the Arrow optimization is on
132+
)
133+
134+
# To verify that a UserDefinedFunction is returned
135+
self.assertListEqual(
136+
df.selectExpr("str_repr(array) AS str_id").collect(),
137+
df.select(str_repr_func("array").alias("str_id")).collect(),
138+
)
139+
122140

123141
class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
124142
@classmethod

python/pyspark/sql/udf.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -623,32 +623,38 @@ def register(
623623
f = cast("UserDefinedFunctionLike", f)
624624
if f.evalType not in [
625625
PythonEvalType.SQL_BATCHED_UDF,
626+
PythonEvalType.SQL_ARROW_BATCHED_UDF,
626627
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
627628
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
628629
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
629630
]:
630631
raise PySparkTypeError(
631632
error_class="INVALID_UDF_EVAL_TYPE",
632633
message_parameters={
633-
"eval_type": "SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
634-
"SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF"
634+
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
635+
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or "
636+
"SQL_GROUPED_AGG_PANDAS_UDF"
635637
},
636638
)
637-
register_udf = _create_udf(
639+
source_udf = _create_udf(
638640
f.func,
639641
returnType=f.returnType,
640642
name=name,
641643
evalType=f.evalType,
642644
deterministic=f.deterministic,
643-
)._unwrapped # type: ignore[attr-defined]
644-
return_udf = f
645+
)
646+
if f.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF:
647+
register_udf = _create_arrow_py_udf(source_udf)._unwrapped
648+
else:
649+
register_udf = source_udf._unwrapped # type: ignore[attr-defined]
650+
return_udf = register_udf
645651
else:
646652
if returnType is None:
647653
returnType = StringType()
648654
return_udf = _create_udf(
649655
f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, name=name
650656
)
651-
register_udf = return_udf._unwrapped # type: ignore[attr-defined]
657+
register_udf = return_udf._unwrapped
652658
self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf)
653659
return return_udf
654660

0 commit comments

Comments
 (0)