Skip to content

Commit bf274fa

Browse files
authored
Merge pull request #744 from Joannall/master
Add Sql Validator
2 parents 16b9703 + cbd181c commit bf274fa

File tree

5 files changed

+117
-12
lines changed

5 files changed

+117
-12
lines changed

src/Plugins/BotSharp.Plugin.Planner/Functions/SummaryPlanFn.cs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using BotSharp.Abstraction.Planning;
22
using BotSharp.Plugin.Planner.TwoStaging;
33
using BotSharp.Plugin.Planner.TwoStaging.Models;
4+
using static System.Net.Mime.MediaTypeNames;
5+
using System.Text.RegularExpressions;
46

57
namespace BotSharp.Plugin.Planner.Functions;
68

@@ -24,14 +26,13 @@ public async Task<bool> Execute(RoleDialogModel message)
2426
{
2527
var fn = _services.GetRequiredService<IRoutingService>();
2628
var agentService = _services.GetRequiredService<IAgentService>();
27-
var state = _services.GetRequiredService<IConversationStateService>();
29+
var states = _services.GetRequiredService<IConversationStateService>();
2830

29-
state.SetState("max_tokens", "4096");
31+
states.SetState("max_tokens", "4096");
3032
var currentAgent = await agentService.LoadAgent(message.CurrentAgentId);
31-
var taskRequirement = state.GetState("requirement_detail");
33+
var taskRequirement = states.GetState("requirement_detail");
3234

3335
// Get table names
34-
var states = _services.GetRequiredService<IConversationStateService>();
3536
var steps = states.GetState("planning_result").JsonArrayContent<SecondStagePlan>();
3637
var allTables = new List<string>();
3738
var ddlStatements = string.Empty;
@@ -53,6 +54,7 @@ public async Task<bool> Execute(RoleDialogModel message)
5354
});
5455
await fn.InvokeFunction("sql_table_definition", msgCopy);
5556
ddlStatements += "\r\n" + msgCopy.Content;
57+
states.SetState("table_ddls", ddlStatements);
5658

5759
// Summarize and generate query
5860
var prompt = await GetSummaryPlanPrompt(msgCopy, taskRequirement, domainKnowledge, dictionaryItems, ddlStatements, excelImportResult);
@@ -69,6 +71,9 @@ public async Task<bool> Execute(RoleDialogModel message)
6971
var summary = await GetAiResponse(plannerAgent);
7072
message.Content = summary.Content;
7173

74+
// Validate the sql result
75+
await fn.InvokeFunction("validate_sql", message);
76+
7277
await HookEmitter.Emit<IPlanningHook>(_services, async hook =>
7378
await hook.OnPlanningCompleted(nameof(TwoStageTaskPlanner), message)
7479
);
@@ -119,8 +124,8 @@ private async Task<RoleDialogModel> GetAiResponse(Agent plannerAgent)
119124
wholeDialogs.Last().Content += "\n\nIf the table structure didn't mention auto incremental, the data field id needs to insert id manually and you need to use max(id).\nFor example, you should use SET @id = select max(id) from table;";
120125
wholeDialogs.Last().Content += "\n\nTry if you can generate a single query to fulfill the needs.";
121126

122-
var completion = CompletionProvider.GetChatCompletion(_services,
123-
provider: plannerAgent.LlmConfig.Provider,
127+
var completion = CompletionProvider.GetChatCompletion(_services,
128+
provider: plannerAgent.LlmConfig.Provider,
124129
model: plannerAgent.LlmConfig.Model);
125130

126131
return await completion.GetChatCompletions(plannerAgent, wholeDialogs);

src/Plugins/BotSharp.Plugin.Planner/data/agents/282a7128-69a1-44b0-878c-a9159b88f3b9/instructions/instruction.liquid

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Use the TwoStagePlanner approach to plan the overall implementation steps, follo
1414
Don't run the planning process repeatedly if you have already got the result of user's request.
1515
Function verify_dictionary_term CAN'T generate INSERT SQL Statement.
1616
The table name must come from the relevant knowledge. has_found_relevant_knowledge must be true.
17+
Do not introduce your actions or intentions in any way.
1718

1819
{% if global_knowledges != empty -%}
1920
=====

src/Plugins/BotSharp.Plugin.SqlDriver/Functions/ExecuteQueryFn.cs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using BotSharp.Abstraction.Agents.Enums;
2-
using BotSharp.Abstraction.Repositories;
32
using BotSharp.Abstraction.Routing;
43
using BotSharp.Core.Infrastructures;
54
using BotSharp.Plugin.SqlDriver.Interfaces;
@@ -9,6 +8,7 @@
98
using Microsoft.Extensions.Logging;
109
using MySqlConnector;
1110
using Npgsql;
11+
using System.Data.Common;
1212

1313
namespace BotSharp.Plugin.SqlDriver.Functions;
1414

@@ -58,10 +58,18 @@ public async Task<bool> Execute(RoleDialogModel message)
5858

5959
message.Content = JsonSerializer.Serialize(results);
6060
}
61+
catch (DbException ex)
62+
{
63+
_logger.LogError(ex, "Error occurred while executing SQL query.");
64+
message.Content = $"Error occurred while executing SQL query: {ex.Message}";
65+
message.Data = ex;
66+
message.StopCompletion = true;
67+
return false;
68+
}
6169
catch (Exception ex)
6270
{
6371
_logger.LogError(ex, "Error occurred while executing SQL query.");
64-
message.Content = "Error occurred while retrieving information.";
72+
message.Content = $"Error occurred while executing SQL query: {ex.Message}";
6573
message.StopCompletion = true;
6674
return false;
6775
}
@@ -141,11 +149,11 @@ private async Task<ExecuteQueryArgs> RefineSqlStatement(RoleDialogModel message,
141149
provider: agent.LlmConfig.Provider,
142150
model: agent.LlmConfig.Model);
143151

