Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ XAI_API_KEY=

# Sambanova
SAMBANOVA_API_KEY=

# OpenVINO model server API
OVMS_API_URL=
48 changes: 48 additions & 0 deletions aisuite/providers/openvino_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os

from openai import OpenAI

from aisuite.provider import LLMError, Provider
from aisuite.providers.message_converter import OpenAICompliantMessageConverter


class OpenvinoMessageConverter(OpenAICompliantMessageConverter):
"""
Openvino-specific message converter.
"""

pass


class OpenvinoProvider(Provider):
"""
OpenVINO Provider that makes chat completions requests using the OpenAI client.
This provider can be used with OpenVINO Model Server.
"""

def __init__(self, **config):
"""
Initialize the OpenVINO provider with the given configuration.
"""
self.url = config.get("api_url") or os.getenv(
"OVMS_API_URL", "http://localhost:8000/v3"
)
config["base_url"] = self.url
config["api_key"] = "unused"
self.client = OpenAI(**config)
self.transformer = OpenvinoMessageConverter()

def chat_completions_create(self, model, messages, **kwargs):
"""
Makes a request to the OpenVINO Model Server chat completions endpoint using the OpenAI client.
"""
try:
transformed_messages = self.transformer.convert_request(messages)
response = self.client.chat.completions.create(
model=model,
messages=transformed_messages,
**kwargs, # Pass any additional arguments to the OpenVINO Chat API
)
return self.transformer.convert_response(response.model_dump())
except Exception as e:
raise LLMError(f"An error occurred: {e}")
48 changes: 48 additions & 0 deletions tests/providers/test_openvino_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from unittest.mock import MagicMock, patch

import pytest

from aisuite.providers.openvino_provider import OpenvinoProvider


@pytest.fixture(autouse=True)
def set_api_url_var(monkeypatch):
"""Fixture to set environment variables for tests."""
monkeypatch.setenv("OVMS_API_URL", "http://localhost:8000/v3")


def test_openvino_provider():
"""High-level test that the provider is initialized and chat completions are requested successfully."""

user_greeting = "Hello!"
message_history = [{"role": "user", "content": user_greeting}]
selected_model = "our-favorite-model"
chosen_temperature = 0.75
response_text_content = "mocked-text-response-from-model"

provider = OpenvinoProvider()
mock_response = MagicMock()
mock_response.model_dump.return_value = {
"choices": [
{"message": {"content": response_text_content, "role": "assistant"}}
]
}

with patch.object(
provider.client.chat.completions,
"create",
return_value=mock_response,
) as mock_create:
response = provider.chat_completions_create(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

mock_create.assert_called_with(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

assert response.choices[0].message.content == response_text_content