Skip to content

Commit 62105d4

Browse files
committed
Enhance providers with output streaming support
1 parent 1df1e41 commit 62105d4

File tree

5 files changed

+346
-159
lines changed

5 files changed

+346
-159
lines changed

aisuite/providers/deepgram_provider.py

Lines changed: 207 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import os
2-
import asyncio
2+
import json
3+
import numpy as np
4+
import queue
5+
import threading
6+
import time
37
from typing import Union, BinaryIO, Optional, AsyncGenerator
48

59
from aisuite.provider import Provider, ASRError, Audio
@@ -99,57 +103,136 @@ async def create_stream_output(
99103
model: str,
100104
file: Union[str, BinaryIO],
101105
options: Optional[TranscriptionOptions] = None,
106+
chunk_size_minutes: float = 3.0,
102107
**kwargs,
103108
) -> AsyncGenerator[StreamingTranscriptionChunk, None]:
104-
"""Create streaming audio transcription using Deepgram SDK."""
109+
"""Create streaming audio transcription using Deepgram SDK with chunked processing."""
105110
try:
106111
from deepgram import LiveOptions
112+
from deepgram.clients.listen import LiveTranscriptionEvents
107113

108-
api_params = self._prepare_api_params(model, options, kwargs)
109-
live_options = LiveOptions(**api_params)
110-
connection = self.client.listen.websocket.v("1")
114+
# Load and prepare audio
115+
audio_data, sample_rate = await self._load_and_prepare_audio(file)
111116

112-
transcript_queue = asyncio.Queue()
113-
finished_event = asyncio.Event()
114-
115-
def on_message(self, result, **kwargs):
116-
sentence = result.channel.alternatives[0].transcript
117-
if sentence:
118-
chunk = StreamingTranscriptionChunk(
119-
text=sentence,
120-
is_final=result.is_final,
121-
confidence=getattr(
122-
result.channel.alternatives[0], "confidence", None
123-
),
117+
# Calculate chunking strategy
118+
duration_seconds = len(audio_data) / sample_rate
119+
chunk_duration_seconds = chunk_size_minutes * 60
120+
121+
if duration_seconds <= chunk_duration_seconds:
122+
chunks = [audio_data]
123+
else:
124+
chunk_size_samples = int(chunk_duration_seconds * sample_rate)
125+
chunks = []
126+
num_chunks = int(np.ceil(duration_seconds / chunk_duration_seconds))
127+
for i in range(num_chunks):
128+
start_sample = i * chunk_size_samples
129+
end_sample = min(
130+
start_sample + chunk_size_samples, len(audio_data)
124131
)
125-
asyncio.create_task(transcript_queue.put(chunk))
132+
chunks.append(audio_data[start_sample:end_sample])
133+
134+
# Setup API parameters
135+
api_params = self._prepare_api_params(
136+
model, options, kwargs, is_streaming=True
137+
)
138+
api_params["interim_results"] = (
139+
True # Enable interim results for streaming
140+
)
141+
142+
# Add critical audio format parameters (matching reference)
143+
api_params["encoding"] = "linear16" # PCM16 format
144+
api_params["sample_rate"] = 16000 # Match our target sample rate
145+
api_params["channels"] = 1 # Mono audio
126146

127-
def on_error(self, error, **kwargs):
128-
asyncio.create_task(
147+
live_options = LiveOptions(**api_params)
148+
149+
# Create single connection for all chunks
150+
connection = self.client.listen.websocket.v("1")
151+
152+
# Use thread-safe queue instead of asyncio.Queue for cross-thread communication
153+
transcript_queue = queue.Queue()
154+
connection_closed = threading.Event()
155+
156+
def on_message(*args, **kwargs):
157+
"""Handle transcript events"""
158+
# Extract result from args or kwargs (following reference pattern)
159+
result = None
160+
if len(args) >= 2:
161+
result = args[1]
162+
elif "result" in kwargs:
163+
result = kwargs["result"]
164+
else:
165+
return
166+
167+
if hasattr(result, "channel") and result.channel.alternatives:
168+
alt = result.channel.alternatives[0]
169+
if alt.transcript:
170+
chunk = StreamingTranscriptionChunk(
171+
text=alt.transcript,
172+
is_final=getattr(result, "is_final", False),
173+
confidence=getattr(alt, "confidence", None),
174+
)
175+
transcript_queue.put(chunk) # Thread-safe put
176+
177+
def on_error(*args, **kwargs):
178+
"""Handle error events"""
179+
# Extract error from args or kwargs
180+
error = None
181+
if len(args) >= 2:
182+
error = args[1]
183+
elif "error" in kwargs:
184+
error = kwargs["error"]
185+
186+
if error:
129187
transcript_queue.put(
130188
ASRError(f"Deepgram streaming error: {error}")
131-
)
132-
)
133-
finished_event.set()
189+
) # Thread-safe put
134190

