Skip to content

Commit 4041902

Browse files
cloud-fanrxin
authored andcommitted
[SPARK-12882][SQL] simplify bucket tests and add more comments
Right now, the bucket tests are kind of hard to understand, this PR simplifies them and add more commetns. Author: Wenchen Fan <[email protected]> Closes #10813 from cloud-fan/bucket-comment.
1 parent 4f11e3f commit 4041902

File tree

2 files changed

+78
-46
lines changed

2 files changed

+78
-46
lines changed

sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLC
2323
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
2424
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
2525
import org.apache.spark.sql.execution.Exchange
26+
import org.apache.spark.sql.execution.datasources.BucketSpec
2627
import org.apache.spark.sql.execution.joins.SortMergeJoin
2728
import org.apache.spark.sql.functions._
2829
import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -61,15 +62,30 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
6162
private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
6263
private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2")
6364

65+
/**
66+
* A helper method to test the bucket read functionality using join. It will save `df1` and `df2`
67+
* to hive tables, bucketed or not, according to the given bucket specifics. Next we will join
68+
* these 2 tables, and firstly make sure the answer is corrected, and then check if the shuffle
69+
* exists as user expected according to the `shuffleLeft` and `shuffleRight`.
70+
*/
6471
private def testBucketing(
65-
bucketing1: DataFrameWriter => DataFrameWriter,
66-
bucketing2: DataFrameWriter => DataFrameWriter,
72+
bucketSpecLeft: Option[BucketSpec],
73+
bucketSpecRight: Option[BucketSpec],
6774
joinColumns: Seq[String],
6875
shuffleLeft: Boolean,
6976
shuffleRight: Boolean): Unit = {
7077
withTable("bucketed_table1", "bucketed_table2") {
71-
bucketing1(df1.write.format("parquet")).saveAsTable("bucketed_table1")
72-
bucketing2(df2.write.format("parquet")).saveAsTable("bucketed_table2")
78+
def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): DataFrameWriter = {
79+
bucketSpec.map { spec =>
80+
writer.bucketBy(
81+
spec.numBuckets,
82+
spec.bucketColumnNames.head,
83+
spec.bucketColumnNames.tail: _*)
84+
}.getOrElse(writer)
85+
}
86+
87+
withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1")
88+
withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2")
7389

7490
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
7591
val t1 = hiveContext.table("bucketed_table1")
@@ -95,42 +111,42 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
95111
}
96112

97113
test("avoid shuffle when join 2 bucketed tables") {
98-
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
99-
testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
114+
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
115+
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
100116
}
101117

102118
// Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
103119
ignore("avoid shuffle when join keys are a super-set of bucket keys") {
104-
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
105-
testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
120+
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
121+
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
106122
}
107123

108124
test("only shuffle one side when join bucketed table and non-bucketed table") {
109-
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
110-
testBucketing(bucketing, identity, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
125+
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
126+
testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
111127
}
112128

113129
test("only shuffle one side when 2 bucketed tables have different bucket number") {
114-
val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
115-
val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(5, "i", "j")
116-
testBucketing(bucketing1, bucketing2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
130+
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil))
131+
val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil))
132+
testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
117133
}
118134

119135
test("only shuffle one side when 2 bucketed tables have different bucket keys") {
120-
val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
121-
val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(8, "j")
122-
testBucketing(bucketing1, bucketing2, Seq("i"), shuffleLeft = false, shuffleRight = true)
136+
val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil))
137+
val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil))
138+
testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true)
123139
}
124140

125141
test("shuffle when join keys are not equal to bucket keys") {
126-
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
127-
testBucketing(bucketing, bucketing, Seq("j"), shuffleLeft = true, shuffleRight = true)
142+
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
143+
testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true)
128144
}
129145

130146
test("shuffle when join 2 bucketed tables with bucketing disabled") {
131-
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
147+
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
132148
withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
133-
testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
149+
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
134150
}
135151
}
136152

sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -65,39 +65,55 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
6565

6666
private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
6767

