Skip to content

add active rounds to state #368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public interface IConversationService
IConversationStateService States { get; }
string ConversationId { get; }
Task<Conversation> NewConversation(Conversation conversation);
void SetConversationId(string conversationId, List<string> states);
void SetConversationId(string conversationId, List<MessageState> states);
Task<Conversation> GetConversation(string id);
Task<PagedItems<Conversation>> GetConversations(ConversationFilter filter);
Task<Conversation> UpdateConversationTitle(string id, string title);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public interface IConversationStateService
string GetState(string name, string defaultValue = "");
bool ContainsState(string name);
Dictionary<string, string> GetStates();
IConversationStateService SetState<T>(string name, T value, bool isNeedVersion = true);
IConversationStateService SetState<T>(string name, T value, bool isNeedVersion = true, int activeRounds = -1);
void SaveStateByArgs(JsonDocument args);
void CleanStates();
void Save();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ public class StateValue
public string Data { get; set; }

[JsonPropertyName("message_id")]
public string MessageId { get; set; }
public string? MessageId { get; set; }

public bool Active { get; set; }

[JsonPropertyName("active_rounds")]
public int ActiveRounds { get; set; }

[JsonPropertyName("update_time")]
public DateTime UpdateTime { get; set; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class MessageConfig : TruncateMessageRequest
/// <summary>
/// Conversation states from input
/// </summary>
public List<string> States { get; set; } = new List<string>();
public List<MessageState> States { get; set; } = new List<MessageState>();

/// <summary>
/// Agent task id
Expand Down
22 changes: 22 additions & 0 deletions src/Infrastructure/BotSharp.Abstraction/Models/MessageState.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
namespace BotSharp.Abstraction.Models;

public class MessageState
{
public string Key { get; set; }
public string Value { get; set; }

[JsonPropertyName("active_rounds")]
public int ActiveRounds { get; set; } = -1;

public MessageState()
{

}

public MessageState(string key, string value, int activeRounds = -1)
{
Key = key;
Value = value;
ActiveRounds = activeRounds;
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using BotSharp.Abstraction.Models;

namespace BotSharp.Core.Conversations.Services;

public partial class ConversationService : IConversationService
Expand Down Expand Up @@ -120,10 +122,10 @@ public List<RoleDialogModel> GetDialogHistory(int lastCount = 50, bool fromBreak
.ToList();
}

public void SetConversationId(string conversationId, List<string> states)
public void SetConversationId(string conversationId, List<MessageState> states)
{
_conversationId = conversationId;
_state.Load(_conversationId);
states.ForEach(x => _state.SetState(x.Split('=')[0], x.Split('=')[1]));
states.ForEach(x => _state.SetState(x.Key, x.Value, activeRounds: x.ActiveRounds));
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using BotSharp.Abstraction.Conversations.Models;
using BotSharp.Abstraction.Users.Enums;

namespace BotSharp.Core.Conversations.Services;

Expand Down Expand Up @@ -33,7 +33,7 @@ public ConversationStateService(ILogger<ConversationStateService> logger,
/// <param name="value"></param>
/// <param name="isNeedVersion">whether the state is related to message or not</param>
/// <returns></returns>
public IConversationStateService SetState<T>(string name, T value, bool isNeedVersion = true)
public IConversationStateService SetState<T>(string name, T value, bool isNeedVersion = true, int activeRounds = -1)
{
if (value == null)
{
Expand Down Expand Up @@ -69,6 +69,7 @@ public IConversationStateService SetState<T>(string name, T value, bool isNeedVe
Data = currentValue,
MessageId = routingCtx.MessageId,
Active = true,
ActiveRounds = activeRounds > 0 ? activeRounds : -1,
UpdateTime = DateTime.UtcNow,
};

Expand All @@ -90,7 +91,16 @@ public Dictionary<string, string> Load(string conversationId)
{
_conversationId = conversationId;

var routingCtx = _services.GetRequiredService<IRoutingContext>();
var curMsgId = routingCtx.MessageId;
_states = _db.GetConversationStates(_conversationId);
var dialogs = _db.GetConversationDialogs(_conversationId);
var userDialogs = dialogs.Where(x => x.MetaData?.Role == AgentRole.User || x.MetaData?.Role == UserRole.Client)
.OrderBy(x => x.MetaData?.CreateTime)
.ToList();

var curMsgIndex = userDialogs.FindIndex(x => !string.IsNullOrEmpty(curMsgId) && x.MetaData?.MessageId == curMsgId);
curMsgIndex = curMsgIndex < 0 ? userDialogs.Count() : curMsgIndex;
var curStates = new Dictionary<string, string>();

if (!_states.IsNullOrEmpty())
Expand All @@ -100,6 +110,23 @@ public Dictionary<string, string> Load(string conversationId)
var value = state.Value?.Values?.LastOrDefault();
if (value == null || !value.Active) continue;

if (value.ActiveRounds > 0)
{
var stateMsgIndex = userDialogs.FindIndex(x => !string.IsNullOrEmpty(x.MetaData?.MessageId) && x.MetaData.MessageId == value.MessageId);
if (stateMsgIndex >= 0 && curMsgIndex - stateMsgIndex >= value.ActiveRounds)
{
state.Value.Values.Add(new StateValue
{
Data = value.Data,
MessageId = value.MessageId,
Active = false,
ActiveRounds = value.ActiveRounds,
UpdateTime = DateTime.UtcNow
});
continue;
}
}

var data = value.Data ?? string.Empty;
curStates[state.Key] = data;
_logger.LogInformation($"[STATE] {state.Key} : {data}");
Expand Down Expand Up @@ -150,6 +177,7 @@ public void CleanStates()
Data = lastValue.Data,
MessageId = lastValue.MessageId,
Active = false,
ActiveRounds = lastValue.ActiveRounds,
UpdateTime = utcNow
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using BotSharp.Abstraction.Evaluations;
using BotSharp.Abstraction.Evaluations.Models;
using BotSharp.Abstraction.Evaluations.Settings;
using BotSharp.Abstraction.Models;
using BotSharp.Abstraction.Templating;
using System.Drawing;

Expand Down Expand Up @@ -88,15 +89,19 @@ public async Task<EvaluationResult> Evaluate(string conversationId, EvaluationRe
private async Task<RoleDialogModel> SendMessage(string agentId, string conversationId, string text)
{
var conv = _services.GetRequiredService<IConversationService>();
conv.SetConversationId(conversationId, new List<string>

var inputMsg = new RoleDialogModel(AgentRole.User, text);
var routing = _services.GetRequiredService<IRoutingService>();
routing.Context.SetMessageId(conversationId, inputMsg.MessageId);
conv.SetConversationId(conversationId, new List<MessageState>
{
$"channel={ConversationChannel.OpenAPI}"
new MessageState("channel", ConversationChannel.OpenAPI)
});

RoleDialogModel response = default;

await conv.SendMessage(agentId,
new RoleDialogModel(AgentRole.User, text),
inputMsg,
replyMessage: null,
async msg => response = msg,
_ => Task.CompletedTask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public async Task<PagedItems<ConversationViewModel>> GetConversations([FromBody]
public async Task<IEnumerable<ChatResponseModel>> GetDialogs([FromRoute] string conversationId)
{
var conv = _services.GetRequiredService<IConversationService>();
conv.SetConversationId(conversationId, new List<string>());
conv.SetConversationId(conversationId, new List<MessageState>());
var history = conv.GetDialogHistory(fromBreakpoint: false);

var userService = _services.GetRequiredService<IUserService>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public async Task<InstructResult> InstructCompletion([FromRoute] string agentId,
[FromBody] InstructMessageModel input)
{
var state = _services.GetRequiredService<IConversationStateService>();
input.States.ForEach(x => state.SetState(x.Split('=')[0], x.Split('=')[1]));
input.States.ForEach(x => state.SetState(x.Key, x.Value, activeRounds: x.ActiveRounds));
state.SetState("provider", input.Provider)
.SetState("model", input.Model)
.SetState("model_id", input.ModelId)
Expand All @@ -44,7 +44,7 @@ public async Task<InstructResult> InstructCompletion([FromRoute] string agentId,
public async Task<string> TextCompletion([FromBody] IncomingMessageModel input)
{
var state = _services.GetRequiredService<IConversationStateService>();
input.States.ForEach(x => state.SetState(x.Split('=')[0], x.Split('=')[1]));
input.States.ForEach(x => state.SetState(x.Key, x.Value, activeRounds: x.ActiveRounds));
state.SetState("provider", input.Provider)
.SetState("model", input.Model)
.SetState("model_id", input.ModelId);
Expand All @@ -57,7 +57,7 @@ public async Task<string> TextCompletion([FromBody] IncomingMessageModel input)
public async Task<string> ChatCompletion([FromBody] IncomingMessageModel input)
{
var state = _services.GetRequiredService<IConversationStateService>();
input.States.ForEach(x => state.SetState(x.Split('=')[0], x.Split('=')[1]));
input.States.ForEach(x => state.SetState(x.Key, x.Value, activeRounds: x.ActiveRounds));
state.SetState("provider", input.Provider)
.SetState("model", input.Model)
.SetState("model_id", input.ModelId);
Expand Down
4 changes: 4 additions & 0 deletions src/Plugins/BotSharp.Plugin.ChatbotUI/ChatbotUiController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using BotSharp.Abstraction.Agents.Enums;
using BotSharp.Abstraction.MLTasks;
using BotSharp.Abstraction.MLTasks.Settings;
using BotSharp.Abstraction.Routing;

namespace BotSharp.Plugin.ChatbotUI.Controllers;

Expand Down Expand Up @@ -72,6 +73,9 @@ public async Task SendMessage([FromBody] OpenAiMessageInput input)
.Name;

var conv = _services.GetRequiredService<IConversationService>();
var routing = _services.GetRequiredService<IRoutingService>();
routing.Context.SetMessageId(input.ConversationId, message.MessageId);

conv.SetConversationId(input.ConversationId, input.States);
conv.States.SetState("channel", input.Channel)
.SetState("provider", "azure-openai")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
using BotSharp.Abstraction.Agents.Enums;
using BotSharp.Abstraction.Conversations.Enums;
using BotSharp.Abstraction.Conversations.Models;
using BotSharp.Abstraction.Messaging.JsonConverters;
using BotSharp.Abstraction.Messaging.Models.RichContent;
using BotSharp.Abstraction.Models;
using BotSharp.Abstraction.Routing;
using System.Text.Json.Serialization.Metadata;

namespace BotSharp.Plugin.MetaMessenger.Services;
Expand Down Expand Up @@ -52,15 +55,18 @@ await messenger.SendMessage(setting.ApiVersion, setting.PageId,
});

// Go to LLM
var inputMsg = new RoleDialogModel(AgentRole.User, message);
var conv = _services.GetRequiredService<IConversationService>();
conv.SetConversationId(sender, new List<string>
var routing = _services.GetRequiredService<IRoutingService>();
routing.Context.SetMessageId(sender, inputMsg.MessageId);
conv.SetConversationId(sender, new List<MessageState>
{
$"channel={ConversationChannel.Messenger}"
new MessageState("channel", ConversationChannel.Messenger)
});

var replies = new List<IRichMessage>();
var result = await conv.SendMessage(agentId,
new RoleDialogModel(AgentRole.User, message),
inputMsg,
replyMessage: null,
async msg =>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ public static StateKeyValue ToDomainElement(StateMongoElement state)
public class StateValueMongoElement
{
public string Data { get; set; }
public string MessageId { get; set; }
public string? MessageId { get; set; }
public bool Active { get; set; }
public int ActiveRounds { get; set; }
public DateTime UpdateTime { get; set; }

public static StateValueMongoElement ToMongoElement(StateValue element)
Expand All @@ -43,6 +44,7 @@ public static StateValueMongoElement ToMongoElement(StateValue element)
Data = element.Data,
MessageId = element.MessageId,
Active = element.Active,
ActiveRounds = element.ActiveRounds,
UpdateTime = element.UpdateTime
};
}
Expand All @@ -54,6 +56,7 @@ public static StateValue ToDomainElement(StateValueMongoElement element)
Data = element.Data,
MessageId = element.MessageId,
Active = element.Active,
ActiveRounds = element.ActiveRounds,
UpdateTime = element.UpdateTime
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.AspNetCore.Mvc;
using System.IdentityModel.Tokens.Jwt;
using BotSharp.Plugin.Twilio.Services;
using BotSharp.Abstraction.Routing;

namespace BotSharp.Plugin.Twilio.Controllers;

Expand Down Expand Up @@ -47,18 +48,22 @@ public async Task<TwiMLResult> ReceivedVoiceMessage([FromRoute] string agentId,
{
string sessionId = $"TwilioVoice_{input.CallSid}";

var inputMsg = new RoleDialogModel(AgentRole.User, input.SpeechResult);
var conv = _services.GetRequiredService<IConversationService>();
conv.SetConversationId(sessionId, new List<string>
var routing = _services.GetRequiredService<IRoutingService>();
routing.Context.SetMessageId(sessionId, inputMsg.MessageId);

conv.SetConversationId(sessionId, new List<MessageState>
{
$"channel={ConversationChannel.Phone}",
$"calling_phone={input.DialCallSid}"
new MessageState("channel", ConversationChannel.Phone),
new MessageState("calling_phone", input.DialCallSid)
});

var twilio = _services.GetRequiredService<TwilioService>();
VoiceResponse response = default;

var result = await conv.SendMessage(agentId,
new RoleDialogModel(AgentRole.User, input.SpeechResult),
inputMsg,
replyMessage: null,
async msg =>
{
Expand Down
15 changes: 11 additions & 4 deletions src/Plugins/BotSharp.Plugin.WeChat/WeChatBackgroundService.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using BotSharp.Abstraction.Agents.Enums;
using BotSharp.Abstraction.Conversations;
using BotSharp.Abstraction.Conversations.Models;
using BotSharp.Abstraction.Models;
using BotSharp.Abstraction.Repositories.Filters;
using BotSharp.Abstraction.Routing;
using BotSharp.Abstraction.Users.Models;
using BotSharp.Plugin.WeChat.Users;
using Microsoft.AspNetCore.Http;
Expand Down Expand Up @@ -50,9 +53,13 @@ private async Task HandleTextMessageAsync(string openid, string message)
.OrderByDescending(_ => _.CreatedTime)
.FirstOrDefault()?.Id;

conversationService.SetConversationId(latestConversationId, new List<string>
var inputMsg = new RoleDialogModel(AgentRole.User, message);
var routing = _service.GetRequiredService<IRoutingService>();
routing.Context.SetMessageId(latestConversationId, inputMsg.MessageId);

conversationService.SetConversationId(latestConversationId, new List<MessageState>
{
"channel=wechat"
new MessageState("channel", "wechat")
});

latestConversationId ??= (await conversationService.NewConversation(new Conversation()
Expand All @@ -61,8 +68,8 @@ private async Task HandleTextMessageAsync(string openid, string message)
AgentId = AgentId
}))?.Id;

var result = await conversationService.SendMessage(AgentId,
new RoleDialogModel("user", message),
var result = await conversationService.SendMessage(AgentId,
inputMsg,
replyMessage: null,
async msg =>
{
Expand Down