Skip to content

Commit 42e331b

Browse files
committed
support responseSchema in VertexAiGeminiChatOptions
Closes #2087 Signed-off-by: Andrei Sumin <[email protected]>
1 parent dca1f98 commit 42e331b

File tree

6 files changed

+230
-19
lines changed

6 files changed

+230
-19
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import com.google.cloud.vertexai.api.GenerationConfig;
3636
import com.google.cloud.vertexai.api.Part;
3737
import com.google.cloud.vertexai.api.SafetySetting;
38-
import com.google.cloud.vertexai.api.Schema;
3938
import com.google.cloud.vertexai.api.Tool;
4039
import com.google.cloud.vertexai.api.Tool.GoogleSearch;
4140
import com.google.cloud.vertexai.generativeai.GenerativeModel;
@@ -87,6 +86,7 @@
8786
import org.springframework.ai.tool.definition.ToolDefinition;
8887
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiConstants;
8988
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
89+
import org.springframework.ai.vertexai.gemini.schema.VertexAiSchemaConverter;
9090
import org.springframework.ai.vertexai.gemini.schema.VertexToolCallingManager;
9191
import org.springframework.beans.factory.DisposableBean;
9292
import org.springframework.lang.NonNull;
@@ -375,17 +375,6 @@ else if (rootNode.isArray()) {
375375
}
376376
}
377377

