Skip to content

Commit 354e936

Browse files
rxinhvanhovell
authored andcommitted
[SPARK-18775][SQL] Limit the max number of records written per file
## What changes were proposed in this pull request? Currently, Spark writes a single file out per task, sometimes leading to very large files. It would be great to have an option to limit the max number of records written per file in a task, to avoid humongous files. This patch introduces a new write config option `maxRecordsPerFile` (default to a session-wide setting `spark.sql.files.maxRecordsPerFile`) that limits the max number of records written to a single file. A non-positive value indicates there is no limit (same behavior as not having this flag). ## How was this patch tested? Added test cases in PartitionedWriteSuite for both dynamic partition insert and non-dynamic partition insert. Author: Reynold Xin <[email protected]> Closes #16204 from rxin/SPARK-18775.
1 parent 078c71c commit 354e936

File tree

4 files changed

+179
-39
lines changed

4 files changed

+179
-39
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,12 @@ import org.apache.spark._
3131
import org.apache.spark.internal.Logging
3232
import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils}
3333
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
34-
import org.apache.spark.sql.{Dataset, SparkSession}
34+
import org.apache.spark.sql.SparkSession
3535
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils}
3636
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
3737
import org.apache.spark.sql.catalyst.expressions._
3838
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
3939
import org.apache.spark.sql.catalyst.InternalRow
40-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
4140
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter}
4241
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
4342
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -47,6 +46,13 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
4746
/** A helper object for writing FileFormat data out to a location. */
4847
object FileFormatWriter extends Logging {
4948

49+
/**
50+
* Max number of files a single task writes out due to file size. In most cases the number of
51+
* files written should be very small. This is just a safe guard to protect some really bad
52+
* settings, e.g. maxRecordsPerFile = 1.
53+
*/
54+
private val MAX_FILE_COUNTER = 1000 * 1000
55+
5056
/** Describes how output files should be placed in the filesystem. */
5157
case class OutputSpec(
5258
outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String])
@@ -61,7 +67,8 @@ object FileFormatWriter extends Logging {
6167
val nonPartitionColumns: Seq[Attribute],
6268
val bucketSpec: Option[BucketSpec],
6369
val path: String,
64-
val customPartitionLocations: Map[TablePartitionSpec, String])
70+
val customPartitionLocations: Map[TablePartitionSpec, String],
71+
val maxRecordsPerFile: Long)
6572
extends Serializable {
6673

6774
assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns),
@@ -116,7 +123,10 @@ object FileFormatWriter extends Logging {
116123
nonPartitionColumns = dataColumns,
117124
bucketSpec = bucketSpec,
118125
path = outputSpec.outputPath,
119-
customPartitionLocations = outputSpec.customPartitionLocations)
126+
customPartitionLocations = outputSpec.customPartitionLocations,
127+
maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
128+
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
129+
)
120130

