Skip to content

Commit aaf8c59

Browse files
authored
Merge pull request #743 from iceljc/features/add-evaluation
Features/add evaluation
2 parents 48e7b6e + cab99c6 commit aaf8c59

File tree

6 files changed

+166
-35
lines changed

6 files changed

+166
-35
lines changed

src/Infrastructure/BotSharp.Abstraction/Evaluations/Models/EvaluationRequest.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@ public class EvaluationRequest : LlmBaseRequest
1010
[JsonPropertyName("states")]
1111
public IEnumerable<MessageState> States { get; set; } = [];
1212

13+
[JsonPropertyName("chat")]
14+
public ChatEvaluationRequest Chat { get; set; } = new ChatEvaluationRequest();
15+
16+
[JsonPropertyName("metric")]
17+
public MetricEvaluationRequest Metric { get; set; } = new MetricEvaluationRequest();
18+
}
19+
20+
21+
public class ChatEvaluationRequest
22+
{
1323
[JsonPropertyName("duplicate_limit")]
1424
public int DuplicateLimit { get; set; } = 2;
1525

@@ -24,4 +34,26 @@ public class EvaluationRequest : LlmBaseRequest
2434
[JsonPropertyName("stop_criteria")]
2535
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
2636
public string? StopCriteria { get; set; }
37+
38+
public ChatEvaluationRequest()
39+
{
40+
41+
}
2742
}
43+
44+
45+
public class MetricEvaluationRequest
46+
{
47+
[JsonPropertyName("additional_instruction")]
48+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
49+
public string? AdditionalInstruction { get; set; }
50+
51+
[JsonPropertyName("metrics")]
52+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
53+
public IEnumerable<NameDesc>? Metrics { get; set; } = [];
54+
55+
public MetricEvaluationRequest()
56+
{
57+
58+
}
59+
}

src/Infrastructure/BotSharp.Abstraction/Evaluations/Models/EvaluationResult.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ public class EvaluationResult
66
public string TaskInstruction { get; set; }
77
public string SystemPrompt { get; set; }
88
public string GeneratedConversationId { get; set; }
9+
public string? MetricResult { get; set; }
910
}

src/Infrastructure/BotSharp.Core/Evaluations/Services/EvaluatingService.Evaluate.cs

Lines changed: 76 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using BotSharp.Abstraction.Evaluations.Models;
22
using BotSharp.Abstraction.Instructs;
33
using BotSharp.Abstraction.Instructs.Models;
4+
using BotSharp.Abstraction.Models;
45

56
namespace BotSharp.Core.Evaluations.Services;
67

@@ -31,15 +32,19 @@ public async Task<EvaluationResult> Evaluate(string conversationId, EvaluationRe
3132
return result;
3233
}
3334

34-
var generatedConvId = await SimulateConversation(initMessage, refDialogContents, request);
35+
var initialStates = GetInitialStates(conversationId);
36+
var generatedConvId = await SimulateConversation(initMessage, refDialogContents, request, initialStates);
37+
var metricResult = await EvaluateMetrics(generatedConvId, refDialogContents, request);
3538

3639
return new EvaluationResult
3740
{
38-
GeneratedConversationId = generatedConvId
41+
GeneratedConversationId = generatedConvId,
42+
MetricResult = metricResult
3943
};
4044
}
4145

