diff --git a/src/Plugins/BotSharp.Plugin.Planner/Functions/SummaryPlanFn.cs b/src/Plugins/BotSharp.Plugin.Planner/Functions/SummaryPlanFn.cs index 27bb86ec1..7b5134a00 100644 --- a/src/Plugins/BotSharp.Plugin.Planner/Functions/SummaryPlanFn.cs +++ b/src/Plugins/BotSharp.Plugin.Planner/Functions/SummaryPlanFn.cs @@ -1,6 +1,8 @@ using BotSharp.Abstraction.Planning; using BotSharp.Plugin.Planner.TwoStaging; using BotSharp.Plugin.Planner.TwoStaging.Models; +using static System.Net.Mime.MediaTypeNames; +using System.Text.RegularExpressions; namespace BotSharp.Plugin.Planner.Functions; @@ -24,14 +26,13 @@ public async Task Execute(RoleDialogModel message) { var fn = _services.GetRequiredService(); var agentService = _services.GetRequiredService(); - var state = _services.GetRequiredService(); + var states = _services.GetRequiredService(); - state.SetState("max_tokens", "4096"); + states.SetState("max_tokens", "4096"); var currentAgent = await agentService.LoadAgent(message.CurrentAgentId); - var taskRequirement = state.GetState("requirement_detail"); + var taskRequirement = states.GetState("requirement_detail"); // Get table names - var states = _services.GetRequiredService(); var steps = states.GetState("planning_result").JsonArrayContent(); var allTables = new List(); var ddlStatements = string.Empty; @@ -53,6 +54,7 @@ public async Task Execute(RoleDialogModel message) }); await fn.InvokeFunction("sql_table_definition", msgCopy); ddlStatements += "\r\n" + msgCopy.Content; + states.SetState("table_ddls", ddlStatements); // Summarize and generate query var prompt = await GetSummaryPlanPrompt(msgCopy, taskRequirement, domainKnowledge, dictionaryItems, ddlStatements, excelImportResult); @@ -69,6 +71,9 @@ public async Task Execute(RoleDialogModel message) var summary = await GetAiResponse(plannerAgent); message.Content = summary.Content; + // Validate the sql result + await fn.InvokeFunction("validate_sql", message); + await HookEmitter.Emit(_services, async hook => await hook.OnPlanningCompleted(nameof(TwoStageTaskPlanner), message) ); @@ -119,8 +124,8 @@ private async Task GetAiResponse(Agent plannerAgent) wholeDialogs.Last().Content += "\n\nIf the table structure didn't mention auto incremental, the data field id needs to insert id manually and you need to use max(id).\nFor example, you should use SET @id = select max(id) from table;"; wholeDialogs.Last().Content += "\n\nTry if you can generate a single query to fulfill the needs."; - var completion = CompletionProvider.GetChatCompletion(_services, - provider: plannerAgent.LlmConfig.Provider, + var completion = CompletionProvider.GetChatCompletion(_services, + provider: plannerAgent.LlmConfig.Provider, model: plannerAgent.LlmConfig.Model); return await completion.GetChatCompletions(plannerAgent, wholeDialogs); diff --git a/src/Plugins/BotSharp.Plugin.Planner/data/agents/282a7128-69a1-44b0-878c-a9159b88f3b9/instructions/instruction.liquid b/src/Plugins/BotSharp.Plugin.Planner/data/agents/282a7128-69a1-44b0-878c-a9159b88f3b9/instructions/instruction.liquid index 9d8768f5a..89bcd49bb 100644 --- a/src/Plugins/BotSharp.Plugin.Planner/data/agents/282a7128-69a1-44b0-878c-a9159b88f3b9/instructions/instruction.liquid +++ b/src/Plugins/BotSharp.Plugin.Planner/data/agents/282a7128-69a1-44b0-878c-a9159b88f3b9/instructions/instruction.liquid @@ -14,6 +14,7 @@ Use the TwoStagePlanner approach to plan the overall implementation steps, follo Don't run the planning process repeatedly if you have already got the result of user's request. Function verify_dictionary_term CAN'T generate INSERT SQL Statement. The table name must come from the relevant knowledge. has_found_relevant_knowledge must be true. +Do not introduce your actions or intentions in any way. {% if global_knowledges != empty -%} ===== diff --git a/src/Plugins/BotSharp.Plugin.SqlDriver/Functions/ExecuteQueryFn.cs b/src/Plugins/BotSharp.Plugin.SqlDriver/Functions/ExecuteQueryFn.cs index 60351806e..af337babf 100644 --- a/src/Plugins/BotSharp.Plugin.SqlDriver/Functions/ExecuteQueryFn.cs +++ b/src/Plugins/BotSharp.Plugin.SqlDriver/Functions/ExecuteQueryFn.cs @@ -1,5 +1,4 @@ using BotSharp.Abstraction.Agents.Enums; -using BotSharp.Abstraction.Repositories; using BotSharp.Abstraction.Routing; using BotSharp.Core.Infrastructures; using BotSharp.Plugin.SqlDriver.Interfaces; @@ -9,6 +8,7 @@ using Microsoft.Extensions.Logging; using MySqlConnector; using Npgsql; +using System.Data.Common; namespace BotSharp.Plugin.SqlDriver.Functions; @@ -58,10 +58,18 @@ public async Task Execute(RoleDialogModel message) message.Content = JsonSerializer.Serialize(results); } + catch (DbException ex) + { + _logger.LogError(ex, "Error occurred while executing SQL query."); + message.Content = $"Error occurred while executing SQL query: {ex.Message}"; + message.Data = ex; + message.StopCompletion = true; + return false; + } catch (Exception ex) { _logger.LogError(ex, "Error occurred while executing SQL query."); - message.Content = "Error occurred while retrieving information."; + message.Content = $"Error occurred while executing SQL query: {ex.Message}"; message.StopCompletion = true; return false; } @@ -141,11 +149,11 @@ private async Task RefineSqlStatement(RoleDialogModel message, provider: agent.LlmConfig.Provider, model: agent.LlmConfig.Model); - var refinedMessage = await completion.GetChatCompletions(agent, new List - { - new RoleDialogModel(AgentRole.User, "Check and output the correct SQL statements") + var refinedMessage = await completion.GetChatCompletions(agent, new List + { + new RoleDialogModel(AgentRole.User, "Check and output the correct SQL statements") }); - + return refinedMessage.Content.JsonContent(); } diff --git a/src/Plugins/BotSharp.Plugin.SqlDriver/Functions/SqlValidateFn.cs b/src/Plugins/BotSharp.Plugin.SqlDriver/Functions/SqlValidateFn.cs new file mode 100644 index 000000000..8778b5a54 --- /dev/null +++ b/src/Plugins/BotSharp.Plugin.SqlDriver/Functions/SqlValidateFn.cs @@ -0,0 +1,87 @@ +using BotSharp.Abstraction.Agents.Enums; +using BotSharp.Abstraction.Agents.Models; +using BotSharp.Abstraction.Instructs; +using BotSharp.Abstraction.Instructs.Models; +using BotSharp.Abstraction.Routing; +using BotSharp.Core.Agents.Services; +using BotSharp.Core.Infrastructures; +using BotSharp.Core.Instructs; +using BotSharp.Plugin.SqlDriver.Interfaces; +using BotSharp.Plugin.SqlDriver.Models; +using Microsoft.Extensions.Logging; +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Text.RegularExpressions; + +namespace BotSharp.Plugin.SqlDriver.Functions; + +public class SqlValidateFn : IFunctionCallback +{ + public string Name => "validate_sql"; + public string Indication => "Performing data validate operation."; + private readonly IServiceProvider _services; + private readonly ILogger _logger; + public SqlValidateFn(IServiceProvider services) + { + _services = services; + } + + public async Task Execute(RoleDialogModel message) + { + string pattern = @"```sql\s*([\s\S]*?)\s*```"; + var sqls = Regex.Match(message.Content, pattern); + if (!sqls.Success) + { + return false; + } + var sql = sqls.Groups[1].Value; + + var dbHook = _services.GetRequiredService(); + var dbType = dbHook.GetDatabaseType(message); + var validateSql = dbType.ToLower() switch + { + "mysql" => $"explain\r\n{sql}", + "sqlserver" => $"SET PARSEONLY ON;\r\n{sql}\r\nSET PARSEONLY OFF;", + "redshift" => $"explain\r\n{sql}", + _ => throw new NotImplementedException($"Database type {dbType} is not supported.") + }; + var msgCopy = RoleDialogModel.From(message); + msgCopy.FunctionArgs = JsonSerializer.Serialize(new ExecuteQueryArgs + { + SqlStatements = new string[] { validateSql } + }); + + var fn = _services.GetRequiredService(); + await fn.InvokeFunction("execute_sql", msgCopy); + + if (msgCopy.Data != null && msgCopy.Data is DbException ex) + { + + var instructService = _services.GetRequiredService(); + var agentService = _services.GetRequiredService(); + var states = _services.GetRequiredService(); + + var agent = await agentService.GetAgent(BuiltInAgentId.SqlDriver); + var template = agent.Templates.FirstOrDefault(x => x.Name == "sql_statement_correctness")?.Content ?? string.Empty; + var ddl = states.GetState("table_ddls"); + + var correctedSql = await instructService.Instruct(template, BuiltInAgentId.SqlDriver, + new InstructOptions + { + Provider = agent?.LlmConfig?.Provider ?? "openai", + Model = agent?.LlmConfig?.Model ?? "gpt-4o", + Message = "Correct SQL Statement", + Data = new Dictionary + { + { "original_sql", validateSql }, + { "error_message", ex.Message }, + { "table_structure", ddl } + } + }); + message.Content = correctedSql; + } + + return true; + } +} diff --git a/src/Plugins/BotSharp.Plugin.SqlDriver/data/agents/beda4c12-e1ec-4b4b-b328-3df4a6687c4f/templates/sql_statement_correctness.liquid b/src/Plugins/BotSharp.Plugin.SqlDriver/data/agents/beda4c12-e1ec-4b4b-b328-3df4a6687c4f/templates/sql_statement_correctness.liquid index cf61eb1b6..2e6f28e1e 100644 --- a/src/Plugins/BotSharp.Plugin.SqlDriver/data/agents/beda4c12-e1ec-4b4b-b328-3df4a6687c4f/templates/sql_statement_correctness.liquid +++ b/src/Plugins/BotSharp.Plugin.SqlDriver/data/agents/beda4c12-e1ec-4b4b-b328-3df4a6687c4f/templates/sql_statement_correctness.liquid @@ -6,6 +6,10 @@ Make sure all the column names are defined in the Table Structure. Original SQL statements: {{ original_sql }} +===== +Error Message: +{{ error_message }} + ===== Table Structure: {{ table_structure }}