Skip to content

Commit a7a93a1

Browse files
Davies Liudavies
authored andcommitted
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request? This PR brings the support for chained Python UDFs, for example ```sql select udf1(udf2(a)) select udf1(udf2(a) + 3) select udf1(udf2(a) + udf3(b)) ``` Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches. For example, ```python >>> sqlContext.sql("select double(double(1))").explain() == Physical Plan == WholeStageCodegen : +- Project [pythonUDF#10 AS double(double(1))#9] : +- INPUT +- !BatchPythonEvaluation double(double(1)), [pythonUDF#10] +- Scan OneRowRelation[] >>> sqlContext.sql("select double(double(1) + double(2))").explain() == Physical Plan == WholeStageCodegen : +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16] : +- INPUT +- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19] +- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18] +- !BatchPythonEvaluation double(1), [pythonUDF#17] +- Scan OneRowRelation[] ``` TODO: will support multiple unrelated Python UDFs in one batch (another PR). ## How was this patch tested? Added new unit tests for chained UDFs. Author: Davies Liu <[email protected]> Closes apache#12014 from davies/py_udfs.
1 parent e58c4cb commit a7a93a1

File tree

6 files changed

+116
-35
lines changed

6 files changed

+116
-35
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ private[spark] class PythonRDD(
5959
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
6060

6161
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
62-
val runner = new PythonRunner(func, bufferSize, reuse_worker)
62+
val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false)
6363
runner.compute(firstParent.iterator(split, context), split.index, context)
6464
}
6565
}
@@ -81,14 +81,18 @@ private[spark] case class PythonFunction(
8181
* A helper class to run Python UDFs in Spark.
8282
*/
8383
private[spark] class PythonRunner(
84-
func: PythonFunction,
84+
funcs: Seq[PythonFunction],
8585
bufferSize: Int,
86-
reuse_worker: Boolean)
86+
reuse_worker: Boolean,
87+
rowBased: Boolean)
8788
extends Logging {
8889

89-
private val envVars = func.envVars
90-
private val pythonExec = func.pythonExec
91-
private val accumulator = func.accumulator
90+
// All the Python functions should have the same exec, version and envvars.
91+
private val envVars = funcs.head.envVars
92+
private val pythonExec = funcs.head.pythonExec
93+
private val pythonVer = funcs.head.pythonVer
94+
95+
private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF
9296

9397
def compute(
9498
inputIterator: Iterator[_],
@@ -228,10 +232,8 @@ private[spark] class PythonRunner(
228232

229233
@volatile private var _exception: Exception = null
230234

231-
private val pythonVer = func.pythonVer
232-
private val pythonIncludes = func.pythonIncludes
233-
private val broadcastVars = func.broadcastVars
234-
private val command = func.command
235+
private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet
236+
private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala)
235237

236238
setDaemon(true)
237239

@@ -256,13 +258,13 @@ private[spark] class PythonRunner(
256258
// sparkFilesDir
257259
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
258260
// Python includes (*.zip and *.egg files)
259-
dataOut.writeInt(pythonIncludes.size())
260-
for (include <- pythonIncludes.asScala) {
261+
dataOut.writeInt(pythonIncludes.size)
262+
for (include <- pythonIncludes) {
261263
PythonRDD.writeUTF(include, dataOut)
262264
}
263265
// Broadcast variables
264266
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
265-
val newBids = broadcastVars.asScala.map(_.id).toSet
267+
val newBids = broadcastVars.map(_.id).toSet
266268
// number of different broadcasts
267269
val toRemove = oldBids.diff(newBids)
268270
val cnt = toRemove.size + newBids.diff(oldBids).size
@@ -272,7 +274,7 @@ private[spark] class PythonRunner(
272274
dataOut.writeLong(- bid - 1) // bid >= 0
273275
oldBids.remove(bid)
274276
}
275-
for (broadcast <- broadcastVars.asScala) {
277+
for (broadcast <- broadcastVars) {
276278
if (!oldBids.contains(broadcast.id)) {
277279
// send new broadcast
278280
dataOut.writeLong(broadcast.id)
@@ -282,8 +284,12 @@ private[spark] class PythonRunner(
282284
}
283285
dataOut.flush()
284286
// Serialized command:
285-
dataOut.writeInt(command.length)
286-
dataOut.write(command)
287+
dataOut.writeInt(if (rowBased) 1 else 0)
288+
dataOut.writeInt(funcs.length)
289+
funcs.foreach { f =>
290+
dataOut.writeInt(f.command.length)
291+
dataOut.write(f.command)
292+
}
287293
// Data values
288294
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
289295
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)

python/pyspark/sql/functions.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from itertools import imap as map
2626

2727
from pyspark import since, SparkContext
28-
from pyspark.rdd import _wrap_function, ignore_unicode_prefix
28+
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
2929
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
3030
from pyspark.sql.types import StringType
3131
from pyspark.sql.column import Column, _to_java_column, _to_seq
@@ -1648,6 +1648,14 @@ def sort_array(col, asc=True):
16481648

16491649
# ---------------------------- User Defined Function ----------------------------------
16501650

1651+
def _wrap_function(sc, func, returnType):
1652+
ser = AutoBatchedSerializer(PickleSerializer())
1653+
command = (func, returnType, ser)
1654+
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
1655+
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
1656+
sc.pythonVer, broadcast_vars, sc._javaAccumulator)
1657+
1658+
16511659
class UserDefinedFunction(object):
16521660
"""
16531661
User defined function in Python
@@ -1662,14 +1670,12 @@ def __init__(self, func, returnType, name=None):
16621670

16631671
def _create_judf(self, name):
16641672
from pyspark.sql import SQLContext
1665-
f, returnType = self.func, self.returnType # put them in closure `func`
1666-
func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
1667-
ser = AutoBatchedSerializer(PickleSerializer())
16681673
sc = SparkContext.getOrCreate()
1669-
wrapped_func = _wrap_function(sc, func, ser, ser)
1674+
wrapped_func = _wrap_function(sc, self.func, self.returnType)
16701675
ctx = SQLContext.getOrCreate(sc)
16711676
jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
16721677
if name is None:
1678+
f = self.func
16731679
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
16741680
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
16751681
name, wrapped_func, jdt)

python/pyspark/sql/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,15 @@ def test_udf2(self):
305305
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
306306
self.assertEqual(4, res[0])
307307

308+
def test_chained_python_udf(self):
309+
self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
310+
[row] = self.sqlCtx.sql("SELECT double(1)").collect()
311+
self.assertEqual(row[0], 2)
312+
[row] = self.sqlCtx.sql("SELECT double(double(1))").collect()
313+
self.assertEqual(row[0], 4)
314+
[row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
315+
self.assertEqual(row[0], 6)
316+
308317
def test_udf_with_array_type(self):
309318
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
310319
rdd = self.sc.parallelize(d)

python/pyspark/worker.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ def add_path(path):
5050
sys.path.insert(1, path)
5151

5252

53+
def read_command(serializer, file):
54+
command = serializer._read_with_length(file)
55+
if isinstance(command, Broadcast):
56+
command = serializer.loads(command.value)
57+
return command
58+
59+
60+
def chain(f, g):
61+
"""chain two function together """
62+
return lambda x: g(f(x))
63+
64+
5365
def main(infile, outfile):
5466
try:
5567
boot_time = time.time()
@@ -95,10 +107,23 @@ def main(infile, outfile):
95107
_broadcastRegistry.pop(bid)
96108

97109
_accumulatorRegistry.clear()
98-
command = pickleSer._read_with_length(infile)
99-
if isinstance(command, Broadcast):
100-
command = pickleSer.loads(command.value)
101-
func, profiler, deserializer, serializer = command
110+
row_based = read_int(infile)
111+
num_commands = read_int(infile)
112+
if row_based:
113+
profiler = None # profiling is not supported for UDF
114+
row_func = None
115+
for i in range(num_commands):
116+
f, returnType, deserializer = read_command(pickleSer, infile)
117+
if row_func is None:
118+
row_func = f
119+
else:
120+
row_func = chain(row_func, f)
121+
serializer = deserializer
122+
func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
123+
else:
124+
assert num_commands == 1
125+
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
126+
102127
init_time = time.time()
103128

104129
def process():

sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ import scala.collection.JavaConverters._
2222
import net.razorvine.pickle.{Pickler, Unpickler}
2323

2424
import org.apache.spark.TaskContext
25-
import org.apache.spark.api.python.PythonRunner
25+
import org.apache.spark.api.python.{PythonFunction, PythonRunner}
2626
import org.apache.spark.rdd.RDD
2727
import org.apache.spark.sql.catalyst.InternalRow
28-
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, JoinedRow, UnsafeProjection}
28+
import org.apache.spark.sql.catalyst.expressions._
2929
import org.apache.spark.sql.execution.SparkPlan
3030
import org.apache.spark.sql.types.{StructField, StructType}
3131

@@ -45,6 +45,18 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
4545

4646
def children: Seq[SparkPlan] = child :: Nil
4747

48+
private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = {
49+
udf.children match {
50+
case Seq(u: PythonUDF) =>
51+
val (fs, children) = collectFunctions(u)
52+
(fs ++ Seq(udf.func), children)
53+
case children =>
54+
// There should not be any other UDFs, or the children can't be evaluated directly.
55+
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
56+
(Seq(udf.func), udf.children)
57+
}
58+
}
59+
4860
protected override def doExecute(): RDD[InternalRow] = {
4961
val inputRDD = child.execute().map(_.copy())
5062
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
@@ -57,9 +69,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
5769
// combine input with output from Python.
5870
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
5971

72+
val (pyFuncs, children) = collectFunctions(udf)
73+
6074
val pickle = new Pickler
61-
val currentRow = newMutableProjection(udf.children, child.output)()
62-
val fields = udf.children.map(_.dataType)
75+
val currentRow = newMutableProjection(children, child.output)()
76+
val fields = children.map(_.dataType)
6377
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
6478

6579
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
@@ -75,11 +89,8 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
7589
val context = TaskContext.get()
7690

7791
// Output iterator for results from Python.
78-
val outputIterator = new PythonRunner(
79-
udf.func,
80-
bufferSize,
81-
reuseWorker
82-
).compute(inputIterator, context.partitionId(), context)
92+
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true)
93+
.compute(inputIterator, context.partitionId(), context)
8394

8495
val unpickle = new Unpickler
8596
val row = new GenericMutableRow(1)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.python
1919

20+
import org.apache.spark.sql.catalyst.expressions.Expression
2021
import org.apache.spark.sql.catalyst.plans.logical
2122
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2223
import org.apache.spark.sql.catalyst.rules.Rule
@@ -25,17 +26,40 @@ import org.apache.spark.sql.catalyst.rules.Rule
2526
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
2627
* alone in a batch.
2728
*
29+
* Only extracts the PythonUDFs that could be evaluated in Python (the single child is PythonUDFs
30+
* or all the children could be evaluated in JVM).
31+
*
2832
* This has the limitation that the input to the Python UDF is not allowed include attributes from
2933
* multiple child operators.
3034
*/
3135
private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
36+
37+
private def hasPythonUDF(e: Expression): Boolean = {
38+
e.find(_.isInstanceOf[PythonUDF]).isDefined
39+
}
40+
41+
private def canEvaluateInPython(e: PythonUDF): Boolean = {
42+
e.children match {
43+
// single PythonUDF child could be chained and evaluated in Python
44+
case Seq(u: PythonUDF) => canEvaluateInPython(u)
45+
// Python UDF can't be evaluated directly in JVM
46+
case children => !children.exists(hasPythonUDF)
47+
}
48+
}
49+
50+
private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = {
51+
expr.collect {
52+
case udf: PythonUDF if canEvaluateInPython(udf) => udf
53+
}
54+
}
55+
3256
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
3357
// Skip EvaluatePython nodes.
3458
case plan: EvaluatePython => plan
3559

3660
case plan: LogicalPlan if plan.resolved =>
3761
// Extract any PythonUDFs from the current operator.
38-
val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
62+
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
3963
if (udfs.isEmpty) {
4064
// If there aren't any, we are done.
4165
plan

0 commit comments

Comments
 (0)