42-
private async Task<string> SimulateConversation(string initMessage, IEnumerable<string> refDialogs, EvaluationRequest request)
46+
private async Task<string> SimulateConversation(string initMessage, IEnumerable<string> refDialogs,
47+
EvaluationRequest request, IEnumerable<MessageState>? states = null)
4348
{
4449
var count = 0;
4550
var duplicateCount = 0;
@@ -49,20 +54,22 @@ private async Task<string> SimulateConversation(string initMessage, IEnumerable<
4954
var prevUserMsg = string.Empty;
5055
var curBotMsg = string.Empty;
5156
var prevBotMsg = string.Empty;
57+
var initialStates = states?.ToList() ?? [];
5258

5359
var storage = _services.GetRequiredService<IConversationStorage>();
5460
var agentService = _services.GetRequiredService<IAgentService>();
5561
var instructService = _services.GetRequiredService<IInstructService>();
5662

5763
var query = "Please see yourself as a user and follow the instruction to generate a message.";
5864
var targetAgentId = request.AgentId;
59-
var evaluatorAgent = await agentService.GetAgent(BuiltInAgentId.Evaluator);
60-
var simulatorPrompt = evaluatorAgent.Templates.FirstOrDefault(x => x.Name == "instruction.simulator")?.Content ?? string.Empty;
65+
var evaluator = await agentService.GetAgent(BuiltInAgentId.Evaluator);
66+
var simulatorPrompt = evaluator.Templates.FirstOrDefault(x => x.Name == "instruction.simulator")?.Content ?? string.Empty;
6167

6268
while (true)
6369
{
6470
curDialogs.Add($"{AgentRole.User}: {curUserMsg}");
65-
var dialog = await SendMessage(targetAgentId, convId, curUserMsg);
71+
var dialog = await SendMessage(targetAgentId, convId, curUserMsg, states: initialStates);
72+
initialStates = [];
6673

6774
prevBotMsg = curBotMsg;
6875
curBotMsg = dialog?.RichContent?.Message?.Text ?? dialog?.Content ?? string.Empty;
@@ -80,30 +87,20 @@ private async Task<string> SimulateConversation(string initMessage, IEnumerable<
8087
{
8188
{ "ref_conversation", refDialogs },
8289
{ "cur_conversation", curDialogs },
83-
{ "additional_instruction", request.AdditionalInstruction },
84-
{ "stop_criteria", request.StopCriteria }
90+
{ "additional_instruction", request.Chat.AdditionalInstruction },
91+
{ "stop_criteria", request.Chat.StopCriteria }
8592
}
8693
});
8794

8895
_logger.LogInformation($"Generated message: {result?.GeneratedMessage}, stop: {result?.Stop}, reason: {result?.Reason}");
8996

90-
if (count > request.MaxRounds || (result != null && result.Stop))
97+
if (count > request.Chat.MaxRounds || (result != null && result.Stop))
9198
{
9299
break;
93100
}
94101

95-
96-
if (curUserMsg.IsEqualTo(prevUserMsg) || curBotMsg.IsEqualTo(prevBotMsg))
97-
{
98-
duplicateCount++;
99-
}
100-
else
101-
{
102-
duplicateCount = 0;
103-
}
104-
105-
106-
if (duplicateCount >= request.DuplicateLimit)
102+
duplicateCount = curBotMsg.IsEqualTo(prevBotMsg) ? duplicateCount + 1 : 0;
103+
if (duplicateCount >= request.Chat.DuplicateLimit)
107104
{
108105
break;
109106
}
@@ -115,6 +112,38 @@ private async Task<string> SimulateConversation(string initMessage, IEnumerable<
115112
return convId;
116113
}
117114

115+
116+
private async Task<string?> EvaluateMetrics(string curConversationId, IEnumerable<string> refDialogs, EvaluationRequest request)
117+
{
118+
var storage = _services.GetRequiredService<IConversationStorage>();
119+
var agentService = _services.GetRequiredService<IAgentService>();
120+
var instructService = _services.GetRequiredService<IInstructService>();
121+
122+
var curDialogs = storage.GetDialogs(curConversationId);
123+
var curDialogContents = GetConversationContent(curDialogs);
124+
125+
var evaluator = await agentService.GetAgent(BuiltInAgentId.Evaluator);
126+
var metricPrompt = evaluator.Templates.FirstOrDefault(x => x.Name == "instruction.metrics")?.Content ?? string.Empty;
127+
var query = "Please follow the instruction for evaluation.";
128+
129+
var result = await instructService.Instruct<JsonDocument>(metricPrompt, BuiltInAgentId.Evaluator,
130+
new InstructOptions
131+
{
132+
Provider = request.Provider,
133+
Model = request.Model,
134+
Message = query,
135+
Data = new Dictionary<string, object>
136+
{
137+
{ "ref_conversation", refDialogs },
138+
{ "cur_conversation", curDialogs },
139+
{ "additional_instruction", request.Metric.AdditionalInstruction },
140+
{ "metrics", request.Metric.Metrics }
141+
}
142+
});
143+
144+
return result != null ? result.RootElement.GetRawText() : null;
145+
}
146+
118147
private IEnumerable<string> GetConversationContent(IEnumerable<RoleDialogModel> dialogs)
119148
{
120149
var contents = new List<string>();
@@ -134,4 +163,30 @@ private IEnumerable<string> GetConversationContent(IEnumerable<RoleDialogModel>
134163

135164
return contents;
136165
}
166+
167+
private IEnumerable<MessageState> GetInitialStates(string conversationId)
168+
{
169+
if (string.IsNullOrWhiteSpace(conversationId))
170+
{
171+
return [];
172+
}
173+
174+
var db = _services.GetRequiredService<IBotSharpRepository>();
175+
var states = db.GetConversationStates(conversationId);
176+
var initialStates = new List<MessageState>();
177+
178+
foreach (var state in states)
179+
{
180+
var value = state.Value?.Values?.FirstOrDefault(x => string.IsNullOrEmpty(x.MessageId));
181+
182+
if (string.IsNullOrEmpty(value?.Data))
183+
{
184+
continue;
185+
}
186+
187+
initialStates.Add(new MessageState(state.Key, value.Data, value.ActiveRounds));
188+
}
189+
190+
return initialStates;
191+
}
137192
}

src/Infrastructure/BotSharp.Core/Evaluations/Services/EvaluatingService.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ public async Task<Conversation> Execute(string task, EvaluationRequest request)
6161
dialogs.Add(new RoleDialogModel(AgentRole.User, question));
6262
prompt += question.Trim();
6363

64-
response = await SendMessage(request.AgentId, conv.Id, question);
64+
response = await SendMessage(request.AgentId, conv.Id, question, states: new List<MessageState>
65+
{
66+
new MessageState("channel", ConversationChannel.OpenAPI)
67+
});
6568
dialogs.Add(new RoleDialogModel(AgentRole.Assistant, response.Content));
6669
prompt += $"\r\n{AgentRole.Assistant}: {response.Content.Trim()}";
6770
prompt += $"\r\n{AgentRole.User}: ";
@@ -86,17 +89,16 @@ public async Task<Conversation> Execute(string task, EvaluationRequest request)
8689
return conv;
8790
}
8891

89-
private async Task<RoleDialogModel> SendMessage(string agentId, string conversationId, string text)
92+
private async Task<RoleDialogModel> SendMessage(string agentId, string conversationId, string text,
93+
PostbackMessageModel? postback = null,
94+
List<MessageState>? states = null)
9095
{
9196
var conv = _services.GetRequiredService<IConversationService>();
9297
var routing = _services.GetRequiredService<IRoutingService>();
9398

9499
var inputMsg = new RoleDialogModel(AgentRole.User, text);
95100
routing.Context.SetMessageId(conversationId, inputMsg.MessageId);
96-
conv.SetConversationId(conversationId, new List<MessageState>
97-
{
98-
new MessageState("channel", ConversationChannel.OpenAPI)
99-
});
101+
conv.SetConversationId(conversationId, states ?? []);
100102

101103
RoleDialogModel response = default;
102104

Original file line numberDiff line numberDiff line change
@@ -1 +1,45 @@
1-
You are a conversation evaluator.
1+
You are a conversaton evaluator.
2+
Please take the content in the [REFERENCE CONVERSATION] section and [ONGOING CONVERSATION] section, and evaluate the metrics defined in [OUTPUT JSON FORMAT].
3+
4+
** You need to take a close look at the content in both [REFERENCE CONVERSATION] and [ONGOING CONVERSATION], and evaluate the metrics listed in [OUTPUT JSON FORMAT].
5+
6+
7+
=================
8+
[ADDITIONAL INSTRUCTION]
9+
{{ "\r\n" }}
10+
{%- if additional_instruction != empty -%}
11+
{{ additional_instruction }}
12+
{%- endif -%}
13+
{{ "\r\n" }}
14+
15+
16+
=================
17+
[OUTPUT JSON FORMAT]
18+
19+
** The output must be in JSON format:
20+
{
21+
{%- if metrics != empty -%}
22+
{{ "\r\n" }}
23+
{% for metric in metrics -%}
24+
{{ metric.name }}: {{ metric.description }},{{ "\r\n" }}
25+
{%- endfor %}
26+
{%- else -%}
27+
"summary": a short summary that summarizes the [ONGOING CONVERSATION] content compared to the [REFERENCE CONVERSATION]
28+
{%- endif -%}
29+
}
30+
31+
32+
=================
33+
[REFERENCE CONVERSATION]
34+
35+
{% for text in ref_conversation -%}
36+
{{ text }}{{ "\r\n" }}
37+
{%- endfor %}
38+
39+
40+
=================
41+
[ONGOING CONVERSATION]
42+
43+
{% for text in cur_conversation -%}
44+
{{ text }}{{ "\r\n" }}
45+
{%- endfor %}

src/Plugins/BotSharp.Plugin.SqlDriver/Hooks/SqlDriverPlanningHook.cs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ await HookEmitter.Emit<ISqlDriverHook>(_services, async (hook) =>
5353
// Invoke "execute_sql"
5454
var routing = _services.GetRequiredService<IRoutingService>();
5555
await routing.InvokeFunction(response.FunctionName, response);
56+
5657
msg.CurrentAgentId = agent.Id;
5758
msg.FunctionName = response.FunctionName;
5859
msg.FunctionArgs = response.FunctionArgs;
@@ -64,13 +65,10 @@ public async Task<string> GetSummaryAdditionalRequirements(string planner, RoleD
6465
{
6566
var settings = _services.GetRequiredService<SqlDriverSetting>();
6667
var sqlHooks = _services.GetServices<ISqlDriverHook>();
67-
68-
var dbType = sqlHooks.Any() ?
69-
sqlHooks.First().GetDatabaseType(message) :
70-
settings.DatabaseType;
68+
var agentService = _services.GetRequiredService<IAgentService>();
7169

72-
var agent = await _services.GetRequiredService<IAgentService>()
73-
.LoadAgent(BuiltInAgentId.SqlDriver);
70+
var dbType = !sqlHooks.IsNullOrEmpty() ? sqlHooks.First().GetDatabaseType(message) : settings.DatabaseType;
71+
var agent = await agentService.LoadAgent(BuiltInAgentId.SqlDriver);
7472

7573
return agent.Templates.FirstOrDefault(x => x.Name == $"database.summarize.{dbType}")?.Content ?? string.Empty;
7674
}
@@ -101,7 +99,6 @@ private RichContent<IRichMessage> BuildRunQueryButton(string conversationId, str
10199
Type = "text",
102100
Title = "Execute the SQL Statement",
103101
Payload = sql,
104-
105102
IsPrimary = true
106103
},
107104
new ElementButton

0 commit comments

Comments
 (0)