Skip to content

Commit 1d0ef81

Browse files
eiriktsarpalisCopilot
authored andcommitted
Fix schema generation for Nullable<T> function parameters. (#6596)
* Fix schema generation for Nullable<T> function parameters. * Update src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs Co-authored-by: Copilot <[email protected]> * Update src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs * Incorporate fix from dotnet/runtime#117493. * Update src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs * Extend fix to include AllowReadingFromString. --------- Co-authored-by: Copilot <[email protected]>
1 parent 736eda0 commit 1d0ef81

File tree

8 files changed

+184
-59
lines changed

8 files changed

+184
-59
lines changed

src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -289,24 +289,49 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, Js
289289
objSchema.InsertAtStart(TypePropertyName, "string");
290290
}
291291

292-
// Include the type keyword in nullable enum types
293-
if (Nullable.GetUnderlyingType(ctx.TypeInfo.Type)?.IsEnum is true && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName))
294-
{
295-
objSchema.InsertAtStart(TypePropertyName, new JsonArray { (JsonNode)"string", (JsonNode)"null" });
296-
}
297-
298292
// Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand
299293
// schemas with "type": [...], and only understand "type" being a single value.
300294
// In certain configurations STJ represents .NET numeric types as ["string", "number"], which will then lead to an error.
301-
if (TypeIsIntegerWithStringNumberHandling(ctx, objSchema, out string? numericType))
295+
if (TypeIsIntegerWithStringNumberHandling(ctx, objSchema, out string? numericType, out bool isNullable))
302296
{
303297
// We don't want to emit any array for "type". In this case we know it contains "integer" or "number",
304298
// so reduce the type to that alone, assuming it's the most specific type.
305299
// This makes schemas for Int32 (etc) work with Ollama.
306300
JsonObject obj = ConvertSchemaToObject(ref schema);
307-
obj[TypePropertyName] = numericType;
301+
if (isNullable)
302+
{
303+
// If the type is nullable, we still need use a type array
304+
obj[TypePropertyName] = new JsonArray { (JsonNode)numericType, (JsonNode)"null" };
305+
}
306+
else
307+
{
308+
obj[TypePropertyName] = (JsonNode)numericType;
309+
}
310+
308311
_ = obj.Remove(PatternPropertyName);
309312
}
313+
314+
if (Nullable.GetUnderlyingType(ctx.TypeInfo.Type) is Type nullableElement)
315+
{
316+
// Account for bug https://github.com/dotnet/runtime/issues/117493
317+
// To be removed once System.Text.Json v10 becomes the lowest supported version.
318+
// null not inserted in the type keyword for root-level Nullable<T> types.
319+
if (objSchema.TryGetPropertyValue(TypePropertyName, out JsonNode? typeKeyWord) &&
320+
typeKeyWord?.GetValueKind() is JsonValueKind.String)
321+
{
322+
string typeValue = typeKeyWord.GetValue<string>()!;
323+
if (typeValue is not "null")
324+
{
325+
objSchema[TypePropertyName] = new JsonArray { (JsonNode)typeValue, (JsonNode)"null" };
326+
}
327+
}
328+
329+
// Include the type keyword in nullable enum types
330+
if (nullableElement.IsEnum && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName))
331+
{
332+
objSchema.InsertAtStart(TypePropertyName, new JsonArray { (JsonNode)"string", (JsonNode)"null" });
333+
}
334+
}
310335
}
311336

312337
if (ctx.Path.IsEmpty && hasDefaultValue)
@@ -601,11 +626,12 @@ static JsonArray CreateJsonArray(object?[] values, JsonSerializerOptions seriali
601626
}
602627
}
603628

