44import uuid
55from typing import Any , Dict , Iterable , List , Optional , Tuple
66
7- from langchain .docstore .document import Document
7+ import numpy as np
8+
89from langchain .embeddings .base import Embeddings
10+ from langchain .schema import Document
911from langchain .utils import get_from_dict_or_env
1012from langchain .vectorstores .base import VectorStore
13+ from langchain .vectorstores .utils import maximal_marginal_relevance
1114
1215IMPORT_OPENSEARCH_PY_ERROR = (
1316 "Could not import OpenSearch. Please install it with `pip install opensearch-py`."
@@ -76,9 +79,12 @@ def _bulk_ingest_embeddings(
7679 metadatas : Optional [List [dict ]] = None ,
7780 vector_field : str = "vector_field" ,
7881 text_field : str = "text" ,
79- mapping : Dict = {} ,
82+ mapping : Optional [ Dict ] = None ,
8083) -> List [str ]:
8184 """Bulk Ingest Embeddings into given index."""
85+ if not mapping :
86+ mapping = dict ()
87+
8288 bulk = _import_bulk ()
8389 not_found_error = _import_not_found_error ()
8490 requests = []
@@ -201,10 +207,14 @@ def _approximate_search_query_with_lucene_filter(
201207def _default_script_query (
202208 query_vector : List [float ],
203209 space_type : str = "l2" ,
204- pre_filter : Dict = MATCH_ALL_QUERY ,
210+ pre_filter : Optional [ Dict ] = None ,
205211 vector_field : str = "vector_field" ,
206212) -> Dict :
207213 """For Script Scoring Search, this is the default query."""
214+
215+ if not pre_filter :
216+ pre_filter = MATCH_ALL_QUERY
217+
208218 return {
209219 "query" : {
210220 "script_score" : {
@@ -245,10 +255,14 @@ def __get_painless_scripting_source(
245255def _default_painless_scripting_query (
246256 query_vector : List [float ],
247257 space_type : str = "l2Squared" ,
248- pre_filter : Dict = MATCH_ALL_QUERY ,
258+ pre_filter : Optional [ Dict ] = None ,
249259 vector_field : str = "vector_field" ,
250260) -> Dict :
251261 """For Painless Scripting Search, this is the default query."""
262+
263+ if not pre_filter :
264+ pre_filter = MATCH_ALL_QUERY
265+
252266 source = __get_painless_scripting_source (space_type , query_vector )
253267 return {
254268 "query" : {
@@ -355,7 +369,7 @@ def similarity_search(
355369 ) -> List [Document ]:
356370 """Return docs most similar to query.
357371
358- By default supports Approximate Search.
372+ By default, supports Approximate Search.
359373 Also supports Script Scoring and Painless Scripting.
360374
361375 Args:
@@ -413,7 +427,7 @@ def similarity_search_with_score(
413427 ) -> List [Tuple [Document , float ]]:
414428 """Return docs and it's scores most similar to query.
415429
416- By default supports Approximate Search.
430+ By default, supports Approximate Search.
417431 Also supports Script Scoring and Painless Scripting.
418432
419433 Args:
@@ -426,10 +440,47 @@ def similarity_search_with_score(
426440 Optional Args:
427441 same as `similarity_search`
428442 """
429- embedding = self .embedding_function .embed_query (query )
430- search_type = _get_kwargs_value (kwargs , "search_type" , "approximate_search" )
443+
431444 text_field = _get_kwargs_value (kwargs , "text_field" , "text" )
432445 metadata_field = _get_kwargs_value (kwargs , "metadata_field" , "metadata" )
446+
447+ hits = self ._raw_similarity_search_with_score (query = query , k = k , ** kwargs )
448+
449+ documents_with_scores = [
450+ (
451+ Document (
452+ page_content = hit ["_source" ][text_field ],
453+ metadata = hit ["_source" ]
454+ if metadata_field == "*" or metadata_field not in hit ["_source" ]
455+ else hit ["_source" ][metadata_field ],
456+ ),
457+ hit ["_score" ],
458+ )
459+ for hit in hits
460+ ]
461+ return documents_with_scores
462+
463+ def _raw_similarity_search_with_score (
464+ self , query : str , k : int = 4 , ** kwargs : Any
465+ ) -> List [dict ]:
466+ """Return raw opensearch documents (dict) including vectors,
467+ scores most similar to query.
468+
469+ By default, supports Approximate Search.
470+ Also supports Script Scoring and Painless Scripting.
471+
472+ Args:
473+ query: Text to look up documents similar to.
474+ k: Number of Documents to return. Defaults to 4.
475+
476+ Returns:
477+ List of dict with its scores most similar to the query.
478+
479+ Optional Args:
480+ same as `similarity_search`
481+ """
482+ embedding = self .embedding_function .embed_query (query )
483+ search_type = _get_kwargs_value (kwargs , "search_type" , "approximate_search" )
433484 vector_field = _get_kwargs_value (kwargs , "vector_field" , "vector_field" )
434485
435486 if search_type == "approximate_search" :
@@ -473,20 +524,59 @@ def similarity_search_with_score(
473524 raise ValueError ("Invalid `search_type` provided as an argument" )
474525
475526 response = self .client .search (index = self .index_name , body = search_query )
476- hits = [hit for hit in response ["hits" ]["hits" ][:k ]]
477- documents_with_scores = [
478- (
479- Document (
480- page_content = hit ["_source" ][text_field ],
481- metadata = hit ["_source" ]
482- if metadata_field == "*" or metadata_field not in hit ["_source" ]
483- else hit ["_source" ][metadata_field ],
484- ),
485- hit ["_score" ],
527+
528+ return [hit for hit in response ["hits" ]["hits" ][:k ]]
529+
530+ def max_marginal_relevance_search (
531+ self ,
532+ query : str ,
533+ k : int = 4 ,
534+ fetch_k : int = 20 ,
535+ lambda_mult : float = 0.5 ,
536+ ** kwargs : Any ,
537+ ) -> list [Document ]:
538+ """Return docs selected using the maximal marginal relevance.
539+
540+ Maximal marginal relevance optimizes for similarity to query AND diversity
541+ among selected documents.
542+
543+ Args:
544+ query: Text to look up documents similar to.
545+ k: Number of Documents to return. Defaults to 4.
546+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
547+ Defaults to 20.
548+ lambda_mult: Number between 0 and 1 that determines the degree
549+ of diversity among the results with 0 corresponding
550+ to maximum diversity and 1 to minimum diversity.
551+ Defaults to 0.5.
552+ Returns:
553+ List of Documents selected by maximal marginal relevance.
554+ """
555+
556+ vector_field = _get_kwargs_value (kwargs , "vector_field" , "vector_field" )
557+ text_field = _get_kwargs_value (kwargs , "text_field" , "text" )
558+ metadata_field = _get_kwargs_value (kwargs , "metadata_field" , "metadata" )
559+
560+ # Get embedding of the user query
561+ embedding = self .embedding_function .embed_query (query )
562+
563+ # Do ANN/KNN search to get top fetch_k results where fetch_k >= k
564+ results = self ._raw_similarity_search_with_score (query , fetch_k , ** kwargs )
565+
566+ embeddings = [result ["_source" ][vector_field ] for result in results ]
567+
568+ # Rerank top k results using MMR, (mmr_selected is a list of indices)
569+ mmr_selected = maximal_marginal_relevance (
570+ np .array (embedding ), embeddings , k = k , lambda_mult = lambda_mult
571+ )
572+
573+ return [
574+ Document (
575+ page_content = results [i ]["_source" ][text_field ],
576+ metadata = results [i ]["_source" ][metadata_field ],
486577 )
487- for hit in hits
578+ for i in mmr_selected
488579 ]
489- return documents_with_scores
490580
491581 @classmethod
492582 def from_texts (
0 commit comments