Skip to content

Commit 024a822

Browse files
author
Davies Liu
committed
support chained Python UDFs
1 parent d63ec84 commit 024a822

File tree

4 files changed

+54
-8
lines changed

4 files changed

+54
-8
lines changed

python/pyspark/sql/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,15 @@ def test_udf2(self):
305305
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
306306
self.assertEqual(4, res[0])
307307

308+
def test_chained_python_udf(self):
309+
self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
310+
[row] = self.sqlCtx.sql("SELECT double(1)").collect()
311+
self.assertEqual(row[0], 2)
312+
[row] = self.sqlCtx.sql("SELECT double(double(1))").collect()
313+
self.assertEqual(row[0], 4)
314+
[row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
315+
self.assertEqual(row[0], 6)
316+
308317
def test_udf_with_array_type(self):
309318
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
310319
rdd = self.sc.parallelize(d)

python/pyspark/worker.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ def read_command(serializer, file):
5757
return command
5858

5959

60+
def chain(f, g):
61+
"""chain two function together """
62+
return lambda x: g(f(x))
63+
64+
6065
def main(infile, outfile):
6166
try:
6267
boot_time = time.time()
@@ -112,8 +117,7 @@ def main(infile, outfile):
112117
if row_func is None:
113118
row_func = f
114119
else:
115-
# chain multiple UDF together
116-
row_func = lambda x: f(row_func(x))
120+
row_func = chain(row_func, f)
117121
serializer = deserializer
118122
func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
119123
else:

sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ import scala.collection.JavaConverters._
2222
import net.razorvine.pickle.{Pickler, Unpickler}
2323

2424
import org.apache.spark.TaskContext
25-
import org.apache.spark.api.python.PythonRunner
25+
import org.apache.spark.api.python.{PythonFunction, PythonRunner}
2626
import org.apache.spark.rdd.RDD
2727
import org.apache.spark.sql.catalyst.InternalRow
28-
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, JoinedRow, UnsafeProjection}
28+
import org.apache.spark.sql.catalyst.expressions._
2929
import org.apache.spark.sql.execution.SparkPlan
3030
import org.apache.spark.sql.types.{StructField, StructType}
3131

@@ -45,6 +45,18 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
4545

4646
def children: Seq[SparkPlan] = child :: Nil
4747

48+
private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = {
49+
udf.children match {
50+
case Seq(u: PythonUDF) =>
51+
val (fs, children) = collectFunctions(u)
52+
(fs ++ Seq(udf.func), children)
53+
case children =>
54+
// There should not be any other UDFs, or the children can't be evaluated directly.
55+
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
56+
(Seq(udf.func), udf.children)
57+
}
58+
}
59+
4860
protected override def doExecute(): RDD[InternalRow] = {
4961
val inputRDD = child.execute().map(_.copy())
5062
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
@@ -57,9 +69,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
5769
// combine input with output from Python.
5870
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
5971

72+
val (pyFuncs, children) = collectFunctions(udf)
73+
6074
val pickle = new Pickler
61-
val currentRow = newMutableProjection(udf.children, child.output)()
62-
val fields = udf.children.map(_.dataType)
75+
val currentRow = newMutableProjection(children, child.output)()
76+
val fields = children.map(_.dataType)
6377
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
6478

6579
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
@@ -75,7 +89,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
7589
val context = TaskContext.get()
7690

7791
// Output iterator for results from Python.
78-
val outputIterator = new PythonRunner(Seq(udf.func), bufferSize, reuseWorker, true)
92+
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true)
7993
.compute(inputIterator, context.partitionId(), context)
8094

8195
val unpickle = new Unpickler

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.python
1919

20+
import org.apache.spark.sql.catalyst.expressions.Expression
2021
import org.apache.spark.sql.catalyst.plans.logical
2122
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2223
import org.apache.spark.sql.catalyst.rules.Rule
@@ -29,13 +30,31 @@ import org.apache.spark.sql.catalyst.rules.Rule
2930
* multiple child operators.
3031
*/
3132
private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
33+
34+
private def hasUDF(e: Expression): Boolean = {
35+
e.find(_.isInstanceOf[PythonUDF]).isDefined
36+
}
37+
38+
private def canEvaluate(e: PythonUDF): Boolean = {
39+
e.children match {
40+
case Seq(u: PythonUDF) => canEvaluate(u)
41+
case children => !children.exists(hasUDF)
42+
}
43+
}
44+
45+
private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = {
46+
expr.collect {
47+
case udf: PythonUDF if canEvaluate(udf) => udf
48+
}
49+
}
50+
3251
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
3352
// Skip EvaluatePython nodes.
3453
case plan: EvaluatePython => plan
3554

3655
case plan: LogicalPlan if plan.resolved =>
3756
// Extract any PythonUDFs from the current operator.
38-
val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
57+
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
3958
if (udfs.isEmpty) {
4059
// If there aren't any, we are done.
4160
plan

0 commit comments

Comments
 (0)