Skip to content

Add Sql Validator #744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/Plugins/BotSharp.Plugin.Planner/Functions/SummaryPlanFn.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using BotSharp.Abstraction.Planning;
using BotSharp.Plugin.Planner.TwoStaging;
using BotSharp.Plugin.Planner.TwoStaging.Models;
using static System.Net.Mime.MediaTypeNames;
using System.Text.RegularExpressions;

namespace BotSharp.Plugin.Planner.Functions;

Expand All @@ -24,14 +26,13 @@ public async Task<bool> Execute(RoleDialogModel message)
{
var fn = _services.GetRequiredService<IRoutingService>();
var agentService = _services.GetRequiredService<IAgentService>();
var state = _services.GetRequiredService<IConversationStateService>();
var states = _services.GetRequiredService<IConversationStateService>();

state.SetState("max_tokens", "4096");
states.SetState("max_tokens", "4096");
var currentAgent = await agentService.LoadAgent(message.CurrentAgentId);
var taskRequirement = state.GetState("requirement_detail");
var taskRequirement = states.GetState("requirement_detail");

// Get table names
var states = _services.GetRequiredService<IConversationStateService>();
var steps = states.GetState("planning_result").JsonArrayContent<SecondStagePlan>();
var allTables = new List<string>();
var ddlStatements = string.Empty;
Expand All @@ -53,6 +54,7 @@ public async Task<bool> Execute(RoleDialogModel message)
});
await fn.InvokeFunction("sql_table_definition", msgCopy);
ddlStatements += "\r\n" + msgCopy.Content;
states.SetState("table_ddls", ddlStatements);

// Summarize and generate query
var prompt = await GetSummaryPlanPrompt(msgCopy, taskRequirement, domainKnowledge, dictionaryItems, ddlStatements, excelImportResult);
Expand All @@ -69,6 +71,9 @@ public async Task<bool> Execute(RoleDialogModel message)
var summary = await GetAiResponse(plannerAgent);
message.Content = summary.Content;

// Validate the sql result
await fn.InvokeFunction("validate_sql", message);

await HookEmitter.Emit<IPlanningHook>(_services, async hook =>
await hook.OnPlanningCompleted(nameof(TwoStageTaskPlanner), message)
);
Expand Down Expand Up @@ -119,8 +124,8 @@ private async Task<RoleDialogModel> GetAiResponse(Agent plannerAgent)
wholeDialogs.Last().Content += "\n\nIf the table structure didn't mention auto incremental, the data field id needs to insert id manually and you need to use max(id).\nFor example, you should use SET @id = select max(id) from table;";
wholeDialogs.Last().Content += "\n\nTry if you can generate a single query to fulfill the needs.";

