1
- from typing import Dict , List
1
+ import logging
2
+ import re
3
+ import string
4
+ import threading
5
+ from concurrent .futures import ThreadPoolExecutor , wait
6
+ from typing import Any , Dict , List , Literal , Optional , Tuple
2
7
3
8
from langchain_core .embeddings import Embeddings
9
+ from langchain_core .language_models .llms import create_base_retry_decorator
4
10
from langchain_core .pydantic_v1 import root_validator
5
11
6
12
from langchain_community .llms .vertexai import _VertexAICommon
7
13
from langchain_community .utilities .vertexai import raise_vertex_import_error
8
14
15
+ logger = logging .getLogger (__name__ )
16
+
17
+ _MAX_TOKENS_PER_BATCH = 20000
18
+ _MAX_BATCH_SIZE = 250
19
+ _MIN_BATCH_SIZE = 5
20
+
9
21
10
22
class VertexAIEmbeddings (_VertexAICommon , Embeddings ):
11
23
"""Google Cloud VertexAI embedding models."""
12
24
13
- model_name : str = "textembedding-gecko"
25
+ # Instance context
26
+ instance : Dict [str , Any ] = {} #: :meta private:
14
27
15
28
@root_validator ()
16
29
def validate_environment (cls , values : Dict ) -> Dict :
17
30
"""Validates that the python package exists in environment."""
18
31
cls ._try_init_vertexai (values )
19
32
try :
20
33
from vertexai .language_models import TextEmbeddingModel
34
+
35
+ values ["client" ] = TextEmbeddingModel .from_pretrained (values ["model_name" ])
21
36
except ImportError :
22
37
raise_vertex_import_error ()
23
- values ["client" ] = TextEmbeddingModel .from_pretrained (values ["model_name" ])
24
38
return values
25
39
26
- def embed_documents (
27
- self , texts : List [str ], batch_size : int = 5
40
+ def __init__ (
41
+ self ,
42
+ project : Optional [str ] = None ,
43
+ location : str = "us-central1" ,
44
+ request_parallelism : int = 5 ,
45
+ max_retries : int = 6 ,
46
+ model_name : str = "textembedding-gecko" ,
47
+ credentials : Optional [Any ] = None ,
48
+ ** kwargs : Any ,
49
+ ):
50
+ """Initialize the sentence_transformer."""
51
+ super ().__init__ (
52
+ project = project ,
53
+ location = location ,
54
+ credentials = credentials ,
55
+ request_parallelism = request_parallelism ,
56
+ max_retries = max_retries ,
57
+ model_name = model_name ,
58
+ ** kwargs ,
59
+ )
60
+ self .instance ["max_batch_size" ] = kwargs .get ("max_batch_size" , _MAX_BATCH_SIZE )
61
+ self .instance ["batch_size" ] = self .instance ["max_batch_size" ]
62
+ self .instance ["min_batch_size" ] = kwargs .get ("min_batch_size" , _MIN_BATCH_SIZE )
63
+ self .instance ["min_good_batch_size" ] = self .instance ["min_batch_size" ]
64
+ self .instance ["lock" ] = threading .Lock ()
65
+ self .instance ["batch_size_validated" ] = False
66
+ self .instance ["task_executor" ] = ThreadPoolExecutor (
67
+ max_workers = request_parallelism
68
+ )
69
+ self .instance [
70
+ "embeddings_task_type_supported"
71
+ ] = not self .client ._endpoint_name .endswith ("/textembedding-gecko@001" )
72
+
73
+ @staticmethod
74
+ def _split_by_punctuation (text : str ) -> List [str ]:
75
+ """Splits a string by punctuation and whitespace characters."""
76
+ split_by = string .punctuation + "\t \n "
77
+ pattern = f"([{ split_by } ])"
78
+ # Using re.split to split the text based on the pattern
79
+ return [segment for segment in re .split (pattern , text ) if segment ]
80
+
81
+ @staticmethod
82
+ def _prepare_batches (texts : List [str ], batch_size : int ) -> List [List [str ]]:
83
+ """Splits texts in batches based on current maximum batch size
84
+ and maximum tokens per request.
85
+ """
86
+ text_index = 0
87
+ texts_len = len (texts )
88
+ batch_token_len = 0
89
+ batches : List [List [str ]] = []
90
+ current_batch : List [str ] = []
91
+ if texts_len == 0 :
92
+ return []
93
+ while text_index < texts_len :
94
+ current_text = texts [text_index ]
95
+ # Number of tokens per a text is conservatively estimated
96
+ # as 2 times number of words, punctuation and whitespace characters.
97
+ # Using `count_tokens` API will make batching too expensive.
98
+ # Utilizing a tokenizer, would add a dependency that would not
99
+ # necessarily be reused by the application using this class.
100
+ current_text_token_cnt = (
101
+ len (VertexAIEmbeddings ._split_by_punctuation (current_text )) * 2
102
+ )
103
+ end_of_batch = False
104
+ if current_text_token_cnt > _MAX_TOKENS_PER_BATCH :
105
+ # Current text is too big even for a single batch.
106
+ # Such request will fail, but we still make a batch
107
+ # so that the app can get the error from the API.
108
+ if len (current_batch ) > 0 :
109
+ # Adding current batch if not empty.
110
+ batches .append (current_batch )
111
+ current_batch = [current_text ]
112
+ text_index += 1
113
+ end_of_batch = True
114
+ elif (
115
+ batch_token_len + current_text_token_cnt > _MAX_TOKENS_PER_BATCH
116
+ or len (current_batch ) == batch_size
117
+ ):
118
+ end_of_batch = True
119
+ else :
120
+ if text_index == texts_len - 1 :
121
+ # Last element - even though the batch may be not big,
122
+ # we still need to make it.
123
+ end_of_batch = True
124
+ batch_token_len += current_text_token_cnt
125
+ current_batch .append (current_text )
126
+ text_index += 1
127
+ if end_of_batch :
128
+ batches .append (current_batch )
129
+ current_batch = []
130
+ batch_token_len = 0
131
+ return batches
132
+
133
+ def _get_embeddings_with_retry (
134
+ self , texts : List [str ], embeddings_type : Optional [str ] = None
135
+ ) -> List [List [float ]]:
136
+ """Makes a Vertex AI model request with retry logic."""
137
+ from google .api_core .exceptions import (
138
+ Aborted ,
139
+ DeadlineExceeded ,
140
+ ResourceExhausted ,
141
+ ServiceUnavailable ,
142
+ )
143
+
144
+ errors = [
145
+ ResourceExhausted ,
146
+ ServiceUnavailable ,
147
+ Aborted ,
148
+ DeadlineExceeded ,
149
+ ]
150
+ retry_decorator = create_base_retry_decorator (
151
+ error_types = errors , max_retries = self .max_retries
152
+ )
153
+
154
+ @retry_decorator
155
+ def _completion_with_retry (texts_to_process : List [str ]) -> Any :
156
+ if embeddings_type and self .instance ["embeddings_task_type_supported" ]:
157
+ from vertexai .language_models import TextEmbeddingInput
158
+
159
+ requests = [
160
+ TextEmbeddingInput (text = t , task_type = embeddings_type )
161
+ for t in texts_to_process
162
+ ]
163
+ else :
164
+ requests = texts_to_process
165
+ embeddings = self .client .get_embeddings (requests )
166
+ return [embs .values for embs in embeddings ]
167
+
168
+ return _completion_with_retry (texts )
169
+
170
+ def _prepare_and_validate_batches (
171
+ self , texts : List [str ], embeddings_type : Optional [str ] = None
172
+ ) -> Tuple [List [List [float ]], List [List [str ]]]:
173
+ """Prepares text batches with one-time validation of batch size.
174
+ Batch size varies between GCP regions and individual project quotas.
175
+ # Returns embeddings of the first text batch that went through,
176
+ # and text batches for the rest of the texts.
177
+ """
178
+ from google .api_core .exceptions import InvalidArgument
179
+
180
+ batches = VertexAIEmbeddings ._prepare_batches (
181
+ texts , self .instance ["batch_size" ]
182
+ )
183
+ # If batch size if less or equal to one that went through before,
184
+ # then keep batches as they are.
185
+ if len (batches [0 ]) <= self .instance ["min_good_batch_size" ]:
186
+ return [], batches
187
+ with self .instance ["lock" ]:
188
+ # If largest possible batch size was validated
189
+ # while waiting for the lock, then check for rebuilding
190
+ # our batches, and return.
191
+ if self .instance ["batch_size_validated" ]:
192
+ if len (batches [0 ]) <= self .instance ["batch_size" ]:
193
+ return [], batches
194
+ else :
195
+ return [], VertexAIEmbeddings ._prepare_batches (
196
+ texts , self .instance ["batch_size" ]
197
+ )
198
+ # Figure out largest possible batch size by trying to push
199
+ # batches and lowering their size in half after every failure.
200
+ first_batch = batches [0 ]
201
+ first_result = []
202
+ had_failure = False
203
+ while True :
204
+ try :
205
+ first_result = self ._get_embeddings_with_retry (
206
+ first_batch , embeddings_type
207
+ )
208
+ break
209
+ except InvalidArgument :
210
+ had_failure = True
211
+ first_batch_len = len (first_batch )
212
+ if first_batch_len == self .instance ["min_batch_size" ]:
213
+ raise
214
+ first_batch_len = max (
215
+ self .instance ["min_batch_size" ], int (first_batch_len / 2 )
216
+ )
217
+ first_batch = first_batch [:first_batch_len ]
218
+ first_batch_len = len (first_batch )
219
+ self .instance ["min_good_batch_size" ] = max (
220
+ self .instance ["min_good_batch_size" ], first_batch_len
221
+ )
222
+ # If had a failure and recovered
223
+ # or went through with the max size, then it's a legit batch size.
224
+ if had_failure or first_batch_len == self .instance ["max_batch_size" ]:
225
+ self .instance ["batch_size" ] = first_batch_len
226
+ self .instance ["batch_size_validated" ] = True
227
+ # If batch size was updated,
228
+ # rebuild batches with the new batch size
229
+ # (texts that went through are excluded here).
230
+ if first_batch_len != self .instance ["max_batch_size" ]:
231
+ batches = VertexAIEmbeddings ._prepare_batches (
232
+ texts [first_batch_len :], self .instance ["batch_size" ]
233
+ )
234
+ else :
235
+ # Still figuring out max batch size.
236
+ batches = batches [1 :]
237
+ # Returning embeddings of the first text batch that went through,
238
+ # and text batches for the rest of texts.
239
+ return first_result , batches
240
+
241
+ def embed (
242
+ self ,
243
+ texts : List [str ],
244
+ batch_size : int = 0 ,
245
+ embeddings_task_type : Optional [
246
+ Literal [
247
+ "RETRIEVAL_QUERY" ,
248
+ "RETRIEVAL_DOCUMENT" ,
249
+ "SEMANTIC_SIMILARITY" ,
250
+ "CLASSIFICATION" ,
251
+ "CLUSTERING" ,
252
+ ]
253
+ ] = None ,
28
254
) -> List [List [float ]]:
29
- """Embed a list of strings. Vertex AI currently
30
- sets a max batch size of 5 strings.
255
+ """Embed a list of strings.
31
256
32
257
Args:
33
258
texts: List[str] The list of strings to embed.
34
- batch_size: [int] The batch size of embeddings to send to the model
259
+ batch_size: [int] The batch size of embeddings to send to the model.
260
+ If zero, then the largest batch size will be detected dynamically
261
+ at the first request, starting from 250, down to 5.
262
+ embeddings_task_type: [str] optional embeddings task type,
263
+ one of the following
264
+ RETRIEVAL_QUERY - Text is a query
265
+ in a search/retrieval setting.
266
+ RETRIEVAL_DOCUMENT - Text is a document
267
+ in a search/retrieval setting.
268
+ SEMANTIC_SIMILARITY - Embeddings will be used
269
+ for Semantic Textual Similarity (STS).
270
+ CLASSIFICATION - Embeddings will be used for classification.
271
+ CLUSTERING - Embeddings will be used for clustering.
35
272
36
273
Returns:
37
274
List of embeddings, one for each text.
38
275
"""
39
- embeddings = []
40
- for batch in range (0 , len (texts ), batch_size ):
41
- text_batch = texts [batch : batch + batch_size ]
42
- embeddings_batch = self .client .get_embeddings (text_batch )
43
- embeddings .extend ([el .values for el in embeddings_batch ])
276
+ if len (texts ) == 0 :
277
+ return []
278
+ embeddings : List [List [float ]] = []
279
+ first_batch_result : List [List [float ]] = []
280
+ if batch_size > 0 :
281
+ # Fixed batch size.
282
+ batches = VertexAIEmbeddings ._prepare_batches (texts , batch_size )
283
+ else :
284
+ # Dynamic batch size, starting from 250 at the first call.
285
+ first_batch_result , batches = self ._prepare_and_validate_batches (
286
+ texts , embeddings_task_type
287
+ )
288
+ # First batch result may have some embeddings already.
289
+ # In such case, batches have texts that were not processed yet.
290
+ embeddings .extend (first_batch_result )
291
+ tasks = []
292
+ for batch in batches :
293
+ tasks .append (
294
+ self .instance ["task_executor" ].submit (
295
+ self ._get_embeddings_with_retry ,
296
+ texts = batch ,
297
+ embeddings_type = embeddings_task_type ,
298
+ )
299
+ )
300
+ if len (tasks ) > 0 :
301
+ wait (tasks )
302
+ for t in tasks :
303
+ embeddings .extend (t .result ())
44
304
return embeddings
45
305
306
+ def embed_documents (
307
+ self , texts : List [str ], batch_size : int = 0
308
+ ) -> List [List [float ]]:
309
+ """Embed a list of documents.
310
+
311
+ Args:
312
+ texts: List[str] The list of texts to embed.
313
+ batch_size: [int] The batch size of embeddings to send to the model.
314
+ If zero, then the largest batch size will be detected dynamically
315
+ at the first request, starting from 250, down to 5.
316
+
317
+ Returns:
318
+ List of embeddings, one for each text.
319
+ """
320
+ return self .embed (texts , batch_size , "RETRIEVAL_DOCUMENT" )
321
+
46
322
def embed_query (self , text : str ) -> List [float ]:
47
323
"""Embed a text.
48
324
@@ -52,5 +328,5 @@ def embed_query(self, text: str) -> List[float]:
52
328
Returns:
53
329
Embedding for the text.
54
330
"""
55
- embeddings = self .client . get_embeddings ([text ])
56
- return embeddings [0 ]. values
331
+ embeddings = self .embed ([text ], 1 , "RETRIEVAL_QUERY" )
332
+ return embeddings [0 ]
0 commit comments