Skip to content

Commit 37fbd6d

Browse files
authored
Merge pull request #554 from iceljc/features/add-image-edit
Features/add image edit
2 parents 5f459a0 + 4311eaf commit 37fbd6d

File tree

18 files changed

+457
-675
lines changed

18 files changed

+457
-675
lines changed

src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageVariation.cs renamed to src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageCompletion.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
namespace BotSharp.Abstraction.MLTasks;
44

5-
public interface IImageVariation
5+
public interface IImageCompletion
66
{
77
/// <summary>
88
/// The LLM provider like Microsoft Azure, OpenAI, ClaudAI
@@ -15,5 +15,7 @@ public interface IImageVariation
1515
/// <param name="model">deployment name</param>
1616
void SetModelName(string model);
1717

18+
Task<RoleDialogModel> GetImageGeneration(Agent agent, RoleDialogModel message);
19+
1820
Task<RoleDialogModel> GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName);
1921
}

src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageEdit.cs

Lines changed: 0 additions & 5 deletions
This file was deleted.

src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageGeneration.cs

Lines changed: 0 additions & 17 deletions
This file was deleted.

src/Infrastructure/BotSharp.Core/Files/Services/BotSharpFileService.Image.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ public partial class BotSharpFileService
66
{
77
public async Task<RoleDialogModel> GenerateImage(string? provider, string? model, string text)
88
{
9-
var completion = CompletionProvider.GetImageGeneration(_services, provider: provider ?? "openai", model: model ?? "dall-e-3");
9+
var completion = CompletionProvider.GetImageCompletion(_services, provider: provider ?? "openai", model: model ?? "dall-e-3");
1010
var message = await completion.GetImageGeneration(new Agent()
1111
{
1212
Id = Guid.Empty.ToString(),
@@ -21,7 +21,7 @@ public async Task<RoleDialogModel> VarifyImage(string? provider, string? model,
2121
throw new ArgumentException($"Please fill in at least file url or file data!");
2222
}
2323

24-
var completion = CompletionProvider.GetImageVariation(_services, provider: provider ?? "openai", model: model ?? "dall-e-2");
24+
var completion = CompletionProvider.GetImageCompletion(_services, provider: provider ?? "openai", model: model ?? "dall-e-2");
2525
var bytes = await DownloadFile(file);
2626
using var stream = new MemoryStream();
2727
stream.Write(bytes, 0, bytes.Length);

src/Infrastructure/BotSharp.Core/Infrastructures/CompletionProvider.cs

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace BotSharp.Core.Infrastructures;
55

66
public class CompletionProvider
77
{
8-
public static object? GetCompletion(IServiceProvider services,
8+
public static object GetCompletion(IServiceProvider services,
99
string? provider = null,
1010
string? model = null,
1111
AgentLlmConfig? agentConfig = null)
@@ -18,26 +18,20 @@ public class CompletionProvider
1818

1919
if (settings.Type == LlmModelType.Text)
2020
{
21-
return GetTextCompletion(services,
22-
provider: provider,
23-
model: model,
24-
agentConfig: agentConfig);
21+
return GetTextCompletion(services, provider: provider, model: model, agentConfig: agentConfig);
2522
}
2623
else if (settings.Type == LlmModelType.Embedding)
2724
{
28-
return GetTextEmbedding(services,
29-
provider: provider,
30-
model: model);
25+
return GetTextEmbedding(services, provider: provider, model: model);
3126
}
32-
else if (settings.Type == LlmModelType.Chat)
27+
else if (settings.Type == LlmModelType.Image)
3328
{
34-
return GetChatCompletion(services,
35-
provider: provider,
36-
model: model,
37-
agentConfig: agentConfig);
29+
return GetImageCompletion(services, provider: provider, model: model);
30+
}
31+
else
32+
{
33+
return GetChatCompletion(services, provider: provider, model: model, agentConfig: agentConfig);
3834
}
39-
40-
return null;
4135
}
4236

4337
public static IChatCompletion GetChatCompletion(IServiceProvider services,
@@ -82,36 +76,15 @@ public static ITextCompletion GetTextCompletion(IServiceProvider services,
8276
return completer;
8377
}
8478

85-
public static IImageGeneration GetImageGeneration(IServiceProvider services,
86-
string? provider = null,
87-
string? model = null,
88-
string? modelId = null,
89-
bool imageGenerate = false,
90-
AgentLlmConfig? agentConfig = null)
91-
{
92-
var completions = services.GetServices<IImageGeneration>();
93-
(provider, model) = GetProviderAndModel(services, provider: provider, model: model, modelId: modelId,
94-
imageGenerate: imageGenerate, agentConfig: agentConfig);
95-
96-
var completer = completions.FirstOrDefault(x => x.Provider == provider);
97-
if (completer == null)
98-
{
99-
var logger = services.GetRequiredService<ILogger<CompletionProvider>>();
100-
logger.LogError($"Can't resolve completion provider by {provider}");
101-
}
102-
103-
completer?.SetModelName(model);
104-
return completer;
105-
}
106-
107-
public static IImageVariation GetImageVariation(IServiceProvider services,
79+
public static IImageCompletion GetImageCompletion(IServiceProvider services,
10880
string? provider = null,
10981
string? model = null,
11082
string? modelId = null,
11183
bool imageGenerate = false)
11284
{
113-
var completions = services.GetServices<IImageVariation>();
114-
(provider, model) = GetProviderAndModel(services, provider: provider, model: model, modelId: modelId, imageGenerate: imageGenerate);
85+
var completions = services.GetServices<IImageCompletion>();
86+
(provider, model) = GetProviderAndModel(services, provider: provider,
87+
model: model, modelId: modelId, imageGenerate: imageGenerate);
11588

11689
var completer = completions.FirstOrDefault(x => x.Provider == provider);
11790
if (completer == null)

src/Infrastructure/BotSharp.OpenAPI/Controllers/InstructModeController.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ public async Task<ImageGenerationViewModel> ImageVariation([FromBody] IncomingMe
141141
try
142142
{
143143
var file = input.Files.FirstOrDefault(x => !string.IsNullOrWhiteSpace(x.FileUrl) || !string.IsNullOrWhiteSpace(x.FileData));
144+
if (file == null)
145+
{
146+
return new ImageGenerationViewModel { Message = "Error! Cannot find an image!" };
147+
}
144148
var message = await fileService.VarifyImage(input.Provider, input.Model, file);
145149
imageViewModel.Content = message.Content;
146150
imageViewModel.Images = message.GeneratedImages.Select(x => ImageViewModel.ToViewModel(x)).ToList();

src/Plugins/BotSharp.Plugin.AzureOpenAI/AzureOpenAiPlugin.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ public void RegisterDI(IServiceCollection services, IConfiguration config)
2929
services.AddScoped<ITextCompletion, TextCompletionProvider>();
3030
services.AddScoped<IChatCompletion, ChatCompletionProvider>();
3131
services.AddScoped<ITextEmbedding, TextEmbeddingProvider>();
32-
services.AddScoped<IImageGeneration, ImageGenerationProvider>();
33-
services.AddScoped<IImageVariation, ImageVariationProvider>();
32+
services.AddScoped<IImageCompletion, ImageCompletionProvider>();
3433
}
3534
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using OpenAI.Images;
2+
3+
namespace BotSharp.Plugin.AzureOpenAI.Providers.Image;
4+
5+
public partial class ImageCompletionProvider
6+
{
7+
public async Task<RoleDialogModel> GetImageGeneration(Agent agent, RoleDialogModel message)
8+
{
9+
var client = ProviderHelper.GetClient(Provider, _model, _services);
10+
var (prompt, imageCount, options) = PrepareOptions(message);
11+
var imageClient = client.GetImageClient(_model);
12+
13+
var response = imageClient.GenerateImages(prompt, imageCount, options);
14+
var values = response.Value;
15+
16+
var generatedImages = new List<ImageGeneration>();
17+
foreach (var value in values)
18+
{
19+
if (value == null) continue;
20+
21+
var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty };
22+
if (options.ResponseFormat == GeneratedImageFormat.Uri)
23+
{
24+
generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty;
25+
}
26+
else if (options.ResponseFormat == GeneratedImageFormat.Bytes)
27+
{
28+
var base64Str = string.Empty;
29+
var bytes = value?.ImageBytes?.ToArray();
30+
if (!bytes.IsNullOrEmpty())
31+
{
32+
base64Str = Convert.ToBase64String(bytes);
33+
}
34+
generatedImage.ImageData = base64Str;
35+
}
36+
37+
generatedImages.Add(generatedImage);
38+
}
39+
40+
var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description));
41+
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
42+
{
43+
CurrentAgentId = agent.Id,
44+
MessageId = message?.MessageId ?? string.Empty,
45+
GeneratedImages = generatedImages
46+
};
47+
48+
return await Task.FromResult(responseMessage);
49+
}
50+
51+
private (string, int, ImageGenerationOptions) PrepareOptions(RoleDialogModel message)
52+
{
53+
var prompt = message?.Payload ?? message?.Content ?? string.Empty;
54+
55+
var state = _services.GetRequiredService<IConversationStateService>();
56+
var size = GetImageSize(state.GetState("image_size"));
57+
var quality = GetImageQuality(state.GetState("image_quality"));
58+
var style = GetImageStyle(state.GetState("image_style"));
59+
var format = GetImageFormat(state.GetState("image_format"));
60+
var count = GetImageCount(state.GetState("image_count", "1"));
61+
62+
var options = new ImageGenerationOptions
63+
{
64+
Size = size,
65+
Quality = quality,
66+
Style = style,
67+
ResponseFormat = format
68+
};
69+
return (prompt, count, options);
70+
}
71+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
using OpenAI.Images;
2+
3+
namespace BotSharp.Plugin.AzureOpenAI.Providers.Image;
4+
5+
public partial class ImageCompletionProvider
6+
{
7+
public async Task<RoleDialogModel> GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName)
8+
{
9+
var client = ProviderHelper.GetClient(Provider, _model, _services);
10+
var (imageCount, options) = PrepareOptions();
11+
var imageClient = client.GetImageClient(_model);
12+
13+
var response = imageClient.GenerateImageVariations(image, imageFileName, imageCount, options);
14+
var values = response.Value;
15+
16+
var generatedImages = new List<ImageGeneration>();
17+
foreach (var value in values)
18+
{
19+
if (value == null) continue;
20+
21+
var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty };
22+
if (options.ResponseFormat == GeneratedImageFormat.Uri)
23+
{
24+
generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty;
25+
}
26+
else if (options.ResponseFormat == GeneratedImageFormat.Bytes)
27+
{
28+
var base64Str = string.Empty;
29+
var bytes = value?.ImageBytes?.ToArray();
30+
if (!bytes.IsNullOrEmpty())
31+
{
32+
base64Str = Convert.ToBase64String(bytes);
33+
}
34+
generatedImage.ImageData = base64Str;
35+
}
36+
37+
generatedImages.Add(generatedImage);
38+
}
39+
40+
var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description));
41+
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
42+
{
43+
CurrentAgentId = agent.Id,
44+
MessageId = message?.MessageId ?? string.Empty,
45+
GeneratedImages = generatedImages
46+
};
47+
48+
return await Task.FromResult(responseMessage);
49+
}
50+
51+
private (int, ImageVariationOptions) PrepareOptions()
52+
{
53+
var state = _services.GetRequiredService<IConversationStateService>();
54+
var size = GetImageSize(state.GetState("image_size"));
55+
var format = GetImageFormat(state.GetState("image_format"));
56+
var count = GetImageCount(state.GetState("image_count", "1"));
57+
58+
var options = new ImageVariationOptions
59+
{
60+
Size = size,
61+
ResponseFormat = format
62+
};
63+
return (count, options);
64+
}
65+
}

0 commit comments

Comments
 (0)