Skip to content

Commit 72a601e

Browse files
committed
Merge pull request alteryx#152 from rxin/repl
Propagate SparkContext local properties from spark-repl caller thread to the repl execution thread.
2 parents dd63c54 + 3192999 commit 72a601e

File tree

3 files changed

+48
-4
lines changed

3 files changed

+48
-4
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,12 @@ class SparkContext(
280280
override protected def childValue(parent: Properties): Properties = new Properties(parent)
281281
}
282282

283+
private[spark] def getLocalProperties(): Properties = localProperties.get()
284+
285+
private[spark] def setLocalProperties(props: Properties) {
286+
localProperties.set(props)
287+
}
288+
283289
def initLocalProperties() {
284290
localProperties.set(new Properties())
285291
}

repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -878,14 +878,21 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends
878878
(message, false)
879879
}
880880
}
881+
882+
// Get a copy of the local properties from SparkContext, and set it later in the thread
883+
// that triggers the execution. This is to make sure the caller of this function can pass
884+
// the right thread local (inheritable) properties down into Spark.
885+
val sc = org.apache.spark.repl.Main.interp.sparkContext
886+
val props = if (sc != null) sc.getLocalProperties() else null
881887

882888
try {
883889
val execution = lineManager.set(originalLine) {
884890
// MATEI: set the right SparkEnv for our SparkContext, because
885891
// this execution will happen in a separate thread
886-
val sc = org.apache.spark.repl.Main.interp.sparkContext
887-
if (sc != null && sc.env != null)
892+
if (sc != null && sc.env != null) {
888893
SparkEnv.set(sc.env)
894+
sc.setLocalProperties(props)
895+
}
889896
// Execute the line
890897
lineRep call "$export"
891898
}

repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ import java.io._
2121
import java.net.URLClassLoader
2222

2323
import scala.collection.mutable.ArrayBuffer
24-
import scala.collection.JavaConversions._
2524

26-
import org.scalatest.FunSuite
2725
import com.google.common.io.Files
26+
import org.scalatest.FunSuite
27+
import org.apache.spark.SparkContext
28+
2829

2930
class ReplSuite extends FunSuite {
31+
3032
def runInterpreter(master: String, input: String): String = {
3133
val in = new BufferedReader(new StringReader(input + "\n"))
3234
val out = new StringWriter()
@@ -64,6 +66,35 @@ class ReplSuite extends FunSuite {
6466
"Interpreter output contained '" + message + "':\n" + output)
6567
}
6668

69+
test("propagation of local properties") {
70+
// A mock ILoop that doesn't install the SIGINT handler.
71+
class ILoop(out: PrintWriter) extends SparkILoop(None, out, None) {
72+
settings = new scala.tools.nsc.Settings
73+
settings.usejavacp.value = true
74+
org.apache.spark.repl.Main.interp = this
75+
override def createInterpreter() {
76+
intp = new SparkILoopInterpreter
77+
intp.setContextClassLoader()
78+
}
79+
}
80+
81+
val out = new StringWriter()
82+
val interp = new ILoop(new PrintWriter(out))
83+
interp.sparkContext = new SparkContext("local", "repl-test")
84+
interp.createInterpreter()
85+
interp.intp.initialize()
86+
interp.sparkContext.setLocalProperty("someKey", "someValue")
87+
88+
// Make sure the value we set in the caller to interpret is propagated in the thread that
89+
// interprets the command.
90+
interp.interpret("org.apache.spark.repl.Main.interp.sparkContext.getLocalProperty(\"someKey\")")
91+
assert(out.toString.contains("someValue"))
92+
93+
interp.sparkContext.stop()
94+
System.clearProperty("spark.driver.port")
95+
System.clearProperty("spark.hostPort")
96+
}
97+
6798
test ("simple foreach with accumulator") {
6899
val output = runInterpreter("local", """
69100
val accum = sc.accumulator(0)

0 commit comments

Comments
 (0)