From 1d687a38a0de6ecd5bcbd75a5707a76b42a5c19f Mon Sep 17 00:00:00 2001 From: Jicheng Lu Date: Thu, 18 Jul 2024 22:28:24 -0500 Subject: [PATCH 1/3] refine image completion --- ...IImageVariation.cs => IImageCompletion.cs} | 4 +- .../MLTasks/IImageEdit.cs | 5 - .../MLTasks/IImageGeneration.cs | 17 -- .../Services/BotSharpFileService.Image.cs | 4 +- .../Infrastructures/CompletionProvider.cs | 53 ++--- .../AzureOpenAiPlugin.cs | 3 +- .../ImageCompletionProvider.Generation.cs | 71 ++++++ .../ImageCompletionProvider.Variation.cs | 65 ++++++ ...Provider.cs => ImageCompletionProvider.cs} | 100 ++------ .../Providers/Image/ImageVariationProvider.cs | 152 ------------- .../Functions/GenerateImageFn.cs | 3 +- .../BotSharp.Plugin.OpenAI/OpenAiPlugin.cs | 3 +- .../ImageCompletionProvider.Generation.cs | 73 ++++++ .../ImageCompletionProvider.Variation.cs | 65 ++++++ .../Image/ImageCompletionProvider.cs | 145 ++++++++++++ .../Image/ImageGenerationProvider.cs | 215 ------------------ .../Providers/Image/ImageVariationProvider.cs | 152 ------------- 17 files changed, 455 insertions(+), 675 deletions(-) rename src/Infrastructure/BotSharp.Abstraction/MLTasks/{IImageVariation.cs => IImageCompletion.cs} (81%) delete mode 100644 src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageEdit.cs delete mode 100644 src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageGeneration.cs create mode 100644 src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Generation.cs create mode 100644 src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Variation.cs rename src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/{ImageGenerationProvider.cs => ImageCompletionProvider.cs} (50%) delete mode 100644 src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageVariationProvider.cs create mode 100644 src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Generation.cs create mode 100644 src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Variation.cs create mode 100644 src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.cs delete mode 100644 src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageGenerationProvider.cs delete mode 100644 src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageVariationProvider.cs diff --git a/src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageVariation.cs b/src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageCompletion.cs similarity index 81% rename from src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageVariation.cs rename to src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageCompletion.cs index a60f43d2c..13132c633 100644 --- a/src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageVariation.cs +++ b/src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageCompletion.cs @@ -2,7 +2,7 @@ namespace BotSharp.Abstraction.MLTasks; -public interface IImageVariation +public interface IImageCompletion { /// /// The LLM provider like Microsoft Azure, OpenAI, ClaudAI @@ -15,5 +15,7 @@ public interface IImageVariation /// deployment name void SetModelName(string model); + Task GetImageGeneration(Agent agent, RoleDialogModel message); + Task GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName); } diff --git a/src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageEdit.cs b/src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageEdit.cs deleted file mode 100644 index 78f44d233..000000000 --- a/src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageEdit.cs +++ /dev/null @@ -1,5 +0,0 @@ -namespace BotSharp.Abstraction.MLTasks; - -public interface IImageEdit -{ -} diff --git a/src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageGeneration.cs b/src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageGeneration.cs deleted file mode 100644 index 2b3458628..000000000 --- a/src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageGeneration.cs +++ /dev/null @@ -1,17 +0,0 @@ -namespace BotSharp.Abstraction.MLTasks; - -public interface IImageGeneration -{ - /// - /// The LLM provider like Microsoft Azure, OpenAI, ClaudAI - /// - string Provider { get; } - - /// - /// Set model name, one provider can consume different model or version(s) - /// - /// deployment name - void SetModelName(string model); - - Task GetImageGeneration(Agent agent, RoleDialogModel message); -} diff --git a/src/Infrastructure/BotSharp.Core/Files/Services/BotSharpFileService.Image.cs b/src/Infrastructure/BotSharp.Core/Files/Services/BotSharpFileService.Image.cs index 40f990112..011362267 100644 --- a/src/Infrastructure/BotSharp.Core/Files/Services/BotSharpFileService.Image.cs +++ b/src/Infrastructure/BotSharp.Core/Files/Services/BotSharpFileService.Image.cs @@ -6,7 +6,7 @@ public partial class BotSharpFileService { public async Task GenerateImage(string? provider, string? model, string text) { - var completion = CompletionProvider.GetImageGeneration(_services, provider: provider ?? "openai", model: model ?? "dall-e-3"); + var completion = CompletionProvider.GetImageCompletion(_services, provider: provider ?? "openai", model: model ?? "dall-e-3"); var message = await completion.GetImageGeneration(new Agent() { Id = Guid.Empty.ToString(), @@ -21,7 +21,7 @@ public async Task VarifyImage(string? provider, string? model, throw new ArgumentException($"Please fill in at least file url or file data!"); } - var completion = CompletionProvider.GetImageVariation(_services, provider: provider ?? "openai", model: model ?? "dall-e-2"); + var completion = CompletionProvider.GetImageCompletion(_services, provider: provider ?? "openai", model: model ?? "dall-e-2"); var bytes = await DownloadFile(file); using var stream = new MemoryStream(); stream.Write(bytes, 0, bytes.Length); diff --git a/src/Infrastructure/BotSharp.Core/Infrastructures/CompletionProvider.cs b/src/Infrastructure/BotSharp.Core/Infrastructures/CompletionProvider.cs index 22675d3a2..dd6b99af8 100644 --- a/src/Infrastructure/BotSharp.Core/Infrastructures/CompletionProvider.cs +++ b/src/Infrastructure/BotSharp.Core/Infrastructures/CompletionProvider.cs @@ -5,7 +5,7 @@ namespace BotSharp.Core.Infrastructures; public class CompletionProvider { - public static object? GetCompletion(IServiceProvider services, + public static object GetCompletion(IServiceProvider services, string? provider = null, string? model = null, AgentLlmConfig? agentConfig = null) @@ -18,26 +18,20 @@ public class CompletionProvider if (settings.Type == LlmModelType.Text) { - return GetTextCompletion(services, - provider: provider, - model: model, - agentConfig: agentConfig); + return GetTextCompletion(services, provider: provider, model: model, agentConfig: agentConfig); } else if (settings.Type == LlmModelType.Embedding) { - return GetTextEmbedding(services, - provider: provider, - model: model); + return GetTextEmbedding(services, provider: provider, model: model); } - else if (settings.Type == LlmModelType.Chat) + else if (settings.Type == LlmModelType.Image) { - return GetChatCompletion(services, - provider: provider, - model: model, - agentConfig: agentConfig); + return GetImageCompletion(services, provider: provider, model: model); + } + else + { + return GetChatCompletion(services, provider: provider, model: model, agentConfig: agentConfig); } - - return null; } public static IChatCompletion GetChatCompletion(IServiceProvider services, @@ -82,36 +76,15 @@ public static ITextCompletion GetTextCompletion(IServiceProvider services, return completer; } - public static IImageGeneration GetImageGeneration(IServiceProvider services, - string? provider = null, - string? model = null, - string? modelId = null, - bool imageGenerate = false, - AgentLlmConfig? agentConfig = null) - { - var completions = services.GetServices(); - (provider, model) = GetProviderAndModel(services, provider: provider, model: model, modelId: modelId, - imageGenerate: imageGenerate, agentConfig: agentConfig); - - var completer = completions.FirstOrDefault(x => x.Provider == provider); - if (completer == null) - { - var logger = services.GetRequiredService>(); - logger.LogError($"Can't resolve completion provider by {provider}"); - } - - completer?.SetModelName(model); - return completer; - } - - public static IImageVariation GetImageVariation(IServiceProvider services, + public static IImageCompletion GetImageCompletion(IServiceProvider services, string? provider = null, string? model = null, string? modelId = null, bool imageGenerate = false) { - var completions = services.GetServices(); - (provider, model) = GetProviderAndModel(services, provider: provider, model: model, modelId: modelId, imageGenerate: imageGenerate); + var completions = services.GetServices(); + (provider, model) = GetProviderAndModel(services, provider: provider, + model: model, modelId: modelId, imageGenerate: imageGenerate); var completer = completions.FirstOrDefault(x => x.Provider == provider); if (completer == null) diff --git a/src/Plugins/BotSharp.Plugin.AzureOpenAI/AzureOpenAiPlugin.cs b/src/Plugins/BotSharp.Plugin.AzureOpenAI/AzureOpenAiPlugin.cs index e7f2130c4..dba74d274 100644 --- a/src/Plugins/BotSharp.Plugin.AzureOpenAI/AzureOpenAiPlugin.cs +++ b/src/Plugins/BotSharp.Plugin.AzureOpenAI/AzureOpenAiPlugin.cs @@ -29,7 +29,6 @@ public void RegisterDI(IServiceCollection services, IConfiguration config) services.AddScoped(); services.AddScoped(); services.AddScoped(); - services.AddScoped(); - services.AddScoped(); + services.AddScoped(); } } \ No newline at end of file diff --git a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Generation.cs b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Generation.cs new file mode 100644 index 000000000..995054d63 --- /dev/null +++ b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Generation.cs @@ -0,0 +1,71 @@ +using OpenAI.Images; + +namespace BotSharp.Plugin.AzureOpenAI.Providers.Image; + +public partial class ImageCompletionProvider +{ + public async Task GetImageGeneration(Agent agent, RoleDialogModel message) + { + var client = ProviderHelper.GetClient(Provider, _model, _services); + var (prompt, imageCount, options) = PrepareOptions(message); + var imageClient = client.GetImageClient(_model); + + var response = imageClient.GenerateImages(prompt, imageCount, options); + var values = response.Value; + + var generatedImages = new List(); + foreach (var value in values) + { + if (value == null) continue; + + var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty }; + if (options.ResponseFormat == GeneratedImageFormat.Uri) + { + generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty; + } + else if (options.ResponseFormat == GeneratedImageFormat.Bytes) + { + var base64Str = string.Empty; + var bytes = value?.ImageBytes?.ToArray(); + if (!bytes.IsNullOrEmpty()) + { + base64Str = Convert.ToBase64String(bytes); + } + generatedImage.ImageData = base64Str; + } + + generatedImages.Add(generatedImage); + } + + var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description)); + var responseMessage = new RoleDialogModel(AgentRole.Assistant, content) + { + CurrentAgentId = agent.Id, + MessageId = message?.MessageId ?? string.Empty, + GeneratedImages = generatedImages + }; + + return responseMessage; + } + + private (string, int, ImageGenerationOptions) PrepareOptions(RoleDialogModel message) + { + var prompt = message?.Payload ?? message?.Content ?? string.Empty; + + var state = _services.GetRequiredService(); + var size = state.GetState("image_size"); + var quality = state.GetState("image_quality"); + var style = state.GetState("image_style"); + var format = state.GetState("image_format"); + var count = GetImageCount(state.GetState("image_count", "1")); + + var options = new ImageGenerationOptions + { + Size = GetImageSize(size), + Quality = GetImageQuality(quality), + Style = GetImageStyle(style), + ResponseFormat = GetImageFormat(format) + }; + return (prompt, count, options); + } +} diff --git a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Variation.cs b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Variation.cs new file mode 100644 index 000000000..8d9b4f39b --- /dev/null +++ b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Variation.cs @@ -0,0 +1,65 @@ +using OpenAI.Images; + +namespace BotSharp.Plugin.AzureOpenAI.Providers.Image; + +public partial class ImageCompletionProvider +{ + public async Task GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName) + { + var client = ProviderHelper.GetClient(Provider, _model, _services); + var (imageCount, options) = PrepareOptions(); + var imageClient = client.GetImageClient(_model); + + var response = imageClient.GenerateImageVariations(image, imageFileName, imageCount, options); + var values = response.Value; + + var generatedImages = new List(); + foreach (var value in values) + { + if (value == null) continue; + + var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty }; + if (options.ResponseFormat == GeneratedImageFormat.Uri) + { + generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty; + } + else if (options.ResponseFormat == GeneratedImageFormat.Bytes) + { + var base64Str = string.Empty; + var bytes = value?.ImageBytes?.ToArray(); + if (!bytes.IsNullOrEmpty()) + { + base64Str = Convert.ToBase64String(bytes); + } + generatedImage.ImageData = base64Str; + } + + generatedImages.Add(generatedImage); + } + + var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description)); + var responseMessage = new RoleDialogModel(AgentRole.Assistant, content) + { + CurrentAgentId = agent.Id, + MessageId = message?.MessageId ?? string.Empty, + GeneratedImages = generatedImages + }; + + return await Task.FromResult(responseMessage); + } + + private (int, ImageVariationOptions) PrepareOptions() + { + var state = _services.GetRequiredService(); + var size = state.GetState("image_size"); + var format = state.GetState("image_format"); + var count = GetImageCount(state.GetState("image_count", "1")); + + var options = new ImageVariationOptions + { + Size = GetImageSize(size), + ResponseFormat = GetImageFormat(format) + }; + return (count, options); + } +} diff --git a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageGenerationProvider.cs b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.cs similarity index 50% rename from src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageGenerationProvider.cs rename to src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.cs index cf995987e..6e19380a7 100644 --- a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageGenerationProvider.cs +++ b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.cs @@ -2,11 +2,11 @@ namespace BotSharp.Plugin.AzureOpenAI.Providers.Image; -public class ImageGenerationProvider : IImageGeneration +public partial class ImageCompletionProvider : IImageCompletion { protected readonly AzureOpenAiSettings _settings; protected readonly IServiceProvider _services; - protected readonly ILogger _logger; + protected readonly ILogger _logger; private const int DEFAULT_IMAGE_COUNT = 1; private const int IMAGE_COUNT_LIMIT = 5; @@ -15,9 +15,9 @@ public class ImageGenerationProvider : IImageGeneration public virtual string Provider => "azure-openai"; - public ImageGenerationProvider( + public ImageCompletionProvider( AzureOpenAiSettings settings, - ILogger logger, + ILogger logger, IServiceProvider services) { _settings = settings; @@ -25,91 +25,12 @@ public ImageGenerationProvider( _logger = logger; } - - public async Task GetImageGeneration(Agent agent, RoleDialogModel message) - { - var client = ProviderHelper.GetClient(Provider, _model, _services); - var (prompt, imageCount, options) = PrepareOptions(message); - var imageClient = client.GetImageClient(_model); - - var response = imageClient.GenerateImages(prompt, imageCount, options); - var values = response.Value; - - var generatedImages = new List(); - foreach (var value in values) - { - if (value == null) continue; - - var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty }; - if (options.ResponseFormat == GeneratedImageFormat.Uri) - { - generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty; - } - else if (options.ResponseFormat == GeneratedImageFormat.Bytes) - { - var base64Str = string.Empty; - var bytes = value?.ImageBytes?.ToArray(); - if (!bytes.IsNullOrEmpty()) - { - base64Str = Convert.ToBase64String(bytes); - } - generatedImage.ImageData = base64Str; - } - - generatedImages.Add(generatedImage); - } - - var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description)); - var responseMessage = new RoleDialogModel(AgentRole.Assistant, content) - { - CurrentAgentId = agent.Id, - MessageId = message?.MessageId ?? string.Empty, - GeneratedImages = generatedImages - }; - - // After - var contentHooks = _services.GetServices().ToList(); - foreach (var hook in contentHooks) - { - await hook.AfterGenerated(responseMessage, new TokenStatsModel - { - Prompt = prompt, - Provider = Provider, - Model = _model, - PromptCount = prompt.Split(' ', StringSplitOptions.RemoveEmptyEntries).Count(), - CompletionCount = content.Split(' ', StringSplitOptions.RemoveEmptyEntries).Count() - }); - } - - return responseMessage; - } - public void SetModelName(string model) { _model = model; } - private (string, int, ImageGenerationOptions) PrepareOptions(RoleDialogModel message) - { - var prompt = message?.Payload ?? message?.Content ?? string.Empty; - - var state = _services.GetRequiredService(); - var size = state.GetState("image_size"); - var quality = state.GetState("image_quality"); - var style = state.GetState("image_style"); - var format = state.GetState("image_format"); - var count = GetImageCount(state.GetState("image_count", "1")); - - var options = new ImageGenerationOptions - { - Size = GetImageSize(size), - Quality = GetImageQuality(quality), - Style = GetImageStyle(style), - ResponseFormat = GetImageFormat(format) - }; - return (prompt, count, options); - } - + #region Private methods private GeneratedImageSize GetImageSize(string size) { var value = !string.IsNullOrEmpty(size) ? size : "1024x1024"; @@ -210,6 +131,15 @@ private int GetImageCount(string count) return DEFAULT_IMAGE_COUNT; } - return retCount > 0 && retCount <= IMAGE_COUNT_LIMIT ? retCount : DEFAULT_IMAGE_COUNT; + if (retCount <= 0) + { + retCount = DEFAULT_IMAGE_COUNT; + } + else if (retCount > IMAGE_COUNT_LIMIT) + { + retCount = IMAGE_COUNT_LIMIT; + } + return retCount; } + #endregion } diff --git a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageVariationProvider.cs b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageVariationProvider.cs deleted file mode 100644 index ab6598644..000000000 --- a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageVariationProvider.cs +++ /dev/null @@ -1,152 +0,0 @@ -using OpenAI.Images; - -namespace BotSharp.Plugin.AzureOpenAI.Providers.Image; - -public class ImageVariationProvider : IImageVariation -{ - protected readonly AzureOpenAiSettings _settings; - protected readonly IServiceProvider _services; - protected readonly ILogger _logger; - - private const int DEFAULT_IMAGE_COUNT = 1; - private const int IMAGE_COUNT_LIMIT = 5; - - protected string _model; - - public virtual string Provider => "azure-openai"; - - public ImageVariationProvider( - AzureOpenAiSettings settings, - ILogger logger, - IServiceProvider services) - { - _settings = settings; - _services = services; - _logger = logger; - } - - public async Task GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName) - { - var client = ProviderHelper.GetClient(Provider, _model, _services); - var (imageCount, options) = PrepareOptions(); - var imageClient = client.GetImageClient(_model); - - var response = imageClient.GenerateImageVariations(image, imageFileName, imageCount, options); - var values = response.Value; - - var generatedImages = new List(); - foreach (var value in values) - { - if (value == null) continue; - - var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty }; - if (options.ResponseFormat == GeneratedImageFormat.Uri) - { - generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty; - } - else if (options.ResponseFormat == GeneratedImageFormat.Bytes) - { - var base64Str = string.Empty; - var bytes = value?.ImageBytes?.ToArray(); - if (!bytes.IsNullOrEmpty()) - { - base64Str = Convert.ToBase64String(bytes); - } - generatedImage.ImageData = base64Str; - } - - generatedImages.Add(generatedImage); - } - - var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description)); - var responseMessage = new RoleDialogModel(AgentRole.Assistant, content) - { - CurrentAgentId = agent.Id, - MessageId = message?.MessageId ?? string.Empty, - GeneratedImages = generatedImages - }; - - return await Task.FromResult(responseMessage); - } - - public void SetModelName(string model) - { - _model = model; - } - - private (int, ImageVariationOptions) PrepareOptions() - { - var state = _services.GetRequiredService(); - var size = state.GetState("image_size"); - var format = state.GetState("image_format"); - var count = GetImageCount(state.GetState("image_count", "1")); - - var options = new ImageVariationOptions - { - Size = GetImageSize(size), - ResponseFormat = GetImageFormat(format) - }; - return (count, options); - } - - private GeneratedImageSize GetImageSize(string size) - { - var value = !string.IsNullOrEmpty(size) ? size : "1024x1024"; - - GeneratedImageSize retSize; - switch (value) - { - case "256x256": - retSize = GeneratedImageSize.W256xH256; - break; - case "512x512": - retSize = GeneratedImageSize.W512xH512; - break; - case "1024x1024": - retSize = GeneratedImageSize.W1024xH1024; - break; - case "1024x1792": - retSize = GeneratedImageSize.W1024xH1792; - break; - case "1792x1024": - retSize = GeneratedImageSize.W1792xH1024; - break; - default: - retSize = GeneratedImageSize.W1024xH1024; - break; - } - - return retSize; - } - - private GeneratedImageFormat GetImageFormat(string format) - { - var value = !string.IsNullOrEmpty(format) ? format : "uri"; - - GeneratedImageFormat retFormat; - switch (value) - { - case "uri": - retFormat = GeneratedImageFormat.Uri; - break; - case "bytes": - retFormat = GeneratedImageFormat.Bytes; - break; - default: - retFormat = GeneratedImageFormat.Uri; - break; - } - - return retFormat; - } - - private int GetImageCount(string count) - { - if (!int.TryParse(count, out var retCount)) - { - return DEFAULT_IMAGE_COUNT; - } - - return retCount > 0 && retCount <= IMAGE_COUNT_LIMIT ? retCount : DEFAULT_IMAGE_COUNT; - } -} diff --git a/src/Plugins/BotSharp.Plugin.FileHandler/Functions/GenerateImageFn.cs b/src/Plugins/BotSharp.Plugin.FileHandler/Functions/GenerateImageFn.cs index e787da844..095585dcf 100644 --- a/src/Plugins/BotSharp.Plugin.FileHandler/Functions/GenerateImageFn.cs +++ b/src/Plugins/BotSharp.Plugin.FileHandler/Functions/GenerateImageFn.cs @@ -18,7 +18,6 @@ public GenerateImageFn( _logger = logger; } - public async Task Execute(RoleDialogModel message) { var args = JsonSerializer.Deserialize(message.FunctionArgs); @@ -59,7 +58,7 @@ private async Task GetImageGeneration(Agent agent, RoleDialogModel messa { try { - var completion = CompletionProvider.GetImageGeneration(_services, provider: "openai", model: "dall-e-3"); + var completion = CompletionProvider.GetImageCompletion(_services, provider: "openai", model: "dall-e-3"); var text = !string.IsNullOrWhiteSpace(description) ? description : message.Content; var dialog = RoleDialogModel.From(message, AgentRole.User, text); var result = await completion.GetImageGeneration(agent, dialog); diff --git a/src/Plugins/BotSharp.Plugin.OpenAI/OpenAiPlugin.cs b/src/Plugins/BotSharp.Plugin.OpenAI/OpenAiPlugin.cs index b2f663873..1bc69aaf4 100644 --- a/src/Plugins/BotSharp.Plugin.OpenAI/OpenAiPlugin.cs +++ b/src/Plugins/BotSharp.Plugin.OpenAI/OpenAiPlugin.cs @@ -29,7 +29,6 @@ public void RegisterDI(IServiceCollection services, IConfiguration config) services.AddScoped(); services.AddScoped(); services.AddScoped(); - services.AddScoped(); - services.AddScoped(); + services.AddScoped(); } } \ No newline at end of file diff --git a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Generation.cs b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Generation.cs new file mode 100644 index 000000000..c03deb6ff --- /dev/null +++ b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Generation.cs @@ -0,0 +1,73 @@ +using OpenAI.Images; + +namespace BotSharp.Plugin.OpenAI.Providers.Image; + +public partial class ImageCompletionProvider +{ + public async Task GetImageGeneration(Agent agent, RoleDialogModel message) + { + var client = ProviderHelper.GetClient(Provider, _model, _services); + var (prompt, imageCount, options) = PrepareOptions(message); + var imageClient = client.GetImageClient(_model); + + var response = imageClient.GenerateImages(prompt, imageCount, options); + var values = response.Value; + + var generatedImages = new List(); + foreach (var value in values) + { + if (value == null) continue; + + var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty }; + if (options.ResponseFormat == GeneratedImageFormat.Uri) + { + generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty; + } + else if (options.ResponseFormat == GeneratedImageFormat.Bytes) + { + var base64Str = string.Empty; + var bytes = value?.ImageBytes?.ToArray(); + if (!bytes.IsNullOrEmpty()) + { + base64Str = Convert.ToBase64String(bytes); + } + generatedImage.ImageData = base64Str; + } + + generatedImages.Add(generatedImage); + } + + var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description)); + var responseMessage = new RoleDialogModel(AgentRole.Assistant, content) + { + CurrentAgentId = agent.Id, + MessageId = message?.MessageId ?? string.Empty, + GeneratedImages = generatedImages + }; + + return await Task.FromResult(responseMessage); + } + + private (string, int, ImageGenerationOptions) PrepareOptions(RoleDialogModel message) + { + var prompt = message?.Payload ?? message?.Content ?? string.Empty; + + var state = _services.GetRequiredService(); + var size = state.GetState("image_size"); + var quality = state.GetState("image_quality"); + var style = state.GetState("image_style"); + var format = state.GetState("image_format"); + var count = GetImageCount(state.GetState("image_count", "1")); + + var options = new ImageGenerationOptions + { + Size = GetImageSize(size), + Quality = GetImageQuality(quality), + Style = GetImageStyle(style), + ResponseFormat = GetImageFormat(format) + }; + return (prompt, count, options); + } + + +} \ No newline at end of file diff --git a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Variation.cs b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Variation.cs new file mode 100644 index 000000000..13142bc06 --- /dev/null +++ b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Variation.cs @@ -0,0 +1,65 @@ +using OpenAI.Images; + +namespace BotSharp.Plugin.OpenAI.Providers.Image; + +public partial class ImageCompletionProvider +{ + public async Task GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName) + { + var client = ProviderHelper.GetClient(Provider, _model, _services); + var (imageCount, options) = PrepareOptions(); + var imageClient = client.GetImageClient(_model); + + var response = imageClient.GenerateImageVariations(image, imageFileName, imageCount, options); + var values = response.Value; + + var generatedImages = new List(); + foreach (var value in values) + { + if (value == null) continue; + + var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty }; + if (options.ResponseFormat == GeneratedImageFormat.Uri) + { + generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty; + } + else if (options.ResponseFormat == GeneratedImageFormat.Bytes) + { + var base64Str = string.Empty; + var bytes = value?.ImageBytes?.ToArray(); + if (!bytes.IsNullOrEmpty()) + { + base64Str = Convert.ToBase64String(bytes); + } + generatedImage.ImageData = base64Str; + } + + generatedImages.Add(generatedImage); + } + + var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description)); + var responseMessage = new RoleDialogModel(AgentRole.Assistant, content) + { + CurrentAgentId = agent.Id, + MessageId = message?.MessageId ?? string.Empty, + GeneratedImages = generatedImages + }; + + return await Task.FromResult(responseMessage); + } + + private (int, ImageVariationOptions) PrepareOptions() + { + var state = _services.GetRequiredService(); + var size = state.GetState("image_size"); + var format = state.GetState("image_format"); + var count = GetImageCount(state.GetState("image_count", "1")); + + var options = new ImageVariationOptions + { + Size = GetImageSize(size), + ResponseFormat = GetImageFormat(format) + }; + return (count, options); + } +} diff --git a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.cs new file mode 100644 index 000000000..13fb197c1 --- /dev/null +++ b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.cs @@ -0,0 +1,145 @@ +using OpenAI.Images; + +namespace BotSharp.Plugin.OpenAI.Providers.Image; + +public partial class ImageCompletionProvider : IImageCompletion +{ + protected readonly OpenAiSettings _settings; + protected readonly IServiceProvider _services; + protected readonly ILogger _logger; + + private const int DEFAULT_IMAGE_COUNT = 1; + private const int IMAGE_COUNT_LIMIT = 5; + + protected string _model; + + public virtual string Provider => "openai"; + + public ImageCompletionProvider( + OpenAiSettings settings, + ILogger logger, + IServiceProvider services) + { + _settings = settings; + _services = services; + _logger = logger; + } + + public void SetModelName(string model) + { + _model = model; + } + + #region Private methods + private GeneratedImageSize GetImageSize(string size) + { + var value = !string.IsNullOrEmpty(size) ? size : "1024x1024"; + + GeneratedImageSize retSize; + switch (value) + { + case "256x256": + retSize = GeneratedImageSize.W256xH256; + break; + case "512x512": + retSize = GeneratedImageSize.W512xH512; + break; + case "1024x1024": + retSize = GeneratedImageSize.W1024xH1024; + break; + case "1024x1792": + retSize = GeneratedImageSize.W1024xH1792; + break; + case "1792x1024": + retSize = GeneratedImageSize.W1792xH1024; + break; + default: + retSize = GeneratedImageSize.W1024xH1024; + break; + } + + return retSize; + } + + private GeneratedImageQuality GetImageQuality(string quality) + { + var value = !string.IsNullOrEmpty(quality) ? quality : "standard"; + + GeneratedImageQuality retQuality; + switch (value) + { + case "standard": + retQuality = GeneratedImageQuality.Standard; + break; + case "hd": + retQuality = GeneratedImageQuality.High; + break; + default: + retQuality = GeneratedImageQuality.Standard; + break; + } + + return retQuality; + } + + private GeneratedImageStyle GetImageStyle(string style) + { + var value = !string.IsNullOrEmpty(style) ? style : "natural"; + + GeneratedImageStyle retStyle; + switch (value) + { + case "natural": + retStyle = GeneratedImageStyle.Natural; + break; + case "vivid": + retStyle = GeneratedImageStyle.Vivid; + break; + default: + retStyle = GeneratedImageStyle.Natural; + break; + } + + return retStyle; + } + + private GeneratedImageFormat GetImageFormat(string format) + { + var value = !string.IsNullOrEmpty(format) ? format : "uri"; + + GeneratedImageFormat retFormat; + switch (value) + { + case "uri": + retFormat = GeneratedImageFormat.Uri; + break; + case "bytes": + retFormat = GeneratedImageFormat.Bytes; + break; + default: + retFormat = GeneratedImageFormat.Uri; + break; + } + + return retFormat; + } + + private int GetImageCount(string count) + { + if (!int.TryParse(count, out var retCount)) + { + return DEFAULT_IMAGE_COUNT; + } + + if (retCount <= 0) + { + retCount = DEFAULT_IMAGE_COUNT; + } + else if (retCount > IMAGE_COUNT_LIMIT) + { + retCount = IMAGE_COUNT_LIMIT; + } + return retCount; + } + #endregion +} diff --git a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageGenerationProvider.cs b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageGenerationProvider.cs deleted file mode 100644 index ddbabe859..000000000 --- a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageGenerationProvider.cs +++ /dev/null @@ -1,215 +0,0 @@ -using OpenAI.Images; - -namespace BotSharp.Plugin.OpenAI.Providers.Image; - -public class ImageGenerationProvider : IImageGeneration -{ - protected readonly OpenAiSettings _settings; - protected readonly IServiceProvider _services; - protected readonly ILogger _logger; - - private const int DEFAULT_IMAGE_COUNT = 1; - private const int IMAGE_COUNT_LIMIT = 5; - - protected string _model; - - public virtual string Provider => "openai"; - - public ImageGenerationProvider( - OpenAiSettings settings, - ILogger logger, - IServiceProvider services) - { - _settings = settings; - _services = services; - _logger = logger; - } - - - public async Task GetImageGeneration(Agent agent, RoleDialogModel message) - { - var client = ProviderHelper.GetClient(Provider, _model, _services); - var (prompt, imageCount, options) = PrepareOptions(message); - var imageClient = client.GetImageClient(_model); - - var response = imageClient.GenerateImages(prompt, imageCount, options); - var values = response.Value; - - var generatedImages = new List(); - foreach (var value in values) - { - if (value == null) continue; - - var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty }; - if (options.ResponseFormat == GeneratedImageFormat.Uri) - { - generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty; - } - else if (options.ResponseFormat == GeneratedImageFormat.Bytes) - { - var base64Str = string.Empty; - var bytes = value?.ImageBytes?.ToArray(); - if (!bytes.IsNullOrEmpty()) - { - base64Str = Convert.ToBase64String(bytes); - } - generatedImage.ImageData = base64Str; - } - - generatedImages.Add(generatedImage); - } - - var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description)); - var responseMessage = new RoleDialogModel(AgentRole.Assistant, content) - { - CurrentAgentId = agent.Id, - MessageId = message?.MessageId ?? string.Empty, - GeneratedImages = generatedImages - }; - - // After - var contentHooks = _services.GetServices().ToList(); - foreach (var hook in contentHooks) - { - await hook.AfterGenerated(responseMessage, new TokenStatsModel - { - Prompt = prompt, - Provider = Provider, - Model = _model, - PromptCount = prompt.Split(' ', StringSplitOptions.RemoveEmptyEntries).Count(), - CompletionCount = content.Split(' ', StringSplitOptions.RemoveEmptyEntries).Count() - }); - } - - return responseMessage; - } - - public void SetModelName(string model) - { - _model = model; - } - - private (string, int, ImageGenerationOptions) PrepareOptions(RoleDialogModel message) - { - var prompt = message?.Payload ?? message?.Content ?? string.Empty; - - var state = _services.GetRequiredService(); - var size = state.GetState("image_size"); - var quality = state.GetState("image_quality"); - var style = state.GetState("image_style"); - var format = state.GetState("image_format"); - var count = GetImageCount(state.GetState("image_count", "1")); - - var options = new ImageGenerationOptions - { - Size = GetImageSize(size), - Quality = GetImageQuality(quality), - Style = GetImageStyle(style), - ResponseFormat = GetImageFormat(format) - }; - return (prompt, count, options); - } - - private GeneratedImageSize GetImageSize(string size) - { - var value = !string.IsNullOrEmpty(size) ? size : "1024x1024"; - - GeneratedImageSize retSize; - switch (value) - { - case "256x256": - retSize = GeneratedImageSize.W256xH256; - break; - case "512x512": - retSize = GeneratedImageSize.W512xH512; - break; - case "1024x1024": - retSize = GeneratedImageSize.W1024xH1024; - break; - case "1024x1792": - retSize = GeneratedImageSize.W1024xH1792; - break; - case "1792x1024": - retSize = GeneratedImageSize.W1792xH1024; - break; - default: - retSize = GeneratedImageSize.W1024xH1024; - break; - } - - return retSize; - } - - private GeneratedImageQuality GetImageQuality(string quality) - { - var value = !string.IsNullOrEmpty(quality) ? quality : "standard"; - - GeneratedImageQuality retQuality; - switch (value) - { - case "standard": - retQuality = GeneratedImageQuality.Standard; - break; - case "hd": - retQuality = GeneratedImageQuality.High; - break; - default: - retQuality = GeneratedImageQuality.Standard; - break; - } - - return retQuality; - } - - private GeneratedImageStyle GetImageStyle(string style) - { - var value = !string.IsNullOrEmpty(style) ? style : "natural"; - - GeneratedImageStyle retStyle; - switch (value) - { - case "natural": - retStyle = GeneratedImageStyle.Natural; - break; - case "vivid": - retStyle = GeneratedImageStyle.Vivid; - break; - default: - retStyle = GeneratedImageStyle.Natural; - break; - } - - return retStyle; - } - - private GeneratedImageFormat GetImageFormat(string format) - { - var value = !string.IsNullOrEmpty(format) ? format : "uri"; - - GeneratedImageFormat retFormat; - switch (value) - { - case "uri": - retFormat = GeneratedImageFormat.Uri; - break; - case "bytes": - retFormat = GeneratedImageFormat.Bytes; - break; - default: - retFormat = GeneratedImageFormat.Uri; - break; - } - - return retFormat; - } - - private int GetImageCount(string count) - { - if (!int.TryParse(count, out var retCount)) - { - return DEFAULT_IMAGE_COUNT; - } - - return retCount > 0 && retCount <= IMAGE_COUNT_LIMIT ? retCount : DEFAULT_IMAGE_COUNT; - } -} \ No newline at end of file diff --git a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageVariationProvider.cs b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageVariationProvider.cs deleted file mode 100644 index 619c665fe..000000000 --- a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageVariationProvider.cs +++ /dev/null @@ -1,152 +0,0 @@ -using OpenAI.Images; - -namespace BotSharp.Plugin.OpenAI.Providers.Image; - -public class ImageVariationProvider : IImageVariation -{ - protected readonly OpenAiSettings _settings; - protected readonly IServiceProvider _services; - protected readonly ILogger _logger; - - private const int DEFAULT_IMAGE_COUNT = 1; - private const int IMAGE_COUNT_LIMIT = 5; - - protected string _model; - - public virtual string Provider => "openai"; - - public ImageVariationProvider( - OpenAiSettings settings, - ILogger logger, - IServiceProvider services) - { - _settings = settings; - _services = services; - _logger = logger; - } - - public async Task GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName) - { - var client = ProviderHelper.GetClient(Provider, _model, _services); - var (imageCount, options) = PrepareOptions(); - var imageClient = client.GetImageClient(_model); - - var response = imageClient.GenerateImageVariations(image, imageFileName, imageCount, options); - var values = response.Value; - - var generatedImages = new List(); - foreach (var value in values) - { - if (value == null) continue; - - var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty }; - if (options.ResponseFormat == GeneratedImageFormat.Uri) - { - generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty; - } - else if (options.ResponseFormat == GeneratedImageFormat.Bytes) - { - var base64Str = string.Empty; - var bytes = value?.ImageBytes?.ToArray(); - if (!bytes.IsNullOrEmpty()) - { - base64Str = Convert.ToBase64String(bytes); - } - generatedImage.ImageData = base64Str; - } - - generatedImages.Add(generatedImage); - } - - var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description)); - var responseMessage = new RoleDialogModel(AgentRole.Assistant, content) - { - CurrentAgentId = agent.Id, - MessageId = message?.MessageId ?? string.Empty, - GeneratedImages = generatedImages - }; - - return await Task.FromResult(responseMessage); - } - - public void SetModelName(string model) - { - _model = model; - } - - private (int, ImageVariationOptions) PrepareOptions() - { - var state = _services.GetRequiredService(); - var size = state.GetState("image_size"); - var format = state.GetState("image_format"); - var count = GetImageCount(state.GetState("image_count", "1")); - - var options = new ImageVariationOptions - { - Size = GetImageSize(size), - ResponseFormat = GetImageFormat(format) - }; - return (count, options); - } - - private GeneratedImageSize GetImageSize(string size) - { - var value = !string.IsNullOrEmpty(size) ? size : "1024x1024"; - - GeneratedImageSize retSize; - switch (value) - { - case "256x256": - retSize = GeneratedImageSize.W256xH256; - break; - case "512x512": - retSize = GeneratedImageSize.W512xH512; - break; - case "1024x1024": - retSize = GeneratedImageSize.W1024xH1024; - break; - case "1024x1792": - retSize = GeneratedImageSize.W1024xH1792; - break; - case "1792x1024": - retSize = GeneratedImageSize.W1792xH1024; - break; - default: - retSize = GeneratedImageSize.W1024xH1024; - break; - } - - return retSize; - } - - private GeneratedImageFormat GetImageFormat(string format) - { - var value = !string.IsNullOrEmpty(format) ? format : "uri"; - - GeneratedImageFormat retFormat; - switch (value) - { - case "uri": - retFormat = GeneratedImageFormat.Uri; - break; - case "bytes": - retFormat = GeneratedImageFormat.Bytes; - break; - default: - retFormat = GeneratedImageFormat.Uri; - break; - } - - return retFormat; - } - - private int GetImageCount(string count) - { - if (!int.TryParse(count, out var retCount)) - { - return DEFAULT_IMAGE_COUNT; - } - - return retCount > 0 && retCount <= IMAGE_COUNT_LIMIT ? retCount : DEFAULT_IMAGE_COUNT; - } -} From bb9feb2697abdd34b5545bb5508ba77be6f98743 Mon Sep 17 00:00:00 2001 From: Jicheng Lu Date: Thu, 18 Jul 2024 22:49:24 -0500 Subject: [PATCH 2/3] minor change --- .../Controllers/InstructModeController.cs | 4 ++++ .../ImageCompletionProvider.Generation.cs | 16 ++++++++-------- .../Image/ImageCompletionProvider.Variation.cs | 8 ++++---- .../ImageCompletionProvider.Generation.cs | 18 ++++++++---------- .../Image/ImageCompletionProvider.Variation.cs | 10 +++++----- 5 files changed, 29 insertions(+), 27 deletions(-) diff --git a/src/Infrastructure/BotSharp.OpenAPI/Controllers/InstructModeController.cs b/src/Infrastructure/BotSharp.OpenAPI/Controllers/InstructModeController.cs index dbfdeea0d..8d4331237 100644 --- a/src/Infrastructure/BotSharp.OpenAPI/Controllers/InstructModeController.cs +++ b/src/Infrastructure/BotSharp.OpenAPI/Controllers/InstructModeController.cs @@ -141,6 +141,10 @@ public async Task ImageVariation([FromBody] IncomingMe try { var file = input.Files.FirstOrDefault(x => !string.IsNullOrWhiteSpace(x.FileUrl) || !string.IsNullOrWhiteSpace(x.FileData)); + if (file == null) + { + return new ImageGenerationViewModel { Message = "Error! Cannot find an image!" }; + } var message = await fileService.VarifyImage(input.Provider, input.Model, file); imageViewModel.Content = message.Content; imageViewModel.Images = message.GeneratedImages.Select(x => ImageViewModel.ToViewModel(x)).ToList(); diff --git a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Generation.cs b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Generation.cs index 995054d63..eeda728ba 100644 --- a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Generation.cs +++ b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Generation.cs @@ -53,18 +53,18 @@ public async Task GetImageGeneration(Agent agent, RoleDialogMod var prompt = message?.Payload ?? message?.Content ?? string.Empty; var state = _services.GetRequiredService(); - var size = state.GetState("image_size"); - var quality = state.GetState("image_quality"); - var style = state.GetState("image_style"); - var format = state.GetState("image_format"); + var size = GetImageSize(state.GetState("image_size")); + var quality = GetImageQuality(state.GetState("image_quality")); + var style = GetImageStyle(state.GetState("image_style")); + var format = GetImageFormat(state.GetState("image_format")); var count = GetImageCount(state.GetState("image_count", "1")); var options = new ImageGenerationOptions { - Size = GetImageSize(size), - Quality = GetImageQuality(quality), - Style = GetImageStyle(style), - ResponseFormat = GetImageFormat(format) + Size = size, + Quality = quality, + Style = style, + ResponseFormat = format }; return (prompt, count, options); } diff --git a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Variation.cs b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Variation.cs index 8d9b4f39b..e7543e8a0 100644 --- a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Variation.cs +++ b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Variation.cs @@ -51,14 +51,14 @@ public async Task GetImageVariation(Agent agent, RoleDialogMode private (int, ImageVariationOptions) PrepareOptions() { var state = _services.GetRequiredService(); - var size = state.GetState("image_size"); - var format = state.GetState("image_format"); + var size = GetImageSize(state.GetState("image_size")); + var format = GetImageFormat(state.GetState("image_format")); var count = GetImageCount(state.GetState("image_count", "1")); var options = new ImageVariationOptions { - Size = GetImageSize(size), - ResponseFormat = GetImageFormat(format) + Size = size, + ResponseFormat = format }; return (count, options); } diff --git a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Generation.cs b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Generation.cs index c03deb6ff..18bf72283 100644 --- a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Generation.cs +++ b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Generation.cs @@ -53,21 +53,19 @@ public async Task GetImageGeneration(Agent agent, RoleDialogMod var prompt = message?.Payload ?? message?.Content ?? string.Empty; var state = _services.GetRequiredService(); - var size = state.GetState("image_size"); - var quality = state.GetState("image_quality"); - var style = state.GetState("image_style"); - var format = state.GetState("image_format"); + var size = GetImageSize(state.GetState("image_size")); + var quality = GetImageQuality(state.GetState("image_quality")); + var style = GetImageStyle(state.GetState("image_style")); + var format = GetImageFormat(state.GetState("image_format")); var count = GetImageCount(state.GetState("image_count", "1")); var options = new ImageGenerationOptions { - Size = GetImageSize(size), - Quality = GetImageQuality(quality), - Style = GetImageStyle(style), - ResponseFormat = GetImageFormat(format) + Size = size, + Quality = quality, + Style = style, + ResponseFormat = format }; return (prompt, count, options); } - - } \ No newline at end of file diff --git a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Variation.cs b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Variation.cs index 13142bc06..a95097bf8 100644 --- a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Variation.cs +++ b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Variation.cs @@ -10,7 +10,7 @@ public async Task GetImageVariation(Agent agent, RoleDialogMode var (imageCount, options) = PrepareOptions(); var imageClient = client.GetImageClient(_model); - var response = imageClient.GenerateImageVariations(image, imageFileName, imageCount, options); + var response = await imageClient.GenerateImageVariationsAsync(image, imageFileName, imageCount, options); var values = response.Value; var generatedImages = new List(); @@ -51,14 +51,14 @@ public async Task GetImageVariation(Agent agent, RoleDialogMode private (int, ImageVariationOptions) PrepareOptions() { var state = _services.GetRequiredService(); - var size = state.GetState("image_size"); - var format = state.GetState("image_format"); + var size = GetImageSize(state.GetState("image_size")); + var format = GetImageFormat(state.GetState("image_format")); var count = GetImageCount(state.GetState("image_count", "1")); var options = new ImageVariationOptions { - Size = GetImageSize(size), - ResponseFormat = GetImageFormat(format) + Size = size, + ResponseFormat = format }; return (count, options); } From 4311eaf87e9dfb249a24dbf9eba06f177fa386b9 Mon Sep 17 00:00:00 2001 From: Jicheng Lu Date: Thu, 18 Jul 2024 22:51:52 -0500 Subject: [PATCH 3/3] minor change --- .../Providers/Image/ImageCompletionProvider.Generation.cs | 2 +- .../Providers/Image/ImageCompletionProvider.Variation.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Generation.cs b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Generation.cs index eeda728ba..dc2de8074 100644 --- a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Generation.cs +++ b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Generation.cs @@ -45,7 +45,7 @@ public async Task GetImageGeneration(Agent agent, RoleDialogMod GeneratedImages = generatedImages }; - return responseMessage; + return await Task.FromResult(responseMessage); } private (string, int, ImageGenerationOptions) PrepareOptions(RoleDialogModel message) diff --git a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Variation.cs b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Variation.cs index a95097bf8..0233bad4e 100644 --- a/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Variation.cs +++ b/src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageCompletionProvider.Variation.cs @@ -10,7 +10,7 @@ public async Task GetImageVariation(Agent agent, RoleDialogMode var (imageCount, options) = PrepareOptions(); var imageClient = client.GetImageClient(_model); - var response = await imageClient.GenerateImageVariationsAsync(image, imageFileName, imageCount, options); + var response = imageClient.GenerateImageVariations(image, imageFileName, imageCount, options); var values = response.Value; var generatedImages = new List();