Skip to content

Commit 076b8bd

Browse files
hejufangxiangyuf
andauthored
[FLINK-37294][state] Support state migration between disabling and enabling ttl in HeapKeyedStateBackend (#26651)
Co-authored-by: Xiangyu Feng <[email protected]>
1 parent 31785e0 commit 076b8bd

File tree

11 files changed

+200
-257
lines changed

11 files changed

+200
-257
lines changed

flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AbstractHeapState.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
import org.apache.flink.api.java.tuple.Tuple2;
2525
import org.apache.flink.queryablestate.client.state.serialization.KvStateSerializer;
2626
import org.apache.flink.runtime.state.internal.InternalKvState;
27+
import org.apache.flink.runtime.state.ttl.TtlAwareSerializer;
28+
import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
29+
import org.apache.flink.runtime.state.ttl.TtlValue;
2730
import org.apache.flink.util.Preconditions;
2831

2932
/**
@@ -112,6 +115,18 @@ public byte[] getSerializedValue(
112115
return KvStateSerializer.serializeValue(result, safeValueSerializer);
113116
}
114117

118+
@SuppressWarnings("unchecked")
119+
public SV migrateTtlValue(
120+
SV stateValue,
121+
TtlAwareSerializer<SV, ?> currentTtlAwareSerializer,
122+
TtlTimeProvider ttlTimeProvider) {
123+
if (currentTtlAwareSerializer.isTtlEnabled()) {
124+
return (SV) new TtlValue<>(stateValue, ttlTimeProvider.currentTimestamp());
125+
}
126+
127+
return (SV) ((TtlValue<?>) stateValue).getUserValue();
128+
}
129+
115130
/** This should only be used for testing. */
116131
@VisibleForTesting
117132
public StateTable<K, N, SV> getStateTable() {

flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import org.apache.flink.runtime.state.StateSnapshotRestore;
5151
import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
5252
import org.apache.flink.runtime.state.StateSnapshotTransformers;
53+
import org.apache.flink.runtime.state.StateTransformationFunction;
5354
import org.apache.flink.runtime.state.StreamCompressionDecorator;
5455
import org.apache.flink.runtime.state.metrics.LatencyTrackingStateConfig;
5556
import org.apache.flink.runtime.state.metrics.SizeTrackingStateConfig;
@@ -65,6 +66,7 @@
6566

6667
import java.util.ArrayList;
6768
import java.util.HashMap;
69+
import java.util.Iterator;
6870
import java.util.List;
6971
import java.util.Map;
7072
import java.util.Spliterators;
@@ -212,7 +214,7 @@ private <N, V> StateTable<K, N, V> tryRegisterStateTable(
212214
StateDescriptor<?, V> stateDesc,
213215
@Nonnull StateSnapshotTransformFactory<V> snapshotTransformFactory,
214216
boolean allowFutureMetadataUpdates)
215-
throws StateMigrationException {
217+
throws Exception {
216218

217219
@SuppressWarnings("unchecked")
218220
StateTable<K, N, V> stateTable =
@@ -259,17 +261,12 @@ private <N, V> StateTable<K, N, V> tryRegisterStateTable(
259261
+ ") must not be incompatible with the old state serializer ("
260262
+ previousStateSerializer
261263
+ ").");
262-
}
263-
264-
// HeapKeyedStateBackend doesn't support ttl state migration currently.
265-
if (TtlAwareSerializer.needTtlStateMigration(
266-
previousStateSerializer, newStateSerializer)) {
267-
throw new StateMigrationException(
268-
"For heap backends, the new state serializer ("
269-
+ newStateSerializer
270-
+ ") must not need ttl state migration with the old state serializer ("
271-
+ previousStateSerializer
272-
+ ").");
264+
} else if (stateCompatibility.isCompatibleAfterMigration()
265+
&& TtlAwareSerializer.needTtlStateMigration(
266+
previousStateSerializer, newStateSerializer)) {
267+
// State migration without ttl change will be performed automatically during
268+
// checkpoint, so we only preform state ttl migration here.
269+
migrateTtlAwareStateValues(stateDesc, previousStateSerializer, newStateSerializer);
273270
}
274271

275272
restoredKvMetaInfo =
@@ -299,6 +296,52 @@ private <N, V> StateTable<K, N, V> tryRegisterStateTable(
299296
return stateTable;
300297
}
301298

299+
@SuppressWarnings("unchecked")
300+
private <V, N> void migrateTtlAwareStateValues(
301+
StateDescriptor<?, V> stateDesc,
302+
TypeSerializer<V> previousSerializer,
303+
TypeSerializer<V> currentSerializer)
304+
throws Exception {
305+
final StateTable<K, N, V> stateTable =
306+
(StateTable<K, N, V>) registeredKVStates.get(stateDesc.getName());
307+
final Iterator<StateEntry<K, N, V>> iterator = stateTable.iterator();
308+
309+
LOG.info(
310+
"Performing state migration for state {} because the state serializer's ttl"
311+
+ " config has been changed from {} to {}.",
312+
stateDesc,
313+
TtlAwareSerializer.isSerializerTtlEnabled(previousSerializer),
314+
TtlAwareSerializer.isSerializerTtlEnabled(currentSerializer));
315+
316+
// we need to get an actual state instance because migration is different
317+
// for different state types. For example, ListState needs to deal with
318+
// individual elements
319+
StateCreateFactory stateCreateFactory = STATE_CREATE_FACTORIES.get(stateDesc.getType());
320+
if (stateCreateFactory == null) {
321+
throw new FlinkRuntimeException(stateNotSupportedMessage(stateDesc));
322+
}
323+
State state =
324+
stateCreateFactory.createState(stateDesc, stateTable, stateTable.keySerializer);
325+
if (!(state instanceof AbstractHeapState)) {
326+
throw new FlinkRuntimeException(
327+
"State should be an AbstractRocksDBState but is " + state);
328+
}
329+
AbstractHeapState<K, N, V> heapState = (AbstractHeapState<K, N, V>) state;
330+
TtlAwareSerializer<V, ?> currentTtlAwareSerializer =
331+
(TtlAwareSerializer<V, ?>)
332+
TtlAwareSerializer.wrapTtlAwareSerializer(currentSerializer);
333+
334+
stateTable.transformAll(
335+
null,
336+
new StateTransformationFunction<V, V>() {
337+
@Override
338+
public V apply(V previousState, V value) throws Exception {
339+
return heapState.migrateTtlValue(
340+
previousState, currentTtlAwareSerializer, ttlTimeProvider);
341+
}
342+
});
343+
}
344+
302345
@SuppressWarnings("unchecked")
303346
@Override
304347
public <N> Stream<K> getKeys(String state, N namespace) {

flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapListState.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
2828
import org.apache.flink.queryablestate.client.state.serialization.KvStateSerializer;
2929
import org.apache.flink.runtime.state.internal.InternalListState;
30+
import org.apache.flink.runtime.state.ttl.TtlAwareSerializer;
31+
import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
32+
import org.apache.flink.runtime.state.ttl.TtlValue;
3033
import org.apache.flink.util.Preconditions;
3134

3235
import java.io.ByteArrayOutputStream;
@@ -142,6 +145,20 @@ public byte[] getSerializedValue(
142145
return baos.toByteArray();
143146
}
144147

148+
@SuppressWarnings("unchecked")
149+
public List<V> migrateTtlValue(
150+
List<V> stateValue,
151+
TtlAwareSerializer<List<V>, ?> currentTtlAwareSerializer,
152+
TtlTimeProvider ttlTimeProvider) {
153+
if (currentTtlAwareSerializer.isTtlEnabled()) {
154+
stateValue.replaceAll(v -> (V) new TtlValue<>(v, ttlTimeProvider.currentTimestamp()));
155+
} else {
156+
stateValue.replaceAll(v -> (V) ((TtlValue<?>) v).getUserValue());
157+
}
158+
159+
return stateValue;
160+
}
161+
145162
// ------------------------------------------------------------------------
146163
// state merging
147164
// ------------------------------------------------------------------------

flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMapState.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
import org.apache.flink.api.java.tuple.Tuple2;
2727
import org.apache.flink.queryablestate.client.state.serialization.KvStateSerializer;
2828
import org.apache.flink.runtime.state.internal.InternalMapState;
29+
import org.apache.flink.runtime.state.ttl.TtlAwareSerializer;
30+
import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
31+
import org.apache.flink.runtime.state.ttl.TtlValue;
2932
import org.apache.flink.util.Preconditions;
3033

3134
import java.util.Collections;
@@ -200,6 +203,28 @@ public byte[] getSerializedValue(
200203
result.entrySet(), dupUserKeySerializer, dupUserValueSerializer);
201204
}
202205

206+
@SuppressWarnings("unchecked")
207+
public Map<UK, UV> migrateTtlValue(
208+
Map<UK, UV> stateValue,
209+
TtlAwareSerializer<Map<UK, UV>, ?> currentTtlAwareSerializer,
210+
TtlTimeProvider ttlTimeProvider) {
211+
212+
if (currentTtlAwareSerializer.isTtlEnabled()) {
213+
for (Map.Entry<UK, UV> entry : stateValue.entrySet()) {
214+
UV value =
215+
(UV) new TtlValue<>(entry.getValue(), ttlTimeProvider.currentTimestamp());
216+
stateValue.put(entry.getKey(), value);
217+
}
218+
} else {
219+
for (Map.Entry<UK, UV> entry : stateValue.entrySet()) {
220+
UV value = (UV) ((TtlValue<?>) entry.getValue()).getUserValue();
221+
stateValue.put(entry.getKey(), value);
222+
}
223+
}
224+
225+
return stateValue;
226+
}
227+
203228
@SuppressWarnings("unchecked")
204229
static <UK, UV, K, N, SV, S extends State, IS extends S> IS create(
205230
StateDescriptor<S, SV> stateDesc,

flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@
3939
import java.util.Arrays;
4040
import java.util.Collection;
4141
import java.util.Iterator;
42+
import java.util.List;
4243
import java.util.Objects;
4344
import java.util.Spliterators;
45+
import java.util.stream.Collectors;
4446
import java.util.stream.Stream;
4547
import java.util.stream.StreamSupport;
4648

@@ -212,6 +214,29 @@ public <T> void transform(
212214
getMapForKeyGroup(keyGroup).transform(key, namespace, value, transformation);
213215
}
214216

217+
/**
218+
* Applies the given {@link StateTransformationFunction} to all state (1st input argument),
219+
* using the given value as second input argument. The result of {@link
220+
* StateTransformationFunction#apply(Object, Object)} is then stored as the new state. This
221+
* function is basically an optimization for get-update-put pattern.
222+
*
223+
* @param value the value to use in transforming the state. Can be null.
224+
* @throws Exception if some exception happens in the transformation function.
225+
*/
226+
public <T> void transformAll(T value, StateTransformationFunction<S, T> transformation)
227+
throws Exception {
228+
for (StateMap<K, N, S> stateMap : keyGroupedStateMaps) {
229+
List<StateEntry<K, N, S>> entries =
230+
StreamSupport.stream(
231+
Spliterators.spliteratorUnknownSize(stateMap.iterator(), 0),
232+
false)
233+
.collect(Collectors.toList());
234+
for (StateEntry<K, N, S> entry : entries) {
235+
stateMap.transform(entry.getKey(), entry.getNamespace(), value, transformation);
236+
}
237+
}
238+
}
239+
215240
// For queryable state ------------------------------------------------------------------------
216241

217242
/**
@@ -297,7 +322,6 @@ public int getKeyGroupOffset() {
297322
return keyGroupRange.getStartKeyGroup();
298323
}
299324

300-
@VisibleForTesting
301325
public StateMap<K, N, S> getMapForKeyGroup(int keyGroupIndex) {
302326
final int pos = indexToOffset(keyGroupIndex);
303327
if (pos >= 0 && pos < keyGroupedStateMaps.length) {

flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendMigrationTestBase.java

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,44 +1203,85 @@ void testStateMigrationAfterChangingTTL() throws Exception {
12031203
}
12041204

12051205
@TestTemplate
1206-
protected void testStateMigrationAfterChangingTTLFromEnablingToDisabling() throws Exception {
1206+
protected void testStateMigrationAfterChangingTTLFromDisablingToEnabling() throws Exception {
12071207
final String stateName = "test-ttl";
12081208

12091209
ValueStateDescriptor<TestType> initialAccessDescriptor =
12101210
new ValueStateDescriptor<>(stateName, new TestType.V1TestTypeSerializer());
1211-
initialAccessDescriptor.enableTimeToLive(
1211+
ValueStateDescriptor<TestType> newAccessDescriptorAfterRestore =
1212+
new ValueStateDescriptor<>(
1213+
stateName,
1214+
// restore with a V2 serializer that has a different schema
1215+
new TestType.V2TestTypeSerializer());
1216+
newAccessDescriptorAfterRestore.enableTimeToLive(
12121217
StateTtlConfig.newBuilder(Duration.ofDays(1)).build());
12131218

1214-
ValueStateDescriptor<TestType> newAccessDescriptorAfterRestore =
1215-
new ValueStateDescriptor<>(stateName, new TestType.V2TestTypeSerializer());
1219+
ListStateDescriptor<TestType> initialAccessListDescriptor =
1220+
new ListStateDescriptor<>(stateName, new TestType.V1TestTypeSerializer());
1221+
ListStateDescriptor<TestType> newAccessListDescriptorAfterRestore =
1222+
new ListStateDescriptor<>(
1223+
stateName,
1224+
// restore with a V2 serializer that has a different schema
1225+
new TestType.V2TestTypeSerializer());
1226+
newAccessListDescriptorAfterRestore.enableTimeToLive(
1227+
StateTtlConfig.newBuilder(Duration.ofDays(1)).build());
12161228

1217-
assertThatThrownBy(
1218-
() ->
1219-
testKeyedValueStateUpgrade(
1220-
initialAccessDescriptor, newAccessDescriptorAfterRestore))
1221-
.satisfiesAnyOf(
1222-
e -> assertThat(e).isInstanceOf(StateMigrationException.class),
1223-
e -> assertThat(e).hasCauseInstanceOf(StateMigrationException.class));
1229+
MapStateDescriptor<Integer, TestType> initialAccessMapDescriptor =
1230+
new MapStateDescriptor<>(
1231+
stateName, IntSerializer.INSTANCE, new TestType.V1TestTypeSerializer());
1232+
MapStateDescriptor<Integer, TestType> newAccessMapDescriptorAfterRestore =
1233+
new MapStateDescriptor<>(
1234+
stateName,
1235+
IntSerializer.INSTANCE,
1236+
// restore with a V2 serializer that has a different schema
1237+
new TestType.V2TestTypeSerializer());
1238+
newAccessMapDescriptorAfterRestore.enableTimeToLive(
1239+
StateTtlConfig.newBuilder(Duration.ofDays(1)).build());
1240+
1241+
testKeyedValueStateUpgrade(initialAccessDescriptor, newAccessDescriptorAfterRestore);
1242+
testKeyedListStateUpgrade(initialAccessListDescriptor, newAccessListDescriptorAfterRestore);
1243+
testKeyedMapStateUpgrade(initialAccessMapDescriptor, newAccessMapDescriptorAfterRestore);
12241244
}
12251245

12261246
@TestTemplate
1227-
protected void testStateMigrationAfterChangingTTLFromDisablingToEnabling() throws Exception {
1247+
protected void testStateMigrationAfterChangingTTLFromEnablingToDisabling() throws Exception {
12281248
final String stateName = "test-ttl";
12291249

12301250
ValueStateDescriptor<TestType> initialAccessDescriptor =
12311251
new ValueStateDescriptor<>(stateName, new TestType.V1TestTypeSerializer());
1252+
initialAccessDescriptor.enableTimeToLive(
1253+
StateTtlConfig.newBuilder(Duration.ofDays(1)).build());
12321254
ValueStateDescriptor<TestType> newAccessDescriptorAfterRestore =
1233-
new ValueStateDescriptor<>(stateName, new TestType.V2TestTypeSerializer());
1234-
newAccessDescriptorAfterRestore.enableTimeToLive(
1255+
new ValueStateDescriptor<>(
1256+
stateName,
1257+
// restore with a V2 serializer that has a different schema
1258+
new TestType.V2TestTypeSerializer());
1259+
1260+
ListStateDescriptor<TestType> initialAccessListDescriptor =
1261+
new ListStateDescriptor<>(stateName, new TestType.V1TestTypeSerializer());
1262+
initialAccessListDescriptor.enableTimeToLive(
12351263
StateTtlConfig.newBuilder(Duration.ofDays(1)).build());
1264+
ListStateDescriptor<TestType> newAccessListDescriptorAfterRestore =
1265+
new ListStateDescriptor<>(
1266+
stateName,
1267+
// restore with a V2 serializer that has a different schema
1268+
new TestType.V2TestTypeSerializer());
12361269

1237-
assertThatThrownBy(
1238-
() ->
1239-
testKeyedValueStateUpgrade(
1240-
initialAccessDescriptor, newAccessDescriptorAfterRestore))
1241-
.satisfiesAnyOf(
1242-
e -> assertThat(e).isInstanceOf(StateMigrationException.class),
1243-
e -> assertThat(e).hasCauseInstanceOf(StateMigrationException.class));
1270+
MapStateDescriptor<Integer, TestType> initialAccessMapDescriptor =
1271+
new MapStateDescriptor<>(
1272+
stateName, IntSerializer.INSTANCE, new TestType.V1TestTypeSerializer());
1273+
initialAccessMapDescriptor.enableTimeToLive(
1274+
StateTtlConfig.newBuilder(Duration.ofDays(1)).build());
1275+
MapStateDescriptor<Integer, TestType> newAccessMapDescriptorAfterRestore =
1276+
new MapStateDescriptor<>(
1277+
stateName,
1278+
IntSerializer.INSTANCE,
1279+
// restore with a V2 serializer that has a different schema
1280+
new TestType.V2TestTypeSerializer());
1281+
1282+
testKeyedValueStateUpgrade(initialAccessDescriptor, newAccessDescriptorAfterRestore);
1283+
testKeyedListStateUpgrade(initialAccessListDescriptor, newAccessListDescriptorAfterRestore);
1284+
testKeyedMapStateUpgrade(initialAccessMapDescriptor, newAccessMapDescriptorAfterRestore);
12441285
}
12451286

12461287
// -------------------------------------------------------------------------------

flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
3232
import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
3333
import org.apache.flink.util.Preconditions;
34-
import org.apache.flink.util.StateMigrationException;
3534

3635
import org.junit.jupiter.api.AfterEach;
3736
import org.junit.jupiter.api.BeforeEach;
@@ -46,7 +45,6 @@
4645

4746
import static org.apache.flink.runtime.state.ttl.StateBackendTestContext.NUMBER_OF_KEY_GROUPS;
4847
import static org.assertj.core.api.Assertions.assertThat;
49-
import static org.assertj.core.api.Assertions.assertThatThrownBy;
5048
import static org.assertj.core.api.Assumptions.assumeThat;
5149

5250
/** State TTL base test suite. */
@@ -508,8 +506,7 @@ protected void testRestoreTtlAndRegisterNonTtlStateCompatFailure() throws Except
508506
sbetc.createAndRestoreKeyedStateBackend(snapshot);
509507

510508
sbetc.setCurrentKey("defaultKey");
511-
assertThatThrownBy(() -> sbetc.createState(ctx().createStateDescriptor(), ""))
512-
.isInstanceOf(StateMigrationException.class);
509+
sbetc.createState(ctx().createStateDescriptor(), "");
513510
}
514511

515512
@TestTemplate

0 commit comments

Comments
 (0)