5
5
from datasets import Dataset
6
6
from dotenv import load_dotenv
7
7
from ragas import evaluate
8
- from ragas .metrics import answer_relevancy , faithfulness
8
+ from ragas .metrics import answer_relevancy , faithfulness , context_entity_recall
9
9
from src .shared .common_fn import load_embedding_model
10
10
from ragas .dataset_schema import SingleTurnSample
11
11
from ragas .metrics import RougeScore , SemanticSimilarity , ContextEntityRecall
@@ -24,25 +24,29 @@ def get_ragas_metrics(question: str, context: list, answer: list, model: str):
24
24
try :
25
25
start_time = time .time ()
26
26
dataset = Dataset .from_dict (
27
- {"question" : [question ] * len (answer ), "answer" : answer , "contexts" : [[ctx ] for ctx in context ]}
27
+ {"question" : [question ] * len (answer ),"reference" : answer , "answer" : answer , "contexts" : [[ctx ] for ctx in context ]}
28
28
)
29
29
logging .info ("Evaluation dataset created successfully." )
30
30
if ("diffbot" in model ) or ("ollama" in model ):
31
31
raise ValueError (f"Unsupported model for evaluation: { model } " )
32
+ elif ("gemini" in model ):
33
+ llm , model_name = get_llm (model = model )
34
+ llm = LangchainLLMWrapper (llm ,is_finished_parser = custom_is_finished_parser )
32
35
else :
33
36
llm , model_name = get_llm (model = model )
37
+ llm = LangchainLLMWrapper (llm )
34
38
35
39
logging .info (f"Evaluating with model: { model_name } " )
36
40
37
41
score = evaluate (
38
42
dataset = dataset ,
39
- metrics = [faithfulness , answer_relevancy ],
43
+ metrics = [faithfulness , answer_relevancy , context_entity_recall ],
40
44
llm = llm ,
41
45
embeddings = EMBEDDING_FUNCTION ,
42
46
)
43
47
44
48
score_dict = (
45
- score .to_pandas ()[["faithfulness" , "answer_relevancy" ]]
49
+ score .to_pandas ()[["faithfulness" , "answer_relevancy" , "context_entity_recall" ]]
46
50
.fillna (0 )
47
51
.round (4 )
48
52
.to_dict (orient = "list" )
@@ -67,13 +71,10 @@ async def get_additional_metrics(question: str, contexts: list, answers: list, r
67
71
if ("diffbot" in model_name ) or ("ollama" in model_name ):
68
72
raise ValueError (f"Unsupported model for evaluation: { model_name } " )
69
73
llm , model_name = get_llm (model = model_name )
70
- ragas_llm = LangchainLLMWrapper (llm )
71
74
embeddings = EMBEDDING_FUNCTION
72
75
embedding_model = LangchainEmbeddingsWrapper (embeddings = embeddings )
73
76
rouge_scorer = RougeScore ()
74
77
semantic_scorer = SemanticSimilarity ()
75
- entity_recall_scorer = ContextEntityRecall ()
76
- entity_recall_scorer .llm = ragas_llm
77
78
semantic_scorer .embeddings = embedding_model
78
79
metrics = []
79
80
for response , context in zip (answers , contexts ):
@@ -82,18 +83,35 @@ async def get_additional_metrics(question: str, contexts: list, answers: list, r
82
83
rouge_score = round (rouge_score ,4 )
83
84
semantic_score = await semantic_scorer .single_turn_ascore (sample )
84
85
semantic_score = round (semantic_score , 4 )
85
- if "gemini" in model_name :
86
- entity_recall_score = "Not Available"
87
- else :
88
- entity_sample = SingleTurnSample (reference = reference , retrieved_contexts = [context ])
89
- entity_recall_score = await entity_recall_scorer .single_turn_ascore (entity_sample )
90
- entity_recall_score = round (entity_recall_score , 4 )
91
86
metrics .append ({
92
87
"rouge_score" : rouge_score ,
93
88
"semantic_score" : semantic_score ,
94
- "context_entity_recall_score" : entity_recall_score
95
89
})
96
90
return metrics
97
91
except Exception as e :
98
92
logging .exception ("Error in get_additional_metrics" )
99
- return {"error" : str (e )}
93
+ return {"error" : str (e )}
94
+
95
+
96
+ def custom_is_finished_parser (response ):
97
+ is_finished_list = []
98
+ for g in response .flatten ():
99
+ resp = g .generations [0 ][0 ]
100
+ if resp .generation_info is not None :
101
+ if resp .generation_info .get ("finish_reason" ) is not None :
102
+ is_finished_list .append (
103
+ resp .generation_info .get ("finish_reason" ) == "STOP"
104
+ )
105
+
106
+ elif (
107
+ isinstance (resp , ChatGeneration )
108
+ and t .cast (ChatGeneration , resp ).message is not None
109
+ ):
110
+ resp_message : BaseMessage = t .cast (ChatGeneration , resp ).message
111
+ if resp_message .response_metadata .get ("finish_reason" ) is not None :
112
+ is_finished_list .append (
113
+ resp_message .response_metadata .get ("finish_reason" ) == "STOP"
114
+ )
115
+ else :
116
+ is_finished_list .append (True )
117
+ return all (is_finished_list )
0 commit comments