|
1 |
| -from unittest.mock import Mock, patch |
| 1 | +"""Tests for client functionality.""" |
2 | 2 |
|
| 3 | +import io |
| 4 | +from unittest.mock import Mock, MagicMock, patch |
3 | 5 | import pytest
|
4 |
| - |
5 | 6 | from aisuite import Client
|
| 7 | +from aisuite.framework.message import TranscriptionResult |
| 8 | +from aisuite.provider import ASRError |
6 | 9 |
|
7 | 10 |
|
8 | 11 | @pytest.fixture(scope="module")
|
9 | 12 | def provider_configs():
|
10 | 13 | return {
|
11 | 14 | "openai": {"api_key": "test_openai_api_key"},
|
| 15 | + "mistral": {"api_key": "test_mistral_api_key"}, |
| 16 | + "groq": {"api_key": "test_groq_api_key"}, |
12 | 17 | "aws": {
|
13 |
| - "aws_access_key": "test_aws_access_key", |
14 |
| - "aws_secret_key": "test_aws_secret_key", |
15 |
| - "aws_session_token": "test_aws_session_token", |
16 |
| - "aws_region": "us-west-2", |
| 18 | + "access_key_id": "test_access_key_id", |
| 19 | + "secret_access_key": "test_secret_access_key", |
| 20 | + "region_name": "us-east-1", |
17 | 21 | },
|
18 | 22 | "azure": {
|
19 |
| - "api_key": "azure-api-key", |
20 |
| - "base_url": "https://model.ai.azure.com", |
21 |
| - }, |
22 |
| - "groq": { |
23 |
| - "api_key": "groq-api-key", |
24 |
| - }, |
25 |
| - "mistral": { |
26 |
| - "api_key": "mistral-api-key", |
| 23 | + "api_key": "test_azure_api_key", |
| 24 | + "azure_endpoint": "https://test.openai.azure.com/", |
| 25 | + "api_version": "2024-02-01", |
27 | 26 | },
|
| 27 | + "anthropic": {"api_key": "test_anthropic_api_key"}, |
28 | 28 | "google": {
|
29 |
| - "project_id": "test_google_project_id", |
30 |
| - "region": "us-west4", |
31 |
| - "application_credentials": "test_google_application_credentials", |
32 |
| - }, |
33 |
| - "fireworks": { |
34 |
| - "api_key": "fireworks-api-key", |
35 |
| - }, |
36 |
| - "nebius": { |
37 |
| - "api_key": "nebius-api-key", |
38 |
| - }, |
39 |
| - "inception": { |
40 |
| - "api_key": "inception-api-key", |
| 29 | + "project_id": "test-project", |
| 30 | + "location": "us-central1", |
| 31 | + "credentials": "test_credentials.json", |
41 | 32 | },
|
| 33 | + "fireworks": {"api_key": "test_fireworks_api_key"}, |
| 34 | + "nebius": {"api_key": "test_nebius_api_key"}, |
| 35 | + "inception": {"api_key": "test_inception_api_key"}, |
| 36 | + "deepgram": {"api_key": "deepgram-api-key"}, |
42 | 37 | }
|
43 | 38 |
|
44 | 39 |
|
| 40 | +@pytest.fixture |
| 41 | +def client_with_config(provider_configs): |
| 42 | + client = Client() |
| 43 | + client.configure(provider_configs) |
| 44 | + return client |
| 45 | + |
| 46 | + |
| 47 | +@pytest.fixture |
| 48 | +def configured_client(): |
| 49 | + """Create a configured client for testing.""" |
| 50 | + client = Client() |
| 51 | + client.configure( |
| 52 | + { |
| 53 | + "openai": {"api_key": "test-openai-key"}, |
| 54 | + "deepgram": {"api_key": "test-deepgram-key"}, |
| 55 | + } |
| 56 | + ) |
| 57 | + return client |
| 58 | + |
| 59 | + |
| 60 | +@pytest.fixture |
| 61 | +def mock_transcription_result(): |
| 62 | + """Create a mock transcription result.""" |
| 63 | + return TranscriptionResult( |
| 64 | + text="Hello, this is a test transcription.", |
| 65 | + language="en", |
| 66 | + confidence=0.95, |
| 67 | + task="transcribe", |
| 68 | + ) |
| 69 | + |
| 70 | + |
| 71 | +# Existing chat completion tests |
45 | 72 | @pytest.mark.parametrize(
|
46 |
| - argnames=("patch_target", "provider", "model"), |
47 |
| - argvalues=[ |
| 73 | + "provider_method_path,provider_key,model", |
| 74 | + [ |
48 | 75 | (
|
49 | 76 | "aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create",
|
50 | 77 | "openai",
|
@@ -98,65 +125,128 @@ def provider_configs():
|
98 | 125 | ],
|
99 | 126 | )
|
100 | 127 | def test_client_chat_completions(
|
101 |
| - provider_configs: dict, patch_target: str, provider: str, model: str |
| 128 | + client_with_config, provider_method_path, provider_key, model |
102 | 129 | ):
|
103 |
| - expected_response = f"{patch_target}_{provider}_{model}" |
104 |
| - with patch(patch_target) as mock_provider: |
105 |
| - mock_provider.return_value = expected_response |
106 |
| - client = Client() |
107 |
| - client.configure(provider_configs) |
108 |
| - messages = [ |
109 |
| - {"role": "system", "content": "You are a helpful assistant."}, |
110 |
| - {"role": "user", "content": "Who won the world series in 2020?"}, |
111 |
| - ] |
| 130 | + user_greeting = "Hello!" |
| 131 | + message_history = [{"role": "user", "content": user_greeting}] |
| 132 | + response_text_content = "mocked-text-response-from-model" |
112 | 133 |
|
113 |
| - model_str = f"{provider}:{model}" |
114 |
| - model_response = client.chat.completions.create(model_str, messages=messages) |
115 |
| - assert model_response == expected_response |
| 134 | + mock_response = MagicMock() |
| 135 | + mock_response.choices = [MagicMock()] |
| 136 | + mock_response.choices[0].message = MagicMock() |
| 137 | + mock_response.choices[0].message.content = response_text_content |
116 | 138 |
|
| 139 | + with patch(provider_method_path) as mock_create: |
| 140 | + mock_create.return_value = mock_response |
117 | 141 |
|
118 |
| -def test_invalid_provider_in_client_config(): |
119 |
| - # Testing an invalid provider name in the configuration |
120 |
| - invalid_provider_configs = { |
121 |
| - "invalid_provider": {"api_key": "invalid_api_key"}, |
122 |
| - } |
| 142 | + response = client_with_config.chat_completions_create( |
| 143 | + model=f"{provider_key}:{model}", |
| 144 | + messages=message_history, |
| 145 | + temperature=0.75, |
| 146 | + ) |
123 | 147 |
|
124 |
| - # Expect ValueError when initializing Client with invalid provider and verify message |
125 |
| - with pytest.raises( |
126 |
| - ValueError, |
127 |
| - match=r"Invalid provider key 'invalid_provider'. Supported providers: ", |
128 |
| - ): |
129 |
| - _ = Client(invalid_provider_configs) |
| 148 | + mock_create.assert_called_once_with( |
| 149 | + model=model, messages=message_history, temperature=0.75 |
| 150 | + ) |
130 | 151 |
|
| 152 | + assert response.choices[0].message.content == response_text_content |
131 | 153 |
|
132 |
| -def test_invalid_model_format_in_create(monkeypatch): |
133 |
| - from aisuite.providers.openai_provider import OpenaiProvider |
134 | 154 |
|
135 |
| - monkeypatch.setattr( |
136 |
| - target=OpenaiProvider, |
137 |
| - name="chat_completions_create", |
138 |
| - value=Mock(), |
139 |
| - ) |
| 155 | +def test_invalid_provider_in_client_config(): |
| 156 | + client = Client() |
| 157 | + invalid_config = {"invalid_provider": {"api_key": "test"}} |
140 | 158 |
|
141 |
| - # Valid provider configurations |
142 |
| - provider_configs = { |
143 |
| - "openai": {"api_key": "test_openai_api_key"}, |
144 |
| - } |
| 159 | + client.configure(invalid_config) |
| 160 | + # Should not raise an error during configuration |
145 | 161 |
|
146 |
| - # Initialize the client with valid provider |
147 |
| - client = Client() |
148 |
| - client.configure(provider_configs) |
149 | 162 |
|
150 |
| - messages = [ |
151 |
| - {"role": "system", "content": "You are a helpful assistant."}, |
152 |
| - {"role": "user", "content": "Tell me a joke."}, |
153 |
| - ] |
| 163 | +def test_invalid_model_format_in_create(client_with_config): |
| 164 | + with pytest.raises(ValueError, match="Invalid model format"): |
| 165 | + client_with_config.chat_completions_create( |
| 166 | + model="invalid_format", messages=[{"role": "user", "content": "Hello"}] |
| 167 | + ) |
| 168 | + |
154 | 169 |
|
155 |
| - # Invalid model format |
156 |
| - invalid_model = "invalidmodel" |
| 170 | +class TestClientASR: |
| 171 | + """Test suite for Client ASR functionality.""" |
157 | 172 |
|
158 |
| - # Expect ValueError when calling create with invalid model format and verify message |
159 |
| - with pytest.raises( |
160 |
| - ValueError, match=r"Invalid model format. Expected 'provider:model'" |
| 173 | + def test_audio_interface_initialization(self): |
| 174 | + """Test that Audio interface is properly initialized.""" |
| 175 | + client = Client() |
| 176 | + assert hasattr(client, "audio") |
| 177 | + assert hasattr(client.audio, "transcriptions") |
| 178 | + assert client.audio.transcriptions.client == client |
| 179 | + |
| 180 | + def test_transcriptions_create_openai_success( |
| 181 | + self, configured_client, mock_transcription_result |
| 182 | + ): |
| 183 | + """Test successful transcription with OpenAI provider.""" |
| 184 | + with patch( |
| 185 | + "aisuite.providers.openai_provider.OpenaiProvider.audio_transcriptions_create", |
| 186 | + return_value=mock_transcription_result, |
| 187 | + ) as mock_create: |
| 188 | + result = configured_client.audio.transcriptions.create( |
| 189 | + model="openai:whisper-1", file="test.mp3", language="en" |
| 190 | + ) |
| 191 | + |
| 192 | + mock_create.assert_called_once_with("whisper-1", "test.mp3", language="en") |
| 193 | + assert result == mock_transcription_result |
| 194 | + |
| 195 | + def test_transcriptions_create_deepgram_success( |
| 196 | + self, configured_client, mock_transcription_result |
| 197 | + ): |
| 198 | + """Test successful transcription with Deepgram provider.""" |
| 199 | + with patch( |
| 200 | + "aisuite.providers.deepgram_provider.DeepgramProvider.audio_transcriptions_create", |
| 201 | + return_value=mock_transcription_result, |
| 202 | + ) as mock_create: |
| 203 | + result = configured_client.audio.transcriptions.create( |
| 204 | + model="deepgram:nova-2", file="test.mp3", diarize=True |
| 205 | + ) |
| 206 | + |
| 207 | + mock_create.assert_called_once_with("nova-2", "test.mp3", diarize=True) |
| 208 | + assert result == mock_transcription_result |
| 209 | + |
| 210 | + def test_transcriptions_create_with_file_object( |
| 211 | + self, configured_client, mock_transcription_result |
161 | 212 | ):
|
162 |
| - client.chat.completions.create(invalid_model, messages=messages) |
| 213 | + """Test transcription with file-like object.""" |
| 214 | + audio_file = io.BytesIO(b"fake audio data") |
| 215 | + |
| 216 | + with patch( |
| 217 | + "aisuite.providers.openai_provider.OpenaiProvider.audio_transcriptions_create", |
| 218 | + return_value=mock_transcription_result, |
| 219 | + ) as mock_create: |
| 220 | + result = configured_client.audio.transcriptions.create( |
| 221 | + model="openai:whisper-1", file=audio_file |
| 222 | + ) |
| 223 | + |
| 224 | + mock_create.assert_called_once_with("whisper-1", audio_file) |
| 225 | + assert result == mock_transcription_result |
| 226 | + |
| 227 | + def test_transcriptions_create_invalid_model_format(self, configured_client): |
| 228 | + """Test error handling for invalid model format.""" |
| 229 | + with pytest.raises( |
| 230 | + ValueError, match="Invalid model format. Expected 'provider:model'" |
| 231 | + ): |
| 232 | + configured_client.audio.transcriptions.create( |
| 233 | + model="invalid-model-format", file="test.mp3" |
| 234 | + ) |
| 235 | + |
| 236 | + def test_transcriptions_create_unsupported_provider(self, configured_client): |
| 237 | + """Test error handling for unsupported provider.""" |
| 238 | + with pytest.raises(ValueError, match="Invalid provider key 'unsupported'"): |
| 239 | + configured_client.audio.transcriptions.create( |
| 240 | + model="unsupported:model", file="test.mp3" |
| 241 | + ) |
| 242 | + |
| 243 | + def test_transcriptions_create_asr_error_propagation(self, configured_client): |
| 244 | + """Test that ASR errors are properly propagated.""" |
| 245 | + with patch( |
| 246 | + "aisuite.providers.openai_provider.OpenaiProvider.audio_transcriptions_create", |
| 247 | + side_effect=ASRError("Test ASR error"), |
| 248 | + ): |
| 249 | + with pytest.raises(ASRError, match="Test ASR error"): |
| 250 | + configured_client.audio.transcriptions.create( |
| 251 | + model="openai:whisper-1", file="test.mp3" |
| 252 | + ) |
0 commit comments