diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index a3e7276fc83e1..511ea3ca6ba04 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy.worker import java.io.{File, FileOutputStream, InputStream, IOException} import scala.collection.Map +import scala.collection.mutable import scala.jdk.CollectionConverters._ import org.apache.spark.{SecurityManager, SSLOptions} @@ -79,20 +80,24 @@ object CommandUtils extends Logging { val libraryPathEntries = command.libraryPathEntries val cmdLibraryPath = command.environment.get(libraryPathName) - var newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) { + val newEnvironment = new mutable.HashMap[String, String]() + newEnvironment.addAll(env) + + if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) { val libraryPaths = libraryPathEntries ++ cmdLibraryPath ++ env.get(libraryPathName) - command.environment ++ Map(libraryPathName -> libraryPaths.mkString(File.pathSeparator)) - } else { - command.environment + newEnvironment.put(libraryPathName, libraryPaths.mkString(File.pathSeparator)) + } + + for ((k, v) <- command.environment) { + newEnvironment.getOrElseUpdate(k, v) } // set auth secret to env variable if needed if (securityMgr.isAuthenticationEnabled()) { - newEnvironment = newEnvironment ++ - Map(SecurityManager.ENV_AUTH_SECRET -> securityMgr.getSecretKey()) + newEnvironment.put(SecurityManager.ENV_AUTH_SECRET, securityMgr.getSecretKey()) } // set SSL env variables if needed - newEnvironment ++= securityMgr.getEnvironmentForSslRpcPasswords + newEnvironment.addAll(securityMgr.getEnvironmentForSslRpcPasswords) Command( command.mainClass, diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala index e864b609d0e48..a08a22c7a646f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala @@ -94,4 +94,35 @@ class CommandUtilsSuite extends SparkFunSuite with Matchers with PrivateMethodTe env => assert(cmd.environment(env) === "password") ) } + + test("SPARK-46912: local environment takes a precedence") { + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val foo = "Foo" + val anotherKey = "AnotherKey" + val doesntExistInLocal = "DoesntExistInLocal" + val envFromCommand = Map(foo -> "commandBar", anotherKey -> "commandValue", + doesntExistInLocal -> "I don't exist", "JAVA_HOME" -> "opt/command/jdk") + val localEnv = Map(foo -> "localBar", anotherKey -> "localValue", + "JAVA_HOME" -> "opt/local/jdk") + + + val cmd = Command("mainClass", Seq(), envFromCommand, Seq(), Seq("libraryPathToB"), Seq()) + val builder = CommandUtils.buildProcessBuilder( + cmd, new SecurityManager(new SparkConf), 512, sparkHome, t => t, Seq(), + env = localEnv) + val libraryPath = Utils.libraryPathEnvName + val env = builder.environment + + assert(env.containsKey(foo)) + assert(env.containsKey(anotherKey)) + assert(env.containsKey(libraryPath)) + assert(env.containsKey(doesntExistInLocal)) + assert(env.containsKey("JAVA_HOME")) + + assert(env.get(foo) equals "localBar") + assert(env.get(anotherKey) equals "localValue") + assert(env.get(doesntExistInLocal) equals "I don't exist") + assert(env.get(libraryPath).startsWith("libraryPathToB")) + assert(builder.command().get(0).startsWith("opt/local/jdk")) + } }