|
16 | 16 | from functools import lru_cache |
17 | 17 | from typing import Any |
18 | 18 |
|
19 | | -from lightrag.utils import logger, remove_think_tags, safe_unicode_decode |
| 19 | +import numpy as np |
| 20 | +from tenacity import ( |
| 21 | + retry, |
| 22 | + stop_after_attempt, |
| 23 | + wait_exponential, |
| 24 | + retry_if_exception_type, |
| 25 | +) |
| 26 | + |
| 27 | +from lightrag.utils import ( |
| 28 | + logger, |
| 29 | + remove_think_tags, |
| 30 | + safe_unicode_decode, |
| 31 | + wrap_embedding_func_with_attrs, |
| 32 | +) |
20 | 33 |
|
21 | 34 | import pipmaster as pm |
22 | 35 |
|
@@ -416,7 +429,136 @@ async def gemini_model_complete( |
416 | 429 | ) |
417 | 430 |
|
418 | 431 |
|
| 432 | +@wrap_embedding_func_with_attrs(embedding_dim=1536) |
| 433 | +@retry( |
| 434 | + stop=stop_after_attempt(3), |
| 435 | + wait=wait_exponential(multiplier=1, min=4, max=60), |
| 436 | + retry=( |
| 437 | + retry_if_exception_type(Exception) # Gemini uses generic exceptions |
| 438 | + ), |
| 439 | +) |
| 440 | +async def gemini_embed( |
| 441 | + texts: list[str], |
| 442 | + model: str = "gemini-embedding-001", |
| 443 | + base_url: str | None = None, |
| 444 | + api_key: str | None = None, |
| 445 | + embedding_dim: int | None = None, |
| 446 | + task_type: str = "RETRIEVAL_DOCUMENT", |
| 447 | + timeout: int | None = None, |
| 448 | + token_tracker: Any | None = None, |
| 449 | +) -> np.ndarray: |
| 450 | + """Generate embeddings for a list of texts using Gemini's API. |
| 451 | +
|
| 452 | + This function uses Google's Gemini embedding model to generate text embeddings. |
| 453 | + It supports dynamic dimension control and automatic normalization for dimensions |
| 454 | + less than 3072. |
| 455 | +
|
| 456 | + Args: |
| 457 | + texts: List of texts to embed. |
| 458 | + model: The Gemini embedding model to use. Default is "gemini-embedding-001". |
| 459 | + base_url: Optional custom API endpoint. |
| 460 | + api_key: Optional Gemini API key. If None, uses environment variables. |
| 461 | + embedding_dim: Optional embedding dimension for dynamic dimension reduction. |
| 462 | + **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. |
| 463 | + Do NOT manually pass this parameter when calling the function directly. |
| 464 | + The dimension is controlled by the @wrap_embedding_func_with_attrs decorator |
| 465 | + or the EMBEDDING_DIM environment variable. |
| 466 | + Supported range: 128-3072. Recommended values: 768, 1536, 3072. |
| 467 | + task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT". |
| 468 | + Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, |
| 469 | + RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY, |
| 470 | + QUESTION_ANSWERING, FACT_VERIFICATION. |
| 471 | + timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API). |
| 472 | + token_tracker: Optional token usage tracker for monitoring API usage. |
| 473 | +
|
| 474 | + Returns: |
| 475 | + A numpy array of embeddings, one per input text. For dimensions < 3072, |
| 476 | + the embeddings are L2-normalized to ensure optimal semantic similarity performance. |
| 477 | +
|
| 478 | + Raises: |
| 479 | + ValueError: If API key is not provided or configured. |
| 480 | + RuntimeError: If the response from Gemini is invalid or empty. |
| 481 | +
|
| 482 | + Note: |
| 483 | + - For dimension 3072: Embeddings are already normalized by the API |
| 484 | + - For dimensions < 3072: Embeddings are L2-normalized after retrieval |
| 485 | + - Normalization ensures accurate semantic similarity via cosine distance |
| 486 | + """ |
| 487 | + loop = asyncio.get_running_loop() |
| 488 | + |
| 489 | + key = _ensure_api_key(api_key) |
| 490 | + # Convert timeout from seconds to milliseconds for Gemini API |
| 491 | + timeout_ms = timeout * 1000 if timeout else None |
| 492 | + client = _get_gemini_client(key, base_url, timeout_ms) |
| 493 | + |
| 494 | + # Prepare embedding configuration |
| 495 | + config_kwargs: dict[str, Any] = {} |
| 496 | + |
| 497 | + # Add task_type to config |
| 498 | + if task_type: |
| 499 | + config_kwargs["task_type"] = task_type |
| 500 | + |
| 501 | + # Add output_dimensionality if embedding_dim is provided |
| 502 | + if embedding_dim is not None: |
| 503 | + config_kwargs["output_dimensionality"] = embedding_dim |
| 504 | + |
| 505 | + # Create config object if we have parameters |
| 506 | + config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None |
| 507 | + |
| 508 | + def _call_embed() -> Any: |
| 509 | + """Call Gemini embedding API in executor thread.""" |
| 510 | + request_kwargs: dict[str, Any] = { |
| 511 | + "model": model, |
| 512 | + "contents": texts, |
| 513 | + } |
| 514 | + if config_obj is not None: |
| 515 | + request_kwargs["config"] = config_obj |
| 516 | + |
| 517 | + return client.models.embed_content(**request_kwargs) |
| 518 | + |
| 519 | + # Execute API call in thread pool |
| 520 | + response = await loop.run_in_executor(None, _call_embed) |
| 521 | + |
| 522 | + # Extract embeddings from response |
| 523 | + if not hasattr(response, "embeddings") or not response.embeddings: |
| 524 | + raise RuntimeError("Gemini response did not contain embeddings.") |
| 525 | + |
| 526 | + # Convert embeddings to numpy array |
| 527 | + embeddings = np.array( |
| 528 | + [np.array(e.values, dtype=np.float32) for e in response.embeddings] |
| 529 | + ) |
| 530 | + |
| 531 | + # Apply L2 normalization for dimensions < 3072 |
| 532 | + # The 3072 dimension embedding is already normalized by Gemini API |
| 533 | + if embedding_dim and embedding_dim < 3072: |
| 534 | + # Normalize each embedding vector to unit length |
| 535 | + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) |
| 536 | + # Avoid division by zero |
| 537 | + norms = np.where(norms == 0, 1, norms) |
| 538 | + embeddings = embeddings / norms |
| 539 | + logger.debug( |
| 540 | + f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}" |
| 541 | + ) |
| 542 | + |
| 543 | + # Track token usage if tracker is provided |
| 544 | + # Note: Gemini embedding API may not provide usage metadata |
| 545 | + if token_tracker and hasattr(response, "usage_metadata"): |
| 546 | + usage = response.usage_metadata |
| 547 | + token_counts = { |
| 548 | + "prompt_tokens": getattr(usage, "prompt_token_count", 0), |
| 549 | + "total_tokens": getattr(usage, "total_token_count", 0), |
| 550 | + } |
| 551 | + token_tracker.add_usage(token_counts) |
| 552 | + |
| 553 | + logger.debug( |
| 554 | + f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}" |
| 555 | + ) |
| 556 | + |
| 557 | + return embeddings |
| 558 | + |
| 559 | + |
419 | 560 | __all__ = [ |
420 | 561 | "gemini_complete_if_cache", |
421 | 562 | "gemini_model_complete", |
| 563 | + "gemini_embed", |
422 | 564 | ] |
0 commit comments