diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index f7d1b105964d5..a69b80428472a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -715,7 +715,8 @@ trait ComplexTypeMergingExpression extends Expression { "The collection of input data types must not be empty.") require( TypeCoercion.haveSameType(inputTypesForMerging), - "All input types must be the same except nullable, containsNull, valueContainsNull flags.") + "All input types must be the same except nullable, containsNull, valueContainsNull flags." + + s" The input types found are\n\t${inputTypesForMerging.mkString("\n\t")}") inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 41fe0c3b60d9e..65afcdd4d67f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -843,6 +843,14 @@ object SQLConf { .intConf .createWithDefault(10) + val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") + .internal() + .doc("State format version used by flatMapGroupsWithState operation in a streaming query") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) + val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation") .doc("The default location for storing checkpoint data for streaming queries.") .stringConf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 02e095b42a506..0c4ea857fd1d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -504,9 +504,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case FlatMapGroupsWithState( func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _, timeout, child) => + val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) val execPlan = FlatMapGroupsWithStateExec( - func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode, - timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, stateVersion, + outputMode, timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) execPlan :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 8e82cccbc8fa3..bfe7d00f56048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,10 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} -import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator /** @@ -52,6 +50,7 @@ case class FlatMapGroupsWithStateExec( outputObjAttr: Attribute, stateInfo: Option[StatefulOperatorStateInfo], stateEncoder: ExpressionEncoder[Any], + stateFormatVersion: Int, outputMode: OutputMode, timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], @@ -60,32 +59,15 @@ case class FlatMapGroupsWithStateExec( ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { import GroupStateImpl._ + import FlatMapGroupsWithStateExecHelper._ private val isTimeoutEnabled = timeoutConf != NoTimeout - private val timestampTimeoutAttribute = - AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() - private val stateAttributes: Seq[Attribute] = { - val encSchemaAttribs = stateEncoder.schema.toAttributes - if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs - } - // Get the serializer for the state, taking into account whether we need to save timestamps - private val stateSerializer = { - val encoderSerializer = stateEncoder.namedExpressions - if (isTimeoutEnabled) { - encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) - } else { - encoderSerializer - } - } - // Get the deserializer for the state. Note that this must be done in the driver, as - // resolving and binding of deserializer expressions to the encoded type can be safely done - // only in the driver. - private val stateDeserializer = stateEncoder.resolveAndBind().deserializer - private val watermarkPresent = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false } + private[sql] val stateManager = + createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion) /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = @@ -125,11 +107,11 @@ case class FlatMapGroupsWithStateExec( child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, groupingAttributes.toStructType, - stateAttributes.toStructType, + stateManager.stateSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val updater = new StateStoreUpdater(store) + val processor = new InputProcessor(store) // If timeout is based on event time, then filter late data based on watermark val filteredIter = watermarkPredicateForData match { @@ -143,7 +125,7 @@ case class FlatMapGroupsWithStateExec( // all the data has been processed. This is to ensure that the timeout information of all // the keys with data is updated before they are processed for timeouts. val outputIterator = - updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys() + processor.processNewData(filteredIter) ++ processor.processTimedOutState() // Return an iterator of all the rows generated by all the keys, such that when fully // consumed, all the state updates will be committed by the state store @@ -158,7 +140,7 @@ case class FlatMapGroupsWithStateExec( } /** Helper class to update the state store */ - class StateStoreUpdater(store: StateStore) { + class InputProcessor(store: StateStore) { // Converters for translating input keys, values, output data between rows and Java objects private val getKeyObj = @@ -167,14 +149,6 @@ case class FlatMapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - // Converters for translating state between rows and Java objects - private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( - stateDeserializer, stateAttributes) - private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // Index of the additional metadata fields in the state row - private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute) - // Metrics private val numUpdatedStateRows = longMetric("numUpdatedStateRows") private val numOutputRows = longMetric("numOutputRows") @@ -183,20 +157,19 @@ case class FlatMapGroupsWithStateExec( * For every group, get the key, values and corresponding state and call the function, * and return an iterator of rows */ - def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) groupedIter.flatMap { case (keyRow, valueRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] callFunctionAndUpdateState( - keyUnsafeRow, + stateManager.getState(store, keyUnsafeRow), valueRowIter, - store.get(keyUnsafeRow), hasTimedOut = false) } } /** Find the groups that have timeout set and are timing out right now, and call the function */ - def updateStateForTimedOutKeys(): Iterator[InternalRow] = { + def processTimedOutState(): Iterator[InternalRow] = { if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { case ProcessingTimeTimeout => batchTimestampMs.get @@ -205,12 +178,11 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutPairs = store.getRange(None, None).filter { rowPair => - val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) - timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold + val timingOutPairs = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold } - timingOutPairs.flatMap { rowPair => - callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) + timingOutPairs.flatMap { stateData => + callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true) } } else Iterator.empty } @@ -220,22 +192,19 @@ case class FlatMapGroupsWithStateExec( * iterator. Note that the store updating is lazy, that is, the store will be updated only * after the returned iterator is fully consumed. * - * @param keyRow Row representing the key, cannot be null + * @param stateData All the data related to the state to be updated * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty - * @param prevStateRow Row representing the previous state, can be null * @param hasTimedOut Whether this function is being called for a key timeout */ private def callFunctionAndUpdateState( - keyRow: UnsafeRow, + stateData: StateData, valueRowIter: Iterator[InternalRow], - prevStateRow: UnsafeRow, hasTimedOut: Boolean): Iterator[InternalRow] = { - val keyObj = getKeyObj(keyRow) // convert key to objects + val keyObj = getKeyObj(stateData.keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObj = getStateObj(prevStateRow) - val keyedState = GroupStateImpl.createForStreaming( - Option(stateObj), + val groupState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, @@ -243,50 +212,24 @@ case class FlatMapGroupsWithStateExec( watermarkPresent) // Call function, get the returned objects and convert them to rows - val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => + val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => numOutputRows += 1 getOutputRow(obj) } // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { - - val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp - // If the state has not yet been set but timeout has been set, then - // we have to generate a row to save the timeout. However, attempting serialize - // null using case class encoder throws - - // java.lang.NullPointerException: Null value appeared in non-nullable field: - // If the schema is inferred from a Scala tuple / case class, or a Java bean, please - // try to use scala.Option[_] or other nullable types. - if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) { - throw new IllegalStateException( - "Cannot set timeout when state is not defined, that is, state has not been" + - "initialized or has been removed") - } - - if (keyedState.hasRemoved) { - store.remove(keyRow) + if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) { + stateManager.removeState(store, stateData.keyRow) numUpdatedStateRows += 1 - } else { - val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow) - val stateRowToWrite = if (keyedState.hasUpdated) { - getStateRow(keyedState.get) - } else { - prevStateRow - } - - val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp - val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged + val currentTimeoutTimestamp = groupState.getTimeoutTimestamp + val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp + val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged if (shouldWriteState) { - if (stateRowToWrite == null) { - // This should never happen because checks in GroupStateImpl should avoid cases - // where empty state would need to be written - throw new IllegalStateException("Attempting to write empty state") - } - setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) - store.put(keyRow, stateRowToWrite) + val updatedStateObj = if (groupState.exists) groupState.get else null + stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp) numUpdatedStateRows += 1 } } @@ -295,28 +238,5 @@ case class FlatMapGroupsWithStateExec( // Return an iterator of rows such that fully consumed, the updated state value will be saved CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) } - - /** Returns the state as Java object if defined */ - def getStateObj(stateRow: UnsafeRow): Any = { - if (stateRow != null) getStateObjFromRow(stateRow) else null - } - - /** Returns the row for an updated state */ - def getStateRow(obj: Any): UnsafeRow = { - assert(obj != null) - getStateRowFromObj(obj) - } - - /** Returns the timeout timestamp of a state row is set */ - def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled && stateRow != null) { - stateRow.getLong(timeoutTimestampIndex) - } else NO_TIMESTAMP - } - - /** Set the timestamp in a state row */ - def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { - if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps) - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 1ae3f36c152cf..9847756f22d4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -22,7 +22,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.internal.SQLConf._ +import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper +import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} /** * An ordered collection of offsets, used to track the progress of processing data from one or more @@ -87,7 +88,8 @@ case class OffsetSeqMetadata( object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) private val relevantSQLConfs = Seq( - SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY) + SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) /** * Default values of relevant configurations that are used for backward compatibility. @@ -100,7 +102,9 @@ object OffsetSeqMetadata extends Logging { * with a specific default value for ensuring same behavior of the query as before. */ private val relevantSQLConfDefaultValues = Map[String, String]( - STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME + STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME, + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> + FlatMapGroupsWithStateExecHelper.legacyVersion.toString ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala new file mode 100644 index 0000000000000..0a16a3819b778 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.ObjectOperator +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.types._ + + +object FlatMapGroupsWithStateExecHelper { + + val supportedVersions = Seq(1, 2) + val legacyVersion = 1 + + /** + * Class to capture deserialized state and timestamp return by the state manager. + * This is intended for reuse. + */ + case class StateData( + var keyRow: UnsafeRow = null, + var stateRow: UnsafeRow = null, + var stateObj: Any = null, + var timeoutTimestamp: Long = -1) { + + private[FlatMapGroupsWithStateExecHelper] def withNew( + newKeyRow: UnsafeRow, + newStateRow: UnsafeRow, + newStateObj: Any, + newTimeout: Long): this.type = { + keyRow = newKeyRow + stateRow = newStateRow + stateObj = newStateObj + timeoutTimestamp = newTimeout + this + } + } + + /** Interface for interacting with state data of FlatMapGroupsWithState */ + sealed trait StateManager extends Serializable { + def stateSchema: StructType + def getState(store: StateStore, keyRow: UnsafeRow): StateData + def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timeoutTimestamp: Long): Unit + def removeState(store: StateStore, keyRow: UnsafeRow): Unit + def getAllState(store: StateStore): Iterator[StateData] + } + + def createStateManager( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean, + stateFormatVersion: Int): StateManager = { + stateFormatVersion match { + case 1 => new StateManagerImplV1(stateEncoder, shouldStoreTimestamp) + case 2 => new StateManagerImplV2(stateEncoder, shouldStoreTimestamp) + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") + } + } + + // =============================================================================================== + // =========================== Private implementations of StateManager =========================== + // =============================================================================================== + + /** Commmon methods for StateManager implementations */ + private abstract class StateManagerImplBase(shouldStoreTimestamp: Boolean) + extends StateManager { + + protected def stateSerializerExprs: Seq[Expression] + protected def stateDeserializerExpr: Expression + protected def timeoutTimestampOrdinalInRow: Int + + /** Get deserialized state and corresponding timeout timestamp for a key */ + override def getState(store: StateStore, keyRow: UnsafeRow): StateData = { + val stateRow = store.get(keyRow) + stateDataForGets.withNew(keyRow, stateRow, getStateObject(stateRow), getTimestamp(stateRow)) + } + + /** Put state and timeout timestamp for a key */ + override def putState(store: StateStore, key: UnsafeRow, state: Any, timestamp: Long): Unit = { + val stateRow = getStateRow(state) + setTimestamp(stateRow, timestamp) + store.put(key, stateRow) + } + + override def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { + store.remove(keyRow) + } + + override def getAllState(store: StateStore): Iterator[StateData] = { + val stateData = StateData() + store.getRange(None, None).map { p => + stateData.withNew(p.key, p.value, getStateObject(p.value), getTimestamp(p.value)) + } + } + + private lazy val stateSerializerFunc = ObjectOperator.serializeObjectToRow(stateSerializerExprs) + private lazy val stateDeserializerFunc = { + ObjectOperator.deserializeRowToObject(stateDeserializerExpr, stateSchema.toAttributes) + } + private lazy val stateDataForGets = StateData() + + protected def getStateObject(row: UnsafeRow): Any = { + if (row != null) stateDeserializerFunc(row) else null + } + + protected def getStateRow(obj: Any): UnsafeRow = { + stateSerializerFunc(obj) + } + + /** Returns the timeout timestamp of a state row is set */ + private def getTimestamp(stateRow: UnsafeRow): Long = { + if (shouldStoreTimestamp && stateRow != null) { + stateRow.getLong(timeoutTimestampOrdinalInRow) + } else NO_TIMESTAMP + } + + /** Set the timestamp in a state row */ + private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinalInRow, timeoutTimestamps) + } + } + + /** + * Version 1 of the StateManager which stores the user-defined state as flattened columns in + * the UnsafeRow. Say the user-defined state has 3 fields - col1, col2, col3. The + * unsafe rows will look like this. + * + * UnsafeRow[ col1 | col2 | col3 | timestamp ] + * + * The limitation of this format is that timestamp cannot be set when the user-defined + * state has been removed. This is because the columns cannot be collectively marked to be + * empty/null. + */ + private class StateManagerImplV1( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) { + + private val timestampTimeoutAttribute = + AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() + + private val stateAttributes: Seq[Attribute] = { + val encSchemaAttribs = stateEncoder.schema.toAttributes + if (shouldStoreTimestamp) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs + } + + override val stateSchema: StructType = stateAttributes.toStructType + + override val timeoutTimestampOrdinalInRow: Int = { + stateAttributes.indexOf(timestampTimeoutAttribute) + } + + override val stateSerializerExprs: Seq[Expression] = { + val encoderSerializer = stateEncoder.namedExpressions + if (shouldStoreTimestamp) { + encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) + } else { + encoderSerializer + } + } + + override val stateDeserializerExpr: Expression = { + // Note that this must be done in the driver, as resolving and binding of deserializer + // expressions to the encoded type can be safely done only in the driver. + stateEncoder.resolveAndBind().deserializer + } + + override protected def getStateRow(obj: Any): UnsafeRow = { + require(obj != null, "State object cannot be null") + super.getStateRow(obj) + } + } + + /** + * Version 2 of the StateManager which stores the user-defined state as a nested struct + * in the UnsafeRow. Say the user-defined state has 3 fields - col1, col2, col3. The + * unsafe rows will look like this. + * ___________________________ + * | | + * | V + * UnsafeRow[ nested-struct | timestamp | UnsafeRow[ col1 | col2 | col3 ] ] + * + * This allows the entire user-defined state to be collectively marked as empty/null, + * thus allowing timestamp to be set without requiring the state to be present. + */ + private class StateManagerImplV2( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) { + + /** Schema of the state rows saved in the state store */ + override val stateSchema: StructType = { + var schema = new StructType().add("groupState", stateEncoder.schema, nullable = true) + if (shouldStoreTimestamp) schema = schema.add("timeoutTimestamp", LongType, nullable = false) + schema + } + + // Ordinals of the information stored in the state row + private val nestedStateOrdinal = 0 + override val timeoutTimestampOrdinalInRow = 1 + + override val stateSerializerExprs: Seq[Expression] = { + val boundRefToSpecificInternalRow = BoundReference( + 0, stateEncoder.serializer.head.collect { case b: BoundReference => b.dataType }.head, true) + + val nestedStateSerExpr = + CreateNamedStruct(stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e))) + + val nullSafeNestedStateSerExpr = { + val nullLiteral = Literal(null, nestedStateSerExpr.dataType) + CaseWhen(Seq(IsNull(boundRefToSpecificInternalRow) -> nullLiteral), nestedStateSerExpr) + } + + if (shouldStoreTimestamp) { + Seq(nullSafeNestedStateSerExpr, Literal(GroupStateImpl.NO_TIMESTAMP)) + } else { + Seq(nullSafeNestedStateSerExpr) + } + } + + override val stateDeserializerExpr: Expression = { + // Note that this must be done in the driver, as resolving and binding of deserializer + // expressions to the encoded type can be safely done only in the driver. + val boundRefToNestedState = + BoundReference(nestedStateOrdinal, stateEncoder.schema, nullable = true) + val deserExpr = stateEncoder.resolveAndBind().deserializer.transformUp { + case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) + } + val nullLiteral = Literal(null, deserExpr.dataType) + CaseWhen(Seq(IsNull(boundRefToNestedState) -> nullLiteral), elseValue = deserExpr) + } + } +} diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata new file mode 100644 index 0000000000000..372180b2096ee --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata @@ -0,0 +1 @@ +{"id":"04d960cd-d38f-4ce6-b8d0-ebcf84c9dccc"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 new file mode 100644 index 0000000000000..807d7b0063b96 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531292029003,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 new file mode 100644 index 0000000000000..cce541073fb4b --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":5000,"batchTimestampMs":1531292030005,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta new file mode 100644 index 0000000000000..193524ffe15b5 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala new file mode 100644 index 0000000000000..dec30fd01f7e2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.GroupStateImpl._ +import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite._ +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types._ + + +class FlatMapGroupsWithStateExecHelperSuite extends StreamTest { + + import testImplicits._ + import FlatMapGroupsWithStateExecHelper._ + + // ============================ StateManagerImplV1 ============================ + + test(s"StateManager v1 - primitive type - without timestamp") { + val schema = new StructType().add("value", IntegerType, nullable = false) + testStateManagerWithoutTimestamp[Int](version = 1, schema, Seq(0, 10)) + } + + test(s"StateManager v1 - primitive type - with timestamp") { + val schema = new StructType() + .add("value", IntegerType, nullable = false) + .add("timeoutTimestamp", IntegerType, nullable = false) + testStateManagerWithTimestamp[Int](version = 1, schema, Seq(0, 10)) + } + + test(s"StateManager v1 - nested type - without timestamp") { + val schema = StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType)) + )) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null)) + + testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, testValues) + + // Verify the limitation of v1 with null state + intercept[Exception] { + testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, testValues = Seq(null)) + } + } + + test(s"StateManager v1 - nested type - with timestamp") { + val schema = StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType)) + )), + StructField("timeoutTimestamp", IntegerType, nullable = false) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null)) + + testStateManagerWithTimestamp[NestedStruct](version = 1, schema, testValues) + + // Verify the limitation of v1 with null state + intercept[Exception] { + testStateManagerWithTimestamp[NestedStruct](version = 1, schema, testValues = Seq(null)) + } + } + + // ============================ StateManagerImplV2 ============================ + + test(s"StateManager v2 - primitive type - without timestamp") { + val schema = new StructType() + .add("groupState", new StructType().add("value", IntegerType, nullable = false)) + testStateManagerWithoutTimestamp[Int](version = 2, schema, Seq(0, 10)) + } + + test(s"StateManager v2 - primitive type - with timestamp") { + val schema = new StructType() + .add("groupState", new StructType().add("value", IntegerType, nullable = false)) + .add("timeoutTimestamp", LongType, nullable = false) + testStateManagerWithTimestamp[Int](version = 2, schema, Seq(0, 10)) + } + + test(s"StateManager v2 - nested type - without timestamp") { + val schema = StructType(Seq( + StructField("groupState", StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType) + ))) + ))) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null), + null) + + testStateManagerWithoutTimestamp[NestedStruct](version = 2, schema, testValues) + } + + test(s"StateManager v2 - nested type - with timestamp") { + val schema = StructType(Seq( + StructField("groupState", StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType) + ))) + ))), + StructField("timeoutTimestamp", LongType, nullable = false) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null), + null) + + testStateManagerWithTimestamp[NestedStruct](version = 2, schema, testValues) + } + + + def testStateManagerWithoutTimestamp[T: Encoder]( + version: Int, + expectedStateSchema: StructType, + testValues: Seq[T]): Unit = { + val stateManager = newStateManager[T](version, withTimestamp = false) + assert(stateManager.stateSchema === expectedStateSchema) + testStateManager(stateManager, testValues, NO_TIMESTAMP) + } + + def testStateManagerWithTimestamp[T: Encoder]( + version: Int, + expectedStateSchema: StructType, + testValues: Seq[T]): Unit = { + val stateManager = newStateManager[T](version, withTimestamp = true) + assert(stateManager.stateSchema === expectedStateSchema) + for (timestamp <- Seq(NO_TIMESTAMP, 1000)) { + testStateManager(stateManager, testValues, timestamp) + } + } + + private def testStateManager[T: Encoder]( + stateManager: StateManager, + values: Seq[T], + timestamp: Long): Unit = { + val keys = (1 to values.size).map(_ => newKey()) + val store = new MemoryStateStore() + + // Test stateManager.getState(), putState(), removeState() + keys.zip(values).foreach { case (key, value) => + try { + stateManager.putState(store, key, value, timestamp) + val data = stateManager.getState(store, key) + assert(data.stateObj == value) + assert(data.timeoutTimestamp === timestamp) + stateManager.removeState(store, key) + assert(stateManager.getState(store, key).stateObj == null) + } catch { + case e: Throwable => + fail(s"put/get/remove test with '$value' failed", e) + } + } + + // Test stateManager.getAllState() + for (i <- keys.indices) { + stateManager.putState(store, keys(i), values(i), timestamp) + } + val allData = stateManager.getAllState(store).map(_.copy()).toArray + assert(allData.map(_.timeoutTimestamp).toSet == Set(timestamp)) + assert(allData.map(_.stateObj).toSet == values.toSet) + } + + private def newStateManager[T: Encoder](version: Int, withTimestamp: Boolean): StateManager = { + FlatMapGroupsWithStateExecHelper.createStateManager( + implicitly[Encoder[T]].asInstanceOf[ExpressionEncoder[Any]], + withTimestamp, + version) + } + + private val proj = UnsafeProjection.create(Array[DataType](IntegerType)) + private val keyCounter = new AtomicInteger(0) + private def newKey(): UnsafeRow = { + proj.apply(new GenericInternalRow(Array[Any](keyCounter.getAndDecrement()))).copy() + } +} + +case class Struct(d: Double, str: String) +case class NestedStruct(i: Int, nested: Struct) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 988c8e6753e25..82d7755aef5f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.streaming +import java.io.File import java.sql.Date import java.util.concurrent.ConcurrentHashMap +import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfterAll import org.scalatest.exceptions.TestFailedException @@ -31,10 +33,12 @@ import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec -import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.util.Utils /** Class to check custom state types */ case class RunningCount(count: Long) @@ -359,13 +363,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } } - // Values used for testing StateStoreUpdater + // Values used for testing InputProcessor val currentBatchTimestamp = 1000 val currentBatchWatermark = 1000 val beforeTimeoutThreshold = 999 val afterTimeoutThreshold = 1001 - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout + // Tests for InputProcessor.processNewData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" val testName = s"NoTimeout - $priorStateStr - " @@ -396,7 +400,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = None) // should be removed } - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout + // Tests for InputProcessor.processTimedOutState() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { var testName = "" @@ -443,6 +447,18 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = None) // state should be removed } + // Tests with ProcessingTimeTimeout + if (priorState == None) { + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated without initializing state", + stateUpdates = state => { state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + } + testStateUpdateWithData( s"ProcessingTimeTimeout - $testName - state and timeout duration updated", stateUpdates = @@ -453,6 +469,30 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = Some(5), // state should change expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated after state removed", + stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + + // Tests with EventTimeTimeout + + if (priorState == None) { + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { + state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } + testStateUpdateWithData( s"EventTimeTimeout - $testName - state and timeout timestamp updated", stateUpdates = @@ -477,48 +517,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest priorTimeoutTimestamp = priorTimeoutTimestamp, expectedState = Some(5), // state should change expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - } - } - // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(), - // Try to remove these cases in the future - for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { - val testName = - if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout" - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { + state.remove(); state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } } - // Tests for StateStoreUpdater.updateStateForTimedOutKeys() + // Tests for InputProcessor.processTimedOutState() val preTimeoutState = Some(5) for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { testStateUpdateWithTimeout( @@ -590,7 +603,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = Some(5), // state should change expectedTimeoutTimestamp = 5000) // timestamp should change - test("flatMapGroupsWithState - streaming") { + testWithAllStateVersions("flatMapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { @@ -669,7 +682,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } - test("flatMapGroupsWithState - streaming + aggregation") { + testWithAllStateVersions("flatMapGroupsWithState - streaming + aggregation") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { @@ -728,7 +741,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF) } - test("flatMapGroupsWithState - streaming with processing time timeout") { + testWithAllStateVersions("flatMapGroupsWithState - streaming with processing time timeout") { // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. // It returns the count if changed, or -1 if the state was removed by timeout. val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { @@ -792,7 +805,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } - test("flatMapGroupsWithState - streaming with event time timeout + watermark") { + testWithAllStateVersions("flatMapGroupsWithState - streaming w/ event time timeout + watermark") { // Function to maintain the max event time as state and set the timeout timestamp based on the // current max event time seen. It returns the max event time in the state, or -1 if the state // was removed by timeout. @@ -843,6 +856,105 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } + test("flatMapGroupsWithState - uses state format version 2 by default") { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + Iterator((key, count.toString)) + } + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + Execute { query => + // Verify state format = 2 + val f = query.lastExecution.executedPlan.collect { case f: FlatMapGroupsWithStateExec => f } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 2) + } + ) + } + + test("flatMapGroupsWithState - recovery from checkpoint uses state format version 1") { + // Function to maintain the max event time as state and set the timeout timestamp based on the + // current max event time seen. It returns the max event time in the state, or -1 if the state + // was removed by timeout. + val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } + + val timeoutDelaySec = 5 + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) + } else { + val valuesSeq = values.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec + state.update(maxEventTimeSec) + state.setTimeoutTimestamp(timeoutTimestampSec * 1000) + Iterator((key, maxEventTimeSec.toInt)) + } + } + val inputData = MemoryStream[(String, Int)] + val result = + inputData.toDS + .select($"_1".as("key"), $"_2".cast("timestamp").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)] + .groupByKey(_._1) + .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(("a", 11), ("a", 13), ("a", 15)) + inputData.addData(("a", 4)) + + testStream(result, Update)( + StartStream( + checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = Map(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "2")), + /* + Note: The checkpoint was generated using the following input in Spark version 2.3.1 + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark + */ + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + Execute { query => + // Verify state format = 1 + val f = query.lastExecution.executedPlan.collect { case f: FlatMapGroupsWithStateExec => f } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 1) + }, + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 + ) + } + + test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) @@ -1032,7 +1144,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state } - test(s"StateStoreUpdater - updates with data - $testName") { + test(s"InputProcessor - process new data - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === false, "hasTimedOut not false") assert(values.nonEmpty, "Some value is expected") @@ -1054,7 +1166,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState: Option[Int], expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { - test(s"StateStoreUpdater - updates for timeout - $testName") { + test(s"InputProcessor - process timed out state - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === true, "hasTimedOut not true") assert(values.isEmpty, "values not empty") @@ -1081,21 +1193,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( mapGroupsFunc, timeoutConf, currentBatchTimestamp) - val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) + val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store) + val stateManager = mapGroupsSparkPlan.stateManager val key = intToRow(0) // Prepare store with prior state configs - if (priorState.nonEmpty) { - val row = updater.getStateRow(priorState.get) - updater.setTimeoutTimestamp(row, priorTimeoutTimestamp) - store.put(key.copy(), row.copy()) + if (priorState.nonEmpty || priorTimeoutTimestamp != NO_TIMESTAMP) { + stateManager.putState(store, key, priorState.orNull, priorTimeoutTimestamp) } // Call updating function to update state store def callFunction() = { val returnedIter = if (testTimeoutUpdates) { - updater.updateStateForTimedOutKeys() + inputProcessor.processTimedOutState() } else { - updater.updateStateForKeysWithData(Iterator(key)) + inputProcessor.processNewData(Iterator(key)) } returnedIter.size // consume the iterator to force state updates } @@ -1106,15 +1217,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } else { // Call function to update and verify updated state in store callFunction() - val updatedStateRow = store.get(key) - assert( - Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState, + val updatedState = stateManager.getState(store, key) + assert(Option(updatedState.stateObj).map(_.toString.toInt) === expectedState, "final state not as expected") - if (updatedStateRow != null) { - assert( - updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp, - "final timeout timestamp not as expected") - } + assert(updatedState.timeoutTimestamp === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") } } @@ -1122,6 +1229,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest func: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int], timeoutType: GroupStateTimeout = GroupStateTimeout.NoTimeout, batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = { + val stateFormatVersion = spark.conf.get(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) MemoryStream[Int] .toDS .groupByKey(x => x) @@ -1129,7 +1237,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .logicalPlan.collectFirst { case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => FlatMapGroupsWithStateExec( - f, k, v, g, d, o, None, s, m, t, + f, k, v, g, d, o, None, s, stateFormatVersion, m, t, Some(currentBatchTimestamp), Some(currentBatchWatermark), RDDScanExec(g, null, "rdd")) }.get } @@ -1162,6 +1270,16 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } def rowToInt(row: UnsafeRow): Int = row.getInt(0) + + def testWithAllStateVersions(name: String)(func: => Unit): Unit = { + for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) { + test(s"$name - state format version $version") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> version.toString) { + func + } + } + } + } } object FlatMapGroupsWithStateSuite {