Skip to content

Commit 2d7c63b

Browse files
authored
Deduplicate base interfaces from another assembly in COM generator (#111931)
Fixes cross-assembly interface inheritance. Deduplicate external ComInterfaceInfos in the pipeline. Don't remove external interface methods in the generator pipeline - these are required to create the new VTable struct type introduced in #116289. Creates a new SourceAvailableIncrementalMethodStubGenerationContext for methods with source available and use IncrementalMethodStubGenerationContext as a base for externally defined methods. Moves the ComInterfaces from SharedTypes to the Common folder, and defines a limited set of base interfaces in SharedTypes for interfaces in ComInterfaceGenerator to derived from. Adds a number of tests for different levels of inheritance across assemblies.
1 parent 1d8a341 commit 2d7c63b

File tree

72 files changed

+776
-86
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+776
-86
lines changed

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,23 @@ internal sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interfa
2020
/// <summary>
2121
/// COM methods that require shadowing declarations on the derived interface.
2222
/// </summary>
23-
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod && !m.IsHiddenOnDerivedInterface);
23+
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod && !m.IsHiddenOnDerivedInterface && !m.IsExternallyDefined);
2424

2525
/// <summary>
2626
/// COM methods that are declared on an interface the interface inherits from.
2727
/// </summary>
2828
public IEnumerable<ComMethodContext> InheritedMethods => Methods.Where(m => m.IsInheritedMethod);
29+
30+
/// <summary>
31+
/// The size of the vtable for this interface, including the base interface methods and IUnknown methods.
32+
/// </summary>
33+
public int VTableSize => Methods.Length == 0
34+
? IUnknownConstants.VTableSize
35+
: 1 + Methods.Max(m => m.GenerationContext.VtableIndexData.Index);
36+
37+
/// <summary>
38+
/// The size of the vtable for the base interface, including it's base interface methods and IUnknown methods.
39+
/// </summary>
40+
public int BaseVTableSize => VTableSize - DeclaredMethods.Count();
2941
}
3042
}

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
5454
var externalInterfaceSymbols = attributedInterfaces.SelectMany(static (data, ct) =>
5555
{
5656
return ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(data.Symbol);
57-
});
57+
}).Collect().SelectMany(static (data, ct) => data.Distinct(ComInterfaceInfo.EqualityComparerForExternalIfaces.Instance));
5858

5959
var interfaceSymbolsWithoutDiagnostics = interfaceSymbolsToGenerateWithoutDiagnostics.Concat(externalInterfaceSymbols);
6060

@@ -84,11 +84,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
8484
.SelectMany(static (data, ct) =>
8585
{
8686
return ComMethodContext.CalculateAllMethods(data, ct);
87-
})
88-
// Now that we've determined method offsets, we can remove all externally defined methods.
89-
// We'll also filter out methods originally declared on externally defined base interfaces
90-
// as we may not be able to emit them into our assembly.
91-
.Where(context => !context.Method.OriginalDeclaringInterface.IsExternallyDefined);
87+
});
9288

9389
// Now that we've determined method offsets, we can remove all externally defined interfaces.
9490
var interfaceContextsToGenerate = interfaceContexts.Where(context => !context.IsExternallyDefined);
@@ -107,13 +103,20 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
107103
return new ComMethodContext(
108104
data.Method,
109105
data.OwningInterface,
110-
CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.OwningInterface.Info, ct));
106+
CalculateStubInformation(
107+
data.Method.MethodInfo.Syntax,
108+
symbolMap[data.Method.MethodInfo],
109+
data.Method.Index,
110+
env,
111+
data.OwningInterface.Info,
112+
ct));
111113
}).WithTrackingName(StepNames.CalculateStubInformation);
112114

113115
var interfaceAndMethodsContexts = comMethodContexts
114116
.Collect()
115117
.Combine(interfaceContextsToGenerate.Collect())
116-
.SelectMany((data, ct) => GroupComContextsForInterfaceGeneration(data.Left, data.Right, ct));
118+
.SelectMany((data, ct) =>
119+
GroupComContextsForInterfaceGeneration(data.Left, data.Right, ct));
117120

