Skip to content

Conversation

saiatmakuri
Copy link
Contributor

@saiatmakuri saiatmakuri commented Nov 7, 2023

Pull Request Summary

For all completions, return the number of tokens in the prompt as num_prompt_tokens.

We get the count for prompt tokens in a waterfall method of fallbacks:

  • get prompt tokens count from the inference framework
  • load a tokenizer from HF weights
  • load a tokenizer from weights stored in S3

Test Plan and Usage Guide

Ran skipped integration test on local test. Need to figure out how to enable for CircleCI.
Compared token counts to https://huggingface.co/spaces/Xenova/the-tokenizer-playground

@saiatmakuri saiatmakuri self-assigned this Nov 14, 2023
@saiatmakuri saiatmakuri added the enhancement New feature or request label Nov 14, 2023
@@ -178,6 +266,82 @@
DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes


# Hack to count prompt tokens
tokenizer_cache: Dict[str, AutoTokenizer] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would we want to make this an lru cache or something? idk how big the tokenizers can get

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

~3MB per model. good call to add lru cache

@@ -234,3 +239,10 @@ def test_sync_streaming_model_endpoint(capsys):
)
finally:
delete_model_endpoint(create_endpoint_request["name"], user)


@pytest.mark.skip(reason="test doesn't currently work, needs to figure out s3 fallback")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the integration tests before deploy, I'd like to check s3 as well. I had changed this skip to since the comment to a skipif to just skip for the circleci env



def mock_boto3_session(fake_files: List[str]):
mock_session = mock.Mock()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moto could be an option here, though this mocking can be fine too since you're already at the last layer, where you're in the S3 artifact gateway (want to avoid mocking across multiple layers).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah moto seems cool, useful to know for the future. wanted to create a custom side effect so this leans itself more easily to that

@@ -1471,6 +1508,11 @@ async def execute(
args["parameters"]["do_sample"] = False
if request.return_token_log_probs:
args["parameters"]["return_details"] = True
num_prompt_tokens = count_tokens(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm okay we do tokenization for less used frameworks, but for more important models can we move tokenization into the framework itself?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about moving this fall-back tokenization into the forwarder? That would then have less overhead in the gateway, which also supports high QPS routes like get/post tasks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah IMO the forwarder feels like the more "natural" place to put counting tokens, this way we'd only have to download one tokenizer in the forwarder, and we offload the computation to something that scales up more in proportion with load

this does mean that the forwarder is gonna have to know to carry out this token-counting logic exactly when it's forwarding to an LLM though, which will mean there are different "modes" for the forwarder (e.g. not-LLM, where it just passes requests through, and LLM, where it maybe does some specific processing and then passes requests through)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upstreaming all changes to the framework will be the goal, but this is a temporary stopgap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does mean that the forwarder is gonna have to know to carry out this token-counting logic exactly when it's forwarding to an LLM though, which will mean there are different "modes" for the forwarder (e.g. not-LLM, where it just passes requests through, and LLM, where it maybe does some specific processing and then passes requests through)

thought about this for emitting token metrics in gateway vs forwarder as well. think its worth a larger discussion after this PR

Copy link
Contributor

@yunfeng-scale yunfeng-scale left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the current ephemeral disk size for model engine pods? should we add more?

@saiatmakuri
Copy link
Contributor Author

what's the current ephemeral disk size for model engine pods? should we add more?

we've set it to 128Mi, this should be sufficient atm

@yunfeng-scale
Copy link
Contributor

what's the current ephemeral disk size for model engine pods? should we add more?

we've set it to 128Mi, this should be sufficient atm

unclear to me whether this is enough since i think previously model engine got out of disk due to usage of 100MB space?

Copy link
Member

@yixu34 yixu34 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good for a V0 of token counting. If needed later on, we can consider in-framework tokenization and/or doing tokenization in the forwarder.

We should carefully monitor error rates and token latency/throughput along this rollout

"model_engine_server.infra.gateways.s3_llm_artifact_gateway.os.makedirs",
lambda *args, **kwargs: None, # noqa
)
def test_s3_llm_artifact_gateway_download_folder(llm_artifact_gateway, fake_files):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for bumping up our test coverage! 🤜🏻 🤛🏻

@saiatmakuri saiatmakuri merged commit 257ea6c into main Nov 15, 2023
@saiatmakuri saiatmakuri deleted the saiatmakuri/count-prompt-tokens branch November 15, 2023 22:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants