Skip to content

Commit c9f600b

Browse files
committed
Added conf and multi-version tests in FlatMapGroupsWithStateSuite
1 parent 9525484 commit c9f600b

File tree

7 files changed

+45
-21
lines changed

7 files changed

+45
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,14 @@ object SQLConf {
814814
.intConf
815815
.createWithDefault(10)
816816

817+
val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION =
818+
buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion")
819+
.internal()
820+
.doc("State format version used by flatMapGroupsWithState operation in a streaming query")
821+
.intConf
822+
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
823+
.createWithDefault(2)
824+
817825
val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation")
818826
.doc("The default location for storing checkpoint data for streaming queries.")
819827
.stringConf

sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ case class ObjectType(cls: Class[_]) extends DataType {
4343

4444
def asNullable: DataType = this
4545

46-
override def simpleString: String = s"Object[${cls.getName}]"
46+
override def simpleString: String = cls.getName
4747

4848
override def acceptsType(other: DataType): Boolean = other match {
4949
case ObjectType(otherCls) => cls.isAssignableFrom(otherCls)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,9 +485,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
485485
case FlatMapGroupsWithState(
486486
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _,
487487
timeout, child) =>
488+
val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
488489
val execPlan = FlatMapGroupsWithStateExec(
489-
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode,
490-
timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child))
490+
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, stateVersion,
491+
outputMode, timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child))
491492
execPlan :: Nil
492493
case _ =>
493494
Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ case class FlatMapGroupsWithStateExec(
5050
outputObjAttr: Attribute,
5151
stateInfo: Option[StatefulOperatorStateInfo],
5252
stateEncoder: ExpressionEncoder[Any],
53+
stateFormatVersion: Int,
5354
outputMode: OutputMode,
5455
timeoutConf: GroupStateTimeout,
5556
batchTimestampMs: Option[Long],
@@ -65,7 +66,8 @@ case class FlatMapGroupsWithStateExec(
6566
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
6667
case _ => false
6768
}
68-
private[sql] val stateManager = createStateManager(stateEncoder, isTimeoutEnabled, 2)
69+
private[sql] val stateManager =
70+
createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion)
6971

7072
/** Distribute by grouping attributes */
7173
override def requiredChildDistribution: Seq[Distribution] =

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution.streaming.state
1919

2020
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
21-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, Expression, GenericInternalRow, GetStructField, If, IsNull, Literal, SpecificInternalRow, UnsafeRow}
21+
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.execution.ObjectOperator
2323
import org.apache.spark.sql.execution.streaming.GroupStateImpl
2424
import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
@@ -27,7 +27,7 @@ import org.apache.spark.sql.types._
2727

