Skip to content

Implemented ChatCompletion task for Google VertexAI with Gemini Models #128105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
00a6636
Implemented ChatCompletion task for Google VertexAI with Gemini Models
lhoet-google Apr 29, 2025
9be2a44
changelog
lhoet-google May 16, 2025
c2387e8
System Instruction bugfix
lhoet-google May 19, 2025
50770ea
Mapping role assistant -> model in vertex ai chat completion request …
lhoet-google May 19, 2025
42cbbe2
GoogleVertexAI chat completion using SSE events. Removed JsonArrayEve…
lhoet-google May 20, 2025
fe8e336
Removed buffer from GoogleVertexAiUnifiedStreamingProcessor
lhoet-google May 20, 2025
7c24f93
Casting inference inputs with `castoTo`
lhoet-google May 21, 2025
2140d05
Registered GoogleVertexAiChatCompletionServiceSettings in InferenceNa…
lhoet-google May 21, 2025
42dd376
Changed transport version to 8_19 for vertexai chatcompletion
lhoet-google May 21, 2025
0863316
Fix to transport version. Moved ML_INFERENCE_VERTEXAI_CHATCOMPLETION_…
lhoet-google May 21, 2025
f080e96
VertexAI Chat completion request entity jsonStringToMap using `ensure…
lhoet-google May 21, 2025
8f6648f
Fixed TransportVersions. Left vertexAi chat completion 8_19 and added…
lhoet-google May 22, 2025
848dc7a
Refactor switch statements by if-else for older java compatibility. I…
lhoet-google May 22, 2025
59862c6
Removed GoogleVertexAiChatCompletionResponseEntity and refactored cod…
lhoet-google May 22, 2025
93a7ca7
Removed redundant test `testUnifiedCompletionInfer_WithGoogleVertexAi…
lhoet-google May 22, 2025
7b99b1d
Returning whole body when fail to parse response from VertexAI
lhoet-google May 22, 2025
c05655f
Refactor use GenericRequestManager instead of GoogleVertexAiCompletio…
lhoet-google May 23, 2025
acc864f
Refactored to constructorArg for mandatory args in GoogleVertexAiUnif…
lhoet-google May 26, 2025
c371073
Changed transport version in GoogleVertexAiChatCompletionServiceSettings
lhoet-google May 26, 2025
efb90ba
Bugfix in tool calling with role tool
lhoet-google May 26, 2025
bb68715
Merge branch 'main' into google-vertexai-chatcompletion
lhoet-google May 26, 2025
1ead8c5
GoogleVertexAiModel added documentation info on rateLimitGroupingHash
leo-hoet May 27, 2025
ad9f0e1
Merge branch 'main' into google-vertexai-chatcompletion
leo-hoet May 27, 2025
f4057f3
Merge branch 'main' into google-vertexai-chatcompletion
jonathan-buttner May 28, 2025
38b9ca4
[CI] Auto commit changes from spotless
May 28, 2025
2e8dbee
Fix: using Locale.ROOT when calling toLowerCase
leo-hoet May 28, 2025
ddd19c5
Fix: Renamed test class to match convention & modified use of forbidd…
leo-hoet May 28, 2025
88a2780
Fix: Failing test in InferenceServicesIT
leo-hoet May 29, 2025
b841e4e
Merge branch 'main' into google-vertexai-chatcompletion
leo-hoet May 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/128105.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 128105
summary: "Adding Google VertexAI chat completion integration"
area: Inference
type: enhancement
issues: [ ]
3 changes: 3 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ static TransportVersion def(int id) {
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_37);
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19 = def(8_841_0_38);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -266,6 +267,8 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00);
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_DRY_RUN = def(9_081_0_00);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION = def(9_082_0_00);
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED = def(9_083_0_00);

/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,22 @@ public void testGetServicesWithCompletionTaskType() throws IOException {

public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(6));
assertThat(services.size(), equalTo(7));

var providers = providers(services);

assertThat(
providers,
containsInAnyOrder(
List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "amazon_sagemaker").toArray()
List.of(
"deepseek",
"elastic",
"openai",
"streaming_completion_test_service",
"hugging_face",
"amazon_sagemaker",
"googlevertexai"
).toArray()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
Expand Down Expand Up @@ -453,6 +454,15 @@ private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry
GoogleVertexAiRerankTaskSettings::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
GoogleVertexAiChatCompletionServiceSettings.NAME,
GoogleVertexAiChatCompletionServiceSettings::new
)
);

}

