-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Implemented ChatCompletion task for Google VertexAI with Gemini Models #128105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented ChatCompletion task for Google VertexAI with Gemini Models #128105
Conversation
💚 CLA has been signed |
Done, both developers have signed. Thanks! |
Pinging @elastic/ml-core (Team:ML) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR is looking good, here is some initial feedback, I haven't gotten through the whole PR yet.
docs/changelog/128105.yaml
Outdated
@@ -0,0 +1,5 @@ | |||
pr: 128105 | |||
summary: "Google VertexAI integration now supports chat_completion task" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
summary: "Google VertexAI integration now supports chat_completion task" | |
summary: "Adding Google VertexAI chat completion integration" |
@@ -254,6 +254,7 @@ static TransportVersion def(int id) { | |||
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00); | |||
public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00); | |||
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00); | |||
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED = def(9_078_0_00); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll want to backport this to 8.19. To do that we need to reserve a transport version for 8.19 but in the main branch.
Let's add another transport version similar to what I did here: https://github.com/elastic/elasticsearch/pull/126805/files#diff-85e782e9e33a0f8ca8e99b41c17f9d04e3a7981d435abf44a3aa5d954a47cd8fR175
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19 = def(8_841_0_30);
Or whatever the latest version number is (it might be 30, or 31 etc).
import java.util.Objects; | ||
import java.util.function.Supplier; | ||
|
||
public class GoogleVertexAiCompletionRequestManager extends GoogleVertexAiRequestManager { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're trying to transition away from the request manager pattern to avoid the extra class since all the classes are pretty similar.
Here's an example of how we implemented it for voyageai: #124512
Here's how we do it for chat completions in openai: https://github.com/elastic/elasticsearch/blob/d2be03c946c94943dca8fe5da75a125fa70ddaa6/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreator.java
If we could switch to using a generic request manager that'd be great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made to swtich to use a generic request and it compiles and works fine. My only fear is that I had to change the base class of GoogleVertexAiModel
from Model
to RateLimitGroupingModel
. Does that have any implication that i am not aware of?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! No that should be fine. Thanks for making that change.
) { | ||
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); | ||
} catch (Exception e) { | ||
logger.warn("Failed to parse Google Vertex AI error response body", e); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll likely refactor the error logic for all the services to return whatever is sent even if we can't parse it. If we fail to parse the error response, how about we return a new ErrorResponse
but just put the body as the message:
var resultAsString = new String(httpResult.body(), StandardCharsets.UTF_8);
return new ErrorResponse(Strings.format("Unable to parse the Google Vertex AI error, response body: [%s]", resultAsString));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
) { | ||
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); | ||
} catch (Exception e) { | ||
logger.warn("Failed to parse Google Vertex AI error string", e); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as comment above.
"Role [%s] not supported by Google VertexAI ChatCompletion. Supported roles: [%s, %s]", | ||
messageRole, | ||
USER_ROLE, | ||
MODEL_ROLE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a reminder to switch this to assistant
.
|
||
try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, jsonString)) { | ||
XContentParser.Token token = parser.nextToken(); | ||
if (token != XContentParser.Token.START_OBJECT) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we omit this check, what is the error that is returned?
We also might be able to leverage the helper method:
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VertexAI request expected the arguments to be a map, and in the current specs the function arguments is a string, so I am using that method to convert the data between both. If the check is not there, depending on which case, it can fail with org.elasticsearch.xcontent.XContentParseException: Unrecognized token
or return an empty object. I put that check there so we are sure that the string being parsed is an object and not any other json token
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, I believe ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser);
checks the same thing right? Or is the error message it produces not sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will change this to use ensureExpectedToken
method
…medWriteablesProvider. Added InferenceSettingsTests
…ADDED to the right location
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some more suggestions, thanks for implementing the initial changes!
@@ -254,6 +254,7 @@ static TransportVersion def(int id) { | |||
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00); | |||
public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00); | |||
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00); | |||
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19 = def(8_841_0_30); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I meant we'll need two transport versions. Let's move ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19
to be with the other 8_19 style versions.
We'll need to create another one called ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED
and have it's version be 9_078_0_00
or whatever the latest is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it! Fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which one should GoogleVertexAiChatCompletionServiceSettings.getMinimalSupportedVersion
return? Right now it's returning ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this PR we'll want it to point to the 9 version (not the 8_19) one. For the backport we'll switch it to be the 8.19 version.
|
||
try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, jsonString)) { | ||
XContentParser.Token token = parser.nextToken(); | ||
if (token != XContentParser.Token.START_OBJECT) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, I believe ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser);
checks the same thing right? Or is the error message it produces not sufficient?
builder.startArray(PARTS); | ||
for (var systemMessage : systemMessages) { | ||
switch (systemMessage.content()) { | ||
case UnifiedCompletionRequest.ContentString contentString -> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can leave this but just an heads up that when we go to backport the changes to the 8.x branch it's going to complain because that branch isn't on the JDK version that supports this type of switch statement. It might be easier to change it here even though the IDE will complain to avoid having to adjust in the backport. Up to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, will change this switch
and others that I made for if-else
. I think that should work
return; | ||
} | ||
|
||
builder.startArray(TOOLS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: adding some indentation with scoping via {}
could help with understanding the nesting here, optional though.
private static final String ERROR_STATUS_FIELD = "status"; | ||
|
||
public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { | ||
super(requestType, parseFunction, GoogleVertexAiErrorResponse::fromResponse, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an area of the code that we need to refactor. The parseFunction
is only used in non-streaming cases. So I think we can actually pass in an empty lambda style function, maybe just one that's defined statically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, so we can also remove GoogleVertexAiChatCompletionResponseEntity
since we are only doing streaming responses and that class is not being used, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind, saw you response in another comment. Will delete GoogleVertexAiChatCompletionResponseEntity
and refactor the code as you suggested
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can wait to delete it if you like. Until we get a response to you about whether we want to include the completion
task type. If we do that, we'll want to implement both streaming and non-stream for completion
.
|
||
StringBuilder fullText = new StringBuilder(); | ||
|
||
while (parser.nextToken() != XContentParser.Token.END_ARRAY) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can use XContentParserUtils.parseList
here instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this class from this PR. If we are doing completion in another PR I will add it there
while (parser.nextToken() != XContentParser.Token.END_ARRAY) { | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
Chunk chunk = Chunk.PARSER.apply(parser, null); | ||
chunk.extractText().ifPresent(fullText::append); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class would be used for non-streaming scenarios. If we get multiple entries in the array could those be for separate input values in the originating request?
Like:
{"input": ["text 1", "text 2"]}
Would we get 2 items in the array from the upstream server? If so, I don't think we want to combine the text as we'd want to return a list of 2 items below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this class from this PR since the chat completion only supports streaming. If we are doing completion in another PR I will add it there
assertThat( | ||
httpRequest.getBody().toString(), | ||
equalTo( | ||
"{\"messages\":[{\"content\":\"Hello\",\"role\":\"user\"}],\"n\":1,\"stream\":true,\"stream_options\":{\"include_usage\":true},\"model\":\"gemini-2.0-flash-001\"}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use XContentHelper.stripWhitespace()
for things like this. That way we can create more readable multiline string in this file and strip the white space when we compare it for equality.
@@ -1299,6 +1299,64 @@ private InferenceEventsAssertion testUnifiedStream(int responseCode, String resp | |||
} | |||
} | |||
|
|||
public void testUnifiedCompletionInfer_WithGoogleVertexAiModel() throws IOException { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to the google vertex service test file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah sorry, this test got slipped in when we were testing some things. It's not necessary so I will remove it
); | ||
} finally { | ||
// Clean up the thread context | ||
threadPool.getThreadContext().stashContext(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to stash the context here? Typically we terminate the thread pool after tests: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java#L119
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see, that's in some of our other tests. I don't believe we need that. Let me know if the test starts failing after we remove it though.
… new one for ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDDED
…mproved indentation via `{}`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes, left one more question
@Override | ||
public int rateLimitGroupingHash() { | ||
// In VertexAI rate limiting is scoped to the project and the model. URI already has this information so we are using that | ||
return Objects.hash(uri); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify, it's not based on the service account key information too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a link to the docs that indicates this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Will do. https://ai.google.dev/gemini-api/docs/rate-limits
Rate limits are applied per project, not per API key.
Also on the VertexAI quotas https://cloud.google.com/vertex-ai/docs/quotas#request_quotas
The following quotas apply to Vertex AI requests for a given project and supported region...
Some resources may not be affected by the region, but I choose to be conservative and go with a safe default
import java.util.Objects; | ||
import java.util.function.Supplier; | ||
|
||
public class GoogleVertexAiCompletionRequestManager extends GoogleVertexAiRequestManager { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! No that should be fine. Thanks for making that change.
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes!
@elasticsearchmachine test this please |
@elasticmachine test this please |
} | ||
|
||
private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) { | ||
var messageRoleLowered = messageRole.toLowerCase(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like CI is complaining about using the default locale here:
> Task :x-pack:plugin:inference:forbiddenApisMain
| Forbidden method invocation: java.lang.String#toLowerCase() [Uses default locale]
| in org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequestEntity (GoogleVertexAiUnifiedChatCompletionRequestEntity.java:73)
| Scanned 890 class file(s) for forbidden API invocations (in 0.78s), 1 error(s).
|
I think we can use Locale.ROOT
instead.
@elasticmachine test this please |
@@ -0,0 +1,39 @@ | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From CI:
Caused by: org.gradle.api.GradleException: Following test classes do not match naming convention to use suffix 'Tests':
--
| org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAIChatCompletionServiceSettingsTest
The file name needs to be: GoogleVertexAIChatCompletionServiceSettingsTests
(trailing s
).
@elasticmachine test this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we have a failing test:
REPRODUCE WITH: ./gradlew ":x-pack:plugin:inference:qa:inference-service-tests:javaRestTest" --tests "org.elasticsearch.xpack.inference.InferenceGetServicesIT.testGetServicesWithChatCompletionTaskType" -Dtests.seed=54CCE93B3DEDB87E -Dtests.locale=su-Latn -Dtests.timezone=Asia/Aqtobe -Druntime.java=24
--
|
| InferenceGetServicesIT > testGetServicesWithChatCompletionTaskType FAILED
| java.lang.AssertionError:
| Expected: <6>
| but: was <7>
| at __randomizedtesting.SeedInfo.seed([54CCE93B3DEDB87E:DC4FE40BF96FF6C7]:0)
| at org.hamcrest.MatcherAssert.assertThat(MatcherAssert.java:20)
| at org.hamcrest.MatcherAssert.assertThat(MatcherAssert.java:6)
| at org.elasticsearch.test.ESTestCase.assertThat(ESTestCase.java:2653)
| at org.elasticsearch.xpack.inference.InferenceGetServicesIT.testGetServicesWithChatCompletionTaskType(InferenceGetServicesIT.java:154)
I think we just need to bump the value.
Working on it. If I do |
@elasticmachine test this please |
This is the command: That'll run it the same way that CI did. Or if you want to run all the rest tests I think this would work: |
@elasticmachine test this please |
💔 Backport failed
You can use sqren/backport to manually backport by running |
elastic#128105) * Implemented ChatCompletion task for Google VertexAI with Gemini Models * changelog * System Instruction bugfix * Mapping role assistant -> model in vertex ai chat completion request for compatibility * GoogleVertexAI chat completion using SSE events. Removed JsonArrayEventParser * Removed buffer from GoogleVertexAiUnifiedStreamingProcessor * Casting inference inputs with `castoTo` * Registered GoogleVertexAiChatCompletionServiceSettings in InferenceNamedWriteablesProvider. Added InferenceSettingsTests * Changed transport version to 8_19 for vertexai chatcompletion * Fix to transport version. Moved ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED to the right location * VertexAI Chat completion request entity jsonStringToMap using `ensureExpectedToken` * Fixed TransportVersions. Left vertexAi chat completion 8_19 and added new one for ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDDED * Refactor switch statements by if-else for older java compatibility. Improved indentation via `{}` * Removed GoogleVertexAiChatCompletionResponseEntity and refactored code around it. * Removed redundant test `testUnifiedCompletionInfer_WithGoogleVertexAiModel` * Returning whole body when fail to parse response from VertexAI * Refactor use GenericRequestManager instead of GoogleVertexAiCompletionRequestManager * Refactored to constructorArg for mandatory args in GoogleVertexAiUnifiedStreamingProcessor * Changed transport version in GoogleVertexAiChatCompletionServiceSettings * Bugfix in tool calling with role tool * GoogleVertexAiModel added documentation info on rateLimitGroupingHash * [CI] Auto commit changes from spotless * Fix: using Locale.ROOT when calling toLowerCase * Fix: Renamed test class to match convention & modified use of forbidden api * Fix: Failing test in InferenceServicesIT --------- Co-authored-by: lhoet <[email protected]> Co-authored-by: Jonathan Buttner <[email protected]> Co-authored-by: elasticsearchmachine <[email protected]>
elastic#128105) * Implemented ChatCompletion task for Google VertexAI with Gemini Models * changelog * System Instruction bugfix * Mapping role assistant -> model in vertex ai chat completion request for compatibility * GoogleVertexAI chat completion using SSE events. Removed JsonArrayEventParser * Removed buffer from GoogleVertexAiUnifiedStreamingProcessor * Casting inference inputs with `castoTo` * Registered GoogleVertexAiChatCompletionServiceSettings in InferenceNamedWriteablesProvider. Added InferenceSettingsTests * Changed transport version to 8_19 for vertexai chatcompletion * Fix to transport version. Moved ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED to the right location * VertexAI Chat completion request entity jsonStringToMap using `ensureExpectedToken` * Fixed TransportVersions. Left vertexAi chat completion 8_19 and added new one for ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDDED * Refactor switch statements by if-else for older java compatibility. Improved indentation via `{}` * Removed GoogleVertexAiChatCompletionResponseEntity and refactored code around it. * Removed redundant test `testUnifiedCompletionInfer_WithGoogleVertexAiModel` * Returning whole body when fail to parse response from VertexAI * Refactor use GenericRequestManager instead of GoogleVertexAiCompletionRequestManager * Refactored to constructorArg for mandatory args in GoogleVertexAiUnifiedStreamingProcessor * Changed transport version in GoogleVertexAiChatCompletionServiceSettings * Bugfix in tool calling with role tool * GoogleVertexAiModel added documentation info on rateLimitGroupingHash * [CI] Auto commit changes from spotless * Fix: using Locale.ROOT when calling toLowerCase * Fix: Renamed test class to match convention & modified use of forbidden api * Fix: Failing test in InferenceServicesIT --------- Co-authored-by: lhoet <[email protected]> Co-authored-by: Jonathan Buttner <[email protected]> Co-authored-by: elasticsearchmachine <[email protected]>
This PR implements the task type
chat_completion
for Google Vertex AI in the inference apiCollaborator @beltrangs