Skip to content

Commit 08ddbec

Browse files
authored
Adaptive allocations improvements (#126307) (#126379)
* Adaptive allocations: don't update deployments that aren't started. * AssignmentPlanner: don't plan deployments with zero allocations * Update JavaDoc
1 parent 82165b1 commit 08ddbec

File tree

4 files changed

+90
-18
lines changed

4 files changed

+90
-18
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ private TrainedModelAssignment(
144144
* @param assignmentState used to track the state of the assignment for rebalancing, autoscaling, and more
145145
* @param reason may contain a human-readable explanation for the current state
146146
* @param startTime the time when the assignment was created
147-
* @param maxAssignedAllocations used for adaptive allocations
147+
* @param maxAssignedAllocations keeps track of the maximum number of allocations used for this assignment
148148
* @param adaptiveAllocationsSettings how the assignment should scale based on usage
149149
*/
150150
TrainedModelAssignment(

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
2929
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
3030
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
31+
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
3132
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
3233
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
3334
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
@@ -389,13 +390,15 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo
389390

390391
Map<String, Stats> recentStatsByDeployment = new HashMap<>();
391392
Map<String, Integer> numberOfAllocations = new HashMap<>();
393+
Map<String, AssignmentState> assignmentStates = new HashMap<>();
392394
// Check for recent scale ups in the deployment stats, because a different node may have
393395
// caused a scale up when an inference request arrives and there were zero allocations.
394396
Set<String> hasRecentObservedScaleUp = new HashSet<>();
395397

396398
for (AssignmentStats assignmentStats : statsResponse.getStats().results()) {
397399
String deploymentId = assignmentStats.getDeploymentId();
398400
numberOfAllocations.put(deploymentId, assignmentStats.getNumberOfAllocations());
401+
assignmentStates.put(deploymentId, assignmentStats.getState());
399402
Map<String, Stats> deploymentStats = lastInferenceStatsByDeploymentAndNode.computeIfAbsent(
400403
deploymentId,
401404
key -> new HashMap<>()
@@ -447,6 +450,14 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo
447450
logger.debug("adaptive allocations scaler: skipping scaling down [{}] because of recent scaleup.", deploymentId);
448451
continue;
449452
}
453+
if (assignmentStates.get(deploymentId) != AssignmentState.STARTED) {
454+
logger.debug(
455+
"adaptive allocations scaler: skipping scaling [{}] because it is in [{}] state.",
456+
deploymentId,
457+
assignmentStates.get(deploymentId)
458+
);
459+
continue;
460+
}
450461
if (newNumberOfAllocations > numberOfAllocations.get(deploymentId)) {
451462
lastScaleUpTimesMillis.put(deploymentId, now);
452463
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ public class AssignmentPlanner {
5050

5151
public AssignmentPlanner(List<Node> nodes, List<AssignmentPlan.Deployment> deployments) {
5252
this.nodes = nodes.stream().sorted(Comparator.comparing(Node::id)).toList();
53-
this.deployments = deployments.stream().sorted(Comparator.comparing(AssignmentPlan.Deployment::deploymentId)).toList();
53+
this.deployments = deployments.stream()
54+
.filter(deployment -> deployment.allocations() > 0)
55+
.sorted(Comparator.comparing(AssignmentPlan.Deployment::deploymentId))
56+
.toList();
5457
}
5558

5659
public AssignmentPlan computePlan() {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
2828
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
2929
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
30+
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
3031
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
3132
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
3233
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
@@ -85,7 +86,7 @@ public void tearDown() throws Exception {
8586
super.tearDown();
8687
}
8788

88-
private ClusterState getClusterState(int numAllocations) {
89+
private ClusterState getClusterState(int numAllocations, AssignmentState assignmentState) {
8990
ClusterState clusterState = mock(ClusterState.class);
9091
Metadata metadata = mock(Metadata.class);
9192
when(clusterState.getMetadata()).thenReturn(metadata);
@@ -107,7 +108,7 @@ private ClusterState getClusterState(int numAllocations) {
107108
100_000_000
108109
),
109110
new AdaptiveAllocationsSettings(true, null, null)
110-
).build()
111+
).setAssignmentState(assignmentState).build()
111112
)
112113
)
113114
);
@@ -118,7 +119,8 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(
118119
int numAllocations,
119120
int inferenceCount,
120121
double latency,
121-
boolean recentStartup
122+
boolean recentStartup,
123+
AssignmentState assignmentState
122124
) {
123125
return new GetDeploymentStatsAction.Response(
124126
List.of(),
@@ -155,15 +157,15 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(
155157
)
156158
),
157159
Priority.NORMAL
158-
)
160+
).setState(assignmentState)
159161
),
160162
0
161163
);
162164
}
163165

