Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
980dafe
count prompt tokens, use tokenizer if needed
saiatmakuri Nov 7, 2023
fcbdca1
merge conflict
saiatmakuri Nov 7, 2023
b6ebaa3
docstrings
saiatmakuri Nov 7, 2023
b02ac14
fix tests and code cov
saiatmakuri Nov 7, 2023
e1e9efa
Merge branch 'main' into saiatmakuri/count-prompt-tokens
saiatmakuri Nov 13, 2023
37be04a
add download files from s3 fn
saiatmakuri Nov 13, 2023
c3bcc98
use same helpers and add docstring
saiatmakuri Nov 13, 2023
0d00465
change to namedtuple
saiatmakuri Nov 13, 2023
5dab779
add s3 repo locations
saiatmakuri Nov 14, 2023
1cc7381
fallback read from s3
saiatmakuri Nov 14, 2023
45e3502
refactor tokenizer laod
saiatmakuri Nov 14, 2023
bd2bb73
edit tests
saiatmakuri Nov 14, 2023
3e9123a
refactor _SUPPORTED_MODELS_BY_FRAMEWORK
saiatmakuri Nov 14, 2023
1d12b73
updates for tests
saiatmakuri Nov 14, 2023
04e078e
move to utils file
saiatmakuri Nov 14, 2023
58c1fbf
move some fns over
saiatmakuri Nov 14, 2023
0344a97
use lru cache
saiatmakuri Nov 14, 2023
8d4ab8a
move model info
saiatmakuri Nov 14, 2023
e53394e
root to opt
saiatmakuri Nov 14, 2023
7842ec1
add log and adjust integration test
saiatmakuri Nov 14, 2023
b610e13
Merge branch 'main' into saiatmakuri/count-prompt-tokens
saiatmakuri Nov 14, 2023
f5da2f9
refocus logs
saiatmakuri Nov 14, 2023
e13d65f
change empty string to optional
saiatmakuri Nov 14, 2023
215f7ff
mock count tokens for unit tests
saiatmakuri Nov 14, 2023
ed168c9
change 1 mock
saiatmakuri Nov 14, 2023
a057792
add unit tests
saiatmakuri Nov 14, 2023
1c32f0f
config change
saiatmakuri Nov 14, 2023
37d4c21
comments pt 1
saiatmakuri Nov 14, 2023
c1d9f06
move internal logic to plugins file
saiatmakuri Nov 14, 2023
a4019e1
replace usage of utils file
saiatmakuri Nov 14, 2023
3c1a4f5
rearrange test mock
saiatmakuri Nov 14, 2023
5a672dd
only return prompt tokens count on last token in stream
saiatmakuri Nov 14, 2023
8b3a5e9
fix mock
saiatmakuri Nov 14, 2023
4bb2d15
reorganize imports
saiatmakuri Nov 14, 2023
723c786
inject in external interfaces
saiatmakuri Nov 15, 2023
2f821fb
make changes to tests
saiatmakuri Nov 15, 2023
8001dbe
Merge branch 'main' into saiatmakuri/count-prompt-tokens
saiatmakuri Nov 15, 2023
e8ff055
fix tests
saiatmakuri Nov 15, 2023
4922b65
adjust test
saiatmakuri Nov 15, 2023
99d6dfb
oops test
saiatmakuri Nov 15, 2023
6916978
add more tests
saiatmakuri Nov 15, 2023
fd94350
Merge branch 'main' into saiatmakuri/count-prompt-tokens
saiatmakuri Nov 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion charts/model-engine/values_circleci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ config:
user_inference_pytorch_repository: "hosted-model-inference/async-pytorch"
user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-cpu"
docker_image_layer_cache_repository: "kaniko-cache"
hf_user_fine_tuned_weights_prefix: "s3://$CIRCLECI_AWS_S3_BUCKET"
hf_user_fine_tuned_weights_prefix: "s3://$CIRCLECI_AWS_S3_BUCKET/model-weights"

# Service Account
serviceAccount:
Expand Down
15 changes: 15 additions & 0 deletions integration_tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time

import pytest
from model_engine_server.common.env_vars import CIRCLECI
from tenacity import RetryError, retry, retry_if_exception_type, stop_after_attempt, wait_fixed

from .rest_api_utils import (
Expand Down Expand Up @@ -234,3 +235,17 @@ def test_sync_streaming_model_endpoint(capsys):
)
finally:
delete_model_endpoint(create_endpoint_request["name"], user)


@pytest.mark.skipif(CIRCLECI, reason="skip on circleci since need to figure out s3 access")
def test_models_tokenizers() -> None:
from model_engine_server.infra.gateways.s3_llm_artifact_gateway import S3LLMArtifactGateway
from model_engine_server.infra.repositories import LiveTokenizerRepository
from model_engine_server.infra.repositories.live_tokenizer_repository import (
SUPPORTED_MODELS_INFO,
)

