diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 9ba34ca3d05..dd0160b11a1 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -35,7 +35,6 @@ import com.google.cloud.vertexai.api.GenerationConfig; import com.google.cloud.vertexai.api.Part; import com.google.cloud.vertexai.api.SafetySetting; -import com.google.cloud.vertexai.api.Schema; import com.google.cloud.vertexai.api.Tool; import com.google.cloud.vertexai.api.Tool.GoogleSearch; import com.google.cloud.vertexai.generativeai.GenerativeModel; @@ -88,6 +87,7 @@ import org.springframework.ai.vertexai.gemini.api.VertexAiGeminiApi; import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiConstants; import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting; +import org.springframework.ai.vertexai.gemini.schema.VertexAiSchemaConverter; import org.springframework.ai.vertexai.gemini.schema.VertexToolCallingManager; import org.springframework.beans.factory.DisposableBean; import org.springframework.lang.NonNull; @@ -376,17 +376,6 @@ else if (rootNode.isArray()) { } } - private static Schema jsonToSchema(String json) { - try { - var schemaBuilder = Schema.newBuilder(); - JsonFormat.parser().ignoringUnknownFields().merge(json, schemaBuilder); - return schemaBuilder.build(); - } - catch (Exception e) { - throw new RuntimeException(e); - } - } - // https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini @Override public ChatResponse call(Prompt prompt) { @@ -697,7 +686,7 @@ GeminiRequest createGeminiRequest(Prompt prompt) { .map(toolDefinition -> FunctionDeclaration.newBuilder() .setName(toolDefinition.name()) .setDescription(toolDefinition.description()) - .setParameters(jsonToSchema(toolDefinition.inputSchema())) + .setParameters(VertexAiSchemaConverter.fromOpenApiSchema(toolDefinition.inputSchema())) .build()) .toList(); tools.add(Tool.newBuilder().addAllFunctionDeclarations(functionDeclarations).build()); @@ -759,6 +748,10 @@ private GenerationConfig toGenerationConfig(VertexAiGeminiChatOptions options) { if (options.getResponseMimeType() != null) { generationConfigBuilder.setResponseMimeType(options.getResponseMimeType()); } + if (options.getResponseSchema() != null) { + generationConfigBuilder + .setResponseSchema(VertexAiSchemaConverter.fromOpenApiSchema(options.getResponseSchema())); + } if (options.getFrequencyPenalty() != null) { generationConfigBuilder.setFrequencyPenalty(options.getFrequencyPenalty().floatValue()); } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 69f32c8440c..c3534ad6d44 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -110,6 +110,11 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions { */ private @JsonProperty("responseMimeType") String responseMimeType; + /** + * Optional. OpenAPI response schema. + */ + private @JsonProperty("responseSchema") String responseSchema; + /** * Optional. Frequency penalties. */ @@ -170,8 +175,8 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr options.setModel(fromOptions.getModel()); options.setToolCallbacks(fromOptions.getToolCallbacks()); options.setResponseMimeType(fromOptions.getResponseMimeType()); + options.setResponseSchema(fromOptions.getResponseSchema()); options.setToolNames(fromOptions.getToolNames()); - options.setResponseMimeType(fromOptions.getResponseMimeType()); options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval()); options.setSafetySettings(fromOptions.getSafetySettings()); options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()); @@ -265,6 +270,14 @@ public void setResponseMimeType(String mimeType) { this.responseMimeType = mimeType; } + public String getResponseSchema() { + return this.responseSchema; + } + + public void setResponseSchema(String responseSchema) { + this.responseSchema = responseSchema; + } + @Override public List getToolCallbacks() { return this.toolCallbacks; @@ -374,6 +387,7 @@ public boolean equals(Object o) { && Objects.equals(this.presencePenalty, that.presencePenalty) && Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model) && Objects.equals(this.responseMimeType, that.responseMimeType) + && Objects.equals(this.responseSchema, that.responseSchema) && Objects.equals(this.toolCallbacks, that.toolCallbacks) && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.safetySettings, that.safetySettings) @@ -386,8 +400,9 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType, - this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings, - this.internalToolExecutionEnabled, this.toolContext, this.logprobs, this.responseLogprobs); + this.responseSchema, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, + this.safetySettings, this.internalToolExecutionEnabled, this.toolContext, this.logprobs, + this.responseLogprobs); } @Override @@ -396,10 +411,10 @@ public String toString() { + this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", frequencyPenalty=" + this.frequencyPenalty + ", presencePenalty=" + this.presencePenalty + ", candidateCount=" + this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\'' - + ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks - + ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval - + ", safetySettings=" + this.safetySettings + ", logProbs=" + this.logprobs + ", responseLogprobs=" - + this.responseLogprobs + '}'; + + ", responseMimeType='" + this.responseMimeType + '\'' + ", responseSchema='" + this.responseSchema + + ", toolCallbacks=" + this.toolCallbacks + ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + + this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + ", logProbs=" + this.logprobs + + ", responseLogprobs=" + this.responseLogprobs + '}'; } @Override @@ -473,6 +488,11 @@ public Builder responseMimeType(String mimeType) { return this; } + public Builder responseSchema(String responseSchema) { + this.options.setResponseSchema(responseSchema); + return this; + } + public Builder toolCallbacks(List toolCallbacks) { this.options.toolCallbacks = toolCallbacks; return this; diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/schema/VertexAiSchemaConverter.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/schema/VertexAiSchemaConverter.java new file mode 100644 index 00000000000..40a35ffe841 --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/schema/VertexAiSchemaConverter.java @@ -0,0 +1,50 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.gemini.schema; + +import com.google.cloud.vertexai.api.Schema; +import com.google.protobuf.util.JsonFormat; + +/** + * Utility class for converting OpenAPI schemas to Vertex AI Schema objects. + * + * @since 1.1.0 + */ +public final class VertexAiSchemaConverter { + + private VertexAiSchemaConverter() { + // Prevent instantiation + } + + /** + * Converts an OpenAPI schema string to a Vertex AI Schema object. + * @param openApiSchema The OpenAPI schema in JSON format + * @return A Schema object representing the OpenAPI schema + * @throws RuntimeException if the schema cannot be parsed + */ + public static Schema fromOpenApiSchema(String openApiSchema) { + try { + var schemaBuilder = Schema.newBuilder(); + JsonFormat.parser().ignoringUnknownFields().merge(openApiSchema, schemaBuilder); + return schemaBuilder.build(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + +} diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java index a0ce5d23305..d45a3034595 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java @@ -23,6 +23,8 @@ import com.google.cloud.vertexai.VertexAI; import com.google.cloud.vertexai.api.Content; import com.google.cloud.vertexai.api.Part; +import com.google.cloud.vertexai.api.Schema; +import com.google.cloud.vertexai.api.Type; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -264,6 +266,9 @@ public void createRequestWithGenerationConfigOptions() { .responseMimeType("application/json") .responseLogprobs(true) .logprobs(2) + .responseSchema(""" + {"type": "OBJECT"} + """) .build()) .build(); @@ -284,6 +289,8 @@ public void createRequestWithGenerationConfigOptions() { assertThat(request.model().getGenerationConfig().getResponseMimeType()).isEqualTo("application/json"); assertThat(request.model().getGenerationConfig().getLogprobs()).isEqualTo(2); assertThat(request.model().getGenerationConfig().getResponseLogprobs()).isEqualTo(true); + assertThat(request.model().getGenerationConfig().getResponseSchema()) + .isEqualTo(Schema.newBuilder().setType(Type.OBJECT).build()); } } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/schema/VertexAiSchemaConverterTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/schema/VertexAiSchemaConverterTests.java new file mode 100644 index 00000000000..1fae7fab705 --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/schema/VertexAiSchemaConverterTests.java @@ -0,0 +1,141 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.gemini.schema; + +import java.util.List; + +import com.google.cloud.vertexai.api.Schema; +import com.google.cloud.vertexai.api.Type; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class VertexAiSchemaConverterTests { + + @Test + public void fromOpenApiSchemaShouldConvertGenericFields() { + String openApiSchema = """ + { + "type": "OBJECT", + "format": "date-time", + "title": "Title", + "description": "Description", + "nullable": true, + "example": "Example", + "default": "0" + }"""; + + Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema); + + assertEquals(Type.OBJECT, schema.getType()); + assertEquals("date-time", schema.getFormat()); + assertEquals("Title", schema.getTitle()); + assertEquals("Description", schema.getDescription()); + assertTrue(schema.getNullable()); + assertEquals("Example", schema.getExample().getStringValue()); + assertEquals("0", schema.getDefault().getStringValue()); + } + + @Test + public void fromOpenApiSchemaShouldConvertStringFields() { + String openApiSchema = """ + { + "type": "STRING", + "enum": ["a", "b", "c"], + "minLength": 1, + "maxLength": 10, + "pattern": "[0-9.]+" + }"""; + + Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema); + + assertEquals(Type.STRING, schema.getType()); + assertEquals(List.of("a", "b", "c"), schema.getEnumList()); + assertEquals(1, schema.getMinLength()); + assertEquals(10, schema.getMaxLength()); + assertEquals("[0-9.]+", schema.getPattern()); + } + + @Test + public void fromOpenApiSchemaShouldConvertIntegerAndNumberFields() { + String openApiSchema = """ + { + "anyOf": [{"type": "INTEGER"}, {"type": "NUMBER"}], + "minimum": 0, + "maximum": 100 + }"""; + + Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema); + + assertEquals(Type.TYPE_UNSPECIFIED, schema.getType()); + assertEquals(Type.INTEGER, schema.getAnyOf(0).getType()); + assertEquals(Type.NUMBER, schema.getAnyOf(1).getType()); + assertEquals(0, schema.getMinimum()); + assertEquals(100, schema.getMaximum()); + } + + @Test + public void fromOpenApiSchemaShouldConvertArrayFields() { + String openApiSchema = """ + { + "type": "ARRAY", + "items": { + "type": "BOOLEAN" + }, + "minItems": 1, + "maxItems": 5 + }"""; + + Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema); + + assertEquals(Type.ARRAY, schema.getType()); + assertEquals(Type.BOOLEAN, schema.getItems().getType()); + assertEquals(1, schema.getMinItems()); + assertEquals(5, schema.getMaxItems()); + } + + @Test + public void fromOpenApiSchemaShouldConvertObjectFields() { + String openApiSchema = """ + { + "type": "OBJECT", + "properties": { + "property1": { + "type": "STRING" + }, + "property2": { + "type": "INTEGER" + } + }, + "minProperties": 1, + "maxProperties": 2, + "required": ["property1"], + "propertyOrdering": ["property1", "property2"] + }"""; + + Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema); + + assertEquals(Type.OBJECT, schema.getType()); + assertEquals(2, schema.getPropertiesMap().size()); + assertEquals(1, schema.getMinProperties()); + assertEquals(2, schema.getMaxProperties()); + assertEquals(List.of("property1"), schema.getRequiredList()); + assertEquals(List.of("property1", "property2"), schema.getPropertyOrderingList()); + } + +} diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc index 32c41a7ac67..605f56f2c69 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc @@ -92,6 +92,7 @@ The prefix `spring.ai.vertex.ai.gemini.chat` is the property prefix that lets yo | spring.ai.vertex.ai.gemini.chat.options.model | Supported https://cloud.google.com/vertex-ai/generative-ai/docs/models#gemini-models[Vertex AI Gemini Chat model] to use include the `gemini-2.0-flash`, `gemini-2.0-flash-lite` and the new `gemini-2.5-pro-preview-03-25`, `gemini-2.5-flash-preview-04-17` models. | gemini-2.0-flash | spring.ai.vertex.ai.gemini.chat.options.response-mime-type | Output response mimetype of the generated candidate text. | `text/plain`: (default) Text output or `application/json`: JSON response. +| spring.ai.vertex.ai.gemini.chat.options.response-schema | String, containing the output response schema in OpenAPI format, as described in https://ai.google.dev/gemini-api/docs/structured-output#json-schemas. | - | spring.ai.vertex.ai.gemini.chat.options.google-search-retrieval | Use Google search Grounding feature | `true` or `false`, default `false`. | spring.ai.vertex.ai.gemini.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising responses from the generative. This value specifies default to be used by the backend while making the call to the generative. | 0.7 | spring.ai.vertex.ai.gemini.chat.options.top-k | The maximum number of tokens to consider when sampling. The generative uses combined Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens. | -