378-
private static Schema jsonToSchema(String json) {
379-
try {
380-
var schemaBuilder = Schema.newBuilder();
381-
JsonFormat.parser().ignoringUnknownFields().merge(json, schemaBuilder);
382-
return schemaBuilder.build();
383-
}
384-
catch (Exception e) {
385-
throw new RuntimeException(e);
386-
}
387-
}
388-
389378
// https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini
390379
@Override
391380
public ChatResponse call(Prompt prompt) {
@@ -676,7 +665,7 @@ GeminiRequest createGeminiRequest(Prompt prompt) {
676665
.map(toolDefinition -> FunctionDeclaration.newBuilder()
677666
.setName(toolDefinition.name())
678667
.setDescription(toolDefinition.description())
679-
.setParameters(jsonToSchema(toolDefinition.inputSchema()))
668+
.setParameters(VertexAiSchemaConverter.fromOpenApiSchema(toolDefinition.inputSchema()))
680669
.build())
681670
.toList();
682671
tools.add(Tool.newBuilder().addAllFunctionDeclarations(functionDeclarations).build());
@@ -738,6 +727,10 @@ private GenerationConfig toGenerationConfig(VertexAiGeminiChatOptions options) {
738727
if (options.getResponseMimeType() != null) {
739728
generationConfigBuilder.setResponseMimeType(options.getResponseMimeType());
740729
}
730+
if (options.getResponseSchema() != null) {
731+
generationConfigBuilder
732+
.setResponseSchema(VertexAiSchemaConverter.fromOpenApiSchema(options.getResponseSchema()));
733+
}
741734
if (options.getFrequencyPenalty() != null) {
742735
generationConfigBuilder.setFrequencyPenalty(options.getFrequencyPenalty().floatValue());
743736
}

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions {
9696
*/
9797
private @JsonProperty("responseMimeType") String responseMimeType;
9898

99+
/**
100+
* Optional. OpenAPI response schema.
101+
*/
102+
private @JsonProperty("responseSchema") String responseSchema;
103+
99104
/**
100105
* Optional. Frequency penalties.
101106
*/
@@ -156,8 +161,8 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr
156161
options.setModel(fromOptions.getModel());
157162
options.setToolCallbacks(fromOptions.getToolCallbacks());
158163
options.setResponseMimeType(fromOptions.getResponseMimeType());
164+
options.setResponseSchema(fromOptions.getResponseSchema());
159165
options.setToolNames(fromOptions.getToolNames());
160-
options.setResponseMimeType(fromOptions.getResponseMimeType());
161166
options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval());
162167
options.setSafetySettings(fromOptions.getSafetySettings());
163168
options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled());
@@ -245,6 +250,14 @@ public void setResponseMimeType(String mimeType) {
245250
this.responseMimeType = mimeType;
246251
}
247252

253+
public String getResponseSchema() {
254+
return this.responseSchema;
255+
}
256+
257+
public void setResponseSchema(String responseSchema) {
258+
this.responseSchema = responseSchema;
259+
}
260+
248261
@Override
249262
public List<ToolCallback> getToolCallbacks() {
250263
return this.toolCallbacks;
@@ -342,6 +355,7 @@ public boolean equals(Object o) {
342355
&& Objects.equals(this.presencePenalty, that.presencePenalty)
343356
&& Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model)
344357
&& Objects.equals(this.responseMimeType, that.responseMimeType)
358+
&& Objects.equals(this.responseSchema, that.responseSchema)
345359
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
346360
&& Objects.equals(this.toolNames, that.toolNames)
347361
&& Objects.equals(this.safetySettings, that.safetySettings)
@@ -353,8 +367,8 @@ public boolean equals(Object o) {
353367
public int hashCode() {
354368
return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount,
355369
this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType,
356-
this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings,
357-
this.internalToolExecutionEnabled, this.toolContext);
370+
this.responseSchema, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval,
371+
this.safetySettings, this.internalToolExecutionEnabled, this.toolContext);
358372
}
359373

360374
@Override
@@ -363,9 +377,9 @@ public String toString() {
363377
+ this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", frequencyPenalty="
364378
+ this.frequencyPenalty + ", presencePenalty=" + this.presencePenalty + ", candidateCount="
365379
+ this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\''
366-
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks
367-
+ ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval
368-
+ ", safetySettings=" + this.safetySettings + '}';
380+
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", responseSchema='" + this.responseSchema
381+
+ ", toolCallbacks=" + this.toolCallbacks + ", toolNames=" + this.toolNames + ", googleSearchRetrieval="
382+
+ this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + '}';
369383
}
370384

371385
@Override
@@ -439,6 +453,11 @@ public Builder responseMimeType(String mimeType) {
439453
return this;
440454
}
441455

456+
public Builder responseSchema(String responseSchema) {
457+
this.options.setResponseSchema(responseSchema);
458+
return this;
459+
}
460+
442461
public Builder toolCallbacks(List<ToolCallback> toolCallbacks) {
443462
this.options.toolCallbacks = toolCallbacks;
444463
return this;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vertexai.gemini.schema;
18+
19+
import com.google.cloud.vertexai.api.Schema;
20+
import com.google.protobuf.util.JsonFormat;
21+
22+
/**
23+
* Utility class for converting OpenAPI schemas to Vertex AI Schema objects.
24+
*
25+
* @since 1.1.0
26+
*/
27+
public final class VertexAiSchemaConverter {
28+
29+
private VertexAiSchemaConverter() {
30+
// Prevent instantiation
31+
}
32+
33+
/**
34+
* Converts an OpenAPI schema string to a Vertex AI Schema object.
35+
* @param openApiSchema The OpenAPI schema in JSON format
36+
* @return A Schema object representing the OpenAPI schema
37+
* @throws RuntimeException if the schema cannot be parsed
38+
*/
39+
public static Schema fromOpenApiSchema(String openApiSchema) {
40+
try {
41+
var schemaBuilder = Schema.newBuilder();
42+
JsonFormat.parser().ignoringUnknownFields().merge(openApiSchema, schemaBuilder);
43+
return schemaBuilder.build();
44+
}
45+
catch (Exception e) {
46+
throw new RuntimeException(e);
47+
}
48+
}
49+
50+
}

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import com.google.cloud.vertexai.VertexAI;
2424
import com.google.cloud.vertexai.api.Content;
2525
import com.google.cloud.vertexai.api.Part;
26+
import com.google.cloud.vertexai.api.Schema;
27+
import com.google.cloud.vertexai.api.Type;
2628
import org.junit.jupiter.api.Test;
2729
import org.junit.jupiter.api.extension.ExtendWith;
2830
import org.mockito.Mock;
@@ -262,6 +264,9 @@ public void createRequestWithGenerationConfigOptions() {
262264
.stopSequences(List.of("stop1", "stop2"))
263265
.candidateCount(1)
264266
.responseMimeType("application/json")
267+
.responseSchema("""
268+
{"type": "OBJECT"}
269+
""")
265270
.build())
266271
.build();
267272

@@ -280,6 +285,8 @@ public void createRequestWithGenerationConfigOptions() {
280285
assertThat(request.model().getGenerationConfig().getStopSequences(0)).isEqualTo("stop1");
281286
assertThat(request.model().getGenerationConfig().getStopSequences(1)).isEqualTo("stop2");
282287
assertThat(request.model().getGenerationConfig().getResponseMimeType()).isEqualTo("application/json");
288+
assertThat(request.model().getGenerationConfig().getResponseSchema())
289+
.isEqualTo(Schema.newBuilder().setType(Type.OBJECT).build());
283290
}
284291

285292
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vertexai.gemini.schema;
18+
19+
import java.util.List;
20+
21+
import com.google.cloud.vertexai.api.Schema;
22+
import com.google.cloud.vertexai.api.Type;
23+
import org.junit.jupiter.api.Test;
24+
25+
import static org.junit.jupiter.api.Assertions.assertEquals;
26+
import static org.junit.jupiter.api.Assertions.assertTrue;
27+
28+
public class VertexAiSchemaConverterTests {
29+
30+
@Test
31+
public void fromOpenApiSchemaShouldConvertGenericFields() {
32+
String openApiSchema = """
33+
{
34+
"type": "OBJECT",
35+
"format": "date-time",
36+
"title": "Title",
37+
"description": "Description",
38+
"nullable": true,
39+
"example": "Example",
40+
"default": "0"
41+
}""";
42+
43+
Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);
44+
45+
assertEquals(Type.OBJECT, schema.getType());
46+
assertEquals("date-time", schema.getFormat());
47+
assertEquals("Title", schema.getTitle());
48+
assertEquals("Description", schema.getDescription());
49+
assertTrue(schema.getNullable());
50+
assertEquals("Example", schema.getExample().getStringValue());
51+
assertEquals("0", schema.getDefault().getStringValue());
52+
}
53+
54+
@Test
55+
public void fromOpenApiSchemaShouldConvertStringFields() {
56+
String openApiSchema = """
57+
{
58+
"type": "STRING",
59+
"enum": ["a", "b", "c"],
60+
"minLength": 1,
61+
"maxLength": 10,
62+
"pattern": "[0-9.]+"
63+
}""";
64+
65+
Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);
66+
67+
assertEquals(Type.STRING, schema.getType());
68+
assertEquals(List.of("a", "b", "c"), schema.getEnumList());
69+
assertEquals(1, schema.getMinLength());
70+
assertEquals(10, schema.getMaxLength());
71+
assertEquals("[0-9.]+", schema.getPattern());
72+
}
73+
74+
@Test
75+
public void fromOpenApiSchemaShouldConvertIntegerAndNumberFields() {
76+
String openApiSchema = """
77+
{
78+
"anyOf": [{"type": "INTEGER"}, {"type": "NUMBER"}],
79+
"minimum": 0,
80+
"maximum": 100
81+
}""";
82+
83+
Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);
84+
85+
assertEquals(Type.TYPE_UNSPECIFIED, schema.getType());
86+
assertEquals(Type.INTEGER, schema.getAnyOf(0).getType());
87+
assertEquals(Type.NUMBER, schema.getAnyOf(1).getType());
88+
assertEquals(0, schema.getMinimum());
89+
assertEquals(100, schema.getMaximum());
90+
}
91+
92+
@Test
93+
public void fromOpenApiSchemaShouldConvertArrayFields() {
94+
String openApiSchema = """
95+
{
96+
"type": "ARRAY",
97+
"items": {
98+
"type": "BOOLEAN"
99+
},
100+
"minItems": 1,
101+
"maxItems": 5
102+
}""";
103+
104+
Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);
105+
106+
assertEquals(Type.ARRAY, schema.getType());
107+
assertEquals(Type.BOOLEAN, schema.getItems().getType());
108+
assertEquals(1, schema.getMinItems());
109+
assertEquals(5, schema.getMaxItems());
110+
}
111+
112+
@Test
113+
public void fromOpenApiSchemaShouldConvertObjectFields() {
114+
String openApiSchema = """
115+
{
116+
"type": "OBJECT",
117+
"properties": {
118+
"property1": {
119+
"type": "STRING"
120+
},
121+
"property2": {
122+
"type": "INTEGER"
123+
}
124+
},
125+
"minProperties": 1,
126+
"maxProperties": 2,
127+
"required": ["property1"],
128+
"propertyOrdering": ["property1", "property2"]
129+
}""";
130+
131+
Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);
132+
133+
assertEquals(Type.OBJECT, schema.getType());
134+
assertEquals(2, schema.getPropertiesMap().size());
135+
assertEquals(1, schema.getMinProperties());
136+
assertEquals(2, schema.getMaxProperties());
137+
assertEquals(List.of("property1"), schema.getRequiredList());
138+
assertEquals(List.of("property1", "property2"), schema.getPropertyOrderingList());
139+
}
140+
141+
}

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ The prefix `spring.ai.vertex.ai.gemini.chat` is the property prefix that lets yo
9292

