From eb159301310dd6750d6dc003e80517cbf45f90e1 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Tue, 25 Mar 2025 16:13:19 -0400 Subject: [PATCH 1/4] implement centml --- src/huggingface_hub/inference/_client.py | 30 +++---- .../inference/_providers/__init__.py | 84 +++++++++++++------ .../inference/_providers/centml.py | 50 +++++++++++ tests/test_inference_client.py | 19 +++-- 4 files changed, 137 insertions(+), 46 deletions(-) create mode 100644 src/huggingface_hub/inference/_providers/centml.py diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index eb26e8e94e..0c4c0dc003 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -133,7 +133,7 @@ class InferenceClient: path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) documentation for details). When passing a URL as `model`, the client will not append any suffix path to it. provider (`str`, *optional*): - Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. + Name of the provider to use for inference. Can be `"black-forest-labs"`, `"centml"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. defaults to hf-inference (Hugging Face Serverless Inference API). If model is a URL or `base_url` is passed, then `provider` is not used. token (`str`, *optional*): @@ -224,7 +224,7 @@ def post( # type: ignore[misc] model: Optional[str] = None, task: Optional[str] = None, stream: Literal[False] = ..., - ) -> bytes: ... + ) -> bytes: ... @overload def post( # type: ignore[misc] @@ -235,7 +235,7 @@ def post( # type: ignore[misc] model: Optional[str] = None, task: Optional[str] = None, stream: Literal[True] = ..., - ) -> Iterable[bytes]: ... + ) -> Iterable[bytes]: ... @overload def post( @@ -246,7 +246,7 @@ def post( model: Optional[str] = None, task: Optional[str] = None, stream: bool = False, - ) -> Union[bytes, Iterable[bytes]]: ... + ) -> Union[bytes, Iterable[bytes]]: ... @_deprecate_method( version="0.31.0", @@ -295,17 +295,17 @@ def post( @overload def _inner_post( # type: ignore[misc] self, request_parameters: RequestParameters, *, stream: Literal[False] = ... - ) -> bytes: ... + ) -> bytes: ... @overload def _inner_post( # type: ignore[misc] self, request_parameters: RequestParameters, *, stream: Literal[True] = ... - ) -> Iterable[bytes]: ... + ) -> Iterable[bytes]: ... @overload def _inner_post( self, request_parameters: RequestParameters, *, stream: bool = False - ) -> Union[bytes, Iterable[bytes]]: ... + ) -> Union[bytes, Iterable[bytes]]: ... def _inner_post( self, request_parameters: RequestParameters, *, stream: bool = False @@ -520,7 +520,7 @@ def chat_completion( # type: ignore top_logprobs: Optional[int] = None, top_p: Optional[float] = None, extra_body: Optional[Dict] = None, - ) -> ChatCompletionOutput: ... + ) -> ChatCompletionOutput: ... @overload def chat_completion( # type: ignore @@ -546,7 +546,7 @@ def chat_completion( # type: ignore top_logprobs: Optional[int] = None, top_p: Optional[float] = None, extra_body: Optional[Dict] = None, - ) -> Iterable[ChatCompletionStreamOutput]: ... + ) -> Iterable[ChatCompletionStreamOutput]: ... @overload def chat_completion( @@ -572,7 +572,7 @@ def chat_completion( top_logprobs: Optional[int] = None, top_p: Optional[float] = None, extra_body: Optional[Dict] = None, - ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: ... + ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: ... def chat_completion( self, @@ -1918,7 +1918,7 @@ def text_generation( # type: ignore truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, - ) -> str: ... + ) -> str: ... @overload def text_generation( # type: ignore @@ -1948,7 +1948,7 @@ def text_generation( # type: ignore truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, - ) -> TextGenerationOutput: ... + ) -> TextGenerationOutput: ... @overload def text_generation( # type: ignore @@ -1978,7 +1978,7 @@ def text_generation( # type: ignore truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, - ) -> Iterable[str]: ... + ) -> Iterable[str]: ... @overload def text_generation( # type: ignore @@ -2008,7 +2008,7 @@ def text_generation( # type: ignore truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, - ) -> Iterable[TextGenerationStreamOutput]: ... + ) -> Iterable[TextGenerationStreamOutput]: ... @overload def text_generation( @@ -2038,7 +2038,7 @@ def text_generation( truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, - ) -> Union[TextGenerationOutput, Iterable[TextGenerationStreamOutput]]: ... + ) -> Union[TextGenerationOutput, Iterable[TextGenerationStreamOutput]]: ... def text_generation( self, diff --git a/src/huggingface_hub/inference/_providers/__init__.py b/src/huggingface_hub/inference/_providers/__init__.py index c96b8700e1..c6c56259ef 100644 --- a/src/huggingface_hub/inference/_providers/__init__.py +++ b/src/huggingface_hub/inference/_providers/__init__.py @@ -19,10 +19,12 @@ from .replicate import ReplicateTask, ReplicateTextToSpeechTask from .sambanova import SambanovaConversationalTask from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask +from .centml import CentmlConversationalTask, CentmlTextGenerationTask PROVIDER_T = Literal[ "black-forest-labs", + "centml", "cerebras", "cohere", "fal-ai", @@ -41,6 +43,10 @@ "black-forest-labs": { "text-to-image": BlackForestLabsTextToImageTask(), }, + "centml": { + "conversational": CentmlConversationalTask(), + "text-generation": CentmlTextGenerationTask(), + }, "cerebras": { "conversational": CerebrasConversationalTask(), }, @@ -57,32 +63,58 @@ "conversational": FireworksAIConversationalTask(), }, "hf-inference": { - "text-to-image": HFInferenceTask("text-to-image"), - "conversational": HFInferenceConversational(), - "text-generation": HFInferenceTask("text-generation"), - "text-classification": HFInferenceTask("text-classification"), - "question-answering": HFInferenceTask("question-answering"), - "audio-classification": HFInferenceBinaryInputTask("audio-classification"), - "automatic-speech-recognition": HFInferenceBinaryInputTask("automatic-speech-recognition"), - "fill-mask": HFInferenceTask("fill-mask"), - "feature-extraction": HFInferenceTask("feature-extraction"), - "image-classification": HFInferenceBinaryInputTask("image-classification"), - "image-segmentation": HFInferenceBinaryInputTask("image-segmentation"), - "document-question-answering": HFInferenceTask("document-question-answering"), - "image-to-text": HFInferenceBinaryInputTask("image-to-text"), - "object-detection": HFInferenceBinaryInputTask("object-detection"), - "audio-to-audio": HFInferenceBinaryInputTask("audio-to-audio"), - "zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"), - "zero-shot-classification": HFInferenceTask("zero-shot-classification"), - "image-to-image": HFInferenceBinaryInputTask("image-to-image"), - "sentence-similarity": HFInferenceTask("sentence-similarity"), - "table-question-answering": HFInferenceTask("table-question-answering"), - "tabular-classification": HFInferenceTask("tabular-classification"), - "text-to-speech": HFInferenceTask("text-to-speech"), - "token-classification": HFInferenceTask("token-classification"), - "translation": HFInferenceTask("translation"), - "summarization": HFInferenceTask("summarization"), - "visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"), + "text-to-image": + HFInferenceTask("text-to-image"), + "conversational": + HFInferenceConversational(), + "text-generation": + HFInferenceTask("text-generation"), + "text-classification": + HFInferenceTask("text-classification"), + "question-answering": + HFInferenceTask("question-answering"), + "audio-classification": + HFInferenceBinaryInputTask("audio-classification"), + "automatic-speech-recognition": + HFInferenceBinaryInputTask("automatic-speech-recognition"), + "fill-mask": + HFInferenceTask("fill-mask"), + "feature-extraction": + HFInferenceTask("feature-extraction"), + "image-classification": + HFInferenceBinaryInputTask("image-classification"), + "image-segmentation": + HFInferenceBinaryInputTask("image-segmentation"), + "document-question-answering": + HFInferenceTask("document-question-answering"), + "image-to-text": + HFInferenceBinaryInputTask("image-to-text"), + "object-detection": + HFInferenceBinaryInputTask("object-detection"), + "audio-to-audio": + HFInferenceBinaryInputTask("audio-to-audio"), + "zero-shot-image-classification": + HFInferenceBinaryInputTask("zero-shot-image-classification"), + "zero-shot-classification": + HFInferenceTask("zero-shot-classification"), + "image-to-image": + HFInferenceBinaryInputTask("image-to-image"), + "sentence-similarity": + HFInferenceTask("sentence-similarity"), + "table-question-answering": + HFInferenceTask("table-question-answering"), + "tabular-classification": + HFInferenceTask("tabular-classification"), + "text-to-speech": + HFInferenceTask("text-to-speech"), + "token-classification": + HFInferenceTask("token-classification"), + "translation": + HFInferenceTask("translation"), + "summarization": + HFInferenceTask("summarization"), + "visual-question-answering": + HFInferenceBinaryInputTask("visual-question-answering"), }, "hyperbolic": { "text-to-image": HyperbolicTextToImageTask(), diff --git a/src/huggingface_hub/inference/_providers/centml.py b/src/huggingface_hub/inference/_providers/centml.py new file mode 100644 index 0000000000..8619bf2b8f --- /dev/null +++ b/src/huggingface_hub/inference/_providers/centml.py @@ -0,0 +1,50 @@ +from typing import Optional + +from huggingface_hub.inference._providers._common import ( + BaseConversationalTask, + BaseTextGenerationTask, +) + + +class CentmlConversationalTask(BaseConversationalTask): + """ + Provider helper for centml conversational (chat completions) tasks. + This helper builds requests in the OpenAI API format. + """ + + def __init__(self): + # Set the provider name to "centml" and use the centml serverless endpoint URL. + super().__init__(provider="centml", base_url="https://api.centml.com/openai") + + def _prepare_api_key(self, api_key: Optional[str]) -> str: + if api_key is None: + raise ValueError( + "An API key must be provided to use the centml provider.") + return api_key + + def _prepare_mapped_model(self, model: Optional[str]) -> str: + if model is None: + raise ValueError("Please provide a centml model ID.") + return model + + +class CentmlTextGenerationTask(BaseTextGenerationTask): + """ + Provider helper for centml text generation (completions) tasks. + This helper builds requests in the OpenAI API format. + """ + + def __init__(self): + super().__init__(provider="centml", base_url="https://api.centml.com/openai") + + def _prepare_api_key(self, api_key: Optional[str]) -> str: + if api_key is None: + raise ValueError( + "An API key must be provided to use the centml provider.") + return api_key + + def _prepare_mapped_model(self, model: Optional[str]) -> str: + if model is None: + raise ValueError("Please provide a centml model ID.") + return model + diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index 85b9593b55..6b41e4e200 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -63,6 +63,10 @@ "black-forest-labs": { "text-to-image": "black-forest-labs/FLUX.1-dev", }, + "centml": { + "conversational": "meta-llama/Llama-3.3-70B-Instruct", + "text-generation": "meta-llama/Llama-3.2-3B-Instruct", + }, "cerebras": { "conversational": "meta-llama/Llama-3.3-70B-Instruct", }, @@ -79,11 +83,14 @@ "conversational": "meta-llama/Llama-3.3-70B-Instruct", }, "hf-inference": { - "audio-classification": "alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech", + "audio-classification": + "alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech", "audio-to-audio": "speechbrain/sepformer-wham", - "automatic-speech-recognition": "jonatasgrosman/wav2vec2-large-xlsr-53-english", + "automatic-speech-recognition": + "jonatasgrosman/wav2vec2-large-xlsr-53-english", "conversational": "meta-llama/Llama-3.1-8B-Instruct", - "document-question-answering": "naver-clova-ix/donut-base-finetuned-docvqa", + "document-question-answering": + "naver-clova-ix/donut-base-finetuned-docvqa", "feature-extraction": "facebook/bart-base", "image-classification": "google/vit-base-patch16-224", "image-to-text": "Salesforce/blip-image-captioning-base", @@ -94,10 +101,12 @@ "table-question-answering": "google/tapas-base-finetuned-wtq", "tabular-classification": "julien-c/wine-quality", "tabular-regression": "scikit-learn/Fish-Weight", - "text-classification": "distilbert/distilbert-base-uncased-finetuned-sst-2-english", + "text-classification": + "distilbert/distilbert-base-uncased-finetuned-sst-2-english", "text-to-image": "CompVis/stable-diffusion-v1-4", "text-to-speech": "espnet/kan-bayashi_ljspeech_vits", - "token-classification": "dbmdz/bert-large-cased-finetuned-conll03-english", + "token-classification": + "dbmdz/bert-large-cased-finetuned-conll03-english", "translation": "t5-small", "visual-question-answering": "dandelin/vilt-b32-finetuned-vqa", "zero-shot-classification": "facebook/bart-large-mnli", From 63227e5181fae8db656ea3dab48590e6c9c05a4a Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Tue, 25 Mar 2025 16:18:15 -0400 Subject: [PATCH 2/4] add VCR cassettes --- ...tion_no_stream[centml,conversational].yaml | 75 +++++++++++++++++++ ...on_with_stream[centml,conversational].yaml | 70 +++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 tests/cassettes/TestInferenceClient.test_chat_completion_no_stream[centml,conversational].yaml create mode 100644 tests/cassettes/TestInferenceClient.test_chat_completion_with_stream[centml,conversational].yaml diff --git a/tests/cassettes/TestInferenceClient.test_chat_completion_no_stream[centml,conversational].yaml b/tests/cassettes/TestInferenceClient.test_chat_completion_no_stream[centml,conversational].yaml new file mode 100644 index 0000000000..b3a977f1de --- /dev/null +++ b/tests/cassettes/TestInferenceClient.test_chat_completion_no_stream[centml,conversational].yaml @@ -0,0 +1,75 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is deep learning?"}], "model": "meta-llama/Llama-3.3-70B-Instruct", + "stream": false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '195' + Content-Type: + - application/json + X-Amzn-Trace-Id: + - 9f95510c-8aae-4df7-820e-eafbc8ad396f + method: POST + uri: https://api.centml.com/openai/v1/chat/completions + response: + body: + string: '{"id":"chatcmpl-de1b282d4615cdcf51313490db81295a","object":"chat.completion","created":1742933815,"model":"meta-llama/Llama-3.3-70B-Instruct","choices":[{"index":0,"message":{"role":"assistant","reasoning_content":null,"content":"**Deep + Learning: An Overview**\n=====================================\n\nDeep learning + is a subset of machine learning that involves the use of artificial neural + networks to analyze and interpret data. These neural networks are designed + to mimic the structure and function of the human brain, with multiple layers + of interconnected nodes (neurons) that process and transmit information.\n\n**Key + Characteristics:**\n\n1. **Artificial Neural Networks**: Deep learning models + are based on artificial neural networks, which are composed of multiple layers + of nodes (neurons) that process and transmit information.\n2. **Multiple Layers**: + Deep learning models have multiple layers, each of which performs a specific + function, such as feature extraction, feature transformation, or classification.\n3. + **Hierarchical Representation**: Deep learning models learn hierarchical representations + of data, with early layers learning low-level features and later layers learning + higher-level features.\n4. **Large Amounts of Data**: Deep learning models + require large amounts of data to train, as they need to learn complex patterns + and relationships in the data.\n\n**Types of Deep Learning Models:**\n\n1. + **Convolutional Neural Networks (CNNs)**: Used for image and video processing, + CNNs are designed to extract features from spatially structured data.\n2. + **Recurrent Neural Networks (RNNs)**: Used for sequential data, such as speech + or text, RNNs are designed to model temporal relationships in data.\n3. **Autoencoders**: + Used for dimensionality reduction and generative modeling, autoencoders are + designed to learn compact representations of data.\n\n**Applications:**\n\n1. + **Computer Vision**: Deep learning models are widely used in computer vision + applications, such as image classification, object detection, and segmentation.\n2. + **Natural Language Processing**: Deep learning models are used in NLP applications, + such as language modeling, text classification, and machine translation.\n3. + **Speech Recognition**: Deep learning models are used in speech recognition + applications, such as speech-to-text and voice recognition.\n\n**Advantages:**\n\n1. + **High Accuracy**: Deep learning models can achieve high accuracy in complex + tasks, such as image recognition and speech recognition.\n2. **Flexibility**: + Deep learning models can be used in a wide range of applications, from computer + vision to NLP.\n3. **Scalability**: Deep learning models can be trained on + large datasets and can scale to large applications.\n\n**Challenges:**\n\n1. + **Computational Requirements**: Deep learning models require significant computational + resources to train and deploy.\n2. **Data Requirements**: Deep learning models + require large amounts of data to train, which can be difficult to obtain.\n3. + **Interpretability**: Deep learning models can be difficult to interpret, + making it challenging to understand why a particular decision was made.","tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":595,"completion_tokens":549,"prompt_tokens_details":null},"prompt_logprobs":null}' + headers: + content-type: + - application/json + date: + - Tue, 25 Mar 2025 20:16:54 GMT + server: + - istio-envoy + transfer-encoding: + - chunked + x-envoy-upstream-service-time: + - '3844' + status: + code: 200 + message: OK +version: 1 diff --git a/tests/cassettes/TestInferenceClient.test_chat_completion_with_stream[centml,conversational].yaml b/tests/cassettes/TestInferenceClient.test_chat_completion_with_stream[centml,conversational].yaml new file mode 100644 index 0000000000..a1debc14af --- /dev/null +++ b/tests/cassettes/TestInferenceClient.test_chat_completion_with_stream[centml,conversational].yaml @@ -0,0 +1,70 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is deep learning?"}], "model": "meta-llama/Llama-3.3-70B-Instruct", + "max_tokens": 20, "stream": true}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '212' + Content-Type: + - application/json + X-Amzn-Trace-Id: + - ad425e9f-bc1f-48df-b84d-a5ebad74cd66 + method: POST + uri: https://api.centml.com/openai/v1/chat/completions + response: + body: + string: 'data: {"id":"chatcmpl-7694e7d5663b4d0e22706f8260bae6df","object":"chat.completion.chunk","created":1742933819,"model":"meta-llama/Llama-3.3-70B-Instruct","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-7694e7d5663b4d0e22706f8260bae6df","object":"chat.completion.chunk","created":1742933819,"model":"meta-llama/Llama-3.3-70B-Instruct","choices":[{"index":0,"delta":{"content":"**"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-7694e7d5663b4d0e22706f8260bae6df","object":"chat.completion.chunk","created":1742933819,"model":"meta-llama/Llama-3.3-70B-Instruct","choices":[{"index":0,"delta":{"content":"Deep"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-7694e7d5663b4d0e22706f8260bae6df","object":"chat.completion.chunk","created":1742933819,"model":"meta-llama/Llama-3.3-70B-Instruct","choices":[{"index":0,"delta":{"content":" + Learning Overview"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-7694e7d5663b4d0e22706f8260bae6df","object":"chat.completion.chunk","created":1742933819,"model":"meta-llama/Llama-3.3-70B-Instruct","choices":[{"index":0,"delta":{"content":"**\n=========================\n\nDeep + learning is a"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-7694e7d5663b4d0e22706f8260bae6df","object":"chat.completion.chunk","created":1742933819,"model":"meta-llama/Llama-3.3-70B-Instruct","choices":[{"index":0,"delta":{"content":" + subset"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-7694e7d5663b4d0e22706f8260bae6df","object":"chat.completion.chunk","created":1742933819,"model":"meta-llama/Llama-3.3-70B-Instruct","choices":[{"index":0,"delta":{"content":" + of machine learning that involves"},"logprobs":null,"finish_reason":null}]} + + + data: {"id":"chatcmpl-7694e7d5663b4d0e22706f8260bae6df","object":"chat.completion.chunk","created":1742933819,"model":"meta-llama/Llama-3.3-70B-Instruct","choices":[{"index":0,"delta":{"content":" + the use"},"logprobs":null,"finish_reason":"length","stop_reason":null}]} + + + data: [DONE] + + + ' + headers: + content-type: + - text/event-stream; charset=utf-8; charset=utf-8 + date: + - Tue, 25 Mar 2025 20:16:58 GMT + server: + - istio-envoy + transfer-encoding: + - chunked + x-envoy-upstream-service-time: + - '328' + status: + code: 200 + message: OK +version: 1 From f4c49a8fe3081cbcf26e0bf643ad51f9142c626f Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Wed, 26 Mar 2025 21:07:08 -0400 Subject: [PATCH 3/4] revert auto-format --- src/huggingface_hub/inference/_client.py | 28 +++---- .../inference/_providers/__init__.py | 79 ++++++------------- 2 files changed, 40 insertions(+), 67 deletions(-) diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 0c4c0dc003..a2576169f7 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -224,7 +224,7 @@ def post( # type: ignore[misc] model: Optional[str] = None, task: Optional[str] = None, stream: Literal[False] = ..., - ) -> bytes: ... + ) -> bytes: ... @overload def post( # type: ignore[misc] @@ -235,7 +235,7 @@ def post( # type: ignore[misc] model: Optional[str] = None, task: Optional[str] = None, stream: Literal[True] = ..., - ) -> Iterable[bytes]: ... + ) -> Iterable[bytes]: ... @overload def post( @@ -246,7 +246,7 @@ def post( model: Optional[str] = None, task: Optional[str] = None, stream: bool = False, - ) -> Union[bytes, Iterable[bytes]]: ... + ) -> Union[bytes, Iterable[bytes]]: ... @_deprecate_method( version="0.31.0", @@ -295,17 +295,17 @@ def post( @overload def _inner_post( # type: ignore[misc] self, request_parameters: RequestParameters, *, stream: Literal[False] = ... - ) -> bytes: ... + ) -> bytes: ... @overload def _inner_post( # type: ignore[misc] self, request_parameters: RequestParameters, *, stream: Literal[True] = ... - ) -> Iterable[bytes]: ... + ) -> Iterable[bytes]: ... @overload def _inner_post( self, request_parameters: RequestParameters, *, stream: bool = False - ) -> Union[bytes, Iterable[bytes]]: ... + ) -> Union[bytes, Iterable[bytes]]: ... def _inner_post( self, request_parameters: RequestParameters, *, stream: bool = False @@ -520,7 +520,7 @@ def chat_completion( # type: ignore top_logprobs: Optional[int] = None, top_p: Optional[float] = None, extra_body: Optional[Dict] = None, - ) -> ChatCompletionOutput: ... + ) -> ChatCompletionOutput: ... @overload def chat_completion( # type: ignore @@ -546,7 +546,7 @@ def chat_completion( # type: ignore top_logprobs: Optional[int] = None, top_p: Optional[float] = None, extra_body: Optional[Dict] = None, - ) -> Iterable[ChatCompletionStreamOutput]: ... + ) -> Iterable[ChatCompletionStreamOutput]: ... @overload def chat_completion( @@ -572,7 +572,7 @@ def chat_completion( top_logprobs: Optional[int] = None, top_p: Optional[float] = None, extra_body: Optional[Dict] = None, - ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: ... + ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: ... def chat_completion( self, @@ -1918,7 +1918,7 @@ def text_generation( # type: ignore truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, - ) -> str: ... + ) -> str: ... @overload def text_generation( # type: ignore @@ -1948,7 +1948,7 @@ def text_generation( # type: ignore truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, - ) -> TextGenerationOutput: ... + ) -> TextGenerationOutput: ... @overload def text_generation( # type: ignore @@ -1978,7 +1978,7 @@ def text_generation( # type: ignore truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, - ) -> Iterable[str]: ... + ) -> Iterable[str]: ... @overload def text_generation( # type: ignore @@ -2008,7 +2008,7 @@ def text_generation( # type: ignore truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, - ) -> Iterable[TextGenerationStreamOutput]: ... + ) -> Iterable[TextGenerationStreamOutput]: ... @overload def text_generation( @@ -2038,7 +2038,7 @@ def text_generation( truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, - ) -> Union[TextGenerationOutput, Iterable[TextGenerationStreamOutput]]: ... + ) -> Union[TextGenerationOutput, Iterable[TextGenerationStreamOutput]]: ... def text_generation( self, diff --git a/src/huggingface_hub/inference/_providers/__init__.py b/src/huggingface_hub/inference/_providers/__init__.py index c6c56259ef..aa3d9b4392 100644 --- a/src/huggingface_hub/inference/_providers/__init__.py +++ b/src/huggingface_hub/inference/_providers/__init__.py @@ -21,7 +21,6 @@ from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask from .centml import CentmlConversationalTask, CentmlTextGenerationTask - PROVIDER_T = Literal[ "black-forest-labs", "centml", @@ -63,58 +62,32 @@ "conversational": FireworksAIConversationalTask(), }, "hf-inference": { - "text-to-image": - HFInferenceTask("text-to-image"), - "conversational": - HFInferenceConversational(), - "text-generation": - HFInferenceTask("text-generation"), - "text-classification": - HFInferenceTask("text-classification"), - "question-answering": - HFInferenceTask("question-answering"), - "audio-classification": - HFInferenceBinaryInputTask("audio-classification"), - "automatic-speech-recognition": - HFInferenceBinaryInputTask("automatic-speech-recognition"), - "fill-mask": - HFInferenceTask("fill-mask"), - "feature-extraction": - HFInferenceTask("feature-extraction"), - "image-classification": - HFInferenceBinaryInputTask("image-classification"), - "image-segmentation": - HFInferenceBinaryInputTask("image-segmentation"), - "document-question-answering": - HFInferenceTask("document-question-answering"), - "image-to-text": - HFInferenceBinaryInputTask("image-to-text"), - "object-detection": - HFInferenceBinaryInputTask("object-detection"), - "audio-to-audio": - HFInferenceBinaryInputTask("audio-to-audio"), - "zero-shot-image-classification": - HFInferenceBinaryInputTask("zero-shot-image-classification"), - "zero-shot-classification": - HFInferenceTask("zero-shot-classification"), - "image-to-image": - HFInferenceBinaryInputTask("image-to-image"), - "sentence-similarity": - HFInferenceTask("sentence-similarity"), - "table-question-answering": - HFInferenceTask("table-question-answering"), - "tabular-classification": - HFInferenceTask("tabular-classification"), - "text-to-speech": - HFInferenceTask("text-to-speech"), - "token-classification": - HFInferenceTask("token-classification"), - "translation": - HFInferenceTask("translation"), - "summarization": - HFInferenceTask("summarization"), - "visual-question-answering": - HFInferenceBinaryInputTask("visual-question-answering"), + "text-to-image": HFInferenceTask("text-to-image"), + "conversational": HFInferenceConversational(), + "text-generation": HFInferenceTask("text-generation"), + "text-classification": HFInferenceTask("text-classification"), + "question-answering": HFInferenceTask("question-answering"), + "audio-classification": HFInferenceBinaryInputTask("audio-classification"), + "automatic-speech-recognition": HFInferenceBinaryInputTask("automatic-speech-recognition"), + "fill-mask": HFInferenceTask("fill-mask"), + "feature-extraction": HFInferenceTask("feature-extraction"), + "image-classification": HFInferenceBinaryInputTask("image-classification"), + "image-segmentation": HFInferenceBinaryInputTask("image-segmentation"), + "document-question-answering": HFInferenceTask("document-question-answering"), + "image-to-text": HFInferenceBinaryInputTask("image-to-text"), + "object-detection": HFInferenceBinaryInputTask("object-detection"), + "audio-to-audio": HFInferenceBinaryInputTask("audio-to-audio"), + "zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"), + "zero-shot-classification": HFInferenceTask("zero-shot-classification"), + "image-to-image": HFInferenceBinaryInputTask("image-to-image"), + "sentence-similarity": HFInferenceTask("sentence-similarity"), + "table-question-answering": HFInferenceTask("table-question-answering"), + "tabular-classification": HFInferenceTask("tabular-classification"), + "text-to-speech": HFInferenceTask("text-to-speech"), + "token-classification": HFInferenceTask("token-classification"), + "translation": HFInferenceTask("translation"), + "summarization": HFInferenceTask("summarization"), + "visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"), }, "hyperbolic": { "text-to-image": HyperbolicTextToImageTask(), From 4b2b72471f7d77ad99994e357041f60d90ffbcc3 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Wed, 26 Mar 2025 21:07:56 -0400 Subject: [PATCH 4/4] revert auto-format --- tests/test_inference_client.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index 6b41e4e200..4e328d519b 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -83,14 +83,11 @@ "conversational": "meta-llama/Llama-3.3-70B-Instruct", }, "hf-inference": { - "audio-classification": - "alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech", + "audio-classification": "alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech", "audio-to-audio": "speechbrain/sepformer-wham", - "automatic-speech-recognition": - "jonatasgrosman/wav2vec2-large-xlsr-53-english", + "automatic-speech-recognition": "jonatasgrosman/wav2vec2-large-xlsr-53-english", "conversational": "meta-llama/Llama-3.1-8B-Instruct", - "document-question-answering": - "naver-clova-ix/donut-base-finetuned-docvqa", + "document-question-answering": "naver-clova-ix/donut-base-finetuned-docvqa", "feature-extraction": "facebook/bart-base", "image-classification": "google/vit-base-patch16-224", "image-to-text": "Salesforce/blip-image-captioning-base", @@ -101,12 +98,10 @@ "table-question-answering": "google/tapas-base-finetuned-wtq", "tabular-classification": "julien-c/wine-quality", "tabular-regression": "scikit-learn/Fish-Weight", - "text-classification": - "distilbert/distilbert-base-uncased-finetuned-sst-2-english", + "text-classification": "distilbert/distilbert-base-uncased-finetuned-sst-2-english", "text-to-image": "CompVis/stable-diffusion-v1-4", "text-to-speech": "espnet/kan-bayashi_ljspeech_vits", - "token-classification": - "dbmdz/bert-large-cased-finetuned-conll03-english", + "token-classification": "dbmdz/bert-large-cased-finetuned-conll03-english", "translation": "t5-small", "visual-question-answering": "dandelin/vilt-b32-finetuned-vqa", "zero-shot-classification": "facebook/bart-large-mnli",