2828
object FlatMapGroupsWithStateExecHelper {
2929

30-
val DEFAULT_STATE_MANAGER_VERSION = 2
30+
val supportedVersions = Seq(1, 2)
3131

3232
/**
3333
* Class to capture deserialized state and timestamp return by the state manager.
@@ -58,24 +58,26 @@ object FlatMapGroupsWithStateExecHelper {
5858
def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timeoutTimestamp: Long): Unit
5959
def removeState(store: StateStore, keyRow: UnsafeRow): Unit
6060
def getAllState(store: StateStore): Iterator[StateData]
61+
def version: Int
6162
}
6263

6364
def createStateManager(
6465
stateEncoder: ExpressionEncoder[Any],
6566
shouldStoreTimestamp: Boolean,
66-
version: Int): StateManager = {
67-
version match {
67+
stateFormatVersion: Int): StateManager = {
68+
stateFormatVersion match {
6869
case 1 => new StateManagerImplV1(stateEncoder, shouldStoreTimestamp)
6970
case 2 => new StateManagerImplV2(stateEncoder, shouldStoreTimestamp)
70-
case _ => throw new IllegalArgumentException(s"Version $version")
71+
case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid")
7172
}
7273
}
7374

7475
// ===============================================================================================
7576
// =========================== Private implementations of StateManager ===========================
7677
// ===============================================================================================
7778

78-
private abstract class StateManagerImplBase(shouldStoreTimestamp: Boolean) extends StateManager {
79+
private abstract class StateManagerImplBase(val version: Int, shouldStoreTimestamp: Boolean)
80+
extends StateManager {
7981

8082
protected def stateSerializerExprs: Seq[Expression]
8183
protected def stateDeserializerExpr: Expression
@@ -135,7 +137,7 @@ object FlatMapGroupsWithStateExecHelper {
135137

136138
private class StateManagerImplV1(
137139
stateEncoder: ExpressionEncoder[Any],
138-
shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) {
140+
shouldStoreTimestamp: Boolean) extends StateManagerImplBase(1, shouldStoreTimestamp) {
139141

140142
private val timestampTimeoutAttribute =
141143
AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)()
@@ -175,7 +177,7 @@ object FlatMapGroupsWithStateExecHelper {
175177

176178
private class StateManagerImplV2(
177179
stateEncoder: ExpressionEncoder[Any],
178-
shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) {
180+
shouldStoreTimestamp: Boolean) extends StateManagerImplBase(2, shouldStoreTimestamp) {
179181

180182
/** Schema of the state rows saved in the state store */
181183
override val stateSchema: StructType = {

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.streaming.state
1919

2020
import java.util.concurrent.atomic.AtomicInteger
2121

22-
import org.apache.spark.sql.{Encoder, QueryTest}
22+
import org.apache.spark.sql.Encoder
2323
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
24-
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, UnsafeProjection, UnsafeRow}
24+
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
2525
import org.apache.spark.sql.execution.streaming.GroupStateImpl._
2626
import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite._
2727
import org.apache.spark.sql.streaming.StreamTest

sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
3131
import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
3232
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
3333
import org.apache.spark.sql.execution.RDDScanExec
34-
import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream}
35-
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
34+
import org.apache.spark.sql.execution.streaming._
35+
import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
36+
import org.apache.spark.sql.internal.SQLConf
3637
import org.apache.spark.sql.streaming.util.StreamManualClock
3738
import org.apache.spark.sql.types.{DataType, IntegerType}
3839

@@ -601,7 +602,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
601602
expectedState = Some(5), // state should change
602603
expectedTimeoutTimestamp = 5000) // timestamp should change
603604

604-
test("flatMapGroupsWithState - streaming") {
605+
testWithAllStateVersions("flatMapGroupsWithState - streaming") {
605606
// Function to maintain running count up to 2, and then remove the count
606607
// Returns the data and the count if state is defined, otherwise does not return anything
607608
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
@@ -680,7 +681,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
680681
)
681682
}
682683

683-
test("flatMapGroupsWithState - streaming + aggregation") {
684+
testWithAllStateVersions("flatMapGroupsWithState - streaming + aggregation") {
684685
// Function to maintain running count up to 2, and then remove the count
685686
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
686687
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
@@ -739,7 +740,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
739740
checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF)
740741
}
741742

742-
test("flatMapGroupsWithState - streaming with processing time timeout") {
743+
testWithAllStateVersions("flatMapGroupsWithState - streaming with processing time timeout") {
743744
// Function to maintain the count as state and set the proc. time timeout delay of 10 seconds.
744745
// It returns the count if changed, or -1 if the state was removed by timeout.
745746
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
@@ -803,7 +804,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
803804
)
804805
}
805806

806-
test("flatMapGroupsWithState - streaming with event time timeout + watermark") {
807+
testWithAllStateVersions("flatMapGroupsWithState - streaming with event time timeout") {
807808
// Function to maintain the max event time as state and set the timeout timestamp based on the
808809
// current max event time seen. It returns the max event time in the state, or -1 if the state
809810
// was removed by timeout.
@@ -1135,7 +1136,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
11351136
.logicalPlan.collectFirst {
11361137
case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) =>
11371138
FlatMapGroupsWithStateExec(
1138-
f, k, v, g, d, o, None, s, m, t,
1139+
f, k, v, g, d, o, None, s, 2, m, t,
11391140
Some(currentBatchTimestamp), Some(currentBatchWatermark), RDDScanExec(g, null, "rdd"))
11401141
}.get
11411142
}
@@ -1168,6 +1169,16 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
11681169
}
11691170

11701171
def rowToInt(row: UnsafeRow): Int = row.getInt(0)
1172+
1173+
def testWithAllStateVersions(name: String)(func: => Unit): Unit = {
1174+
for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) {
1175+
test(s"$name - state format version $version") {
1176+
withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> version.toString) {
1177+
func
1178+
}
1179+
}
1180+
}
1181+
}
11711182
}
11721183

11731184
object FlatMapGroupsWithStateSuite {

0 commit comments

Comments
 (0)