144-
var refinedMessage = await completion.GetChatCompletions(agent, new List<RoleDialogModel>
145-
{
146-
new RoleDialogModel(AgentRole.User, "Check and output the correct SQL statements")
152+
var refinedMessage = await completion.GetChatCompletions(agent, new List<RoleDialogModel>
153+
{
154+
new RoleDialogModel(AgentRole.User, "Check and output the correct SQL statements")
147155
});
148-
156+
149157
return refinedMessage.Content.JsonContent<ExecuteQueryArgs>();
150158
}
151159

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
using BotSharp.Abstraction.Agents.Enums;
2+
using BotSharp.Abstraction.Agents.Models;
3+
using BotSharp.Abstraction.Instructs;
4+
using BotSharp.Abstraction.Instructs.Models;
5+
using BotSharp.Abstraction.Routing;
6+
using BotSharp.Core.Agents.Services;
7+
using BotSharp.Core.Infrastructures;
8+
using BotSharp.Core.Instructs;
9+
using BotSharp.Plugin.SqlDriver.Interfaces;
10+
using BotSharp.Plugin.SqlDriver.Models;
11+
using Microsoft.Extensions.Logging;
12+
using System;
13+
using System.Collections.Generic;
14+
using System.Data.Common;
15+
using System.Text.RegularExpressions;
16+
17+
namespace BotSharp.Plugin.SqlDriver.Functions;
18+
19+
public class SqlValidateFn : IFunctionCallback
20+
{
21+
public string Name => "validate_sql";
22+
public string Indication => "Performing data validate operation.";
23+
private readonly IServiceProvider _services;
24+
private readonly ILogger _logger;
25+
public SqlValidateFn(IServiceProvider services)
26+
{
27+
_services = services;
28+
}
29+
30+
public async Task<bool> Execute(RoleDialogModel message)
31+
{
32+
string pattern = @"```sql\s*([\s\S]*?)\s*```";
33+
var sqls = Regex.Match(message.Content, pattern);
34+
if (!sqls.Success)
35+
{
36+
return false;
37+
}
38+
var sql = sqls.Groups[1].Value;
39+
40+
var dbHook = _services.GetRequiredService<ISqlDriverHook>();
41+
var dbType = dbHook.GetDatabaseType(message);
42+
var validateSql = dbType.ToLower() switch
43+
{
44+
"mysql" => $"explain\r\n{sql}",
45+
"sqlserver" => $"SET PARSEONLY ON;\r\n{sql}\r\nSET PARSEONLY OFF;",
46+
"redshift" => $"explain\r\n{sql}",
47+
_ => throw new NotImplementedException($"Database type {dbType} is not supported.")
48+
};
49+
var msgCopy = RoleDialogModel.From(message);
50+
msgCopy.FunctionArgs = JsonSerializer.Serialize(new ExecuteQueryArgs
51+
{
52+
SqlStatements = new string[] { validateSql }
53+
});
54+
55+
var fn = _services.GetRequiredService<IRoutingService>();
56+
await fn.InvokeFunction("execute_sql", msgCopy);
57+
58+
if (msgCopy.Data != null && msgCopy.Data is DbException ex)
59+
{
60+
61+
var instructService = _services.GetRequiredService<IInstructService>();
62+
var agentService = _services.GetRequiredService<IAgentService>();
63+
var states = _services.GetRequiredService<IConversationStateService>();
64+
65+
var agent = await agentService.GetAgent(BuiltInAgentId.SqlDriver);
66+
var template = agent.Templates.FirstOrDefault(x => x.Name == "sql_statement_correctness")?.Content ?? string.Empty;
67+
var ddl = states.GetState("table_ddls");
68+
69+
var correctedSql = await instructService.Instruct<string>(template, BuiltInAgentId.SqlDriver,
70+
new InstructOptions
71+
{
72+
Provider = agent?.LlmConfig?.Provider ?? "openai",
73+
Model = agent?.LlmConfig?.Model ?? "gpt-4o",
74+
Message = "Correct SQL Statement",
75+
Data = new Dictionary<string, object>
76+
{
77+
{ "original_sql", validateSql },
78+
{ "error_message", ex.Message },
79+
{ "table_structure", ddl }
80+
}
81+
});
82+
message.Content = correctedSql;
83+
}
84+
85+
return true;
86+
}
87+
}

src/Plugins/BotSharp.Plugin.SqlDriver/data/agents/beda4c12-e1ec-4b4b-b328-3df4a6687c4f/templates/sql_statement_correctness.liquid

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ Make sure all the column names are defined in the Table Structure.
66
Original SQL statements:
77
{{ original_sql }}
88

9+
=====
10+
Error Message:
11+
{{ error_message }}
12+
913
=====
1014
Table Structure:
1115
{{ table_structure }}

0 commit comments

Comments
 (0)