|
26 | 26 | import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
27 | 27 | import org.elasticsearch.rest.RestStatus;
|
28 | 28 | import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
|
| 29 | +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; |
| 30 | +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; |
29 | 31 | import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
| 32 | +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; |
30 | 33 | import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
31 | 34 | import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
32 | 35 | import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
|
36 | 39 | import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
37 | 40 | import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
38 | 41 | import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
| 42 | +import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; |
| 43 | +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; |
| 44 | +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; |
39 | 45 | import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
40 | 46 | import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
41 | 47 |
|
|
45 | 51 | import java.util.Map;
|
46 | 52 | import java.util.Set;
|
47 | 53 |
|
| 54 | +import static org.elasticsearch.core.Strings.format; |
48 | 55 | import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
|
49 | 56 | import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
|
50 | 57 |
|
|
55 | 62 | public class HuggingFaceService extends HuggingFaceBaseService {
|
56 | 63 | public static final String NAME = "hugging_face";
|
57 | 64 |
|
| 65 | + private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = |
| 66 | + "Failed to send Hugging Face %s request from inference entity id [%s]"; |
58 | 67 | private static final String SERVICE_NAME = "Hugging Face";
|
59 | 68 | private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(
|
60 | 69 | TaskType.TEXT_EMBEDDING,
|
61 | 70 | TaskType.SPARSE_EMBEDDING,
|
62 | 71 | TaskType.COMPLETION,
|
63 | 72 | TaskType.CHAT_COMPLETION
|
64 | 73 | );
|
| 74 | + private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( |
| 75 | + "hugging face chat completion", |
| 76 | + OpenAiChatCompletionResponseEntity::fromResponse |
| 77 | + ); |
65 | 78 |
|
66 | 79 | public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
67 | 80 | super(factory, serviceComponents);
|
@@ -161,10 +174,18 @@ protected void doUnifiedCompletionInfer(
|
161 | 174 | listener.onFailure(createInvalidModelException(model));
|
162 | 175 | return;
|
163 | 176 | }
|
| 177 | + |
164 | 178 | HuggingFaceChatCompletionModel huggingFaceChatCompletionModel = (HuggingFaceChatCompletionModel) model;
|
165 |
| - var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents()); |
166 | 179 | var overriddenModel = HuggingFaceChatCompletionModel.of(huggingFaceChatCompletionModel, inputs.getRequest());
|
167 |
| - var action = overriddenModel.accept(actionCreator); |
| 180 | + var manager = new GenericRequestManager<>( |
| 181 | + getServiceComponents().threadPool(), |
| 182 | + overriddenModel, |
| 183 | + UNIFIED_CHAT_COMPLETION_HANDLER, |
| 184 | + unifiedChatInput -> new HuggingFaceUnifiedChatCompletionRequest(unifiedChatInput, overriddenModel), |
| 185 | + UnifiedChatInput.class |
| 186 | + ); |
| 187 | + var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "CHAT COMPLETION", model.getInferenceEntityId()); |
| 188 | + var action = new SenderExecutableAction(getSender(), manager, errorMessage); |
168 | 189 |
|
169 | 190 | action.execute(inputs, timeout, listener);
|
170 | 191 | }
|
|
0 commit comments