Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lightrag/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__api_version__ = "0256"
__api_version__ = "0257"
8 changes: 6 additions & 2 deletions lightrag/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,12 @@ def parse_args() -> argparse.Namespace:

# Inject model configuration
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
# EMBEDDING_MODEL defaults to None - each binding will use its own default model
# e.g., OpenAI uses "text-embedding-3-small", Jina uses "jina-embeddings-v4"
args.embedding_model = get_env_value("EMBEDDING_MODEL", None, special_none=True)
# EMBEDDING_DIM defaults to None - each binding will use its own default dimension
# Value is inherited from provider defaults via wrap_embedding_func_with_attrs decorator
args.embedding_dim = get_env_value("EMBEDDING_DIM", None, int, special_none=True)
args.embedding_send_dim = get_env_value("EMBEDDING_SEND_DIM", False, bool)

# Inject chunk configuration
Expand Down
101 changes: 68 additions & 33 deletions lightrag/api/lightrag_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,17 @@ def create_optimized_embedding_function(
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
4. Returns a properly configured EmbeddingFunc instance

Configuration Rules:
- When EMBEDDING_MODEL is not set: Uses provider's default model and dimension
(e.g., jina-embeddings-v4 with 2048 dims, text-embedding-3-small with 1536 dims)
- When EMBEDDING_MODEL is set to a custom model: User MUST also set EMBEDDING_DIM
to match the custom model's dimension (e.g., for jina-embeddings-v3, set EMBEDDING_DIM=1024)

Note: The embedding_dim parameter is automatically injected by EmbeddingFunc wrapper
when send_dimensions=True (enabled for Jina and Gemini bindings). This wrapper calls
the underlying provider function directly (.func) to avoid double-wrapping, so we must
explicitly pass embedding_dim to the provider's underlying function.
"""

# Step 1: Import provider function and extract default attributes
Expand Down Expand Up @@ -713,6 +724,7 @@ def create_optimized_embedding_function(
)

# Step 3: Create optimized embedding function (calls underlying function directly)
# Note: When model is None, each binding will use its own default model
async def optimized_embedding_function(texts, embedding_dim=None):
try:
if binding == "lollms":
Expand All @@ -724,9 +736,9 @@ async def optimized_embedding_function(texts, embedding_dim=None):
if isinstance(lollms_embed, EmbeddingFunc)
else lollms_embed
)
return await actual_func(
texts, embed_model=model, host=host, api_key=api_key
)
# lollms embed_model is not used (server uses configured vectorizer)
# Only pass base_url and api_key
return await actual_func(texts, base_url=host, api_key=api_key)
elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed

Expand All @@ -745,13 +757,16 @@ async def optimized_embedding_function(texts, embedding_dim=None):

ollama_options = OllamaEmbeddingOptions.options_dict(args)

return await actual_func(
texts,
embed_model=model,
host=host,
api_key=api_key,
options=ollama_options,
)
# Pass embed_model only if provided, let function use its default (bge-m3:latest)
kwargs = {
"texts": texts,
"host": host,
"api_key": api_key,
"options": ollama_options,
}
if model:
kwargs["embed_model"] = model
return await actual_func(**kwargs)
elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed

Expand All @@ -760,7 +775,11 @@ async def optimized_embedding_function(texts, embedding_dim=None):
if isinstance(azure_openai_embed, EmbeddingFunc)
else azure_openai_embed
)
return await actual_func(texts, model=model, api_key=api_key)
# Pass model only if provided, let function use its default otherwise
kwargs = {"texts": texts, "api_key": api_key}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed

Expand All @@ -769,7 +788,11 @@ async def optimized_embedding_function(texts, embedding_dim=None):
if isinstance(bedrock_embed, EmbeddingFunc)
else bedrock_embed
)
return await actual_func(texts, model=model)
# Pass model only if provided, let function use its default otherwise
kwargs = {"texts": texts}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
elif binding == "jina":
from lightrag.llm.jina import jina_embed

Expand All @@ -778,12 +801,16 @@ async def optimized_embedding_function(texts, embedding_dim=None):
if isinstance(jina_embed, EmbeddingFunc)
else jina_embed
)
return await actual_func(
texts,
embedding_dim=embedding_dim,
base_url=host,
api_key=api_key,
)
# Pass model only if provided, let function use its default (jina-embeddings-v4)
kwargs = {
"texts": texts,
"embedding_dim": embedding_dim,
"base_url": host,
"api_key": api_key,
}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed

Expand All @@ -801,14 +828,19 @@ async def optimized_embedding_function(texts, embedding_dim=None):

gemini_options = GeminiEmbeddingOptions.options_dict(args)

return await actual_func(
texts,
model=model,
base_url=host,
api_key=api_key,
embedding_dim=embedding_dim,
task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"),
)
# Pass model only if provided, let function use its default (gemini-embedding-001)
kwargs = {
"texts": texts,
"base_url": host,
"api_key": api_key,
"embedding_dim": embedding_dim,
"task_type": gemini_options.get(
"task_type", "RETRIEVAL_DOCUMENT"
),
}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
else: # openai and compatible
from lightrag.llm.openai import openai_embed

Expand All @@ -817,13 +849,16 @@ async def optimized_embedding_function(texts, embedding_dim=None):
if isinstance(openai_embed, EmbeddingFunc)
else openai_embed
)
return await actual_func(
texts,
model=model,
base_url=host,
api_key=api_key,
embedding_dim=embedding_dim,
)
# Pass model only if provided, let function use its default (text-embedding-3-small)
kwargs = {
"texts": texts,
"base_url": host,
"api_key": api_key,
"embedding_dim": embedding_dim,
}
if model:
kwargs["model"] = model
return await actual_func(**kwargs)
except ImportError as e:
raise Exception(f"Failed to import {binding} embedding: {e}")

Expand Down
5 changes: 4 additions & 1 deletion lightrag/llm/jina.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ async def fetch_data(url, headers, data):
)
async def jina_embed(
texts: list[str],
model: str = "jina-embeddings-v4",
embedding_dim: int = 2048,
late_chunking: bool = False,
base_url: str = None,
Expand All @@ -78,6 +79,8 @@ async def jina_embed(

Args:
texts: List of texts to embed.
model: The Jina embedding model to use (default: jina-embeddings-v4).
Supported models: jina-embeddings-v3, jina-embeddings-v4, etc.
embedding_dim: The embedding dimensions (default: 2048 for jina-embeddings-v4).
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
Expand Down Expand Up @@ -107,7 +110,7 @@ async def jina_embed(
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
}
data = {
"model": "jina-embeddings-v4",
"model": model,
"task": "text-matching",
"dimensions": embedding_dim,
"embedding_type": "base64",
Expand Down
4 changes: 3 additions & 1 deletion lightrag/llm/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ async def ollama_model_complete(


@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
async def ollama_embed(
texts: list[str], embed_model: str = "bge-m3:latest", **kwargs
) -> np.ndarray:
api_key = kwargs.pop("api_key", None)
if not api_key:
api_key = os.getenv("OLLAMA_API_KEY")
Expand Down
Loading