Skip to content

Commit 2c937ec

Browse files
[ML] Refactor OpenAI request managers (#124144) (#124240)
* Code compiling * Removing OpenAiAccount
1 parent 44f4692 commit 2c937ec

23 files changed

+349
-335
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,18 @@
1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1111
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1212
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
13-
import org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager;
14-
import org.elasticsearch.xpack.inference.external.http.sender.OpenAiEmbeddingsRequestManager;
13+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
14+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
15+
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
1516
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
17+
import org.elasticsearch.xpack.inference.external.http.sender.TruncatingRequestManager;
18+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
19+
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler;
20+
import org.elasticsearch.xpack.inference.external.openai.OpenAiResponseHandler;
21+
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiEmbeddingsRequest;
22+
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest;
23+
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
24+
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity;
1625
import org.elasticsearch.xpack.inference.services.ServiceComponents;
1726
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
1827
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;
@@ -27,6 +36,18 @@
2736
*/
2837
public class OpenAiActionCreator implements OpenAiActionVisitor {
2938
public static final String COMPLETION_ERROR_PREFIX = "OpenAI chat completions";
39+
public static final String USER_ROLE = "user";
40+
41+
static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
42+
"openai completion",
43+
OpenAiChatCompletionResponseEntity::fromResponse
44+
);
45+
public static final ResponseHandler EMBEDDINGS_HANDLER = new OpenAiResponseHandler(
46+
"openai text embedding",
47+
OpenAiEmbeddingsResponseEntity::fromResponse,
48+
false
49+
);
50+
3051
private final Sender sender;
3152
private final ServiceComponents serviceComponents;
3253

@@ -38,20 +59,30 @@ public OpenAiActionCreator(Sender sender, ServiceComponents serviceComponents) {
3859
@Override
3960
public ExecutableAction create(OpenAiEmbeddingsModel model, Map<String, Object> taskSettings) {
4061
var overriddenModel = OpenAiEmbeddingsModel.of(model, taskSettings);
41-
var requestCreator = OpenAiEmbeddingsRequestManager.of(
62+
var manager = new TruncatingRequestManager(
63+
serviceComponents.threadPool(),
4264
overriddenModel,
43-
serviceComponents.truncator(),
44-
serviceComponents.threadPool()
65+
EMBEDDINGS_HANDLER,
66+
(truncationResult) -> new OpenAiEmbeddingsRequest(serviceComponents.truncator(), truncationResult, overriddenModel),
67+
overriddenModel.getServiceSettings().maxInputTokens()
4568
);
69+
4670
var errorMessage = constructFailedToSendRequestMessage("OpenAI embeddings");
47-
return new SenderExecutableAction(sender, requestCreator, errorMessage);
71+
return new SenderExecutableAction(sender, manager, errorMessage);
4872
}
4973

5074
@Override
5175
public ExecutableAction create(OpenAiChatCompletionModel model, Map<String, Object> taskSettings) {
5276
var overriddenModel = OpenAiChatCompletionModel.of(model, taskSettings);
53-
var requestCreator = OpenAiCompletionRequestManager.of(overriddenModel, serviceComponents.threadPool());
77+
var manager = new GenericRequestManager<>(
78+
serviceComponents.threadPool(),
79+
overriddenModel,
80+
COMPLETION_HANDLER,
81+
(inputs) -> new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), overriddenModel),
82+
ChatCompletionInput.class
83+
);
84+
5485
var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
55-
return new SingleInputSenderExecutableAction(sender, requestCreator, errorMessage, COMPLETION_ERROR_PREFIX);
86+
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX);
5687
}
5788
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManager.java

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.external.http.sender;
99

1010
import org.elasticsearch.threadpool.ThreadPool;
11+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
1112
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
1213

