27
27
import org .elasticsearch .xpack .core .ml .action .StartTrainedModelDeploymentAction ;
28
28
import org .elasticsearch .xpack .core .ml .action .UpdateTrainedModelDeploymentAction ;
29
29
import org .elasticsearch .xpack .core .ml .inference .assignment .AdaptiveAllocationsSettings ;
30
+ import org .elasticsearch .xpack .core .ml .inference .assignment .AssignmentState ;
30
31
import org .elasticsearch .xpack .core .ml .inference .assignment .AssignmentStats ;
31
32
import org .elasticsearch .xpack .core .ml .inference .assignment .Priority ;
32
33
import org .elasticsearch .xpack .core .ml .inference .assignment .TrainedModelAssignment ;
@@ -85,7 +86,7 @@ public void tearDown() throws Exception {
85
86
super .tearDown ();
86
87
}
87
88
88
- private ClusterState getClusterState (int numAllocations ) {
89
+ private ClusterState getClusterState (int numAllocations , AssignmentState assignmentState ) {
89
90
ClusterState clusterState = mock (ClusterState .class );
90
91
Metadata metadata = mock (Metadata .class );
91
92
when (clusterState .getMetadata ()).thenReturn (metadata );
@@ -107,7 +108,7 @@ private ClusterState getClusterState(int numAllocations) {
107
108
100_000_000
108
109
),
109
110
new AdaptiveAllocationsSettings (true , null , null )
110
- ).build ()
111
+ ).setAssignmentState ( assignmentState ). build ()
111
112
)
112
113
)
113
114
);
@@ -118,7 +119,8 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(
118
119
int numAllocations ,
119
120
int inferenceCount ,
120
121
double latency ,
121
- boolean recentStartup
122
+ boolean recentStartup ,
123
+ AssignmentState assignmentState
122
124
) {
123
125
return new GetDeploymentStatsAction .Response (
124
126
List .of (),
@@ -155,15 +157,15 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(
155
157
)
156
158
),
157
159
Priority .NORMAL
158
- )
160
+ ). setState ( assignmentState )
159
161
),
160
162
0
161
163
);
162
164
}
163
165
164
166
public void test_scaleUp () {
165
167
// Initialize the cluster with a deployment with 1 allocation.
166
- ClusterState clusterState = getClusterState (1 );
168
+ ClusterState clusterState = getClusterState (1 , AssignmentState . STARTED );
167
169
when (clusterService .state ()).thenReturn (clusterState );
168
170
169
171
AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService (
@@ -189,7 +191,7 @@ public void test_scaleUp() {
189
191
doAnswer (invocationOnMock -> {
190
192
@ SuppressWarnings ("unchecked" )
191
193
var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
192
- listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , false ));
194
+ listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , false , AssignmentState . STARTED ));
193
195
return Void .TYPE ;
194
196
}).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
195
197
@@ -205,7 +207,7 @@ public void test_scaleUp() {
205
207
doAnswer (invocationOnMock -> {
206
208
@ SuppressWarnings ("unchecked" )
207
209
var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
208
- listener .onResponse (getDeploymentStatsResponse (1 , 150 , 10.0 , false ));
210
+ listener .onResponse (getDeploymentStatsResponse (1 , 150 , 10.0 , false , AssignmentState . STARTED ));
209
211
return Void .TYPE ;
210
212
}).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
211
213
doAnswer (invocationOnMock -> {
@@ -226,7 +228,7 @@ public void test_scaleUp() {
226
228
verifyNoMoreInteractions (client , clusterService );
227
229
reset (client , clusterService );
228
230
229
- clusterState = getClusterState (2 );
231
+ clusterState = getClusterState (2 , AssignmentState . STARTED );
230
232
ClusterChangedEvent clusterChangedEvent = mock (ClusterChangedEvent .class );
231
233
when (clusterChangedEvent .state ()).thenReturn (clusterState );
232
234
service .clusterChanged (clusterChangedEvent );
@@ -236,7 +238,7 @@ public void test_scaleUp() {
236
238
doAnswer (invocationOnMock -> {
237
239
@ SuppressWarnings ("unchecked" )
238
240
var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
239
- listener .onResponse (getDeploymentStatsResponse (2 , 0 , 9.0 , false ));
241
+ listener .onResponse (getDeploymentStatsResponse (2 , 0 , 9.0 , false , AssignmentState . STARTED ));
240
242
return Void .TYPE ;
241
243
}).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
242
244
doAnswer (invocationOnMock -> {
@@ -257,7 +259,7 @@ public void test_scaleUp() {
257
259
258
260
public void test_scaleDownToZero_whenNoRequests () {
259
261
// Initialize the cluster with a deployment with 1 allocation.
260
- ClusterState clusterState = getClusterState (1 );
262
+ ClusterState clusterState = getClusterState (1 , AssignmentState . STARTED );
261
263
when (clusterService .state ()).thenReturn (clusterState );
262
264
263
265
AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService (
@@ -283,7 +285,7 @@ public void test_scaleDownToZero_whenNoRequests() {
283
285
doAnswer (invocationOnMock -> {
284
286
@ SuppressWarnings ("unchecked" )
285
287
var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
286
- listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , false ));
288
+ listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , false , AssignmentState . STARTED ));
287
289
return Void .TYPE ;
288
290
}).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
289
291
@@ -299,7 +301,7 @@ public void test_scaleDownToZero_whenNoRequests() {
299
301
doAnswer (invocationOnMock -> {
300
302
@ SuppressWarnings ("unchecked" )
301
303
var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
302
- listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , false ));
304
+ listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , false , AssignmentState . STARTED ));
303
305
return Void .TYPE ;
304
306
}).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
305
307
doAnswer (invocationOnMock -> {
@@ -322,9 +324,65 @@ public void test_scaleDownToZero_whenNoRequests() {
322
324
service .stop ();
323
325
}
324
326
327
+ public void test_dontScale_whenNotStarted () {
328
+ // Initialize the cluster with a deployment with 1 allocation.
329
+ ClusterState clusterState = getClusterState (1 , AssignmentState .STARTING );
330
+ when (clusterService .state ()).thenReturn (clusterState );
331
+
332
+ AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService (
333
+ threadPool ,
334
+ clusterService ,
335
+ client ,
336
+ inferenceAuditor ,
337
+ meterRegistry ,
338
+ true ,
339
+ 1 ,
340
+ 1 ,
341
+ 2_000
342
+ );
343
+ service .start ();
344
+
345
+ verify (clusterService ).state ();
346
+ verify (clusterService ).addListener (same (service ));
347
+ verifyNoMoreInteractions (client , clusterService );
348
+ reset (client , clusterService );
349
+
350
+ // First cycle: many inference requests
351
+ when (client .threadPool ()).thenReturn (threadPool );
352
+ doAnswer (invocationOnMock -> {
353
+ @ SuppressWarnings ("unchecked" )
354
+ var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
355
+ listener .onResponse (getDeploymentStatsResponse (1 , 10000 , 10.0 , false , AssignmentState .STARTING ));
356
+ return Void .TYPE ;
357
+ }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
358
+
359
+ safeSleep (1200 );
360
+
361
+ verify (client , times (1 )).threadPool ();
362
+ verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
363
+ verifyNoMoreInteractions (client , clusterService );
364
+ reset (client , clusterService );
365
+
366
+ // Second cycle: again many inference requests
367
+ when (client .threadPool ()).thenReturn (threadPool );
368
+ doAnswer (invocationOnMock -> {
369
+ @ SuppressWarnings ("unchecked" )
370
+ var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
371
+ listener .onResponse (getDeploymentStatsResponse (1 , 20000 , 10.0 , false , AssignmentState .STARTING ));
372
+ return Void .TYPE ;
373
+ }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
374
+
375
+ safeSleep (1200 );
376
+
377
+ verify (client , times (1 )).threadPool ();
378
+ verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
379
+ verifyNoMoreInteractions (client , clusterService );
380
+ service .stop ();
381
+ }
382
+
325
383
public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode () {
326
384
// Initialize the cluster with a deployment with 1 allocation.
327
- ClusterState clusterState = getClusterState (1 );
385
+ ClusterState clusterState = getClusterState (1 , AssignmentState . STARTED );
328
386
when (clusterService .state ()).thenReturn (clusterState );
329
387
330
388
AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService (
@@ -350,7 +408,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
350
408
doAnswer (invocationOnMock -> {
351
409
@ SuppressWarnings ("unchecked" )
352
410
var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
353
- listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , true ));
411
+ listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , true , AssignmentState . STARTED ));
354
412
return Void .TYPE ;
355
413
}).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
356
414
@@ -366,7 +424,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
366
424
doAnswer (invocationOnMock -> {
367
425
@ SuppressWarnings ("unchecked" )
368
426
var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
369
- listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , true ));
427
+ listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , true , AssignmentState . STARTED ));
370
428
return Void .TYPE ;
371
429
}).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
372
430
doAnswer (invocationOnMock -> {
@@ -388,7 +446,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
388
446
doAnswer (invocationOnMock -> {
389
447
@ SuppressWarnings ("unchecked" )
390
448
var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
391
- listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , false ));
449
+ listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , false , AssignmentState . STARTED ));
392
450
return Void .TYPE ;
393
451
}).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
394
452
doAnswer (invocationOnMock -> {
0 commit comments