private static void addInternalNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.net.URI;
import java.util.Map;
import java.util.Objects;

public abstract class GoogleVertexAiModel extends Model {
public abstract class GoogleVertexAiModel extends RateLimitGroupingModel {

private final GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings;

Expand Down Expand Up @@ -58,4 +59,18 @@ public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() {
public URI uri() {
return uri;
}

@Override
public int rateLimitGroupingHash() {
// In VertexAI rate limiting is scoped to the project, region and model. URI already has this information so we are using that.
// API Key does not affect the quota
// https://ai.google.dev/gemini-api/docs/rate-limits
// https://cloud.google.com/vertex-ai/docs/quotas
return Objects.hash(uri);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, it's not based on the service account key information too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a link to the docs that indicates this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Will do. https://ai.google.dev/gemini-api/docs/rate-limits

Rate limits are applied per project, not per API key.

Also on the VertexAI quotas https://cloud.google.com/vertex-ai/docs/quotas#request_quotas

The following quotas apply to Vertex AI requests for a given project and supported region...

Some resources may not be affected by the region, but I choose to be conservative and go with a safe default

}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitServiceSettings().rateLimitSettings();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@

import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiErrorResponseEntity;

import java.util.function.Function;

import static org.elasticsearch.core.Strings.format;

public class GoogleVertexAiResponseHandler extends BaseResponseHandler {
Expand All @@ -24,6 +27,15 @@ public GoogleVertexAiResponseHandler(String requestType, ResponseParser parseFun
super(requestType, parseFunction, GoogleVertexAiErrorResponseEntity::fromResponse);
}

public GoogleVertexAiResponseHandler(
String requestType,
ResponseParser parseFunction,
Function<HttpResult, ErrorResponse> errorParseFunction,
boolean canHandleStreamingResponses
) {
super(requestType, parseFunction, errorParseFunction, canHandleStreamingResponses);
}

@Override
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
if (result.isSuccessfulResponse()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,8 @@ public static Map<String, SettingsConfiguration> get() {
var configurationMap = new HashMap<String, SettingsConfiguration>();
configurationMap.put(
SERVICE_ACCOUNT_JSON,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK)).setDescription(
"API Key for the provider you're connecting to."
)
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION))
.setDescription("API Key for the provider you're connecting to.")
.setLabel("Credentials JSON")
.setRequired(true)
.setSensitive(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
Expand All @@ -38,34 +41,42 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID;
import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_ERROR_PREFIX;

public class GoogleVertexAiService extends SenderService {

public static final String NAME = "googlevertexai";

private static final String SERVICE_NAME = "Google Vertex AI";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK);
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.RERANK,
TaskType.CHAT_COMPLETION
);

public static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
InputType.INGEST,
Expand All @@ -76,6 +87,15 @@ public class GoogleVertexAiService extends SenderService {
InputType.INTERNAL_SEARCH
);

private final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
"Google VertexAI chat completion"
);

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.CHAT_COMPLETION);
}

public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
}
Expand Down Expand Up @@ -220,7 +240,24 @@ protected void doUnifiedCompletionInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throwUnsupportedUnifiedCompletionOperation(NAME);
if (model instanceof GoogleVertexAiChatCompletionModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}
var chatCompletionModel = (GoogleVertexAiChatCompletionModel) model;
var updatedChatCompletionModel = GoogleVertexAiChatCompletionModel.of(chatCompletionModel, inputs.getRequest());

var manager = new GenericRequestManager<>(
getServiceComponents().threadPool(),
updatedChatCompletionModel,
COMPLETION_HANDLER,
(unifiedChatInput) -> new GoogleVertexAiUnifiedChatCompletionRequest(unifiedChatInput, updatedChatCompletionModel),
UnifiedChatInput.class
);

var errorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
action.execute(inputs, timeout, listener);
}

@Override
Expand Down Expand Up @@ -320,6 +357,17 @@ private static GoogleVertexAiModel createModel(
secretSettings,
context
);

case CHAT_COMPLETION -> new GoogleVertexAiChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
context
);

default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
Expand Down Expand Up @@ -348,7 +396,7 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
LOCATION,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription(
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription(
"Please provide the GCP region where the Vertex AI API(s) is enabled. "
+ "For more information, refer to the {geminiVertexAIDocs}."
)
Expand Down
Loading
Loading