Skip to content

Commit de4ed73

Browse files
committed
Add Gemini embedding support
- Implement gemini_embed function - Add gemini to embedding binding choices - Add L2 normalization for dims < 3072
1 parent f4492d4 commit de4ed73

File tree

4 files changed

+216
-6
lines changed

4 files changed

+216
-6
lines changed

lightrag/api/config.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dotenv import load_dotenv
99
from lightrag.utils import get_env_value
1010
from lightrag.llm.binding_options import (
11+
GeminiEmbeddingOptions,
1112
GeminiLLMOptions,
1213
OllamaEmbeddingOptions,
1314
OllamaLLMOptions,
@@ -238,7 +239,15 @@ def parse_args() -> argparse.Namespace:
238239
"--embedding-binding",
239240
type=str,
240241
default=get_env_value("EMBEDDING_BINDING", "ollama"),
241-
choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"],
242+
choices=[
243+
"lollms",
244+
"ollama",
245+
"openai",
246+
"azure_openai",
247+
"aws_bedrock",
248+
"jina",
249+
"gemini",
250+
],
242251
help="Embedding binding type (default: from env or ollama)",
243252
)
244253
parser.add_argument(
@@ -265,12 +274,19 @@ def parse_args() -> argparse.Namespace:
265274
if "--embedding-binding" in sys.argv:
266275
try:
267276
idx = sys.argv.index("--embedding-binding")
268-
if idx + 1 < len(sys.argv) and sys.argv[idx + 1] == "ollama":
269-
OllamaEmbeddingOptions.add_args(parser)
277+
if idx + 1 < len(sys.argv):
278+
if sys.argv[idx + 1] == "ollama":
279+
OllamaEmbeddingOptions.add_args(parser)
280+
elif sys.argv[idx + 1] == "gemini":
281+
GeminiEmbeddingOptions.add_args(parser)
270282
except IndexError:
271283
pass
272-
elif os.environ.get("EMBEDDING_BINDING") == "ollama":
273-
OllamaEmbeddingOptions.add_args(parser)
284+
else:
285+
env_embedding_binding = os.environ.get("EMBEDDING_BINDING")
286+
if env_embedding_binding == "ollama":
287+
OllamaEmbeddingOptions.add_args(parser)
288+
elif env_embedding_binding == "gemini":
289+
GeminiEmbeddingOptions.add_args(parser)
274290

275291
# Add OpenAI LLM options when llm-binding is openai or azure_openai
276292
if "--llm-binding" in sys.argv:

lightrag/api/lightrag_server.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(self, args):
8989
# Initialize configurations based on binding conditions
9090
self.openai_llm_options = None
9191
self.gemini_llm_options = None
92+
self.gemini_embedding_options = None
9293
self.ollama_llm_options = None
9394
self.ollama_embedding_options = None
9495

@@ -135,6 +136,23 @@ def __init__(self, args):
135136
)
136137
self.ollama_embedding_options = {}
137138

139+
# Only initialize and log Gemini Embedding options when using Gemini Embedding binding
140+
if args.embedding_binding == "gemini":
141+
try:
142+
from lightrag.llm.binding_options import GeminiEmbeddingOptions
143+
144+
self.gemini_embedding_options = GeminiEmbeddingOptions.options_dict(
145+
args
146+
)
147+
logger.info(
148+
f"Gemini Embedding Options: {self.gemini_embedding_options}"
149+
)
150+
except ImportError:
151+
logger.warning(
152+
"GeminiEmbeddingOptions not available, using default configuration"
153+
)
154+
self.gemini_embedding_options = {}
155+
138156

