Skip to content

Fix schema generation for Nullable<T> function parameters. #6596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -289,24 +289,49 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, Js
objSchema.InsertAtStart(TypePropertyName, "string");
}

// Include the type keyword in nullable enum types
if (Nullable.GetUnderlyingType(ctx.TypeInfo.Type)?.IsEnum is true && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName))
{
objSchema.InsertAtStart(TypePropertyName, new JsonArray { (JsonNode)"string", (JsonNode)"null" });
}

// Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand
// schemas with "type": [...], and only understand "type" being a single value.
// In certain configurations STJ represents .NET numeric types as ["string", "number"], which will then lead to an error.
if (TypeIsIntegerWithStringNumberHandling(ctx, objSchema, out string? numericType))
if (TypeIsIntegerWithStringNumberHandling(ctx, objSchema, out string? numericType, out bool isNullable))
{
// We don't want to emit any array for "type". In this case we know it contains "integer" or "number",
// so reduce the type to that alone, assuming it's the most specific type.
// This makes schemas for Int32 (etc) work with Ollama.
JsonObject obj = ConvertSchemaToObject(ref schema);
obj[TypePropertyName] = numericType;
if (isNullable)
{
// If the type is nullable, we still need use a type array
obj[TypePropertyName] = new JsonArray { (JsonNode)numericType, (JsonNode)"null" };
}
else
{
obj[TypePropertyName] = (JsonNode)numericType;
}

_ = obj.Remove(PatternPropertyName);
}

if (Nullable.GetUnderlyingType(ctx.TypeInfo.Type) is Type nullableElement)
{
// Account for bug https://github.com/dotnet/runtime/issues/117493
// To be removed once System.Text.Json v10 becomes the lowest supported version.
// null not inserted in the type keyword for root-level Nullable<T> types.
if (objSchema.TryGetPropertyValue(TypePropertyName, out JsonNode? typeKeyWord) &&
typeKeyWord?.GetValueKind() is JsonValueKind.String)
{
string typeValue = typeKeyWord.GetValue<string>()!;
if (typeValue is not "null")
{
objSchema[TypePropertyName] = new JsonArray { (JsonNode)typeValue, (JsonNode)"null" };
}
}

// Include the type keyword in nullable enum types
if (nullableElement.IsEnum && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName))
{
objSchema.InsertAtStart(TypePropertyName, new JsonArray { (JsonNode)"string", (JsonNode)"null" });
}
}
}

if (ctx.Path.IsEmpty && hasDefaultValue)
Expand Down Expand Up @@ -601,11 +626,12 @@ static JsonArray CreateJsonArray(object?[] values, JsonSerializerOptions seriali
}
}

private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateContext ctx, JsonObject schema, [NotNullWhen(true)] out string? numericType)
private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateContext ctx, JsonObject schema, [NotNullWhen(true)] out string? numericType, out bool isNullable)
{
numericType = null;
isNullable = false;

if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray { Count: 2 } typeArray)
if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray typeArray)
{
bool allowString = false;

Expand All @@ -617,11 +643,23 @@ private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateCont
switch (type)
{
case "integer" or "number":
if (numericType is not null)
{
// Conflicting numeric type
return false;
}

numericType = type;
break;
case "string":
allowString = true;
break;
case "null":
isNullable = true;
break;
default:
// keyword is not valid in the context of numeric types.
return false;
}
}
}
Expand Down Expand Up @@ -665,7 +703,7 @@ private static JsonElement ParseJsonElement(ReadOnlySpan<byte> utf8Json)

