Skip to content

Commit 60ac144

Browse files
Validate quantization (#315)
* Validate quantization * comments
1 parent 65afe0a commit 60ac144

File tree

5 files changed

+83
-4
lines changed

5 files changed

+83
-4
lines changed

clients/python/llmengine/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def create(
7676
7777
num_shards (`int`):
7878
Number of shards for the LLM. When bigger than 1, LLM will be sharded
79-
to multiple GPUs. Number of GPUs must be larger than num_shards.
79+
to multiple GPUs. Number of GPUs must be equal or larger than num_shards.
8080
Only affects behavior for text-generation-inference models
8181
8282
quantize (`Optional[Quantization]`):

model-engine/model_engine_server/common/env_vars.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
A place for defining, setting, and referencing all environment variables used in Launch.
33
"""
44
import os
5+
import sys
56
from typing import Optional, Sequence
67

78
from model_engine_server.common.constants import PROJECT_ROOT
@@ -73,5 +74,5 @@ def get_boolean_env_var(name: str) -> bool:
7374
logger.warning("LOCAL development & testing mode is ON")
7475

7576
GIT_TAG: str = os.environ.get("GIT_TAG", "GIT_TAG_NOT_FOUND")
76-
if GIT_TAG == "GIT_TAG_NOT_FOUND":
77+
if GIT_TAG == "GIT_TAG_NOT_FOUND" and "pytest" not in sys.modules:
7778
raise ValueError("GIT_TAG environment variable must be set")

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,13 @@
136136
},
137137
}
138138

139+
_SUPPORTED_QUANTIZATIONS: Dict[LLMInferenceFramework, List[Quantization]] = {
140+
LLMInferenceFramework.DEEPSPEED: [],
141+
LLMInferenceFramework.TEXT_GENERATION_INFERENCE: [Quantization.BITSANDBYTES],
142+
LLMInferenceFramework.VLLM: [Quantization.AWQ],
143+
LLMInferenceFramework.LIGHTLLM: [],
144+
}
145+
139146

140147
NUM_DOWNSTREAM_REQUEST_RETRIES = 80 # has to be high enough so that the retries take the 5 minutes
141148
DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes
@@ -198,8 +205,21 @@ def validate_num_shards(
198205
raise ObjectHasInvalidValueException("DeepSpeed requires more than 1 GPU.")
199206
if num_shards != gpus:
200207
raise ObjectHasInvalidValueException(
201-
f"DeepSpeed requires num shard {num_shards} to be the same as number of GPUs {gpus}."
208+
f"Num shard {num_shards} must be the same as number of GPUs {gpus} for DeepSpeed."
202209
)
210+
if num_shards > gpus:
211+
raise ObjectHasInvalidValueException(
212+
f"Num shard {num_shards} must be less than or equal to the number of GPUs {gpus}."
213+
)
214+
215+
216+
def validate_quantization(
217+
quantize: Optional[Quantization], inference_framework: LLMInferenceFramework
218+
) -> None:
219+
if quantize is not None and quantize not in _SUPPORTED_QUANTIZATIONS[inference_framework]:
220+
raise ObjectHasInvalidValueException(
221+
f"Quantization {quantize} is not supported for inference framework {inference_framework}. Supported quantization types are {_SUPPORTED_QUANTIZATIONS[inference_framework]}."
222+
)
203223

204224

205225
class CreateLLMModelEndpointV1UseCase:
@@ -667,10 +687,12 @@ async def execute(
667687
validate_post_inference_hooks(user, request.post_inference_hooks)
668688
validate_model_name(request.model_name, request.inference_framework)
669689
validate_num_shards(request.num_shards, request.inference_framework, request.gpus)
690+
validate_quantization(request.quantize, request.inference_framework)
670691

671692
if request.inference_framework in [
672693
LLMInferenceFramework.TEXT_GENERATION_INFERENCE,
673694
LLMInferenceFramework.VLLM,
695+
LLMInferenceFramework.LIGHTLLM,
674696
]:
675697
if request.endpoint_type != ModelEndpointType.STREAMING:
676698
raise ObjectHasInvalidValueException(

model-engine/tests/unit/domain/conftest.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from model_engine_server.domain.entities import (
2020
GpuType,
21+
LLMInferenceFramework,
2122
ModelBundle,
2223
ModelBundleEnvironmentParams,
2324
ModelBundleFrameworkType,
@@ -283,7 +284,6 @@ def create_llm_model_endpoint_text_generation_inference_request_streaming() -> (
283284
inference_framework="deepspeed",
284285
inference_framework_image_tag="test_tag",
285286
num_shards=2,
286-
quantize=Quantization.BITSANDBYTES,
287287
endpoint_type=ModelEndpointType.STREAMING,
288288
metadata={},
289289
post_inference_hooks=["billing"],
@@ -356,6 +356,33 @@ def create_llm_model_endpoint_request_invalid_model_name() -> CreateLLMModelEndp
356356
)
357357

358358

359+
@pytest.fixture
360+
def create_llm_model_endpoint_request_invalid_quantization() -> CreateLLMModelEndpointV1Request:
361+
return CreateLLMModelEndpointV1Request(
362+
name="test_llm_endpoint_name_1",
363+
model_name="nonexist",
364+
source="hugging_face",
365+
inference_framework=LLMInferenceFramework.VLLM,
366+
inference_framework_image_tag="test_tag",
367+
num_shards=2,
368+
quantize=Quantization.BITSANDBYTES,
369+
endpoint_type=ModelEndpointType.SYNC,
370+
metadata={},
371+
post_inference_hooks=["billing"],
372+
cpus=1,
373+
gpus=2,
374+
memory="8G",
375+
gpu_type=GpuType.NVIDIA_TESLA_T4,
376+
storage=None,
377+
min_workers=1,
378+
max_workers=3,
379+
per_worker=2,
380+
labels={"team": "infra", "product": "my_product"},
381+
aws_role="test_aws_role",
382+
results_s3_bucket="test_s3_bucket",
383+
)
384+
385+
359386
@pytest.fixture
360387
def completion_sync_request() -> CompletionSyncV1Request:
361388
return CompletionSyncV1Request(

model-engine/tests/unit/domain/test_llm_use_cases.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,35 @@ async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception
232232
)
233233

234234

235+
@pytest.mark.asyncio
236+
async def test_create_llm_model_endpoint_use_case_quantization_exception(
237+
test_api_key: str,
238+
fake_model_bundle_repository,
239+
fake_model_endpoint_service,
240+
fake_docker_repository_image_always_exists,
241+
fake_model_primitive_gateway,
242+
fake_llm_artifact_gateway,
243+
create_llm_model_endpoint_request_invalid_quantization: CreateLLMModelEndpointV1Request,
244+
):
245+
fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository
246+
bundle_use_case = CreateModelBundleV2UseCase(
247+
model_bundle_repository=fake_model_bundle_repository,
248+
docker_repository=fake_docker_repository_image_always_exists,
249+
model_primitive_gateway=fake_model_primitive_gateway,
250+
)
251+
use_case = CreateLLMModelEndpointV1UseCase(
252+
create_model_bundle_use_case=bundle_use_case,
253+
model_bundle_repository=fake_model_bundle_repository,
254+
model_endpoint_service=fake_model_endpoint_service,
255+
llm_artifact_gateway=fake_llm_artifact_gateway,
256+
)
257+
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
258+
with pytest.raises(ObjectHasInvalidValueException):
259+
await use_case.execute(
260+
user=user, request=create_llm_model_endpoint_request_invalid_quantization
261+
)
262+
263+
235264
@pytest.mark.asyncio
236265
async def test_get_llm_model_endpoint_use_case_raises_not_found(
237266
test_api_key: str,

0 commit comments

Comments
 (0)