7
7
8
8
@pytest .mark .asyncio
9
9
@patch ("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine" )
10
- @patch ("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest" )
10
+ @patch (
11
+ "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest"
12
+ )
11
13
@patch (
12
14
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent"
13
15
)
@@ -25,9 +27,9 @@ async def test_batch_inference(
25
27
mock_get_s3_client ,
26
28
mock_generate_with_vllm ,
27
29
mock_create_batch_completions_request_content ,
28
- mock_create_batch_completions_request ,
30
+ mock_create_batch_completions_engine_request ,
29
31
mock_vllm ,
30
- create_batch_completions_request ,
32
+ create_batch_completions_engine_request ,
31
33
create_batch_completions_request_content ,
32
34
mock_s3_client ,
33
35
mock_process ,
@@ -36,7 +38,9 @@ async def test_batch_inference(
36
38
# Mock the necessary objects and data
37
39
mock_popen .return_value = mock_process
38
40
mock_get_s3_client .return_value = mock_s3_client
39
- mock_create_batch_completions_request .parse_file .return_value = create_batch_completions_request
41
+ mock_create_batch_completions_engine_request .parse_file .return_value = (
42
+ create_batch_completions_engine_request
43
+ )
40
44
mock_create_batch_completions_request_content .parse_raw .return_value = (
41
45
create_batch_completions_request_content
42
46
)
@@ -48,7 +52,7 @@ async def test_batch_inference(
48
52
await batch_inference ()
49
53
50
54
# Assertions
51
- mock_create_batch_completions_request .parse_file .assert_called_once ()
55
+ mock_create_batch_completions_engine_request .parse_file .assert_called_once ()
52
56
mock_open_func .assert_has_calls (
53
57
[
54
58
call ("input_data_path" , "r" ),
@@ -61,7 +65,9 @@ async def test_batch_inference(
61
65
62
66
@pytest .mark .asyncio
63
67
@patch ("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine" )
64
- @patch ("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest" )
68
+ @patch (
69
+ "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest"
70
+ )
65
71
@patch (
66
72
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent"
67
73
)
@@ -79,9 +85,9 @@ async def test_batch_inference_failed_to_download_model_but_proceed(
79
85
mock_get_s3_client ,
80
86
mock_generate_with_vllm ,
81
87
mock_create_batch_completions_request_content ,
82
- mock_create_batch_completions_request ,
88
+ mock_create_batch_completions_engine_request ,
83
89
mock_vllm ,
84
- create_batch_completions_request ,
90
+ create_batch_completions_engine_request ,
85
91
create_batch_completions_request_content ,
86
92
mock_s3_client ,
87
93
mock_process ,
@@ -91,7 +97,9 @@ async def test_batch_inference_failed_to_download_model_but_proceed(
91
97
mock_process .returncode = 1 # Failed to download model
92
98
mock_popen .return_value = mock_process
93
99
mock_get_s3_client .return_value = mock_s3_client
94
- mock_create_batch_completions_request .parse_file .return_value = create_batch_completions_request
100
+ mock_create_batch_completions_engine_request .parse_file .return_value = (
101
+ create_batch_completions_engine_request
102
+ )
95
103
mock_create_batch_completions_request_content .parse_raw .return_value = (
96
104
create_batch_completions_request_content
97
105
)
@@ -103,7 +111,7 @@ async def test_batch_inference_failed_to_download_model_but_proceed(
103
111
await batch_inference ()
104
112
105
113
# Assertions
106
- mock_create_batch_completions_request .parse_file .assert_called_once ()
114
+ mock_create_batch_completions_engine_request .parse_file .assert_called_once ()
107
115
mock_open_func .assert_has_calls (
108
116
[
109
117
call ("input_data_path" , "r" ),
@@ -116,7 +124,9 @@ async def test_batch_inference_failed_to_download_model_but_proceed(
116
124
117
125
@pytest .mark .asyncio
118
126
@patch ("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine" )
119
- @patch ("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest" )
127
+ @patch (
128
+ "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest"
129
+ )
120
130
@patch (
121
131
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent"
122
132
)
@@ -136,9 +146,9 @@ async def test_batch_inference_two_workers(
136
146
mock_get_s3_client ,
137
147
mock_generate_with_vllm ,
138
148
mock_create_batch_completions_request_content ,
139
- mock_create_batch_completions_request ,
149
+ mock_create_batch_completions_engine_request ,
140
150
mock_vllm ,
141
- create_batch_completions_request ,
151
+ create_batch_completions_engine_request ,
142
152
create_batch_completions_request_content ,
143
153
mock_s3_client ,
144
154
mock_process ,
@@ -147,8 +157,10 @@ async def test_batch_inference_two_workers(
147
157
# Mock the necessary objects and data
148
158
mock_popen .return_value = mock_process
149
159
mock_get_s3_client .return_value = mock_s3_client
150
- create_batch_completions_request .data_parallelism = 2
151
- mock_create_batch_completions_request .parse_file .return_value = create_batch_completions_request
160
+ create_batch_completions_engine_request .data_parallelism = 2
161
+ mock_create_batch_completions_engine_request .parse_file .return_value = (
162
+ create_batch_completions_engine_request
163
+ )
152
164
mock_create_batch_completions_request_content .parse_raw .return_value = (
153
165
create_batch_completions_request_content
154
166
)
@@ -168,7 +180,7 @@ def side_effect(key, default):
168
180
await batch_inference ()
169
181
170
182
# Assertions
171
- mock_create_batch_completions_request .parse_file .assert_called_once ()
183
+ mock_create_batch_completions_engine_request .parse_file .assert_called_once ()
172
184
mock_open_func .assert_has_calls (
173
185
[
174
186
call ("input_data_path" , "r" ),
@@ -198,7 +210,9 @@ def side_effect(key, default):
198
210
199
211
@pytest .mark .asyncio
200
212
@patch ("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine" )
201
- @patch ("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest" )
213
+ @patch (
214
+ "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest"
215
+ )
202
216
@patch (
203
217
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent"
204
218
)
@@ -218,9 +232,9 @@ async def test_batch_inference_delete_chunks(
218
232
mock_get_s3_client ,
219
233
mock_generate_with_vllm ,
220
234
mock_create_batch_completions_request_content ,
221
- mock_create_batch_completions_request ,
235
+ mock_create_batch_completions_engine_request ,
222
236
mock_vllm ,
223
- create_batch_completions_request ,
237
+ create_batch_completions_engine_request ,
224
238
create_batch_completions_request_content ,
225
239
mock_s3_client ,
226
240
mock_process ,
@@ -229,9 +243,11 @@ async def test_batch_inference_delete_chunks(
229
243
# Mock the necessary objects and data
230
244
mock_popen .return_value = mock_process
231
245
mock_get_s3_client .return_value = mock_s3_client
232
- create_batch_completions_request .data_parallelism = 2
233
- create_batch_completions_request .output_data_path = "s3://bucket/key"
234
- mock_create_batch_completions_request .parse_file .return_value = create_batch_completions_request
246
+ create_batch_completions_engine_request .data_parallelism = 2
247
+ create_batch_completions_engine_request .output_data_path = "s3://bucket/key"
248
+ mock_create_batch_completions_engine_request .parse_file .return_value = (
249
+ create_batch_completions_engine_request
250
+ )
235
251
mock_create_batch_completions_request_content .parse_raw .return_value = (
236
252
create_batch_completions_request_content
237
253
)
@@ -251,7 +267,7 @@ def side_effect(key, default):
251
267
await batch_inference ()
252
268
253
269
# Assertions
254
- mock_create_batch_completions_request .parse_file .assert_called_once ()
270
+ mock_create_batch_completions_engine_request .parse_file .assert_called_once ()
255
271
mock_open_func .assert_has_calls (
256
272
[
257
273
call ("input_data_path" , "r" ),
@@ -310,7 +326,9 @@ def test_file_exists_no_such_key():
310
326
311
327
@pytest .mark .asyncio
312
328
@patch ("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine" )
313
- @patch ("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest" )
329
+ @patch (
330
+ "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest"
331
+ )
314
332
@patch (
315
333
"model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent"
316
334
)
@@ -330,7 +348,7 @@ async def test_batch_inference_tool_completion(
330
348
mock_get_s3_client ,
331
349
mock_generate_with_vllm ,
332
350
mock_create_batch_completions_request_content ,
333
- mock_create_batch_completions_request ,
351
+ mock_create_batch_completions_engine_request ,
334
352
mock_vllm ,
335
353
create_batch_completions_tool_completion_request ,
336
354
create_batch_completions_tool_completion_request_content ,
@@ -344,7 +362,7 @@ async def test_batch_inference_tool_completion(
344
362
mock_run .return_value = mock_run_output
345
363
mock_popen .return_value = mock_process
346
364
mock_get_s3_client .return_value = mock_s3_client
347
- mock_create_batch_completions_request .parse_file .return_value = (
365
+ mock_create_batch_completions_engine_request .parse_file .return_value = (
348
366
create_batch_completions_tool_completion_request
349
367
)
350
368
mock_create_batch_completions_request_content .parse_raw .return_value = (
@@ -361,7 +379,7 @@ async def test_batch_inference_tool_completion(
361
379
await batch_inference ()
362
380
363
381
# Assertions
364
- mock_create_batch_completions_request .parse_file .assert_called_once ()
382
+ mock_create_batch_completions_engine_request .parse_file .assert_called_once ()
365
383
mock_open_func .assert_has_calls (
366
384
[
367
385
call ("input_data_path" , "r" ),
0 commit comments