-
Notifications
You must be signed in to change notification settings - Fork 28.7k
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs #12014
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,18 @@ def add_path(path): | |
sys.path.insert(1, path) | ||
|
||
|
||
def read_command(serializer, file): | ||
command = serializer._read_with_length(file) | ||
if isinstance(command, Broadcast): | ||
command = serializer.loads(command.value) | ||
return command | ||
|
||
|
||
def chain(f, g): | ||
"""chain two function together """ | ||
return lambda x: g(f(x)) | ||
|
||
|
||
def main(infile, outfile): | ||
try: | ||
boot_time = time.time() | ||
|
@@ -95,10 +107,23 @@ def main(infile, outfile): | |
_broadcastRegistry.pop(bid) | ||
|
||
_accumulatorRegistry.clear() | ||
command = pickleSer._read_with_length(infile) | ||
if isinstance(command, Broadcast): | ||
command = pickleSer.loads(command.value) | ||
func, profiler, deserializer, serializer = command | ||
row_based = read_int(infile) | ||
num_commands = read_int(infile) | ||
if row_based: | ||
profiler = None # profiling is not supported for UDF | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The other branch also have |
||
row_func = None | ||
for i in range(num_commands): | ||
f, returnType, deserializer = read_command(pickleSer, infile) | ||
if row_func is None: | ||
row_func = f | ||
else: | ||
row_func = chain(row_func, f) | ||
serializer = deserializer | ||
func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it) | ||
else: | ||
assert num_commands == 1 | ||
func, profiler, deserializer, serializer = read_command(pickleSer, infile) | ||
|
||
init_time = time.time() | ||
|
||
def process(): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
|
||
package org.apache.spark.sql.execution.python | ||
|
||
import org.apache.spark.sql.catalyst.expressions.Expression | ||
import org.apache.spark.sql.catalyst.plans.logical | ||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan | ||
import org.apache.spark.sql.catalyst.rules.Rule | ||
|
@@ -25,17 +26,40 @@ import org.apache.spark.sql.catalyst.rules.Rule | |
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated | ||
* alone in a batch. | ||
* | ||
* Only extracts the PythonUDFs that could be evaluated in Python (the single child is PythonUDFs | ||
* or all the children could be evaluated in JVM). | ||
* | ||
* This has the limitation that the input to the Python UDF is not allowed include attributes from | ||
* multiple child operators. | ||
*/ | ||
private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { | ||
|
||
private def hasPythonUDF(e: Expression): Boolean = { | ||
e.find(_.isInstanceOf[PythonUDF]).isDefined | ||
} | ||
|
||
private def canEvaluateInPython(e: PythonUDF): Boolean = { | ||
e.children match { | ||
// single PythonUDF child could be chained and evaluated in Python | ||
case Seq(u: PythonUDF) => canEvaluateInPython(u) | ||
// Python UDF can't be evaluated directly in JVM | ||
case children => !children.exists(hasPythonUDF) | ||
} | ||
} | ||
|
||
private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = { | ||
expr.collect { | ||
case udf: PythonUDF if canEvaluateInPython(udf) => udf | ||
} | ||
} | ||
|
||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { | ||
// Skip EvaluatePython nodes. | ||
case plan: EvaluatePython => plan | ||
|
||
case plan: LogicalPlan if plan.resolved => | ||
// Extract any PythonUDFs from the current operator. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should update the comments the explain our new strategy of extracting and evaluating python udfs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. |
||
val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) | ||
val udfs = plan.expressions.flatMap(collectEvaluatableUDF) | ||
if (udfs.isEmpty) { | ||
// If there aren't any, we are done. | ||
plan | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the point of creating a new
_wrap_function
here? To decrease the size of serialized python function?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh i see, we wanna chain the functions at python worker side.