Skip to content

Commit 6c5cb85

Browse files
HeartSaVioRtdas
authored andcommitted
[SPARK-24763][SS] Remove redundant key data from value in streaming aggregation
## What changes were proposed in this pull request? This patch proposes a new flag option for stateful aggregation: remove redundant key data from value. Enabling new option runs similar with current, and uses less memory for state according to key/value fields of state operator. Please refer below link to see detailed perf. test result: https://issues.apache.org/jira/browse/SPARK-24763?focusedCommentId=16536539&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16536539 Since the state between enabling the option and disabling the option is not compatible, the option is set to 'disable' by default (to ensure backward compatibility), and OffsetSeqMetadata would prevent modifying the option after executing query. ## How was this patch tested? Modify unit tests to cover both disabling option and enabling option. Also did manual tests to see whether propose patch improves state memory usage. Closes #21733 from HeartSaVioR/SPARK-24763. Authored-by: Jungtaek Lim <[email protected]> Signed-off-by: Tathagata Das <[email protected]>
1 parent 72ecfd0 commit 6c5cb85

File tree

26 files changed

+573
-85
lines changed

26 files changed

+573
-85
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,16 @@ object SQLConf {
888888
.intConf
889889
.createWithDefault(2)
890890

891+
val STREAMING_AGGREGATION_STATE_FORMAT_VERSION =
892+
buildConf("spark.sql.streaming.aggregation.stateFormatVersion")
893+
.internal()
894+
.doc("State format version used by streaming aggregation operations in a streaming query. " +
895+
"State between versions are tend to be incompatible, so state format version shouldn't " +
896+
"be modified after running.")
897+
.intConf
898+
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
899+
.createWithDefault(2)
900+
891901
val UNSUPPORTED_OPERATION_CHECK_ENABLED =
892902
buildConf("spark.sql.streaming.unsupportedOperationCheck")
893903
.internal()

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,10 +328,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
328328
"Streaming aggregation doesn't support group aggregate pandas UDF")
329329
}
330330

331+
val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION)
332+
331333
aggregate.AggUtils.planStreamingAggregation(
332334
namedGroupingExpressions,
333335
aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]),
334336
rewrittenResultExpressions,
337+
stateVersion,
335338
planLater(child))
336339

337340
case _ => Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ object AggUtils {
260260
groupingExpressions: Seq[NamedExpression],
261261
functionsWithoutDistinct: Seq[AggregateExpression],
262262
resultExpressions: Seq[NamedExpression],
263+
stateFormatVersion: Int,
263264
child: SparkPlan): Seq[SparkPlan] = {
264265

265266
val groupingAttributes = groupingExpressions.map(_.toAttribute)
@@ -291,7 +292,8 @@ object AggUtils {
291292
child = partialAggregate)
292293
}
293294

294-
val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1)
295+
val restored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion,
296+
partialMerged1)
295297

