Skip to content

Commit 11fda49

Browse files
authored
community[minor]: New model parameters and dynamic batching for VertexAIEmbeddings (#13999)
- **Description:** VertexAIEmbeddings performance improvements - **Twitter handle:** @vladkol ## Improvements - Dynamic batch size, starting from 250, lowering down to 5. Batch size varies across regions. Some regions support larger batches, and it significantly improves performance. When running large batches of texts in `us-central1`, performance gain can be up to 3.5x. The dynamic batching also makes sure every batch is below 20K token limit. - New model parameter `embeddings_type` that translates to `task_type` parameter of the API. Newer model versions support [different embeddings task types](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings#api_changes_to_models_released_on_or_after_august_2023).
1 parent 2e6a9e6 commit 11fda49

File tree

3 files changed

+366
-17
lines changed

3 files changed

+366
-17
lines changed
Lines changed: 291 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,324 @@
1-
from typing import Dict, List
1+
import logging
2+
import re
3+
import string
4+
import threading
5+
from concurrent.futures import ThreadPoolExecutor, wait
6+
from typing import Any, Dict, List, Literal, Optional, Tuple
27

38
from langchain_core.embeddings import Embeddings
9+
from langchain_core.language_models.llms import create_base_retry_decorator
410
from langchain_core.pydantic_v1 import root_validator
511

612
from langchain_community.llms.vertexai import _VertexAICommon
713
from langchain_community.utilities.vertexai import raise_vertex_import_error
814

15+
logger = logging.getLogger(__name__)
16+
17+
_MAX_TOKENS_PER_BATCH = 20000
18+
_MAX_BATCH_SIZE = 250
19+
_MIN_BATCH_SIZE = 5
20+
921

1022
class VertexAIEmbeddings(_VertexAICommon, Embeddings):
1123
"""Google Cloud VertexAI embedding models."""
1224

13-
model_name: str = "textembedding-gecko"
25+
# Instance context
26+
instance: Dict[str, Any] = {} #: :meta private:
1427

1528
@root_validator()
1629
def validate_environment(cls, values: Dict) -> Dict:
1730
"""Validates that the python package exists in environment."""
1831
cls._try_init_vertexai(values)
1932
try:
2033
from vertexai.language_models import TextEmbeddingModel
34+
35+
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
2136
except ImportError:
2237
raise_vertex_import_error()
23-
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
2438
return values
2539

26-
def embed_documents(
27-
self, texts: List[str], batch_size: int = 5
40+
def __init__(
41+
self,
42+
project: Optional[str] = None,
43+
location: str = "us-central1",
44+
request_parallelism: int = 5,
45+
max_retries: int = 6,
46+
model_name: str = "textembedding-gecko",
47+
credentials: Optional[Any] = None,
48+
**kwargs: Any,
49+
):
50+
"""Initialize the sentence_transformer."""
51+
super().__init__(
52+
project=project,
53+
location=location,
54+
credentials=credentials,
55+
request_parallelism=request_parallelism,
56+
max_retries=max_retries,
57+
model_name=model_name,
58+
**kwargs,
59+
)
60+
self.instance["max_batch_size"] = kwargs.get("max_batch_size", _MAX_BATCH_SIZE)
61+
self.instance["batch_size"] = self.instance["max_batch_size"]
62+
self.instance["min_batch_size"] = kwargs.get("min_batch_size", _MIN_BATCH_SIZE)
63+
self.instance["min_good_batch_size"] = self.instance["min_batch_size"]
64+
self.instance["lock"] = threading.Lock()
65+
self.instance["batch_size_validated"] = False
66+
self.instance["task_executor"] = ThreadPoolExecutor(
67+
max_workers=request_parallelism
68+
)
69+
self.instance[
70+
"embeddings_task_type_supported"
71+
] = not self.client._endpoint_name.endswith("/textembedding-gecko@001")
72+
73+
@staticmethod
74+
def _split_by_punctuation(text: str) -> List[str]:
75+
"""Splits a string by punctuation and whitespace characters."""
76+
split_by = string.punctuation + "\t\n "
77+
pattern = f"([{split_by}])"
78+
# Using re.split to split the text based on the pattern
79+
return [segment for segment in re.split(pattern, text) if segment]
80+
81+
@staticmethod
82+
def _prepare_batches(texts: List[str], batch_size: int) -> List[List[str]]:
83+
"""Splits texts in batches based on current maximum batch size
84+
and maximum tokens per request.
85+
"""
86+
text_index = 0
87+
texts_len = len(texts)
88+
batch_token_len = 0
89+
batches: List[List[str]] = []
90+
current_batch: List[str] = []
91+
if texts_len == 0:
92+
return []
93+
while text_index < texts_len:
94+
current_text = texts[text_index]
95+
# Number of tokens per a text is conservatively estimated
96+
# as 2 times number of words, punctuation and whitespace characters.
97+
# Using `count_tokens` API will make batching too expensive.
98+
# Utilizing a tokenizer, would add a dependency that would not
99+
# necessarily be reused by the application using this class.
100+
current_text_token_cnt = (
101+
len(VertexAIEmbeddings._split_by_punctuation(current_text)) * 2
102+
)
103+
end_of_batch = False
104+
if current_text_token_cnt > _MAX_TOKENS_PER_BATCH:
105+
# Current text is too big even for a single batch.
106+
# Such request will fail, but we still make a batch
107+
# so that the app can get the error from the API.
108+
if len(current_batch) > 0:
109+
# Adding current batch if not empty.
110+
batches.append(current_batch)
111+
current_batch = [current_text]
112+
text_index += 1
113+
end_of_batch = True
114+
elif (
115+
batch_token_len + current_text_token_cnt > _MAX_TOKENS_PER_BATCH
116+
or len(current_batch) == batch_size
117+
):
118+
end_of_batch = True
119+
else:
120+
if text_index == texts_len - 1:
121+
# Last element - even though the batch may be not big,
122+
# we still need to make it.
123+
end_of_batch = True
124+
batch_token_len += current_text_token_cnt
125+
current_batch.append(current_text)
126+
text_index += 1
127+
if end_of_batch:
128+
batches.append(current_batch)
129+
current_batch = []
130+
batch_token_len = 0
131+
return batches
132+
133+
def _get_embeddings_with_retry(
134+
self, texts: List[str], embeddings_type: Optional[str] = None
135+
) -> List[List[float]]:
136+
"""Makes a Vertex AI model request with retry logic."""
137+
from google.api_core.exceptions import (
138+
Aborted,
139+
DeadlineExceeded,
140+
ResourceExhausted,
141+
ServiceUnavailable,
142+
)
143+
144+
errors = [
145+
ResourceExhausted,
146+
ServiceUnavailable,
147+
Aborted,
148+
DeadlineExceeded,
149+
]
150+
retry_decorator = create_base_retry_decorator(
151+
error_types=errors, max_retries=self.max_retries
152+
)
153+
154+
@retry_decorator
155+
def _completion_with_retry(texts_to_process: List[str]) -> Any:
156+
if embeddings_type and self.instance["embeddings_task_type_supported"]:
157+
from vertexai.language_models import TextEmbeddingInput
158+
159+
requests = [
160+
TextEmbeddingInput(text=t, task_type=embeddings_type)
161+
for t in texts_to_process
162+
]
163+
else:
164+
requests = texts_to_process
165+
embeddings = self.client.get_embeddings(requests)
166+
return [embs.values for embs in embeddings]
167+
168+
return _completion_with_retry(texts)
169+
170+
def _prepare_and_validate_batches(
171+
self, texts: List[str], embeddings_type: Optional[str] = None
172+
) -> Tuple[List[List[float]], List[List[str]]]:
173+
"""Prepares text batches with one-time validation of batch size.
174+
Batch size varies between GCP regions and individual project quotas.
175+
# Returns embeddings of the first text batch that went through,
176+
# and text batches for the rest of the texts.
177+
"""
178+
from google.api_core.exceptions import InvalidArgument
179+
180+
batches = VertexAIEmbeddings._prepare_batches(
181+
texts, self.instance["batch_size"]
182+
)
183+
# If batch size if less or equal to one that went through before,
184+
# then keep batches as they are.
185+
if len(batches[0]) <= self.instance["min_good_batch_size"]:
186+
return [], batches
187+
with self.instance["lock"]:
188+
# If largest possible batch size was validated
189+
# while waiting for the lock, then check for rebuilding
190+
# our batches, and return.
191+
if self.instance["batch_size_validated"]:
192+
if len(batches[0]) <= self.instance["batch_size"]:
193+
return [], batches
194+
else:
195+
return [], VertexAIEmbeddings._prepare_batches(
196+
texts, self.instance["batch_size"]
197+
)
198+
# Figure out largest possible batch size by trying to push
199+
# batches and lowering their size in half after every failure.
200+
first_batch = batches[0]
201+
first_result = []
202+
had_failure = False
203+
while True:
204+
try:
205+
first_result = self._get_embeddings_with_retry(
206+
first_batch, embeddings_type
207+
)
208+
break
209+
except InvalidArgument:
210+
had_failure = True
211+
first_batch_len = len(first_batch)
212+
if first_batch_len == self.instance["min_batch_size"]:
213+
raise
214+
first_batch_len = max(
215+
self.instance["min_batch_size"], int(first_batch_len / 2)
216+
)
217+
first_batch = first_batch[:first_batch_len]
218+
first_batch_len = len(first_batch)
219+
self.instance["min_good_batch_size"] = max(
220+
self.instance["min_good_batch_size"], first_batch_len
221+
)
222+
# If had a failure and recovered
223+
# or went through with the max size, then it's a legit batch size.
224+
if had_failure or first_batch_len == self.instance["max_batch_size"]:
225+
self.instance["batch_size"] = first_batch_len
226+
self.instance["batch_size_validated"] = True
227+
# If batch size was updated,
228+
# rebuild batches with the new batch size
229+
# (texts that went through are excluded here).
230+
if first_batch_len != self.instance["max_batch_size"]:
231+
batches = VertexAIEmbeddings._prepare_batches(
232+
texts[first_batch_len:], self.instance["batch_size"]
233+
)
234+
else:
235+
# Still figuring out max batch size.
236+
batches = batches[1:]
237+
# Returning embeddings of the first text batch that went through,
238+
# and text batches for the rest of texts.
239+
return first_result, batches
240+
241+
def embed(
242+
self,
243+
texts: List[str],
244+
batch_size: int = 0,
245+
embeddings_task_type: Optional[
246+
Literal[
247+
"RETRIEVAL_QUERY",
248+
"RETRIEVAL_DOCUMENT",
249+
"SEMANTIC_SIMILARITY",
250+
"CLASSIFICATION",
251+
"CLUSTERING",
252+
]
253+
] = None,
28254
) -> List[List[float]]:
29-
"""Embed a list of strings. Vertex AI currently
30-
sets a max batch size of 5 strings.
255+
"""Embed a list of strings.
31256
32257
Args:
33258
texts: List[str] The list of strings to embed.
34-
batch_size: [int] The batch size of embeddings to send to the model
259+
batch_size: [int] The batch size of embeddings to send to the model.
260+
If zero, then the largest batch size will be detected dynamically
261+
at the first request, starting from 250, down to 5.
262+
embeddings_task_type: [str] optional embeddings task type,
263+
one of the following
264+
RETRIEVAL_QUERY - Text is a query
265+
in a search/retrieval setting.
266+
RETRIEVAL_DOCUMENT - Text is a document
267+
in a search/retrieval setting.
268+
SEMANTIC_SIMILARITY - Embeddings will be used
269+
for Semantic Textual Similarity (STS).
270+
CLASSIFICATION - Embeddings will be used for classification.
271+
CLUSTERING - Embeddings will be used for clustering.
35272
36273
Returns:
37274
List of embeddings, one for each text.
38275
"""
39-
embeddings = []
40-
for batch in range(0, len(texts), batch_size):
41-
text_batch = texts[batch : batch + batch_size]
42-
embeddings_batch = self.client.get_embeddings(text_batch)
43-
embeddings.extend([el.values for el in embeddings_batch])
276+
if len(texts) == 0:
277+
return []
278+
embeddings: List[List[float]] = []
279+
first_batch_result: List[List[float]] = []
280+
if batch_size > 0:
281+
# Fixed batch size.
282+
batches = VertexAIEmbeddings._prepare_batches(texts, batch_size)
283+
else:
284+
# Dynamic batch size, starting from 250 at the first call.
285+
first_batch_result, batches = self._prepare_and_validate_batches(
286+
texts, embeddings_task_type
287+
)
288+
# First batch result may have some embeddings already.
289+
# In such case, batches have texts that were not processed yet.
290+
embeddings.extend(first_batch_result)
291+
tasks = []
292+
for batch in batches:
293+
tasks.append(
294+
self.instance["task_executor"].submit(
295+
self._get_embeddings_with_retry,
296+
texts=batch,
297+
embeddings_type=embeddings_task_type,
298+
)
299+
)
300+
if len(tasks) > 0:
301+
wait(tasks)
302+
for t in tasks:
303+
embeddings.extend(t.result())
44304
return embeddings
45305

306+
def embed_documents(
307+
self, texts: List[str], batch_size: int = 0
308+
) -> List[List[float]]:
309+
"""Embed a list of documents.
310+
311+
Args:
312+
texts: List[str] The list of texts to embed.
313+
batch_size: [int] The batch size of embeddings to send to the model.
314+
If zero, then the largest batch size will be detected dynamically
315+
at the first request, starting from 250, down to 5.
316+
317+
Returns:
318+
List of embeddings, one for each text.
319+
"""
320+
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")
321+
46322
def embed_query(self, text: str) -> List[float]:
47323
"""Embed a text.
48324
@@ -52,5 +328,5 @@ def embed_query(self, text: str) -> List[float]:
52328
Returns:
53329
Embedding for the text.
54330
"""
55-
embeddings = self.client.get_embeddings([text])
56-
return embeddings[0].values
331+
embeddings = self.embed([text], 1, "RETRIEVAL_QUERY")
332+
return embeddings[0]

libs/community/tests/integration_tests/embeddings/test_vertexai.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Test Vertex AI API wrapper.
2-
In order to run this test, you need to install VertexAI SDK
2+
In order to run this test, you need to install VertexAI SDK
33
pip install google-cloud-aiplatform>=1.35.0
44
5-
Your end-user credentials would be used to make the calls (make sure you've run
5+
Your end-user credentials would be used to make the calls (make sure you've run
66
`gcloud auth login` first).
77
"""
88
from langchain_community.embeddings import VertexAIEmbeddings
@@ -24,6 +24,16 @@ def test_embedding_query() -> None:
2424
assert len(output) == 768
2525

2626

27+
def test_large_batches() -> None:
28+
documents = ["foo bar" for _ in range(0, 251)]
29+
model_uscentral1 = VertexAIEmbeddings(location="us-central1")
30+
model_asianortheast1 = VertexAIEmbeddings(location="asia-northeast1")
31+
model_uscentral1.embed_documents(documents)
32+
model_asianortheast1.embed_documents(documents)
33+
assert model_uscentral1.instance["batch_size"] >= 250
34+
assert model_asianortheast1.instance["batch_size"] < 50
35+
36+
2737
def test_paginated_texts() -> None:
2838
documents = [
2939
"foo bar",

0 commit comments

Comments
 (0)