From 79ca0a006e4f5536bd6cf1cdaaadff142c5ba4c8 Mon Sep 17 00:00:00 2001 From: Curtis Castrapel Date: Mon, 22 Dec 2025 14:34:11 -0800 Subject: [PATCH] Configuration flag to skip token count for batch jobs --- litellm/__init__.py | 1 + litellm/proxy/hooks/batch_rate_limiter.py | 14 +++- tests/batches_tests/test_batch_rate_limits.py | 70 +++++++++++++++++++ 3 files changed, 83 insertions(+), 2 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index b20b3c5f8e11..25139ef554b3 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -305,6 +305,7 @@ #################### logging: bool = True enable_loadbalancing_on_batch_endpoints: Optional[bool] = None +skip_batch_token_counting_providers: Optional[List[str]] = None enable_caching_on_provider_specific_optional_params: bool = ( False # feature-flag for caching on optional params - e.g. 'top_k' ) diff --git a/litellm/proxy/hooks/batch_rate_limiter.py b/litellm/proxy/hooks/batch_rate_limiter.py index ecad8bc1b119..45a02346a861 100644 --- a/litellm/proxy/hooks/batch_rate_limiter.py +++ b/litellm/proxy/hooks/batch_rate_limiter.py @@ -244,14 +244,24 @@ async def count_input_file_usage( ) -> BatchFileUsage: """ Count number of requests and tokens in a batch input file. - + Args: file_id: The file ID to read custom_llm_provider: The custom LLM provider to use for token encoding - + Returns: BatchFileUsage with total_tokens and request_count """ + skip_providers = litellm.skip_batch_token_counting_providers or [] + if custom_llm_provider in skip_providers: + verbose_proxy_logger.debug( + f"Skipping batch token counting for provider: {custom_llm_provider}" + ) + return BatchFileUsage( + total_tokens=0, + request_count=0, + ) + try: # Read file content file_content = await litellm.afile_content( diff --git a/tests/batches_tests/test_batch_rate_limits.py b/tests/batches_tests/test_batch_rate_limits.py index 776aba438c2c..ff698f5b3c12 100644 --- a/tests/batches_tests/test_batch_rate_limits.py +++ b/tests/batches_tests/test_batch_rate_limits.py @@ -389,3 +389,73 @@ async def test_batch_rate_limit_multiple_requests(): print(f" Error: {exc_info.value.detail}") finally: os.unlink(file_path_2) + + +@pytest.mark.asyncio() +async def test_skip_batch_token_counting_for_providers(): + """ + Test that batch token counting can be skipped for configured providers. + + When skip_batch_token_counting_providers includes a provider, the batch rate limiter + should return zero tokens and requests without attempting to download the file. + This is useful for providers like vertex_ai where batch files are stored in GCS + and downloading large files for token counting is impractical. + """ + import litellm + + original_value = litellm.skip_batch_token_counting_providers + + try: + litellm.skip_batch_token_counting_providers = ["vertex_ai"] + + batch_limiter = _PROXY_BatchRateLimiter( + internal_usage_cache=None, + parallel_request_limiter=None, + ) + + result = await batch_limiter.count_input_file_usage( + file_id="gs://test-bucket/test.jsonl", + custom_llm_provider="vertex_ai", + ) + + assert result.total_tokens == 0, "Should return 0 tokens when provider is in skip list" + assert result.request_count == 0, "Should return 0 requests when provider is in skip list" + print("✓ Token counting skipped for vertex_ai provider") + finally: + litellm.skip_batch_token_counting_providers = original_value + + +@pytest.mark.asyncio() +async def test_skip_batch_token_counting_multiple_providers(): + """ + Test that multiple providers can be configured in skip list. + """ + import litellm + + original_value = litellm.skip_batch_token_counting_providers + + try: + litellm.skip_batch_token_counting_providers = ["vertex_ai", "azure"] + + batch_limiter = _PROXY_BatchRateLimiter( + internal_usage_cache=None, + parallel_request_limiter=None, + ) + + result_vertex = await batch_limiter.count_input_file_usage( + file_id="gs://test-bucket/test.jsonl", + custom_llm_provider="vertex_ai", + ) + assert result_vertex.total_tokens == 0 + assert result_vertex.request_count == 0 + + result_azure = await batch_limiter.count_input_file_usage( + file_id="azure-file-id", + custom_llm_provider="azure", + ) + assert result_azure.total_tokens == 0 + assert result_azure.request_count == 0 + + print("✓ Token counting skipped for multiple providers") + finally: + litellm.skip_batch_token_counting_providers = original_value