Skip to content

Commit c2b870b

Browse files
author
Jicheng Lu
committed
resolve conflicts
2 parents ddd1d0a + 0e156c3 commit c2b870b

File tree

6 files changed

+81
-168
lines changed

6 files changed

+81
-168
lines changed

src/Plugins/BotSharp.Plugin.Planner/Functions/PrimaryStagePlanFn.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ public PrimaryStagePlanFn(IServiceProvider services, ILogger<PrimaryStagePlanFn>
1818
public async Task<bool> Execute(RoleDialogModel message)
1919
{
2020
// Debug
21+
var agentService = _services.GetRequiredService<IAgentService>();
2122
var state = _services.GetRequiredService<IConversationStateService>();
2223
var knowledgeService = _services.GetRequiredService<IKnowledgeService>();
2324
var knowledgeSettings = _services.GetRequiredService<KnowledgeBaseSettings>();
@@ -28,16 +29,14 @@ public async Task<bool> Execute(RoleDialogModel message)
2829

2930
// Get knowledge from vectordb
3031
var collectionName = knowledgeSettings.Default.CollectionName ?? KnowledgeCollectionName.BotSharp; ;
31-
var knowledges = await knowledgeService.SearchVectorKnowledge(task.Question, collectionName, new VectorSearchOptions
32+
var knowledges = await knowledgeService.SearchVectorKnowledge(task.Requirements, collectionName, new VectorSearchOptions
3233
{
3334
Confidence = 0.1f
3435
});
3536
message.Content = string.Join("\r\n\r\n=====\r\n", knowledges.Select(x => x.ToQuestionAnswer()));
3637

37-
var agentService = _services.GetRequiredService<IAgentService>();
38-
var currentAgent = await agentService.LoadAgent(message.CurrentAgentId);
39-
4038
// Send knowledge to AI to refine and summarize the primary planning
39+
var currentAgent = await agentService.LoadAgent(message.CurrentAgentId);
4140
var firstPlanningPrompt = await GetFirstStagePlanPrompt(task, message);
4241
var plannerAgent = new Agent
4342
{

src/Plugins/BotSharp.Plugin.Planner/Functions/SummaryPlanFn.cs

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,44 +5,65 @@ namespace BotSharp.Plugin.Planner.Functions;
55
public class SummaryPlanFn : IFunctionCallback
66
{
77
public string Name => "plan_summary";
8+
89
private readonly IServiceProvider _services;
9-
private readonly ILogger _logger;
10-
private object aiAssistant;
10+
private readonly ILogger<SummaryPlanFn> _logger;
1111

12-
public SummaryPlanFn(IServiceProvider services, ILogger<PrimaryStagePlanFn> logger)
12+
public SummaryPlanFn(
13+
IServiceProvider services,
14+
ILogger<SummaryPlanFn> logger)
1315
{
1416
_services = services;
1517
_logger = logger;
1618
}
1719

1820
public async Task<bool> Execute(RoleDialogModel message)
1921
{
20-
//debug
22+
var fn = _services.GetRequiredService<IRoutingService>();
23+
var agentService = _services.GetRequiredService<IAgentService>();
2124
var state = _services.GetRequiredService<IConversationStateService>();
25+
26+
var currentAgent = await agentService.LoadAgent(message.CurrentAgentId);
2227
state.SetState("max_tokens", "4096");
2328

2429
var task = state.GetState("requirement_detail");
2530

26-
// summarize and generate query
27-
var summaryPlanningPrompt = await GetPlanSummaryPrompt(task, message);
28-
_logger.LogInformation(summaryPlanningPrompt);
31+
// Get DDL
32+
var steps = message.Content.JsonArrayContent<SecondStagePlan>();
33+
34+
// Get all the related tables
35+
var allTables = new List<string>();
36+
foreach (var step in steps)
37+
{
38+
allTables.AddRange(step.Tables);
39+
}
40+
message.Data = allTables.Distinct().ToList();
41+
42+
// Get table DDL and stores in content
43+
var msgCopy = RoleDialogModel.From(message);
44+
await fn.InvokeFunction("get_table_definition", msgCopy);
45+
var ddlStatements = msgCopy.Content;
46+
47+
// Summarize and generate query
48+
var summaryPlanPrompt = await GetPlanSummaryPrompt(task, message.Content, ddlStatements);
49+
_logger.LogInformation($"Summary plan prompt:\r\n{summaryPlanPrompt}");
2950

3051
var plannerAgent = new Agent
3152
{
3253
Id = BuiltInAgentId.Planner,
33-
Name = "planner_summary",
34-
Instruction = summaryPlanningPrompt,
35-
TemplateDict = new Dictionary<string, object>()
54+
Name = "Planner Summary",
55+
Instruction = summaryPlanPrompt,
56+
LlmConfig = currentAgent.LlmConfig
3657
};
37-
var response_summary = await GetAiResponse(plannerAgent);
3858

39-
message.Content = response_summary.Content;
59+
var summary = await GetAiResponse(plannerAgent);
60+
message.Content = summary.Content;
4061
message.StopCompletion = true;
4162

4263
return true;
4364
}
4465

45-
private async Task<string> GetPlanSummaryPrompt(string task, RoleDialogModel message)
66+
private async Task<string> GetPlanSummaryPrompt(string task, string knowledge, string ddlStatement)
4667
{
4768
// save to knowledge base
4869
var agentService = _services.GetRequiredService<IAgentService>();
@@ -58,9 +79,9 @@ private async Task<string> GetPlanSummaryPrompt(string task, RoleDialogModel mes
5879

5980
return render.Render(template, new Dictionary<string, object>
6081
{
61-
{ "table_structure", message.SecondaryContent }, ////check
62-
{ "task_description", task},
63-
{ "relevant_knowledges", message.Content },
82+
{ "table_structure", ddlStatement },
83+
{ "task_description", task },
84+
{ "relevant_knowledges", knowledge },
6485
{ "response_format", responseFormat }
6586
});
6687
}
@@ -69,17 +90,17 @@ private async Task<RoleDialogModel> GetAiResponse(Agent plannerAgent)
6990
var conv = _services.GetRequiredService<IConversationService>();
7091
var wholeDialogs = conv.GetDialogHistory();
7192

72-
//add "test" to wholeDialogs' last element
93+
// Add "test" to wholeDialogs' last element
7394
if (plannerAgent.Name == "planner_summary")
7495
{
75-
//add "test" to wholeDialogs' last element in a new paragraph
96+
// Add "test" to wholeDialogs' last element in a new paragraph
7697
wholeDialogs.Last().Content += "\n\nIf the table structure didn't mention auto incremental, the data field id needs to insert id manually and you need to use max(id) instead of LAST_INSERT_ID function.\nFor example, you should use SET @id = select max(id) from table;";
7798
wholeDialogs.Last().Content += "\n\nTry if you can generate a single query to fulfill the needs";
7899
}
79100

80101
if (plannerAgent.Name == "planning_1st")
81102
{
82-
//add "test" to wholeDialogs' last element in a new paragraph
103+
// Add "test" to wholeDialogs' last element in a new paragraph
83104
wholeDialogs.Last().Content += "\n\nYou must analyze the table description to infer the table relations.";
84105
}
85106

Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
Use the TwoStagePlanner approach to plan the overall implementation steps, call plan_primary_stage.
22
If need_additional_information is true, call plan_secondary_stage for the specific primary stage.
3-
Call plan_summary to summarize the final planning steps.
3+
You must Call plan_summary as the last step to summarize the final planning steps.

src/Plugins/BotSharp.Plugin.SqlDriver/Functions/AddDatabaseKnowledge.cs

Lines changed: 0 additions & 122 deletions
This file was deleted.
Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using Microsoft.Extensions.Logging;
12
using MySqlConnector;
23
using static Dapper.SqlMapper;
34

@@ -7,10 +8,14 @@ public class GetTableDefinitionFn : IFunctionCallback
78
{
89
public string Name => "get_table_definition";
910
private readonly IServiceProvider _services;
11+
private readonly ILogger<GetTableDefinitionFn> _logger;
1012

11-
public GetTableDefinitionFn(IServiceProvider services)
13+
public GetTableDefinitionFn(
14+
IServiceProvider services,
15+
ILogger<GetTableDefinitionFn> logger)
1216
{
1317
_services = services;
18+
_logger = logger;
1419
}
1520

1621
public async Task<bool> Execute(RoleDialogModel message)
@@ -20,37 +25,47 @@ public async Task<bool> Execute(RoleDialogModel message)
2025
var settings = _services.GetRequiredService<SqlDriverSetting>();
2126

2227
// Get table DDL from database
28+
var tables = message.Data as List<string>;
29+
if (tables.IsNullOrEmpty()) return false;
30+
2331
var tableDdls = new List<string>();
2432
using var connection = new MySqlConnection(settings.MySqlConnectionString);
2533
connection.Open();
2634

27-
foreach (var table in (List<string>)message.Data)
35+
foreach (var table in tables)
2836
{
29-
var escapedTableName = MySqlHelper.EscapeString(table);
30-
31-
var sql = $"select * from information_schema.tables where table_name = @tableName";
32-
var result = connection.QueryFirstOrDefault(sql, new
37+
try
3338
{
34-
tableName = escapedTableName
35-
});
39+
var escapedTableName = MySqlHelper.EscapeString(table);
40+
41+
var sql = $"select * from information_schema.tables where table_name = @tableName";
42+
var result = connection.QueryFirstOrDefault(sql, new
43+
{
44+
tableName = escapedTableName
45+
});
3646

37-
if (result == null) continue;
47+
if (result == null) continue;
3848

39-
sql = $"SHOW CREATE TABLE `{escapedTableName}`";
40-
using var command = new MySqlCommand(sql, connection);
41-
using var reader = command.ExecuteReader();
42-
if (reader.Read())
49+
sql = $"SHOW CREATE TABLE `{escapedTableName}`";
50+
using var command = new MySqlCommand(sql, connection);
51+
using var reader = command.ExecuteReader();
52+
if (reader.Read())
53+
{
54+
result = reader.GetString(1);
55+
tableDdls.Add(result);
56+
}
57+
58+
reader.Close();
59+
command.Dispose();
60+
}
61+
catch (Exception ex)
4362
{
44-
result = reader.GetString(1);
45-
tableDdls.Add(result);
63+
_logger.LogWarning($"Error when getting ddl statement of table {table}.");
4664
}
47-
48-
reader.Close();
49-
command.Dispose();
5065
}
5166

5267
connection.Close();
53-
message.Content = string.Join("\r\n", tableDdls);
68+
message.Content = string.Join("\r\n\r\n", tableDdls);
5469
return true;
5570
}
5671
}

src/Plugins/BotSharp.Plugin.SqlDriver/Services/DbKnowledgeService.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,21 +94,21 @@ private string GetTableStructure(string table)
9494
using var connection = new MySqlConnection(settings.MySqlConnectionString);
9595
connection.Open();
9696

97-
var result = string.Empty;
97+
var ddl = string.Empty;
9898
var escapedTableName = MySqlHelper.EscapeString(table);
9999
var sql = $"SHOW CREATE TABLE `{escapedTableName}`";
100100

101101
using var command = new MySqlCommand(sql, connection);
102102
using var reader = command.ExecuteReader();
103103
if (reader.Read())
104104
{
105-
result = reader.GetString(1);
105+
ddl = reader.GetString(1);
106106
}
107107

108108
reader.Close();
109109
command.Dispose();
110110
connection.Close();
111-
return result;
111+
return ddl;
112112
}
113113

114114
private async Task<string> GetPrompt(string content)

0 commit comments

Comments
 (0)