Skip to content

Commit 60bb770

Browse files
[ML] Remove Voyageai request manager classes (#124512) (#124795)
* Removing voyage request managers * Fixing tests (cherry picked from commit 1bee2cc)
1 parent 226d5c5 commit 60bb770

File tree

15 files changed

+183
-434
lines changed

15 files changed

+183
-434
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,16 @@
1010
import org.elasticsearch.inference.InputType;
1111
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1212
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
13+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
14+
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
15+
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
16+
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
1317
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
14-
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager;
15-
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIRerankRequestManager;
18+
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest;
19+
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequest;
20+
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIEmbeddingsResponseEntity;
21+
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIRerankResponseEntity;
22+
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
1623
import org.elasticsearch.xpack.inference.services.ServiceComponents;
1724
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
1825
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
@@ -26,6 +33,15 @@
2633
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the voyageai model type.
2734
*/
2835
public class VoyageAIActionCreator implements VoyageAIActionVisitor {
36+
public static final ResponseHandler EMBEDDINGS_HANDLER = new VoyageAIResponseHandler(
37+
"voyageai text embedding",
38+
VoyageAIEmbeddingsResponseEntity::fromResponse
39+
);
40+
static final ResponseHandler RERANK_HANDLER = new VoyageAIResponseHandler(
41+
"voyageai rerank",
42+
(request, response) -> VoyageAIRerankResponseEntity.fromResponse(response)
43+
);
44+
2945
private final Sender sender;
3046
private final ServiceComponents serviceComponents;
3147

@@ -37,16 +53,30 @@ public VoyageAIActionCreator(Sender sender, ServiceComponents serviceComponents)
3753
@Override
3854
public ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
3955
var overriddenModel = VoyageAIEmbeddingsModel.of(model, taskSettings, inputType);
56+
var manager = new GenericRequestManager<>(
57+
serviceComponents.threadPool(),
58+
overriddenModel,
59+
EMBEDDINGS_HANDLER,
60+
(documentsOnlyInput) -> new VoyageAIEmbeddingsRequest(documentsOnlyInput.getInputs(), overriddenModel),
61+
DocumentsOnlyInput.class
62+
);
63+
4064
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("VoyageAI embeddings");
41-
var requestCreator = VoyageAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
42-
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
65+
return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage);
4366
}
4467

4568
@Override
4669
public ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings) {
4770
var overriddenModel = VoyageAIRerankModel.of(model, taskSettings);
71+
var manager = new GenericRequestManager<>(
72+
serviceComponents.threadPool(),
73+
overriddenModel,
74+
RERANK_HANDLER,
75+
(rerankInput) -> new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model),
76+
QueryAndDocsInputs.class
77+
);
78+
4879
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("VoyageAI rerank");
49-
var requestCreator = VoyageAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
50-
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
80+
return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage);
5181
}
5282
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIEmbeddingsRequestManager.java

Lines changed: 0 additions & 57 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRequestManager.java

Lines changed: 0 additions & 54 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRerankRequestManager.java

Lines changed: 0 additions & 56 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/RequestUtils.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,9 @@ public static URI buildUri(URI accountUri, String service, CheckedSupplier<URI,
3434
}
3535
}
3636

37+
public static URI buildUri(String service, CheckedSupplier<URI, URISyntaxException> uriBuilder) {
38+
return buildUri(null, service, uriBuilder);
39+
}
40+
3741
private RequestUtils() {}
3842
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
import org.elasticsearch.common.Strings;
1313
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1414
import org.elasticsearch.xpack.inference.external.request.Request;
15-
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount;
1615
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
1716
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
18-
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
1917

2018
import java.net.URI;
2119
import java.nio.charset.StandardCharsets;
@@ -24,47 +22,43 @@
2422

2523
public class VoyageAIEmbeddingsRequest extends VoyageAIRequest {
2624

27-
private final VoyageAIAccount account;
2825
private final List<String> input;
29-
private final VoyageAIEmbeddingsServiceSettings serviceSettings;
30-
private final VoyageAIEmbeddingsTaskSettings taskSettings;
31-
private final String model;
32-
private final String inferenceEntityId;
26+
private final VoyageAIEmbeddingsModel embeddingsModel;
3327

3428
public VoyageAIEmbeddingsRequest(List<String> input, VoyageAIEmbeddingsModel embeddingsModel) {
35-
Objects.requireNonNull(embeddingsModel);
36-
37-
account = VoyageAIAccount.of(embeddingsModel);
29+
this.embeddingsModel = Objects.requireNonNull(embeddingsModel);
3830
this.input = Objects.requireNonNull(input);
39-
serviceSettings = embeddingsModel.getServiceSettings();
40-
taskSettings = embeddingsModel.getTaskSettings();
41-
model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
42-
inferenceEntityId = embeddingsModel.getInferenceEntityId();
4331
}
4432

4533
@Override
4634
public HttpRequest createHttpRequest() {
47-
HttpPost httpPost = new HttpPost(account.uri());
35+
HttpPost httpPost = new HttpPost(embeddingsModel.uri());
4836

4937
ByteArrayEntity byteEntity = new ByteArrayEntity(
50-
Strings.toString(new VoyageAIEmbeddingsRequestEntity(input, serviceSettings, taskSettings, model))
51-
.getBytes(StandardCharsets.UTF_8)
38+
Strings.toString(
39+
new VoyageAIEmbeddingsRequestEntity(
40+
input,
41+
embeddingsModel.getServiceSettings(),
42+
embeddingsModel.getTaskSettings(),
43+
embeddingsModel.getServiceSettings().modelId()
44+
)
45+
).getBytes(StandardCharsets.UTF_8)
5246
);
5347
httpPost.setEntity(byteEntity);
5448

55-
decorateWithHeaders(httpPost, account);
49+
decorateWithHeaders(httpPost, embeddingsModel);
5650

5751
return new HttpRequest(httpPost, getInferenceEntityId());
5852
}
5953

6054
@Override
6155
public String getInferenceEntityId() {
62-
return inferenceEntityId;
56+
return embeddingsModel.getInferenceEntityId();
6357
}
6458

6559
@Override
6660
public URI getURI() {
67-
return account.uri();
61+
return embeddingsModel.uri();
6862
}
6963

7064
@Override
@@ -77,11 +71,7 @@ public boolean[] getTruncationInfo() {
7771
return null;
7872
}
7973

80-
public VoyageAIEmbeddingsTaskSettings getTaskSettings() {
81-
return taskSettings;
82-
}
83-
8474
public VoyageAIEmbeddingsServiceSettings getServiceSettings() {
85-
return serviceSettings;
75+
return embeddingsModel.getServiceSettings();
8676
}
8777
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
import org.apache.http.client.methods.HttpPost;
1212
import org.elasticsearch.xcontent.XContentType;
1313
import org.elasticsearch.xpack.inference.external.request.Request;
14-
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount;
14+
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;
1515

1616
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
1717

1818
public abstract class VoyageAIRequest implements Request {
1919

20-
public static void decorateWithHeaders(HttpPost request, VoyageAIAccount account) {
20+
public static void decorateWithHeaders(HttpPost request, VoyageAIModel model) {
2121
request.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
22-
request.setHeader(createAuthBearerHeader(account.apiKey()));
22+
request.setHeader(createAuthBearerHeader(model.apiKey()));
2323
request.setHeader(VoyageAIUtils.createRequestSourceHeader());
2424
}
2525

0 commit comments

Comments
 (0)