118121
// Generate the code for the managed-to-unmanaged stubs.
119122
var managedToNativeInterfaceImplementations = interfaceAndMethodsContexts
@@ -256,12 +259,22 @@ private static bool IsHResultLikeType(ManagedTypeInfo type)
256259
|| typeName.Equals("hresult", StringComparison.OrdinalIgnoreCase);
257260
}
258261

259-
private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterfaceInfo, CancellationToken ct)
262+
/// <summary>
263+
/// Calculates the shared information needed for both source-available and sourceless stub generation.
264+
/// </summary>
265+
private static IncrementalMethodStubGenerationContext CalculateSharedStubInformation(
266+
IMethodSymbol symbol,
267+
int index,
268+
StubEnvironment environment,
269+
ISignatureDiagnosticLocations diagnosticLocations,
270+
ComInterfaceInfo owningInterfaceInfo,
271+
CancellationToken ct)
260272
{
261273
ct.ThrowIfCancellationRequested();
262274
INamedTypeSymbol? lcidConversionAttrType = environment.LcidConversionAttrType;
263275
INamedTypeSymbol? suppressGCTransitionAttrType = environment.SuppressGCTransitionAttrType;
264276
INamedTypeSymbol? unmanagedCallConvAttrType = environment.UnmanagedCallConvAttrType;
277+
265278
// Get any attributes of interest on the method
266279
AttributeData? lcidConversionAttr = null;
267280
AttributeData? suppressGCTransitionAttribute = null;
@@ -282,8 +295,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
282295
}
283296
}
284297

285-
var locations = new MethodSignatureDiagnosticLocations(syntax);
286-
var generatorDiagnostics = new GeneratorDiagnosticsBag(new DiagnosticDescriptorProvider(), locations, SR.ResourceManager, typeof(FxResources.Microsoft.Interop.ComInterfaceGenerator.SR));
298+
var generatorDiagnostics = new GeneratorDiagnosticsBag(new DiagnosticDescriptorProvider(), diagnosticLocations, SR.ResourceManager, typeof(FxResources.Microsoft.Interop.ComInterfaceGenerator.SR));
287299

288300
if (lcidConversionAttr is not null)
289301
{
@@ -293,8 +305,8 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
293305

294306
GeneratedComInterfaceCompilationData.TryGetGeneratedComInterfaceAttributeFromInterface(symbol.ContainingType, out var generatedComAttribute);
295307
var generatedComInterfaceAttributeData = GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComAttribute);
296-
// Create the stub.
297308

309+
// Create the stub.
298310
var signatureContext = SignatureContext.Create(
299311
symbol,
300312
DefaultMarshallingInfoParser.Create(
@@ -387,21 +399,14 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
387399
GeneratorDiagnostics.SizeOfInCollectionMustBeDefinedAtCallReturnValue);
388400
}
389401

390-
var containingSyntaxContext = new ContainingSyntaxContext(syntax);
391-
392-
var methodSyntaxTemplate = new ContainingSyntax(new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.NewKeyword))).StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList);
393-
394402
ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv = VirtualMethodPointerStubGenerator.GenerateCallConvSyntaxFromAttributes(
395403
suppressGCTransitionAttribute,
396404
unmanagedCallConvAttribute,
397405
ImmutableArray.Create(FunctionPointerUnmanagedCallingConvention(Identifier("MemberFunction"))));
398406

399407
var declaringType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol.ContainingType);
400408

401-
var virtualMethodIndexData = new VirtualMethodIndexData(index, ImplicitThisParameter: true, direction, true, ExceptionMarshalling.Com);
402-
403409
MarshallingInfo exceptionMarshallingInfo;
404-
405410
if (generatedComInterfaceAttributeData.ExceptionToUnmanagedMarshaller is null)
406411
{
407412
exceptionMarshallingInfo = new ComExceptionMarshalling();
@@ -418,11 +423,9 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
418423

419424
return new IncrementalMethodStubGenerationContext(
420425
signatureContext,
421-
containingSyntaxContext,
422-
methodSyntaxTemplate,
423-
locations,
426+
diagnosticLocations,
424427
callConv.ToSequenceEqualImmutableArray(SyntaxEquivalentComparer.Instance),
425-
virtualMethodIndexData,
428+
new VirtualMethodIndexData(index, ImplicitThisParameter: true, direction, true, ExceptionMarshalling.Com),
426429
exceptionMarshallingInfo,
427430
environment.EnvironmentFlags,
428431
owningInterfaceInfo.Type,
@@ -431,6 +434,45 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
431434
ComInterfaceDispatchMarshallingInfo.Instance);
432435
}
433436