135-
def on_close(self, close, **kwargs):
136-
finished_event.set()
191+
def on_close(*args, **kwargs):
192+
"""Handle connection close events"""
193+
connection_closed.set()
137194

138-
connection.on(connection.event.TRANSCRIPT_RECEIVED, on_message)
139-
connection.on(connection.event.ERROR, on_error)
140-
connection.on(connection.event.CLOSE, on_close)
195+
# Register event handlers
196+
connection.on(LiveTranscriptionEvents.Transcript, on_message)
197+
connection.on(LiveTranscriptionEvents.Error, on_error)
198+
connection.on(LiveTranscriptionEvents.Close, on_close)
141199

200+
# Start connection
142201
if not connection.start(live_options):
143202
raise ASRError("Failed to start Deepgram streaming connection")
144203

145-
audio_data = self._read_audio_data(file)
146-
await self._send_audio_chunks(connection, audio_data)
147-
connection.finish()
148-
149-
async for chunk in self._yield_transcription_chunks(
150-
transcript_queue, finished_event
151-
):
152-
yield chunk
204+
# Send all chunks through single connection
205+
try:
206+
for audio_chunk in chunks:
207+
self._send_audio_chunk(connection, audio_chunk)
208+
209+
# Send CloseStream message to signal end of all chunks
210+
close_stream_message = json.dumps({"type": "CloseStream"})
211+
connection.send(close_stream_message)
212+
213+
# Yield results until connection closes naturally
214+
while not connection_closed.is_set():
215+
try:
216+
# Use thread-safe queue with timeout
217+
chunk = transcript_queue.get(timeout=0.1)
218+
if isinstance(chunk, Exception):
219+
raise chunk
220+
yield chunk
221+
except queue.Empty:
222+
continue
223+
224+
# Get any remaining results
225+
while not transcript_queue.empty():
226+
try:
227+
chunk = transcript_queue.get_nowait()
228+
if isinstance(chunk, Exception):
229+
raise chunk
230+
yield chunk
231+
except queue.Empty:
232+
break
233+
234+
except Exception as e:
235+
raise ASRError(f"Error during audio streaming: {e}")
153236

154237
except Exception as e:
155238
raise ASRError(f"Deepgram streaming transcription error: {e}")
@@ -159,14 +242,23 @@ def _extract_model_name(self, model: str) -> str:
159242
return model
160243

161244
def _prepare_api_params(
162-
self, model: str, options: Optional[TranscriptionOptions], kwargs: dict
245+
self,
246+
model: str,
247+
options: Optional[TranscriptionOptions],
248+
kwargs: dict,
249+
is_streaming: bool = False,
163250
) -> dict:
164251
"""Prepare API parameters for Deepgram."""
165252
if options is not None:
166253
api_params = ParameterMapper.map_to_deepgram(options)
167254
else:
168255
api_params = self._map_openai_to_deepgram_params(kwargs)
169256

257+
# Remove parameters not supported by LiveOptions (streaming)
258+
if is_streaming:
259+
# utterances is only supported in batch/prerecorded, not streaming
260+
api_params.pop("utterances", None)
261+
170262
model_name = self._extract_model_name(model)
171263
api_params.setdefault("smart_format", True)
172264
api_params.setdefault("punctuate", True)
@@ -188,34 +280,80 @@ def _prepare_audio_payload(self, file: Union[str, BinaryIO]) -> dict:
188280
)
189281
return {"buffer": buffer_data}
190282

