diff --git a/.env.sample b/.env.sample index cc933d9b..f45e5755 100644 --- a/.env.sample +++ b/.env.sample @@ -36,3 +36,6 @@ XAI_API_KEY= # Sambanova SAMBANOVA_API_KEY= + +# OpenVINO model server API +OVMS_API_URL= diff --git a/aisuite/providers/openvino_provider.py b/aisuite/providers/openvino_provider.py new file mode 100644 index 00000000..4549b579 --- /dev/null +++ b/aisuite/providers/openvino_provider.py @@ -0,0 +1,48 @@ +import os + +from openai import OpenAI + +from aisuite.provider import LLMError, Provider +from aisuite.providers.message_converter import OpenAICompliantMessageConverter + + +class OpenvinoMessageConverter(OpenAICompliantMessageConverter): + """ + Openvino-specific message converter. + """ + + pass + + +class OpenvinoProvider(Provider): + """ + OpenVINO Provider that makes chat completions requests using the OpenAI client. + This provider can be used with OpenVINO Model Server. + """ + + def __init__(self, **config): + """ + Initialize the OpenVINO provider with the given configuration. + """ + self.url = config.get("api_url") or os.getenv( + "OVMS_API_URL", "http://localhost:8000/v3" + ) + config["base_url"] = self.url + config["api_key"] = "unused" + self.client = OpenAI(**config) + self.transformer = OpenvinoMessageConverter() + + def chat_completions_create(self, model, messages, **kwargs): + """ + Makes a request to the OpenVINO Model Server chat completions endpoint using the OpenAI client. + """ + try: + transformed_messages = self.transformer.convert_request(messages) + response = self.client.chat.completions.create( + model=model, + messages=transformed_messages, + **kwargs, # Pass any additional arguments to the OpenVINO Chat API + ) + return self.transformer.convert_response(response.model_dump()) + except Exception as e: + raise LLMError(f"An error occurred: {e}") diff --git a/tests/providers/test_openvino_provider.py b/tests/providers/test_openvino_provider.py new file mode 100644 index 00000000..3732429e --- /dev/null +++ b/tests/providers/test_openvino_provider.py @@ -0,0 +1,48 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from aisuite.providers.openvino_provider import OpenvinoProvider + + +@pytest.fixture(autouse=True) +def set_api_url_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("OVMS_API_URL", "http://localhost:8000/v3") + + +def test_openvino_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + provider = OpenvinoProvider() + mock_response = MagicMock() + mock_response.model_dump.return_value = { + "choices": [ + {"message": {"content": response_text_content, "role": "assistant"}} + ] + } + + with patch.object( + provider.client.chat.completions, + "create", + return_value=mock_response, + ) as mock_create: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + assert response.choices[0].message.content == response_text_content