if (defaultValue is null || (defaultValue == DBNull.Value && parameterType != typeof(DBNull)))
{
return parameterType.IsValueType
return parameterType.IsValueType && Nullable.GetUnderlyingType(parameterType) is null
#if NET
? RuntimeHelpers.GetUninitializedObject(parameterType)
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ private static class ReflectionHelpers
public static bool IsBuiltInConverter(JsonConverter converter) =>
converter.GetType().Assembly == typeof(JsonConverter).Assembly;

public static bool CanBeNull(Type type) => !type.IsValueType || Nullable.GetUnderlyingType(type) is not null;

public static Type GetElementType(JsonTypeInfo typeInfo)
{
Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.Enumerable or JsonTypeInfoKind.Dictionary, "TypeInfo must be of collection type");
Expand Down
16 changes: 10 additions & 6 deletions src/Shared/JsonSchemaExporter/JsonSchemaExporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -452,20 +452,24 @@ JsonSchema CompleteSchema(ref GenerationState state, JsonSchema schema)

bool IsNullableSchema(ref GenerationState state)
{
// A schema is marked as nullable if either
// A schema is marked as nullable if either:
// 1. We have a schema for a property where either the getter or setter are marked as nullable.
// 2. We have a schema for a reference type, unless we're explicitly treating null-oblivious types as non-nullable
// 2. We have a schema for a Nullable<T> type.
// 3. We have a schema for a reference type, unless we're explicitly treating null-oblivious types as non-nullable.

if (propertyInfo != null || parameterInfo != null)
{
return !isNonNullableType;
}
else

if (Nullable.GetUnderlyingType(typeInfo.Type) is not null)
{
return ReflectionHelpers.CanBeNull(typeInfo.Type) &&
!parentPolymorphicTypeIsNonNullable &&
!state.ExporterOptions.TreatNullObliviousAsNonNullable;
return true;
}

return !typeInfo.Type.IsValueType &&
!parentPolymorphicTypeIsNonNullable &&
!state.ExporterOptions.TreatNullObliviousAsNonNullable;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,29 @@ public static void EqualFunctionCallParameters(
public static void EqualFunctionCallResults(object? expected, object? actual, JsonSerializerOptions? options = null)
=> AreJsonEquivalentValues(expected, actual, options);

private static void AreJsonEquivalentValues(object? expected, object? actual, JsonSerializerOptions? options, string? propertyName = null)
/// <summary>
/// Asserts that the two JSON values are equal.
/// </summary>
public static void EqualJsonValues(JsonElement expectedJson, JsonElement actualJson, string? propertyName = null)
{
options ??= AIJsonUtilities.DefaultOptions;
JsonElement expectedElement = NormalizeToElement(expected, options);
JsonElement actualElement = NormalizeToElement(actual, options);
if (!JsonNode.DeepEquals(
JsonSerializer.SerializeToNode(expectedElement, AIJsonUtilities.DefaultOptions),
JsonSerializer.SerializeToNode(actualElement, AIJsonUtilities.DefaultOptions)))
JsonSerializer.SerializeToNode(expectedJson, AIJsonUtilities.DefaultOptions),
JsonSerializer.SerializeToNode(actualJson, AIJsonUtilities.DefaultOptions)))
{
string message = propertyName is null
? $"Function result does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}"
: $"Parameter '{propertyName}' does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}";
? $"JSON result does not match expected JSON.\r\nExpected: {expectedJson.GetRawText()}\r\nActual: {actualJson.GetRawText()}"
: $"Parameter '{propertyName}' does not match expected JSON.\r\nExpected: {expectedJson.GetRawText()}\r\nActual: {actualJson.GetRawText()}";

throw new XunitException(message);
}
}

private static void AreJsonEquivalentValues(object? expected, object? actual, JsonSerializerOptions? options, string? propertyName = null)
{
options ??= AIJsonUtilities.DefaultOptions;
JsonElement expectedElement = NormalizeToElement(expected, options);
JsonElement actualElement = NormalizeToElement(actual, options);
EqualJsonValues(expectedElement, actualElement, propertyName);

static JsonElement NormalizeToElement(object? value, JsonSerializerOptions options)
=> value is JsonElement e ? e : JsonSerializer.SerializeToElement(value, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,21 @@ public static void CreateFunctionJsonSchema_TreatsIntegralTypesAsInteger_EvenWit
int i = 0;
foreach (JsonProperty property in schemaParameters.EnumerateObject())
{
string numericType = Type.GetTypeCode(parameters[i].ParameterType) is TypeCode.Double or TypeCode.Single or TypeCode.Decimal
? "number"
: "integer";
bool isNullable = false;
Type type = parameters[i].ParameterType;
if (Nullable.GetUnderlyingType(type) is { } elementType)
{
type = elementType;
isNullable = true;
}

string numericType = Type.GetTypeCode(type) is TypeCode.Double or TypeCode.Single or TypeCode.Decimal
? "\"number\""
: "\"integer\"";

JsonElement expected = JsonDocument.Parse($$"""
{
"type": "{{numericType}}"
"type": {{(isNullable ? $"[{numericType}, \"null\"]" : numericType)}}
}
""").RootElement;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Nodes;
Expand Down Expand Up @@ -854,6 +855,71 @@ public async Task AIFunctionFactory_DefaultDefaultParameter()
Assert.Contains("00000000-0000-0000-0000-000000000000,0", result?.ToString());
}

[Fact]
public async Task AIFunctionFactory_NullableParameters()
{
Assert.NotEqual(new StructWithDefaultCtor().Value, default(StructWithDefaultCtor).Value);

AIFunction f = AIFunctionFactory.Create(
(int? limit = null, DateTime? from = null) => Enumerable.Repeat(from ?? default, limit ?? 4).Select(d => d.Year).ToArray(),
serializerOptions: JsonContext.Default.Options);

JsonElement expectedSchema = JsonDocument.Parse("""
{
"type": "object",
"properties": {
"limit": {
"type": ["integer", "null"],
"default": null
},
"from": {
"type": ["string", "null"],
"format": "date-time",
"default": null
}
}
}
""").RootElement;

AssertExtensions.EqualJsonValues(expectedSchema, f.JsonSchema);

object? result = await f.InvokeAsync();
Assert.Contains("[1,1,1,1]", result?.ToString());
}

[Fact]
public async Task AIFunctionFactory_NullableParameters_AllowReadingFromString()
{
JsonSerializerOptions options = new(JsonContext.Default.Options) { NumberHandling = JsonNumberHandling.AllowReadingFromString };
Assert.NotEqual(new StructWithDefaultCtor().Value, default(StructWithDefaultCtor).Value);

AIFunction f = AIFunctionFactory.Create(
(int? limit = null, DateTime? from = null) => Enumerable.Repeat(from ?? default, limit ?? 4).Select(d => d.Year).ToArray(),
serializerOptions: options);

JsonElement expectedSchema = JsonDocument.Parse("""
{
"type": "object",
"properties": {
"limit": {
"type": ["integer", "null"],
"default": null
},
"from": {
"type": ["string", "null"],
"format": "date-time",
"default": null
}
}
}
""").RootElement;

AssertExtensions.EqualJsonValues(expectedSchema, f.JsonSchema);

object? result = await f.InvokeAsync();
Assert.Contains("[1,1,1,1]", result?.ToString());
}

[Fact]
public void AIFunctionFactory_ReturnTypeWithDescriptionAttribute()
{
Expand Down Expand Up @@ -959,5 +1025,7 @@ private static AIFunctionFactoryOptions CreateKeyedServicesSupportOptions() =>
[JsonSerializable(typeof(Guid))]
[JsonSerializable(typeof(StructWithDefaultCtor))]
[JsonSerializable(typeof(B))]
[JsonSerializable(typeof(int?))]
[JsonSerializable(typeof(DateTime?))]
private partial class JsonContext : JsonSerializerContext;
}
10 changes: 8 additions & 2 deletions test/Shared/JsonSchemaExporter/TestData.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ internal sealed record TestData<T>(
T? Value,
[StringSyntax(StringSyntaxAttribute.Json)] string ExpectedJsonSchema,
IEnumerable<T?>? AdditionalValues = null,
object? ExporterOptions = null,
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
System.Text.Json.Schema.JsonSchemaExporterOptions? ExporterOptions = null,
#endif
JsonSerializerOptions? Options = null,
bool WritesNumbersAsStrings = false)
: ITestData
Expand All @@ -22,7 +24,9 @@ internal sealed record TestData<T>(

public Type Type => typeof(T);
object? ITestData.Value => Value;
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
object? ITestData.ExporterOptions => ExporterOptions;
#endif
JsonNode ITestData.ExpectedJsonSchema { get; } =
JsonNode.Parse(ExpectedJsonSchema, documentOptions: _schemaParseOptions)
?? throw new ArgumentNullException("schema must not be null");
Expand All @@ -32,7 +36,7 @@ IEnumerable<ITestData> ITestData.GetTestDataForAllValues()
yield return this;

if (default(T) is null &&
#if NET9_0_OR_GREATER
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
ExporterOptions is System.Text.Json.Schema.JsonSchemaExporterOptions { TreatNullObliviousAsNonNullable: false } &&
#endif
Value is not null)
Expand All @@ -58,7 +62,9 @@ public interface ITestData

JsonNode ExpectedJsonSchema { get; }

#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
object? ExporterOptions { get; }
#endif

JsonSerializerOptions? Options { get; }

Expand Down
Loading
Loading