191-
def _read_audio_data(self, file: Union[str, BinaryIO]) -> bytes:
192-
"""Read audio data from file or file-like object."""
193-
if isinstance(file, str):
194-
with open(file, "rb") as audio_file:
195-
return audio_file.read()
196-
else:
197-
return file.read()
198-
199-
async def _send_audio_chunks(self, connection, audio_data: bytes) -> None:
200-
"""Send audio data in chunks to Deepgram connection."""
201-
chunk_size = 8192
202-
for i in range(0, len(audio_data), chunk_size):
203-
chunk = audio_data[i : i + chunk_size]
204-
connection.send(chunk)
205-
await asyncio.sleep(0.01)
206-
207-
async def _yield_transcription_chunks(
208-
self, transcript_queue: asyncio.Queue, finished_event: asyncio.Event
209-
) -> AsyncGenerator[StreamingTranscriptionChunk, None]:
210-
"""Yield transcription chunks as they arrive."""
211-
while not finished_event.is_set():
283+
async def _load_and_prepare_audio(
284+
self, file: Union[str, BinaryIO]
285+
) -> tuple[np.ndarray, int]:
286+
"""Load and prepare audio file for streaming.
287+
288+
Conversions performed only when necessary:
289+
- Stereo to mono: Required for multi-channel audio
290+
- Sample rate conversion: Required when input != 16kHz
291+
- Other formats: Error out as unsupported
292+
"""
293+
try:
212294
try:
213-
chunk = await asyncio.wait_for(transcript_queue.get(), timeout=1.0)
214-
if isinstance(chunk, Exception):
215-
raise chunk
216-
yield chunk
217-
except asyncio.TimeoutError:
218-
continue
295+
import soundfile as sf
296+
except ImportError:
297+
raise ASRError(
298+
"soundfile is required for audio processing. Install with: pip install soundfile"
299+
)
300+
301+
if isinstance(file, str):
302+
audio_data, original_sample_rate = sf.read(file)
303+
else:
304+
audio_data, original_sample_rate = sf.read(file)
305+
306+
audio_data = np.asarray(audio_data, dtype=np.float32)
307+
308+
# Convert to mono if stereo
309+
if len(audio_data.shape) > 1:
310+
if audio_data.shape[1] == 2:
311+
audio_data = np.mean(audio_data, axis=1)
312+
else:
313+
raise ASRError(
314+
f"Unsupported audio format: {audio_data.shape[1]} channels. Only mono and stereo are supported."
315+
)
316+
317+
# Resample to 16kHz if needed
318+
target_sample_rate = 16000
319+
if original_sample_rate != target_sample_rate:
320+
try:
321+
from scipy import signal
322+
323+
num_samples = int(
324+
len(audio_data) * target_sample_rate / original_sample_rate
325+
)
326+
audio_data = signal.resample(audio_data, num_samples)
327+
except ImportError:
328+
raise ASRError(
329+
f"Audio resampling required but scipy not available. "
330+
f"Input is {original_sample_rate}Hz, need {target_sample_rate}Hz. "
331+
f"Install scipy or provide audio at {target_sample_rate}Hz."
332+
)
333+
334+
return np.asarray(audio_data, dtype=np.float32), target_sample_rate
335+
336+
except Exception as e:
337+
if isinstance(e, ASRError):
338+
raise
339+
raise ASRError(f"Error loading audio file: {e}")
340+
341+
def _send_audio_chunk(self, connection, audio_chunk: np.ndarray) -> None:
342+
"""Send audio chunk data through the connection."""
343+
streaming_chunk_size = 8000 # Match reference BLOCKSIZE (~0.5s @16kHz mono)
344+
send_delay = 0.01
345+
346+
for i in range(0, len(audio_chunk), streaming_chunk_size):
347+
piece = audio_chunk[i : i + streaming_chunk_size]
348+
349+
if len(piece) < streaming_chunk_size:
350+
piece = np.pad(
351+
piece, (0, streaming_chunk_size - len(piece)), mode="constant"
352+
)
353+
354+
pcm16 = (piece * 32767).astype(np.int16).tobytes()
355+
connection.send(pcm16)
356+
time.sleep(send_delay) # Use synchronous sleep like reference
219357

220358
def _map_openai_to_deepgram_params(self, openai_params: dict) -> dict:
221359
"""Map OpenAI-style parameters to Deepgram parameters."""
@@ -229,7 +367,11 @@ def _map_openai_to_deepgram_params(self, openai_params: dict) -> dict:
229367
granularities = openai_params["timestamp_granularities"]
230368
if "word" in granularities:
231369
deepgram_params["punctuate"] = True
232-
deepgram_params["utterances"] = True
370+
# Note: utterances is only for batch/prerecorded, not streaming
371+
372+
# Essential for streaming - map interim_results
373+
if "interim_results" in openai_params:
374+
deepgram_params["interim_results"] = openai_params["interim_results"]
233375

234376
return deepgram_params
235377

aisuite/providers/google_provider.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ async def create_stream_output(
398398
)
399399

400400
responses = self.provider.speech_client.streaming_recognize(
401-
requests=request_generator
401+
config=streaming_config, requests=request_generator
402402
)
403403

404404
for response in responses:
@@ -498,10 +498,6 @@ def _create_streaming_requests(
498498
"""Create streaming requests generator for Google Speech API."""
499499

500500
def request_generator():
501-
yield speech.StreamingRecognizeRequest(
502-
streaming_config=streaming_config
503-
)
504-
505501
chunk_size = 8192
506502
for i in range(0, len(audio_data), chunk_size):
507503
chunk = audio_data[i : i + chunk_size]

0 commit comments

Comments
 (0)