437+
private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax? syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterface, CancellationToken ct)
438+
{
439+
ISignatureDiagnosticLocations locations = syntax is null
440+
? NoneSignatureDiagnosticLocations.Instance
441+
: new MethodSignatureDiagnosticLocations(syntax);
442+
443+
var sourcelessStubInformation = CalculateSharedStubInformation(
444+
symbol,
445+
index,
446+
environment,
447+
locations,
448+
owningInterface,
449+
ct);
450+
451+
if (syntax is null)
452+
return sourcelessStubInformation;
453+
454+
var containingSyntaxContext = new ContainingSyntaxContext(syntax);
455+
var methodSyntaxTemplate = new ContainingSyntax(
456+
new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.NewKeyword))).StripAccessibilityModifiers(),
457+
SyntaxKind.MethodDeclaration,
458+
syntax.Identifier,
459+
syntax.TypeParameterList);
460+
461+
return new SourceAvailableIncrementalMethodStubGenerationContext(
462+
sourcelessStubInformation.SignatureContext,
463+
containingSyntaxContext,
464+
methodSyntaxTemplate,
465+
locations,
466+
sourcelessStubInformation.CallingConvention,
467+
sourcelessStubInformation.VtableIndexData,
468+
sourcelessStubInformation.ExceptionMarshallingInfo,
469+
sourcelessStubInformation.EnvironmentFlags,
470+
sourcelessStubInformation.TypeKeyOwner,
471+
sourcelessStubInformation.DeclaringType,
472+
sourcelessStubInformation.Diagnostics,
473+
ComInterfaceDispatchMarshallingInfo.Instance);
474+
}
475+
434476
private static MarshalDirection GetDirectionFromOptions(ComInterfaceOptions options)
435477
{
436478
if (options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper | ComInterfaceOptions.ComObjectWrapper))
@@ -520,12 +562,12 @@ static bool MethodEquals(ComMethodContext a, ComMethodContext b)
520562
private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _)
521563
{
522564
var definingType = interfaceGroup.Interface.Info.Type;
523-
var shadowImplementations = interfaceGroup.InheritedMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
565+
var shadowImplementations = interfaceGroup.InheritedMethods.Where(m => !m.IsExternallyDefined).Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
524566
.Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext)
525567
.Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node
526568
.WithExplicitInterfaceSpecifier(
527569
ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName))));
528-
var inheritedStubs = interfaceGroup.InheritedMethods.Select(m => m.UnreachableExceptionStub);
570+
var inheritedStubs = interfaceGroup.InheritedMethods.Where(m => !m.IsExternallyDefined).Select(m => m.UnreachableExceptionStub);
529571
return ImplementationInterfaceTemplate
530572
.AddBaseListTypes(SimpleBaseType(definingType.Syntax))
531573
.WithMembers(
@@ -661,7 +703,6 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
661703

662704
BlockSyntax fillBaseInterfaceSlots;
663705

664-
665706
if (interfaceMethods.Interface.Base is null)
666707
{
667708
// If we don't have a base interface, we need to manually fill in the base iUnknown slots.
@@ -740,7 +781,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
740781
}
741782
else
742783
{
743-
// NativeMemory.Copy(StrategyBasedComWrappers.DefaultIUnknownInteraceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<baseInterfaceType>).TypeHandle).ManagedVirtualMethodTable, vtable, (nuint)(sizeof(void*) * <startingOffset>));
784+
// NativeMemory.Copy(StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<baseInterfaceType>).TypeHandle).ManagedVirtualMethodTable, vtable, (nuint)(sizeof(void*) * <baseVTableSize>));
744785
fillBaseInterfaceSlots = Block(
745786
MethodInvocationStatement(
746787
TypeSyntaxes.System_Runtime_InteropServices_NativeMemory,
@@ -750,7 +791,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
750791
TypeSyntaxes.StrategyBasedComWrappers
751792
.Dot(IdentifierName("DefaultIUnknownInterfaceDetailsStrategy")),
752793
IdentifierName("GetIUnknownDerivedDetails"),
753-
Argument( //baseInterfaceTypeInfo.BaseInterface.FullTypeName)),
794+
Argument(
754795
TypeOfExpression(ParseTypeName(interfaceMethods.Interface.Base.Info.Type.FullTypeName))
755796
.Dot(IdentifierName("TypeHandle"))))
756797
.Dot(IdentifierName("ManagedVirtualMethodTable"))),
@@ -767,7 +808,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
767808
ParenthesizedExpression(
768809
BinaryExpression(SyntaxKind.MultiplyExpression,
769810
SizeOfExpression(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))),
770-
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.InheritedMethods.Count() + 3))))))));
811+
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.BaseVTableSize))))))));
771812
}
772813

