-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Add Hugging Face Chat Completion support to Inference Plugin #127254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
63f21de
6b7dd2e
65e4060
404f640
ceebb9a
acaa35b
91fa92e
ff3ef50
965093b
6757b07
58ea9fd
df845eb
cc24e68
5bbe3b7
3684816
7670d2c
6630be7
1efb2ee
61537d0
64c0685
4688901
bfc8072
13ef13b
129caaf
214de5f
d3411d6
e170b96
473dee6
cb03100
c856853
bd2e601
aae528a
82f8049
b0679d5
2fa3dff
cdb3c1c
9370b57
9044bee
e72a312
e2cb334
a4b5d2c
c5988ed
1547559
71c6057
228fffa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,32 +26,55 @@ | |
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; | ||
import org.elasticsearch.rest.RestStatus; | ||
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; | ||
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; | ||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; | ||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; | ||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; | ||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; | ||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; | ||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; | ||
import org.elasticsearch.xpack.inference.services.ServiceComponents; | ||
import org.elasticsearch.xpack.inference.services.ServiceUtils; | ||
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator; | ||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; | ||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; | ||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; | ||
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; | ||
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; | ||
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; | ||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; | ||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; | ||
|
||
import java.util.EnumSet; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Set; | ||
|
||
import static org.elasticsearch.core.Strings.format; | ||
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; | ||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; | ||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; | ||
|
||
/** | ||
* This class is responsible for managing the Hugging Face inference service. | ||
* It handles the creation of models, chunked inference, and unified completion inference. | ||
*/ | ||
public class HuggingFaceService extends HuggingFaceBaseService { | ||
public static final String NAME = "hugging_face"; | ||
|
||
private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = | ||
"Failed to send Hugging Face %s request from inference entity id [%s]"; | ||
private static final String SERVICE_NAME = "Hugging Face"; | ||
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING); | ||
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of( | ||
TaskType.TEXT_EMBEDDING, | ||
TaskType.SPARSE_EMBEDDING, | ||
TaskType.COMPLETION, | ||
TaskType.CHAT_COMPLETION | ||
); | ||
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( | ||
"hugging face chat completion", | ||
OpenAiChatCompletionResponseEntity::fromResponse | ||
); | ||
|
||
public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { | ||
super(factory, serviceComponents); | ||
|
@@ -78,6 +101,14 @@ protected HuggingFaceModel createModel( | |
context | ||
); | ||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context); | ||
case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel( | ||
inferenceEntityId, | ||
taskType, | ||
NAME, | ||
serviceSettings, | ||
secretSettings, | ||
context | ||
); | ||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); | ||
}; | ||
} | ||
|
@@ -139,7 +170,29 @@ protected void doUnifiedCompletionInfer( | |
TimeValue timeout, | ||
ActionListener<InferenceServiceResults> listener | ||
) { | ||
throwUnsupportedUnifiedCompletionOperation(NAME); | ||
if (model instanceof HuggingFaceChatCompletionModel == false) { | ||
listener.onFailure(createInvalidModelException(model)); | ||
return; | ||
} | ||
|
||
HuggingFaceChatCompletionModel huggingFaceChatCompletionModel = (HuggingFaceChatCompletionModel) model; | ||
var overriddenModel = HuggingFaceChatCompletionModel.of(huggingFaceChatCompletionModel, inputs.getRequest()); | ||
var manager = new GenericRequestManager<>( | ||
getServiceComponents().threadPool(), | ||
overriddenModel, | ||
UNIFIED_CHAT_COMPLETION_HANDLER, | ||
unifiedChatInput -> new HuggingFaceUnifiedChatCompletionRequest(unifiedChatInput, overriddenModel), | ||
UnifiedChatInput.class | ||
); | ||
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "CHAT COMPLETION", model.getInferenceEntityId()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: How about we move this into a function something like:
It might be a little easier to see how the string is being formatted if the raw string is included in the format call. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: P.S. Also having elser vs sparse embedding used interchangeably might be worth unifying to keep the vocabulary more strict. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added version described above. Please do tell if you'd like to stick with the version you proposed initially. |
||
var action = new SenderExecutableAction(getSender(), manager, errorMessage); | ||
|
||
action.execute(inputs, timeout, listener); | ||
} | ||
|
||
@Override | ||
public Set<TaskType> supportedStreamingTasks() { | ||
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); | ||
} | ||
|
||
@Override | ||
|
@@ -149,7 +202,7 @@ public InferenceServiceConfiguration getConfiguration() { | |
|
||
@Override | ||
public EnumSet<TaskType> supportedTaskTypes() { | ||
return supportedTaskTypes; | ||
return SUPPORTED_TASK_TYPES; | ||
} | ||
|
||
@Override | ||
|
@@ -167,13 +220,15 @@ public static InferenceServiceConfiguration get() { | |
return configuration.getOrCompute(); | ||
} | ||
|
||
private Configuration() {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this line needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In short - to protect this class from being instantiated. Since there are only static members in this class - there is no reason for having an option of instantiating it. To protect this class from being instantiated we can hide default constructor that every Object has by declaring private one. |
||
|
||
private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>( | ||
() -> { | ||
var configurationMap = new HashMap<String, SettingsConfiguration>(); | ||
|
||
configurationMap.put( | ||
URL, | ||
new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue("https://api.openai.com/v1/embeddings") | ||
new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDefaultValue("https://api.openai.com/v1/embeddings") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops looks like we have an existing bug here (unrelated to your changes). Can you remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I initially assumed it is there for some internal configuration and didn't want to introduce any risks by changing it. Removed. |
||
.setDescription("The URL endpoint to use for the requests.") | ||
.setLabel("URL") | ||
.setRequired(true) | ||
|
@@ -183,12 +238,12 @@ public static InferenceServiceConfiguration get() { | |
.build() | ||
); | ||
|
||
configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); | ||
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); | ||
configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); | ||
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); | ||
|
||
return new InferenceServiceConfiguration.Builder().setService(NAME) | ||
.setName(SERVICE_NAME) | ||
.setTaskTypes(supportedTaskTypes) | ||
.setTaskTypes(SUPPORTED_TASK_TYPES) | ||
.setConfigurations(configurationMap) | ||
.build(); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,14 +9,23 @@ | |
|
||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction; | ||
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; | ||
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; | ||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; | ||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; | ||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; | ||
import org.elasticsearch.xpack.inference.external.http.sender.Sender; | ||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; | ||
import org.elasticsearch.xpack.inference.services.ServiceComponents; | ||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestManager; | ||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler; | ||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; | ||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; | ||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; | ||
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; | ||
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity; | ||
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity; | ||
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; | ||
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; | ||
|
||
import java.util.Objects; | ||
|
||
|
@@ -26,6 +35,15 @@ | |
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the hugging face model type. | ||
*/ | ||
public class HuggingFaceActionCreator implements HuggingFaceActionVisitor { | ||
|
||
public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions"; | ||
static final String USER_ROLE = "user"; | ||
private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = | ||
"Failed to send Hugging Face %s request from inference entity id [%s]"; | ||
static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( | ||
"hugging face completion", | ||
OpenAiChatCompletionResponseEntity::fromResponse | ||
); | ||
private final Sender sender; | ||
private final ServiceComponents serviceComponents; | ||
|
||
|
@@ -46,11 +64,7 @@ public ExecutableAction create(HuggingFaceEmbeddingsModel model) { | |
serviceComponents.truncator(), | ||
serviceComponents.threadPool() | ||
); | ||
var errorMessage = format( | ||
"Failed to send Hugging Face %s request from inference entity id [%s]", | ||
"text embeddings", | ||
model.getInferenceEntityId() | ||
); | ||
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "text embeddings", model.getInferenceEntityId()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Same comment as above suggesting making this a function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did the change described in my comment above. |
||
return new SenderExecutableAction(sender, requestCreator, errorMessage); | ||
} | ||
|
||
|
@@ -63,11 +77,21 @@ public ExecutableAction create(HuggingFaceElserModel model) { | |
serviceComponents.truncator(), | ||
serviceComponents.threadPool() | ||
); | ||
var errorMessage = format( | ||
"Failed to send Hugging Face %s request from inference entity id [%s]", | ||
"ELSER", | ||
model.getInferenceEntityId() | ||
); | ||
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "ELSER", model.getInferenceEntityId()); | ||
return new SenderExecutableAction(sender, requestCreator, errorMessage); | ||
} | ||
|
||
@Override | ||
public ExecutableAction create(HuggingFaceChatCompletionModel model) { | ||
var manager = new GenericRequestManager<>( | ||
serviceComponents.threadPool(), | ||
model, | ||
COMPLETION_HANDLER, | ||
inputs -> new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), | ||
ChatCompletionInput.class | ||
); | ||
|
||
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "COMPLETION", model.getInferenceEntityId()); | ||
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: The class also handles non-chunked inference which should be included in the javadoc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rephrased it so it is more specific. Thanks.