Skip to content

Commit c16c392

Browse files
committed
put addPyFile in front of sys.path
1 parent 56dae30 commit c16c392

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

python/pyspark/context.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
171171

172172
SparkFiles._sc = self
173173
root_dir = SparkFiles.getRootDirectory()
174-
sys.path.append(root_dir)
174+
sys.path.insert(1, root_dir)
175175

176176
# Deploy any code dependencies specified in the constructor
177177
self._python_includes = list()
@@ -183,10 +183,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
183183
for path in self._conf.get("spark.submit.pyFiles", "").split(","):
184184
if path != "":
185185
(dirname, filename) = os.path.split(path)
186-
self._python_includes.append(filename)
187-
sys.path.append(path)
188-
if dirname not in sys.path:
189-
sys.path.append(dirname)
186+
if filename.lower().endswith("zip") or filename.lower().endswith("egg"):
187+
self._python_includes.append(filename)
188+
sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))
190189

191190
# Create a temporary directory inside spark.local.dir:
192191
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
@@ -667,7 +666,7 @@ def addPyFile(self, path):
667666
if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
668667
self._python_includes.append(filename)
669668
# for tests in local mode
670-
sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename))
669+
sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))
671670

672671
def setCheckpointDir(self, dirName):
673672
"""

python/pyspark/worker.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ def report_times(outfile, boot, init, finish):
4343
write_long(1000 * finish, outfile)
4444

4545

46+
def add_path(path):
47+
# worker can be used, so donot add path multiple times
48+
if path not in sys.path:
49+
# overwrite system packages
50+
sys.path.insert(1, path)
51+
52+
4653
def main(infile, outfile):
4754
try:
4855
boot_time = time.time()
@@ -61,11 +68,11 @@ def main(infile, outfile):
6168
SparkFiles._is_running_on_worker = True
6269

6370
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
64-
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
71+
add_path(spark_files_dir) # *.py files that were added will be copied here
6572
num_python_includes = read_int(infile)
6673
for _ in range(num_python_includes):
6774
filename = utf8_deserializer.loads(infile)
68-
sys.path.append(os.path.join(spark_files_dir, filename))
75+
add_path(os.path.join(spark_files_dir, filename))
6976

7077
# fetch names and values of broadcast variables
7178
num_broadcast_variables = read_int(infile)

0 commit comments

Comments
 (0)