Skip to content

Commit 7abd0f2

Browse files
committed
Introduce unified TranscriptionOptions for better parameter management across providers.
1 parent 716cff7 commit 7abd0f2

15 files changed

+1752
-593
lines changed

aisuite/client.py

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from .provider import ProviderFactory
22
import os
33
from .utils.tools import Tools
4-
from typing import Union, BinaryIO
4+
from typing import Union, BinaryIO, Optional, Any
5+
from .framework.message import (
6+
TranscriptionOptions,
7+
TranscriptionResponse,
8+
)
59

610

711
class Client:
@@ -51,7 +55,7 @@ def _validate_provider_key(self, provider_key):
5155

5256
return provider_key
5357

54-
def configure(self, provider_configs: dict = None):
58+
def configure(self, provider_configs: Optional[dict] = None):
5559
"""
5660
Configure the client with provider configurations.
5761
"""
@@ -124,7 +128,7 @@ def _tool_runner(
124128
provider,
125129
model_name: str,
126130
messages: list,
127-
tools: any,
131+
tools: Any,
128132
max_turns: int,
129133
**kwargs,
130134
):
@@ -273,18 +277,47 @@ class Transcriptions:
273277
def __init__(self, client: "Client"):
274278
self.client = client
275279

276-
def create(self, *, model: str, file: Union[str, BinaryIO], **kwargs):
280+
def create(
281+
self,
282+
*,
283+
model: str,
284+
file: Union[str, BinaryIO],
285+
options: Optional[TranscriptionOptions] = None,
286+
**kwargs,
287+
) -> TranscriptionResponse:
277288
"""
278289
Create a transcription using the specified model and file.
279290
280291
Args:
281292
model: Provider and model in format 'provider:model' (e.g., 'openai:whisper-1')
282293
file: Audio file to transcribe (file path or file-like object)
283-
**kwargs: Additional parameters to pass to the provider
294+
options: TranscriptionOptions instance with unified parameters (includes stream control)
295+
**kwargs: Additional parameters (used if options is None, assumed to be OpenAI format)
284296
285297
Returns:
286-
TranscriptionResult: Unified transcription result
298+
TranscriptionResponse: Unified response (batch or streaming based on options.stream)
287299
"""
300+
# Validate options and kwargs
301+
if options is not None:
302+
if not options.has_any_parameters():
303+
raise ValueError(
304+
"TranscriptionOptions provided but no parameters are set. "
305+
"Please set at least one parameter or pass None to use kwargs."
306+
)
307+
# TranscriptionOptions takes precedence, ignore kwargs
308+
if kwargs:
309+
import warnings
310+
311+
warnings.warn(
312+
"Both TranscriptionOptions and kwargs provided. Using TranscriptionOptions and ignoring kwargs.",
313+
UserWarning,
314+
)
315+
elif not kwargs:
316+
# Neither options nor kwargs provided
317+
raise ValueError(
318+
"Either TranscriptionOptions or kwargs must be provided for transcription parameters."
319+
)
320+
288321
# Check that correct format is used
289322
if ":" not in model:
290323
raise ValueError(
@@ -294,29 +327,61 @@ def create(self, *, model: str, file: Union[str, BinaryIO], **kwargs):
294327
# Extract the provider key from the model identifier
295328
provider_key, model_name = model.split(":", 1)
296329

297-
# Validate if the provider is supported
298-
supported_providers = ProviderFactory.get_supported_providers()
299-
if provider_key not in supported_providers:
300-
raise ValueError(
301-
f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. "
302-
"Make sure the model string is formatted correctly as 'provider:model'."
303-
)
304-
305330
# Initialize provider if not already initialized
306331
if provider_key not in self.client.providers:
307332
config = self.client.provider_configs.get(provider_key, {})
308-
self.client.providers[provider_key] = ProviderFactory.create_provider(
309-
provider_key, config
310-
)
333+
try:
334+
self.client.providers[provider_key] = ProviderFactory.create_provider(
335+
provider_key, config
336+
)
337+
except ImportError as e:
338+
raise ValueError(f"Provider '{provider_key}' is not available: {e}")
311339

312340
provider = self.client.providers.get(provider_key)
313341
if not provider:
314342
raise ValueError(f"Could not load provider for '{provider_key}'.")
315343

344+
# Check if provider supports audio transcription
345+
if not hasattr(provider, "audio") or provider.audio is None:
346+
raise ValueError(
347+
f"Provider '{provider_key}' does not support audio transcription."
348+
)
349+
350+
# Determine if streaming is requested
351+
should_stream = False # Default to batch processing
352+
if options and options.stream is not None:
353+
should_stream = options.stream
354+
elif kwargs.get("stream"):
355+
should_stream = kwargs.get("stream", False)
356+
316357
# Delegate the transcription to the correct provider's implementation
317-
# The provider will raise NotImplementedError if it doesn't support ASR
318358
try:
319-
return provider.audio_transcriptions_create(model_name, file, **kwargs)
359+
if should_stream:
360+
# Check if provider supports output streaming
361+
if (
362+
hasattr(provider.audio, "transcriptions")
363+
and hasattr(provider.audio.transcriptions, "create_stream_output")
364+
):
365+
return provider.audio.transcriptions.create_stream_output(
366+
model_name, file, options=options, **kwargs
367+
)
368+
else:
369+
raise ValueError(
370+
f"Provider '{provider_key}' does not support output streaming transcription."
371+
)
372+
else:
373+
# Non-streaming (batch) transcription
374+
if (
375+
hasattr(provider.audio, "transcriptions")
376+
and hasattr(provider.audio.transcriptions, "create")
377+
):
378+
return provider.audio.transcriptions.create(
379+
model_name, file, options=options, **kwargs
380+
)
381+
else:
382+
raise ValueError(
383+
f"Provider '{provider_key}' does not support audio transcription."
384+
)
320385
except NotImplementedError:
321386
raise ValueError(
322387
f"Provider '{provider_key}' does not support audio transcription."

aisuite/framework/message.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
to the OpenAI style response.
44
"""
55

6-
from typing import Literal, Optional, List
6+
from typing import Literal, Optional, List, AsyncGenerator, Union, Dict, Any
77
from pydantic import BaseModel
8+
from dataclasses import dataclass, field
89

910

1011
class Function(BaseModel):
@@ -168,3 +169,129 @@ class TranscriptionResult(BaseModel):
168169
# Metadata
169170
metadata: Optional[dict] = None # Provider-specific metadata
170171
model_info: Optional[dict] = None # Model information
172+
173+
174+
class StreamingTranscriptionChunk(BaseModel):
175+
"""Represents a single chunk of streaming transcription data."""
176+
177+
text: str
178+
is_final: bool
179+
confidence: Optional[float] = None
180+
start_time: Optional[float] = None
181+
end_time: Optional[float] = None
182+
speaker_id: Optional[int] = None
183+
speaker_confidence: Optional[float] = None
184+
words: Optional[List[Word]] = None
185+
sequence_number: Optional[int] = None
186+
channel: Optional[int] = None
187+
provider_data: Optional[dict] = None
188+
189+
190+
# Type alias for streaming transcription responses
191+
StreamingTranscriptionResponse = AsyncGenerator[StreamingTranscriptionChunk, None]
192+
193+
# Union type for both batch and streaming responses
194+
TranscriptionResponse = Union[TranscriptionResult, StreamingTranscriptionResponse]
195+
196+
197+
@dataclass
198+
class TranscriptionOptions:
199+
"""Unified transcription options for ASR providers."""
200+
201+
# Core parameters
202+
language: Optional[str] = None
203+
204+
# Audio format parameters
205+
audio_format: Optional[str] = None
206+
sample_rate: Optional[int] = None
207+
channels: Optional[int] = None
208+
encoding: Optional[str] = None # Audio encoding type
209+
210+
# Output format
211+
response_format: Optional[str] = None
212+
include_word_timestamps: Optional[bool] = None
213+
include_segment_timestamps: Optional[bool] = None
214+
215+
# Context and guidance
216+
prompt: Optional[str] = None
217+
context_phrases: Optional[List[str]] = None
218+
boost_phrases: Optional[List[str]] = None
219+
220+
# Speaker features
221+
enable_speaker_diarization: Optional[bool] = None
222+
max_speakers: Optional[int] = None
223+
min_speakers: Optional[int] = None
224+
225+
# Text processing
226+
enable_automatic_punctuation: Optional[bool] = None
227+
enable_profanity_filter: Optional[bool] = None
228+
enable_smart_formatting: Optional[bool] = None
229+
enable_word_confidence: Optional[bool] = None
230+
enable_spoken_punctuation: Optional[bool] = None
231+
enable_spoken_emojis: Optional[bool] = None
232+
233+
# Advanced features
234+
enable_sentiment_analysis: Optional[bool] = None
235+
enable_topic_detection: Optional[bool] = None
236+
enable_intent_recognition: Optional[bool] = None
237+
enable_summarization: Optional[bool] = None
238+
enable_translation: Optional[bool] = None
239+
translation_target_language: Optional[str] = None
240+
241+
# Confidence and alternatives
242+
include_confidence_scores: Optional[bool] = None
243+
max_alternatives: Optional[int] = None
244+
245+
# Processing options
246+
temperature: Optional[float] = None
247+
interim_results: Optional[bool] = None
248+
vad_sensitivity: Optional[float] = None
249+
stream: Optional[bool] = None # Enable streaming output
250+
251+
# Custom parameters
252+
custom_parameters: Dict[str, Any] = field(default_factory=dict)
253+
254+
def __post_init__(self):
255+
"""Validate parameters and constraints."""
256+
# Validate constraints
257+
if self.temperature is not None and not (0.0 <= self.temperature <= 1.0):
258+
raise ValueError("temperature must be between 0.0 and 1.0")
259+
260+
if self.max_speakers is not None and self.max_speakers < 1:
261+
raise ValueError("max_speakers must be at least 1")
262+
263+
if self.min_speakers is not None and self.min_speakers < 1:
264+
raise ValueError("min_speakers must be at least 1")
265+
266+
if (
267+
self.max_speakers is not None
268+
and self.min_speakers is not None
269+
and self.min_speakers > self.max_speakers
270+
):
271+
raise ValueError("min_speakers cannot be greater than max_speakers")
272+
273+
if self.vad_sensitivity is not None and not (
274+
0.0 <= self.vad_sensitivity <= 1.0
275+
):
276+
raise ValueError("vad_sensitivity must be between 0.0 and 1.0")
277+
278+
def has_any_parameters(self) -> bool:
279+
"""Check if any parameters are set."""
280+
for field_name, field_value in self.__dict__.items():
281+
if field_name == "custom_parameters":
282+
if field_value:
283+
return True
284+
elif field_value is not None:
285+
return True
286+
return False
287+
288+
def get_set_parameters(self) -> Dict[str, Any]:
289+
"""Get only the parameters that are set."""
290+
set_params = {}
291+
for field_name, field_value in self.__dict__.items():
292+
if field_name == "custom_parameters":
293+
if field_value:
294+
set_params[field_name] = field_value
295+
elif field_value is not None:
296+
set_params[field_name] = field_value
297+
return set_params

0 commit comments

Comments
 (0)