diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java index a68b0afe1d40f..a7767ab4e764b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java @@ -10,9 +10,18 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; -import org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.OpenAiEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.TruncatingRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.openai.OpenAiResponseHandler; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; @@ -27,6 +36,18 @@ */ public class OpenAiActionCreator implements OpenAiActionVisitor { public static final String COMPLETION_ERROR_PREFIX = "OpenAI chat completions"; + public static final String USER_ROLE = "user"; + + static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( + "openai completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + public static final ResponseHandler EMBEDDINGS_HANDLER = new OpenAiResponseHandler( + "openai text embedding", + OpenAiEmbeddingsResponseEntity::fromResponse, + false + ); + private final Sender sender; private final ServiceComponents serviceComponents; @@ -38,20 +59,30 @@ public OpenAiActionCreator(Sender sender, ServiceComponents serviceComponents) { @Override public ExecutableAction create(OpenAiEmbeddingsModel model, Map taskSettings) { var overriddenModel = OpenAiEmbeddingsModel.of(model, taskSettings); - var requestCreator = OpenAiEmbeddingsRequestManager.of( + var manager = new TruncatingRequestManager( + serviceComponents.threadPool(), overriddenModel, - serviceComponents.truncator(), - serviceComponents.threadPool() + EMBEDDINGS_HANDLER, + (truncationResult) -> new OpenAiEmbeddingsRequest(serviceComponents.truncator(), truncationResult, overriddenModel), + overriddenModel.getServiceSettings().maxInputTokens() ); + var errorMessage = constructFailedToSendRequestMessage("OpenAI embeddings"); - return new SenderExecutableAction(sender, requestCreator, errorMessage); + return new SenderExecutableAction(sender, manager, errorMessage); } @Override public ExecutableAction create(OpenAiChatCompletionModel model, Map taskSettings) { var overriddenModel = OpenAiChatCompletionModel.of(model, taskSettings); - var requestCreator = OpenAiCompletionRequestManager.of(overriddenModel, serviceComponents.threadPool()); + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + overriddenModel, + COMPLETION_HANDLER, + (inputs) -> new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), overriddenModel), + ChatCompletionInput.class + ); + var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); - return new SingleInputSenderExecutableAction(sender, requestCreator, errorMessage, COMPLETION_ERROR_PREFIX); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManager.java index a015716b81032..b4242ff524dcf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManager.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.util.Objects; @@ -17,14 +18,31 @@ abstract class BaseRequestManager implements RequestManager { private final ThreadPool threadPool; private final String inferenceEntityId; - private final Object rateLimitGroup; + // It's possible that two inference endpoints have the same information defining the group but have different + // rate limits then they should be in different groups otherwise whoever initially created the group will set + // the rate and the other inference endpoint's rate will be ignored + private final EndpointGrouping endpointGrouping; private final RateLimitSettings rateLimitSettings; BaseRequestManager(ThreadPool threadPool, String inferenceEntityId, Object rateLimitGroup, RateLimitSettings rateLimitSettings) { this.threadPool = Objects.requireNonNull(threadPool); this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId); - this.rateLimitGroup = Objects.requireNonNull(rateLimitGroup); - this.rateLimitSettings = Objects.requireNonNull(rateLimitSettings); + + Objects.requireNonNull(rateLimitSettings); + this.endpointGrouping = new EndpointGrouping(Objects.requireNonNull(rateLimitGroup).hashCode(), rateLimitSettings); + this.rateLimitSettings = rateLimitSettings; + } + + BaseRequestManager(ThreadPool threadPool, RateLimitGroupingModel rateLimitGroupingModel) { + this.threadPool = Objects.requireNonNull(threadPool); + Objects.requireNonNull(rateLimitGroupingModel); + + this.inferenceEntityId = rateLimitGroupingModel.inferenceEntityId(); + this.endpointGrouping = new EndpointGrouping( + rateLimitGroupingModel.rateLimitGroupingHash(), + rateLimitGroupingModel.rateLimitSettings() + ); + this.rateLimitSettings = rateLimitGroupingModel.rateLimitSettings(); } protected void execute(Runnable runnable) { @@ -38,10 +56,7 @@ public String inferenceEntityId() { @Override public Object rateLimitGrouping() { - // It's possible that two inference endpoints have the same information defining the group but have different - // rate limits then they should be in different groups otherwise whoever initially created the group will set - // the rate and the other inference endpoint's rate will be ignored - return new EndpointGrouping(rateLimitGroup, rateLimitSettings); + return endpointGrouping; } @Override @@ -49,5 +64,5 @@ public RateLimitSettings rateLimitSettings() { return rateLimitSettings; } - private record EndpointGrouping(Object group, RateLimitSettings settings) {} + private record EndpointGrouping(int group, RateLimitSettings settings) {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GenericRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GenericRequestManager.java new file mode 100644 index 0000000000000..77f51e169ab17 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GenericRequestManager.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; + +import java.util.Objects; +import java.util.function.Function; +import java.util.function.Supplier; + +/** + * This is a temporary class to use while we refactor all the request managers. After all the request managers extend + * this class we'll move this functionality directly into the {@link BaseRequestManager}. + */ +public class GenericRequestManager extends BaseRequestManager { + private static final Logger logger = LogManager.getLogger(GenericRequestManager.class); + + protected final ResponseHandler responseHandler; + protected final Function requestCreator; + protected final Class inputType; + + public GenericRequestManager( + ThreadPool threadPool, + RateLimitGroupingModel rateLimitGroupingModel, + ResponseHandler responseHandler, + Function requestCreator, + Class inputType + ) { + super(threadPool, rateLimitGroupingModel); + this.responseHandler = Objects.requireNonNull(responseHandler); + this.requestCreator = Objects.requireNonNull(requestCreator); + this.inputType = Objects.requireNonNull(inputType); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var request = requestCreator.apply(inferenceInputs.castTo(inputType)); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java index e85ea6f1d9b35..816d6550f9b04 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java @@ -22,7 +22,7 @@ public static IllegalArgumentException createUnsupportedTypeException(InferenceI ); } - public T castTo(Class clazz) { + public T castTo(Class clazz) { if (clazz.isInstance(this) == false) { throw createUnsupportedTypeException(this, clazz); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java deleted file mode 100644 index ca25b56953251..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.sender; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler; -import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; -import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; -import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; - -import java.util.Objects; -import java.util.function.Supplier; - -public class OpenAiCompletionRequestManager extends OpenAiRequestManager { - - private static final Logger logger = LogManager.getLogger(OpenAiCompletionRequestManager.class); - private static final ResponseHandler HANDLER = createCompletionHandler(); - public static final String USER_ROLE = "user"; - - public static OpenAiCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) { - return new OpenAiCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); - } - - private final OpenAiChatCompletionModel model; - - private OpenAiCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPool threadPool) { - super(threadPool, model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); - this.model = Objects.requireNonNull(model); - } - - @Override - public void execute( - InferenceInputs inferenceInputs, - RequestSender requestSender, - Supplier hasRequestCompletedFunction, - ActionListener listener - ) { - var chatCompletionInputs = inferenceInputs.castTo(ChatCompletionInput.class); - var request = new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(chatCompletionInputs, USER_ROLE), model); - - execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); - } - - private static ResponseHandler createCompletionHandler() { - return new OpenAiChatCompletionResponseHandler("openai completion", OpenAiChatCompletionResponseEntity::fromResponse); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java deleted file mode 100644 index 49fa15e5bc843..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.sender; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.common.Truncator; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.openai.OpenAiResponseHandler; -import org.elasticsearch.xpack.inference.external.request.openai.OpenAiEmbeddingsRequest; -import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity; -import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; - -import java.util.List; -import java.util.Objects; -import java.util.function.Supplier; - -import static org.elasticsearch.xpack.inference.common.Truncator.truncate; - -public class OpenAiEmbeddingsRequestManager extends OpenAiRequestManager { - - private static final Logger logger = LogManager.getLogger(OpenAiEmbeddingsRequestManager.class); - - private static final ResponseHandler HANDLER = createEmbeddingsHandler(); - - private static ResponseHandler createEmbeddingsHandler() { - return new OpenAiResponseHandler("openai text embedding", OpenAiEmbeddingsResponseEntity::fromResponse, false); - } - - public static OpenAiEmbeddingsRequestManager of(OpenAiEmbeddingsModel model, Truncator truncator, ThreadPool threadPool) { - return new OpenAiEmbeddingsRequestManager( - Objects.requireNonNull(model), - Objects.requireNonNull(truncator), - Objects.requireNonNull(threadPool) - ); - } - - private final Truncator truncator; - private final OpenAiEmbeddingsModel model; - - private OpenAiEmbeddingsRequestManager(OpenAiEmbeddingsModel model, Truncator truncator, ThreadPool threadPool) { - super(threadPool, model, OpenAiEmbeddingsRequest::buildDefaultUri); - this.model = Objects.requireNonNull(model); - this.truncator = Objects.requireNonNull(truncator); - } - - @Override - public void execute( - InferenceInputs inferenceInputs, - RequestSender requestSender, - Supplier hasRequestCompletedFunction, - ActionListener listener - ) { - List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); - var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); - OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(truncator, truncatedInput, model); - - execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiRequestManager.java deleted file mode 100644 index a97e912141631..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiRequestManager.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.sender; - -import org.elasticsearch.common.CheckedSupplier; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.openai.OpenAiAccount; -import org.elasticsearch.xpack.inference.services.openai.OpenAiModel; - -import java.net.URI; -import java.net.URISyntaxException; -import java.util.Objects; - -abstract class OpenAiRequestManager extends BaseRequestManager { - - protected OpenAiRequestManager(ThreadPool threadPool, OpenAiModel model, CheckedSupplier uriBuilder) { - super( - threadPool, - model.getInferenceEntityId(), - RateLimitGrouping.of(model, uriBuilder), - model.rateLimitServiceSettings().rateLimitSettings() - ); - } - - record RateLimitGrouping(int accountHash, int modelIdHash) { - public static RateLimitGrouping of(OpenAiModel model, CheckedSupplier uriBuilder) { - Objects.requireNonNull(model); - - return new RateLimitGrouping( - OpenAiAccount.of(model, uriBuilder).hashCode(), - model.rateLimitServiceSettings().modelId().hashCode() - ); - } - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java deleted file mode 100644 index 3b0f770e3e061..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.sender; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler; -import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; -import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; -import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; - -import java.util.Objects; -import java.util.function.Supplier; - -public class OpenAiUnifiedCompletionRequestManager extends OpenAiRequestManager { - - private static final Logger logger = LogManager.getLogger(OpenAiUnifiedCompletionRequestManager.class); - - private static final ResponseHandler HANDLER = createCompletionHandler(); - - public static OpenAiUnifiedCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) { - return new OpenAiUnifiedCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); - } - - private final OpenAiChatCompletionModel model; - - private OpenAiUnifiedCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPool threadPool) { - super(threadPool, model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); - this.model = Objects.requireNonNull(model); - } - - @Override - public void execute( - InferenceInputs inferenceInputs, - RequestSender requestSender, - Supplier hasRequestCompletedFunction, - ActionListener listener - ) { - - OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest( - inferenceInputs.castTo(UnifiedChatInput.class), - model - ); - - execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); - } - - private static ResponseHandler createCompletionHandler() { - return new OpenAiUnifiedChatCompletionResponseHandler("openai completion", OpenAiChatCompletionResponseEntity::fromResponse); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java new file mode 100644 index 0000000000000..a292b8f7acc10 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; + +import java.util.Objects; +import java.util.function.Function; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +public class TruncatingRequestManager extends BaseRequestManager { + private static final Logger logger = LogManager.getLogger(TruncatingRequestManager.class); + + private final ResponseHandler responseHandler; + private final Function requestCreator; + private final Integer maxInputTokens; + + public TruncatingRequestManager( + ThreadPool threadPool, + RateLimitGroupingModel rateLimitGroupingModel, + ResponseHandler responseHandler, + Function requestCreator, + @Nullable Integer maxInputTokens + ) { + super(threadPool, rateLimitGroupingModel); + this.responseHandler = Objects.requireNonNull(responseHandler); + this.requestCreator = Objects.requireNonNull(requestCreator); + this.maxInputTokens = maxInputTokens; + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var docsInput = inferenceInputs.castTo(DocumentsOnlyInput.class).getInputs(); + var truncatedInput = truncate(docsInput, maxInputTokens); + var request = requestCreator.apply(truncatedInput); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiAccount.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiAccount.java deleted file mode 100644 index 07ccf298a0bd3..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiAccount.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.openai; - -import org.elasticsearch.common.CheckedSupplier; -import org.elasticsearch.common.settings.SecureString; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.xpack.inference.services.openai.OpenAiModel; - -import java.net.URI; -import java.net.URISyntaxException; -import java.util.Objects; - -import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; - -public record OpenAiAccount(URI uri, @Nullable String organizationId, SecureString apiKey) { - - public static OpenAiAccount of(OpenAiModel model, CheckedSupplier uriBuilder) { - var uri = buildUri(model.rateLimitServiceSettings().uri(), "OpenAI", uriBuilder); - - return new OpenAiAccount(uri, model.rateLimitServiceSettings().organizationId(), model.apiKey()); - } - - public OpenAiAccount { - Objects.requireNonNull(uri); - Objects.requireNonNull(apiKey); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java index 7f8626dacc684..dfd4330713d61 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java @@ -9,18 +9,15 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.utils.URIBuilder; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.common.Truncator; -import org.elasticsearch.xpack.inference.external.openai.OpenAiAccount; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; import java.net.URI; -import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; import java.util.Objects; @@ -30,19 +27,17 @@ public class OpenAiEmbeddingsRequest implements Request { private final Truncator truncator; - private final OpenAiAccount account; private final Truncator.TruncationResult truncationResult; private final OpenAiEmbeddingsModel model; public OpenAiEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, OpenAiEmbeddingsModel model) { this.truncator = Objects.requireNonNull(truncator); - this.account = OpenAiAccount.of(model, OpenAiEmbeddingsRequest::buildDefaultUri); this.truncationResult = Objects.requireNonNull(input); this.model = Objects.requireNonNull(model); } public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(account.uri()); + HttpPost httpPost = new HttpPost(model.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( Strings.toString( @@ -58,9 +53,9 @@ public HttpRequest createHttpRequest() { httpPost.setEntity(byteEntity); httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); - httpPost.setHeader(createAuthBearerHeader(account.apiKey())); + httpPost.setHeader(createAuthBearerHeader(model.apiKey())); - var org = account.organizationId(); + var org = model.rateLimitServiceSettings().organizationId(); if (org != null) { httpPost.setHeader(createOrgHeader(org)); } @@ -75,7 +70,7 @@ public String getInferenceEntityId() { @Override public URI getURI() { - return account.uri(); + return model.uri(); } @Override @@ -89,11 +84,4 @@ public Request truncate() { public boolean[] getTruncationInfo() { return truncationResult.truncated().clone(); } - - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(OpenAiUtils.HOST) - .setPathSegments(OpenAiUtils.VERSION_1, OpenAiUtils.EMBEDDINGS_PATH) - .build(); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java index e5b85633a499b..84f90e9027323 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java @@ -9,18 +9,15 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.utils.URIBuilder; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.external.openai.OpenAiAccount; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; import java.net.URI; -import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; import java.util.Objects; @@ -29,19 +26,17 @@ public class OpenAiUnifiedChatCompletionRequest implements Request { - private final OpenAiAccount account; private final OpenAiChatCompletionModel model; private final UnifiedChatInput unifiedChatInput; public OpenAiUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) { - this.account = OpenAiAccount.of(model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); this.model = Objects.requireNonNull(model); } @Override public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(account.uri()); + HttpPost httpPost = new HttpPost(model.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( Strings.toString(new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8) @@ -49,9 +44,9 @@ public HttpRequest createHttpRequest() { httpPost.setEntity(byteEntity); httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); - httpPost.setHeader(createAuthBearerHeader(account.apiKey())); + httpPost.setHeader(createAuthBearerHeader(model.apiKey())); - var org = account.organizationId(); + var org = model.rateLimitServiceSettings().organizationId(); if (org != null) { httpPost.setHeader(createOrgHeader(org)); } @@ -61,7 +56,7 @@ public HttpRequest createHttpRequest() { @Override public URI getURI() { - return account.uri(); + return model.uri(); } @Override @@ -85,11 +80,4 @@ public String getInferenceEntityId() { public boolean isStreaming() { return unifiedChatInput.stream(); } - - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(OpenAiUtils.HOST) - .setPathSegments(OpenAiUtils.VERSION_1, OpenAiUtils.CHAT_PATH, OpenAiUtils.COMPLETIONS_PATH) - .build(); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/RateLimitGroupingModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/RateLimitGroupingModel.java new file mode 100644 index 0000000000000..87a14f238d2ad --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/RateLimitGroupingModel.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services; + +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public abstract class RateLimitGroupingModel extends Model { + protected RateLimitGroupingModel(ModelConfigurations configurations, ModelSecrets secrets) { + super(configurations, secrets); + } + + protected RateLimitGroupingModel(RateLimitGroupingModel model, TaskSettings taskSettings) { + super(model, taskSettings); + } + + protected RateLimitGroupingModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + public String inferenceEntityId() { + return getInferenceEntityId(); + } + + public abstract int rateLimitGroupingHash(); + + public abstract RateLimitSettings rateLimitSettings(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiModel.java index caf09de31794e..dd0afaa1c6977 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiModel.java @@ -9,34 +9,39 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionVisitor; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import java.net.URI; import java.util.Map; import java.util.Objects; -public abstract class OpenAiModel extends Model { +public abstract class OpenAiModel extends RateLimitGroupingModel { private final OpenAiRateLimitServiceSettings rateLimitServiceSettings; private final SecureString apiKey; + private final URI uri; public OpenAiModel( ModelConfigurations configurations, ModelSecrets secrets, OpenAiRateLimitServiceSettings rateLimitServiceSettings, - @Nullable ApiKeySecrets apiKeySecrets + @Nullable ApiKeySecrets apiKeySecrets, + URI uri ) { super(configurations, secrets); this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); apiKey = ServiceUtils.apiKey(apiKeySecrets); + this.uri = Objects.requireNonNull(uri); } protected OpenAiModel(OpenAiModel model, TaskSettings taskSettings) { @@ -44,6 +49,7 @@ protected OpenAiModel(OpenAiModel model, TaskSettings taskSettings) { rateLimitServiceSettings = model.rateLimitServiceSettings(); apiKey = model.apiKey(); + uri = model.uri; } protected OpenAiModel(OpenAiModel model, ServiceSettings serviceSettings) { @@ -51,6 +57,7 @@ protected OpenAiModel(OpenAiModel model, ServiceSettings serviceSettings) { rateLimitServiceSettings = model.rateLimitServiceSettings(); apiKey = model.apiKey(); + uri = model.uri; } public SecureString apiKey() { @@ -62,4 +69,16 @@ public OpenAiRateLimitServiceSettings rateLimitServiceSettings() { } public abstract ExecutableAction accept(OpenAiActionVisitor creator, Map taskSettings); + + public int rateLimitGroupingHash() { + return Objects.hash(rateLimitServiceSettings.modelId(), apiKey, uri); + } + + public RateLimitSettings rateLimitSettings() { + return rateLimitServiceSettings.rateLimitSettings(); + } + + public URI uri() { + return uri; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 30973bea16ec5..8df3dc01b4fbe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -32,11 +32,15 @@ import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; -import org.elasticsearch.xpack.inference.external.http.sender.OpenAiUnifiedCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -82,6 +86,10 @@ public class OpenAiService extends SenderService { * The task types that the {@link InferenceAction.Request} can accept. */ private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION); + private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( + "openai completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); @@ -295,9 +303,17 @@ public void doUnifiedCompletionInfer( OpenAiChatCompletionModel openAiModel = (OpenAiChatCompletionModel) model; var overriddenModel = OpenAiChatCompletionModel.of(openAiModel, inputs.getRequest()); - var requestCreator = OpenAiUnifiedCompletionRequestManager.of(overriddenModel, getServiceComponents().threadPool()); + + var manager = new GenericRequestManager<>( + getServiceComponents().threadPool(), + overriddenModel, + UNIFIED_CHAT_COMPLETION_HANDLER, + (unifiedChatInput) -> new OpenAiUnifiedChatCompletionRequest(unifiedChatInput, overriddenModel), + UnifiedChatInput.class + ); + var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); - var action = new SenderExecutableAction(getSender(), requestCreator, errorMessage); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); action.execute(inputs, timeout, listener); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java index dea703b9ce243..9b1d27d569cc8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.openai.completion; +import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -14,13 +15,19 @@ import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionVisitor; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.openai.OpenAiModel; +import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import java.net.URI; +import java.net.URISyntaxException; import java.util.Map; import java.util.Objects; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; + public class OpenAiChatCompletionModel extends OpenAiModel { public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, Map taskSettings) { @@ -83,10 +90,18 @@ public OpenAiChatCompletionModel( new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secrets), serviceSettings, - secrets + secrets, + buildUri(serviceSettings.uri(), OpenAiService.NAME, OpenAiChatCompletionModel::buildDefaultUri) ); } + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(OpenAiUtils.HOST) + .setPathSegments(OpenAiUtils.VERSION_1, OpenAiUtils.CHAT_PATH, OpenAiUtils.COMPLETIONS_PATH) + .build(); + } + private OpenAiChatCompletionModel(OpenAiChatCompletionModel originalModel, OpenAiChatCompletionTaskSettings taskSettings) { super(originalModel, taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java index 5659c46050ad8..e1265fe12b1fa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.openai.embeddings; +import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ModelConfigurations; @@ -14,12 +15,18 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionVisitor; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.openai.OpenAiModel; +import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import java.net.URI; +import java.net.URISyntaxException; import java.util.Map; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; + public class OpenAiEmbeddingsModel extends OpenAiModel { public static OpenAiEmbeddingsModel of(OpenAiEmbeddingsModel model, Map taskSettings) { @@ -66,10 +73,18 @@ public OpenAiEmbeddingsModel( new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secrets), serviceSettings, - secrets + secrets, + buildUri(serviceSettings.uri(), OpenAiService.NAME, OpenAiEmbeddingsModel::buildDefaultUri) ); } + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(OpenAiUtils.HOST) + .setPathSegments(OpenAiUtils.VERSION_1, OpenAiUtils.EMBEDDINGS_PATH) + .build(); + } + private OpenAiEmbeddingsModel(OpenAiEmbeddingsModel originalModel, OpenAiEmbeddingsTaskSettings taskSettings) { super(originalModel, taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java index 1092d84a6ef6b..733a6c1a917ed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java @@ -19,7 +19,7 @@ import java.util.List; -import static org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager.USER_ROLE; +import static org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator.USER_ROLE; /** * This class uses the unified chat completion method to perform validation. diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java index c96372eadfbc2..d29e19a7902ba 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java @@ -28,10 +28,12 @@ import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.junit.After; import org.junit.Before; @@ -44,6 +46,8 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator.COMPLETION_HANDLER; +import static org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator.USER_ROLE; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; @@ -282,8 +286,14 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc private ExecutableAction createAction(String url, String org, String apiKey, String modelName, @Nullable String user, Sender sender) { var model = createCompletionModel(url, org, apiKey, modelName, user); - var requestCreator = OpenAiCompletionRequestManager.of(model, threadPool); + var manager = new GenericRequestManager<>( + threadPool, + model, + COMPLETION_HANDLER, + (inputs) -> new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); var errorMessage = constructFailedToSendRequestMessage("OpenAI chat completions"); - return new SingleInputSenderExecutableAction(sender, requestCreator, errorMessage, "OpenAI chat completions"); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "OpenAI chat completions"); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java index c8a0e1c398d4b..799804482e2f6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java @@ -27,8 +27,9 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; -import org.elasticsearch.xpack.inference.external.http.sender.OpenAiEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.TruncatingRequestManager; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiEmbeddingsRequest; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; import org.junit.After; @@ -41,6 +42,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator.EMBEDDINGS_HANDLER; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; @@ -223,9 +225,16 @@ public void testExecute_ThrowsExceptionWithNullUrl() { private ExecutableAction createAction(String url, String org, String apiKey, String modelName, @Nullable String user, Sender sender) { var model = createModel(url, org, apiKey, modelName, user); - var requestCreator = OpenAiEmbeddingsRequestManager.of(model, TruncatorTests.createTruncator(), threadPool); + var manager = new TruncatingRequestManager( + threadPool, + model, + EMBEDDINGS_HANDLER, + (truncationResult) -> new OpenAiEmbeddingsRequest(TruncatorTests.createTruncator(), truncationResult, model), + null + ); + var errorMessage = constructFailedToSendRequestMessage("OpenAI embeddings"); - return new SenderExecutableAction(sender, requestCreator, errorMessage); + return new SenderExecutableAction(sender, manager, errorMessage); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManagerTests.java index eb7f7c4a0035d..b4dd269d872ea 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManagerTests.java @@ -10,11 +10,13 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiEmbeddingsRequest; +import static org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator.EMBEDDINGS_HANDLER; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel; public class OpenAiEmbeddingsRequestManagerTests { - public static OpenAiEmbeddingsRequestManager makeCreator( + public static RequestManager makeCreator( String url, @Nullable String org, String apiKey, @@ -23,11 +25,16 @@ public static OpenAiEmbeddingsRequestManager makeCreator( ThreadPool threadPool ) { var model = createModel(url, org, apiKey, modelName, user); - - return OpenAiEmbeddingsRequestManager.of(model, TruncatorTests.createTruncator(), threadPool); + return new TruncatingRequestManager( + threadPool, + model, + EMBEDDINGS_HANDLER, + (truncationResult) -> new OpenAiEmbeddingsRequest(TruncatorTests.createTruncator(), truncationResult, model), + null + ); } - public static OpenAiEmbeddingsRequestManager makeCreator( + public static RequestManager makeCreator( String url, @Nullable String org, String apiKey, @@ -37,7 +44,12 @@ public static OpenAiEmbeddingsRequestManager makeCreator( ThreadPool threadPool ) { var model = createModel(url, org, apiKey, modelName, user, inferenceEntityId); - - return OpenAiEmbeddingsRequestManager.of(model, TruncatorTests.createTruncator(), threadPool); + return new TruncatingRequestManager( + threadPool, + model, + EMBEDDINGS_HANDLER, + (truncationResult) -> new OpenAiEmbeddingsRequest(TruncatorTests.createTruncator(), truncationResult, model), + null + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequestTests.java index 935b27cfb688a..e80cdc848a1e4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequestTests.java @@ -21,8 +21,8 @@ import java.util.List; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; -import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiEmbeddingsRequest.buildDefaultUri; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; +import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel.buildDefaultUri; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java index ec4231bd73154..7b15ab5abed88 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java @@ -21,8 +21,8 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; -import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest.buildDefaultUri; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel.buildDefaultUri; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -85,7 +85,7 @@ public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); + assertThat(httpPost.getURI().toString(), is(buildDefaultUri().toString())); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); assertNull(httpPost.getLastHeader(ORGANIZATION_HEADER)); @@ -108,7 +108,7 @@ public void testCreateRequest_WithStreaming() throws IOException { public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, IOException { var request = createRequest(null, null, "secret", "abcd", "model", null, true); var truncatedRequest = request.truncate(); - assertThat(request.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); + assertThat(request.getURI().toString(), is(buildDefaultUri().toString())); var httpRequest = truncatedRequest.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));