604-
private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateContext ctx, JsonObject schema, [NotNullWhen(true)] out string? numericType)
629+
private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateContext ctx, JsonObject schema, [NotNullWhen(true)] out string? numericType, out bool isNullable)
605630
{
606631
numericType = null;
632+
isNullable = false;
607633

608-
if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray { Count: 2 } typeArray)
634+
if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray typeArray)
609635
{
610636
bool allowString = false;
611637

@@ -617,11 +643,23 @@ private static bool TypeIsIntegerWithStringNumberHandling(AIJsonSchemaCreateCont
617643
switch (type)
618644
{
619645
case "integer" or "number":
646+
if (numericType is not null)
647+
{
648+
// Conflicting numeric type
649+
return false;
650+
}
651+
620652
numericType = type;
621653
break;
622654
case "string":
623655
allowString = true;
624656
break;
657+
case "null":
658+
isNullable = true;
659+
break;
660+
default:
661+
// keyword is not valid in the context of numeric types.
662+
return false;
625663
}
626664
}
627665
}
@@ -665,7 +703,7 @@ private static JsonElement ParseJsonElement(ReadOnlySpan<byte> utf8Json)
665703

666704
if (defaultValue is null || (defaultValue == DBNull.Value && parameterType != typeof(DBNull)))
667705
{
668-
return parameterType.IsValueType
706+
return parameterType.IsValueType && Nullable.GetUnderlyingType(parameterType) is null
669707
#if NET
670708
? RuntimeHelpers.GetUninitializedObject(parameterType)
671709
#else

src/Shared/JsonSchemaExporter/JsonSchemaExporter.ReflectionHelpers.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ private static class ReflectionHelpers
3131
public static bool IsBuiltInConverter(JsonConverter converter) =>
3232
converter.GetType().Assembly == typeof(JsonConverter).Assembly;
3333

34-
public static bool CanBeNull(Type type) => !type.IsValueType || Nullable.GetUnderlyingType(type) is not null;
35-
3634
public static Type GetElementType(JsonTypeInfo typeInfo)
3735
{
3836
Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.Enumerable or JsonTypeInfoKind.Dictionary, "TypeInfo must be of collection type");

src/Shared/JsonSchemaExporter/JsonSchemaExporter.cs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -452,20 +452,24 @@ JsonSchema CompleteSchema(ref GenerationState state, JsonSchema schema)
452452

453453
bool IsNullableSchema(ref GenerationState state)
454454
{
455-
// A schema is marked as nullable if either
455+
// A schema is marked as nullable if either:
456456
// 1. We have a schema for a property where either the getter or setter are marked as nullable.
457-
// 2. We have a schema for a reference type, unless we're explicitly treating null-oblivious types as non-nullable
457+
// 2. We have a schema for a Nullable<T> type.
458+
// 3. We have a schema for a reference type, unless we're explicitly treating null-oblivious types as non-nullable.
458459

459460
if (propertyInfo != null || parameterInfo != null)
460461
{
461462
return !isNonNullableType;
462463
}
463-
else
464+
465+
if (Nullable.GetUnderlyingType(typeInfo.Type) is not null)
464466
{
465-
return ReflectionHelpers.CanBeNull(typeInfo.Type) &&
466-
!parentPolymorphicTypeIsNonNullable &&
467-
!state.ExporterOptions.TreatNullObliviousAsNonNullable;
467+
return true;
468468
}
469+
470+
return !typeInfo.Type.IsValueType &&
471+
!parentPolymorphicTypeIsNonNullable &&
472+
!state.ExporterOptions.TreatNullObliviousAsNonNullable;
469473
}
470474
}
471475