1314
import java.util.Objects;
@@ -17,14 +18,31 @@
1718
abstract class BaseRequestManager implements RequestManager {
1819
private final ThreadPool threadPool;
1920
private final String inferenceEntityId;
20-
private final Object rateLimitGroup;
21+
// It's possible that two inference endpoints have the same information defining the group but have different
22+
// rate limits then they should be in different groups otherwise whoever initially created the group will set
23+
// the rate and the other inference endpoint's rate will be ignored
24+
private final EndpointGrouping endpointGrouping;
2125
private final RateLimitSettings rateLimitSettings;
2226

2327
BaseRequestManager(ThreadPool threadPool, String inferenceEntityId, Object rateLimitGroup, RateLimitSettings rateLimitSettings) {
2428
this.threadPool = Objects.requireNonNull(threadPool);
2529
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
26-
this.rateLimitGroup = Objects.requireNonNull(rateLimitGroup);
27-
this.rateLimitSettings = Objects.requireNonNull(rateLimitSettings);
30+
31+
Objects.requireNonNull(rateLimitSettings);
32+
this.endpointGrouping = new EndpointGrouping(Objects.requireNonNull(rateLimitGroup).hashCode(), rateLimitSettings);
33+
this.rateLimitSettings = rateLimitSettings;
34+
}
35+
36+
BaseRequestManager(ThreadPool threadPool, RateLimitGroupingModel rateLimitGroupingModel) {
37+
this.threadPool = Objects.requireNonNull(threadPool);
38+
Objects.requireNonNull(rateLimitGroupingModel);
39+
40+
this.inferenceEntityId = rateLimitGroupingModel.inferenceEntityId();
41+
this.endpointGrouping = new EndpointGrouping(
42+
rateLimitGroupingModel.rateLimitGroupingHash(),
43+
rateLimitGroupingModel.rateLimitSettings()
44+
);
45+
this.rateLimitSettings = rateLimitGroupingModel.rateLimitSettings();
2846
}
2947

3048
protected void execute(Runnable runnable) {
@@ -38,16 +56,13 @@ public String inferenceEntityId() {
3856

3957
@Override
4058
public Object rateLimitGrouping() {
41-
// It's possible that two inference endpoints have the same information defining the group but have different
42-
// rate limits then they should be in different groups otherwise whoever initially created the group will set
43-
// the rate and the other inference endpoint's rate will be ignored
44-
return new EndpointGrouping(rateLimitGroup, rateLimitSettings);
59+
return endpointGrouping;
4560
}
4661

4762
@Override
4863
public RateLimitSettings rateLimitSettings() {
4964
return rateLimitSettings;
5065
}
5166

52-
private record EndpointGrouping(Object group, RateLimitSettings settings) {}
67+
private record EndpointGrouping(int group, RateLimitSettings settings) {}
5368
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.request.Request;
18+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
19+
20+
import java.util.Objects;
21+
import java.util.function.Function;
22+
import java.util.function.Supplier;
23+
24+
/**
25+
* This is a temporary class to use while we refactor all the request managers. After all the request managers extend
26+
* this class we'll move this functionality directly into the {@link BaseRequestManager}.
27+
*/
28+
public class GenericRequestManager<T extends InferenceInputs> extends BaseRequestManager {
29+
private static final Logger logger = LogManager.getLogger(GenericRequestManager.class);
30+
31+
protected final ResponseHandler responseHandler;
32+
protected final Function<T, Request> requestCreator;
33+
protected final Class<T> inputType;
34+
35+
public GenericRequestManager(
36+
ThreadPool threadPool,
37+
RateLimitGroupingModel rateLimitGroupingModel,
38+
ResponseHandler responseHandler,
39+
Function<T, Request> requestCreator,
40+
Class<T> inputType
41+
) {
42+
super(threadPool, rateLimitGroupingModel);
43+
this.responseHandler = Objects.requireNonNull(responseHandler);
44+
this.requestCreator = Objects.requireNonNull(requestCreator);
45+
this.inputType = Objects.requireNonNull(inputType);
46+
}
47+
48+
@Override
49+
public void execute(
50+
InferenceInputs inferenceInputs,
51+
RequestSender requestSender,
52+
Supplier<Boolean> hasRequestCompletedFunction,
53+
ActionListener<InferenceServiceResults> listener
54+
) {
55+
var request = requestCreator.apply(inferenceInputs.castTo(inputType));
56+
57+
execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
58+
}
59+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public static IllegalArgumentException createUnsupportedTypeException(InferenceI
2222
);
2323
}
2424

25-
public <T> T castTo(Class<T> clazz) {
25+
public <T extends InferenceInputs> T castTo(Class<T> clazz) {
2626
if (clazz.isInstance(this) == false) {
2727
throw createUnsupportedTypeException(this, clazz);
2828
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java

Lines changed: 0 additions & 58 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java

Lines changed: 0 additions & 69 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiRequestManager.java

Lines changed: 0 additions & 40 deletions
This file was deleted.

0 commit comments

Comments
 (0)