Skip to content

Commit a21f450

Browse files
dvogelbacheremanuelebardelli
authored andcommitted
[SPARK-27805][PYTHON] Propagate SparkExceptions during toPandas with arrow enabled
## What changes were proposed in this pull request? Similar to apache#24070, we now propagate SparkExceptions that are encountered during the collect in the java process to the python process. Fixes https://jira.apache.org/jira/browse/SPARK-27805 ## How was this patch tested? Added a new unit test Closes apache#24677 from dvogelbacher/dv/betterErrorMsgWhenUsingArrow. Authored-by: David Vogelbacher <[email protected]> Signed-off-by: Bryan Cutler <[email protected]>
1 parent 430eb48 commit a21f450

File tree

3 files changed

+44
-14
lines changed

3 files changed

+44
-14
lines changed

python/pyspark/serializers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,12 @@ def load_stream(self, stream):
206206
for batch in self.serializer.load_stream(stream):
207207
yield batch
208208

209-
# load the batch order indices
209+
# load the batch order indices or propagate any error that occurred in the JVM
210210
num = read_int(stream)
211+
if num == -1:
212+
error_msg = UTF8Deserializer().loads(stream)
213+
raise RuntimeError("An error occurred while calling "
214+
"ArrowCollectSerializer.load_stream: {}".format(error_msg))
211215
batch_order = []
212216
for i in xrange(num):
213217
index = read_int(stream)

python/pyspark/sql/tests/test_arrow.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import warnings
2424

2525
from pyspark.sql import Row
26+
from pyspark.sql.functions import udf
2627
from pyspark.sql.types import *
2728
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
2829
pandas_requirement_message, pyarrow_requirement_message
@@ -205,6 +206,17 @@ def test_no_partition_frame(self):
205206
self.assertEqual(pdf.columns[0], "field1")
206207
self.assertTrue(pdf.empty)
207208

209+
def test_propagates_spark_exception(self):
210+
df = self.spark.range(3).toDF("i")
211+
212+
def raise_exception():
213+
raise Exception("My error")
214+
exception_udf = udf(raise_exception, IntegerType())
215+
df = df.withColumn("error", exception_udf())
216+
with QuietTest(self.sc):
217+
with self.assertRaisesRegexp(RuntimeError, 'My error'):
218+
df.toPandas()
219+
208220
def _createDataFrame_toggle(self, pdf, schema=None):
209221
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
210222
df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import scala.util.control.NonFatal
2626

2727
import org.apache.commons.lang3.StringUtils
2828

29-
import org.apache.spark.TaskContext
29+
import org.apache.spark.{SparkException, TaskContext}
3030
import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable}
3131
import org.apache.spark.api.java.JavaRDD
3232
import org.apache.spark.api.java.function._
@@ -3321,20 +3321,34 @@ class Dataset[T] private[sql](
33213321
}
33223322
}
33233323

3324-
val arrowBatchRdd = toArrowBatchRdd(plan)
3325-
sparkSession.sparkContext.runJob(
3326-
arrowBatchRdd,
3327-
(it: Iterator[Array[Byte]]) => it.toArray,
3328-
handlePartitionBatches)
3324+
var sparkException: Option[SparkException] = None
3325+
try {
3326+
val arrowBatchRdd = toArrowBatchRdd(plan)
3327+
sparkSession.sparkContext.runJob(
3328+
arrowBatchRdd,
3329+
(it: Iterator[Array[Byte]]) => it.toArray,
3330+
handlePartitionBatches)
3331+
} catch {
3332+
case e: SparkException =>
3333+
sparkException = Some(e)
3334+
}
33293335

3330-
// After processing all partitions, end the stream and write batch order indices
3336+
// After processing all partitions, end the batch stream
33313337
batchWriter.end()
3332-
out.writeInt(batchOrder.length)
3333-
// Sort by (index of partition, batch index in that partition) tuple to get the
3334-
// overall_batch_index from 0 to N-1 batches, which can be used to put the
3335-
// transferred batches in the correct order
3336-
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
3337-
out.writeInt(overallBatchIndex)
3338+
sparkException match {
3339+
case Some(exception) =>
3340+
// Signal failure and write error message
3341+
out.writeInt(-1)
3342+
PythonRDD.writeUTF(exception.getMessage, out)
3343+
case None =>
3344+
// Write batch order indices
3345+
out.writeInt(batchOrder.length)
3346+
// Sort by (index of partition, batch index in that partition) tuple to get the
3347+
// overall_batch_index from 0 to N-1 batches, which can be used to put the
3348+
// transferred batches in the correct order
3349+
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
3350+
out.writeInt(overallBatchIndex)
3351+
}
33383352
}
33393353
}
33403354
}

0 commit comments

Comments
 (0)