9393
| spring.ai.vertex.ai.gemini.chat.options.model | Supported https://cloud.google.com/vertex-ai/generative-ai/docs/models#gemini-models[Vertex AI Gemini Chat model] to use include the `gemini-2.0-flash`, `gemini-2.0-flash-lite` and the new `gemini-2.5-pro-preview-03-25`, `gemini-2.5-flash-preview-04-17` models. | gemini-2.0-flash
9494
| spring.ai.vertex.ai.gemini.chat.options.response-mime-type | Output response mimetype of the generated candidate text. | `text/plain`: (default) Text output or `application/json`: JSON response.
95+
| spring.ai.vertex.ai.gemini.chat.options.response-schema | String, containing the output response schema in OpenAPI format, as described in https://ai.google.dev/gemini-api/docs/structured-output#json-schemas. | -
9596
| spring.ai.vertex.ai.gemini.chat.options.google-search-retrieval | Use Google search Grounding feature | `true` or `false`, default `false`.
9697
| spring.ai.vertex.ai.gemini.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising responses from the generative. This value specifies default to be used by the backend while making the call to the generative. | 0.7
9798
| spring.ai.vertex.ai.gemini.chat.options.top-k | The maximum number of tokens to consider when sampling. The generative uses combined Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens. | -

0 commit comments

Comments
 (0)