296298
val partialMerged2: SparkPlan = {
297299
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
@@ -315,6 +317,7 @@ object AggUtils {
315317
stateInfo = None,
316318
outputMode = None,
317319
eventTimeWatermark = None,
320+
stateFormatVersion = stateFormatVersion,
318321
partialMerged2)
319322

320323
val finalAndCompleteAggregate: SparkPlan = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,21 @@ class IncrementalExecution(
102102
val state = new Rule[SparkPlan] {
103103

104104
override def apply(plan: SparkPlan): SparkPlan = plan transform {
105-
case StateStoreSaveExec(keys, None, None, None,
105+
case StateStoreSaveExec(keys, None, None, None, stateFormatVersion,
106106
UnaryExecNode(agg,
107-
StateStoreRestoreExec(_, None, child))) =>
107+
StateStoreRestoreExec(_, None, _, child))) =>
108108
val aggStateInfo = nextStatefulOperationStateInfo
109109
StateStoreSaveExec(
110110
keys,
111111
Some(aggStateInfo),
112112
Some(outputMode),
113113
Some(offsetSeqMetadata.batchWatermarkMs),
114+
stateFormatVersion,
114115
agg.withNewChildren(
115116
StateStoreRestoreExec(
116117
keys,
117118
Some(aggStateInfo),
119+
stateFormatVersion,
118120
child) :: Nil))
119121

120122
case StreamingDeduplicateExec(keys, child, None, None) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.json4s.jackson.Serialization
2222

2323
import org.apache.spark.internal.Logging
2424
import org.apache.spark.sql.RuntimeConfig
25-
import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper
25+
import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager}
2626
import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _}
2727

2828
/**
@@ -89,7 +89,7 @@ object OffsetSeqMetadata extends Logging {
8989
private implicit val format = Serialization.formats(NoTypeHints)
9090
private val relevantSQLConfs = Seq(
9191
SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY,
92-
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
92+
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION)
9393

9494
/**
9595
* Default values of relevant configurations that are used for backward compatibility.
@@ -104,7 +104,9 @@ object OffsetSeqMetadata extends Logging {
104104
private val relevantSQLConfDefaultValues = Map[String, String](
105105
STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME,
106106
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key ->
107-
FlatMapGroupsWithStateExecHelper.legacyVersion.toString
107+
FlatMapGroupsWithStateExecHelper.legacyVersion.toString,
108+
STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
109+
StreamingAggregationStateManager.legacyVersion.toString
108110
)
109111

110112
def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming.state
19+
20+
import org.apache.spark.internal.Logging
21+
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
22+
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
23+
import org.apache.spark.sql.types.StructType
24+
25+
/**
26+
* Base trait for state manager purposed to be used from streaming aggregations.
27+
*/
28+
sealed trait StreamingAggregationStateManager extends Serializable {
29+
30+
/** Extract columns consisting key from input row, and return the new row for key columns. */
31+
def getKey(row: UnsafeRow): UnsafeRow
32+
33+
/** Calculate schema for the value of state. The schema is mainly passed to the StateStoreRDD. */
34+
def getStateValueSchema: StructType
35+
36+
/** Get the current value of a non-null key from the target state store. */
37+
def get(store: StateStore, key: UnsafeRow): UnsafeRow
38+
39+
/**
40+
* Put a new value for a non-null key to the target state store. Note that key will be
41+
* extracted from the input row, and the key would be same as the result of getKey(inputRow).
42+
*/
43+
def put(store: StateStore, row: UnsafeRow): Unit
44+
45+
/**
46+
* Commit all the updates that have been made to the target state store, and return the
47+
* new version.
48+
*/
49+
def commit(store: StateStore): Long
50+
51+
/** Remove a single non-null key from the target state store. */
52+
def remove(store: StateStore, key: UnsafeRow): Unit
53+
54+
/** Return an iterator containing all the key-value pairs in target state store. */
55+
def iterator(store: StateStore): Iterator[UnsafeRowPair]
56+
57+
/** Return an iterator containing all the keys in target state store. */
58+
def keys(store: StateStore): Iterator[UnsafeRow]
59+
60+
/** Return an iterator containing all the values in target state store. */
61+
def values(store: StateStore): Iterator[UnsafeRow]
62+
}
63+
64+
object StreamingAggregationStateManager extends Logging {
65+
val supportedVersions = Seq(1, 2)
66+
val legacyVersion = 1
67+
68+
def createStateManager(
69+
keyExpressions: Seq[Attribute],
70+
inputRowAttributes: Seq[Attribute],
71+
stateFormatVersion: Int): StreamingAggregationStateManager = {
72+
stateFormatVersion match {
73+
case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, inputRowAttributes)
74+
case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, inputRowAttributes)
75+
case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid")
76+
}
77+
}
78+
}
79+
80+
abstract class StreamingAggregationStateManagerBaseImpl(
81+
protected val keyExpressions: Seq[Attribute],
82+
protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager {
83+
84+
@transient protected lazy val keyProjector =
85+
GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes)
86+
87+
override def getKey(row: UnsafeRow): UnsafeRow = keyProjector(row)
88+
89+
override def commit(store: StateStore): Long = store.commit()
90+
91+
override def remove(store: StateStore, key: UnsafeRow): Unit = store.remove(key)
92+
93+
override def keys(store: StateStore): Iterator[UnsafeRow] = {
94+
// discard and don't convert values to avoid computation
95+
store.getRange(None, None).map(_.key)
96+
}
97+
}
98+
99+
/**
100+
* The implementation of StreamingAggregationStateManager for state version 1.
101+
* In state version 1, the schema of key and value in state are follow:
102+
*
103+
* - key: Same as key expressions.
104+
* - value: Same as input row attributes. The schema of value contains key expressions as well.
105+
*
106+
* @param keyExpressions The attributes of keys.
107+
* @param inputRowAttributes The attributes of input row.
108+
*/
109+
class StreamingAggregationStateManagerImplV1(
110+
keyExpressions: Seq[Attribute],
111+
inputRowAttributes: Seq[Attribute])
112+
extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) {
113+
114+
override def getStateValueSchema: StructType = inputRowAttributes.toStructType
115+
116+
override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
117+
store.get(key)
118+
}
119+
120+
override def put(store: StateStore, row: UnsafeRow): Unit = {
121+
store.put(getKey(row), row)
122+
}
123+
124+
override def iterator(store: StateStore): Iterator[UnsafeRowPair] = {
125+
store.iterator()
126+
}
127+
128+
override def values(store: StateStore): Iterator[UnsafeRow] = {
129+
store.iterator().map(_.value)
130+
}
131+
}
132+
133+
/**
134+
* The implementation of StreamingAggregationStateManager for state version 2.
135+
* In state version 2, the schema of key and value in state are follow:
136+
*
137+
* - key: Same as key expressions.
138+
* - value: The diff between input row attributes and key expressions.
139+
*
140+
* The schema of value is changed to optimize the memory/space usage in state, via removing
141+
* duplicated columns in key-value pair. Hence key columns are excluded from the schema of value.
142+
*
143+
* @param keyExpressions The attributes of keys.
144+
* @param inputRowAttributes The attributes of input row.
145+
*/
146+
class StreamingAggregationStateManagerImplV2(
147+
keyExpressions: Seq[Attribute],
148+
inputRowAttributes: Seq[Attribute])
149+
extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) {
150+
151+
private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions)
152+
private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions
153+
154+
// flag to check whether the row needs to be project into input row attributes after join
155+
// e.g. if the fields in the joined row are not in the expected order
156+
private val needToProjectToRestoreValue: Boolean =
157+
keyValueJoinedExpressions != inputRowAttributes
158+
159+
@transient private lazy val valueProjector =
160+
GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes)
161+
162+
@transient private lazy val joiner =
163+
GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions),
164+
StructType.fromAttributes(valueExpressions))
165+
@transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate(
166+
inputRowAttributes, keyValueJoinedExpressions)
167+
168+
override def getStateValueSchema: StructType = valueExpressions.toStructType
169+
170+
override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
171+
val savedState = store.get(key)
172+
if (savedState == null) {
173+
return savedState
174+
}
175+
176+
restoreOriginalRow(key, savedState)
177+
}
178+
179+
override def put(store: StateStore, row: UnsafeRow): Unit = {
180+
val key = keyProjector(row)
181+
val value = valueProjector(row)
182+
store.put(key, value)
183+
}
184+
185+
override def iterator(store: StateStore): Iterator[UnsafeRowPair] = {
186+
store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginalRow(rowPair)))
187+
}
188+
189+
override def values(store: StateStore): Iterator[UnsafeRow] = {
190+
store.iterator().map(rowPair => restoreOriginalRow(rowPair))
191+
}
192+
193+
private def restoreOriginalRow(rowPair: UnsafeRowPair): UnsafeRow = {
194+
restoreOriginalRow(rowPair.key, rowPair.value)
195+
}
196+
197+
private def restoreOriginalRow(key: UnsafeRow, value: UnsafeRow): UnsafeRow = {
198+
val joinedRow = joiner.join(key, value)
199+
if (needToProjectToRestoreValue) {
200+
restoreValueProjector(joinedRow)
201+
} else {
202+
joinedRow
203+
}
204+
}
205+
}

0 commit comments

Comments
 (0)