Skip to content

Commit 5132288

Browse files
taprosoftfzowl
andauthored
feat: add VoyageAI's rerank and embeddings models (#733) #none
* Introducing VoyageAI's rerank and embeddings models * fix: comfort CI * fix: update test case --------- Co-authored-by: fzowl <[email protected]>
1 parent c33bedc commit 5132288

File tree

11 files changed

+194
-5
lines changed

11 files changed

+194
-5
lines changed

.env.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ COHERE_API_KEY=<COHERE_API_KEY>
1919
# settings for Mistral
2020
# MISTRAL_API_KEY=placeholder
2121

22+
# settings for VoyageAI
23+
VOYAGE_API_KEY=<VOYAGE_API_KEY>
24+
2225
# settings for local models
2326
LOCAL_MODEL=qwen2.5:7b
2427
LOCAL_MODEL_EMBEDDINGS=nomic-embed-text

flowsettings.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,25 @@
172172
"default": IS_OPENAI_DEFAULT,
173173
}
174174

175+
VOYAGE_API_KEY = config("VOYAGE_API_KEY", default="")
176+
if VOYAGE_API_KEY:
177+
KH_EMBEDDINGS["voyageai"] = {
178+
"spec": {
179+
"__type__": "kotaemon.embeddings.VoyageAIEmbeddings",
180+
"api_key": VOYAGE_API_KEY,
181+
"model": config("VOYAGE_EMBEDDINGS_MODEL", default="voyage-3-large"),
182+
},
183+
"default": False,
184+
}
185+
KH_RERANKINGS["voyageai"] = {
186+
"spec": {
187+
"__type__": "kotaemon.rerankings.VoyageAIReranking",
188+
"model_name": "rerank-2",
189+
"api_key": VOYAGE_API_KEY,
190+
},
191+
"default": False,
192+
}
193+
175194
if config("LOCAL_MODEL", default=""):
176195
KH_LLMS["ollama"] = {
177196
"spec": {

libs/kotaemon/kotaemon/embeddings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
1313
from .tei_endpoint_embed import TeiEndpointEmbeddings
14+
from .voyageai import VoyageAIEmbeddings
1415

1516
__all__ = [
1617
"BaseEmbeddings",
@@ -25,4 +26,5 @@
2526
"OpenAIEmbeddings",
2627
"AzureOpenAIEmbeddings",
2728
"FastEmbedEmbeddings",
29+
"VoyageAIEmbeddings",
2830
]
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Implements embeddings from [Voyage AI](https://voyageai.com).
2+
"""
3+
4+
import importlib
5+
6+
from kotaemon.base import Document, DocumentWithEmbedding, Param
7+
8+
from .base import BaseEmbeddings
9+
10+
vo = None
11+
12+
13+
def _import_voyageai():
14+
global vo
15+
if not vo:
16+
vo = importlib.import_module("voyageai")
17+
return vo
18+
19+
20+
def _format_output(texts: list[str], embeddings: list[list]):
21+
"""Formats the output of all `.embed` calls.
22+
Args:
23+
texts: List of original documents
24+
embeddings: Embeddings corresponding to each document
25+
"""
26+
return [
27+
DocumentWithEmbedding(content=text, embedding=embedding)
28+
for text, embedding in zip(texts, embeddings)
29+
]
30+
31+
32+
class VoyageAIEmbeddings(BaseEmbeddings):
33+
"""Voyage AI provides best-in-class embedding models and rerankers."""
34+
35+
api_key: str = Param(None, help="Voyage API key", required=False)
36+
model: str = Param(
37+
"voyage-3",
38+
help=(
39+
"Model name to use. The Voyage "
40+
"[documentation](https://docs.voyageai.com/docs/embeddings) "
41+
"provides a list of all available embedding models."
42+
),
43+
required=True,
44+
)
45+
46+
def __init__(self, *args, **kwargs):
47+
super().__init__(*args, **kwargs)
48+
if not self.api_key:
49+
raise ValueError("API key must be provided for VoyageAIEmbeddings.")
50+
51+
self._client = _import_voyageai().Client(api_key=self.api_key)
52+
self._aclient = _import_voyageai().AsyncClient(api_key=self.api_key)
53+
54+
def invoke(
55+
self, text: str | list[str] | Document | list[Document], *args, **kwargs
56+
) -> list[DocumentWithEmbedding]:
57+
texts = [t.content for t in self.prepare_input(text)]
58+
embeddings = self._client.embed(texts, model=self.model).embeddings
59+
return _format_output(texts, embeddings)
60+
61+
async def ainvoke(
62+
self, text: str | list[str] | Document | list[Document], *args, **kwargs
63+
) -> list[DocumentWithEmbedding]:
64+
texts = [t.content for t in self.prepare_input(text)]
65+
embeddings = await self._aclient.embed(texts, model=self.model).embeddings
66+
return _format_output(texts, embeddings)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base import BaseReranking
22
from .cohere import CohereReranking
33
from .tei_fast_rerank import TeiFastReranking
4+
from .voyageai import VoyageAIReranking
45

5-
__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking"]
6+
__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking", "VoyageAIReranking"]
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import annotations
2+
3+
import importlib
4+
5+
from decouple import config
6+
7+
from kotaemon.base import Document, Param
8+
9+
from .base import BaseReranking
10+
11+
vo = None
12+
13+
14+
def _import_voyageai():
15+
global vo
16+
if not vo:
17+
vo = importlib.import_module("voyageai")
18+
return vo
19+
20+
21+
class VoyageAIReranking(BaseReranking):
22+
"""VoyageAI Reranking model"""
23+
24+
model_name: str = Param(
25+
"rerank-2",
26+
help=(
27+
"ID of the model to use. You can go to [Supported Models]"
28+
"(https://docs.voyageai.com/docs/reranker) to see the supported models"
29+
),
30+
required=True,
31+
)
32+
api_key: str = Param(
33+
config("VOYAGE_API_KEY", ""),
34+
help="VoyageAI API key",
35+
required=True,
36+
)
37+
38+
def __init__(self, *args, **kwargs):
39+
super().__init__(*args, **kwargs)
40+
if not self.api_key:
41+
raise ValueError("API key must be provided for VoyageAIEmbeddings.")
42+
43+
self._client = _import_voyageai().Client(api_key=self.api_key)
44+
self._aclient = _import_voyageai().AsyncClient(api_key=self.api_key)
45+
46+
def run(self, documents: list[Document], query: str) -> list[Document]:
47+
"""Use VoyageAI Reranker model to re-order documents
48+
with their relevance score"""
49+
compressed_docs: list[Document] = []
50+
51+
if not documents: # to avoid empty api call
52+
return compressed_docs
53+
54+
_docs = [d.content for d in documents]
55+
response = self._client.rerank(
56+
model=self.model_name, query=query, documents=_docs
57+
)
58+
for r in response.results:
59+
doc = documents[r.index]
60+
doc.metadata["reranking_score"] = r.relevance_score
61+
compressed_docs.append(doc)
62+
63+
return compressed_docs

libs/kotaemon/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ adv = [
9090
"tabulate",
9191
"unstructured>=0.15.8,<0.16",
9292
"wikipedia>=1.4.0,<1.5",
93+
"voyageai>=0.3.0",
9394
]
9495
dev = [
9596
"black",

libs/kotaemon/tests/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ def if_llama_cpp_not_installed():
7070
return False
7171

7272

73+
def if_voyageai_not_installed():
74+
try:
75+
import voyageai # noqa: F401
76+
except ImportError:
77+
return True
78+
else:
79+
return False
80+
81+
7382
skip_when_haystack_not_installed = pytest.mark.skipif(
7483
if_haystack_not_installed(), reason="Haystack is not installed"
7584
)
@@ -97,3 +106,7 @@ def if_llama_cpp_not_installed():
97106
skip_llama_cpp_not_installed = pytest.mark.skipif(
98107
if_llama_cpp_not_installed(), reason="llama_cpp is not installed"
99108
)
109+
110+
skip_when_voyageai_not_installed = pytest.mark.skipif(
111+
if_voyageai_not_installed(), reason="voyageai is not installed"
112+
)

libs/kotaemon/tests/test_embedding_models.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
import json
22
from pathlib import Path
3-
from unittest.mock import patch
3+
from unittest.mock import Mock, patch
44

55
from openai.types.create_embedding_response import CreateEmbeddingResponse
66

7-
from kotaemon.base import Document
7+
from kotaemon.base import Document, DocumentWithEmbedding
88
from kotaemon.embeddings import (
99
AzureOpenAIEmbeddings,
1010
FastEmbedEmbeddings,
1111
LCCohereEmbeddings,
1212
LCHuggingFaceEmbeddings,
1313
OpenAIEmbeddings,
14+
VoyageAIEmbeddings,
1415
)
1516

1617
from .conftest import (
1718
skip_when_cohere_not_installed,
1819
skip_when_fastembed_not_installed,
1920
skip_when_sentence_bert_not_installed,
21+
skip_when_voyageai_not_installed,
2022
)
2123

2224
with open(Path(__file__).parent / "resources" / "embedding_openai_batch.json") as f:
@@ -155,3 +157,16 @@ def test_fastembed_embeddings():
155157
model = FastEmbedEmbeddings()
156158
output = model("Hello World")
157159
assert_embedding_result(output)
160+
161+
162+
voyage_output_mock = Mock()
163+
voyage_output_mock.embeddings = [[1.0, 2.1, 3.2]]
164+
165+
166+
@skip_when_voyageai_not_installed
167+
@patch("voyageai.Client.embed", return_value=voyage_output_mock)
168+
@patch("voyageai.AsyncClient.embed", return_value=voyage_output_mock)
169+
def test_voyageai_embeddings(sync_call, async_call):
170+
model = VoyageAIEmbeddings(api_key="test")
171+
output = model("Hello, world!")
172+
assert all(isinstance(doc, DocumentWithEmbedding) for doc in output)

libs/ktem/ktem/embeddings/manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def load_vendors(self):
6262
LCMistralEmbeddings,
6363
OpenAIEmbeddings,
6464
TeiEndpointEmbeddings,
65+
VoyageAIEmbeddings,
6566
)
6667

6768
self._vendors = [
@@ -73,6 +74,7 @@ def load_vendors(self):
7374
LCGoogleEmbeddings,
7475
LCMistralEmbeddings,
7576
TeiEndpointEmbeddings,
77+
VoyageAIEmbeddings,
7678
]
7779

7880
def __getitem__(self, key: str) -> BaseEmbeddings:

0 commit comments

Comments
 (0)