llm_artifact_gateway = S3LLMArtifactGateway()
tokenizer_repository = LiveTokenizerRepository(llm_artifact_gateway=llm_artifact_gateway)
for model_name in SUPPORTED_MODELS_INFO:
tokenizer_repository.load_tokenizer(model_name)
6 changes: 6 additions & 0 deletions model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
DockerRepository,
LLMFineTuneEventsRepository,
ModelBundleRepository,
TokenizerRepository,
TriggerRepository,
)
from model_engine_server.domain.services import (
Expand Down Expand Up @@ -87,6 +88,7 @@
DbTriggerRepository,
ECRDockerRepository,
FakeDockerRepository,
LiveTokenizerRepository,
RedisModelEndpointCacheRepository,
S3FileLLMFineTuneEventsRepository,
S3FileLLMFineTuneRepository,
Expand Down Expand Up @@ -134,6 +136,7 @@ class ExternalInterfaces:
llm_artifact_gateway: LLMArtifactGateway
cron_job_gateway: CronJobGateway
monitoring_metrics_gateway: MonitoringMetricsGateway
tokenizer_repository: TokenizerRepository


def get_default_monitoring_metrics_gateway() -> MonitoringMetricsGateway:
Expand Down Expand Up @@ -260,6 +263,8 @@ def _get_external_interfaces(

docker_repository = ECRDockerRepository() if not CIRCLECI else FakeDockerRepository()

tokenizer_repository = LiveTokenizerRepository(llm_artifact_gateway=llm_artifact_gateway)

external_interfaces = ExternalInterfaces(
docker_repository=docker_repository,
model_bundle_repository=model_bundle_repository,
Expand All @@ -281,6 +286,7 @@ def _get_external_interfaces(
trigger_repository=trigger_repository,
cron_job_gateway=cron_job_gateway,
monitoring_metrics_gateway=monitoring_metrics_gateway,
tokenizer_repository=tokenizer_repository,
)
return external_interfaces

Expand Down
2 changes: 2 additions & 0 deletions model-engine/model_engine_server/api/llms_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ async def create_completion_sync_task(
use_case = CompletionSyncV1UseCase(
model_endpoint_service=external_interfaces.model_endpoint_service,
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
tokenizer_repository=external_interfaces.tokenizer_repository,
)
return await use_case.execute(
user=auth, model_endpoint_name=model_endpoint_name, request=request
Expand Down Expand Up @@ -290,6 +291,7 @@ async def create_completion_stream_task(
use_case = CompletionStreamV1UseCase(
model_endpoint_service=external_interfaces.model_endpoint_service,
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
tokenizer_repository=external_interfaces.tokenizer_repository,
)
response = use_case.execute(user=auth, model_endpoint_name=model_endpoint_name, request=request)

Expand Down
4 changes: 3 additions & 1 deletion model-engine/model_engine_server/common/dtos/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class GetLLMModelEndpointV1Response(BaseModel):
"""

name: str
model_name: Optional[str] = None
model_name: str
source: LLMSource
status: ModelEndpointStatus
inference_framework: LLMInferenceFramework
Expand Down Expand Up @@ -143,6 +143,7 @@ class TokenOutput(BaseModel):

class CompletionOutput(BaseModel):
text: str
num_prompt_tokens: int
num_completion_tokens: int
tokens: Optional[List[TokenOutput]] = None

Expand Down Expand Up @@ -198,6 +199,7 @@ class CompletionStreamV1Request(BaseModel):
class CompletionStreamOutput(BaseModel):
text: str
finished: bool
num_prompt_tokens: Optional[int] = None
num_completion_tokens: Optional[int] = None
token: Optional[TokenOutput] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,31 @@ class LLMArtifactGateway(ABC):
def list_files(self, path: str, **kwargs) -> List[str]:
"""
Gets a list of files from a given path.

Args:
path (str): path to list files
"""
pass

@abstractmethod
def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]:
"""
Download files from a given path to a target path.

Args:
path (str): path to list files
target_path (str): local path to download files
overwrite (bool): whether to overwrite existing local files
"""
pass

@abstractmethod
def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
"""
Gets a list of URLs for all files associated with a given model.

Args:
owner (str): owner of the model
model_name (str): name of the model
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from .docker_repository import DockerRepository
from .llm_fine_tune_events_repository import LLMFineTuneEventsRepository
from .model_bundle_repository import ModelBundleRepository
from .tokenizer_repository import TokenizerRepository
from .trigger_repository import TriggerRepository

__all__: Sequence[str] = [
"DockerRepository",
"DockerImageBatchJobBundleRepository",
"LLMFineTuneEventsRepository",
"ModelBundleRepository",
"TokenizerRepository",
"TriggerRepository",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from abc import ABC, abstractmethod

from transformers import AutoTokenizer


class TokenizerRepository(ABC):
@abstractmethod
def load_tokenizer(self, model_name: str) -> AutoTokenizer:
"""
Loads a tokenizer from a model name.

Args:
model_name: The model name to load the tokenizer for.

Returns:
A tokenizer.
"""
pass
Loading