Skip to content

Commit e2ab5d5

Browse files
committed
test
1 parent c165355 commit e2ab5d5

File tree

12 files changed

+76
-14
lines changed

12 files changed

+76
-14
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ private[spark] object PythonEvalType {
7070
// Arrow UDFs
7171
val SQL_SCALAR_ARROW_UDF = 250
7272
val SQL_SCALAR_ARROW_ITER_UDF = 251
73+
val SQL_GROUPED_AGG_ARROW_UDF = 252
7374

7475
val SQL_TABLE_UDF = 300
7576
val SQL_ARROW_TABLE_UDF = 301
@@ -101,6 +102,7 @@ private[spark] object PythonEvalType {
101102
// Arrow UDFs
102103
case SQL_SCALAR_ARROW_UDF => "SQL_SCALAR_ARROW_UDF"
103104
case SQL_SCALAR_ARROW_ITER_UDF => "SQL_SCALAR_ARROW_ITER_UDF"
105+
case SQL_GROUPED_AGG_ARROW_UDF => "SQL_GROUPED_AGG_ARROW_UDF"
104106
}
105107
}
106108

python/pyspark/sql/connect/udf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,13 +280,14 @@ def register(
280280
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
281281
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
282282
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
283+
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
283284
]:
284285
raise PySparkTypeError(
285286
errorClass="INVALID_UDF_EVAL_TYPE",
286287
messageParameters={
287288
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
288-
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or "
289-
"SQL_GROUPED_AGG_PANDAS_UDF"
289+
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF, "
290+
"SQL_GROUPED_AGG_PANDAS_UDF or SQL_GROUPED_AGG_ARROW_UDF"
290291
},
291292
)
292293
self.sparkSession._client.register_udf(

python/pyspark/sql/pandas/_typing/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ GroupedMapUDFTransformWithStateInitStateType = Literal[214]
6363
# Arrow UDFs
6464
ArrowScalarUDFType = Literal[250]
6565
ArrowScalarIterUDFType = Literal[251]
66+
ArrowGroupedAggUDFType = Literal[252]
6667

6768
class ArrowVariadicScalarToScalarFunction(Protocol):
6869
def __call__(self, *_: pyarrow.Array) -> pyarrow.Array: ...

python/pyspark/sql/pandas/functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class ArrowUDFType:
4848

4949
SCALAR_ITER = PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF
5050

51+
GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF
52+
5153

5254
def arrow_udf(f=None, returnType=None, functionType=None):
5355
return vectorized_udf(f, returnType, functionType, "arrow")

python/pyspark/sql/pandas/typehints.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
PandasGroupedAggUDFType,
2828
ArrowScalarUDFType,
2929
ArrowScalarIterUDFType,
30+
ArrowGroupedAggUDFType,
3031
)
3132

3233