773814
var validDeclaredMethods = interfaceMethods.DeclaredMethods
@@ -787,7 +828,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
787828
IdentifierName($"{declaredMethodContext.MethodInfo.MethodName}_{declaredMethodContext.GenerationContext.VtableIndexData.Index}")),
788829
PrefixUnaryExpression(
789830
SyntaxKind.AddressOfExpression,
790-
IdentifierName($"ABI_{declaredMethodContext.GenerationContext.StubMethodSyntaxTemplate.Identifier}")))));
831+
IdentifierName($"ABI_{((SourceAvailableIncrementalMethodStubGenerationContext)declaredMethodContext.GenerationContext).StubMethodSyntaxTemplate.Identifier}")))));
791832
}
792833

793834
return ImplementationInterfaceTemplate

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5+
using System.Collections.Generic;
56
using System.Collections.Immutable;
7+
using System.Diagnostics;
68
using System.Diagnostics.CodeAnalysis;
79
using System.Threading;
810
using Microsoft.CodeAnalysis;
911
using Microsoft.CodeAnalysis.CSharp;
1012
using Microsoft.CodeAnalysis.CSharp.Syntax;
1113
using InterfaceInfo = (Microsoft.Interop.ComInterfaceInfo InterfaceInfo, Microsoft.CodeAnalysis.INamedTypeSymbol Symbol);
1214
using DiagnosticOrInterfaceInfo = Microsoft.Interop.DiagnosticOr<(Microsoft.Interop.ComInterfaceInfo InterfaceInfo, Microsoft.CodeAnalysis.INamedTypeSymbol Symbol)>;
13-
using System.Diagnostics;
1415

1516
namespace Microsoft.Interop
1617
{
@@ -176,6 +177,13 @@ public static ImmutableArray<InterfaceInfo> CreateInterfaceInfoForBaseInterfaces
176177
return builder.ToImmutable();
177178
}
178179

180+
internal sealed class EqualityComparerForExternalIfaces : IEqualityComparer<(ComInterfaceInfo InterfaceInfo, INamedTypeSymbol Symbol)>
181+
{
182+
public bool Equals((ComInterfaceInfo, INamedTypeSymbol) x, (ComInterfaceInfo, INamedTypeSymbol) y) => SymbolEqualityComparer.Default.Equals(x.Item2, y.Item2);
183+
public int GetHashCode((ComInterfaceInfo, INamedTypeSymbol) obj) => SymbolEqualityComparer.Default.GetHashCode(obj.Item2);
184+
public static readonly EqualityComparerForExternalIfaces Instance = new();
185+
}
186+
179187
private static bool IsInPartialContext(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax, [NotNullWhen(false)] out DiagnosticInfo? diagnostic)
180188
{
181189
// Verify that the types the interface is declared in are marked partial.

0 commit comments

Comments
 (0)