139157
def check_frontend_build():
140158
"""Check if frontend is built and optionally check if source is up-to-date
@@ -296,6 +314,7 @@ def create_app(args):
296314
"azure_openai",
297315
"aws_bedrock",
298316
"jina",
317+
"gemini",
299318
]:
300319
raise Exception("embedding binding not supported")
301320

@@ -649,6 +668,26 @@ async def optimized_embedding_function(texts, embedding_dim=None):
649668
base_url=host,
650669
api_key=api_key,
651670
)
671+
elif binding == "gemini":
672+
from lightrag.llm.gemini import gemini_embed
673+
674+
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
675+
if config_cache.gemini_embedding_options is not None:
676+
gemini_options = config_cache.gemini_embedding_options
677+
else:
678+
# Fallback for cases where config cache wasn't initialized properly
679+
from lightrag.llm.binding_options import GeminiEmbeddingOptions
680+
681+
gemini_options = GeminiEmbeddingOptions.options_dict(args)
682+
683+
return await gemini_embed(
684+
texts,
685+
model=model,
686+
base_url=host,
687+
api_key=api_key,
688+
embedding_dim=embedding_dim,
689+
task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"),
690+
)
652691
else: # openai and compatible
653692
from lightrag.llm.openai import openai_embed
654693

lightrag/llm/binding_options.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,19 @@ class GeminiLLMOptions(BindingOptions):
508508
}
509509

510510

511+
@dataclass
512+
class GeminiEmbeddingOptions(BindingOptions):
513+
"""Options for Google Gemini embedding models."""
514+
515+
_binding_name: ClassVar[str] = "gemini_embedding"
516+
517+
task_type: str = "RETRIEVAL_DOCUMENT"
518+
519+
_help: ClassVar[dict[str, str]] = {
520+
"task_type": "Task type for embedding optimization (RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING, CODE_RETRIEVAL_QUERY, QUESTION_ANSWERING, FACT_VERIFICATION)",
521+
}
522+
523+
511524
# =============================================================================
512525
# Binding Options for OpenAI
513526
# =============================================================================

lightrag/llm/gemini.py

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,20 @@
1616
from functools import lru_cache
1717
from typing import Any
1818

19-
from lightrag.utils import logger, remove_think_tags, safe_unicode_decode
19+
import numpy as np
20+
from tenacity import (
21+
retry,
22+
stop_after_attempt,
23+
wait_exponential,
24+
retry_if_exception_type,
25+
)
26+
27+
from lightrag.utils import (
28+
logger,
29+
remove_think_tags,
30+
safe_unicode_decode,
31+
wrap_embedding_func_with_attrs,
32+
)
2033

2134
import pipmaster as pm
2235

@@ -416,7 +429,136 @@ async def gemini_model_complete(
416429
)
417430

418431

432+
@wrap_embedding_func_with_attrs(embedding_dim=1536)
433+
@retry(
434+
stop=stop_after_attempt(3),
435+
wait=wait_exponential(multiplier=1, min=4, max=60),
436+
retry=(
437+
retry_if_exception_type(Exception) # Gemini uses generic exceptions
438+
),
439+
)
440+
async def gemini_embed(
441+
texts: list[str],
442+
model: str = "gemini-embedding-001",
443+
base_url: str | None = None,
444+
api_key: str | None = None,
445+
embedding_dim: int | None = None,
446+
task_type: str = "RETRIEVAL_DOCUMENT",
447+
timeout: int | None = None,
448+
token_tracker: Any | None = None,
449+
) -> np.ndarray:
450+
"""Generate embeddings for a list of texts using Gemini's API.
451+
452+
This function uses Google's Gemini embedding model to generate text embeddings.
453+
It supports dynamic dimension control and automatic normalization for dimensions
454+
less than 3072.
455+
456+
Args:
457+
texts: List of texts to embed.
458+
model: The Gemini embedding model to use. Default is "gemini-embedding-001".
459+
base_url: Optional custom API endpoint.
460+
api_key: Optional Gemini API key. If None, uses environment variables.
461+
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
462+
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
463+
Do NOT manually pass this parameter when calling the function directly.
464+
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator
465+
or the EMBEDDING_DIM environment variable.
466+
Supported range: 128-3072. Recommended values: 768, 1536, 3072.
467+
task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT".
468+
Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING,
469+
RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY,
470+
QUESTION_ANSWERING, FACT_VERIFICATION.
471+
timeout: Request timeout in seconds (will be converted to milliseconds for Gemini API).
472+
token_tracker: Optional token usage tracker for monitoring API usage.
473+
474+
Returns:
475+
A numpy array of embeddings, one per input text. For dimensions < 3072,
476+
the embeddings are L2-normalized to ensure optimal semantic similarity performance.
477+
478+
Raises:
479+
ValueError: If API key is not provided or configured.
480+
RuntimeError: If the response from Gemini is invalid or empty.
481+
482+
Note:
483+
- For dimension 3072: Embeddings are already normalized by the API
484+
- For dimensions < 3072: Embeddings are L2-normalized after retrieval
485+
- Normalization ensures accurate semantic similarity via cosine distance
486+
"""
487+
loop = asyncio.get_running_loop()
488+
489+
key = _ensure_api_key(api_key)
490+
# Convert timeout from seconds to milliseconds for Gemini API
491+
timeout_ms = timeout * 1000 if timeout else None
492+
client = _get_gemini_client(key, base_url, timeout_ms)
493+
494+
# Prepare embedding configuration
495+
config_kwargs: dict[str, Any] = {}
496+
497+
# Add task_type to config
498+
if task_type:
499+
config_kwargs["task_type"] = task_type
500+
501+
# Add output_dimensionality if embedding_dim is provided
502+
if embedding_dim is not None:
503+
config_kwargs["output_dimensionality"] = embedding_dim
504+
505+
# Create config object if we have parameters
506+
config_obj = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None
507+
508+
def _call_embed() -> Any:
509+
"""Call Gemini embedding API in executor thread."""
510+
request_kwargs: dict[str, Any] = {
511+
"model": model,
512+
"contents": texts,
513+
}
514+
if config_obj is not None:
515+
request_kwargs["config"] = config_obj
516+
517+
return client.models.embed_content(**request_kwargs)
518+
519+
# Execute API call in thread pool
520+
response = await loop.run_in_executor(None, _call_embed)
521+
522+
# Extract embeddings from response
523+
if not hasattr(response, "embeddings") or not response.embeddings:
524+
raise RuntimeError("Gemini response did not contain embeddings.")
525+
526+
# Convert embeddings to numpy array
527+
embeddings = np.array(
528+
[np.array(e.values, dtype=np.float32) for e in response.embeddings]
529+
)
530+
531+
# Apply L2 normalization for dimensions < 3072
532+
# The 3072 dimension embedding is already normalized by Gemini API
533+
if embedding_dim and embedding_dim < 3072:
534+
# Normalize each embedding vector to unit length
535+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
536+
# Avoid division by zero
537+
norms = np.where(norms == 0, 1, norms)
538+
embeddings = embeddings / norms
539+
logger.debug(
540+
f"Applied L2 normalization to {len(embeddings)} embeddings of dimension {embedding_dim}"
541+
)
542+
543+
# Track token usage if tracker is provided
544+
# Note: Gemini embedding API may not provide usage metadata
545+
if token_tracker and hasattr(response, "usage_metadata"):
546+
usage = response.usage_metadata
547+
token_counts = {
548+
"prompt_tokens": getattr(usage, "prompt_token_count", 0),
549+
"total_tokens": getattr(usage, "total_token_count", 0),
550+
}
551+
token_tracker.add_usage(token_counts)
552+
553+
logger.debug(
554+
f"Generated {len(embeddings)} Gemini embeddings with dimension {embeddings.shape[1]}"
555+
)
556+
557+
return embeddings
558+
559+
419560
__all__ = [
420561
"gemini_complete_if_cache",
421562
"gemini_model_complete",
563+
"gemini_embed",
422564
]

0 commit comments

Comments
 (0)