@@ -54,7 +54,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
54
54
var externalInterfaceSymbols = attributedInterfaces . SelectMany ( static ( data , ct ) =>
55
55
{
56
56
return ComInterfaceInfo . CreateInterfaceInfoForBaseInterfacesInOtherCompilations ( data . Symbol ) ;
57
- } ) ;
57
+ } ) . Collect ( ) . SelectMany ( static ( data , ct ) => data . Distinct ( ComInterfaceInfo . EqualityComparerForExternalIfaces . Instance ) ) ;
58
58
59
59
var interfaceSymbolsWithoutDiagnostics = interfaceSymbolsToGenerateWithoutDiagnostics . Concat ( externalInterfaceSymbols ) ;
60
60
@@ -84,11 +84,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
84
84
. SelectMany ( static ( data , ct ) =>
85
85
{
86
86
return ComMethodContext . CalculateAllMethods ( data , ct ) ;
87
- } )
88
- // Now that we've determined method offsets, we can remove all externally defined methods.
89
- // We'll also filter out methods originally declared on externally defined base interfaces
90
- // as we may not be able to emit them into our assembly.
91
- . Where ( context => ! context . Method . OriginalDeclaringInterface . IsExternallyDefined ) ;
87
+ } ) ;
92
88
93
89
// Now that we've determined method offsets, we can remove all externally defined interfaces.
94
90
var interfaceContextsToGenerate = interfaceContexts . Where ( context => ! context . IsExternallyDefined ) ;
@@ -107,13 +103,20 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
107
103
return new ComMethodContext (
108
104
data . Method ,
109
105
data . OwningInterface ,
110
- CalculateStubInformation ( data . Method . MethodInfo . Syntax , symbolMap [ data . Method . MethodInfo ] , data . Method . Index , env , data . OwningInterface . Info , ct ) ) ;
106
+ CalculateStubInformation (
107
+ data . Method . MethodInfo . Syntax ,
108
+ symbolMap [ data . Method . MethodInfo ] ,
109
+ data . Method . Index ,
110
+ env ,
111
+ data . OwningInterface . Info ,
112
+ ct ) ) ;
111
113
} ) . WithTrackingName ( StepNames . CalculateStubInformation ) ;
112
114
113
115
var interfaceAndMethodsContexts = comMethodContexts
114
116
. Collect ( )
115
117
. Combine ( interfaceContextsToGenerate . Collect ( ) )
116
- . SelectMany ( ( data , ct ) => GroupComContextsForInterfaceGeneration ( data . Left , data . Right , ct ) ) ;
118
+ . SelectMany ( ( data , ct ) =>
119
+ GroupComContextsForInterfaceGeneration ( data . Left , data . Right , ct ) ) ;
117
120
118
121
// Generate the code for the managed-to-unmanaged stubs.
119
122
var managedToNativeInterfaceImplementations = interfaceAndMethodsContexts
@@ -256,12 +259,22 @@ private static bool IsHResultLikeType(ManagedTypeInfo type)
256
259
|| typeName . Equals ( "hresult" , StringComparison . OrdinalIgnoreCase ) ;
257
260
}
258
261
259
- private static IncrementalMethodStubGenerationContext CalculateStubInformation ( MethodDeclarationSyntax syntax , IMethodSymbol symbol , int index , StubEnvironment environment , ComInterfaceInfo owningInterfaceInfo , CancellationToken ct )
262
+ /// <summary>
263
+ /// Calculates the shared information needed for both source-available and sourceless stub generation.
264
+ /// </summary>
265
+ private static IncrementalMethodStubGenerationContext CalculateSharedStubInformation (
266
+ IMethodSymbol symbol ,
267
+ int index ,
268
+ StubEnvironment environment ,
269
+ ISignatureDiagnosticLocations diagnosticLocations ,
270
+ ComInterfaceInfo owningInterfaceInfo ,
271
+ CancellationToken ct )
260
272
{
261
273
ct . ThrowIfCancellationRequested ( ) ;
262
274
INamedTypeSymbol ? lcidConversionAttrType = environment . LcidConversionAttrType ;
263
275
INamedTypeSymbol ? suppressGCTransitionAttrType = environment . SuppressGCTransitionAttrType ;
264
276
INamedTypeSymbol ? unmanagedCallConvAttrType = environment . UnmanagedCallConvAttrType ;
277
+
265
278
// Get any attributes of interest on the method
266
279
AttributeData ? lcidConversionAttr = null ;
267
280
AttributeData ? suppressGCTransitionAttribute = null ;
@@ -282,8 +295,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
282
295
}
283
296
}
284
297
285
- var locations = new MethodSignatureDiagnosticLocations ( syntax ) ;
286
- var generatorDiagnostics = new GeneratorDiagnosticsBag ( new DiagnosticDescriptorProvider ( ) , locations , SR . ResourceManager , typeof ( FxResources . Microsoft . Interop . ComInterfaceGenerator . SR ) ) ;
298
+ var generatorDiagnostics = new GeneratorDiagnosticsBag ( new DiagnosticDescriptorProvider ( ) , diagnosticLocations , SR . ResourceManager , typeof ( FxResources . Microsoft . Interop . ComInterfaceGenerator . SR ) ) ;
287
299
288
300
if ( lcidConversionAttr is not null )
289
301
{
@@ -293,8 +305,8 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
293
305
294
306
GeneratedComInterfaceCompilationData . TryGetGeneratedComInterfaceAttributeFromInterface ( symbol . ContainingType , out var generatedComAttribute ) ;
295
307
var generatedComInterfaceAttributeData = GeneratedComInterfaceCompilationData . GetDataFromAttribute ( generatedComAttribute ) ;
296
- // Create the stub.
297
308
309
+ // Create the stub.
298
310
var signatureContext = SignatureContext . Create (
299
311
symbol ,
300
312
DefaultMarshallingInfoParser . Create (
@@ -387,21 +399,14 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
387
399
GeneratorDiagnostics . SizeOfInCollectionMustBeDefinedAtCallReturnValue ) ;
388
400
}
389
401
390
- var containingSyntaxContext = new ContainingSyntaxContext ( syntax ) ;
391
-
392
- var methodSyntaxTemplate = new ContainingSyntax ( new SyntaxTokenList ( syntax . Modifiers . Where ( static m => ! m . IsKind ( SyntaxKind . NewKeyword ) ) ) . StripAccessibilityModifiers ( ) , SyntaxKind . MethodDeclaration , syntax . Identifier , syntax . TypeParameterList ) ;
393
-
394
402
ImmutableArray < FunctionPointerUnmanagedCallingConventionSyntax > callConv = VirtualMethodPointerStubGenerator . GenerateCallConvSyntaxFromAttributes (
395
403
suppressGCTransitionAttribute ,
396
404
unmanagedCallConvAttribute ,
397
405
ImmutableArray . Create ( FunctionPointerUnmanagedCallingConvention ( Identifier ( "MemberFunction" ) ) ) ) ;
398
406
399
407
var declaringType = ManagedTypeInfo . CreateTypeInfoForTypeSymbol ( symbol . ContainingType ) ;
400
408
401
- var virtualMethodIndexData = new VirtualMethodIndexData ( index , ImplicitThisParameter : true , direction , true , ExceptionMarshalling . Com ) ;
402
-
403
409
MarshallingInfo exceptionMarshallingInfo ;
404
-
405
410
if ( generatedComInterfaceAttributeData . ExceptionToUnmanagedMarshaller is null )
406
411
{
407
412
exceptionMarshallingInfo = new ComExceptionMarshalling ( ) ;
@@ -418,11 +423,9 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
418
423
419
424
return new IncrementalMethodStubGenerationContext (
420
425
signatureContext ,
421
- containingSyntaxContext ,
422
- methodSyntaxTemplate ,
423
- locations ,
426
+ diagnosticLocations ,
424
427
callConv . ToSequenceEqualImmutableArray ( SyntaxEquivalentComparer . Instance ) ,
425
- virtualMethodIndexData ,
428
+ new VirtualMethodIndexData ( index , ImplicitThisParameter : true , direction , true , ExceptionMarshalling . Com ) ,
426
429
exceptionMarshallingInfo ,
427
430
environment . EnvironmentFlags ,
428
431
owningInterfaceInfo . Type ,
@@ -431,6 +434,45 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
431
434
ComInterfaceDispatchMarshallingInfo . Instance ) ;
432
435
}
433
436
437
+ private static IncrementalMethodStubGenerationContext CalculateStubInformation ( MethodDeclarationSyntax ? syntax , IMethodSymbol symbol , int index , StubEnvironment environment , ComInterfaceInfo owningInterface , CancellationToken ct )
438
+ {
439
+ ISignatureDiagnosticLocations locations = syntax is null
440
+ ? NoneSignatureDiagnosticLocations . Instance
441
+ : new MethodSignatureDiagnosticLocations ( syntax ) ;
442
+
443
+ var sourcelessStubInformation = CalculateSharedStubInformation (
444
+ symbol ,
445
+ index ,
446
+ environment ,
447
+ locations ,
448
+ owningInterface ,
449
+ ct ) ;
450
+
451
+ if ( syntax is null )
452
+ return sourcelessStubInformation ;
453
+
454
+ var containingSyntaxContext = new ContainingSyntaxContext ( syntax ) ;
455
+ var methodSyntaxTemplate = new ContainingSyntax (
456
+ new SyntaxTokenList ( syntax . Modifiers . Where ( static m => ! m . IsKind ( SyntaxKind . NewKeyword ) ) ) . StripAccessibilityModifiers ( ) ,
457
+ SyntaxKind . MethodDeclaration ,
458
+ syntax . Identifier ,
459
+ syntax . TypeParameterList ) ;
460
+
461
+ return new SourceAvailableIncrementalMethodStubGenerationContext (
462
+ sourcelessStubInformation . SignatureContext ,
463
+ containingSyntaxContext ,
464
+ methodSyntaxTemplate ,
465
+ locations ,
466
+ sourcelessStubInformation . CallingConvention ,
467
+ sourcelessStubInformation . VtableIndexData ,
468
+ sourcelessStubInformation . ExceptionMarshallingInfo ,
469
+ sourcelessStubInformation . EnvironmentFlags ,
470
+ sourcelessStubInformation . TypeKeyOwner ,
471
+ sourcelessStubInformation . DeclaringType ,
472
+ sourcelessStubInformation . Diagnostics ,
473
+ ComInterfaceDispatchMarshallingInfo . Instance ) ;
474
+ }
475
+
434
476
private static MarshalDirection GetDirectionFromOptions ( ComInterfaceOptions options )
435
477
{
436
478
if ( options . HasFlag ( ComInterfaceOptions . ManagedObjectWrapper | ComInterfaceOptions . ComObjectWrapper ) )
@@ -520,12 +562,12 @@ static bool MethodEquals(ComMethodContext a, ComMethodContext b)
520
562
private static InterfaceDeclarationSyntax GenerateImplementationInterface ( ComInterfaceAndMethodsContext interfaceGroup , CancellationToken _ )
521
563
{
522
564
var definingType = interfaceGroup . Interface . Info . Type ;
523
- var shadowImplementations = interfaceGroup . InheritedMethods . Select ( m => ( Method : m , ManagedToUnmanagedStub : m . ManagedToUnmanagedStub ) )
565
+ var shadowImplementations = interfaceGroup . InheritedMethods . Where ( m => ! m . IsExternallyDefined ) . Select ( m => ( Method : m , ManagedToUnmanagedStub : m . ManagedToUnmanagedStub ) )
524
566
. Where ( p => p . ManagedToUnmanagedStub is GeneratedStubCodeContext )
525
567
. Select ( ctx => ( ( GeneratedStubCodeContext ) ctx . ManagedToUnmanagedStub ) . Stub . Node
526
568
. WithExplicitInterfaceSpecifier (
527
569
ExplicitInterfaceSpecifier ( ParseName ( definingType . FullTypeName ) ) ) ) ;
528
- var inheritedStubs = interfaceGroup . InheritedMethods . Select ( m => m . UnreachableExceptionStub ) ;
570
+ var inheritedStubs = interfaceGroup . InheritedMethods . Where ( m => ! m . IsExternallyDefined ) . Select ( m => m . UnreachableExceptionStub ) ;
529
571
return ImplementationInterfaceTemplate
530
572
. AddBaseListTypes ( SimpleBaseType ( definingType . Syntax ) )
531
573
. WithMembers (
@@ -661,7 +703,6 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
661
703
662
704
BlockSyntax fillBaseInterfaceSlots ;
663
705
664
-
665
706
if ( interfaceMethods . Interface . Base is null )
666
707
{
667
708
// If we don't have a base interface, we need to manually fill in the base iUnknown slots.
@@ -740,7 +781,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
740
781
}
741
782
else
742
783
{
743
- // NativeMemory.Copy(StrategyBasedComWrappers.DefaultIUnknownInteraceDetailsStrategy .GetIUnknownDerivedDetails(typeof(<baseInterfaceType>).TypeHandle).ManagedVirtualMethodTable, vtable, (nuint)(sizeof(void*) * <startingOffset >));
784
+ // NativeMemory.Copy(StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy .GetIUnknownDerivedDetails(typeof(<baseInterfaceType>).TypeHandle).ManagedVirtualMethodTable, vtable, (nuint)(sizeof(void*) * <baseVTableSize >));
744
785
fillBaseInterfaceSlots = Block (
745
786
MethodInvocationStatement (
746
787
TypeSyntaxes . System_Runtime_InteropServices_NativeMemory ,
@@ -750,7 +791,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
750
791
TypeSyntaxes . StrategyBasedComWrappers
751
792
. Dot ( IdentifierName ( "DefaultIUnknownInterfaceDetailsStrategy" ) ) ,
752
793
IdentifierName ( "GetIUnknownDerivedDetails" ) ,
753
- Argument ( //baseInterfaceTypeInfo.BaseInterface.FullTypeName)),
794
+ Argument (
754
795
TypeOfExpression ( ParseTypeName ( interfaceMethods . Interface . Base . Info . Type . FullTypeName ) )
755
796
. Dot ( IdentifierName ( "TypeHandle" ) ) ) )
756
797
. Dot ( IdentifierName ( "ManagedVirtualMethodTable" ) ) ) ,
@@ -767,7 +808,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
767
808
ParenthesizedExpression (
768
809
BinaryExpression ( SyntaxKind . MultiplyExpression ,
769
810
SizeOfExpression ( PointerType ( PredefinedType ( Token ( SyntaxKind . VoidKeyword ) ) ) ) ,
770
- LiteralExpression ( SyntaxKind . NumericLiteralExpression , Literal ( interfaceMethods . InheritedMethods . Count ( ) + 3 ) ) ) ) ) ) ) ) ;
811
+ LiteralExpression ( SyntaxKind . NumericLiteralExpression , Literal ( interfaceMethods . BaseVTableSize ) ) ) ) ) ) ) ) ;
771
812
}
772
813
773
814
var validDeclaredMethods = interfaceMethods . DeclaredMethods
@@ -787,7 +828,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
787
828
IdentifierName ( $ "{ declaredMethodContext . MethodInfo . MethodName } _{ declaredMethodContext . GenerationContext . VtableIndexData . Index } ") ) ,
788
829
PrefixUnaryExpression (
789
830
SyntaxKind . AddressOfExpression ,
790
- IdentifierName ( $ "ABI_{ declaredMethodContext . GenerationContext . StubMethodSyntaxTemplate . Identifier } ") ) ) ) ) ;
831
+ IdentifierName ( $ "ABI_{ ( ( SourceAvailableIncrementalMethodStubGenerationContext ) declaredMethodContext . GenerationContext ) . StubMethodSyntaxTemplate . Identifier } ") ) ) ) ) ;
791
832
}
792
833
793
834
return ImplementationInterfaceTemplate
0 commit comments