121131
SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
122132
// This call shouldn't be put into the `try` block below because it only initializes and
@@ -225,32 +235,49 @@ object FileFormatWriter extends Logging {
225235
taskAttemptContext: TaskAttemptContext,
226236
committer: FileCommitProtocol) extends ExecuteWriteTask {
227237

228-
private[this] var outputWriter: OutputWriter = {
238+
private[this] var currentWriter: OutputWriter = _
239+
240+
private def newOutputWriter(fileCounter: Int): Unit = {
241+
val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext)
229242
val tmpFilePath = committer.newTaskTempFile(
230243
taskAttemptContext,
231244
None,
232-
description.outputWriterFactory.getFileExtension(taskAttemptContext))
245+
f"-c$fileCounter%03d" + ext)
233246

234-
val outputWriter = description.outputWriterFactory.newInstance(
247+
currentWriter = description.outputWriterFactory.newInstance(
235248
path = tmpFilePath,
236249
dataSchema = description.nonPartitionColumns.toStructType,
237250
context = taskAttemptContext)
238-
outputWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
239-
outputWriter
251+
currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
240252
}
241253

242254
override def execute(iter: Iterator[InternalRow]): Set[String] = {
255+
var fileCounter = 0
256+
var recordsInFile: Long = 0L
257+
newOutputWriter(fileCounter)
243258
while (iter.hasNext) {
259+
if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) {
260+
fileCounter += 1
261+
assert(fileCounter < MAX_FILE_COUNTER,
262+
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
263+
264+
recordsInFile = 0
265+
releaseResources()
266+
newOutputWriter(fileCounter)
267+
}
268+
244269
val internalRow = iter.next()
245-
outputWriter.writeInternal(internalRow)
270+
currentWriter.writeInternal(internalRow)
271+
recordsInFile += 1
246272
}
273+
releaseResources()
247274
Set.empty
248275
}
249276

250277
override def releaseResources(): Unit = {
251-
if (outputWriter != null) {
252-
outputWriter.close()
253-
outputWriter = null
278+
if (currentWriter != null) {
279+
currentWriter.close()
280+
currentWriter = null
254281
}
255282
}
256283
}
@@ -300,8 +327,15 @@ object FileFormatWriter extends Logging {
300327
* Open and returns a new OutputWriter given a partition key and optional bucket id.
301328
* If bucket id is specified, we will append it to the end of the file name, but before the
302329
* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
330+
*
331+
* @param key vaues for fields consisting of partition keys for the current row
332+
* @param partString a function that projects the partition values into a string
333+
* @param fileCounter the number of files that have been written in the past for this specific
334+
* partition. This is used to limit the max number of records written for a
335+
* single file. The value should start from 0.
303336
*/
304-
private def newOutputWriter(key: InternalRow, partString: UnsafeProjection): OutputWriter = {
337+
private def newOutputWriter(
338+
key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = {
305339
val partDir =
306340
if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0))
307341

@@ -311,7 +345,10 @@ object FileFormatWriter extends Logging {
311345
} else {
312346
""
313347
}
314-
val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext)
348+
349+
// This must be in a form that matches our bucketing format. See BucketingUtils.
350+
val ext = f"$bucketId.c$fileCounter%03d" +
351+
description.outputWriterFactory.getFileExtension(taskAttemptContext)
315352

316353
val customPath = partDir match {
317354
case Some(dir) =>
@@ -324,12 +361,12 @@ object FileFormatWriter extends Logging {
324361
} else {
325362
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
326363
}
327-
val newWriter = description.outputWriterFactory.newInstance(
364+
365+
currentWriter = description.outputWriterFactory.newInstance(
328366
path = path,
329367
dataSchema = description.nonPartitionColumns.toStructType,
330368
context = taskAttemptContext)
331-
newWriter.initConverter(description.nonPartitionColumns.toStructType)
332-
newWriter
369+
currentWriter.initConverter(description.nonPartitionColumns.toStructType)
333370
}
334371

335372
override def execute(iter: Iterator[InternalRow]): Set[String] = {
@@ -349,7 +386,7 @@ object FileFormatWriter extends Logging {
349386
description.nonPartitionColumns, description.allColumns)
350387

351388
// Returns the partition path given a partition key.
352-
val getPartitionString = UnsafeProjection.create(
389+
val getPartitionStringFunc = UnsafeProjection.create(
353390
Seq(Concat(partitionStringExpression)), description.partitionColumns)
354391

355392
// Sorts the data before write, so that we only need one writer at the same time.
@@ -366,7 +403,6 @@ object FileFormatWriter extends Logging {
366403
val currentRow = iter.next()
367404
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
368405
}
369-
logInfo(s"Sorting complete. Writing out partition files one at a time.")
370406

371407
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
372408
identity
@@ -379,30 +415,43 @@ object FileFormatWriter extends Logging {
379415
val sortedIterator = sorter.sortedIterator()
380416

381417
// If anything below fails, we should abort the task.
418+
var recordsInFile: Long = 0L
419+
var fileCounter = 0
382420
var currentKey: UnsafeRow = null
383421
val updatedPartitions = mutable.Set[String]()
384422
while (sortedIterator.next()) {
385423
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
386424
if (currentKey != nextKey) {
387-
if (currentWriter != null) {
388-
currentWriter.close()
389-
currentWriter = null
390-
}
425+
// See a new key - write to a new partition (new file).
391426
currentKey = nextKey.copy()
392427
logDebug(s"Writing partition: $currentKey")
393428

394-
currentWriter = newOutputWriter(currentKey, getPartitionString)
395-
val partitionPath = getPartitionString(currentKey).getString(0)
429+
recordsInFile = 0
430+
fileCounter = 0
431+
432+
releaseResources()
433+
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
434+
val partitionPath = getPartitionStringFunc(currentKey).getString(0)
396435
if (partitionPath.nonEmpty) {
397436
updatedPartitions.add(partitionPath)
398437
}
438+
} else if (description.maxRecordsPerFile > 0 &&
439+
recordsInFile >= description.maxRecordsPerFile) {
440+
// Exceeded the threshold in terms of the number of records per file.
441+
// Create a new file by increasing the file counter.
442+
recordsInFile = 0
443+
fileCounter += 1
444+
assert(fileCounter < MAX_FILE_COUNTER,
445+
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
446+
447+
releaseResources()
448+
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
399449
}
450+
400451
currentWriter.writeInternal(sortedIterator.getValue)
452+
recordsInFile += 1
401453
}
402-
if (currentWriter != null) {
403-
currentWriter.close()
404-
currentWriter = null
405-
}
454+
releaseResources()
406455
updatedPartitions.toSet
407456
}
408457

sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,19 @@ object SQLConf {
466466
.longConf
467467
.createWithDefault(4 * 1024 * 1024)
468468

469+
val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles")
470+
.doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
471+
"encountering corrupted or non-existing and contents that have been read will still be " +
472+
"returned.")
473+
.booleanConf
474+
.createWithDefault(false)
475+
476+
val MAX_RECORDS_PER_FILE = SQLConfigBuilder("spark.sql.files.maxRecordsPerFile")
477+
.doc("Maximum number of records to write out to a single file. " +
478+
"If this value is zero or negative, there is no limit.")
479+
.longConf
480+
.createWithDefault(0)
481+
469482
val EXCHANGE_REUSE_ENABLED = SQLConfigBuilder("spark.sql.exchange.reuse")
470483
.internal()
471484
.doc("When true, the planner will try to find out duplicated exchanges and re-use them.")
@@ -629,13 +642,6 @@ object SQLConf {
629642
.doubleConf
630643
.createWithDefault(0.05)
631644

632-
val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles")
633-
.doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
634-
"encountering corrupted or non-existing and contents that have been read will still be " +
635-
"returned.")
636-
.booleanConf
637-
.createWithDefault(false)
638-
639645
object Deprecated {
640646
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
641647
}
@@ -700,6 +706,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
700706

701707
def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES)
702708

709+
def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES)
710+
711+
def maxRecordsPerFile: Long = getConf(MAX_RECORDS_PER_FILE)
712+
703713
def useCompression: Boolean = getConf(COMPRESS_CACHED)
704714

705715
def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION)
@@ -821,8 +831,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
821831

822832
def warehousePath: String = new Path(getConf(StaticSQLConf.WAREHOUSE_PATH)).toString
823833

824-
def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES)
825-
826834
override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)
827835

828836
override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.datasources
19+
20+
import org.apache.spark.SparkFunSuite
21+
22+
class BucketingUtilsSuite extends SparkFunSuite {
23+
24+
test("generate bucket id") {
25+
assert(BucketingUtils.bucketIdToString(0) == "_00000")
26+
assert(BucketingUtils.bucketIdToString(10) == "_00010")
27+
assert(BucketingUtils.bucketIdToString(999999) == "_999999")
28+
}
29+
30+
test("match bucket ids") {
31+
def testCase(filename: String, expected: Option[Int]): Unit = withClue(s"name: $filename") {
32+
assert(BucketingUtils.getBucketId(filename) == expected)
33+
}
34+
35+
testCase("a_1", Some(1))
36+
testCase("a_1.txt", Some(1))
37+
testCase("a_9999999", Some(9999999))
38+
testCase("a_9999999.txt", Some(9999999))
39+
testCase("a_1.c2.txt", Some(1))
40+
testCase("a_1.", Some(1))
41+
42+
testCase("a_1:txt", None)
43+
testCase("a_1-c2.txt", None)
44+
}
45+
46+
}

sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.sources
1919

20+
import java.io.File
21+
2022
import org.apache.spark.sql.{QueryTest, Row}
2123
import org.apache.spark.sql.functions._
2224
import org.apache.spark.sql.test.SharedSQLContext
@@ -61,4 +63,39 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
6163
assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i"))
6264
}
6365
}
66+
67+
test("maxRecordsPerFile setting in non-partitioned write path") {
68+
withTempDir { f =>
69+
spark.range(start = 0, end = 4, step = 1, numPartitions = 1)
70+
.write.option("maxRecordsPerFile", 1).mode("overwrite").parquet(f.getAbsolutePath)
71+
assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4)
72+
73+
spark.range(start = 0, end = 4, step = 1, numPartitions = 1)
74+
.write.option("maxRecordsPerFile", 2).mode("overwrite").parquet(f.getAbsolutePath)
75+
assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2)
76+
77+
spark.range(start = 0, end = 4, step = 1, numPartitions = 1)
78+
.write.option("maxRecordsPerFile", -1).mode("overwrite").parquet(f.getAbsolutePath)
79+
assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1)
80+
}
81+
}
82+
83+
test("maxRecordsPerFile setting in dynamic partition writes") {
84+
withTempDir { f =>
85+
spark.range(start = 0, end = 4, step = 1, numPartitions = 1).selectExpr("id", "id id1")
86+
.write
87+
.partitionBy("id")
88+
.option("maxRecordsPerFile", 1)
89+
.mode("overwrite")
90+
.parquet(f.getAbsolutePath)
91+
assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4)
92+
}
93+
}
94+
95+
/** Lists files recursively. */
96+
private def recursiveList(f: File): Array[File] = {
97+
require(f.isDirectory)
98+
val current = f.listFiles
99+
current ++ current.filter(_.isDirectory).flatMap(recursiveList)
100+
}
64101
}

0 commit comments

Comments
 (0)