Skip to content

Commit db00a61

Browse files
authored
Reuse generated code when possible (#1821)
* Initialize size improvements * Also optimize ccw * Fix build * Add comments and renaming
1 parent eb67781 commit db00a61

File tree

1 file changed

+130
-43
lines changed

1 file changed

+130
-43
lines changed

src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs

Lines changed: 130 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -773,21 +773,21 @@ private static void GenerateVtableAttributes(
773773
GenerateVtableAttributes(sourceProductionContext.AddSource, value.vtableAttributes, value.context.properties.isCsWinRTComponent, value.context.escapedAssemblyName);
774774
}
775775

776-
internal static string GenerateVtableEntry(VtableAttribute vtableAttribute, string escapedAssemblyName)
776+
internal static string GenerateVtableEntry(VtableEntry vtableEntry, string escapedAssemblyName)
777777
{
778778
StringBuilder source = new();
779779

780-
foreach (var genericInterface in vtableAttribute.GenericInterfaces)
780+
foreach (var genericInterface in vtableEntry.GenericInterfaces)
781781
{
782782
source.AppendLine(GenericVtableInitializerStrings.GetInstantiationInitFunction(
783783
genericInterface.GenericDefinition,
784784
genericInterface.GenericParameters,
785785
escapedAssemblyName));
786786
}
787787

788-
if (vtableAttribute.IsDelegate)
788+
if (vtableEntry.IsDelegate)
789789
{
790-
var @interface = vtableAttribute.Interfaces.First();
790+
var @interface = vtableEntry.Interfaces.First();
791791
source.AppendLine();
792792
source.AppendLine($$"""
793793
var delegateInterface = new global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry
@@ -799,15 +799,15 @@ internal static string GenerateVtableEntry(VtableAttribute vtableAttribute, stri
799799
return global::WinRT.DelegateTypeDetails<{{@interface}}>.GetExposedInterfaces(delegateInterface);
800800
""");
801801
}
802-
else if (vtableAttribute.Interfaces.Any())
802+
else if (vtableEntry.Interfaces.Any())
803803
{
804804
source.AppendLine();
805805
source.AppendLine($$"""
806806
return new global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry[]
807807
{
808808
""");
809809

810-
foreach (var @interface in vtableAttribute.Interfaces)
810+
foreach (var @interface in vtableEntry.Interfaces)
811811
{
812812
var genericStartIdx = @interface.IndexOf('<');
813813
var interfaceStaticsMethod = @interface[..(genericStartIdx == -1 ? @interface.Length : genericStartIdx)] + "Methods";
@@ -840,6 +840,10 @@ internal static string GenerateVtableEntry(VtableAttribute vtableAttribute, stri
840840

841841
internal static void GenerateVtableAttributes(Action<string, string> addSource, ImmutableArray<VtableAttribute> vtableAttributes, bool isCsWinRTComponentFromAotOptimizer, string escapedAssemblyName)
842842
{
843+
var vtableEntryToVtableClassName = new Dictionary<VtableEntry, string>();
844+
StringBuilder vtableClassesSource = new();
845+
bool firstVtableClass = true;
846+
843847
// Using ToImmutableHashSet to avoid duplicate entries from the use of partial classes by the developer
844848
// to split out their implementation. When they do that, we will get multiple entries here for that
845849
// and try to generate the same attribute and file with the same data as we use the semantic model
@@ -850,11 +854,10 @@ internal static void GenerateVtableAttributes(Action<string, string> addSource,
850854
// from the AOT optimizer, then any public types are not handled
851855
// right now as they are handled by the WinRT component source generator
852856
// calling this.
853-
if (((isCsWinRTComponentFromAotOptimizer && !vtableAttribute.IsPublic) || !isCsWinRTComponentFromAotOptimizer) &&
857+
if (((isCsWinRTComponentFromAotOptimizer && !vtableAttribute.IsPublic) || !isCsWinRTComponentFromAotOptimizer) &&
854858
vtableAttribute.Interfaces.Any())
855859
{
856860
StringBuilder source = new();
857-
source.AppendLine("using static WinRT.TypeExtensions;\n");
858861
if (!vtableAttribute.IsGlobalNamespace)
859862
{
860863
source.AppendLine($$"""
@@ -863,6 +866,16 @@ namespace {{vtableAttribute.Namespace}}
863866
""");
864867
}
865868

869+
// Check if this class shares the same vtable as another class. If so, reuse the same generated class for it.
870+
VtableEntry entry = new(vtableAttribute.Interfaces, vtableAttribute.GenericInterfaces, vtableAttribute.IsDelegate);
871+
bool vtableEntryExists = vtableEntryToVtableClassName.TryGetValue(entry, out var ccwClassName);
872+
if (!vtableEntryExists)
873+
{
874+
var @namespace = vtableAttribute.IsGlobalNamespace ? "" : $"{vtableAttribute.Namespace}.";
875+
ccwClassName = GeneratorHelper.EscapeTypeNameForIdentifier(@namespace + vtableAttribute.ClassName);
876+
vtableEntryToVtableClassName.Add(entry, ccwClassName);
877+
}
878+
866879
var escapedClassName = GeneratorHelper.EscapeTypeNameForIdentifier(vtableAttribute.ClassName);
867880

868881
// Simple case when the type is not nested
@@ -874,7 +887,7 @@ namespace {{vtableAttribute.Namespace}}
874887
}
875888

876889
source.AppendLine($$"""
877-
[global::WinRT.WinRTExposedType(typeof({{escapedClassName}}WinRTTypeDetails))]
890+
[global::WinRT.WinRTExposedType(typeof(global::WinRT.{{escapedAssemblyName}}VtableClasses.{{ccwClassName}}WinRTTypeDetails))]
878891
partial class {{vtableAttribute.ClassName}}
879892
{
880893
}
@@ -900,7 +913,7 @@ partial class {{vtableAttribute.ClassName}}
900913
}
901914

902915
source.AppendLine($$"""
903-
[global::WinRT.WinRTExposedType(typeof({{escapedClassName}}WinRTTypeDetails))]
916+
[global::WinRT.WinRTExposedType(typeof(global::WinRT.{{escapedAssemblyName}}VtableClasses.{{ccwClassName}}WinRTTypeDetails))]
904917
partial {{classHierarchy[0].GetTypeKeyword()}} {{classHierarchy[0].QualifiedName}}
905918
{
906919
}
@@ -913,62 +926,78 @@ partial class {{vtableAttribute.ClassName}}
913926
}
914927
}
915928

916-
source.AppendLine();
917-
source.AppendLine($$"""
918-
internal sealed class {{escapedClassName}}WinRTTypeDetails : global::WinRT.IWinRTExposedTypeDetails
929+
// Only generate class, if this is the first time we run into this set of vtables.
930+
if (!vtableEntryExists)
931+
{
932+
if (firstVtableClass)
933+
{
934+
vtableClassesSource.AppendLine($$"""
935+
namespace WinRT.{{escapedAssemblyName}}VtableClasses
936+
{
937+
""");
938+
firstVtableClass = false;
939+
}
940+
else
941+
{
942+
vtableClassesSource.AppendLine();
943+
}
944+
945+
vtableClassesSource.AppendLine($$"""
946+
internal sealed class {{ccwClassName}}WinRTTypeDetails : global::WinRT.IWinRTExposedTypeDetails
919947
{
920948
public global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry[] GetExposedInterfaces()
921949
{
922950
""");
923951

924-
if (vtableAttribute.Interfaces.Any())
925-
{
926-
foreach (var genericInterface in vtableAttribute.GenericInterfaces)
952+
if (vtableAttribute.Interfaces.Any())
927953
{
928-
source.AppendLine(GenericVtableInitializerStrings.GetInstantiationInitFunction(
929-
genericInterface.GenericDefinition,
930-
genericInterface.GenericParameters,
931-
escapedAssemblyName));
932-
}
954+
foreach (var genericInterface in vtableAttribute.GenericInterfaces)
955+
{
956+
vtableClassesSource.AppendLine(GenericVtableInitializerStrings.GetInstantiationInitFunction(
957+
genericInterface.GenericDefinition,
958+
genericInterface.GenericParameters,
959+
escapedAssemblyName));
960+
}
933961

934-
source.AppendLine();
935-
source.AppendLine($$"""
962+
vtableClassesSource.AppendLine();
963+
vtableClassesSource.AppendLine($$"""
936964
return new global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry[]
937965
{
938966
""");
939967

940-
foreach (var @interface in vtableAttribute.Interfaces)
941-
{
942-
var genericStartIdx = @interface.IndexOf('<');
943-
var interfaceStaticsMethod = @interface[..(genericStartIdx == -1 ? @interface.Length : genericStartIdx)] + "Methods";
944-
if (genericStartIdx != -1)
968+
foreach (var @interface in vtableAttribute.Interfaces)
945969
{
946-
interfaceStaticsMethod += @interface[genericStartIdx..@interface.Length];
947-
}
970+
var genericStartIdx = @interface.IndexOf('<');
971+
var interfaceStaticsMethod = @interface[..(genericStartIdx == -1 ? @interface.Length : genericStartIdx)] + "Methods";
972+
if (genericStartIdx != -1)
973+
{
974+
interfaceStaticsMethod += @interface[genericStartIdx..@interface.Length];
975+
}
948976

949-
source.AppendLine($$"""
977+
vtableClassesSource.AppendLine($$"""
950978
new global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry
951979
{
952980
IID = global::ABI.{{interfaceStaticsMethod}}.IID,
953981
Vtable = global::ABI.{{interfaceStaticsMethod}}.AbiToProjectionVftablePtr
954982
},
955983
""");
956-
}
957-
source.AppendLine($$"""
984+
}
985+
vtableClassesSource.AppendLine($$"""
958986
};
959987
""");
960-
}
961-
else
962-
{
963-
source.AppendLine($$"""
988+
}
989+
else
990+
{
991+
vtableClassesSource.AppendLine($$"""
964992
return global::System.Array.Empty<global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry>();
965993
""");
966-
}
994+
}
967995

968-
source.AppendLine($$"""
996+
vtableClassesSource.AppendLine($$"""
969997
}
970998
}
971999
""");
1000+
}
9721001

9731002
if (!vtableAttribute.IsGlobalNamespace)
9741003
{
@@ -979,6 +1008,12 @@ internal sealed class {{escapedClassName}}WinRTTypeDetails : global::WinRT.IWinR
9791008
addSource($"{prefix}{escapedClassName}.WinRTVtable.g.cs", source.ToString());
9801009
}
9811010
}
1011+
1012+
if (vtableClassesSource.Length != 0)
1013+
{
1014+
vtableClassesSource.AppendLine("}");
1015+
addSource($"WinRTCCWVtable.g.cs", vtableClassesSource.ToString());
1016+
}
9821017
}
9831018

9841019
private static void GenerateCCWForGenericInstantiation(
@@ -1444,12 +1479,37 @@ private static ComWrappers.ComInterfaceEntry[] LookupVtableEntries(Type type)
14441479
""");
14451480
}
14461481

1482+
// We gather all the class names that have the same vtable and generate it
1483+
// as part of one if to reduce generated code.
1484+
var vtableEntryToClassNameList = new Dictionary<VtableEntry, List<string>>();
14471485
foreach (var vtableAttribute in value.vtableAttributes.ToImmutableHashSet())
14481486
{
1487+
VtableEntry entry = new(vtableAttribute.Interfaces, vtableAttribute.GenericInterfaces, vtableAttribute.IsDelegate);
1488+
if (!vtableEntryToClassNameList.TryGetValue(entry, out var classNameList))
1489+
{
1490+
classNameList = new List<string>();
1491+
vtableEntryToClassNameList.Add(entry, classNameList);
1492+
}
1493+
classNameList.Add(vtableAttribute.VtableLookupClassName);
1494+
}
1495+
1496+
foreach (var vtableEntry in vtableEntryToClassNameList)
1497+
{
1498+
source.AppendLine($$"""
1499+
if (typeName == "{{vtableEntry.Value[0]}}"
1500+
""");
1501+
1502+
for (var i = 1; i < vtableEntry.Value.Count; i++)
1503+
{
1504+
source.AppendLine($$"""
1505+
|| typeName == "{{vtableEntry.Value[i]}}"
1506+
""");
1507+
}
1508+
14491509
source.AppendLine($$"""
1450-
if (typeName == "{{vtableAttribute.VtableLookupClassName}}")
1510+
)
14511511
{
1452-
{{GenerateVtableEntry(vtableAttribute, value.context.escapedAssemblyName)}}
1512+
{{GenerateVtableEntry(vtableEntry.Key, value.context.escapedAssemblyName)}}
14531513
}
14541514
""");
14551515
}
@@ -1469,12 +1529,34 @@ private static string LookupRuntimeClassName(Type type)
14691529
string typeName = type.ToString();
14701530
""");
14711531

1532+
var runtimeClassNameToClassNameList = new Dictionary<string, List<string>>();
14721533
foreach (var vtableAttribute in value.vtableAttributes.ToImmutableHashSet().Where(static v => !string.IsNullOrEmpty(v.RuntimeClassName)))
1534+
{
1535+
if (!runtimeClassNameToClassNameList.TryGetValue(vtableAttribute.RuntimeClassName, out var classNameList))
1536+
{
1537+
classNameList = new List<string>();
1538+
runtimeClassNameToClassNameList.Add(vtableAttribute.RuntimeClassName, classNameList);
1539+
}
1540+
classNameList.Add(vtableAttribute.VtableLookupClassName);
1541+
}
1542+
1543+
foreach (var entry in runtimeClassNameToClassNameList)
14731544
{
14741545
source.AppendLine($$"""
1475-
if (typeName == "{{vtableAttribute.VtableLookupClassName}}")
1546+
if (typeName == "{{entry.Value[0]}}"
1547+
""");
1548+
1549+
for (var i = 1; i < entry.Value.Count; i++)
1550+
{
1551+
source.AppendLine($$"""
1552+
|| typeName == "{{entry.Value[i]}}"
1553+
""");
1554+
}
1555+
1556+
source.AppendLine($$"""
1557+
)
14761558
{
1477-
return "{{vtableAttribute.RuntimeClassName}}";
1559+
return "{{entry.Key}}";
14781560
}
14791561
""");
14801562
}
@@ -1630,6 +1712,11 @@ internal sealed record VtableAttribute(
16301712
bool IsPublic,
16311713
string RuntimeClassName = default);
16321714

1715+
sealed record VtableEntry(
1716+
EquatableArray<string> Interfaces,
1717+
EquatableArray<GenericInterface> GenericInterfaces,
1718+
bool IsDelegate);
1719+
16331720
internal readonly record struct BindableCustomProperty(
16341721
string Name,
16351722
string Type,

0 commit comments

Comments
 (0)