@@ -22,10 +22,10 @@ import scala.collection.JavaConverters._
22
22
import net .razorvine .pickle .{Pickler , Unpickler }
23
23
24
24
import org .apache .spark .TaskContext
25
- import org .apache .spark .api .python .PythonRunner
25
+ import org .apache .spark .api .python .{ PythonFunction , PythonRunner }
26
26
import org .apache .spark .rdd .RDD
27
27
import org .apache .spark .sql .catalyst .InternalRow
28
- import org .apache .spark .sql .catalyst .expressions .{ Attribute , GenericMutableRow , JoinedRow , UnsafeProjection }
28
+ import org .apache .spark .sql .catalyst .expressions ._
29
29
import org .apache .spark .sql .execution .SparkPlan
30
30
import org .apache .spark .sql .types .{StructField , StructType }
31
31
@@ -45,6 +45,18 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
45
45
46
46
def children : Seq [SparkPlan ] = child :: Nil
47
47
48
+ private def collectFunctions (udf : PythonUDF ): (Seq [PythonFunction ], Seq [Expression ]) = {
49
+ udf.children match {
50
+ case Seq (u : PythonUDF ) =>
51
+ val (fs, children) = collectFunctions(u)
52
+ (fs ++ Seq (udf.func), children)
53
+ case children =>
54
+ // There should not be any other UDFs, or the children can't be evaluated directly.
55
+ assert(children.forall(_.find(_.isInstanceOf [PythonUDF ]).isEmpty))
56
+ (Seq (udf.func), udf.children)
57
+ }
58
+ }
59
+
48
60
protected override def doExecute (): RDD [InternalRow ] = {
49
61
val inputRDD = child.execute().map(_.copy())
50
62
val bufferSize = inputRDD.conf.getInt(" spark.buffer.size" , 65536 )
@@ -57,9 +69,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
57
69
// combine input with output from Python.
58
70
val queue = new java.util.concurrent.ConcurrentLinkedQueue [InternalRow ]()
59
71
72
+ val (pyFuncs, children) = collectFunctions(udf)
73
+
60
74
val pickle = new Pickler
61
- val currentRow = newMutableProjection(udf. children, child.output)()
62
- val fields = udf. children.map(_.dataType)
75
+ val currentRow = newMutableProjection(children, child.output)()
76
+ val fields = children.map(_.dataType)
63
77
val schema = new StructType (fields.map(t => new StructField (" " , t, true )).toArray)
64
78
65
79
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
@@ -75,7 +89,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
75
89
val context = TaskContext .get()
76
90
77
91
// Output iterator for results from Python.
78
- val outputIterator = new PythonRunner (Seq (udf.func) , bufferSize, reuseWorker, true )
92
+ val outputIterator = new PythonRunner (pyFuncs , bufferSize, reuseWorker, true )
79
93
.compute(inputIterator, context.partitionId(), context)
80
94
81
95
val unpickle = new Unpickler
0 commit comments