164166
public void test_scaleUp() {
165167
// Initialize the cluster with a deployment with 1 allocation.
166-
ClusterState clusterState = getClusterState(1);
168+
ClusterState clusterState = getClusterState(1, AssignmentState.STARTED);
167169
when(clusterService.state()).thenReturn(clusterState);
168170

169171
AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService(
@@ -189,7 +191,7 @@ public void test_scaleUp() {
189191
doAnswer(invocationOnMock -> {
190192
@SuppressWarnings("unchecked")
191193
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
192-
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false));
194+
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false, AssignmentState.STARTED));
193195
return Void.TYPE;
194196
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
195197

@@ -205,7 +207,7 @@ public void test_scaleUp() {
205207
doAnswer(invocationOnMock -> {
206208
@SuppressWarnings("unchecked")
207209
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
208-
listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0, false));
210+
listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0, false, AssignmentState.STARTED));
209211
return Void.TYPE;
210212
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
211213
doAnswer(invocationOnMock -> {
@@ -226,7 +228,7 @@ public void test_scaleUp() {
226228
verifyNoMoreInteractions(client, clusterService);
227229
reset(client, clusterService);
228230

229-
clusterState = getClusterState(2);
231+
clusterState = getClusterState(2, AssignmentState.STARTED);
230232
ClusterChangedEvent clusterChangedEvent = mock(ClusterChangedEvent.class);
231233
when(clusterChangedEvent.state()).thenReturn(clusterState);
232234
service.clusterChanged(clusterChangedEvent);
@@ -236,7 +238,7 @@ public void test_scaleUp() {
236238
doAnswer(invocationOnMock -> {
237239
@SuppressWarnings("unchecked")
238240
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
239-
listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0, false));
241+
listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0, false, AssignmentState.STARTED));
240242
return Void.TYPE;
241243
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
242244
doAnswer(invocationOnMock -> {
@@ -257,7 +259,7 @@ public void test_scaleUp() {
257259

258260
public void test_scaleDownToZero_whenNoRequests() {
259261
// Initialize the cluster with a deployment with 1 allocation.
260-
ClusterState clusterState = getClusterState(1);
262+
ClusterState clusterState = getClusterState(1, AssignmentState.STARTED);
261263
when(clusterService.state()).thenReturn(clusterState);
262264

263265
AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService(
@@ -283,7 +285,7 @@ public void test_scaleDownToZero_whenNoRequests() {
283285
doAnswer(invocationOnMock -> {
284286
@SuppressWarnings("unchecked")
285287
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
286-
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false));
288+
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false, AssignmentState.STARTED));
287289
return Void.TYPE;
288290
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
289291

@@ -299,7 +301,7 @@ public void test_scaleDownToZero_whenNoRequests() {
299301
doAnswer(invocationOnMock -> {
300302
@SuppressWarnings("unchecked")
301303
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
302-
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false));
304+
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false, AssignmentState.STARTED));
303305
return Void.TYPE;
304306
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
305307
doAnswer(invocationOnMock -> {
@@ -322,9 +324,65 @@ public void test_scaleDownToZero_whenNoRequests() {
322324
service.stop();
323325
}
324326

327+
public void test_dontScale_whenNotStarted() {
328+
// Initialize the cluster with a deployment with 1 allocation.
329+
ClusterState clusterState = getClusterState(1, AssignmentState.STARTING);
330+
when(clusterService.state()).thenReturn(clusterState);
331+
332+
AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService(
333+
threadPool,
334+
clusterService,
335+
client,
336+
inferenceAuditor,
337+
meterRegistry,
338+
true,
339+
1,
340+
1,
341+
2_000
342+
);
343+
service.start();
344+
345+
verify(clusterService).state();
346+
verify(clusterService).addListener(same(service));
347+
verifyNoMoreInteractions(client, clusterService);
348+
reset(client, clusterService);
349+
350+
// First cycle: many inference requests
351+
when(client.threadPool()).thenReturn(threadPool);
352+
doAnswer(invocationOnMock -> {
353+
@SuppressWarnings("unchecked")
354+
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
355+
listener.onResponse(getDeploymentStatsResponse(1, 10000, 10.0, false, AssignmentState.STARTING));
356+
return Void.TYPE;
357+
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
358+
359+
safeSleep(1200);
360+
361+
verify(client, times(1)).threadPool();
362+
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
363+
verifyNoMoreInteractions(client, clusterService);
364+
reset(client, clusterService);
365+
366+
// Second cycle: again many inference requests
367+
when(client.threadPool()).thenReturn(threadPool);
368+
doAnswer(invocationOnMock -> {
369+
@SuppressWarnings("unchecked")
370+
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
371+
listener.onResponse(getDeploymentStatsResponse(1, 20000, 10.0, false, AssignmentState.STARTING));
372+
return Void.TYPE;
373+
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
374+
375+
safeSleep(1200);
376+
377+
verify(client, times(1)).threadPool();
378+
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
379+
verifyNoMoreInteractions(client, clusterService);
380+
service.stop();
381+
}
382+
325383
public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
326384
// Initialize the cluster with a deployment with 1 allocation.
327-
ClusterState clusterState = getClusterState(1);
385+
ClusterState clusterState = getClusterState(1, AssignmentState.STARTED);
328386
when(clusterService.state()).thenReturn(clusterState);
329387

330388
AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService(
@@ -350,7 +408,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
350408
doAnswer(invocationOnMock -> {
351409
@SuppressWarnings("unchecked")
352410
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
353-
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, true));
411+
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, true, AssignmentState.STARTED));
354412
return Void.TYPE;
355413
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
356414

@@ -366,7 +424,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
366424
doAnswer(invocationOnMock -> {
367425
@SuppressWarnings("unchecked")
368426
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
369-
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, true));
427+
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, true, AssignmentState.STARTED));
370428
return Void.TYPE;
371429
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
372430
doAnswer(invocationOnMock -> {
@@ -388,7 +446,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
388446
doAnswer(invocationOnMock -> {
389447
@SuppressWarnings("unchecked")
390448
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
391-
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false));
449+
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false, AssignmentState.STARTED));
392450
return Void.TYPE;
393451
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
394452
doAnswer(invocationOnMock -> {

0 commit comments

Comments
 (0)