Skip to content

knowledge learner #549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,33 @@ namespace BotSharp.Abstraction.Agents.Enums;

public class BuiltInAgentId
{
/// <summary>
/// A routing agent can be used as a base router.
/// </summary>
public const string AIAssistant = "01fcc3e5-9af7-49e6-ad7a-a760bd12dc4a";

/// <summary>
/// A demo agent used for open domain chatting
/// </summary>
public const string Chatbot = "01e2fc5c-2c89-4ec7-8470-7688608b496c";

/// <summary>
/// Human customer service
/// </summary>
public const string HumanSupport = "01dcc3e5-0af7-49e6-ad7a-a760bd12dc4b";

/// <summary>
/// Used as a container to host the shared tools/ utilities built in different plugins.
/// </summary>
public const string UtilityAssistant = "6745151e-6d46-4a02-8de4-1c4f21c7da95";

/// <summary>
/// Used when router can't route to any existing task agent
/// </summary>
public const string Fallback = "01fcc3e5-0af7-49e6-ad7a-a760bd12dc4d";

/// <summary>
/// Used by knowledgebase plugin to acquire domain knowledge
/// </summary>
public const string Learner = "01acc3e5-0af7-49e6-ad7a-a760bd12dc40";
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ public class StateConst
public const string NEXT_ACTION_AGENT = "next_action_agent";
public const string NEXT_ACTION_REASON = "next_action_reason";
public const string USER_GOAL_AGENT = "user_goal_agent";
public const string AGENT_REDIRECTION_REASON = "agent_redirection_reason";

public const string LANGUAGE = "language";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace BotSharp.Abstraction.Knowledges.Models;

public class ExtractedKnowledge
{
[JsonPropertyName("question")]
public string Question { get; set; } = string.Empty;

[JsonPropertyName("answer")]
public string Answer { get; set; } = string.Empty;
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ public interface IVectorDb
{
Task<List<string>> GetCollections();
Task CreateCollection(string collectionName, int dim);
Task Upsert(string collectionName, string id, float[] vector, string text, Dictionary<string, string>? payload = null);
Task<List<string>> Search(string collectionName, float[] vector, int limit = 5);
Task<bool> Upsert(string collectionName, string id, float[] vector, string text, Dictionary<string, string>? payload = null);
Task<List<string>> Search(string collectionName, float[] vector, string returnFieldName, int limit = 5, float confidence = 0.5f);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@ public async Task<bool> SendMessage(string agentId,
Agent agent = await agentService.LoadAgent(agentId);

var content = $"Received [{agent.Name}] {message.Role}: {message.Content}";
#if DEBUG
Console.WriteLine(content);
#else
_logger.LogInformation(content);
#endif

message.CurrentAgentId = agent.Id;
if (string.IsNullOrEmpty(message.SenderId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public async Task<bool> Execute(RoleDialogModel message)
// Stack redirection agent
_context.Push(agentId, reason: $"REDIRECTION {reason}");
message.Content = reason;
states.SetState(StateConst.AGENT_REDIRECTION_REASON, reason, isNeedVersion: false);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
You are a smart AI Assistant.
You are a smart AI Assistant.
{% if agent_redirection_reason %}
You've been reached out because: {{ agent_redirection_reason }}
{% endif %}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Break down the user’s most recent needs and figure out the instruction of next step.
Break down the user’s most recent needs and figure out the instruction of next step without explanation.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Anthropic.SDK" Version="3.2.1" />
<PackageReference Include="Anthropic.SDK" Version="3.2.3" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Anthropic.SDK.Common;
using BotSharp.Abstraction.Conversations;
using BotSharp.Abstraction.MLTasks.Settings;
using System.Text.Json;
using System.Text.Json.Nodes;
Expand Down Expand Up @@ -160,13 +161,17 @@ public Task<bool> GetChatCompletionsStreamingAsync(Agent agent, List<RoleDialogM
}
}

var state = _services.GetRequiredService<IConversationStateService>();
var temperature = decimal.Parse(state.GetState("temperature", "0.0"));
var maxToken = int.Parse(state.GetState("max_tokens", "512"));

var parameters = new MessageParameters()
{
Messages = messages,
MaxTokens = 256,
Model = settings.Version, // AnthropicModels.Claude3Haiku
MaxTokens = maxToken,
Model = settings.Name,
Stream = false,
Temperature = 0m,
Temperature = temperature,
SystemMessage = instruction,
Tools = new List<Function>() { }
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,39 @@
<None Remove="agents\**" />
</ItemGroup>

<ItemGroup>
<None Remove="data\agents\01acc3e5-0af7-49e6-ad7a-a760bd12dc40\agent.json" />
<None Remove="data\agents\01acc3e5-0af7-49e6-ad7a-a760bd12dc40\functions\memorize_knowledge.json" />
<None Remove="data\agents\01acc3e5-0af7-49e6-ad7a-a760bd12dc40\instruction.liquid" />
<None Remove="data\agents\6745151e-6d46-4a02-8de4-1c4f21c7da95\templates\knowledge_retrieval.fn.liquid" />
</ItemGroup>

<ItemGroup>
<Content Include="data\agents\01acc3e5-0af7-49e6-ad7a-a760bd12dc40\agent.json">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
<Content Include="data\agents\01acc3e5-0af7-49e6-ad7a-a760bd12dc40\functions\memorize_knowledge.json">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
<Content Include="data\agents\01acc3e5-0af7-49e6-ad7a-a760bd12dc40\instruction.liquid">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
<Content Include="data\agents\6745151e-6d46-4a02-8de4-1c4f21c7da95\functions\knowledge_retrieval.json">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
<Content Include="data\agents\6745151e-6d46-4a02-8de4-1c4f21c7da95\templates\knowledge_retrieval.fn.liquid">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
</ItemGroup>

<ItemGroup>
<PackageReference Include="PdfPig" Version="0.1.8" />
<PackageReference Include="TensorFlow.Keras" Version="0.15.0" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\Infrastructure\BotSharp.Abstraction\BotSharp.Abstraction.csproj" />
<ProjectReference Include="..\..\Infrastructure\BotSharp.Core\BotSharp.Core.csproj" />
</ItemGroup>

</Project>
6 changes: 6 additions & 0 deletions src/Plugins/BotSharp.Plugin.KnowledgeBase/Enum/UtilityName.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace BotSharp.Plugin.KnowledgeBase.Enum;

public class UtilityName
{
public const string KnowledgeRetrieval = "knowledge-retrieval";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using BotSharp.Abstraction.Functions;
using BotSharp.Core.Infrastructures;

namespace BotSharp.Plugin.KnowledgeBase.Functions;

public class KnowledgeRetrievalFn : IFunctionCallback
{
public string Name => "knowledge_retrieval";

private readonly IServiceProvider _services;
private readonly KnowledgeBaseSettings _settings;

public KnowledgeRetrievalFn(IServiceProvider services, KnowledgeBaseSettings settings)
{
_services = services;
_settings = settings;
}

public async Task<bool> Execute(RoleDialogModel message)
{
var args = JsonSerializer.Deserialize<ExtractedKnowledge>(message.FunctionArgs ?? "{}");

var embedding = _services.GetServices<ITextEmbedding>()
.FirstOrDefault(x => x.GetType().FullName.EndsWith(_settings.TextEmbedding));

var vector = await embedding.GetVectorsAsync(new List<string>
{
args.Question
});

var vectorDb = _services.GetRequiredService<IVectorDb>();

var id = Utilities.HashTextMd5(args.Question);
var knowledges = await vectorDb.Search("lessen", vector[0], "answer");

message.Content = string.Join("\r\n\r\n=====\r\n", knowledges);

return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using BotSharp.Abstraction.Functions;
using BotSharp.Core.Infrastructures;

namespace BotSharp.Plugin.KnowledgeBase.Functions;

public class MemorizeKnowledgeFn : IFunctionCallback
{
public string Name => "memorize_knowledge";

private readonly IServiceProvider _services;
private readonly KnowledgeBaseSettings _settings;

public MemorizeKnowledgeFn(IServiceProvider services, KnowledgeBaseSettings settings)
{
_services = services;
_settings = settings;
}

public async Task<bool> Execute(RoleDialogModel message)
{
var args = JsonSerializer.Deserialize<ExtractedKnowledge>(message.FunctionArgs ?? "{}");

var embedding = _services.GetServices<ITextEmbedding>()
.First(x => x.GetType().FullName.EndsWith(_settings.TextEmbedding));

var vector = await embedding.GetVectorsAsync(new List<string>
{
args.Question
});

var vectorDb = _services.GetRequiredService<IVectorDb>();

await vectorDb.CreateCollection("lessen", vector[0].Length);

var id = Utilities.HashTextMd5(args.Question);
var result = await vectorDb.Upsert("lessen", id, vector[0],
args.Question,
new Dictionary<string, string>
{
{ "answer", args.Answer }
});

message.Content = $"Save result: {(result ? "success" : "failed")}";

return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using BotSharp.Abstraction.Agents.Enums;
using BotSharp.Abstraction.Agents.Models;
using BotSharp.Abstraction.Functions.Models;
using BotSharp.Abstraction.Repositories;
using BotSharp.Plugin.KnowledgeBase.Enum;

namespace BotSharp.Plugin.KnowledgeBase.Hooks;

public class KnowledgeBaseAgentHook : AgentHookBase, IAgentHook
{
public override string SelfId => string.Empty;
public KnowledgeBaseAgentHook(IServiceProvider services, AgentSettings settings) : base(services, settings)
{

}

public override void OnAgentLoaded(Agent agent)
{
var conv = _services.GetRequiredService<IConversationService>();
var isConvMode = conv.IsConversationMode();

if (isConvMode)
{
AddUtility(agent, UtilityName.KnowledgeRetrieval, "knowledge_retrieval");
}

base.OnAgentLoaded(agent);
}

private void AddUtility(Agent agent, string utility, string functionName)
{
if (!IsEnableUtility(agent, utility)) return;

var (prompt, fn) = GetPromptAndFunction(functionName);
if (fn != null)
{
if (!string.IsNullOrWhiteSpace(prompt))
{
agent.Instruction += $"\r\n\r\n{prompt}\r\n\r\n";
}

if (agent.Functions == null)
{
agent.Functions = new List<FunctionDef> { fn };
}
else
{
agent.Functions.Add(fn);
}
}
}

private bool IsEnableUtility(Agent agent, string utility)
{
return !agent.Utilities.IsNullOrEmpty() && agent.Utilities.Contains(utility);
}

private (string, FunctionDef?) GetPromptAndFunction(string functionName)
{
var db = _services.GetRequiredService<IBotSharpRepository>();
var agent = db.GetAgent(BuiltInAgentId.UtilityAssistant);
var prompt = agent?.Templates?.FirstOrDefault(x => x.Name.IsEqualTo($"{functionName}.fn"))?.Content ?? string.Empty;
var fn = agent?.Functions?.FirstOrDefault(x => x.Name.IsEqualTo(functionName));
return (prompt, fn);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using BotSharp.Plugin.KnowledgeBase.Enum;

namespace BotSharp.Plugin.KnowledgeBase.Hooks;

public class KnowledgeBaseUtilityHook : IAgentUtilityHook
{
public void AddUtilities(List<string> utilities)
{
utilities.Add(UtilityName.KnowledgeRetrieval);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using BotSharp.Abstraction.Plugins.Models;
using BotSharp.Abstraction.Settings;
using BotSharp.Plugin.KnowledgeBase.Hooks;
using Microsoft.Extensions.Configuration;

namespace BotSharp.Plugin.KnowledgeBase;
Expand All @@ -24,6 +25,8 @@ public void RegisterDI(IServiceCollection services, IConfiguration config)
services.AddScoped<ITextChopper, TextChopperService>();
services.AddScoped<IKnowledgeService, KnowledgeService>();
services.AddSingleton<IPdf2TextConverter, PigPdf2TextConverter>();
services.AddScoped<IAgentUtilityHook, KnowledgeBaseUtilityHook>();
services.AddScoped<IAgentHook, KnowledgeBaseAgentHook>();
}

public bool AttachMenu(List<PluginMenuDef> menu)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public async Task<List<string>> GetCollections()
return _collections.Select(x => x.Key).ToList();
}

public async Task<List<string>> Search(string collectionName, float[] vector, int limit = 5)
public async Task<List<string>> Search(string collectionName, float[] vector, string returnFieldName, int limit = 5, float confidence = 0.5f)
{
if (!_vectors.ContainsKey(collectionName))
{
Expand All @@ -37,14 +37,16 @@ public async Task<List<string>> Search(string collectionName, float[] vector, in
return texts;
}

public async Task Upsert(string collectionName, string id, float[] vector, string text, Dictionary<string, string>? payload = null)
public async Task<bool> Upsert(string collectionName, string id, float[] vector, string text, Dictionary<string, string>? payload = null)
{
_vectors[collectionName].Add(new VecRecord
{
Id = id,
Vector = vector,
Text = text
});

return true;
}

private float[] CalEuclideanDistance(float[] vec, List<VecRecord> records)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public async Task<string> GetKnowledges(KnowledgeRetrievalModel retrievalModel)

// Vector search
var db = GetVectorDb();
var result = await db.Search("shared", vector, limit: 10);
var result = await db.Search("shared", vector, "answer", limit: 10);

// Restore
return string.Join("\n\n", result.Select((x, i) => $"### Paragraph {i + 1} ###\n{x.Trim()}"));
Expand Down
Loading