Skip to content

Commit ceebb9a

Browse files
Add support for non-streaming completion task for HuggingFace
1 parent 404f640 commit ceebb9a

File tree

3 files changed

+56
-11
lines changed

3 files changed

+56
-11
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,18 @@
99

1010
import org.elasticsearch.common.settings.SecureString;
1111
import org.elasticsearch.core.Nullable;
12-
import org.elasticsearch.inference.Model;
1312
import org.elasticsearch.inference.ModelConfigurations;
1413
import org.elasticsearch.inference.ModelSecrets;
1514
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
15+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
1616
import org.elasticsearch.xpack.inference.services.ServiceUtils;
1717
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
1818
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
19+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
1920

2021
import java.util.Objects;
2122

22-
public abstract class HuggingFaceModel extends Model {
23+
public abstract class HuggingFaceModel extends RateLimitGroupingModel {
2324
private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings;
2425
private final SecureString apiKey;
2526

@@ -38,6 +39,16 @@ public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() {
3839
return rateLimitServiceSettings;
3940
}
4041

42+
@Override
43+
public int rateLimitGroupingHash() {
44+
return Objects.hash(rateLimitServiceSettings.uri(), apiKey);
45+
}
46+
47+
@Override
48+
public RateLimitSettings rateLimitSettings() {
49+
return rateLimitServiceSettings.rateLimitSettings();
50+
}
51+
4152
public SecureString apiKey() {
4253
return apiKey;
4354
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
2727
import org.elasticsearch.rest.RestStatus;
2828
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
29+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
30+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
2931
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
32+
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
3033
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3134
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
3235
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
@@ -36,6 +39,9 @@
3639
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
3740
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
3841
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
42+
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
43+
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
44+
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
3945
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
4046
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
4147

@@ -45,6 +51,7 @@
4551
import java.util.Map;
4652
import java.util.Set;
4753

54+
import static org.elasticsearch.core.Strings.format;
4855
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
4956
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
5057

@@ -55,13 +62,19 @@
5562
public class HuggingFaceService extends HuggingFaceBaseService {
5663
public static final String NAME = "hugging_face";
5764

65+
private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE =
66+
"Failed to send Hugging Face %s request from inference entity id [%s]";
5867
private static final String SERVICE_NAME = "Hugging Face";
5968
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(
6069
TaskType.TEXT_EMBEDDING,
6170
TaskType.SPARSE_EMBEDDING,
6271
TaskType.COMPLETION,
6372
TaskType.CHAT_COMPLETION
6473
);
74+
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler(
75+
"hugging face chat completion",
76+
OpenAiChatCompletionResponseEntity::fromResponse
77+
);
6578

6679
public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
6780
super(factory, serviceComponents);
@@ -161,10 +174,18 @@ protected void doUnifiedCompletionInfer(
161174
listener.onFailure(createInvalidModelException(model));
162175
return;
163176
}
177+
164178
HuggingFaceChatCompletionModel huggingFaceChatCompletionModel = (HuggingFaceChatCompletionModel) model;
165-
var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents());
166179
var overriddenModel = HuggingFaceChatCompletionModel.of(huggingFaceChatCompletionModel, inputs.getRequest());
167-
var action = overriddenModel.accept(actionCreator);
180+
var manager = new GenericRequestManager<>(
181+
getServiceComponents().threadPool(),
182+
overriddenModel,
183+
UNIFIED_CHAT_COMPLETION_HANDLER,
184+
unifiedChatInput -> new HuggingFaceUnifiedChatCompletionRequest(unifiedChatInput, overriddenModel),
185+
UnifiedChatInput.class
186+
);
187+
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "CHAT COMPLETION", model.getInferenceEntityId());
188+
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
168189

169190
action.execute(inputs, timeout, listener);
170191
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,19 @@
99

1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1111
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
12+
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
13+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
14+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
15+
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
1216
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
17+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
1318
import org.elasticsearch.xpack.inference.services.ServiceComponents;
14-
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceChatCompletionRequestManager;
1519
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceEmbeddingsRequestManager;
1620
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler;
1721
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
1822
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
1923
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
24+
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
2025
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity;
2126
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;
2227
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
@@ -31,8 +36,14 @@
3136
*/
3237
public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
3338

39+
public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions";
40+
private static final String USER_ROLE = "user";
3441
private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE =
3542
"Failed to send Hugging Face %s request from inference entity id [%s]";
43+
static final ResponseHandler COMPLETION_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler(
44+
"hugging face completion",
45+
OpenAiChatCompletionResponseEntity::fromResponse
46+
);
3647
private final Sender sender;
3748
private final ServiceComponents serviceComponents;
3849

@@ -72,13 +83,15 @@ public ExecutableAction create(HuggingFaceElserModel model) {
7283

7384
@Override
7485
public ExecutableAction create(HuggingFaceChatCompletionModel model) {
75-
var responseHandler = new OpenAiUnifiedChatCompletionResponseHandler(
76-
"hugging face chat completion",
77-
OpenAiChatCompletionResponseEntity::fromResponse
86+
var manager = new GenericRequestManager<>(
87+
serviceComponents.threadPool(),
88+
model,
89+
COMPLETION_HANDLER,
90+
inputs -> new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
91+
ChatCompletionInput.class
7892
);
7993

80-
var requestCreator = HuggingFaceChatCompletionRequestManager.of(model, responseHandler, serviceComponents.threadPool());
81-
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "CHAT COMPLETION", model.getInferenceEntityId());
82-
return new SenderExecutableAction(sender, requestCreator, errorMessage);
94+
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "COMPLETION", model.getInferenceEntityId());
95+
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX);
8396
}
8497
}

0 commit comments

Comments
 (0)