Skip to content

Commit a2bbe3d

Browse files
Harrison/mmr support for opensearch (#6349)
Co-authored-by: Mehmet Öner Yalçın <[email protected]>
1 parent 2eea5d4 commit a2bbe3d

File tree

2 files changed

+140
-60
lines changed

2 files changed

+140
-60
lines changed

docs/extras/modules/data_connection/vectorstores/integrations/opensearch.ipynb

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,7 @@
129129
"cell_type": "code",
130130
"execution_count": null,
131131
"id": "db3fa309",
132-
"metadata": {
133-
"pycharm": {
134-
"name": "#%%\n"
135-
}
136-
},
132+
"metadata": {},
137133
"outputs": [],
138134
"source": [
139135
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
@@ -144,11 +140,7 @@
144140
"cell_type": "code",
145141
"execution_count": null,
146142
"id": "c160d5bb",
147-
"metadata": {
148-
"pycharm": {
149-
"name": "#%%\n"
150-
}
151-
},
143+
"metadata": {},
152144
"outputs": [],
153145
"source": [
154146
"print(docs[0].page_content)"
@@ -158,11 +150,7 @@
158150
"cell_type": "code",
159151
"execution_count": null,
160152
"id": "96215c90",
161-
"metadata": {
162-
"pycharm": {
163-
"name": "#%%\n"
164-
}
165-
},
153+
"metadata": {},
166154
"outputs": [],
167155
"source": [
168156
"docsearch = OpenSearchVectorSearch.from_documents(\n",
@@ -183,11 +171,7 @@
183171
"cell_type": "code",
184172
"execution_count": null,
185173
"id": "62a7cea0",
186-
"metadata": {
187-
"pycharm": {
188-
"name": "#%%\n"
189-
}
190-
},
174+
"metadata": {},
191175
"outputs": [],
192176
"source": [
193177
"print(docs[0].page_content)"
@@ -207,11 +191,7 @@
207191
"cell_type": "code",
208192
"execution_count": null,
209193
"id": "0a8e3c0e",
210-
"metadata": {
211-
"pycharm": {
212-
"name": "#%%\n"
213-
}
214-
},
194+
"metadata": {},
215195
"outputs": [],
216196
"source": [
217197
"docsearch = OpenSearchVectorSearch.from_documents(\n",
@@ -230,11 +210,7 @@
230210
"cell_type": "code",
231211
"execution_count": null,
232212
"id": "92bc40db",
233-
"metadata": {
234-
"pycharm": {
235-
"name": "#%%\n"
236-
}
237-
},
213+
"metadata": {},
238214
"outputs": [],
239215
"source": [
240216
"print(docs[0].page_content)"
@@ -254,11 +230,7 @@
254230
"cell_type": "code",
255231
"execution_count": null,
256232
"id": "6d9f436e",
257-
"metadata": {
258-
"pycharm": {
259-
"name": "#%%\n"
260-
}
261-
},
233+
"metadata": {},
262234
"outputs": [],
263235
"source": [
264236
"docsearch = OpenSearchVectorSearch.from_documents(\n",
@@ -278,16 +250,34 @@
278250
"cell_type": "code",
279251
"execution_count": null,
280252
"id": "8ca50bce",
281-
"metadata": {
282-
"pycharm": {
283-
"name": "#%%\n"
284-
}
285-
},
253+
"metadata": {},
286254
"outputs": [],
287255
"source": [
288256
"print(docs[0].page_content)"
289257
]
290258
},
259+
{
260+
"cell_type": "markdown",
261+
"source": [
262+
"### Maximum marginal relevance search (MMR)\n",
263+
"If you’d like to look up for some similar documents, but you’d also like to receive diverse results, MMR is method you should consider. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents."
264+
],
265+
"metadata": {
266+
"collapsed": false
267+
}
268+
},
269+
{
270+
"cell_type": "code",
271+
"execution_count": null,
272+
"outputs": [],
273+
"source": [
274+
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
275+
"docs = docsearch.max_marginal_relevance_search(query, k=2, fetch_k=10, lambda_param=0.5)"
276+
],
277+
"metadata": {
278+
"collapsed": false
279+
}
280+
},
291281
{
292282
"cell_type": "markdown",
293283
"id": "73264864",

langchain/vectorstores/opensearch_vector_search.py

Lines changed: 110 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import uuid
55
from typing import Any, Dict, Iterable, List, Optional, Tuple
66

7-
from langchain.docstore.document import Document
7+
import numpy as np
8+
89
from langchain.embeddings.base import Embeddings
10+
from langchain.schema import Document
911
from langchain.utils import get_from_dict_or_env
1012
from langchain.vectorstores.base import VectorStore
13+
from langchain.vectorstores.utils import maximal_marginal_relevance
1114

1215
IMPORT_OPENSEARCH_PY_ERROR = (
1316
"Could not import OpenSearch. Please install it with `pip install opensearch-py`."
@@ -76,9 +79,12 @@ def _bulk_ingest_embeddings(
7679
metadatas: Optional[List[dict]] = None,
7780
vector_field: str = "vector_field",
7881
text_field: str = "text",
79-
mapping: Dict = {},
82+
mapping: Optional[Dict] = None,
8083
) -> List[str]:
8184
"""Bulk Ingest Embeddings into given index."""
85+
if not mapping:
86+
mapping = dict()
87+
8288
bulk = _import_bulk()
8389
not_found_error = _import_not_found_error()
8490
requests = []
@@ -201,10 +207,14 @@ def _approximate_search_query_with_lucene_filter(
201207
def _default_script_query(
202208
query_vector: List[float],
203209
space_type: str = "l2",
204-
pre_filter: Dict = MATCH_ALL_QUERY,
210+
pre_filter: Optional[Dict] = None,
205211
vector_field: str = "vector_field",
206212
) -> Dict:
207213
"""For Script Scoring Search, this is the default query."""
214+
215+
if not pre_filter:
216+
pre_filter = MATCH_ALL_QUERY
217+
208218
return {
209219
"query": {
210220
"script_score": {
@@ -245,10 +255,14 @@ def __get_painless_scripting_source(
245255
def _default_painless_scripting_query(
246256
query_vector: List[float],
247257
space_type: str = "l2Squared",
248-
pre_filter: Dict = MATCH_ALL_QUERY,
258+
pre_filter: Optional[Dict] = None,
249259
vector_field: str = "vector_field",
250260
) -> Dict:
251261
"""For Painless Scripting Search, this is the default query."""
262+
263+
if not pre_filter:
264+
pre_filter = MATCH_ALL_QUERY
265+
252266
source = __get_painless_scripting_source(space_type, query_vector)
253267
return {
254268
"query": {
@@ -355,7 +369,7 @@ def similarity_search(
355369
) -> List[Document]:
356370
"""Return docs most similar to query.
357371
358-
By default supports Approximate Search.
372+
By default, supports Approximate Search.
359373
Also supports Script Scoring and Painless Scripting.
360374
361375
Args:
@@ -413,7 +427,7 @@ def similarity_search_with_score(
413427
) -> List[Tuple[Document, float]]:
414428
"""Return docs and it's scores most similar to query.
415429
416-
By default supports Approximate Search.
430+
By default, supports Approximate Search.
417431
Also supports Script Scoring and Painless Scripting.
418432
419433
Args:
@@ -426,10 +440,47 @@ def similarity_search_with_score(
426440
Optional Args:
427441
same as `similarity_search`
428442
"""
429-
embedding = self.embedding_function.embed_query(query)
430-
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
443+
431444
text_field = _get_kwargs_value(kwargs, "text_field", "text")
432445
metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata")
446+
447+
hits = self._raw_similarity_search_with_score(query=query, k=k, **kwargs)
448+
449+
documents_with_scores = [
450+
(
451+
Document(
452+
page_content=hit["_source"][text_field],
453+
metadata=hit["_source"]
454+
if metadata_field == "*" or metadata_field not in hit["_source"]
455+
else hit["_source"][metadata_field],
456+
),
457+
hit["_score"],
458+
)
459+
for hit in hits
460+
]
461+
return documents_with_scores
462+
463+
def _raw_similarity_search_with_score(
464+
self, query: str, k: int = 4, **kwargs: Any
465+
) -> List[dict]:
466+
"""Return raw opensearch documents (dict) including vectors,
467+
scores most similar to query.
468+
469+
By default, supports Approximate Search.
470+
Also supports Script Scoring and Painless Scripting.
471+
472+
Args:
473+
query: Text to look up documents similar to.
474+
k: Number of Documents to return. Defaults to 4.
475+
476+
Returns:
477+
List of dict with its scores most similar to the query.
478+
479+
Optional Args:
480+
same as `similarity_search`
481+
"""
482+
embedding = self.embedding_function.embed_query(query)
483+
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
433484
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
434485

435486
if search_type == "approximate_search":
@@ -473,20 +524,59 @@ def similarity_search_with_score(
473524
raise ValueError("Invalid `search_type` provided as an argument")
474525

475526
response = self.client.search(index=self.index_name, body=search_query)
476-
hits = [hit for hit in response["hits"]["hits"][:k]]
477-
documents_with_scores = [
478-
(
479-
Document(
480-
page_content=hit["_source"][text_field],
481-
metadata=hit["_source"]
482-
if metadata_field == "*" or metadata_field not in hit["_source"]
483-
else hit["_source"][metadata_field],
484-
),
485-
hit["_score"],
527+
528+
return [hit for hit in response["hits"]["hits"][:k]]
529+
530+
def max_marginal_relevance_search(
531+
self,
532+
query: str,
533+
k: int = 4,
534+
fetch_k: int = 20,
535+
lambda_mult: float = 0.5,
536+
**kwargs: Any,
537+
) -> list[Document]:
538+
"""Return docs selected using the maximal marginal relevance.
539+
540+
Maximal marginal relevance optimizes for similarity to query AND diversity
541+
among selected documents.
542+
543+
Args:
544+
query: Text to look up documents similar to.
545+
k: Number of Documents to return. Defaults to 4.
546+
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
547+
Defaults to 20.
548+
lambda_mult: Number between 0 and 1 that determines the degree
549+
of diversity among the results with 0 corresponding
550+
to maximum diversity and 1 to minimum diversity.
551+
Defaults to 0.5.
552+
Returns:
553+
List of Documents selected by maximal marginal relevance.
554+
"""
555+
556+
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
557+
text_field = _get_kwargs_value(kwargs, "text_field", "text")
558+
metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata")
559+
560+
# Get embedding of the user query
561+
embedding = self.embedding_function.embed_query(query)
562+
563+
# Do ANN/KNN search to get top fetch_k results where fetch_k >= k
564+
results = self._raw_similarity_search_with_score(query, fetch_k, **kwargs)
565+
566+
embeddings = [result["_source"][vector_field] for result in results]
567+
568+
# Rerank top k results using MMR, (mmr_selected is a list of indices)
569+
mmr_selected = maximal_marginal_relevance(
570+
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
571+
)
572+
573+
return [
574+
Document(
575+
page_content=results[i]["_source"][text_field],
576+
metadata=results[i]["_source"][metadata_field],
486577
)
487-
for hit in hits
578+
for i in mmr_selected
488579
]
489-
return documents_with_scores
490580

491581
@classmethod
492582
def from_texts(

0 commit comments

Comments
 (0)