test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AssertExtensions.cs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,21 +53,29 @@ public static void EqualFunctionCallParameters(
5353
public static void EqualFunctionCallResults(object? expected, object? actual, JsonSerializerOptions? options = null)
5454
=> AreJsonEquivalentValues(expected, actual, options);
5555

56-
private static void AreJsonEquivalentValues(object? expected, object? actual, JsonSerializerOptions? options, string? propertyName = null)
56+
/// <summary>
57+
/// Asserts that the two JSON values are equal.
58+
/// </summary>
59+
public static void EqualJsonValues(JsonElement expectedJson, JsonElement actualJson, string? propertyName = null)
5760
{
58-
options ??= AIJsonUtilities.DefaultOptions;
59-
JsonElement expectedElement = NormalizeToElement(expected, options);
60-
JsonElement actualElement = NormalizeToElement(actual, options);
6161
if (!JsonNode.DeepEquals(
62-
JsonSerializer.SerializeToNode(expectedElement, AIJsonUtilities.DefaultOptions),
63-
JsonSerializer.SerializeToNode(actualElement, AIJsonUtilities.DefaultOptions)))
62+
JsonSerializer.SerializeToNode(expectedJson, AIJsonUtilities.DefaultOptions),
63+
JsonSerializer.SerializeToNode(actualJson, AIJsonUtilities.DefaultOptions)))
6464
{
6565
string message = propertyName is null
66-
? $"Function result does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}"
67-
: $"Parameter '{propertyName}' does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}";
66+
? $"JSON result does not match expected JSON.\r\nExpected: {expectedJson.GetRawText()}\r\nActual: {actualJson.GetRawText()}"
67+
: $"Parameter '{propertyName}' does not match expected JSON.\r\nExpected: {expectedJson.GetRawText()}\r\nActual: {actualJson.GetRawText()}";
6868

6969
throw new XunitException(message);
7070
}
71+
}
72+
73+
private static void AreJsonEquivalentValues(object? expected, object? actual, JsonSerializerOptions? options, string? propertyName = null)
74+
{
75+
options ??= AIJsonUtilities.DefaultOptions;
76+
JsonElement expectedElement = NormalizeToElement(expected, options);
77+
JsonElement actualElement = NormalizeToElement(actual, options);
78+
EqualJsonValues(expectedElement, actualElement, propertyName);
7179

7280
static JsonElement NormalizeToElement(object? value, JsonSerializerOptions options)
7381
=> value is JsonElement e ? e : JsonSerializer.SerializeToElement(value, options);

test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,13 +354,21 @@ public static void CreateFunctionJsonSchema_TreatsIntegralTypesAsInteger_EvenWit
354354
int i = 0;
355355
foreach (JsonProperty property in schemaParameters.EnumerateObject())
356356
{
357-
string numericType = Type.GetTypeCode(parameters[i].ParameterType) is TypeCode.Double or TypeCode.Single or TypeCode.Decimal
358-
? "number"
359-
: "integer";
357+
bool isNullable = false;
358+
Type type = parameters[i].ParameterType;
359+
if (Nullable.GetUnderlyingType(type) is { } elementType)
360+
{
361+
type = elementType;
362+
isNullable = true;
363+
}
364+
365+
string numericType = Type.GetTypeCode(type) is TypeCode.Double or TypeCode.Single or TypeCode.Decimal
366+
? "\"number\""
367+
: "\"integer\"";
360368

361369
JsonElement expected = JsonDocument.Parse($$"""
362370
{
363-
"type": "{{numericType}}"
371+
"type": {{(isNullable ? $"[{numericType}, \"null\"]" : numericType)}}
364372
}
365373
""").RootElement;
366374

test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System;
55
using System.Collections.Generic;
66
using System.ComponentModel;
7+
using System.Linq;
78
using System.Reflection;
89
using System.Text.Json;
910
using System.Text.Json.Nodes;
@@ -838,6 +839,71 @@ public async Task AIFunctionFactory_DefaultDefaultParameter()
838839
Assert.Contains("00000000-0000-0000-0000-000000000000,0", result?.ToString());
839840
}
840841

842+
[Fact]
843+
public async Task AIFunctionFactory_NullableParameters()
844+
{
845+
Assert.NotEqual(new StructWithDefaultCtor().Value, default(StructWithDefaultCtor).Value);
846+
847+
AIFunction f = AIFunctionFactory.Create(
848+
(int? limit = null, DateTime? from = null) => Enumerable.Repeat(from ?? default, limit ?? 4).Select(d => d.Year).ToArray(),
849+
serializerOptions: JsonContext.Default.Options);
850+
851+
JsonElement expectedSchema = JsonDocument.Parse("""
852+
{
853+
"type": "object",
854+
"properties": {
855+
"limit": {
856+
"type": ["integer", "null"],
857+
"default": null
858+
},
859+
"from": {
860+
"type": ["string", "null"],
861+
"format": "date-time",
862+
"default": null
863+
}
864+
}
865+
}
866+
""").RootElement;
867+
868+
AssertExtensions.EqualJsonValues(expectedSchema, f.JsonSchema);
869+
870+
object? result = await f.InvokeAsync();
871+
Assert.Contains("[1,1,1,1]", result?.ToString());
872+
}
873+
874+
[Fact]
875+
public async Task AIFunctionFactory_NullableParameters_AllowReadingFromString()
876+
{
877+
JsonSerializerOptions options = new(JsonContext.Default.Options) { NumberHandling = JsonNumberHandling.AllowReadingFromString };
878+
Assert.NotEqual(new StructWithDefaultCtor().Value, default(StructWithDefaultCtor).Value);
879+
880+
AIFunction f = AIFunctionFactory.Create(
881+
(int? limit = null, DateTime? from = null) => Enumerable.Repeat(from ?? default, limit ?? 4).Select(d => d.Year).ToArray(),
882+
serializerOptions: options);
883+
884+
JsonElement expectedSchema = JsonDocument.Parse("""
885+
{
886+
"type": "object",
887+
"properties": {
888+
"limit": {
889+
"type": ["integer", "null"],
890+
"default": null
891+
},
892+
"from": {
893+
"type": ["string", "null"],
894+
"format": "date-time",
895+
"default": null
896+
}
897+
}
898+
}
899+
""").RootElement;
900+
901+
AssertExtensions.EqualJsonValues(expectedSchema, f.JsonSchema);
902+
903+
object? result = await f.InvokeAsync();
904+
Assert.Contains("[1,1,1,1]", result?.ToString());
905+
}
906+
841907
[Fact]
842908
public void AIFunctionFactory_ReturnTypeWithDescriptionAttribute()
843909
{
@@ -943,5 +1009,7 @@ private static AIFunctionFactoryOptions CreateKeyedServicesSupportOptions() =>
9431009
[JsonSerializable(typeof(Guid))]
9441010
[JsonSerializable(typeof(StructWithDefaultCtor))]
9451011
[JsonSerializable(typeof(B))]
1012+
[JsonSerializable(typeof(int?))]
1013+
[JsonSerializable(typeof(DateTime?))]
9461014
private partial class JsonContext : JsonSerializerContext;
9471015
}

test/Shared/JsonSchemaExporter/TestData.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ internal sealed record TestData<T>(
1313
T? Value,
1414
[StringSyntax(StringSyntaxAttribute.Json)] string ExpectedJsonSchema,
1515
IEnumerable<T?>? AdditionalValues = null,
16-
object? ExporterOptions = null,
16+
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
17+
System.Text.Json.Schema.JsonSchemaExporterOptions? ExporterOptions = null,
18+
#endif
1719
JsonSerializerOptions? Options = null,
1820
bool WritesNumbersAsStrings = false)
1921
: ITestData
@@ -22,7 +24,9 @@ internal sealed record TestData<T>(
2224

2325
public Type Type => typeof(T);
2426
object? ITestData.Value => Value;
27+
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
2528
object? ITestData.ExporterOptions => ExporterOptions;
29+
#endif
2630
JsonNode ITestData.ExpectedJsonSchema { get; } =
2731
JsonNode.Parse(ExpectedJsonSchema, documentOptions: _schemaParseOptions)
2832
?? throw new ArgumentNullException("schema must not be null");
@@ -32,7 +36,7 @@ IEnumerable<ITestData> ITestData.GetTestDataForAllValues()
3236
yield return this;
3337

3438
if (default(T) is null &&
35-
#if NET9_0_OR_GREATER
39+
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
3640
ExporterOptions is System.Text.Json.Schema.JsonSchemaExporterOptions { TreatNullObliviousAsNonNullable: false } &&
3741
#endif
3842
Value is not null)
@@ -58,7 +62,9 @@ public interface ITestData
5862

5963
JsonNode ExpectedJsonSchema { get; }
6064

65+
#if TESTS_JSON_SCHEMA_EXPORTER_POLYFILL
6166
object? ExporterOptions { get; }
67+
#endif
6268

6369
JsonSerializerOptions? Options { get; }
6470

0 commit comments

Comments
 (0)