Skip to content

Commit bb6ffd1

Browse files
zedtangattilapiros
authored andcommitted
[SPARK-48761][SQL] Introduce clusterBy DataFrameWriter API for Scala
### What changes were proposed in this pull request? Introduce a new `clusterBy` DataFrame API in Scala. This PR adds the API for both the DataFrameWriter V1 and V2, as well as Spark Connect. ### Why are the changes needed? Introduce more ways for users to interact with clustered tables. ### Does this PR introduce _any_ user-facing change? Yes, it adds a new `clusterBy` DataFrame API in Scala to allow specifying the clustering columns when writing DataFrames. ### How was this patch tested? New unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#47301 from zedtang/clusterby-scala-api. Authored-by: Jiaheng Tang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 0e59ffb commit bb6ffd1

File tree

18 files changed

+482
-11
lines changed

18 files changed

+482
-11
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,20 @@
471471
],
472472
"sqlState" : "0A000"
473473
},
474+
"CLUSTERING_COLUMNS_MISMATCH" : {
475+
"message" : [
476+
"Specified clustering does not match that of the existing table <tableName>.",
477+
"Specified clustering columns: [<specifiedClusteringString>].",
478+
"Existing clustering columns: [<existingClusteringString>]."
479+
],
480+
"sqlState" : "42P10"
481+
},
482+
"CLUSTERING_NOT_SUPPORTED" : {
483+
"message" : [
484+
"'<operation>' does not support clustering."
485+
],
486+
"sqlState" : "42000"
487+
},
474488
"CODEC_NOT_AVAILABLE" : {
475489
"message" : [
476490
"The codec <codecName> is not available."

connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3090,6 +3090,11 @@ class SparkConnectPlanner(
30903090
w.partitionBy(names.toSeq: _*)
30913091
}
30923092

3093+
if (writeOperation.getClusteringColumnsCount > 0) {
3094+
val names = writeOperation.getClusteringColumnsList.asScala
3095+
w.clusterBy(names.head, names.tail.toSeq: _*)
3096+
}
3097+
30933098
if (writeOperation.hasSource) {
30943099
w.format(writeOperation.getSource)
30953100
}
@@ -3153,6 +3158,11 @@ class SparkConnectPlanner(
31533158
w.partitionedBy(names.head, names.tail: _*)
31543159
}
31553160

3161+
if (writeOperation.getClusteringColumnsCount > 0) {
3162+
val names = writeOperation.getClusteringColumnsList.asScala
3163+
w.clusterBy(names.head, names.tail.toSeq: _*)
3164+
}
3165+
31563166
writeOperation.getMode match {
31573167
case proto.WriteOperationV2.Mode.MODE_CREATE =>
31583168
if (writeOperation.hasProvider) {

connect/server/src/test/scala/org/apache/spark/sql/connect/dsl/package.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ package object dsl {
219219
mode: Option[String] = None,
220220
sortByColumns: Seq[String] = Seq.empty,
221221
partitionByCols: Seq[String] = Seq.empty,
222+
clusterByCols: Seq[String] = Seq.empty,
222223
bucketByCols: Seq[String] = Seq.empty,
223224
numBuckets: Option[Int] = None): Command = {
224225
val writeOp = WriteOperation.newBuilder()
@@ -242,6 +243,7 @@ package object dsl {
242243
}
243244
sortByColumns.foreach(writeOp.addSortColumnNames(_))
244245
partitionByCols.foreach(writeOp.addPartitioningColumns(_))
246+
clusterByCols.foreach(writeOp.addClusteringColumns(_))
245247

246248
if (numBuckets.nonEmpty && bucketByCols.nonEmpty) {
247249
val op = WriteOperation.BucketBy.newBuilder()
@@ -272,13 +274,15 @@ package object dsl {
272274
options: Map[String, String] = Map.empty,
273275
tableProperties: Map[String, String] = Map.empty,
274276
partitionByCols: Seq[Expression] = Seq.empty,
277+
clusterByCols: Seq[String] = Seq.empty,
275278
mode: Option[String] = None,
276279
overwriteCondition: Option[Expression] = None): Command = {
277280
val writeOp = WriteOperationV2.newBuilder()
278281
writeOp.setInput(logicalPlan)
279282
tableName.foreach(writeOp.setTableName)
280283
provider.foreach(writeOp.setProvider)
281284
partitionByCols.foreach(writeOp.addPartitioningColumns)
285+
clusterByCols.foreach(writeOp.addClusteringColumns)
282286
options.foreach { case (k, v) =>
283287
writeOp.putOptions(k, v)
284288
}

connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,48 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
596596
}
597597
}
598598

599+
test("Write with clustering") {
600+
// Cluster by existing column.
601+
withTable("testtable") {
602+
transform(
603+
localRelation.write(
604+
tableName = Some("testtable"),
605+
tableSaveMethod = Some("save_as_table"),
606+
format = Some("parquet"),
607+
clusterByCols = Seq("id")))
608+
}
609+
610+
// Cluster by non-existing column.
611+
assertThrows[AnalysisException](
612+
transform(
613+
localRelation
614+
.write(
615+
tableName = Some("testtable"),
616+
tableSaveMethod = Some("save_as_table"),
617+
format = Some("parquet"),
618+
clusterByCols = Seq("noid"))))
619+
}
620+
621+
test("Write V2 with clustering") {
622+
// Cluster by existing column.
623+
withTable("testtable") {
624+
transform(
625+
localRelation.writeV2(
626+
tableName = Some("testtable"),
627+
mode = Some("MODE_CREATE"),
628+
clusterByCols = Seq("id")))
629+
}
630+
631+
// Cluster by non-existing column.
632+
assertThrows[AnalysisException](
633+
transform(
634+
localRelation
635+
.writeV2(
636+
tableName = Some("testtable"),
637+
mode = Some("MODE_CREATE"),
638+
clusterByCols = Seq("noid"))))
639+
}
640+
599641
test("Write with invalid bucketBy configuration") {
600642
val cmd = localRelation.write(bucketByCols = Seq("id"), numBuckets = Some(0))
601643
assertThrows[InvalidCommandInput] {

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,22 @@ final class DataFrameWriter[T] private[sql] (ds: Dataset[T]) {
201201
this
202202
}
203203

204+
/**
205+
* Clusters the output by the given columns on the storage. The rows with matching values in the
206+
* specified clustering columns will be consolidated within the same group.
207+
*
208+
* For instance, if you cluster a dataset by date, the data sharing the same date will be stored
209+
* together in a file. This arrangement improves query efficiency when you apply selective
210+
* filters to these clustering columns, thanks to data skipping.
211+
*
212+
* @since 4.0.0
213+
*/
214+
@scala.annotation.varargs
215+
def clusterBy(colName: String, colNames: String*): DataFrameWriter[T] = {
216+
this.clusteringColumns = Option(colName +: colNames)
217+
this
218+
}
219+
204220
/**
205221
* Saves the content of the `DataFrame` at the specified path.
206222
*
@@ -242,6 +258,7 @@ final class DataFrameWriter[T] private[sql] (ds: Dataset[T]) {
242258
source.foreach(builder.setSource)
243259
sortColumnNames.foreach(names => builder.addAllSortColumnNames(names.asJava))
244260
partitioningColumns.foreach(cols => builder.addAllPartitioningColumns(cols.asJava))
261+
clusteringColumns.foreach(cols => builder.addAllClusteringColumns(cols.asJava))
245262

246263
numBuckets.foreach(n => {
247264
val bucketBuilder = proto.WriteOperation.BucketBy.newBuilder()
@@ -509,4 +526,6 @@ final class DataFrameWriter[T] private[sql] (ds: Dataset[T]) {
509526
private var numBuckets: Option[Int] = None
510527

511528
private var sortColumnNames: Option[Seq[String]] = None
529+
530+
private var clusteringColumns: Option[Seq[String]] = None
512531
}

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
4141

4242
private var partitioning: Option[Seq[proto.Expression]] = None
4343

44+
private var clustering: Option[Seq[String]] = None
45+
4446
private var overwriteCondition: Option[proto.Expression] = None
4547

4648
override def using(provider: String): CreateTableWriter[T] = {
@@ -77,6 +79,12 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
7779
this
7880
}
7981

82+
@scala.annotation.varargs
83+
override def clusterBy(colName: String, colNames: String*): CreateTableWriter[T] = {
84+
this.clustering = Some(colName +: colNames)
85+
this
86+
}
87+
8088
override def create(): Unit = {
8189
executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE)
8290
}
@@ -133,6 +141,7 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
133141
provider.foreach(builder.setProvider)
134142

135143
partitioning.foreach(columns => builder.addAllPartitioningColumns(columns.asJava))
144+
clustering.foreach(columns => builder.addAllClusteringColumns(columns.asJava))
136145

137146
options.foreach { case (k, v) =>
138147
builder.putOptions(k, v)
@@ -252,8 +261,22 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] {
252261
*
253262
* @since 3.4.0
254263
*/
264+
@scala.annotation.varargs
255265
def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T]
256266

267+
/**
268+
* Clusters the output by the given columns on the storage. The rows with matching values in the
269+
* specified clustering columns will be consolidated within the same group.
270+
*
271+
* For instance, if you cluster a dataset by date, the data sharing the same date will be stored
272+
* together in a file. This arrangement improves query efficiency when you apply selective
273+
* filters to these clustering columns, thanks to data skipping.
274+
*
275+
* @since 4.0.0
276+
*/
277+
@scala.annotation.varargs
278+
def clusterBy(colName: String, colNames: String*): CreateTableWriter[T]
279+
257280
/**
258281
* Specifies a provider for the underlying output data source. Spark's default catalog supports
259282
* "parquet", "json", etc.

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
8585
.setNumBuckets(2)
8686
.addBucketColumnNames("col1")
8787
.addBucketColumnNames("col2"))
88+
.addClusteringColumns("col3")
8889

8990
val expectedPlan = proto.Plan
9091
.newBuilder()
@@ -95,6 +96,7 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
9596
.sortBy("col1")
9697
.partitionBy("col99")
9798
.bucketBy(2, "col1", "col2")
99+
.clusterBy("col3")
98100
.parquet("my/test/path")
99101
val actualPlan = service.getAndClearLatestInputPlan()
100102
assert(actualPlan.equals(expectedPlan))
@@ -136,6 +138,7 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
136138
.setTableName("t1")
137139
.addPartitioningColumns(col("col99").expr)
138140
.setProvider("json")
141+
.addClusteringColumns("col3")
139142
.putTableProperties("key", "value")
140143
.putOptions("key2", "value2")
141144
.setMode(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE)
@@ -147,6 +150,7 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
147150

148151
df.writeTo("t1")
149152
.partitionedBy(col("col99"))
153+
.clusterBy("col3")
150154
.using("json")
151155
.tableProperty("key", "value")
152156
.options(Map("key2" -> "value2"))

project/MimaExcludes.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ object MimaExcludes {
100100
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext#implicits._sqlContext"),
101101
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits._sqlContext"),
102102
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.session"),
103-
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SparkSession#implicits._sqlContext")
103+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SparkSession#implicits._sqlContext"),
104+
// SPARK-48761: Add clusterBy() to CreateTableWriter.
105+
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.CreateTableWriter.clusterBy")
104106
)
105107

106108
// Default exclude rules

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,22 @@ object ClusterBySpec {
197197
ret
198198
}
199199

200+
/**
201+
* Converts the clustering column property to a ClusterBySpec.
202+
*/
200203
def fromProperty(columns: String): ClusterBySpec = {
201204
ClusterBySpec(mapper.readValue[Seq[Seq[String]]](columns).map(FieldReference(_)))
202205
}
203206

207+
/**
208+
* Converts a ClusterBySpec to a clustering column property map entry, with validation
209+
* of the column names against the schema.
210+
*
211+
* @param schema the schema of the table.
212+
* @param clusterBySpec the ClusterBySpec to be converted to a property.
213+
* @param resolver the resolver used to match the column names.
214+
* @return a map entry for the clustering column property.
215+
*/
204216
def toProperty(
205217
schema: StructType,
206218
clusterBySpec: ClusterBySpec,
@@ -209,10 +221,25 @@ object ClusterBySpec {
209221
normalizeClusterBySpec(schema, clusterBySpec, resolver).toJson
210222
}
211223

224+
/**
225+
* Converts a ClusterBySpec to a clustering column property map entry, without validating
226+
* the column names against the schema.
227+
*
228+
* @param clusterBySpec existing ClusterBySpec to be converted to properties.
229+
* @return a map entry for the clustering column property.
230+
*/
231+
def toPropertyWithoutValidation(clusterBySpec: ClusterBySpec): (String, String) = {
232+
(CatalogTable.PROP_CLUSTERING_COLUMNS -> clusterBySpec.toJson)
233+
}
234+
212235
private def normalizeClusterBySpec(
213236
schema: StructType,
214237
clusterBySpec: ClusterBySpec,
215238
resolver: Resolver): ClusterBySpec = {
239+
if (schema.isEmpty) {
240+
return clusterBySpec
241+
}
242+
216243
val normalizedColumns = clusterBySpec.columnNames.map { columnName =>
217244
val position = SchemaUtils.findColumnPosition(
218245
columnName.fieldNames().toImmutableArraySeq, schema, resolver)
@@ -239,6 +266,10 @@ object ClusterBySpec {
239266
val normalizedClusterBySpec = normalizeClusterBySpec(schema, clusterBySpec, resolver)
240267
ClusterByTransform(normalizedClusterBySpec.columnNames)
241268
}
269+
270+
def fromColumnNames(names: Seq[String]): ClusterBySpec = {
271+
ClusterBySpec(names.map(FieldReference(_)))
272+
}
242273
}
243274

244275
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,6 +1866,18 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
18661866
"existingBucketString" -> existingBucketString))
18671867
}
18681868

1869+
def mismatchedTableClusteringError(
1870+
tableName: String,
1871+
specifiedClusteringString: String,
1872+
existingClusteringString: String): Throwable = {
1873+
new AnalysisException(
1874+
errorClass = "CLUSTERING_COLUMNS_MISMATCH",
1875+
messageParameters = Map(
1876+
"tableName" -> tableName,
1877+
"specifiedClusteringString" -> specifiedClusteringString,
1878+
"existingClusteringString" -> existingClusteringString))
1879+
}
1880+
18691881
def specifyPartitionNotAllowedWhenTableSchemaNotDefinedError(): Throwable = {
18701882
new AnalysisException(
18711883
errorClass = "_LEGACY_ERROR_TEMP_1165",
@@ -4100,4 +4112,22 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
41004112
messageParameters = Map("functionName" -> functionName)
41014113
)
41024114
}
4115+
4116+
def operationNotSupportClusteringError(operation: String): Throwable = {
4117+
new AnalysisException(
4118+
errorClass = "CLUSTERING_NOT_SUPPORTED",
4119+
messageParameters = Map("operation" -> operation))
4120+
}
4121+
4122+
def clusterByWithPartitionedBy(): Throwable = {
4123+
new AnalysisException(
4124+
errorClass = "SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED",
4125+
messageParameters = Map.empty)
4126+
}
4127+
4128+
def clusterByWithBucketing(): Throwable = {
4129+
new AnalysisException(
4130+
errorClass = "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED",
4131+
messageParameters = Map.empty)
4132+
}
41034133
}

0 commit comments

Comments
 (0)