Skip to content

Commit 032d179

Browse files
ueshinHyukjinKwon
authored andcommitted
[SPARK-31945][SQL][PYSPARK] Enable cache for the same Python function
### What changes were proposed in this pull request? This PR proposes to make `PythonFunction` holds `Seq[Byte]` instead of `Array[Byte]` to be able to compare if the byte array has the same values for the cache manager. ### Why are the changes needed? Currently the cache manager doesn't use the cache for `udf` if the `udf` is created again even if the functions is the same. ```py >>> func = lambda x: x >>> df = spark.range(1) >>> df.select(udf(func)("id")).cache() ``` ```py >>> df.select(udf(func)("id")).explain() == Physical Plan == *(2) Project [pythonUDF0#14 AS <lambda>(id)#12] +- BatchEvalPython [<lambda>(id#0L)], [pythonUDF0#14] +- *(1) Range (0, 1, step=1, splits=12) ``` This is because `PythonFunction` holds `Array[Byte]`, and `equals` method of array equals only when the both array is the same instance. ### Does this PR introduce _any_ user-facing change? Yes, if the user reuse the Python function for the UDF, the cache manager will detect the same function and use the cache for it. ### How was this patch tested? I added a test case and manually. ```py >>> df.select(udf(func)("id")).explain() == Physical Plan == InMemoryTableScan [<lambda>(id)#12] +- InMemoryRelation [<lambda>(id)#12], StorageLevel(disk, memory, deserialized, 1 replicas) +- *(2) Project [pythonUDF0#5 AS <lambda>(id)#3] +- BatchEvalPython [<lambda>(id#0L)], [pythonUDF0#5] +- *(1) Range (0, 1, step=1, splits=12) ``` Closes #28774 from ueshin/issues/SPARK-31945/udf_cache. Authored-by: Takuya UESHIN <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent e14029b commit 032d179

File tree

4 files changed

+25
-4
lines changed

4 files changed

+25
-4
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,25 @@ private[spark] class PythonRDD(
7474
* runner.
7575
*/
7676
private[spark] case class PythonFunction(
77-
command: Array[Byte],
77+
command: Seq[Byte],
7878
envVars: JMap[String, String],
7979
pythonIncludes: JList[String],
8080
pythonExec: String,
8181
pythonVer: String,
8282
broadcastVars: JList[Broadcast[PythonBroadcast]],
83-
accumulator: PythonAccumulatorV2)
83+
accumulator: PythonAccumulatorV2) {
84+
85+
def this(
86+
command: Array[Byte],
87+
envVars: JMap[String, String],
88+
pythonIncludes: JList[String],
89+
pythonExec: String,
90+
pythonVer: String,
91+
broadcastVars: JList[Broadcast[PythonBroadcast]],
92+
accumulator: PythonAccumulatorV2) = {
93+
this(command.toSeq, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator)
94+
}
95+
}
8496

8597
/**
8698
* A wrapper for chained Python functions (from bottom to top).

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions])
613613
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
614614
val command = funcs.head.funcs.head.command
615615
dataOut.writeInt(command.length)
616-
dataOut.write(command)
616+
dataOut.write(command.toArray)
617617
}
618618

619619
protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {

python/pyspark/sql/tests/test_udf.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,15 @@ def f(*a):
642642
r = df.select(fUdf(*df.columns))
643643
self.assertEqual(r.first()[0], "success")
644644

645+
def test_udf_cache(self):
646+
func = lambda x: x
647+
648+
df = self.spark.range(1)
649+
df.select(udf(func)("id")).cache()
650+
651+
self.assertEqual(df.select(udf(func)("id"))._jdf.queryExecution()
652+
.withCachedData().getClass().getSimpleName(), 'InMemoryRelation')
653+
645654

646655
class UDFInitializationTests(unittest.TestCase):
647656
def tearDown(self):

sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ object PythonUDFRunner {
104104
dataOut.writeInt(chained.funcs.length)
105105
chained.funcs.foreach { f =>
106106
dataOut.writeInt(f.command.length)
107-
dataOut.write(f.command)
107+
dataOut.write(f.command.toArray)
108108
}
109109
}
110110
}

0 commit comments

Comments
 (0)