Skip to content

[SPARK-36452][SQL]: Add the support in Spark for having group by map datatype column for the scenario that works in Hive #33679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
}

// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
if (!RowOrdering.isOrderable(expr.dataType, isGroupingExpr = true)) {
failAnalysis(
s"expression ${expr.sql} cannot be used as a grouping expression " +
s"because its data type ${expr.dataType.catalogString} is not an orderable " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,16 @@ object RowOrdering extends CodeGeneratorWithInterpretedFallback[Seq[SortOrder],
/**
* Returns true iff the data type can be ordered (i.e. can be sorted).
*/
def isOrderable(dataType: DataType): Boolean = dataType match {
def isOrderable(dataType: DataType,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we fix #31967 first?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon - Thanks for checking this PR. Yes we can wait for this PR #32552. The fix in this will work with group by, order by , partition by in window.

isGroupingExpr: Boolean = false): Boolean = dataType match {
case NullType => true
case dt: AtomicType => true
case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType))
case array: ArrayType => isOrderable(array.elementType)
// Support MapType when the request comes from check
// analysis for the grouping expression
case map: MapType if isGroupingExpr =>
isOrderable(map.keyType) && isOrderable(map.valueType)
case udt: UserDefinedType[_] => isOrderable(udt.sqlType)
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case _ => needNormalize(expr.dataType)
}

private def needNormalize(dt: DataType): Boolean = dt match {
private[sql] def needNormalize(dt: DataType): Boolean = dt match {
case FloatType | DoubleType => true
case StructType(fields) => fields.exists(f => needNormalize(f.dataType))
case ArrayType(et, _) => needNormalize(et)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ class AnalysisErrorSuite extends AnalysisTest {
FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
DateType, TimestampType,
ArrayType(IntegerType),
MapType(StringType, LongType),
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", StringType, nullable = true),
Expand All @@ -600,7 +601,6 @@ class AnalysisErrorSuite extends AnalysisTest {
}

val unsupportedDataTypes = Seq(
MapType(StringType, LongType),
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", MapType(StringType, LongType), nullable = true),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.MemoryPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{MapType, StructType}

/**
* Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting
Expand Down Expand Up @@ -498,10 +498,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because
// `groupingExpressions` is not extracted during logical phase.
val normalizedGroupingExpressions = groupingExpressions.map { e =>
NormalizeFloatingNumbers.normalize(e) match {
case n: NamedExpression => n
// Keep the name of the original expression.
case other => Alias(other, e.name)(exprId = e.exprId)
e.dataType match {
// Support use of MapType in the group by when aggregateExpressions
// does not contain the MapType attribute and both keys and value
// are not Float/Double.
case MapType(kt, vt, _)
if !aggregateExpressions.exists(_.references == e.references) &&
!NormalizeFloatingNumbers.needNormalize(kt) &&
!NormalizeFloatingNumbers.needNormalize(vt) => e
case _ =>
NormalizeFloatingNumbers.normalize(e) match {
case n: NamedExpression => n
// Keep the name of the original expression.
case other => Alias(other, e.name)(exprId = e.exprId)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.types.MapType

/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
Expand Down Expand Up @@ -76,14 +77,22 @@ object AggUtils {
resultExpressions = resultExpressions,
child = child)
} else {
SortAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
resultExpressions = resultExpressions,
child = child)
// In SortAggregateExec there is one step that checks whether
// expression datatype is orderable or not over there Map
// is not orderable, Adding the validation for checking
// the Maptype in grouping expression
if (!groupingExpressions.exists(_.dataType.isInstanceOf[MapType])) {
SortAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
resultExpressions = resultExpressions,
child = child)
} else {
throw new IllegalStateException("grouping keys cannot be map type for SortAggregateExec")
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,33 @@ class DataFrameAggregateSuite extends QueryTest
assert (df.schema == expectedSchema)
checkAnswer(df, Seq(Row(LocalDateTime.parse(ts1), 2), Row(LocalDateTime.parse(ts2), 1)))
}

test("SPARK-36452: Support Map Type column in group by") {
var df = Seq((1, Map(1 -> 2)), (2, Map(1 -> 2))).toDF("id", "mapInfo")
// group by map column
checkAnswer(df.groupBy("mapInfo").count(), Seq(Row(Map[Any, Any](1 -> 2), 2)))
// group by map column and other column
checkAnswer(df.groupBy("id", "mapInfo").count(),
Seq(Row(1, Map[Any, Any](1 -> 2), 1), Row(2, Map[Any, Any](1 -> 2), 1)))
checkAnswer(df.groupBy("mapInfo").agg(avg("id")),
Seq(Row(Map[Any, Any](1 -> 2), 1.5)))
// Does not support if the map type if present in the aggregated expression
var error = intercept[IllegalStateException] {
df.groupBy("mapInfo").agg(max(map_keys(col("mapinfo")))).collect
}
assert(error.getMessage.contains("grouping/join/window partition keys cannot be map type."))
// Does not support if the map type with float/double keys or value
df = Seq((1, Map(1 -> 2.0)), (2, Map(1 -> 2.0))).toDF("id", "mapInfo")
error = intercept[IllegalStateException] {
df.groupBy("mapInfo").agg(max(map_keys(col("mapinfo")))).collect
}
assert(error.getMessage.contains("grouping/join/window partition keys cannot be map type."))
df = Seq((1, Map(1.1 -> 2.0)), (2, Map(1.1 -> 2.0))).toDF("id", "mapInfo")
error = intercept[IllegalStateException] {
df.groupBy("mapInfo").agg(max(map_keys(col("mapinfo")))).collect
}
assert(error.getMessage.contains("grouping/join/window partition keys cannot be map type."))
}
}

case class B(c: Option[Double])
Expand Down