Skip to content

Add Hugging Face Chat Completion support to Inference Plugin #127254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
63f21de
Add Hugging Face Chat Completion support to Inference Plugin
Jan-Kazlouski-elastic Apr 23, 2025
6b7dd2e
Merge remote-tracking branch 'refs/remotes/origin/main' into feature/…
Jan-Kazlouski-elastic Apr 25, 2025
65e4060
Add support for streaming chat completion task for HuggingFace
Jan-Kazlouski-elastic Apr 25, 2025
404f640
[CI] Auto commit changes from spotless
Apr 25, 2025
ceebb9a
Add support for non-streaming completion task for HuggingFace
Jan-Kazlouski-elastic Apr 25, 2025
acaa35b
Remove RequestManager for HF Chat Completion Task
Jan-Kazlouski-elastic Apr 25, 2025
91fa92e
Merge remote-tracking branch 'refs/remotes/origin/main' into feature/…
Jan-Kazlouski-elastic Apr 28, 2025
ff3ef50
Refactored Hugging Face Completion Service Settings, removed Request …
Jan-Kazlouski-elastic Apr 28, 2025
965093b
Refactored Hugging Face Action Creator, added Unit Tests
Jan-Kazlouski-elastic Apr 29, 2025
6757b07
Add Hugging Face Server Test
Jan-Kazlouski-elastic Apr 29, 2025
58ea9fd
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic Apr 29, 2025
df845eb
[CI] Auto commit changes from spotless
Apr 29, 2025
cc24e68
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 2, 2025
5bbe3b7
Removed parameters from media type for Chat Completion Request and un…
Jan-Kazlouski-elastic May 2, 2025
3684816
Removed OpenAI default URL in HuggingFaceService's configuration, fix…
Jan-Kazlouski-elastic May 2, 2025
7670d2c
Refactor error message handling in HuggingFaceActionCreator and Huggi…
Jan-Kazlouski-elastic May 2, 2025
6630be7
Update minimal supported version and add Hugging Face transport versi…
Jan-Kazlouski-elastic May 2, 2025
1efb2ee
Made modelId field optional in HuggingFaceChatCompletionModel, update…
Jan-Kazlouski-elastic May 2, 2025
61537d0
Removed max input tokens field from HuggingFaceChatCompletionServiceS…
Jan-Kazlouski-elastic May 2, 2025
64c0685
Removed if statement checking TransportVersion for HuggingFaceChatCom…
Jan-Kazlouski-elastic May 2, 2025
4688901
Removed getFirst() method calls for backport compatibility
Jan-Kazlouski-elastic May 2, 2025
bfc8072
Made HuggingFaceChatCompletionServiceSettingsTests extend AbstractBWC…
Jan-Kazlouski-elastic May 2, 2025
13ef13b
Refactored tests to use stripWhitespace method for readability
Jan-Kazlouski-elastic May 2, 2025
129caaf
Refactored javadoc for HuggingFaceService
Jan-Kazlouski-elastic May 2, 2025
214de5f
Renamed HF chat completion TransportVersion constant names
Jan-Kazlouski-elastic May 2, 2025
d3411d6
Added random string generation in unit test
Jan-Kazlouski-elastic May 2, 2025
e170b96
Refactored javadocs for HuggingFace requests
Jan-Kazlouski-elastic May 2, 2025
473dee6
Refactored tests to reduce duplication
Jan-Kazlouski-elastic May 2, 2025
cb03100
Added changelog file
Jan-Kazlouski-elastic May 2, 2025
c856853
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 5, 2025
bd2e601
Merge remote-tracking branch 'refs/remotes/origin/main' into feature/…
Jan-Kazlouski-elastic May 5, 2025
aae528a
Add HuggingFaceChatCompletionResponseHandler and associated tests
Jan-Kazlouski-elastic May 5, 2025
82f8049
Refactor error handling in HuggingFaceServiceTests to standardize err…
Jan-Kazlouski-elastic May 5, 2025
b0679d5
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 6, 2025
2fa3dff
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 7, 2025
cdb3c1c
Refactor HuggingFace error handling to improve response structure and…
Jan-Kazlouski-elastic May 7, 2025
9370b57
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 11, 2025
9044bee
Allowing null function name for hugging face models
jonathan-buttner May 9, 2025
e72a312
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 12, 2025
e2cb334
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 13, 2025
a4b5d2c
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 13, 2025
c5988ed
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 14, 2025
1547559
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 19, 2025
71c6057
Merge branch 'main' into feature/hugging-face-chat-completion-integra…
Jan-Kazlouski-elastic May 19, 2025
228fffa
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(10));
assertThat(services.size(), equalTo(11));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -133,6 +133,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"cohere",
"deepseek",
"googleaistudio",
"hugging_face",
"openai",
"streaming_completion_test_service"
).toArray(),
Expand All @@ -143,15 +144,18 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(4));
assertThat(services.size(), equalTo(5));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}

assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
assertArrayEquals(
List.of("deepseek", "elastic", "hugging_face", "openai", "streaming_completion_test_service").toArray(),
providers
);
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
Expand Down Expand Up @@ -353,6 +354,13 @@ private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.En
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
HuggingFaceChatCompletionServiceSettings.NAME,
HuggingFaceChatCompletionServiceSettings::new
)
);
}

private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@

import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.util.Objects;

public abstract class HuggingFaceModel extends Model {
public abstract class HuggingFaceModel extends RateLimitGroupingModel {
private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings;
private final SecureString apiKey;

Expand All @@ -38,6 +39,16 @@ public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}

@Override
public int rateLimitGroupingHash() {
return Objects.hash(rateLimitServiceSettings.uri(), apiKey);
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitServiceSettings.rateLimitSettings();
}

public SecureString apiKey() {
return apiKey;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.services.huggingface.request.HuggingFaceInferenceRequest;
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest;

import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -64,7 +64,7 @@ public void execute(
) {
List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
var truncatedInput = truncate(docsInput, model.getTokenLimit());
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);
var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,55 @@
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;

/**
* This class is responsible for managing the Hugging Face inference service.
* It handles the creation of models, chunked inference, and unified completion inference.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The class also handles non-chunked inference which should be included in the javadoc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rephrased it so it is more specific. Thanks.

*/
public class HuggingFaceService extends HuggingFaceBaseService {
public static final String NAME = "hugging_face";

private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE =
"Failed to send Hugging Face %s request from inference entity id [%s]";
private static final String SERVICE_NAME = "Hugging Face";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING);
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.SPARSE_EMBEDDING,
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION
);
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler(
"hugging face chat completion",
OpenAiChatCompletionResponseEntity::fromResponse
);

public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
Expand All @@ -78,6 +101,14 @@ protected HuggingFaceModel createModel(
context
);
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
secretSettings,
context
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
Expand Down Expand Up @@ -139,7 +170,29 @@ protected void doUnifiedCompletionInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throwUnsupportedUnifiedCompletionOperation(NAME);
if (model instanceof HuggingFaceChatCompletionModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}

HuggingFaceChatCompletionModel huggingFaceChatCompletionModel = (HuggingFaceChatCompletionModel) model;
var overriddenModel = HuggingFaceChatCompletionModel.of(huggingFaceChatCompletionModel, inputs.getRequest());
var manager = new GenericRequestManager<>(
getServiceComponents().threadPool(),
overriddenModel,
UNIFIED_CHAT_COMPLETION_HANDLER,
unifiedChatInput -> new HuggingFaceUnifiedChatCompletionRequest(unifiedChatInput, overriddenModel),
UnifiedChatInput.class
);
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "CHAT COMPLETION", model.getInferenceEntityId());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: How about we move this into a function something like:

private static String errorMessage(String requestDescription, String inferenceId) {
  return format("Failed to send Hugging Face %s request from inference entity id [%s]", requestDescription, inferenceId)
}

It might be a little easier to see how the string is being formatted if the raw string is included in the format call.

