Skip to content

Commit f752c1a

Browse files
feat(langchain): Add support to google_genai provider in init_embeddings (#34388)
1 parent 7902fa3 commit f752c1a

File tree

4 files changed

+39
-24
lines changed

4 files changed

+39
-24
lines changed

libs/langchain/langchain_classic/embeddings/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"azure_openai": "langchain_openai",
1010
"bedrock": "langchain_aws",
1111
"cohere": "langchain_cohere",
12+
"google_genai": "langchain_google_genai",
1213
"google_vertexai": "langchain_google_vertexai",
1314
"huggingface": "langchain_huggingface",
1415
"mistralai": "langchain_mistralai",
@@ -155,6 +156,7 @@ def init_embeddings(
155156
- `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
156157
- `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
157158
- `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere)
159+
- `google_genai` -> [`langchain-google-genai`](https://docs.langchain.com/oss/python/integrations/providers/google)
158160
- `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
159161
- `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface)
160162
- `mistraiai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai)
@@ -207,6 +209,10 @@ def init_embeddings(
207209
from langchain_openai import AzureOpenAIEmbeddings
208210

209211
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
212+
if provider == "google_genai":
213+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
214+
215+
return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs)
210216
if provider == "google_vertexai":
211217
from langchain_google_vertexai import VertexAIEmbeddings
212218

libs/langchain/tests/unit_tests/embeddings/test_base.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,22 @@
99
)
1010

1111

12-
def test_parse_model_string() -> None:
12+
@pytest.mark.parametrize(
13+
("model_string", "expected_provider", "expected_model"),
14+
[
15+
("openai:text-embedding-3-small", "openai", "text-embedding-3-small"),
16+
("bedrock:amazon.titan-embed-text-v1", "bedrock", "amazon.titan-embed-text-v1"),
17+
("huggingface:BAAI/bge-base-en:v1.5", "huggingface", "BAAI/bge-base-en:v1.5"),
18+
("google_genai:gemini-embedding-001", "google_genai", "gemini-embedding-001"),
19+
],
20+
)
21+
def test_parse_model_string(
22+
model_string: str, expected_provider: str, expected_model: str
23+
) -> None:
1324
"""Test parsing model strings into provider and model components."""
14-
assert _parse_model_string("openai:text-embedding-3-small") == (
15-
"openai",
16-
"text-embedding-3-small",
17-
)
18-
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
19-
"bedrock",
20-
"amazon.titan-embed-text-v1",
21-
)
22-
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
23-
"huggingface",
24-
"BAAI/bge-base-en:v1.5",
25+
assert _parse_model_string(model_string) == (
26+
expected_provider,
27+
expected_model,
2528
)
2629

2730

libs/langchain_v1/langchain/embeddings/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"azure_openai": "langchain_openai",
1111
"bedrock": "langchain_aws",
1212
"cohere": "langchain_cohere",
13+
"google_genai": "langchain_google_genai",
1314
"google_vertexai": "langchain_google_vertexai",
1415
"huggingface": "langchain_huggingface",
1516
"mistralai": "langchain_mistralai",
@@ -207,6 +208,10 @@ def init_embeddings(
207208
from langchain_openai import AzureOpenAIEmbeddings
208209

209210
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
211+
if provider == "google_genai":
212+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
213+
214+
return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs)
210215
if provider == "google_vertexai":
211216
from langchain_google_vertexai import VertexAIEmbeddings
212217

libs/langchain_v1/tests/unit_tests/embeddings/test_base.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,20 @@
99
)
1010

1111

12-
def test_parse_model_string() -> None:
12+
@pytest.mark.parametrize(
13+
("model_string", "expected_provider", "expected_model"),
14+
[
15+
("openai:text-embedding-3-small", "openai", "text-embedding-3-small"),
16+
("bedrock:amazon.titan-embed-text-v1", "bedrock", "amazon.titan-embed-text-v1"),
17+
("huggingface:BAAI/bge-base-en:v1.5", "huggingface", "BAAI/bge-base-en:v1.5"),
18+
("google_genai:gemini-embedding-001", "google_genai", "gemini-embedding-001"),
19+
],
20+
)
21+
def test_parse_model_string(model_string: str, expected_provider: str, expected_model: str) -> None:
1322
"""Test parsing model strings into provider and model components."""
14-
assert _parse_model_string("openai:text-embedding-3-small") == (
15-
"openai",
16-
"text-embedding-3-small",
17-
)
18-
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
19-
"bedrock",
20-
"amazon.titan-embed-text-v1",
21-
)
22-
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
23-
"huggingface",
24-
"BAAI/bge-base-en:v1.5",
23+
assert _parse_model_string(model_string) == (
24+
expected_provider,
25+
expected_model,
2526
)
2627

2728

0 commit comments

Comments
 (0)