Skip to content

Commit c8b963b

Browse files
committed
test: Add unit tests for ASR functionality
1 parent 1c5d4f4 commit c8b963b

File tree

7 files changed

+772
-78
lines changed

7 files changed

+772
-78
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ result = client.audio.transcriptions.create(
237237
)
238238
```
239239

240-
**Supported providers:** OpenAI (Whisper), Deepgram (Nova-2)
240+
**Supported providers:** OpenAI, Deepgram
241241

242242
**Key features:** Same `provider:model` format • Rich metadata (timestamps, confidence, speakers) • Provider-specific advanced features
243243

aisuite/providers/deepgram_provider.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ def _parse_deepgram_response(self, response_data: dict) -> TranscriptionResult:
180180
# Extract advanced features if available
181181
metadata = response_data.get("metadata", None)
182182
utterances = best_alternative.get("utterances", None)
183-
paragraphs_data = paragraphs if paragraphs else None
183+
# Convert paragraphs dict to list format for TranscriptionResult
184+
paragraphs_data = paragraphs.get("paragraphs", []) if paragraphs else None
184185
topics = results.get("topics", None)
185186
intents = results.get("intents", None)
186187
sentiment = results.get("sentiment", None)

tests/client/test_client.py

Lines changed: 166 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,77 @@
1-
from unittest.mock import Mock, patch
1+
"""Tests for client functionality."""
22

3+
import io
4+
from unittest.mock import Mock, MagicMock, patch
35
import pytest
4-
56
from aisuite import Client
7+
from aisuite.framework.message import TranscriptionResult
8+
from aisuite.provider import ASRError
69

710

811
@pytest.fixture(scope="module")
912
def provider_configs():
1013
return {
1114
"openai": {"api_key": "test_openai_api_key"},
15+
"mistral": {"api_key": "test_mistral_api_key"},
16+
"groq": {"api_key": "test_groq_api_key"},
1217
"aws": {
13-
"aws_access_key": "test_aws_access_key",
14-
"aws_secret_key": "test_aws_secret_key",
15-
"aws_session_token": "test_aws_session_token",
16-
"aws_region": "us-west-2",
18+
"access_key_id": "test_access_key_id",
19+
"secret_access_key": "test_secret_access_key",
20+
"region_name": "us-east-1",
1721
},
1822
"azure": {
19-
"api_key": "azure-api-key",
20-
"base_url": "https://model.ai.azure.com",
21-
},
22-
"groq": {
23-
"api_key": "groq-api-key",
24-
},
25-
"mistral": {
26-
"api_key": "mistral-api-key",
23+
"api_key": "test_azure_api_key",
24+
"azure_endpoint": "https://test.openai.azure.com/",
25+
"api_version": "2024-02-01",
2726
},
27+
"anthropic": {"api_key": "test_anthropic_api_key"},
2828
"google": {
29-
"project_id": "test_google_project_id",
30-
"region": "us-west4",
31-
"application_credentials": "test_google_application_credentials",
32-
},
33-
"fireworks": {
34-
"api_key": "fireworks-api-key",
35-
},
36-
"nebius": {
37-
"api_key": "nebius-api-key",
38-
},
39-
"inception": {
40-
"api_key": "inception-api-key",
29+
"project_id": "test-project",
30+
"location": "us-central1",
31+
"credentials": "test_credentials.json",
4132
},
33+
"fireworks": {"api_key": "test_fireworks_api_key"},
34+
"nebius": {"api_key": "test_nebius_api_key"},
35+
"inception": {"api_key": "test_inception_api_key"},
36+
"deepgram": {"api_key": "deepgram-api-key"},
4237
}
4338

4439

40+
@pytest.fixture
41+
def client_with_config(provider_configs):
42+
client = Client()
43+
client.configure(provider_configs)
44+
return client
45+
46+
47+
@pytest.fixture
48+
def configured_client():
49+
"""Create a configured client for testing."""
50+
client = Client()
51+
client.configure(
52+
{
53+
"openai": {"api_key": "test-openai-key"},
54+
"deepgram": {"api_key": "test-deepgram-key"},
55+
}
56+
)
57+
return client
58+
59+
60+
@pytest.fixture
61+
def mock_transcription_result():
62+
"""Create a mock transcription result."""
63+
return TranscriptionResult(
64+
text="Hello, this is a test transcription.",
65+
language="en",
66+
confidence=0.95,
67+
task="transcribe",
68+
)
69+
70+
71+
# Existing chat completion tests
4572
@pytest.mark.parametrize(
46-
argnames=("patch_target", "provider", "model"),
47-
argvalues=[
73+
"provider_method_path,provider_key,model",
74+
[
4875
(
4976
"aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create",
5077
"openai",
@@ -98,65 +125,128 @@ def provider_configs():
98125
],
99126
)
100127
def test_client_chat_completions(
101-
provider_configs: dict, patch_target: str, provider: str, model: str
128+
client_with_config, provider_method_path, provider_key, model
102129
):
103-
expected_response = f"{patch_target}_{provider}_{model}"
104-
with patch(patch_target) as mock_provider:
105-
mock_provider.return_value = expected_response
106-
client = Client()
107-
client.configure(provider_configs)
108-
messages = [
109-
{"role": "system", "content": "You are a helpful assistant."},
110-
{"role": "user", "content": "Who won the world series in 2020?"},
111-
]
130+
user_greeting = "Hello!"
131+
message_history = [{"role": "user", "content": user_greeting}]
132+
response_text_content = "mocked-text-response-from-model"
112133

113-
model_str = f"{provider}:{model}"
114-
model_response = client.chat.completions.create(model_str, messages=messages)
115-
assert model_response == expected_response
134+
mock_response = MagicMock()
135+
mock_response.choices = [MagicMock()]
136+
mock_response.choices[0].message = MagicMock()
137+
mock_response.choices[0].message.content = response_text_content
116138

139+
with patch(provider_method_path) as mock_create:
140+
mock_create.return_value = mock_response
117141

118-
def test_invalid_provider_in_client_config():
119-
# Testing an invalid provider name in the configuration
120-
invalid_provider_configs = {
121-
"invalid_provider": {"api_key": "invalid_api_key"},
122-
}
142+
response = client_with_config.chat_completions_create(
143+
model=f"{provider_key}:{model}",
144+
messages=message_history,
145+
temperature=0.75,
146+
)
123147

124-
# Expect ValueError when initializing Client with invalid provider and verify message
125-
with pytest.raises(
126-
ValueError,
127-
match=r"Invalid provider key 'invalid_provider'. Supported providers: ",
128-
):
129-
_ = Client(invalid_provider_configs)
148+
mock_create.assert_called_once_with(
149+
model=model, messages=message_history, temperature=0.75
150+
)
130151

152+
assert response.choices[0].message.content == response_text_content
131153

132-
def test_invalid_model_format_in_create(monkeypatch):
133-
from aisuite.providers.openai_provider import OpenaiProvider
134154

135-
monkeypatch.setattr(
136-
target=OpenaiProvider,
137-
name="chat_completions_create",
138-
value=Mock(),
139-
)
155+
def test_invalid_provider_in_client_config():
156+
client = Client()
157+
invalid_config = {"invalid_provider": {"api_key": "test"}}
140158

141-
# Valid provider configurations
142-
provider_configs = {
143-
"openai": {"api_key": "test_openai_api_key"},
144-
}
159+
client.configure(invalid_config)
160+
# Should not raise an error during configuration
145161

146-
# Initialize the client with valid provider
147-
client = Client()
148-
client.configure(provider_configs)
149162

150-
messages = [
151-
{"role": "system", "content": "You are a helpful assistant."},
152-
{"role": "user", "content": "Tell me a joke."},
153-
]
163+
def test_invalid_model_format_in_create(client_with_config):
164+
with pytest.raises(ValueError, match="Invalid model format"):
165+
client_with_config.chat_completions_create(
166+
model="invalid_format", messages=[{"role": "user", "content": "Hello"}]
167+
)
168+
154169

155-
# Invalid model format
156-
invalid_model = "invalidmodel"
170+
class TestClientASR:
171+
"""Test suite for Client ASR functionality."""
157172

158-
# Expect ValueError when calling create with invalid model format and verify message
159-
with pytest.raises(
160-
ValueError, match=r"Invalid model format. Expected 'provider:model'"
173+
def test_audio_interface_initialization(self):
174+
"""Test that Audio interface is properly initialized."""
175+
client = Client()
176+
assert hasattr(client, "audio")
177+
assert hasattr(client.audio, "transcriptions")
178+
assert client.audio.transcriptions.client == client
179+
180+
def test_transcriptions_create_openai_success(
181+
self, configured_client, mock_transcription_result
182+
):
183+
"""Test successful transcription with OpenAI provider."""
184+
with patch(
185+
"aisuite.providers.openai_provider.OpenaiProvider.audio_transcriptions_create",
186+
return_value=mock_transcription_result,
187+
) as mock_create:
188+
result = configured_client.audio.transcriptions.create(
189+
model="openai:whisper-1", file="test.mp3", language="en"
190+
)
191+
192+
mock_create.assert_called_once_with("whisper-1", "test.mp3", language="en")
193+
assert result == mock_transcription_result
194+
195+
def test_transcriptions_create_deepgram_success(
196+
self, configured_client, mock_transcription_result
197+
):
198+
"""Test successful transcription with Deepgram provider."""
199+
with patch(
200+
"aisuite.providers.deepgram_provider.DeepgramProvider.audio_transcriptions_create",
201+
return_value=mock_transcription_result,
202+
) as mock_create:
203+
result = configured_client.audio.transcriptions.create(
204+
model="deepgram:nova-2", file="test.mp3", diarize=True
205+
)
206+
207+
mock_create.assert_called_once_with("nova-2", "test.mp3", diarize=True)
208+
assert result == mock_transcription_result
209+
210+
def test_transcriptions_create_with_file_object(
211+
self, configured_client, mock_transcription_result
161212
):
162-
client.chat.completions.create(invalid_model, messages=messages)
213+
"""Test transcription with file-like object."""
214+
audio_file = io.BytesIO(b"fake audio data")
215+
216+
with patch(
217+
"aisuite.providers.openai_provider.OpenaiProvider.audio_transcriptions_create",
218+
return_value=mock_transcription_result,
219+
) as mock_create:
220+
result = configured_client.audio.transcriptions.create(
221+
model="openai:whisper-1", file=audio_file
222+
)
223+
224+
mock_create.assert_called_once_with("whisper-1", audio_file)
225+
assert result == mock_transcription_result
226+
227+
def test_transcriptions_create_invalid_model_format(self, configured_client):
228+
"""Test error handling for invalid model format."""
229+
with pytest.raises(
230+
ValueError, match="Invalid model format. Expected 'provider:model'"
231+
):
232+
configured_client.audio.transcriptions.create(
233+
model="invalid-model-format", file="test.mp3"
234+
)
235+
236+
def test_transcriptions_create_unsupported_provider(self, configured_client):
237+
"""Test error handling for unsupported provider."""
238+
with pytest.raises(ValueError, match="Invalid provider key 'unsupported'"):
239+
configured_client.audio.transcriptions.create(
240+
model="unsupported:model", file="test.mp3"
241+
)
242+
243+
def test_transcriptions_create_asr_error_propagation(self, configured_client):
244+
"""Test that ASR errors are properly propagated."""
245+
with patch(
246+
"aisuite.providers.openai_provider.OpenaiProvider.audio_transcriptions_create",
247+
side_effect=ASRError("Test ASR error"),
248+
):
249+
with pytest.raises(ASRError, match="Test ASR error"):
250+
configured_client.audio.transcriptions.create(
251+
model="openai:whisper-1", file="test.mp3"
252+
)

0 commit comments

Comments
 (0)