Skip to content

Commit 494cde0

Browse files
committed
Propagate TaskContext to writer thread
1 parent 323bb2b commit 494cde0

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.hadoop.hive.serde.serdeConstants
2828
import org.apache.hadoop.hive.serde2.AbstractSerDe
2929
import org.apache.hadoop.hive.serde2.objectinspector._
3030

31-
import org.apache.spark.Logging
31+
import org.apache.spark.{TaskContext, Logging}
3232
import org.apache.spark.rdd.RDD
3333
import org.apache.spark.sql.catalyst.InternalRow
3434
import org.apache.spark.sql.catalyst.CatalystTypeConverters
@@ -98,7 +98,8 @@ case class ScriptTransformation(
9898
ioschema,
9999
outputStream,
100100
proc,
101-
stderrBuffer
101+
stderrBuffer,
102+
TaskContext.get()
102103
)
103104

104105
// This nullability is a performance optimization in order to avoid an Option.foreach() call
@@ -221,7 +222,8 @@ private class ScriptTransformationWriterThread(
221222
ioschema: HiveScriptIOSchema,
222223
outputStream: OutputStream,
223224
proc: Process,
224-
stderrBuffer: CircularBuffer
225+
stderrBuffer: CircularBuffer,
226+
taskContext: TaskContext
225227
) extends Thread("Thread-ScriptTransformation-Feed") with Logging {
226228

227229
setDaemon(true)
@@ -232,6 +234,8 @@ private class ScriptTransformationWriterThread(
232234
def exception: Option[Throwable] = Option(_exception)
233235

234236
override def run(): Unit = Utils.logUncaughtExceptions {
237+
TaskContext.setTaskContext(taskContext)
238+
235239
val dataOutputStream = new DataOutputStream(outputStream)
236240

237241
// We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution
2020
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
2121
import org.scalatest.exceptions.TestFailedException
2222

23+
import org.apache.spark.TaskContext
2324
import org.apache.spark.rdd.RDD
2425
import org.apache.spark.sql.SQLContext
2526
import org.apache.spark.sql.catalyst.InternalRow
@@ -113,6 +114,7 @@ class ScriptTransformationSuite extends SparkPlanTest {
113114
private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryNode {
114115
override protected def doExecute(): RDD[InternalRow] = {
115116
child.execute().map { x =>
117+
assert(TaskContext.get() != null) // Make sure that TaskContext is defined.
116118
Thread.sleep(1000) // This sleep gives the external process time to start.
117119
throw new IllegalArgumentException("intentional exception")
118120
}

0 commit comments

Comments
 (0)