Skip to content

Commit 074a422

Browse files
committed
Fix: deal with spark.files.overwrite
1 parent 03ed3a8 commit 074a422

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,8 @@ private[spark] object Utils extends Logging {
333333
val fileName = url.split("/").last
334334
val targetFile = new File(targetDir, fileName)
335335
if (useCache) {
336-
val cachedFileName = url.hashCode + timestamp + "_cache"
337-
val lockFileName = url.hashCode + timestamp + "_lock"
336+
val cachedFileName = s"${url.hashCode}${timestamp}_cache"
337+
val lockFileName = s"${url.hashCode}${timestamp}_lock"
338338
val localDir = new File(getLocalDir(conf))
339339
val lockFile = new File(localDir, lockFileName)
340340
val raf = new RandomAccessFile(lockFile, "rw")
@@ -345,15 +345,24 @@ private[spark] object Utils extends Logging {
345345
val cachedFile = new File(localDir, cachedFileName)
346346
try {
347347
if (!cachedFile.exists()) {
348-
doFetchFile(url, localDir, conf, securityMgr, hadoopConf)
349-
Files.move(new File(localDir, fileName), cachedFile)
348+
doFetchFile(url, localDir, cachedFileName, conf, securityMgr, hadoopConf)
350349
}
351350
} finally {
352351
lock.release()
353352
}
353+
if (targetFile.exists && !Files.equal(cachedFile, targetFile)) {
354+
if (conf.getBoolean("spark.files.overwrite", false)) {
355+
targetFile.delete()
356+
logInfo(("File %s exists and does not match contents of %s, " +
357+
"replacing it with %s").format(targetFile, url, url))
358+
} else {
359+
throw new SparkException(
360+
"File " + targetFile + " exists and does not match contents of" + " " + url)
361+
}
362+
}
354363
Files.copy(cachedFile, targetFile)
355364
} else {
356-
doFetchFile(url, targetDir, conf, securityMgr, hadoopConf)
365+
doFetchFile(url, targetDir, fileName, conf, securityMgr, hadoopConf)
357366
}
358367

359368
// Decompress the file if it's a .tar or .tar.gz
@@ -378,10 +387,10 @@ private[spark] object Utils extends Logging {
378387
private def doFetchFile(
379388
url: String,
380389
targetDir: File,
390+
filename: String,
381391
conf: SparkConf,
382392
securityMgr: SecurityManager,
383393
hadoopConf: Configuration) {
384-
val filename = url.split("/").last
385394
val tempDir = getLocalDir(conf)
386395
val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
387396
val targetFile = new File(targetDir, filename)

0 commit comments

Comments
 (0)