@@ -38,6 +39,7 @@ def infer_eval_type(
3839
"PandasGroupedAggUDFType",
3940
"ArrowScalarUDFType",
4041
"ArrowScalarIterUDFType",
42+
"ArrowGroupedAggUDFType",
4143
]:
4244
"""
4345
Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from
@@ -175,6 +177,13 @@ def infer_eval_type(
175177
and not check_tuple_annotation(return_annotation)
176178
)
177179

180+
# pa.Array, ... -> Any
181+
is_array_agg = all(a == pa.Array for a in parameters_sig) and (
182+
return_annotation != pa.Array
183+
and not check_iterator_annotation(return_annotation)
184+
and not check_tuple_annotation(return_annotation)
185+
)
186+
178187
if is_series_or_frame:
179188
return PandasUDFType.SCALAR
180189
elif is_arrow_array:
@@ -185,6 +194,8 @@ def infer_eval_type(
185194
return ArrowUDFType.SCALAR_ITER
186195
elif is_series_or_frame_agg:
187196
return PandasUDFType.GROUPED_AGG
197+
elif is_array_agg:
198+
return ArrowUDFType.GROUPED_AGG
188199
else:
189200
raise PySparkNotImplementedError(
190201
errorClass="UNSUPPORTED_SIGNATURE",

python/pyspark/sql/udf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,13 +655,14 @@ def register(
655655
PythonEvalType.SQL_SCALAR_ARROW_UDF,
656656
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
657657
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
658+
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
658659
]:
659660
raise PySparkTypeError(
660661
errorClass="INVALID_UDF_EVAL_TYPE",
661662
messageParameters={
662663
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
663-
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or "
664-
"SQL_GROUPED_AGG_PANDAS_UDF"
664+
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF, "
665+
"SQL_GROUPED_AGG_PANDAS_UDF or SQL_GROUPED_AGG_ARROW_UDF"
665666
},
666667
)
667668
source_udf = _create_udf(

python/pyspark/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
GroupedMapUDFTransformWithStateInitStateType,
6868
ArrowScalarUDFType,
6969
ArrowScalarIterUDFType,
70+
ArrowGroupedAggUDFType,
7071
)
7172
from pyspark.sql._typing import (
7273
SQLArrowBatchedUDFType,
@@ -651,6 +652,7 @@ class PythonEvalType:
651652
# Arrow UDFs
652653
SQL_SCALAR_ARROW_UDF: "ArrowScalarUDFType" = 250
653654
SQL_SCALAR_ARROW_ITER_UDF: "ArrowScalarIterUDFType" = 251
655+
SQL_GROUPED_AGG_ARROW_UDF: "ArrowGroupedAggUDFType" = 252
654656

655657
SQL_TABLE_UDF: "SQLTableUDFType" = 300
656658
SQL_ARROW_TABLE_UDF: "SQLArrowTableUDFType" = 301

python/pyspark/worker.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,25 @@ def wrapped(*series):
796796
)
797797

798798

799+
def wrap_grouped_agg_arrow_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf):
800+
func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets)
801+
802+
arrow_return_type = to_arrow_type(
803+
return_type, prefers_large_types=use_large_var_types(runner_conf)
804+
)
805+
806+
def wrapped(*series):
807+
import pyarrow as pa
808+
809+
result = func(*series)
810+
return pa.array([result])
811+
812+
return (
813+
args_kwargs_offsets,
814+
lambda *a: (wrapped(*a), arrow_return_type),
815+
)
816+
817+
799818
def wrap_window_agg_pandas_udf(
800819
f, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index
801820
):
@@ -974,6 +993,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil
974993
# The below doesn't support named argument, but shares the same protocol.
975994
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
976995
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
996+
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
977997
):
978998
args_offsets = []
979999
kwargs_offsets = {}
@@ -1070,6 +1090,10 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil
10701090
return wrap_grouped_agg_pandas_udf(
10711091
func, args_offsets, kwargs_offsets, return_type, runner_conf
10721092
)
1093+
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF:
1094+
return wrap_grouped_agg_arrow_udf(
1095+
func, args_offsets, kwargs_offsets, return_type, runner_conf
1096+
)
10731097
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
10741098
return wrap_window_agg_pandas_udf(
10751099
func, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index
@@ -1815,6 +1839,7 @@ def read_udfs(pickleSer, infile, eval_type):
18151839
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
18161840
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
18171841
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
1842+
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
18181843
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
18191844
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
18201845
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
@@ -1911,6 +1936,7 @@ def read_udfs(pickleSer, infile, eval_type):
19111936
elif eval_type in (
19121937
PythonEvalType.SQL_SCALAR_ARROW_UDF,
19131938
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
1939+
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
19141940
):
19151941
# Arrow cast for type coercion is disabled by default
19161942
ser = ArrowStreamArrowUDFSerializer(timezone, safecheck, _assign_cols_by_name, False)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,10 @@ case class PythonUDAF(
120120
dataType: DataType,
121121
children: Seq[Expression],
122122
udfDeterministic: Boolean,
123+
evalType: Int,
123124
resultId: ExprId = NamedExpression.newExprId)
124125
extends UnevaluableAggregateFunc with PythonFuncExpression {
125126

126-
override def evalType: Int = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
127-
128127
override def sql(isDistinct: Boolean): String = {
129128
val distinct = if (isDistinct) "DISTINCT " else ""
130129
s"$name($distinct${children.mkString(", ")})"

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
644644

645645
case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child)
646646
if aggExpressions.forall(_.aggregateFunction.isInstanceOf[PythonUDAF]) =>
647-
Seq(execution.python.AggregateInPandasExec(
647+
Seq(execution.python.ArrowAggregatePythonExec(
648648
groupingExpressions,
649649
aggExpressions,
650650
resultExpressions,

0 commit comments

Comments
 (0)