Skip to content

Commit 18dda54

Browse files
committed
feat: Update google provider for ASR
1 parent c8b963b commit 18dda54

File tree

6 files changed

+353
-7
lines changed

6 files changed

+353
-7
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ result = client.audio.transcriptions.create(
237237
)
238238
```
239239

240-
**Supported providers:** OpenAI, Deepgram
240+
**Supported providers:** OpenAI, Deepgram, Google.
241241

242242
**Key features:** Same `provider:model` format • Rich metadata (timestamps, confidence, speakers) • Provider-specific advanced features
243243

aisuite/client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def __init__(self, provider_configs: dict = {}):
2828
self.provider_configs = provider_configs
2929
self._chat = None
3030
self._audio = None
31-
self._initialize_providers()
3231

3332
def _initialize_providers(self):
3433
"""Helper method to initialize or update providers."""
@@ -60,7 +59,7 @@ def configure(self, provider_configs: dict = None):
6059
return
6160

6261
self.provider_configs.update(provider_configs)
63-
self._initialize_providers() # NOTE: This will override existing provider instances.
62+
# Providers will be lazily initialized when needed
6463

6564
@property
6665
def chat(self):

aisuite/providers/google_provider.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
import json
5-
from typing import List, Dict, Any, Optional
5+
from typing import List, Dict, Any, Optional, Union, BinaryIO
66

77
import vertexai
88
from vertexai.generative_models import (
@@ -16,6 +16,8 @@
1616
import pprint
1717

1818
from aisuite.framework import ProviderInterface, ChatCompletionResponse, Message
19+
from aisuite.framework.message import TranscriptionResult, Word, Segment, Alternative
20+
from aisuite.provider import ASRError
1921

2022

2123
DEFAULT_TEMPERATURE = 0.7
@@ -211,6 +213,9 @@ def __init__(self, **config):
211213

212214
self.transformer = GoogleMessageConverter()
213215

216+
# Initialize Speech client lazily
217+
self._speech_client = None
218+
214219
def chat_completions_create(self, model, messages, **kwargs):
215220
"""Request chat completions from the Google AI API.
216221
@@ -296,3 +301,113 @@ def chat_completions_create(self, model, messages, **kwargs):
296301

297302
# Convert and return the response
298303
return self.transformer.convert_response(response)
304+
305+
@property
306+
def speech_client(self):
307+
"""Lazy initialization of Google Cloud Speech client."""
308+
if self._speech_client is None:
309+
try:
310+
from google.cloud import speech
311+
312+
self._speech_client = speech.SpeechClient()
313+
except ImportError:
314+
raise ImportError(
315+
"google-cloud-speech is required for ASR functionality. "
316+
"Install it with: pip install google-cloud-speech"
317+
)
318+
return self._speech_client
319+
320+
def audio_transcriptions_create(
321+
self, model: str, file: Union[str, BinaryIO], **kwargs
322+
) -> TranscriptionResult:
323+
"""Create audio transcription using Google Cloud Speech-to-Text API."""
324+
try:
325+
from google.cloud import speech
326+
327+
# Handle file input
328+
if isinstance(file, str):
329+
with open(file, "rb") as audio_file:
330+
audio_data = audio_file.read()
331+
else:
332+
audio_data = file.read()
333+
334+
# Create audio object
335+
audio = speech.RecognitionAudio(content=audio_data)
336+
337+
# Configure recognition settings
338+
config = speech.RecognitionConfig(
339+
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
340+
sample_rate_hertz=kwargs.get("sample_rate_hertz", 16000),
341+
language_code=kwargs.get("language", "en-US"),
342+
enable_word_time_offsets=True,
343+
enable_word_confidence=True,
344+
enable_automatic_punctuation=kwargs.get("punctuate", True),
345+
model=model if model != "default" else "latest_long",
346+
)
347+
348+
# Make API request
349+
response = self.speech_client.recognize(config=config, audio=audio)
350+
return self._parse_google_response(response)
351+
352+
except ImportError:
353+
raise ASRError(
354+
"google-cloud-speech is required for ASR functionality. "
355+
"Install it with: pip install google-cloud-speech"
356+
)
357+
except Exception as e:
358+
raise ASRError(f"Google Speech-to-Text error: {e}")
359+
360+
def _parse_google_response(self, response) -> TranscriptionResult:
361+
"""Convert Google Speech-to-Text response to unified TranscriptionResult."""
362+
if not response.results:
363+
return TranscriptionResult(text="", language=None)
364+
365+
# Get the best result
366+
best_result = response.results[0]
367+
if not best_result.alternatives:
368+
return TranscriptionResult(text="", language=None)
369+
370+
# Get the best alternative
371+
best_alternative = best_result.alternatives[0]
372+
text = best_alternative.transcript
373+
confidence = getattr(best_alternative, "confidence", None)
374+
375+
# Parse words if available
376+
words = []
377+
if hasattr(best_alternative, "words") and best_alternative.words:
378+
for word in best_alternative.words:
379+
words.append(
380+
Word(
381+
word=word.word,
382+
start=(
383+
word.start_time.total_seconds()
384+
if hasattr(word, "start_time")
385+
else 0.0
386+
),
387+
end=(
388+
word.end_time.total_seconds()
389+
if hasattr(word, "end_time")
390+
else 0.0
391+
),
392+
confidence=getattr(word, "confidence", None),
393+
)
394+
)
395+
396+
# Create alternatives list
397+
alternatives = []
398+
for alt in best_result.alternatives:
399+
alternatives.append(
400+
Alternative(
401+
transcript=alt.transcript,
402+
confidence=getattr(alt, "confidence", None),
403+
)
404+
)
405+
406+
return TranscriptionResult(
407+
text=text,
408+
language=None, # Google doesn't return detected language in this format
409+
confidence=confidence,
410+
task="transcribe",
411+
words=words if words else None,
412+
alternatives=alternatives if alternatives else None,
413+
)

examples/asr_example.ipynb

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,116 @@
6060
"for word in result.words[:3]:\n",
6161
" print(f\"{word.word}: {word.start:.1f}s-{word.end:.1f}s\")\n"
6262
]
63+
},
64+
{
65+
"cell_type": "markdown",
66+
"id": "32ed3f0f",
67+
"metadata": {},
68+
"source": [
69+
"## Google Cloud Speech-to-Text\n",
70+
"\n",
71+
"Google provider supports ASR using Google Cloud Speech-to-Text API. Make sure you have:\n",
72+
"- `google-cloud-speech` library installed: `pip install google-cloud-speech`\n",
73+
"- Google Cloud credentials configured\n",
74+
"- Required environment variables: `GOOGLE_PROJECT_ID`, `GOOGLE_REGION`, `GOOGLE_APPLICATION_CREDENTIALS`\n"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": null,
80+
"id": "0540ca09",
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"# Configure Google provider\n",
85+
"client.configure({\n",
86+
" \"google\": {\n",
87+
" \"project_id\": \"your-project-id\",\n",
88+
" \"region\": \"us-central1\", \n",
89+
" \"application_credentials\": \"path/to/credentials.json\"\n",
90+
" }\n",
91+
"})\n",
92+
"\n",
93+
"# Basic Google transcription\n",
94+
"result = client.audio.transcriptions.create(\n",
95+
" model=\"google:latest_long\",\n",
96+
" file=audio_file,\n",
97+
" language=\"en-US\"\n",
98+
")\n",
99+
"print(f\"Google transcription: {result.text}\")\n",
100+
"print(f\"Confidence: {result.confidence}\")\n",
101+
"print(f\"Task: {result.task}\")\n"
102+
]
103+
},
104+
{
105+
"cell_type": "code",
106+
"execution_count": null,
107+
"id": "2bd3aa6a",
108+
"metadata": {},
109+
"outputs": [],
110+
"source": [
111+
"# Google transcription with advanced options\n",
112+
"result = client.audio.transcriptions.create(\n",
113+
" model=\"google:latest_long\",\n",
114+
" file=audio_file,\n",
115+
" language=\"en-US\",\n",
116+
" sample_rate_hertz=44100,\n",
117+
" punctuate=True\n",
118+
")\n",
119+
"\n",
120+
"print(f\"Text: {result.text}\")\n",
121+
"print(f\"Language: {result.language}\")\n",
122+
"\n",
123+
"# Show word-level timestamps if available\n",
124+
"if result.words:\n",
125+
" print(f\"Words with timestamps: {len(result.words)}\")\n",
126+
" for word in result.words[:5]: # Show first 5 words\n",
127+
" print(f\" {word.word}: {word.start:.1f}s-{word.end:.1f}s (confidence: {word.confidence:.2f})\")\n",
128+
"\n",
129+
"# Show alternatives if available\n",
130+
"if result.alternatives:\n",
131+
" print(f\"Alternatives: {len(result.alternatives)}\")\n",
132+
" for i, alt in enumerate(result.alternatives[:3]): # Show first 3 alternatives\n",
133+
" print(f\" Alt {i+1}: {alt.transcript} (confidence: {alt.confidence:.2f})\")\n"
134+
]
135+
},
136+
{
137+
"cell_type": "markdown",
138+
"id": "f777a880",
139+
"metadata": {},
140+
"source": [
141+
"## Deepgram Provider\n",
142+
"\n",
143+
"You can also use Deepgram for ASR with advanced features like speaker diarization.\n"
144+
]
145+
},
146+
{
147+
"cell_type": "code",
148+
"execution_count": null,
149+
"id": "bd51b0ed",
150+
"metadata": {},
151+
"outputs": [],
152+
"source": [
153+
"# Deepgram transcription with speaker diarization\n",
154+
"result = client.audio.transcriptions.create(\n",
155+
" model=\"deepgram:nova-2\",\n",
156+
" file=audio_file,\n",
157+
" diarize=True,\n",
158+
" punctuate=True,\n",
159+
" language=\"en-US\"\n",
160+
")\n",
161+
"\n",
162+
"print(f\"Deepgram transcription: {result.text}\")\n",
163+
"print(f\"Confidence: {result.confidence}\")\n",
164+
"\n",
165+
"# Show speaker information if available\n",
166+
"if result.words:\n",
167+
" speakers = set(word.speaker for word in result.words if word.speaker is not None)\n",
168+
" print(f\"Detected speakers: {len(speakers)}\")\n",
169+
" for word in result.words[:5]:\n",
170+
" if word.speaker is not None:\n",
171+
" print(f\" {word.word} (Speaker {word.speaker}): {word.start:.1f}s-{word.end:.1f}s\")\n"
172+
]
63173
}
64174
],
65175
"metadata": {

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ anthropic = { version = "^0.30.1", optional = true }
1111
boto3 = { version = "^1.34.144", optional = true }
1212
cohere = { version = "^5.12.0", optional = true }
1313
vertexai = { version = "^1.63.0", optional = true }
14+
google-cloud-speech = { version = "^2.33.0", optional = true }
1415
groq = { version = "^0.9.0", optional = true }
1516
mistralai = { version = "^1.0.3", optional = true }
1617
openai = { version = "^1.35.8", optional = true }
@@ -27,7 +28,7 @@ azure = []
2728
cerebras = ["cerebras_cloud_sdk"]
2829
cohere = ["cohere"]
2930
deepseek = ["openai"]
30-
google = ["vertexai"]
31+
google = ["vertexai", "google-cloud-speech"]
3132
groq = ["groq"]
3233
huggingface = []
3334
mistral = ["mistralai"]
@@ -52,6 +53,7 @@ chromadb = "^0.5.4"
5253
sentence-transformers = "^3.0.1"
5354
datasets = "^2.20.0"
5455
vertexai = "^1.63.0"
56+
google-cloud-speech = "^2.33.0"
5557
ibm-watsonx-ai = "^1.1.16"
5658
cerebras_cloud_sdk = "^1.19.0"
5759

0 commit comments

Comments
 (0)