|
2 | 2 |
|
3 | 3 | import os
|
4 | 4 | import json
|
5 |
| -from typing import List, Dict, Any, Optional |
| 5 | +from typing import List, Dict, Any, Optional, Union, BinaryIO |
6 | 6 |
|
7 | 7 | import vertexai
|
8 | 8 | from vertexai.generative_models import (
|
|
16 | 16 | import pprint
|
17 | 17 |
|
18 | 18 | from aisuite.framework import ProviderInterface, ChatCompletionResponse, Message
|
| 19 | +from aisuite.framework.message import TranscriptionResult, Word, Segment, Alternative |
| 20 | +from aisuite.provider import ASRError |
19 | 21 |
|
20 | 22 |
|
21 | 23 | DEFAULT_TEMPERATURE = 0.7
|
@@ -211,6 +213,9 @@ def __init__(self, **config):
|
211 | 213 |
|
212 | 214 | self.transformer = GoogleMessageConverter()
|
213 | 215 |
|
| 216 | + # Initialize Speech client lazily |
| 217 | + self._speech_client = None |
| 218 | + |
214 | 219 | def chat_completions_create(self, model, messages, **kwargs):
|
215 | 220 | """Request chat completions from the Google AI API.
|
216 | 221 |
|
@@ -296,3 +301,113 @@ def chat_completions_create(self, model, messages, **kwargs):
|
296 | 301 |
|
297 | 302 | # Convert and return the response
|
298 | 303 | return self.transformer.convert_response(response)
|
| 304 | + |
| 305 | + @property |
| 306 | + def speech_client(self): |
| 307 | + """Lazy initialization of Google Cloud Speech client.""" |
| 308 | + if self._speech_client is None: |
| 309 | + try: |
| 310 | + from google.cloud import speech |
| 311 | + |
| 312 | + self._speech_client = speech.SpeechClient() |
| 313 | + except ImportError: |
| 314 | + raise ImportError( |
| 315 | + "google-cloud-speech is required for ASR functionality. " |
| 316 | + "Install it with: pip install google-cloud-speech" |
| 317 | + ) |
| 318 | + return self._speech_client |
| 319 | + |
| 320 | + def audio_transcriptions_create( |
| 321 | + self, model: str, file: Union[str, BinaryIO], **kwargs |
| 322 | + ) -> TranscriptionResult: |
| 323 | + """Create audio transcription using Google Cloud Speech-to-Text API.""" |
| 324 | + try: |
| 325 | + from google.cloud import speech |
| 326 | + |
| 327 | + # Handle file input |
| 328 | + if isinstance(file, str): |
| 329 | + with open(file, "rb") as audio_file: |
| 330 | + audio_data = audio_file.read() |
| 331 | + else: |
| 332 | + audio_data = file.read() |
| 333 | + |
| 334 | + # Create audio object |
| 335 | + audio = speech.RecognitionAudio(content=audio_data) |
| 336 | + |
| 337 | + # Configure recognition settings |
| 338 | + config = speech.RecognitionConfig( |
| 339 | + encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16, |
| 340 | + sample_rate_hertz=kwargs.get("sample_rate_hertz", 16000), |
| 341 | + language_code=kwargs.get("language", "en-US"), |
| 342 | + enable_word_time_offsets=True, |
| 343 | + enable_word_confidence=True, |
| 344 | + enable_automatic_punctuation=kwargs.get("punctuate", True), |
| 345 | + model=model if model != "default" else "latest_long", |
| 346 | + ) |
| 347 | + |
| 348 | + # Make API request |
| 349 | + response = self.speech_client.recognize(config=config, audio=audio) |
| 350 | + return self._parse_google_response(response) |
| 351 | + |
| 352 | + except ImportError: |
| 353 | + raise ASRError( |
| 354 | + "google-cloud-speech is required for ASR functionality. " |
| 355 | + "Install it with: pip install google-cloud-speech" |
| 356 | + ) |
| 357 | + except Exception as e: |
| 358 | + raise ASRError(f"Google Speech-to-Text error: {e}") |
| 359 | + |
| 360 | + def _parse_google_response(self, response) -> TranscriptionResult: |
| 361 | + """Convert Google Speech-to-Text response to unified TranscriptionResult.""" |
| 362 | + if not response.results: |
| 363 | + return TranscriptionResult(text="", language=None) |
| 364 | + |
| 365 | + # Get the best result |
| 366 | + best_result = response.results[0] |
| 367 | + if not best_result.alternatives: |
| 368 | + return TranscriptionResult(text="", language=None) |
| 369 | + |
| 370 | + # Get the best alternative |
| 371 | + best_alternative = best_result.alternatives[0] |
| 372 | + text = best_alternative.transcript |
| 373 | + confidence = getattr(best_alternative, "confidence", None) |
| 374 | + |
| 375 | + # Parse words if available |
| 376 | + words = [] |
| 377 | + if hasattr(best_alternative, "words") and best_alternative.words: |
| 378 | + for word in best_alternative.words: |
| 379 | + words.append( |
| 380 | + Word( |
| 381 | + word=word.word, |
| 382 | + start=( |
| 383 | + word.start_time.total_seconds() |
| 384 | + if hasattr(word, "start_time") |
| 385 | + else 0.0 |
| 386 | + ), |
| 387 | + end=( |
| 388 | + word.end_time.total_seconds() |
| 389 | + if hasattr(word, "end_time") |
| 390 | + else 0.0 |
| 391 | + ), |
| 392 | + confidence=getattr(word, "confidence", None), |
| 393 | + ) |
| 394 | + ) |
| 395 | + |
| 396 | + # Create alternatives list |
| 397 | + alternatives = [] |
| 398 | + for alt in best_result.alternatives: |
| 399 | + alternatives.append( |
| 400 | + Alternative( |
| 401 | + transcript=alt.transcript, |
| 402 | + confidence=getattr(alt, "confidence", None), |
| 403 | + ) |
| 404 | + ) |
| 405 | + |
| 406 | + return TranscriptionResult( |
| 407 | + text=text, |
| 408 | + language=None, # Google doesn't return detected language in this format |
| 409 | + confidence=confidence, |
| 410 | + task="transcribe", |
| 411 | + words=words if words else None, |
| 412 | + alternatives=alternatives if alternatives else None, |
| 413 | + ) |
0 commit comments