var completion = CompletionProvider.GetChatCompletion(_services,
provider: plannerAgent.LlmConfig.Provider,
var completion = CompletionProvider.GetChatCompletion(_services,
provider: plannerAgent.LlmConfig.Provider,
model: plannerAgent.LlmConfig.Model);

return await completion.GetChatCompletions(plannerAgent, wholeDialogs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Use the TwoStagePlanner approach to plan the overall implementation steps, follo
Don't run the planning process repeatedly if you have already got the result of user's request.
Function verify_dictionary_term CAN'T generate INSERT SQL Statement.
The table name must come from the relevant knowledge. has_found_relevant_knowledge must be true.
Do not introduce your actions or intentions in any way.

{% if global_knowledges != empty -%}
=====
Expand Down
20 changes: 14 additions & 6 deletions src/Plugins/BotSharp.Plugin.SqlDriver/Functions/ExecuteQueryFn.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using BotSharp.Abstraction.Agents.Enums;
using BotSharp.Abstraction.Repositories;
using BotSharp.Abstraction.Routing;
using BotSharp.Core.Infrastructures;
using BotSharp.Plugin.SqlDriver.Interfaces;
Expand All @@ -9,6 +8,7 @@
using Microsoft.Extensions.Logging;
using MySqlConnector;
using Npgsql;
using System.Data.Common;

namespace BotSharp.Plugin.SqlDriver.Functions;

Expand Down Expand Up @@ -58,10 +58,18 @@ public async Task<bool> Execute(RoleDialogModel message)

message.Content = JsonSerializer.Serialize(results);
}
catch (DbException ex)
{
_logger.LogError(ex, "Error occurred while executing SQL query.");
message.Content = $"Error occurred while executing SQL query: {ex.Message}";
message.Data = ex;
message.StopCompletion = true;
return false;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error occurred while executing SQL query.");
message.Content = "Error occurred while retrieving information.";
message.Content = $"Error occurred while executing SQL query: {ex.Message}";
message.StopCompletion = true;
return false;
}
Expand Down Expand Up @@ -141,11 +149,11 @@ private async Task<ExecuteQueryArgs> RefineSqlStatement(RoleDialogModel message,
provider: agent.LlmConfig.Provider,
model: agent.LlmConfig.Model);

var refinedMessage = await completion.GetChatCompletions(agent, new List<RoleDialogModel>
{
new RoleDialogModel(AgentRole.User, "Check and output the correct SQL statements")
var refinedMessage = await completion.GetChatCompletions(agent, new List<RoleDialogModel>
{
new RoleDialogModel(AgentRole.User, "Check and output the correct SQL statements")
});

return refinedMessage.Content.JsonContent<ExecuteQueryArgs>();
}

Expand Down
87 changes: 87 additions & 0 deletions src/Plugins/BotSharp.Plugin.SqlDriver/Functions/SqlValidateFn.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
using BotSharp.Abstraction.Agents.Enums;
using BotSharp.Abstraction.Agents.Models;
using BotSharp.Abstraction.Instructs;
using BotSharp.Abstraction.Instructs.Models;
using BotSharp.Abstraction.Routing;
using BotSharp.Core.Agents.Services;
using BotSharp.Core.Infrastructures;
using BotSharp.Core.Instructs;
using BotSharp.Plugin.SqlDriver.Interfaces;
using BotSharp.Plugin.SqlDriver.Models;
using Microsoft.Extensions.Logging;
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Text.RegularExpressions;

namespace BotSharp.Plugin.SqlDriver.Functions;

public class SqlValidateFn : IFunctionCallback
{
public string Name => "validate_sql";
public string Indication => "Performing data validate operation.";
private readonly IServiceProvider _services;
private readonly ILogger _logger;
public SqlValidateFn(IServiceProvider services)
{
_services = services;
}

public async Task<bool> Execute(RoleDialogModel message)
{
string pattern = @"```sql\s*([\s\S]*?)\s*```";
var sqls = Regex.Match(message.Content, pattern);
if (!sqls.Success)
{
return false;
}
var sql = sqls.Groups[1].Value;

var dbHook = _services.GetRequiredService<ISqlDriverHook>();
var dbType = dbHook.GetDatabaseType(message);
var validateSql = dbType.ToLower() switch
{
"mysql" => $"explain\r\n{sql}",
"sqlserver" => $"SET PARSEONLY ON;\r\n{sql}\r\nSET PARSEONLY OFF;",
"redshift" => $"explain\r\n{sql}",
_ => throw new NotImplementedException($"Database type {dbType} is not supported.")
};
var msgCopy = RoleDialogModel.From(message);
msgCopy.FunctionArgs = JsonSerializer.Serialize(new ExecuteQueryArgs
{
SqlStatements = new string[] { validateSql }
});

var fn = _services.GetRequiredService<IRoutingService>();
await fn.InvokeFunction("execute_sql", msgCopy);

if (msgCopy.Data != null && msgCopy.Data is DbException ex)
{

var instructService = _services.GetRequiredService<IInstructService>();
var agentService = _services.GetRequiredService<IAgentService>();
var states = _services.GetRequiredService<IConversationStateService>();

var agent = await agentService.GetAgent(BuiltInAgentId.SqlDriver);
var template = agent.Templates.FirstOrDefault(x => x.Name == "sql_statement_correctness")?.Content ?? string.Empty;
var ddl = states.GetState("table_ddls");

var correctedSql = await instructService.Instruct<string>(template, BuiltInAgentId.SqlDriver,
new InstructOptions
{
Provider = agent?.LlmConfig?.Provider ?? "openai",
Model = agent?.LlmConfig?.Model ?? "gpt-4o",
Message = "Correct SQL Statement",
Data = new Dictionary<string, object>
{
{ "original_sql", validateSql },
{ "error_message", ex.Message },
{ "table_structure", ddl }
}
});
message.Content = correctedSql;
}

return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ Make sure all the column names are defined in the Table Structure.
Original SQL statements:
{{ original_sql }}

=====
Error Message:
{{ error_message }}

=====
Table Structure:
{{ table_structure }}