Skip to content

Commit f4b19e9

Browse files
prwhelanalbertzaharovits
authored andcommitted
[ML] Integrate with DeepSeek API (elastic#122218)
Integrating for Chat Completion and Completion task types, both calling the chat completion API for DeepSeek.
1 parent d716f4a commit f4b19e9

File tree

13 files changed

+1089
-11
lines changed

13 files changed

+1089
-11
lines changed

docs/changelog/122218.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 122218
2+
summary: Integrate with `DeepSeek` API
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ static TransportVersion def(int id) {
147147
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
148148
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07);
149149
public static final TransportVersion INFERENCE_CONTEXT_8_X = def(8_841_0_08);
150+
public static final TransportVersion ML_INFERENCE_DEEPSEEK_8_19 = def(8_841_0_09);
150151
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
151152
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
152153
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@@ -183,6 +184,7 @@ static TransportVersion def(int id) {
183184
public static final TransportVersion ESQL_SERIALIZE_BLOCK_TYPE_CODE = def(9_026_0_00);
184185
public static final TransportVersion ESQL_THREAD_NAME_IN_DRIVER_PROFILE = def(9_027_0_00);
185186
public static final TransportVersion INFERENCE_CONTEXT = def(9_028_0_00);
187+
public static final TransportVersion ML_INFERENCE_DEEPSEEK = def(9_029_00_0);
186188

187189
/*
188190
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2525
@SuppressWarnings("unchecked")
2626
public void testGetServicesWithoutTaskType() throws IOException {
2727
List<Object> services = getAllServices();
28-
assertThat(services.size(), equalTo(20));
28+
assertThat(services.size(), equalTo(21));
2929

3030
String[] providers = new String[services.size()];
3131
for (int i = 0; i < services.size(); i++) {
@@ -41,6 +41,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
4141
"azureaistudio",
4242
"azureopenai",
4343
"cohere",
44+
"deepseek",
4445
"elastic",
4546
"elasticsearch",
4647
"googleaistudio",
@@ -114,7 +115,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
114115
@SuppressWarnings("unchecked")
115116
public void testGetServicesWithCompletionTaskType() throws IOException {
116117
List<Object> services = getServices(TaskType.COMPLETION);
117-
assertThat(services.size(), equalTo(9));
118+
assertThat(services.size(), equalTo(10));
118119

119120
String[] providers = new String[services.size()];
120121
for (int i = 0; i < services.size(); i++) {
@@ -130,6 +131,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
130131
"azureaistudio",
131132
"azureopenai",
132133
"cohere",
134+
"deepseek",
133135
"googleaistudio",
134136
"openai",
135137
"streaming_completion_test_service"
@@ -141,15 +143,15 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
141143
@SuppressWarnings("unchecked")
142144
public void testGetServicesWithChatCompletionTaskType() throws IOException {
143145
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
144-
assertThat(services.size(), equalTo(3));
146+
assertThat(services.size(), equalTo(4));
145147

146148
String[] providers = new String[services.size()];
147149
for (int i = 0; i < services.size(); i++) {
148150
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
149151
providers[i] = (String) serviceConfig.get("service");
150152
}
151153

152-
assertArrayEquals(List.of("elastic", "openai", "streaming_completion_test_service").toArray(), providers);
154+
assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
153155
}
154156

155157
@SuppressWarnings("unchecked")

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
5959
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
6060
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
61+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
6162
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
6263
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
6364
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
@@ -153,6 +154,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
153154
addUnifiedNamedWriteables(namedWriteables);
154155

155156
namedWriteables.addAll(StreamingTaskManager.namedWriteables());
157+
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
156158

157159
return namedWriteables;
158160
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
117117
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
118118
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
119+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
119120
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
120121
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
121122
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
@@ -362,6 +363,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
362363
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
363364
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
364365
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
366+
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
365367
ElasticsearchInternalService::new
366368
);
367369
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.deepseek;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.entity.ByteArrayEntity;
13+
import org.elasticsearch.ElasticsearchException;
14+
import org.elasticsearch.common.Strings;
15+
import org.elasticsearch.xcontent.ToXContent;
16+
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xcontent.json.JsonXContent;
18+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
19+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
20+
import org.elasticsearch.xpack.inference.external.request.Request;
21+
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
22+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
23+
24+
import java.io.IOException;
25+
import java.net.URI;
26+
import java.nio.charset.StandardCharsets;
27+
import java.util.Objects;
28+
29+
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
30+
31+
public class DeepSeekChatCompletionRequest implements Request {
32+
private static final String MODEL_FIELD = "model";
33+
private static final String MAX_TOKENS = "max_tokens";
34+
35+
private final DeepSeekChatCompletionModel model;
36+
private final UnifiedChatInput unifiedChatInput;
37+
38+
public DeepSeekChatCompletionRequest(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) {
39+
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
40+
this.model = Objects.requireNonNull(model);
41+
}
42+
43+
@Override
44+
public HttpRequest createHttpRequest() {
45+
HttpPost httpPost = new HttpPost(model.uri());
46+
47+
httpPost.setEntity(createEntity());
48+
49+
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
50+
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
51+
52+
return new HttpRequest(httpPost, getInferenceEntityId());
53+
}
54+
55+
private ByteArrayEntity createEntity() {
56+
var modelId = Objects.requireNonNullElseGet(unifiedChatInput.getRequest().model(), model::model);
57+
try (var builder = JsonXContent.contentBuilder()) {
58+
builder.startObject();
59+
new UnifiedChatCompletionRequestEntity(unifiedChatInput).toXContent(builder, ToXContent.EMPTY_PARAMS);
60+
builder.field(MODEL_FIELD, modelId);
61+
62+
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
63+
builder.field(MAX_TOKENS, unifiedChatInput.getRequest().maxCompletionTokens());
64+
}
65+
66+
builder.endObject();
67+
return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8));
68+
} catch (IOException e) {
69+
throw new ElasticsearchException("Failed to serialize request payload.", e);
70+
}
71+
}
72+
73+
@Override
74+
public URI getURI() {
75+
return model.uri();
76+
}
77+
78+
@Override
79+
public Request truncate() {
80+
return this;
81+
}
82+
83+
@Override
84+
public boolean[] getTruncationInfo() {
85+
return null;
86+
}
87+
88+
@Override
89+
public String getInferenceEntityId() {
90+
return model.getInferenceEntityId();
91+
}
92+
93+
@Override
94+
public boolean isStreaming() {
95+
return unifiedChatInput.stream();
96+
}
97+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.deepseek.DeepSeekChatCompletionRequest;
16+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
17+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
18+
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseEntity;
19+
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler;
20+
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler;
21+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
22+
23+
import java.util.Objects;
24+
import java.util.function.Supplier;
25+
26+
import static org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs.createUnsupportedTypeException;
27+
28+
public class DeepSeekRequestManager extends BaseRequestManager {
29+
30+
private static final Logger logger = LogManager.getLogger(DeepSeekRequestManager.class);
31+
32+
private static final ResponseHandler CHAT_COMPLETION = createChatCompletionHandler();
33+
private static final ResponseHandler COMPLETION = createCompletionHandler();
34+
35+
private final DeepSeekChatCompletionModel model;
36+
37+
public DeepSeekRequestManager(DeepSeekChatCompletionModel model, ThreadPool threadPool) {
38+
super(threadPool, model.getInferenceEntityId(), model.rateLimitGroup(), model.rateLimitSettings());
39+
this.model = Objects.requireNonNull(model);
40+
}
41+
42+
@Override
43+
public void execute(
44+
InferenceInputs inferenceInputs,
45+
RequestSender requestSender,
46+
Supplier<Boolean> hasRequestCompletedFunction,
47+
ActionListener<InferenceServiceResults> listener
48+
) {
49+
switch (inferenceInputs) {
50+
case UnifiedChatInput uci -> execute(uci, requestSender, hasRequestCompletedFunction, listener);
51+
case ChatCompletionInput cci -> execute(cci, requestSender, hasRequestCompletedFunction, listener);
52+
default -> throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class);
53+
}
54+
}
55+
56+
private void execute(
57+
UnifiedChatInput inferenceInputs,
58+
RequestSender requestSender,
59+
Supplier<Boolean> hasRequestCompletedFunction,
60+
ActionListener<InferenceServiceResults> listener
61+
) {
62+
var request = new DeepSeekChatCompletionRequest(inferenceInputs, model);
63+
execute(new ExecutableInferenceRequest(requestSender, logger, request, CHAT_COMPLETION, hasRequestCompletedFunction, listener));
64+
}
65+
66+
private void execute(
67+
ChatCompletionInput inferenceInputs,
68+
RequestSender requestSender,
69+
Supplier<Boolean> hasRequestCompletedFunction,
70+
ActionListener<InferenceServiceResults> listener
71+
) {
72+
var unifiedInputs = new UnifiedChatInput(inferenceInputs.getInputs(), "user", inferenceInputs.stream());
73+
var request = new DeepSeekChatCompletionRequest(unifiedInputs, model);
74+
execute(new ExecutableInferenceRequest(requestSender, logger, request, COMPLETION, hasRequestCompletedFunction, listener));
75+
}
76+
77+
private static ResponseHandler createChatCompletionHandler() {
78+
return new OpenAiUnifiedChatCompletionResponseHandler("deepseek chat completion", OpenAiChatCompletionResponseEntity::fromResponse);
79+
}
80+
81+
private static ResponseHandler createCompletionHandler() {
82+
return new OpenAiChatCompletionResponseHandler("deepseek completion", OpenAiChatCompletionResponseEntity::fromResponse);
83+
}
84+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionRequestEntity.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec
2121

2222
public static final String USER_FIELD = "user";
2323
private static final String MODEL_FIELD = "model";
24+
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
2425

26+
private final UnifiedChatInput unifiedChatInput;
2527
private final OpenAiChatCompletionModel model;
2628
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
2729

2830
public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) {
29-
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
31+
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
32+
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
3033
this.model = Objects.requireNonNull(model);
3134
}
3235

@@ -41,6 +44,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
4144
builder.field(USER_FIELD, model.getTaskSettings().user());
4245
}
4346

47+
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
48+
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
49+
}
50+
4451
builder.endObject();
4552

4653
return builder;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717

1818
public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implements ToXContentObject {
1919
private static final String MODEL_FIELD = "model";
20+
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
2021

22+
private final UnifiedChatInput unifiedChatInput;
2123
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
2224
private final String modelId;
2325

2426
public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) {
25-
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
27+
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
28+
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
2629
this.modelId = Objects.requireNonNull(modelId);
2730
}
2831

@@ -31,6 +34,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
3134
builder.startObject();
3235
unifiedRequestEntity.toXContent(builder, params);
3336
builder.field(MODEL_FIELD, modelId);
37+
38+
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
39+
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
40+
}
41+
3442
builder.endObject();
3543

3644
return builder;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
3232
public static final String MESSAGES_FIELD = "messages";
3333
private static final String ROLE_FIELD = "role";
3434
private static final String CONTENT_FIELD = "content";
35-
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
3635
private static final String STOP_FIELD = "stop";
3736
private static final String TEMPERATURE_FIELD = "temperature";
3837
private static final String TOOL_CHOICE_FIELD = "tool_choice";
@@ -104,10 +103,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
104103
}
105104
builder.endArray();
106105

107-
if (unifiedRequest.maxCompletionTokens() != null) {
108-
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens());
109-
}
110-
111106
// Underlying providers expect OpenAI to only return 1 possible choice.
112107
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);
113108

0 commit comments

Comments
 (0)