Copy link
Contributor Author

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:
Maybe we should use TaskType taskType instead of String requestDescription parameter in it? That way we'd restrict values to be a part of specified list of clearly defined tasks erasing possibility of different formatting. Because in current implementation it is "text embeddings" and "ELSER" which is a bit messy.
Such approach would change "ELSER" to sparse_embedding and make other values lowercase as well.

P.S. Also having elser vs sparse embedding used interchangeably might be worth unifying to keep the vocabulary more strict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added version described above. Please do tell if you'd like to stick with the version you proposed initially.

var action = new SenderExecutableAction(getSender(), manager, errorMessage);

action.execute(inputs, timeout, listener);
}

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
}

@Override
Expand All @@ -149,7 +202,7 @@ public InferenceServiceConfiguration getConfiguration() {

@Override
public EnumSet<TaskType> supportedTaskTypes() {
return supportedTaskTypes;
return SUPPORTED_TASK_TYPES;
}

@Override
Expand All @@ -167,13 +220,15 @@ public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();
}

private Configuration() {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this line needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In short - to protect this class from being instantiated.

Since there are only static members in this class - there is no reason for having an option of instantiating it. To protect this class from being instantiated we can hide default constructor that every Object has by declaring private one.
It is optional, si if you want - I can remove this.


private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
() -> {
var configurationMap = new HashMap<String, SettingsConfiguration>();

configurationMap.put(
URL,
new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue("https://api.openai.com/v1/embeddings")
new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDefaultValue("https://api.openai.com/v1/embeddings")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops looks like we have an existing bug here (unrelated to your changes). Can you remove the setDefaultValue that shouldn't be pointing to openai 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially assumed it is there for some internal configuration and didn't want to introduce any risks by changing it. Removed.

.setDescription("The URL endpoint to use for the requests.")
.setLabel("URL")
.setRequired(true)
Expand All @@ -183,12 +238,12 @@ public static InferenceServiceConfiguration get() {
.build()
);

configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes));
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes));
configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES));
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES));

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(SERVICE_NAME)
.setTaskTypes(supportedTaskTypes)
.setTaskTypes(SUPPORTED_TASK_TYPES)
.setConfigurations(configurationMap)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,23 @@

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestManager;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity;
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;

import java.util.Objects;

Expand All @@ -26,6 +35,15 @@
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the hugging face model type.
*/
public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {

public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions";
static final String USER_ROLE = "user";
private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE =
"Failed to send Hugging Face %s request from inference entity id [%s]";
static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
"hugging face completion",
OpenAiChatCompletionResponseEntity::fromResponse
);
private final Sender sender;
private final ServiceComponents serviceComponents;

Expand All @@ -46,11 +64,7 @@ public ExecutableAction create(HuggingFaceEmbeddingsModel model) {
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = format(
"Failed to send Hugging Face %s request from inference entity id [%s]",
"text embeddings",
model.getInferenceEntityId()
);
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "text embeddings", model.getInferenceEntityId());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Same comment as above suggesting making this a function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did the change described in my comment above.

return new SenderExecutableAction(sender, requestCreator, errorMessage);
}

Expand All @@ -63,11 +77,21 @@ public ExecutableAction create(HuggingFaceElserModel model) {
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = format(
"Failed to send Hugging Face %s request from inference entity id [%s]",
"ELSER",
model.getInferenceEntityId()
);
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "ELSER", model.getInferenceEntityId());
return new SenderExecutableAction(sender, requestCreator, errorMessage);
}

@Override
public ExecutableAction create(HuggingFaceChatCompletionModel model) {
var manager = new GenericRequestManager<>(
serviceComponents.threadPool(),
model,
COMPLETION_HANDLER,
inputs -> new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
ChatCompletionInput.class
);

var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "COMPLETION", model.getInferenceEntityId());
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
package org.elasticsearch.xpack.inference.services.huggingface.action;

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;

public interface HuggingFaceActionVisitor {
ExecutableAction create(HuggingFaceEmbeddingsModel model);

ExecutableAction create(HuggingFaceElserModel model);

ExecutableAction create(HuggingFaceChatCompletionModel model);
}
Loading