68+
/**
69+
* A helper method to check the bucket write functionality in low level, i.e. check the written
70+
* bucket files to see if the data are correct. User should pass in a data dir that these bucket
71+
* files are written to, and the format of data(parquet, json, etc.), and the bucketing
72+
* information.
73+
*/
6874
private def testBucketing(
6975
dataDir: File,
7076
source: String,
77+
numBuckets: Int,
7178
bucketCols: Seq[String],
7279
sortCols: Seq[String] = Nil): Unit = {
7380
val allBucketFiles = dataDir.listFiles().filterNot(f =>
7481
f.getName.startsWith(".") || f.getName.startsWith("_")
7582
)
76-
val groupedBucketFiles = allBucketFiles.groupBy(f => BucketingUtils.getBucketId(f.getName).get)
77-
assert(groupedBucketFiles.size <= 8)
78-
79-
for ((bucketId, bucketFiles) <- groupedBucketFiles) {
80-
for (bucketFilePath <- bucketFiles.map(_.getAbsolutePath)) {
81-
val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType)
82-
val columns = (bucketCols ++ sortCols).zip(types).map {
83-
case (colName, dt) => col(colName).cast(dt)
84-
}
85-
val readBack = sqlContext.read.format(source).load(bucketFilePath).select(columns: _*)
8683

87-
if (sortCols.nonEmpty) {
88-
checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect())
89-
}
84+
for (bucketFile <- allBucketFiles) {
85+
val bucketId = BucketingUtils.getBucketId(bucketFile.getName).get
86+
assert(bucketId >= 0 && bucketId < numBuckets)
9087

91-
val qe = readBack.select(bucketCols.map(col): _*).queryExecution
92-
val rows = qe.toRdd.map(_.copy()).collect()
93-
val getBucketId = UnsafeProjection.create(
94-
HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil,
95-
qe.analyzed.output)
88+
// We may loss the type information after write(e.g. json format doesn't keep schema
89+
// information), here we get the types from the original dataframe.
90+
val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType)
91+
val columns = (bucketCols ++ sortCols).zip(types).map {
92+
case (colName, dt) => col(colName).cast(dt)
93+
}
9694

97-
for (row <- rows) {
98-
val actualBucketId = getBucketId(row).getInt(0)
99-
assert(actualBucketId == bucketId)
100-
}
95+
// Read the bucket file into a dataframe, so that it's easier to test.
96+
val readBack = sqlContext.read.format(source)
97+
.load(bucketFile.getAbsolutePath)
98+
.select(columns: _*)
99+
100+
// If we specified sort columns while writing bucket table, make sure the data in this
101+
// bucket file is already sorted.
102+
if (sortCols.nonEmpty) {
103+
checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect())
104+
}
105+
106+
// Go through all rows in this bucket file, calculate bucket id according to bucket column
107+
// values, and make sure it equals to the expected bucket id that inferred from file name.
108+
val qe = readBack.select(bucketCols.map(col): _*).queryExecution
109+
val rows = qe.toRdd.map(_.copy()).collect()
110+
val getBucketId = UnsafeProjection.create(
111+
HashPartitioning(qe.analyzed.output, numBuckets).partitionIdExpression :: Nil,
112+
qe.analyzed.output)
113+
114+
for (row <- rows) {
115+
val actualBucketId = getBucketId(row).getInt(0)
116+
assert(actualBucketId == bucketId)
101117
}
102118
}
103119
}
@@ -113,7 +129,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
113129

114130
val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
115131
for (i <- 0 until 5) {
116-
testBucketing(new File(tableDir, s"i=$i"), source, Seq("j", "k"))
132+
testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k"))
117133
}
118134
}
119135
}
@@ -131,7 +147,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
131147

132148
val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
133149
for (i <- 0 until 5) {
134-
testBucketing(new File(tableDir, s"i=$i"), source, Seq("j"), Seq("k"))
150+
testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j"), Seq("k"))
135151
}
136152
}
137153
}
@@ -146,7 +162,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
146162
.saveAsTable("bucketed_table")
147163

148164
val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
149-
testBucketing(tableDir, source, Seq("i", "j"))
165+
testBucketing(tableDir, source, 8, Seq("i", "j"))
150166
}
151167
}
152168
}
@@ -161,7 +177,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
161177
.saveAsTable("bucketed_table")
162178

163179
val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
164-
testBucketing(tableDir, source, Seq("i", "j"), Seq("k"))
180+
testBucketing(tableDir, source, 8, Seq("i", "j"), Seq("k"))
165181
}
166182
}
167183
}

0 commit comments

Comments
 (0)