Skip to content

[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs #12014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val runner = new PythonRunner(func, bufferSize, reuse_worker)
val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
Expand All @@ -81,14 +81,18 @@ private[spark] case class PythonFunction(
* A helper class to run Python UDFs in Spark.
*/
private[spark] class PythonRunner(
func: PythonFunction,
funcs: Seq[PythonFunction],
bufferSize: Int,
reuse_worker: Boolean)
reuse_worker: Boolean,
rowBased: Boolean)
extends Logging {

private val envVars = func.envVars
private val pythonExec = func.pythonExec
private val accumulator = func.accumulator
// All the Python functions should have the same exec, version and envvars.
private val envVars = funcs.head.envVars
private val pythonExec = funcs.head.pythonExec
private val pythonVer = funcs.head.pythonVer

private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF

def compute(
inputIterator: Iterator[_],
Expand Down Expand Up @@ -228,10 +232,8 @@ private[spark] class PythonRunner(

@volatile private var _exception: Exception = null

private val pythonVer = func.pythonVer
private val pythonIncludes = func.pythonIncludes
private val broadcastVars = func.broadcastVars
private val command = func.command
private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet
private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala)

setDaemon(true)

Expand All @@ -256,13 +258,13 @@ private[spark] class PythonRunner(
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.size())
for (include <- pythonIncludes.asScala) {
dataOut.writeInt(pythonIncludes.size)
for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
val newBids = broadcastVars.asScala.map(_.id).toSet
val newBids = broadcastVars.map(_.id).toSet
// number of different broadcasts
val toRemove = oldBids.diff(newBids)
val cnt = toRemove.size + newBids.diff(oldBids).size
Expand All @@ -272,7 +274,7 @@ private[spark] class PythonRunner(
dataOut.writeLong(- bid - 1) // bid >= 0
oldBids.remove(bid)
}
for (broadcast <- broadcastVars.asScala) {
for (broadcast <- broadcastVars) {
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
Expand All @@ -282,8 +284,12 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(command.length)
dataOut.write(command)
dataOut.writeInt(if (rowBased) 1 else 0)
dataOut.writeInt(funcs.length)
funcs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
}
// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
Expand Down
16 changes: 11 additions & 5 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from itertools import imap as map

from pyspark import since, SparkContext
from pyspark.rdd import _wrap_function, ignore_unicode_prefix
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType
from pyspark.sql.column import Column, _to_java_column, _to_seq
Expand Down Expand Up @@ -1648,6 +1648,14 @@ def sort_array(col, asc=True):

# ---------------------------- User Defined Function ----------------------------------

def _wrap_function(sc, func, returnType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the point of creating a new _wrap_function here? To decrease the size of serialized python function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see, we wanna chain the functions at python worker side.

ser = AutoBatchedSerializer(PickleSerializer())
command = (func, returnType, ser)
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
sc.pythonVer, broadcast_vars, sc._javaAccumulator)


class UserDefinedFunction(object):
"""
User defined function in Python
Expand All @@ -1662,14 +1670,12 @@ def __init__(self, func, returnType, name=None):

def _create_judf(self, name):
from pyspark.sql import SQLContext
f, returnType = self.func, self.returnType # put them in closure `func`
func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
ser = AutoBatchedSerializer(PickleSerializer())
sc = SparkContext.getOrCreate()
wrapped_func = _wrap_function(sc, func, ser, ser)
wrapped_func = _wrap_function(sc, self.func, self.returnType)
ctx = SQLContext.getOrCreate(sc)
jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
if name is None:
f = self.func
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
name, wrapped_func, jdt)
Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,15 @@ def test_udf2(self):
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])

def test_chained_python_udf(self):
self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(1)").collect()
self.assertEqual(row[0], 2)
[row] = self.sqlCtx.sql("SELECT double(double(1))").collect()
self.assertEqual(row[0], 4)
[row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
self.assertEqual(row[0], 6)

def test_udf_with_array_type(self):
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
Expand Down
33 changes: 29 additions & 4 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ def add_path(path):
sys.path.insert(1, path)


def read_command(serializer, file):
command = serializer._read_with_length(file)
if isinstance(command, Broadcast):
command = serializer.loads(command.value)
return command


def chain(f, g):
"""chain two function together """
return lambda x: g(f(x))


def main(infile, outfile):
try:
boot_time = time.time()
Expand Down Expand Up @@ -95,10 +107,23 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)

_accumulatorRegistry.clear()
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
func, profiler, deserializer, serializer = command
row_based = read_int(infile)
num_commands = read_int(infile)
if row_based:
profiler = None # profiling is not supported for UDF
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

profiler seems need to be defined before this if block. The codes refer profiler later out of this block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other branch also have profiler, so I think it's fine.

row_func = None
for i in range(num_commands):
f, returnType, deserializer = read_command(pickleSer, infile)
if row_func is None:
row_func = f
else:
row_func = chain(row_func, f)
serializer = deserializer
func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
else:
assert num_commands == 1
func, profiler, deserializer, serializer = read_command(pickleSer, infile)

init_time = time.time()

def process():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import scala.collection.JavaConverters._
import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.TaskContext
import org.apache.spark.api.python.PythonRunner
import org.apache.spark.api.python.{PythonFunction, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, JoinedRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}

Expand All @@ -45,6 +45,18 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:

def children: Seq[SparkPlan] = child :: Nil

private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
val (fs, children) = collectFunctions(u)
(fs ++ Seq(udf.func), children)
case children =>
// There should not be any other UDFs, or the children can't be evaluated directly.
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
(Seq(udf.func), udf.children)
}
}

protected override def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute().map(_.copy())
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
Expand All @@ -57,9 +69,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
// combine input with output from Python.
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()

val (pyFuncs, children) = collectFunctions(udf)

val pickle = new Pickler
val currentRow = newMutableProjection(udf.children, child.output)()
val fields = udf.children.map(_.dataType)
val currentRow = newMutableProjection(children, child.output)()
val fields = children.map(_.dataType)
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)

// Input iterator to Python: input rows are grouped so we send them in batches to Python.
Expand All @@ -75,11 +89,8 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
val context = TaskContext.get()

// Output iterator for results from Python.
val outputIterator = new PythonRunner(
udf.func,
bufferSize,
reuseWorker
).compute(inputIterator, context.partitionId(), context)
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
val row = new GenericMutableRow(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution.python

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -25,17 +26,40 @@ import org.apache.spark.sql.catalyst.rules.Rule
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
* alone in a batch.
*
* Only extracts the PythonUDFs that could be evaluated in Python (the single child is PythonUDFs
* or all the children could be evaluated in JVM).
*
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {

private def hasPythonUDF(e: Expression): Boolean = {
e.find(_.isInstanceOf[PythonUDF]).isDefined
}

private def canEvaluateInPython(e: PythonUDF): Boolean = {
e.children match {
// single PythonUDF child could be chained and evaluated in Python
case Seq(u: PythonUDF) => canEvaluateInPython(u)
// Python UDF can't be evaluated directly in JVM
case children => !children.exists(hasPythonUDF)
}
}

private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = {
expr.collect {
case udf: PythonUDF if canEvaluateInPython(udf) => udf
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Skip EvaluatePython nodes.
case plan: EvaluatePython => plan

case plan: LogicalPlan if plan.resolved =>
// Extract any PythonUDFs from the current operator.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update the comments the explain our new strategy of extracting and evaluating python udfs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
if (udfs.isEmpty) {
// If there aren't any, we are done.
plan
Expand Down