Skip to content

Commit bdd7b6a

Browse files
authored
[ML] Integrate with DeepSeek API (#122218) (#124796)
Integrating for Chat Completion and Completion task types, both calling the chat completion API for DeepSeek.
1 parent c31a432 commit bdd7b6a

File tree

13 files changed

+1090
-11
lines changed

13 files changed

+1090
-11
lines changed

docs/changelog/122218.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 122218
2+
summary: Integrate with `DeepSeek` API
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ static TransportVersion def(int id) {
193193
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
194194
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07);
195195
public static final TransportVersion INFERENCE_CONTEXT_8_X = def(8_841_0_08);
196+
public static final TransportVersion ML_INFERENCE_DEEPSEEK_8_19 = def(8_841_0_09);
196197

197198
/*
198199
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2626
@SuppressWarnings("unchecked")
2727
public void testGetServicesWithoutTaskType() throws IOException {
2828
List<Object> services = getAllServices();
29-
assertThat(services.size(), equalTo(20));
29+
assertThat(services.size(), equalTo(21));
3030

3131
String[] providers = new String[services.size()];
3232
for (int i = 0; i < services.size(); i++) {
@@ -42,6 +42,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
4242
"azureaistudio",
4343
"azureopenai",
4444
"cohere",
45+
"deepseek",
4546
"elastic",
4647
"elasticsearch",
4748
"googleaistudio",
@@ -115,7 +116,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
115116
@SuppressWarnings("unchecked")
116117
public void testGetServicesWithCompletionTaskType() throws IOException {
117118
List<Object> services = getServices(TaskType.COMPLETION);
118-
assertThat(services.size(), equalTo(9));
119+
assertThat(services.size(), equalTo(10));
119120

120121
String[] providers = new String[services.size()];
121122
for (int i = 0; i < services.size(); i++) {
@@ -131,6 +132,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
131132
"azureaistudio",
132133
"azureopenai",
133134
"cohere",
135+
"deepseek",
134136
"googleaistudio",
135137
"openai",
136138
"streaming_completion_test_service"
@@ -143,15 +145,15 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
143145
@SuppressWarnings("unchecked")
144146
public void testGetServicesWithChatCompletionTaskType() throws IOException {
145147
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
146-
assertThat(services.size(), equalTo(3));
148+
assertThat(services.size(), equalTo(4));
147149

148150
String[] providers = new String[services.size()];
149151
for (int i = 0; i < services.size(); i++) {
150152
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
151153
providers[i] = (String) serviceConfig.get("service");
152154
}
153155

154-
assertArrayEquals(List.of("elastic", "openai", "streaming_completion_test_service").toArray(), providers);
156+
assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
155157
}
156158

157159
@SuppressWarnings("unchecked")

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
5959
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
6060
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
61+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
6162
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
6263
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
6364
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
@@ -153,6 +154,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
153154
addUnifiedNamedWriteables(namedWriteables);
154155

155156
namedWriteables.addAll(StreamingTaskManager.namedWriteables());
157+
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
156158

157159
return namedWriteables;
158160
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
117117
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
118118
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
119+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
119120
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
120121
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
121122
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
@@ -362,6 +363,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
362363
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
363364
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
364365
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
366+
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
365367
ElasticsearchInternalService::new
366368
);
367369
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.deepseek;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.entity.ByteArrayEntity;
13+
import org.elasticsearch.ElasticsearchException;
14+
import org.elasticsearch.common.Strings;
15+
import org.elasticsearch.xcontent.ToXContent;
16+
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xcontent.json.JsonXContent;
18+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
19+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
20+
import org.elasticsearch.xpack.inference.external.request.Request;
21+
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
22+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
23+
24+
import java.io.IOException;
25+
import java.net.URI;
26+
import java.nio.charset.StandardCharsets;
27+
import java.util.Objects;
28+
29+
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
30+
31+
public class DeepSeekChatCompletionRequest implements Request {
32+
private static final String MODEL_FIELD = "model";
33+
private static final String MAX_TOKENS = "max_tokens";
34+
35+
private final DeepSeekChatCompletionModel model;
36+
private final UnifiedChatInput unifiedChatInput;
37+
38+
public DeepSeekChatCompletionRequest(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) {
39+
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
40+
this.model = Objects.requireNonNull(model);
41+
}
42+
43+
@Override
44+
public HttpRequest createHttpRequest() {
45+
HttpPost httpPost = new HttpPost(model.uri());
46+
47+
httpPost.setEntity(createEntity());
48+
49+
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
50+
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
51+
52+
return new HttpRequest(httpPost, getInferenceEntityId());
53+
}
54+
55+
private ByteArrayEntity createEntity() {
56+
var modelId = Objects.requireNonNullElseGet(unifiedChatInput.getRequest().model(), model::model);
57+
try (var builder = JsonXContent.contentBuilder()) {
58+
builder.startObject();
59+
new UnifiedChatCompletionRequestEntity(unifiedChatInput).toXContent(builder, ToXContent.EMPTY_PARAMS);
60+
builder.field(MODEL_FIELD, modelId);
61+
62+
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
63+
builder.field(MAX_TOKENS, unifiedChatInput.getRequest().maxCompletionTokens());
64+
}
65+
66+
builder.endObject();
67+
return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8));
68+
} catch (IOException e) {
69+
throw new ElasticsearchException("Failed to serialize request payload.", e);
70+
}
71+
}
72+
73+
@Override
74+
public URI getURI() {
75+
return model.uri();
76+
}
77+
78+
@Override
79+
public Request truncate() {
80+
return this;
81+
}
82+
83+
@Override
84+
public boolean[] getTruncationInfo() {
85+
return null;
86+
}
87+
88+
@Override
89+
public String getInferenceEntityId() {
90+
return model.getInferenceEntityId();
91+
}
92+
93+
@Override
94+
public boolean isStreaming() {
95+
return unifiedChatInput.stream();
96+
}
97+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.deepseek.DeepSeekChatCompletionRequest;
16+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
17+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
18+
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseEntity;
19+
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler;
20+
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler;
21+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
22+
23+
import java.util.Objects;
24+
import java.util.function.Supplier;
25+
26+
import static org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs.createUnsupportedTypeException;
27+
28+
public class DeepSeekRequestManager extends BaseRequestManager {
29+
30+
private static final Logger logger = LogManager.getLogger(DeepSeekRequestManager.class);
31+
32+
private static final ResponseHandler CHAT_COMPLETION = createChatCompletionHandler();
33+
private static final ResponseHandler COMPLETION = createCompletionHandler();
34+
35+
private final DeepSeekChatCompletionModel model;
36+
37+
public DeepSeekRequestManager(DeepSeekChatCompletionModel model, ThreadPool threadPool) {
38+
super(threadPool, model.getInferenceEntityId(), model.rateLimitGroup(), model.rateLimitSettings());
39+
this.model = Objects.requireNonNull(model);
40+
}
41+
42+
@Override
43+
public void execute(
44+
InferenceInputs inferenceInputs,
45+
RequestSender requestSender,
46+
Supplier<Boolean> hasRequestCompletedFunction,
47+
ActionListener<InferenceServiceResults> listener
48+
) {
49+
if (inferenceInputs instanceof UnifiedChatInput uci) {
50+
execute(uci, requestSender, hasRequestCompletedFunction, listener);
51+
} else if (inferenceInputs instanceof ChatCompletionInput cci) {
52+
execute(cci, requestSender, hasRequestCompletedFunction, listener);
53+
} else {
54+
throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class);
55+
}
56+
}
57+
58+
private void execute(
59+
UnifiedChatInput inferenceInputs,
60+
RequestSender requestSender,
61+
Supplier<Boolean> hasRequestCompletedFunction,
62+
ActionListener<InferenceServiceResults> listener
63+
) {
64+
var request = new DeepSeekChatCompletionRequest(inferenceInputs, model);
65+
execute(new ExecutableInferenceRequest(requestSender, logger, request, CHAT_COMPLETION, hasRequestCompletedFunction, listener));
66+
}
67+
68+
private void execute(
69+
ChatCompletionInput inferenceInputs,
70+
RequestSender requestSender,
71+
Supplier<Boolean> hasRequestCompletedFunction,
72+
ActionListener<InferenceServiceResults> listener
73+
) {
74+
var unifiedInputs = new UnifiedChatInput(inferenceInputs.getInputs(), "user", inferenceInputs.stream());
75+
var request = new DeepSeekChatCompletionRequest(unifiedInputs, model);
76+
execute(new ExecutableInferenceRequest(requestSender, logger, request, COMPLETION, hasRequestCompletedFunction, listener));
77+
}
78+
79+
private static ResponseHandler createChatCompletionHandler() {
80+
return new OpenAiUnifiedChatCompletionResponseHandler("deepseek chat completion", OpenAiChatCompletionResponseEntity::fromResponse);
81+
}
82+
83+
private static ResponseHandler createCompletionHandler() {
84+
return new OpenAiChatCompletionResponseHandler("deepseek completion", OpenAiChatCompletionResponseEntity::fromResponse);
85+
}
86+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionRequestEntity.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec
2121

2222
public static final String USER_FIELD = "user";
2323
private static final String MODEL_FIELD = "model";
24+
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
2425

26+
private final UnifiedChatInput unifiedChatInput;
2527
private final OpenAiChatCompletionModel model;
2628
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
2729

2830
public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) {
29-
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
31+
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
32+
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
3033
this.model = Objects.requireNonNull(model);
3134
}
3235

@@ -41,6 +44,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
4144
builder.field(USER_FIELD, model.getTaskSettings().user());
4245
}
4346

47+
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
48+
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
49+
}
50+
4451
builder.endObject();
4552

4653
return builder;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717

1818
public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implements ToXContentObject {
1919
private static final String MODEL_FIELD = "model";
20+
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
2021

22+
private final UnifiedChatInput unifiedChatInput;
2123
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
2224
private final String modelId;
2325

2426
public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) {
25-
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
27+
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
28+
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
2629
this.modelId = Objects.requireNonNull(modelId);
2730
}
2831

@@ -31,6 +34,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
3134
builder.startObject();
3235
unifiedRequestEntity.toXContent(builder, params);
3336
builder.field(MODEL_FIELD, modelId);
37+
38+
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
39+
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
40+
}
41+
3442
builder.endObject();
3543

3644
return builder;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
3232
public static final String MESSAGES_FIELD = "messages";
3333
private static final String ROLE_FIELD = "role";
3434
private static final String CONTENT_FIELD = "content";
35-
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
3635
private static final String STOP_FIELD = "stop";
3736
private static final String TEMPERATURE_FIELD = "temperature";
3837
private static final String TOOL_CHOICE_FIELD = "tool_choice";
@@ -102,10 +101,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
102101
}
103102
builder.endArray();
104103

105-
if (unifiedRequest.maxCompletionTokens() != null) {
106-
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens());
107-
}
108-
109104
// Underlying providers expect OpenAI to only return 1 possible choice.
110105
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);
111106

0 commit comments

Comments
 (0)