From ead9bc3ce3c1c3db2c447243398395e31959e8d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexis=20Laferrie=CC=80re?= Date: Thu, 4 Jun 2020 18:10:23 -0700 Subject: [PATCH 01/36] [ModuleInterface] Don't print SPI attributes on unsupported decls When emitting the private swiftinterface, the compiler prints the attribute explicitly even when it is deduced from the context. This can lead to unparsable private swiftinterface files. As a narrow fix, check if the decl type is supported before printing the attribute. rdar://64039069 --- lib/AST/ASTPrinter.cpp | 4 +++- test/SPI/private_swiftinterface.swift | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/lib/AST/ASTPrinter.cpp b/lib/AST/ASTPrinter.cpp index b3815525627ee..79b0aa87d06ad 100644 --- a/lib/AST/ASTPrinter.cpp +++ b/lib/AST/ASTPrinter.cpp @@ -991,7 +991,9 @@ void PrintAST::printAttributes(const Decl *D) { } // SPI groups - if (Options.PrintSPIs) { + if (Options.PrintSPIs && + DeclAttribute::canAttributeAppearOnDeclKind( + DAK_SPIAccessControl, D->getKind())) { interleave(D->getSPIGroups(), [&](Identifier spiName) { Printer.printAttrName("_spi", true); diff --git a/test/SPI/private_swiftinterface.swift b/test/SPI/private_swiftinterface.swift index 8be3d34925410..3530365c4cac7 100644 --- a/test/SPI/private_swiftinterface.swift +++ b/test/SPI/private_swiftinterface.swift @@ -88,6 +88,20 @@ private class PrivateClassLocal {} // CHECK-PUBLIC-NOT: extensionSPIMethod } +@_spi(LocalSPI) public protocol SPIProto3 { +// CHECK-PRIVATE: @_spi(LocalSPI) public protocol SPIProto3 +// CHECK-PUBLIC-NOT: SPIProto3 + + associatedtype AssociatedType + // CHECK-PRIVATE: associatedtype AssociatedType + // CHECK-PRIVATE-NOT: @_spi(LocalSPI) associatedtype AssociatedType + // CHECK-PUBLIC-NOT: AssociatedType + + func implicitSPIMethod() + // CHECK-PRIVATE: @_spi(LocalSPI) func implicitSPIMethod() + // CHECK-PUBLIC-NOT: implicitSPIMethod +} + // Test the dummy conformance printed to replace private types used in // conditional conformances. rdar://problem/63352700 From f44cbe4697c69d76c2e7ecd2f28e3565827a0505 Mon Sep 17 00:00:00 2001 From: Karoy Lorentey Date: Thu, 2 Jul 2020 01:08:05 -0700 Subject: [PATCH 02/36] [Foundation] Update & simplify class name stability check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move the ObjC class name stability check logic to the Swift runtime, exposing it as a new SPI called _swift_isObjCTypeNameSerializable. Update the reporting logic. The ObjC names of generic classes are considered stable now, but private classes and classes defined in function bodies or other anonymous contexts are unstable by design. On the overlay side, rewrite the check’s implementation in Swift and considerably simplify it. rdar://57809977 --- include/swift/Runtime/FoundationSupport.h | 40 +++ .../public/Darwin/Foundation/CMakeLists.txt | 2 +- stdlib/public/Darwin/Foundation/CheckClass.mm | 291 ------------------ .../public/Darwin/Foundation/CheckClass.swift | 64 ++++ .../SwiftShims/FoundationOverlayShims.h | 5 + stdlib/public/runtime/CMakeLists.txt | 1 + stdlib/public/runtime/FoundationSupport.cpp | 61 ++++ .../SDK/check_class_for_archiving.swift | 80 +++-- .../SDK/check_class_for_archiving_log.swift | 232 ++++++-------- 9 files changed, 324 insertions(+), 452 deletions(-) create mode 100644 include/swift/Runtime/FoundationSupport.h delete mode 100644 stdlib/public/Darwin/Foundation/CheckClass.mm create mode 100644 stdlib/public/Darwin/Foundation/CheckClass.swift create mode 100644 stdlib/public/runtime/FoundationSupport.cpp diff --git a/include/swift/Runtime/FoundationSupport.h b/include/swift/Runtime/FoundationSupport.h new file mode 100644 index 0000000000000..cbda6635c7a67 --- /dev/null +++ b/include/swift/Runtime/FoundationSupport.h @@ -0,0 +1,40 @@ +//===--- FoundationSupport.cpp - Support functions for Foundation ---------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// Helper functions for the Foundation framework. +// +//===----------------------------------------------------------------------===// + +#ifndef SWIFT_RUNTIME_FOUNDATION_SUPPORT_H +#define SWIFT_RUNTIME_FOUNDATION_SUPPORT_H + +#include "swift/Runtime/Config.h" + +#if SWIFT_OBJC_INTEROP +#include + +#ifdef __cplusplus +namespace swift { extern "C" { +#endif + +/// Returns a boolean indicating whether the Objective-C name of a class type is +/// stable across executions, i.e., if the class name is safe to serialize. (The +/// names of private and local types are unstable.) +SWIFT_RUNTIME_STDLIB_SPI +bool _swift_isObjCTypeNameSerializable(Class theClass); + +#ifdef __cplusplus +}} // extern "C", namespace swift +#endif + +#endif // SWIFT_OBJC_INTEROP +#endif // SWIFT_RUNTIME_FOUNDATION_SUPPORT_H diff --git a/stdlib/public/Darwin/Foundation/CMakeLists.txt b/stdlib/public/Darwin/Foundation/CMakeLists.txt index 4b489d06a297e..da9a8428a25db 100644 --- a/stdlib/public/Darwin/Foundation/CMakeLists.txt +++ b/stdlib/public/Darwin/Foundation/CMakeLists.txt @@ -7,7 +7,7 @@ add_swift_target_library(swiftFoundation ${SWIFT_SDK_OVERLAY_LIBRARY_BUILD_TYPES BundleLookup.mm Calendar.swift CharacterSet.swift - CheckClass.mm + CheckClass.swift Codable.swift Collections+DataProtocol.swift CombineTypealiases.swift diff --git a/stdlib/public/Darwin/Foundation/CheckClass.mm b/stdlib/public/Darwin/Foundation/CheckClass.mm deleted file mode 100644 index 8a5621f1e3e41..0000000000000 --- a/stdlib/public/Darwin/Foundation/CheckClass.mm +++ /dev/null @@ -1,291 +0,0 @@ -#import - -#include - -#include "swift/Runtime/HeapObject.h" -#include "swift/Runtime/Metadata.h" - -@interface NSKeyedUnarchiver (SwiftAdditions) -+ (int)_swift_checkClassAndWarnForKeyedArchiving:(Class)cls - operation:(int)operation - NS_SWIFT_NAME(_swift_checkClassAndWarnForKeyedArchiving(_:operation:)); -@end - -static bool isASCIIIdentifierChar(char c) { - if (c >= 'a' && c <= 'z') return true; - if (c >= 'A' && c <= 'Z') return true; - if (c >= '0' && c <= '9') return true; - if (c == '_') return true; - if (c == '$') return true; - return false; -} - -template -static constexpr size_t arrayLength(T (&)[N]) { return N; } - -static void logIfFirstOccurrence(Class objcClass, void (^log)(void)) { - static auto queue = dispatch_queue_create( - "SwiftFoundation._checkClassAndWarnForKeyedArchivingQueue", - DISPATCH_QUEUE_SERIAL); - static NSHashTable *seenClasses = nil; - - dispatch_sync(queue, ^{ - // Will be NO when seenClasses is still nil. - if ([seenClasses containsObject:objcClass]) - return; - - if (!seenClasses) { - NSPointerFunctionsOptions options = 0; - options |= NSPointerFunctionsOpaqueMemory; - options |= NSPointerFunctionsObjectPointerPersonality; - seenClasses = [[NSHashTable alloc] initWithOptions:options capacity:16]; - } - [seenClasses addObject:objcClass]; - - // Synchronize logging so that multiple lines aren't interleaved. - log(); - }); -} - -namespace { - class StringRefLite { - StringRefLite(const char *data, size_t len) : data(data), length(len) {} - public: - const char *data; - size_t length; - - StringRefLite() : data(nullptr), length(0) {} - - template - StringRefLite(const char (&staticStr)[N]) : data(staticStr), length(N) {} - - StringRefLite(swift::TypeNamePair rawValue) - : data(rawValue.data), - length(rawValue.length){} - - NS_RETURNS_RETAINED - NSString *newNSStringNoCopy() const { - return [[NSString alloc] initWithBytesNoCopy:const_cast(data) - length:length - encoding:NSUTF8StringEncoding - freeWhenDone:NO]; - } - - const char &operator[](size_t offset) const { - assert(offset < length); - return data[offset]; - } - - StringRefLite slice(size_t from, size_t to) const { - assert(from <= to); - assert(to <= length); - return {data + from, to - from}; - } - - const char *begin() const { - return data; - } - const char *end() const { - return data + length; - } - }; -} - -/// Assume that a non-generic demangled class name always ends in ".MyClass" -/// or ".(MyClass plus extra info)". -static StringRefLite findBaseName(StringRefLite demangledName) { - size_t end = demangledName.length; - size_t parenCount = 0; - for (size_t i = end; i != 0; --i) { - switch (demangledName[i - 1]) { - case '.': - if (parenCount == 0) { - if (i != end && demangledName[i] == '(') - ++i; - return demangledName.slice(i, end); - } - break; - case ')': - parenCount += 1; - break; - case '(': - if (parenCount > 0) - parenCount -= 1; - break; - case ' ': - end = i - 1; - break; - default: - break; - } - } - return {}; -} - -@implementation NSKeyedUnarchiver (SwiftAdditions) - -/// Checks if class \p objcClass is good for archiving. -/// -/// If not, a runtime warning is printed. -/// -/// \param operation Specifies the archiving operation. Valid operations are: -/// 0: archiving -/// 1: unarchiving -/// \return Returns the status -/// 0: not a problem class (either non-Swift or has an explicit name) -/// 1: a Swift generic class -/// 2: a Swift non-generic class where adding @objc is valid -/// Future versions of this API will return nonzero values for additional cases -/// that mean the class shouldn't be archived. -+ (int)_swift_checkClassAndWarnForKeyedArchiving:(Class)objcClass - operation:(int)operation { - using namespace swift; - const ClassMetadata *theClass = (ClassMetadata *)objcClass; - - // Is it a (real) swift class? - if (!theClass->isTypeMetadata() || theClass->isArtificialSubclass()) - return 0; - - // Does the class already have a custom name? - if (theClass->getFlags() & ClassFlags::HasCustomObjCName) - return 0; - - // Is it a mangled name? - const char *className = class_getName(objcClass); - if (!(className[0] == '_' && className[1] == 'T')) - return 0; - // Is it a name in the form .? Note: the module name could - // start with "_T". - if (strchr(className, '.')) - return 0; - - // Is it a generic class? - if (theClass->getDescription()->isGeneric()) { - logIfFirstOccurrence(objcClass, ^{ - // Use actual NSStrings to force UTF-8. - StringRefLite demangledName = swift_getTypeName(theClass, - /*qualified*/true); - NSString *demangledString = demangledName.newNSStringNoCopy(); - NSString *mangledString = NSStringFromClass(objcClass); - - NSString *primaryMessage; - switch (operation) { - case 1: - primaryMessage = [[NSString alloc] initWithFormat: - @"Attempting to unarchive generic Swift class '%@' with mangled " - "runtime name '%@'. Runtime names for generic classes are " - "unstable and may change in the future, leading to " - "non-decodable data.", demangledString, mangledString]; - break; - default: - primaryMessage = [[NSString alloc] initWithFormat: - @"Attempting to archive generic Swift class '%@' with mangled " - "runtime name '%@'. Runtime names for generic classes are " - "unstable and may change in the future, leading to " - "non-decodable data.", demangledString, mangledString]; - break; - } - NSString *generatedNote = [[NSString alloc] initWithFormat: - @"To avoid this failure, create a concrete subclass and register " - "it with NSKeyedUnarchiver.setClass(_:forClassName:) instead, " - "using the name \"%@\".", mangledString]; - const char *staticNote = - "If you need to produce archives compatible with older versions " - "of your program, use NSKeyedArchiver.setClassName(_:for:) as well."; - - NSLog(@"%@", primaryMessage); - NSLog(@"%@", generatedNote); - NSLog(@"%s", staticNote); - - RuntimeErrorDetails::Note notes[] = { - { [generatedNote UTF8String], /*numFixIts*/0, /*fixIts*/nullptr }, - { staticNote, /*numFixIts*/0, /*fixIts*/nullptr }, - }; - - RuntimeErrorDetails errorInfo = {}; - errorInfo.version = RuntimeErrorDetails::currentVersion; - errorInfo.errorType = "nskeyedarchiver-incompatible-class"; - errorInfo.notes = notes; - errorInfo.numNotes = arrayLength(notes); - - _swift_reportToDebugger(RuntimeErrorFlagNone, [primaryMessage UTF8String], - &errorInfo); - - [primaryMessage release]; - [generatedNote release]; - [demangledString release]; - }); - return 1; - } - - // It's a swift class with a (compiler generated) mangled name, which should - // be written into an NSArchive. - logIfFirstOccurrence(objcClass, ^{ - // Use actual NSStrings to force UTF-8. - StringRefLite demangledName = swift_getTypeName(theClass,/*qualified*/true); - NSString *demangledString = demangledName.newNSStringNoCopy(); - NSString *mangledString = NSStringFromClass(objcClass); - - NSString *primaryMessage; - switch (operation) { - case 1: - primaryMessage = [[NSString alloc] initWithFormat: - @"Attempting to unarchive Swift class '%@' with mangled runtime " - "name '%@'. The runtime name for this class is unstable and may " - "change in the future, leading to non-decodable data.", - demangledString, mangledString]; - break; - default: - primaryMessage = [[NSString alloc] initWithFormat: - @"Attempting to archive Swift class '%@' with mangled runtime " - "name '%@'. The runtime name for this class is unstable and may " - "change in the future, leading to non-decodable data.", - demangledString, mangledString]; - break; - } - - NSString *firstNote = [[NSString alloc] initWithFormat: - @"You can use the 'objc' attribute to ensure that the name will not " - "change: \"@objc(%@)\"", mangledString]; - - StringRefLite baseName = findBaseName(demangledName); - // Offer a more generic message if the base name we found doesn't look like - // an ASCII identifier. This avoids printing names like "ABCモデル". - if (baseName.length == 0 || - !std::all_of(baseName.begin(), baseName.end(), isASCIIIdentifierChar)) { - baseName = "MyModel"; - } - - NSString *secondNote = [[NSString alloc] initWithFormat: - @"If there are no existing archives containing this class, you should " - "choose a unique, prefixed name instead: \"@objc(ABC%1$.*2$s)\"", - baseName.data, (int)baseName.length]; - - NSLog(@"%@", primaryMessage); - NSLog(@"%@", firstNote); - NSLog(@"%@", secondNote); - - // FIXME: We could suggest these as fix-its if we had source locations for - // the class. - RuntimeErrorDetails::Note notes[] = { - { [firstNote UTF8String], /*numFixIts*/0, /*fixIts*/nullptr }, - { [secondNote UTF8String], /*numFixIts*/0, /*fixIts*/nullptr }, - }; - - RuntimeErrorDetails errorInfo = {}; - errorInfo.version = RuntimeErrorDetails::currentVersion; - errorInfo.errorType = "nskeyedarchiver-incompatible-class"; - errorInfo.notes = notes; - errorInfo.numNotes = arrayLength(notes); - - _swift_reportToDebugger(RuntimeErrorFlagNone, [primaryMessage UTF8String], - &errorInfo); - - [primaryMessage release]; - [firstNote release]; - [secondNote release]; - [demangledString release]; - }); - return 2; -} -@end diff --git a/stdlib/public/Darwin/Foundation/CheckClass.swift b/stdlib/public/Darwin/Foundation/CheckClass.swift new file mode 100644 index 0000000000000..b7ac9da9cc8c2 --- /dev/null +++ b/stdlib/public/Darwin/Foundation/CheckClass.swift @@ -0,0 +1,64 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +@_exported import Foundation // Clang module +@_implementationOnly import _SwiftFoundationOverlayShims +import Dispatch + +private let _queue = DispatchQueue(label: "com.apple.SwiftFoundation._checkClassAndWarnForKeyedArchivingQueue") +private var _seenClasses: Set = [] +private func _isClassFirstSeen(_ theClass: AnyClass) -> Bool { + _queue.sync { + let id = ObjectIdentifier(theClass) + return _seenClasses.insert(id).inserted + } +} + +extension NSKeyedUnarchiver { + /// Checks if class `theClass` is good for archiving. + /// + /// If not, a runtime warning is printed. + /// + /// - Parameter operation: Specifies the archiving operation. Supported values + /// are 0 for archiving, and 1 for unarchiving. + /// - Returns: 0 if the given class is safe to archive, and non-zero if it + /// isn't. + @objc(_swift_checkClassAndWarnForKeyedArchiving:operation:) + @usableFromInline + internal class func _swift_checkClassAndWarnForKeyedArchiving( + _ theClass: AnyClass, + operation: CInt + ) -> CInt { + if _swift_isObjCTypeNameSerializable(theClass) { return 0 } + + if _isClassFirstSeen(theClass) { + let demangledName = String(reflecting: theClass) + let mangledName = NSStringFromClass(theClass) + + let op = (operation == 1 ? "unarchive" : "archive") + + let message = """ + Attempting to \(op) Swift class '\(demangledName)' with unstable runtime name '\(mangledName)'. + The runtime name for this class may change in the future, leading to non-decodable data. + + You can use the 'objc' attribute to ensure that the name will not change: + "@objc(\(mangledName))" + + If there are no existing archives containing this class, you should choose a unique, prefixed name instead: + "@objc(ABCMyModel)" + """ + NSLog("%@", message) + _swift_reportToDebugger(0, message, nil) + } + return 1 + } +} diff --git a/stdlib/public/SwiftShims/FoundationOverlayShims.h b/stdlib/public/SwiftShims/FoundationOverlayShims.h index 1c175352cae86..8c1692c8ffb8b 100644 --- a/stdlib/public/SwiftShims/FoundationOverlayShims.h +++ b/stdlib/public/SwiftShims/FoundationOverlayShims.h @@ -75,3 +75,8 @@ static inline _Bool _withStackOrHeapBuffer(size_t amount, void (__attribute__((n @protocol _NSKVOCompatibilityShim + (void)_noteProcessHasUsedKVOSwiftOverlay; @end + + +// Exported by libswiftCore: +extern bool _swift_isObjCTypeNameSerializable(Class theClass); +extern void _swift_reportToDebugger(uintptr_t flags, const char *message, void *details); diff --git a/stdlib/public/runtime/CMakeLists.txt b/stdlib/public/runtime/CMakeLists.txt index a43be7ce52b1d..ed72fba6bf02d 100644 --- a/stdlib/public/runtime/CMakeLists.txt +++ b/stdlib/public/runtime/CMakeLists.txt @@ -43,6 +43,7 @@ set(swift_runtime_sources Exclusivity.cpp ExistentialContainer.cpp Float16Support.cpp + FoundationSupport.cpp Heap.cpp HeapObject.cpp ImageInspectionCommon.cpp diff --git a/stdlib/public/runtime/FoundationSupport.cpp b/stdlib/public/runtime/FoundationSupport.cpp new file mode 100644 index 0000000000000..bab3f01e75ba7 --- /dev/null +++ b/stdlib/public/runtime/FoundationSupport.cpp @@ -0,0 +1,61 @@ +//===--- FoundationSupport.cpp - Support functions for Foundation ---------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// Helper functions for the Foundation framework. +// +//===----------------------------------------------------------------------===// + +#include "swift/Runtime/FoundationSupport.h" + +#if SWIFT_OBJC_INTEROP +#include "swift/Runtime/Metadata.h" +#include "swift/Runtime/HeapObject.h" + +using namespace swift; + +/// Returns a boolean indicating whether the Objective-C name of a class type is +/// stable across executions, i.e., if the class name is safe to serialize. (The +/// names of private and local types are unstable.) +bool +swift::_swift_isObjCTypeNameSerializable(Class theClass) { + auto type = (AnyClassMetadata *)theClass; + switch (type->getKind()) { + case MetadataKind::ObjCClassWrapper: + case MetadataKind::ForeignClass: + return true; + case MetadataKind::Class: { + // Pure ObjC classes always have stable names. + if (type->isPureObjC()) + return true; + auto cls = static_cast(type); + // Peek through artificial subclasses. + if (cls->isArtificialSubclass()) { + cls = cls->Superclass; + } + // A custom ObjC name is always considered stable. + if (cls->getFlags() & ClassFlags::HasCustomObjCName) + return true; + // Otherwise the name is stable if the class has no anonymous ancestor context. + auto desc = static_cast(cls->getDescription()); + while (desc) { + if (desc->getKind() == ContextDescriptorKind::Anonymous) { + return false; + } + desc = desc->Parent.get(); + } + return true; + } + default: + return false; + } +} +#endif // SWIFT_OBJC_INTEROP diff --git a/test/Interpreter/SDK/check_class_for_archiving.swift b/test/Interpreter/SDK/check_class_for_archiving.swift index faca73a9d1bd2..cbfd9bc21572b 100644 --- a/test/Interpreter/SDK/check_class_for_archiving.swift +++ b/test/Interpreter/SDK/check_class_for_archiving.swift @@ -1,12 +1,13 @@ // RUN: %empty-directory(%t) // RUN: %target-build-swift %s -module-name=_Test -import-objc-header %S/Inputs/check_class_for_archiving.h -o %t/a.out // RUN: %target-codesign %t/a.out -// RUN: %target-run %t/a.out | %FileCheck %s +// RUN: %target-run %t/a.out // REQUIRES: executable_test // REQUIRES: objc_interop import Foundation +import StdlibUnittest class SwiftClass {} @@ -35,36 +36,57 @@ struct DEF { class InnerClass : NSObject {} } +let suite = TestSuite("check_class_for_archiving") +defer { runAllTests() } + let op: Int32 = 0 // archiving -// CHECK: SwiftClass: 0 -print("SwiftClass: \(NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(SwiftClass.self, operation: op))") -// CHECK: ObjcClass: 0 -print("ObjcClass: \(NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(ObjcClass.self, operation: op))") -// CHECK: NamedClass1: 0 -print("NamedClass1: \(NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(NamedClass1.self, operation: op))") -// CHECK: NamedClass2: 0 -print("NamedClass2: \(NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(NamedClass2.self, operation: op))") -// CHECK: DerivedClass: 0 -print("DerivedClass: \(NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(DerivedClass.self, operation: op))") -// CHECK: DerivedClassWithName: 0 -print("DerivedClassWithName: \(NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(DerivedClass.self, operation: op))") -// CHECK: NSKeyedUnarchiver: 0 -print("NSKeyedUnarchiver: \(NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(NSKeyedUnarchiver.self, operation: op))") +suite.test("SwiftClass") { + expectEqual(0, NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(SwiftClass.self, operation: op)) +} +suite.test("ObjcClass") { + expectEqual(0, NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(ObjcClass.self, operation: op)) +} +suite.test("NamedClass1") { + expectEqual(0, NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(NamedClass1.self, operation: op)) +} +suite.test("NamedClass2") { + expectEqual(0, NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(NamedClass2.self, operation: op)) +} +suite.test("DerivedClass") { + expectEqual(0, NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(DerivedClass.self, operation: op)) +} +suite.test("DerivedClassWithName") { + expectEqual(0, NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(DerivedClassWithName.self, operation: op)) +} +suite.test("NSKeyedUnarchiver") { + expectEqual(0, NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(NSKeyedUnarchiver.self, operation: op)) +} +// Disable negative tests on older OSes because of rdar://problem/50504765 if #available(iOS 13, macOS 10.15, tvOS 13, watchOS 6, *) { - // CHECK: PrivateClass: 2 - print("PrivateClass: \(NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(PrivateClass.self, operation: op))") - // CHECK: GenericClass: 1 - print("GenericClass: \(NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(GenericClass.self, operation: op))") - // CHECK: InnerClass: 2 - print("InnerClass: \(NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(ABC.InnerClass.self, operation: op))") - // CHECK: InnerClass2: 1 - print("InnerClass2: \(NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(DEF.InnerClass.self, operation: op))") -} else { - // Disable the checks for older OSes because of rdar://problem/50504765 - print("PrivateClass: 2") - print("GenericClass: 1") - print("InnerClass: 2") - print("InnerClass2: 1") + suite.test("PrivateClass") { + expectNotEqual(0, NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(PrivateClass.self, operation: op)) + } +} + +if #available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *) { + // Generic classes and nested classes were considered to have unstable names + // in earlier releases. + suite.test("GenericClass") { + expectEqual(0, NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(GenericClass.self, operation: op)) + } + suite.test("InnerClass") { + print(NSStringFromClass(ABC.InnerClass.self)) + expectEqual(0, NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(ABC.InnerClass.self, operation: op)) + } + suite.test("InnerClass2") { + print(NSStringFromClass(DEF.InnerClass.self)) + expectEqual(0, NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(DEF.InnerClass.self, operation: op)) + } + + suite.test("LocalClass") { + class LocalClass: NSObject {} + expectNotEqual(0, NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(LocalClass.self, operation: op)) + } } diff --git a/test/Interpreter/SDK/check_class_for_archiving_log.swift b/test/Interpreter/SDK/check_class_for_archiving_log.swift index bce3b64c99c38..4e5810959485c 100644 --- a/test/Interpreter/SDK/check_class_for_archiving_log.swift +++ b/test/Interpreter/SDK/check_class_for_archiving_log.swift @@ -24,162 +24,132 @@ if #available(iOS 13, macOS 10.15, tvOS 13, watchOS 6, *) { class SwiftClass {} -func checkArchiving(_ cls: AnyObject.Type) { - NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(cls, operation: 0) +func _check(_ label: String, _ cls: AnyObject.Type, _ op: CInt) { + NSLog("--%@ start", label) + NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(cls, operation: op) + NSLog("--%@ end", label) } -func checkUnarchiving(_ cls: AnyObject.Type) { - NSKeyedUnarchiver._swift_checkClassAndWarnForKeyedArchiving(cls, operation: 1) +func checkArchiving(_ label: String, _ cls: AnyObject.Type) { + _check(label, cls, 0) } - -func mark(line: Int32 = #line) { - NSLog("--%d--", line) +func checkUnarchiving(_ label: String, _ cls: AnyObject.Type) { + _check(label, cls, 1) } -mark() // CHECK: --[[@LINE]]-- -checkArchiving(SwiftClass.self) -mark() // CHECK-NEXT: --[[@LINE]]-- + +// CHECK-LABEL: --SwiftClass start +checkArchiving("SwiftClass", SwiftClass.self) +// CHECK-NEXT: --SwiftClass end private class ArchivedTwice {} -checkArchiving(ArchivedTwice.self) -// CHECK-NEXT: Attempting to archive Swift class '_Test.({{.+}}).ArchivedTwice' with mangled runtime name '_TtC{{.+[0-9]+}}ArchivedTwice' -// CHECK-NEXT: @objc(_TtC{{.+[0-9]+}}ArchivedTwice) -// CHECK-NEXT: @objc(ABCArchivedTwice) -mark() // CHECK-NEXT: --[[@LINE]]-- -checkArchiving(ArchivedTwice.self) -mark() // CHECK-NEXT: --[[@LINE]]-- +// CHECK-LABEL: --ArchivedTwice1 start +checkArchiving("ArchivedTwice1", ArchivedTwice.self) +// CHECK: Attempting to archive Swift class '_Test.({{.+}}).ArchivedTwice' with {{.+}} runtime name '_TtC{{.+[0-9]+}}ArchivedTwice' +// CHECK: @objc(_TtC{{.+[0-9]+}}ArchivedTwice) + +// CHECK-LABEL: --ArchivedTwice2 start +checkArchiving("ArchivedTwice2", ArchivedTwice.self) +// CHECK-NEXT: --ArchivedTwice2 end private class UnarchivedTwice {} -checkUnarchiving(UnarchivedTwice.self) -// CHECK-NEXT: Attempting to unarchive Swift class '_Test.({{.+}}).UnarchivedTwice' with mangled runtime name '_TtC{{.+[0-9]+}}UnarchivedTwice' -// CHECK-NEXT: @objc(_TtC{{.+[0-9]+}}UnarchivedTwice) -// CHECK-NEXT: @objc(ABCUnarchivedTwice) -mark() // CHECK-NEXT: --[[@LINE]]-- -checkUnarchiving(UnarchivedTwice.self) -mark() // CHECK-NEXT: --[[@LINE]]-- +// CHECK-LABEL: --UnarchivedTwice1 start +checkUnarchiving("UnarchivedTwice1", UnarchivedTwice.self) +// CHECK: Attempting to unarchive Swift class '_Test.({{.+}}).UnarchivedTwice' with {{.+}} runtime name '_TtC{{.+[0-9]+}}UnarchivedTwice' +// CHECK: @objc(_TtC{{.+[0-9]+}}UnarchivedTwice) + +// CHECK-LABEL: --UnarchivedTwice2 start +checkUnarchiving("UnarchivedTwice2", UnarchivedTwice.self) +// CHECK-NEXT: --UnarchivedTwice2 end private class ArchivedThenUnarchived {} -checkArchiving(ArchivedThenUnarchived.self) -// CHECK-NEXT: Attempting to archive Swift class '_Test.({{.+}}).ArchivedThenUnarchived' with mangled runtime name '_TtC{{.+[0-9]+}}ArchivedThenUnarchived' -// CHECK-NEXT: @objc(_TtC{{.+[0-9]+}}ArchivedThenUnarchived) -// CHECK-NEXT: @objc(ABCArchivedThenUnarchived) -mark() // CHECK-NEXT: --[[@LINE]]-- -checkUnarchiving(ArchivedThenUnarchived.self) -mark() // CHECK-NEXT: --[[@LINE]]-- +// CHECK-LABEL: --ArchivedThenUnarchived1 start +checkArchiving("ArchivedThenUnarchived1", ArchivedThenUnarchived.self) +// CHECK: Attempting to archive Swift class '_Test.({{.+}}).ArchivedThenUnarchived' with {{.+}} runtime name '_TtC{{.+[0-9]+}}ArchivedThenUnarchived' +// CHECK: @objc(_TtC{{.+[0-9]+}}ArchivedThenUnarchived) + +// CHECK-LABEL: --ArchivedThenUnarchived2 start +checkUnarchiving("ArchivedThenUnarchived2", ArchivedThenUnarchived.self) +// CHECK-NEXT: --ArchivedThenUnarchived2 end private class UnarchivedThenArchived {} -checkUnarchiving(UnarchivedThenArchived.self) -// CHECK-NEXT: Attempting to unarchive Swift class '_Test.({{.+}}).UnarchivedThenArchived' with mangled runtime name '_TtC{{.+[0-9]+}}UnarchivedThenArchived' -// CHECK-NEXT: @objc(_TtC{{.+[0-9]+}}UnarchivedThenArchived) -// CHECK-NEXT: @objc(ABCUnarchivedThenArchived) -mark() // CHECK-NEXT: --[[@LINE]]-- -checkArchiving(UnarchivedThenArchived.self) -mark() // CHECK-NEXT: --[[@LINE]]-- +// CHECK-LABEL: --UnarchivedThenArchived1 start +checkUnarchiving("UnarchivedThenArchived1", UnarchivedThenArchived.self) +// CHECK: Attempting to unarchive Swift class '_Test.({{.+}}).UnarchivedThenArchived' with {{.+}} runtime name '_TtC{{.+[0-9]+}}UnarchivedThenArchived' +// CHECK: @objc(_TtC{{.+[0-9]+}}UnarchivedThenArchived) -class Outer { +// CHECK-LABEL: --UnarchivedThenArchived2 start +checkArchiving("UnarchivedThenArchived2", UnarchivedThenArchived.self) +// CHECK-NEXT: --UnarchivedThenArchived2 end + +private class Outer { class ArchivedTwice {} class UnarchivedTwice {} class ArchivedThenUnarchived {} class UnarchivedThenArchived {} } -checkArchiving(Outer.ArchivedTwice.self) -// CHECK-NEXT: Attempting to archive Swift class '_Test.Outer.ArchivedTwice' -// CHECK-NEXT: @objc(_TtC{{.+[0-9]+}}ArchivedTwice) -// CHECK-NEXT: @objc(ABCArchivedTwice) -mark() // CHECK-NEXT: --[[@LINE]]-- -checkArchiving(Outer.ArchivedTwice.self) -mark() // CHECK-NEXT: --[[@LINE]]-- - -checkUnarchiving(Outer.UnarchivedTwice.self) -// CHECK-NEXT: Attempting to unarchive Swift class '_Test.Outer.UnarchivedTwice' -// CHECK-NEXT: @objc(_TtC{{.+[0-9]+}}UnarchivedTwice) -// CHECK-NEXT: @objc(ABCUnarchivedTwice) -mark() // CHECK-NEXT: --[[@LINE]]-- -checkUnarchiving(Outer.UnarchivedTwice.self) -mark() // CHECK-NEXT: --[[@LINE]]-- - -checkArchiving(Outer.ArchivedThenUnarchived.self) -// CHECK-NEXT: Attempting to archive Swift class '_Test.Outer.ArchivedThenUnarchived' -// CHECK-NEXT: @objc(_TtC{{.+[0-9]+}}ArchivedThenUnarchived) -// CHECK-NEXT: @objc(ABCArchivedThenUnarchived) -mark() // CHECK-NEXT: --[[@LINE]]-- -checkUnarchiving(Outer.ArchivedThenUnarchived.self) -mark() // CHECK-NEXT: --[[@LINE]]-- - -checkUnarchiving(Outer.UnarchivedThenArchived.self) -// CHECK-NEXT: Attempting to unarchive Swift class '_Test.Outer.UnarchivedThenArchived' -// CHECK-NEXT: @objc(_TtC{{.+[0-9]+}}UnarchivedThenArchived) -// CHECK-NEXT: @objc(ABCUnarchivedThenArchived) -mark() // CHECK-NEXT: --[[@LINE]]-- -checkArchiving(Outer.UnarchivedThenArchived.self) -mark() // CHECK-NEXT: --[[@LINE]]-- +// CHECK-LABEL: --Outer.ArchivedTwice1 start +checkArchiving("Outer.ArchivedTwice1", Outer.ArchivedTwice.self) +// CHECK: Attempting to archive Swift class '_Test.({{.+}}).Outer.ArchivedTwice' +// CHECK: @objc(_TtC{{.+[0-9]+}}ArchivedTwice) + +// CHECK-LABEL: --Outer.ArchivedTwice2 start +checkArchiving("Outer.ArchivedTwice2", Outer.ArchivedTwice.self) +// CHECK-NEXT: --Outer.ArchivedTwice2 end + +// CHECK-LABEL: --Outer.UnarchivedTwice1 start +checkUnarchiving("Outer.UnarchivedTwice1", Outer.UnarchivedTwice.self) +// CHECK: Attempting to unarchive Swift class '_Test.({{.+}}).Outer.UnarchivedTwice' +// CHECK: @objc(_TtC{{.+[0-9]+}}UnarchivedTwice) + +// CHECK-LABEL: --Outer.UnarchivedTwice2 start +checkUnarchiving("Outer.UnarchivedTwice2", Outer.UnarchivedTwice.self) +// CHECK-NEXT: --Outer.UnarchivedTwice2 end + +// CHECK-LABEL: --Outer.ArchivedThenUnarchived1 start +checkArchiving("Outer.ArchivedThenUnarchived1", Outer.ArchivedThenUnarchived.self) +// CHECK: Attempting to archive Swift class '_Test.({{.+}}).Outer.ArchivedThenUnarchived' +// CHECK: @objc(_TtC{{.+[0-9]+}}ArchivedThenUnarchived) + +// CHECK-LABEL: --Outer.ArchivedThenUnarchived2 start +checkUnarchiving("Outer.ArchivedThenUnarchived2", Outer.ArchivedThenUnarchived.self) +// CHECK-NEXT: --Outer.ArchivedThenUnarchived2 end + +// CHECK-LABEL: --Outer.UnarchivedThenArchived1 start +checkUnarchiving("Outer.UnarchivedThenArchived1", Outer.UnarchivedThenArchived.self) +// CHECK: Attempting to unarchive Swift class '_Test.({{.+}}).Outer.UnarchivedThenArchived' +// CHECK: @objc(_TtC{{.+[0-9]+}}UnarchivedThenArchived) + +// CHECK-LABEL: --Outer.UnarchivedThenArchived2 start +checkArchiving("Outer.UnarchivedThenArchived2", Outer.UnarchivedThenArchived.self) +// CHECK-NEXT: --Outer.UnarchivedThenArchived2 end private class 日本語 {} -checkArchiving(日本語.self) -// CHECK-NEXT: Attempting to archive Swift class '_Test.({{.*}}).日本語' -// CHECK-NEXT: @objc(_TtC{{.+[0-9]+}}9日本語) -// CHECK-NEXT: @objc(ABCMyModel) -mark() // CHECK-NEXT: --[[@LINE]]-- -checkArchiving(日本語.self) -mark() // CHECK-NEXT: --[[@LINE]]-- - - -class ArchivedTwiceGeneric {} - -checkArchiving(ArchivedTwiceGeneric.self) -// CHECK-NEXT: Attempting to archive generic Swift class '_Test.ArchivedTwiceGeneric' with mangled runtime name '_TtGC5_Test20ArchivedTwiceGenericSi_' -// CHECK-NEXT: NSKeyedUnarchiver.setClass(_:forClassName:) -// CHECK-SAME: _TtGC5_Test20ArchivedTwiceGenericSi_ -// CHECK-NEXT: NSKeyedArchiver.setClassName(_:for:) -mark() // CHECK-NEXT: --[[@LINE]]-- -checkArchiving(ArchivedTwiceGeneric.self) -mark() // CHECK-NEXT: --[[@LINE]]-- - -checkArchiving(ArchivedTwiceGeneric.self) -// CHECK-NEXT: Attempting to archive generic Swift class '_Test.ArchivedTwiceGeneric<__C.NSObject>' with mangled runtime name '_TtGC5_Test20ArchivedTwiceGenericCSo8NSObject_' -// CHECK-NEXT: NSKeyedUnarchiver.setClass(_:forClassName:) -// CHECK-SAME: _TtGC5_Test20ArchivedTwiceGenericCSo8NSObject_ -// CHECK-NEXT: NSKeyedArchiver.setClassName(_:for:) -mark() // CHECK-NEXT: --[[@LINE]]-- -checkArchiving(ArchivedTwiceGeneric.self) -mark() // CHECK-NEXT: --[[@LINE]]-- - -class UnarchivedTwiceGeneric {} - -checkUnarchiving(UnarchivedTwiceGeneric.self) -// CHECK-NEXT: Attempting to unarchive generic Swift class '_Test.UnarchivedTwiceGeneric' with mangled runtime name '_TtGC5_Test22UnarchivedTwiceGenericSi_' -// CHECK-NEXT: NSKeyedUnarchiver.setClass(_:forClassName:) -// CHECK-SAME: _TtGC5_Test22UnarchivedTwiceGenericSi_ -// CHECK-NEXT: NSKeyedArchiver.setClassName(_:for:) -mark() // CHECK-NEXT: --[[@LINE]]-- -checkUnarchiving(UnarchivedTwiceGeneric.self) -mark() // CHECK-NEXT: --[[@LINE]]-- - -class ArchivedThenUnarchivedGeneric {} - -checkArchiving(ArchivedThenUnarchivedGeneric.self) -// CHECK-NEXT: Attempting to archive generic Swift class '_Test.ArchivedThenUnarchivedGeneric' with mangled runtime name '_TtGC5_Test29ArchivedThenUnarchivedGenericSi_' -// CHECK-NEXT: NSKeyedUnarchiver.setClass(_:forClassName:) -// CHECK-SAME: _TtGC5_Test29ArchivedThenUnarchivedGenericSi_ -// CHECK-NEXT: NSKeyedArchiver.setClassName(_:for:) -mark() // CHECK-NEXT: --[[@LINE]]-- -checkUnarchiving(ArchivedThenUnarchivedGeneric.self) -mark() // CHECK-NEXT: --[[@LINE]]-- - -class UnarchivedThenArchivedGeneric {} - -checkUnarchiving(UnarchivedThenArchivedGeneric.self) -// CHECK-NEXT: Attempting to unarchive generic Swift class '_Test.UnarchivedThenArchivedGeneric' with mangled runtime name '_TtGC5_Test29UnarchivedThenArchivedGenericSi_' -// CHECK-NEXT: NSKeyedUnarchiver.setClass(_:forClassName:) -// CHECK-SAME: _TtGC5_Test29UnarchivedThenArchivedGenericSi_ -// CHECK-NEXT: NSKeyedArchiver.setClassName(_:for:) -mark() // CHECK-NEXT: --[[@LINE]]-- -checkArchiving(UnarchivedThenArchivedGeneric.self) -mark() // CHECK-NEXT: --[[@LINE]]-- +// CHECK-LABEL: --Japanese1 start +checkArchiving("Japanese1", 日本語.self) +// CHECK: Attempting to archive Swift class '_Test.({{.*}}).日本語' + +// CHECK-LABEL: --Japanese2 start +checkArchiving("Japanese2", 日本語.self) +// CHECK-NEXT: --Japanese2 end + +func someFunction() { + class LocalArchived: NSObject {} + class LocalUnarchived: NSObject {} + + // CHECK-LABEL: --LocalArchived start + checkArchiving("LocalArchived", LocalArchived.self) + // CHECK: Attempting to archive Swift class '_Test.({{.+}}).LocalArchived' + + // CHECK-LABEL: --LocalUnarchived start + checkUnarchiving("LocalUnarchived", LocalUnarchived.self) + // CHECK: Attempting to unarchive Swift class '_Test.({{.+}}).LocalUnarchived' +} +someFunction() From f361b250de94acea47abba43eac038c167635cf0 Mon Sep 17 00:00:00 2001 From: Artem Chikin Date: Tue, 7 Jul 2020 17:11:55 -0700 Subject: [PATCH 03/36] [Explicit Module Builds] Add canImport functionality to the ExplicitSwiftModuleLoader It needs to check against the provided ExplicitModuleMap instead of looking into search paths. --- include/swift/Frontend/ModuleInterfaceLoader.h | 2 ++ lib/Frontend/ModuleInterfaceLoader.cpp | 11 +++++++++++ test/ScanDependencies/can_import_with_map.swift | 17 +++++++++++++++++ 3 files changed, 30 insertions(+) create mode 100644 test/ScanDependencies/can_import_with_map.swift diff --git a/include/swift/Frontend/ModuleInterfaceLoader.h b/include/swift/Frontend/ModuleInterfaceLoader.h index 626febd3da678..2f68343f38b3e 100644 --- a/include/swift/Frontend/ModuleInterfaceLoader.h +++ b/include/swift/Frontend/ModuleInterfaceLoader.h @@ -141,6 +141,8 @@ class ExplicitSwiftModuleLoader: public SerializedModuleLoaderBase { std::unique_ptr *ModuleDocBuffer, std::unique_ptr *ModuleSourceInfoBuffer) override; + bool canImportModule(Located mID) override; + bool isCached(StringRef DepPath) override { return false; }; struct Implementation; diff --git a/lib/Frontend/ModuleInterfaceLoader.cpp b/lib/Frontend/ModuleInterfaceLoader.cpp index 3f277e08fa868..9f64c7f5ea01a 100644 --- a/lib/Frontend/ModuleInterfaceLoader.cpp +++ b/lib/Frontend/ModuleInterfaceLoader.cpp @@ -1514,6 +1514,17 @@ std::error_code ExplicitSwiftModuleLoader::findModuleFilesInDirectory( return std::error_code(); } +bool ExplicitSwiftModuleLoader::canImportModule( + Located mID) { + StringRef moduleName = mID.Item.str(); + auto it = Impl.ExplicitModuleMap.find(moduleName); + // If no provided explicit module matches the name, then it cannot be imported. + if (it == Impl.ExplicitModuleMap.end()) { + return false; + } + return true; +} + void ExplicitSwiftModuleLoader::collectVisibleTopLevelModuleNames( SmallVectorImpl &names) const { for (auto &entry: Impl.ExplicitModuleMap) { diff --git a/test/ScanDependencies/can_import_with_map.swift b/test/ScanDependencies/can_import_with_map.swift new file mode 100644 index 0000000000000..a65d5bd8e128e --- /dev/null +++ b/test/ScanDependencies/can_import_with_map.swift @@ -0,0 +1,17 @@ +// RUN: %empty-directory(%t) +// RUN: mkdir -p %t/clang-module-cache +// RUN: mkdir -p %t/inputs +// RUN: echo "public func foo() {}" >> %t/foo.swift +// RUN: %target-swift-frontend -emit-module -emit-module-path %t/inputs/Foo.swiftmodule -emit-module-doc-path %t/inputs/Foo.swiftdoc -emit-module-source-info -emit-module-source-info-path %t/inputs/Foo.swiftsourceinfo -module-cache-path %t.module-cache %t/foo.swift -module-name Foo + +// RUN: echo "[{" > %/t/inputs/map.json +// RUN: echo "\"moduleName\": \"Foo\"," >> %/t/inputs/map.json +// RUN: echo "\"modulePath\": \"%/t/inputs/Foo.swiftmodule\"," >> %/t/inputs/map.json +// RUN: echo "\"docPath\": \"%/t/inputs/Foo.swiftdoc\"," >> %/t/inputs/map.json +// RUN: echo "\"sourceInfoPath\": \"%/t/inputs/Foo.swiftsourceinfo\"" >> %/t/inputs/map.json +// RUN: echo "}]" >> %/t/inputs/map.json + +// RUN: %target-swift-frontend -typecheck %s -explicit-swift-module-map-file %t/inputs/map.json -disable-implicit-swift-modules +#if canImport(Foo) +import Foo +#endif From b614a4d33596cb5b63d5c564cca5e8b50e2c7a8a Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 8 Jul 2020 15:50:32 -0400 Subject: [PATCH 04/36] Sema: Don't look through nested typealiases when checking for unsupported member reference It appears that a long time ago, we didn't enforce that a member reference to a typealias nested inside a generic type would supply the generic arguments at all. To avoid breaking source compatibility, we carved out some exceptions. Tighten up the exception to prohibit the case where a typealias references another typealias, to fix a crash. While this worked in 5.1, it would crash in 5.2 and 5.3, and at this point it's more trouble than it is worth to make it work again, because of subtle representational issues. So let's just ban it. Fixes . --- lib/Sema/TypeCheckNameLookup.cpp | 3 +-- test/decl/typealias/dependent_types.swift | 6 +++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/lib/Sema/TypeCheckNameLookup.cpp b/lib/Sema/TypeCheckNameLookup.cpp index a5b1bd0bc9b3e..6068b1a2c527f 100644 --- a/lib/Sema/TypeCheckNameLookup.cpp +++ b/lib/Sema/TypeCheckNameLookup.cpp @@ -353,8 +353,7 @@ bool TypeChecker::isUnsupportedMemberTypeAccess(Type type, TypeDecl *typeDecl) { // underlying type is not dependent. if (auto *aliasDecl = dyn_cast(typeDecl)) { if (!aliasDecl->isGeneric() && - aliasDecl->getUnderlyingType()->getCanonicalType() - ->hasTypeParameter()) { + aliasDecl->getUnderlyingType()->hasTypeParameter()) { return true; } } diff --git a/test/decl/typealias/dependent_types.swift b/test/decl/typealias/dependent_types.swift index c854388e0f669..4334f3d71734e 100644 --- a/test/decl/typealias/dependent_types.swift +++ b/test/decl/typealias/dependent_types.swift @@ -23,11 +23,12 @@ struct X1 : P1 { } } -struct GenericStruct { // expected-note 2{{generic type 'GenericStruct' declared here}} +struct GenericStruct { // expected-note 3{{generic type 'GenericStruct' declared here}} typealias Alias = T typealias MetaAlias = T.Type typealias Concrete = Int + typealias ReferencesConcrete = Concrete func methodOne() -> Alias.Type {} func methodTwo() -> MetaAlias {} @@ -59,6 +60,9 @@ let _: GenericStruct.MetaAlias = metaFoo() // we are OK. let _: GenericStruct.Concrete = foo() +let _: GenericStruct.ReferencesConcrete = foo() +// expected-error@-1 {{reference to generic type 'GenericStruct' requires arguments in <...>}} + class SuperG { typealias Composed = (T, U) } From 9bcb54910e88ab9017d391362d860fe4fc3a49cf Mon Sep 17 00:00:00 2001 From: Nathan Hawes Date: Mon, 29 Jun 2020 14:57:21 -0700 Subject: [PATCH 05/36] [AST] Prefer the 'macOS' spelling over 'OSX' when printing the platform kind. This affects module interfaces, interface generation in sourcekitd, and diagnostics. Also fixes a fixit that was assuming the 'OSX' spelling when computing the source range to replace. Resolves rdar://problem/64667960 --- lib/AST/PlatformKind.cpp | 9 + lib/IDE/CodeCompletion.cpp | 10 +- lib/Parse/ParseStmt.cpp | 5 +- test/ClangImporter/objc_factory_method.swift | 2 +- ...it-interface-macos-canonical-version.swift | 4 +- ...to-print-availability.h.module.printed.txt | 16 +- ...header-to-print-availability.h.printed.txt | 16 +- test/IDE/print_ast_tc_decls.swift | 6 +- ...t_tc_decls_macosx_canonical_versions.swift | 2 +- .../print_swift_module_with_available.swift | 8 +- ...print_synthesized_extensions_nomerge.swift | 2 +- test/ModuleInterface/conformances.swift | 8 +- .../originally-defined-attr.swift | 4 +- test/Parse/availability_query.swift | 2 +- test/Parse/original_defined_in_attr.swift | 2 +- test/Sema/availability_versions.swift | 78 +++---- .../availability_versions_playgrounds.swift | 4 +- .../doc_swift_module.swift.response | 4 +- .../gen_clang_module.swift.response | 198 +++++++++--------- test/attr/Inputs/OldAndNew.swift | 4 +- test/attr/Inputs/PackageDescription.swift | 2 +- 21 files changed, 195 insertions(+), 191 deletions(-) diff --git a/lib/AST/PlatformKind.cpp b/lib/AST/PlatformKind.cpp index 2d283c951b654..370a9f83b821b 100644 --- a/lib/AST/PlatformKind.cpp +++ b/lib/AST/PlatformKind.cpp @@ -24,6 +24,15 @@ using namespace swift; StringRef swift::platformString(PlatformKind platform) { + // FIXME: Update PlatformKinds.def to use the macOS spelling by default. + switch (platform) { + case PlatformKind::OSX: + return "macOS"; + case PlatformKind::OSXApplicationExtension: + return "macOSApplicationExtension"; + default: break; + } + switch (platform) { case PlatformKind::none: return "*"; diff --git a/lib/IDE/CodeCompletion.cpp b/lib/IDE/CodeCompletion.cpp index aff9699acf673..db073e6604365 100644 --- a/lib/IDE/CodeCompletion.cpp +++ b/lib/IDE/CodeCompletion.cpp @@ -4429,14 +4429,10 @@ class CompletionLookup final : public swift::VisibleDeclConsumer { if (ParamIndex == 0) { addDeclAttrParamKeyword("*", "Platform", false); - // For code completion, suggest 'macOS' instead of 'OSX'. + // For code completion, suggest 'macOS' instead of 'OSX'. #define AVAILABILITY_PLATFORM(X, PrettyName) \ - if (StringRef(#X) == "OSX") \ - addDeclAttrParamKeyword("macOS", "Platform", false); \ - else if (StringRef(#X) == "OSXApplicationExtension") \ - addDeclAttrParamKeyword("macOSApplicationExtension", "Platform", false); \ - else \ - addDeclAttrParamKeyword(#X, "Platform", false); + addDeclAttrParamKeyword(swift::platformString(PlatformKind::X), \ + "Platform", false); #include "swift/AST/PlatformKinds.def" } else { diff --git a/lib/Parse/ParseStmt.cpp b/lib/Parse/ParseStmt.cpp index c0cfc7da3de43..893a61f48b548 100644 --- a/lib/Parse/ParseStmt.cpp +++ b/lib/Parse/ParseStmt.cpp @@ -1350,10 +1350,9 @@ Parser::parseAvailabilitySpecList(SmallVectorImpl &Specs) { auto *PlatformSpec = cast(Previous); - auto PlatformName = platformString(PlatformSpec->getPlatform()); auto PlatformNameEndLoc = - PlatformSpec->getPlatformLoc().getAdvancedLoc( - PlatformName.size()); + Lexer::getLocForEndOfToken(SourceManager, + PlatformSpec->getPlatformLoc()); diagnose(PlatformSpec->getPlatformLoc(), diag::avail_query_meant_introduced) diff --git a/test/ClangImporter/objc_factory_method.swift b/test/ClangImporter/objc_factory_method.swift index 4fcfb4d9e7d38..e71873fefab1a 100644 --- a/test/ClangImporter/objc_factory_method.swift +++ b/test/ClangImporter/objc_factory_method.swift @@ -45,7 +45,7 @@ func testFactoryWithLaterIntroducedInit() { // expected-note @-1 {{add 'if #available' version check}} _ = NSHavingConvenienceFactoryAndLaterDesignatedInit(flam:5) // expected-error {{'init(flam:)' is only available in macOS 10.52 or newer}} - // expected-note @-1 {{add 'if #available' version check}} {{3-63=if #available(OSX 10.52, *) {\n _ = NSHavingConvenienceFactoryAndLaterDesignatedInit(flam:5)\n \} else {\n // Fallback on earlier versions\n \}}} + // expected-note @-1 {{add 'if #available' version check}} {{3-63=if #available(macOS 10.52, *) {\n _ = NSHavingConvenienceFactoryAndLaterDesignatedInit(flam:5)\n \} else {\n // Fallback on earlier versions\n \}}} // Don't prefer more available factory initializer over less diff --git a/test/Frontend/emit-interface-macos-canonical-version.swift b/test/Frontend/emit-interface-macos-canonical-version.swift index e21c9348887f2..75aadc9ebbdb2 100644 --- a/test/Frontend/emit-interface-macos-canonical-version.swift +++ b/test/Frontend/emit-interface-macos-canonical-version.swift @@ -6,13 +6,13 @@ @available(macOS 10.16, *) public func introduced10_16() { } -// CHECK: @available(OSX 11.0, *) +// CHECK: @available(macOS 11.0, *) // CHECK-NEXT: public func introduced10_16() @available(OSX 11.0, *) public func introduced11_0() { } -// CHECK-NEXT: @available(OSX 11.0, *) +// CHECK-NEXT: @available(macOS 11.0, *) // CHECK-NEXT: public func introduced11_0() diff --git a/test/IDE/Inputs/print_clang_header/header-to-print-availability.h.module.printed.txt b/test/IDE/Inputs/print_clang_header/header-to-print-availability.h.module.printed.txt index 327c11d0e28e1..4b533febd394c 100644 --- a/test/IDE/Inputs/print_clang_header/header-to-print-availability.h.module.printed.txt +++ b/test/IDE/Inputs/print_clang_header/header-to-print-availability.h.module.printed.txt @@ -1,18 +1,18 @@ class MaybeAvailable { - @available(OSX 10.1, *) + @available(macOS 10.1, *) class func method1() - @available(OSX 10.1, *) + @available(macOS 10.1, *) func method1() - @available(OSX 10.1, *) + @available(macOS 10.1, *) class func method2() - @available(OSX 10.1, *) + @available(macOS 10.1, *) func method2() - @available(OSX, deprecated: 10.10) + @available(macOS, deprecated: 10.10) class func method3() - @available(OSX, deprecated: 10.10) + @available(macOS, deprecated: 10.10) func method3() - @available(OSX, introduced: 10.1, deprecated: 10.10, obsoleted: 10.11) + @available(macOS, introduced: 10.1, deprecated: 10.10, obsoleted: 10.11) class func method4() - @available(OSX, introduced: 10.1, deprecated: 10.10, obsoleted: 10.11) + @available(macOS, introduced: 10.1, deprecated: 10.10, obsoleted: 10.11) func method4() } diff --git a/test/IDE/Inputs/print_clang_header/header-to-print-availability.h.printed.txt b/test/IDE/Inputs/print_clang_header/header-to-print-availability.h.printed.txt index 327c11d0e28e1..4b533febd394c 100644 --- a/test/IDE/Inputs/print_clang_header/header-to-print-availability.h.printed.txt +++ b/test/IDE/Inputs/print_clang_header/header-to-print-availability.h.printed.txt @@ -1,18 +1,18 @@ class MaybeAvailable { - @available(OSX 10.1, *) + @available(macOS 10.1, *) class func method1() - @available(OSX 10.1, *) + @available(macOS 10.1, *) func method1() - @available(OSX 10.1, *) + @available(macOS 10.1, *) class func method2() - @available(OSX 10.1, *) + @available(macOS 10.1, *) func method2() - @available(OSX, deprecated: 10.10) + @available(macOS, deprecated: 10.10) class func method3() - @available(OSX, deprecated: 10.10) + @available(macOS, deprecated: 10.10) func method3() - @available(OSX, introduced: 10.1, deprecated: 10.10, obsoleted: 10.11) + @available(macOS, introduced: 10.1, deprecated: 10.10, obsoleted: 10.11) class func method4() - @available(OSX, introduced: 10.1, deprecated: 10.10, obsoleted: 10.11) + @available(macOS, introduced: 10.1, deprecated: 10.10, obsoleted: 10.11) func method4() } diff --git a/test/IDE/print_ast_tc_decls.swift b/test/IDE/print_ast_tc_decls.swift index 59bc8d69f22cc..d6d4150c41e7f 100644 --- a/test/IDE/print_ast_tc_decls.swift +++ b/test/IDE/print_ast_tc_decls.swift @@ -516,19 +516,19 @@ class d0170_TestAvailability { @available(OSX, unavailable) func f3() {} // PASS_COMMON-NEXT: {{^}} @available(iOS, unavailable){{$}} -// PASS_COMMON-NEXT: {{^}} @available(OSX, unavailable){{$}} +// PASS_COMMON-NEXT: {{^}} @available(macOS, unavailable){{$}} // PASS_COMMON-NEXT: {{^}} func f3(){{$}} @available(iOS 8.0, OSX 10.10, *) func f4() {} -// PASS_COMMON-NEXT: {{^}} @available(iOS 8.0, OSX 10.10, *){{$}} +// PASS_COMMON-NEXT: {{^}} @available(iOS 8.0, macOS 10.10, *){{$}} // PASS_COMMON-NEXT: {{^}} func f4(){{$}} // Convert long-form @available() to short form when possible. @available(iOS, introduced: 8.0) @available(OSX, introduced: 10.10) func f5() {} -// PASS_COMMON-NEXT: {{^}} @available(iOS 8.0, OSX 10.10, *){{$}} +// PASS_COMMON-NEXT: {{^}} @available(iOS 8.0, macOS 10.10, *){{$}} // PASS_COMMON-NEXT: {{^}} func f5(){{$}} } diff --git a/test/IDE/print_ast_tc_decls_macosx_canonical_versions.swift b/test/IDE/print_ast_tc_decls_macosx_canonical_versions.swift index 86a6565b740af..f51f8f2e173ca 100644 --- a/test/IDE/print_ast_tc_decls_macosx_canonical_versions.swift +++ b/test/IDE/print_ast_tc_decls_macosx_canonical_versions.swift @@ -12,5 +12,5 @@ @available(iOS 10.16, OSX 10.16, *) func introduced10_16() {} -// PASS_COMMON: {{^}}@available(iOS 10.16, OSX 11.0, *){{$}} +// PASS_COMMON: {{^}}@available(iOS 10.16, macOS 11.0, *){{$}} // PASS_COMMON-NEXT: {{^}}func introduced10_16(){{$}} diff --git a/test/IDE/print_swift_module_with_available.swift b/test/IDE/print_swift_module_with_available.swift index 930a3e28d3dc7..3345e9f1f7155 100644 --- a/test/IDE/print_swift_module_with_available.swift +++ b/test/IDE/print_swift_module_with_available.swift @@ -5,17 +5,17 @@ // REQUIRES: OS=macosx -@available(OSX 10.11, iOS 8.0, *) +@available(macOS 10.11, iOS 8.0, *) public class C1 { } -@available(OSX 10.12, *) +@available(macOS 10.12, *) public extension C1 { func ext_foo() {} } -// CHECK1: @available(OSX 10.11, iOS 8.0, *) +// CHECK1: @available(macOS 10.11, iOS 8.0, *) // CHECK1-NEXT: public class C1 { -// CHECK1: @available(OSX 10.12, *) +// CHECK1: @available(macOS 10.12, *) // CHECK1-NEXT: extension C1 { diff --git a/test/IDE/print_synthesized_extensions_nomerge.swift b/test/IDE/print_synthesized_extensions_nomerge.swift index aff0ee8c9ca49..d0a4101571aa2 100644 --- a/test/IDE/print_synthesized_extensions_nomerge.swift +++ b/test/IDE/print_synthesized_extensions_nomerge.swift @@ -15,7 +15,7 @@ public extension S1 { func bar() {} } -// CHECK1: @available(OSX 10.15, *) +// CHECK1: @available(macOS 10.15, *) // CHECK1: extension S1 { // CHECK1: @available(iOS 13, *) // CHECK1: extension S1 { diff --git a/test/ModuleInterface/conformances.swift b/test/ModuleInterface/conformances.swift index cc18f3ec33b2c..2557c79b9211b 100644 --- a/test/ModuleInterface/conformances.swift +++ b/test/ModuleInterface/conformances.swift @@ -196,20 +196,20 @@ extension Bool: ExtraHashable {} public struct CoolTVType: PrivateSubProto {} // CHECK: public struct CoolTVType { // CHECK-END: @available(iOS, unavailable) -// CHECK-END-NEXT: @available(OSX, unavailable) +// CHECK-END-NEXT: @available(macOS, unavailable) // CHECK-END-NEXT: extension conformances.CoolTVType : conformances.PublicBaseProto {} @available(macOS 10.99, *) public struct VeryNewMacType: PrivateSubProto {} // CHECK: public struct VeryNewMacType { -// CHECK-END: @available(OSX 10.99, *) +// CHECK-END: @available(macOS 10.99, *) // CHECK-END-NEXT: extension conformances.VeryNewMacType : conformances.PublicBaseProto {} public struct VeryNewMacProto {} @available(macOS 10.98, *) extension VeryNewMacProto: PrivateSubProto {} // CHECK: public struct VeryNewMacProto { -// CHECK-END: @available(OSX 10.98, *) +// CHECK-END: @available(macOS 10.98, *) // CHECK-END-NEXT: extension conformances.VeryNewMacProto : conformances.PublicBaseProto {} public struct PrivateProtoConformer {} @@ -237,7 +237,7 @@ public struct NestedAvailabilityOuter { } // CHECK-END: @available(swift 4.2.123) -// CHECK-END-NEXT: @available(OSX 10.97, iOS 23, *) +// CHECK-END-NEXT: @available(macOS 10.97, iOS 23, *) // CHECK-END-NEXT: @available(tvOS, unavailable) // CHECK-END-NEXT: extension conformances.NestedAvailabilityOuter.Inner : conformances.PublicBaseProto {} diff --git a/test/ModuleInterface/originally-defined-attr.swift b/test/ModuleInterface/originally-defined-attr.swift index 0940414f57e32..63949a9170e85 100644 --- a/test/ModuleInterface/originally-defined-attr.swift +++ b/test/ModuleInterface/originally-defined-attr.swift @@ -8,13 +8,13 @@ // RUN: %target-swift-ide-test -print-module -module-to-print Foo -I %t -source-filename %s > %t/printed-module.txt // RUN: %FileCheck %s < %t/printed-module.txt -// CHECK: @_originallyDefinedIn(module: "another", OSX 13.13) +// CHECK: @_originallyDefinedIn(module: "another", macOS 13.13) @available(OSX 10.8, *) @_originallyDefinedIn(module: "another", OSX 13.13) public protocol SimpleProto { } // CHECK: @_originallyDefinedIn(module: "original", tvOS 1.0) -// CHECK: @_originallyDefinedIn(module: "another_original", OSX 2.0) +// CHECK: @_originallyDefinedIn(module: "another_original", macOS 2.0) // CHECK: @_originallyDefinedIn(module: "another_original", iOS 3.0) // CHECK: @_originallyDefinedIn(module: "another_original", watchOS 4.0) @available(tvOS 0.7, OSX 1.1, iOS 2.1, watchOS 3.2, *) diff --git a/test/Parse/availability_query.swift b/test/Parse/availability_query.swift index eea8399c4ffaa..a6f91497402dc 100644 --- a/test/Parse/availability_query.swift +++ b/test/Parse/availability_query.swift @@ -46,7 +46,7 @@ if #available(iDishwasherOS 10.51) { // expected-warning {{unrecognized platform if #available(iDishwasherOS 10.51, *) { // expected-warning {{unrecognized platform name 'iDishwasherOS'}} } -if #available(OSX 10.51, OSX 10.52, *) { // expected-error {{version for 'OSX' already specified}} +if #available(OSX 10.51, OSX 10.52, *) { // expected-error {{version for 'macOS' already specified}} } if #available(OSX 10.52) { } // expected-error {{must handle potential future platforms with '*'}} {{24-24=, *}} diff --git a/test/Parse/original_defined_in_attr.swift b/test/Parse/original_defined_in_attr.swift index 56604c02ce9d9..253171d7d9a0c 100644 --- a/test/Parse/original_defined_in_attr.swift +++ b/test/Parse/original_defined_in_attr.swift @@ -22,7 +22,7 @@ class ToplevelClass3 {} @available(OSX 13.10, *) @_originallyDefinedIn(module: "foo", * 13.13) // expected-warning {{* as platform name has no effect}} expected-error {{expected at least one platform version in @_originallyDefinedIn}} @_originallyDefinedIn(module: "foo", OSX 13.13, iOS 7.0) -@_originallyDefinedIn(module: "foo", OSX 13.14, * 7.0) // expected-warning {{* as platform name has no effect}} expected-error {{duplicate version number for platform OSX}} +@_originallyDefinedIn(module: "foo", OSX 13.14, * 7.0) // expected-warning {{* as platform name has no effect}} expected-error {{duplicate version number for platform macOS}} class ToplevelClass4 { @_originallyDefinedIn(module: "foo", OSX 13.13) // expected-error {{'@_originallyDefinedIn' attribute cannot be applied to this declaration}} subscript(index: Int) -> Int { diff --git a/test/Sema/availability_versions.swift b/test/Sema/availability_versions.swift index e3adcb811d4cc..ad4159215becf 100644 --- a/test/Sema/availability_versions.swift +++ b/test/Sema/availability_versions.swift @@ -1017,7 +1017,7 @@ func functionWithDefaultAvailabilityAndUselessCheck(_ p: Bool) { if #available(OSX 10.51, *) { // expected-note {{enclosing scope here}} let _ = globalFuncAvailableOn10_51() - if #available(OSX 10.51, *) { // expected-warning {{unnecessary check for 'OSX'; enclosing scope ensures guard will always be true}} + if #available(OSX 10.51, *) { // expected-warning {{unnecessary check for 'macOS'; enclosing scope ensures guard will always be true}} let _ = globalFuncAvailableOn10_51() } } @@ -1025,7 +1025,7 @@ func functionWithDefaultAvailabilityAndUselessCheck(_ p: Bool) { if #available(OSX 10.9, *) { // expected-note {{enclosing scope here}} } else { // Make sure we generate a warning about an unnecessary check even if the else branch of if is dead. - if #available(OSX 10.51, *) { // expected-warning {{unnecessary check for 'OSX'; enclosing scope ensures guard will always be true}} + if #available(OSX 10.51, *) { // expected-warning {{unnecessary check for 'macOS'; enclosing scope ensures guard will always be true}} } } @@ -1033,7 +1033,7 @@ func functionWithDefaultAvailabilityAndUselessCheck(_ p: Bool) { if p { guard #available(OSX 10.9, *) else { // expected-note {{enclosing scope here}} // Make sure we generate a warning about an unnecessary check even if the else branch of guard is dead. - if #available(OSX 10.51, *) { // expected-warning {{unnecessary check for 'OSX'; enclosing scope ensures guard will always be true}} + if #available(OSX 10.51, *) { // expected-warning {{unnecessary check for 'macOS'; enclosing scope ensures guard will always be true}} } } } @@ -1079,11 +1079,11 @@ func functionWithUnavailableInDeadBranch() { @available(OSX, introduced: 10.51) func functionWithSpecifiedAvailabilityAndUselessCheck() { // expected-note 2{{enclosing scope here}} - if #available(OSX 10.9, *) { // expected-warning {{unnecessary check for 'OSX'; enclosing scope ensures guard will always be true}} + if #available(OSX 10.9, *) { // expected-warning {{unnecessary check for 'macOS'; enclosing scope ensures guard will always be true}} let _ = globalFuncAvailableOn10_9() } - if #available(OSX 10.51, *) { // expected-warning {{unnecessary check for 'OSX'; enclosing scope ensures guard will always be true}} + if #available(OSX 10.51, *) { // expected-warning {{unnecessary check for 'macOS'; enclosing scope ensures guard will always be true}} let _ = globalFuncAvailableOn10_51() } } @@ -1128,7 +1128,7 @@ if let _ = injectToOptional(globalFuncAvailableOn10_51()), #available(OSX 10.51, } if let _ = injectToOptional(5), #available(OSX 10.51, *), // expected-note {{enclosing scope here}} - let _ = injectToOptional(6), #available(OSX 10.51, *) { // expected-warning {{unnecessary check for 'OSX'; enclosing scope ensures guard will always be true}} + let _ = injectToOptional(6), #available(OSX 10.51, *) { // expected-warning {{unnecessary check for 'macOS'; enclosing scope ensures guard will always be true}} } @@ -1148,7 +1148,7 @@ func useGuardAvailable() { let _ = globalFuncAvailableOn10_52() // expected-error {{'globalFuncAvailableOn10_52()' is only available in macOS 10.52 or newer}} // expected-note@-1 {{add 'if #available' version check}} - if #available(OSX 10.51, *) { // expected-warning {{unnecessary check for 'OSX'; enclosing scope ensures guard will always be true}} + if #available(OSX 10.51, *) { // expected-warning {{unnecessary check for 'macOS'; enclosing scope ensures guard will always be true}} } if globalFuncAvailableOn10_51() > 0 { @@ -1196,7 +1196,7 @@ while #available(OSX 10.51, *), // expected-note {{enclosing scope here}} let _ = globalFuncAvailableOn10_52(); } - while #available(OSX 10.51, *) { // expected-warning {{unnecessary check for 'OSX'; enclosing scope ensures guard will always be true}} + while #available(OSX 10.51, *) { // expected-warning {{unnecessary check for 'macOS'; enclosing scope ensures guard will always be true}} } } @@ -1208,34 +1208,34 @@ while #available(OSX 10.51, *), // expected-note {{enclosing scope here}} functionAvailableOn10_51() // expected-error@-1 {{'functionAvailableOn10_51()' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add 'if #available' version check}} {{1-27=if #available(OSX 10.51, *) {\n functionAvailableOn10_51()\n} else {\n // Fallback on earlier versions\n}}} + // expected-note@-2 {{add 'if #available' version check}} {{1-27=if #available(macOS 10.51, *) {\n functionAvailableOn10_51()\n} else {\n // Fallback on earlier versions\n}}} let declForFixitAtTopLevel: ClassAvailableOn10_51? = nil // expected-error@-1 {{'ClassAvailableOn10_51' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add 'if #available' version check}} {{1-57=if #available(OSX 10.51, *) {\n let declForFixitAtTopLevel: ClassAvailableOn10_51? = nil\n} else {\n // Fallback on earlier versions\n}}} + // expected-note@-2 {{add 'if #available' version check}} {{1-57=if #available(macOS 10.51, *) {\n let declForFixitAtTopLevel: ClassAvailableOn10_51? = nil\n} else {\n // Fallback on earlier versions\n}}} func fixitForReferenceInGlobalFunction() { - // expected-note@-1 {{add @available attribute to enclosing global function}} {{1-1=@available(OSX 10.51, *)\n}} + // expected-note@-1 {{add @available attribute to enclosing global function}} {{1-1=@available(macOS 10.51, *)\n}} functionAvailableOn10_51() // expected-error@-1 {{'functionAvailableOn10_51()' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add 'if #available' version check}} {{3-29=if #available(OSX 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} + // expected-note@-2 {{add 'if #available' version check}} {{3-29=if #available(macOS 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} } public func fixitForReferenceInGlobalFunctionWithDeclModifier() { - // expected-note@-1 {{add @available attribute to enclosing global function}} {{1-1=@available(OSX 10.51, *)\n}} + // expected-note@-1 {{add @available attribute to enclosing global function}} {{1-1=@available(macOS 10.51, *)\n}} functionAvailableOn10_51() // expected-error@-1 {{'functionAvailableOn10_51()' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add 'if #available' version check}} {{3-29=if #available(OSX 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} + // expected-note@-2 {{add 'if #available' version check}} {{3-29=if #available(macOS 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} } func fixitForReferenceInGlobalFunctionWithAttribute() -> Never { - // expected-note@-1 {{add @available attribute to enclosing global function}} {{1-1=@available(OSX 10.51, *)\n}} + // expected-note@-1 {{add @available attribute to enclosing global function}} {{1-1=@available(macOS 10.51, *)\n}} _ = 0 // Avoid treating the call to functionAvailableOn10_51 as an implicit return functionAvailableOn10_51() // expected-error@-1 {{'functionAvailableOn10_51()' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add 'if #available' version check}} {{3-29=if #available(OSX 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} + // expected-note@-2 {{add 'if #available' version check}} {{3-29=if #available(macOS 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} } @@ -1243,40 +1243,40 @@ func takesAutoclosure(_ c : @autoclosure () -> ()) { } class ClassForFixit { - // expected-note@-1 12{{add @available attribute to enclosing class}} {{1-1=@available(OSX 10.51, *)\n}} + // expected-note@-1 12{{add @available attribute to enclosing class}} {{1-1=@available(macOS 10.51, *)\n}} func fixitForReferenceInMethod() { - // expected-note@-1 {{add @available attribute to enclosing instance method}} {{3-3=@available(OSX 10.51, *)\n }} + // expected-note@-1 {{add @available attribute to enclosing instance method}} {{3-3=@available(macOS 10.51, *)\n }} functionAvailableOn10_51() // expected-error@-1 {{'functionAvailableOn10_51()' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add 'if #available' version check}} {{5-31=if #available(OSX 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} + // expected-note@-2 {{add 'if #available' version check}} {{5-31=if #available(macOS 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} } func fixitForReferenceNestedInMethod() { - // expected-note@-1 3{{add @available attribute to enclosing instance method}} {{3-3=@available(OSX 10.51, *)\n }} + // expected-note@-1 3{{add @available attribute to enclosing instance method}} {{3-3=@available(macOS 10.51, *)\n }} func inner() { functionAvailableOn10_51() // expected-error@-1 {{'functionAvailableOn10_51()' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add 'if #available' version check}} {{7-33=if #available(OSX 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} + // expected-note@-2 {{add 'if #available' version check}} {{7-33=if #available(macOS 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} } let _: () -> () = { () in functionAvailableOn10_51() // expected-error@-1 {{'functionAvailableOn10_51()' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add 'if #available' version check}} {{7-33=if #available(OSX 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} + // expected-note@-2 {{add 'if #available' version check}} {{7-33=if #available(macOS 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} } takesAutoclosure(functionAvailableOn10_51()) // expected-error@-1 {{'functionAvailableOn10_51()' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add 'if #available' version check}} {{5-49=if #available(OSX 10.51, *) {\n takesAutoclosure(functionAvailableOn10_51())\n } else {\n // Fallback on earlier versions\n }}} + // expected-note@-2 {{add 'if #available' version check}} {{5-49=if #available(macOS 10.51, *) {\n takesAutoclosure(functionAvailableOn10_51())\n } else {\n // Fallback on earlier versions\n }}} } var fixitForReferenceInPropertyAccessor: Int { - // expected-note@-1 {{add @available attribute to enclosing property}} {{3-3=@available(OSX 10.51, *)\n }} + // expected-note@-1 {{add @available attribute to enclosing property}} {{3-3=@available(macOS 10.51, *)\n }} get { functionAvailableOn10_51() // expected-error@-1 {{'functionAvailableOn10_51()' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add 'if #available' version check}} {{7-33=if #available(OSX 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} + // expected-note@-2 {{add 'if #available' version check}} {{7-33=if #available(macOS 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} return 5 } @@ -1287,53 +1287,53 @@ class ClassForFixit { lazy var fixitForReferenceInLazyPropertyType: ClassAvailableOn10_51? = nil // expected-error@-1 {{'ClassAvailableOn10_51' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add @available attribute to enclosing property}} {{3-3=@available(OSX 10.51, *)\n }} + // expected-note@-2 {{add @available attribute to enclosing property}} {{3-3=@available(macOS 10.51, *)\n }} private lazy var fixitForReferenceInPrivateLazyPropertyType: ClassAvailableOn10_51? = nil // expected-error@-1 {{'ClassAvailableOn10_51' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add @available attribute to enclosing property}} {{3-3=@available(OSX 10.51, *)\n }} + // expected-note@-2 {{add @available attribute to enclosing property}} {{3-3=@available(macOS 10.51, *)\n }} lazy private var fixitForReferenceInLazyPrivatePropertyType: ClassAvailableOn10_51? = nil // expected-error@-1 {{'ClassAvailableOn10_51' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add @available attribute to enclosing property}} {{3-3=@available(OSX 10.51, *)\n }} + // expected-note@-2 {{add @available attribute to enclosing property}} {{3-3=@available(macOS 10.51, *)\n }} static var fixitForReferenceInStaticPropertyType: ClassAvailableOn10_51? = nil // expected-error@-1 {{'ClassAvailableOn10_51' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add @available attribute to enclosing static property}} {{3-3=@available(OSX 10.51, *)\n }} + // expected-note@-2 {{add @available attribute to enclosing static property}} {{3-3=@available(macOS 10.51, *)\n }} var fixitForReferenceInPropertyTypeMultiple: ClassAvailableOn10_51? = nil, other: Int = 7 // expected-error@-1 {{'ClassAvailableOn10_51' is only available in macOS 10.51 or newer}} func fixitForRefInGuardOfIf() { - // expected-note@-1 {{add @available attribute to enclosing instance method}} {{3-3=@available(OSX 10.51, *)\n }} + // expected-note@-1 {{add @available attribute to enclosing instance method}} {{3-3=@available(macOS 10.51, *)\n }} if (globalFuncAvailableOn10_51() > 1066) { let _ = 5 let _ = 6 } // expected-error@-4 {{'globalFuncAvailableOn10_51()' is only available in macOS 10.51 or newer}} - // expected-note@-5 {{add 'if #available' version check}} {{5-6=if #available(OSX 10.51, *) {\n if (globalFuncAvailableOn10_51() > 1066) {\n let _ = 5\n let _ = 6\n }\n } else {\n // Fallback on earlier versions\n }}} + // expected-note@-5 {{add 'if #available' version check}} {{5-6=if #available(macOS 10.51, *) {\n if (globalFuncAvailableOn10_51() > 1066) {\n let _ = 5\n let _ = 6\n }\n } else {\n // Fallback on earlier versions\n }}} } } extension ClassToExtend { // expected-note@-1 {{add @available attribute to enclosing extension}} func fixitForReferenceInExtensionMethod() { - // expected-note@-1 {{add @available attribute to enclosing instance method}} {{3-3=@available(OSX 10.51, *)\n }} + // expected-note@-1 {{add @available attribute to enclosing instance method}} {{3-3=@available(macOS 10.51, *)\n }} functionAvailableOn10_51() // expected-error@-1 {{'functionAvailableOn10_51()' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add 'if #available' version check}} {{5-31=if #available(OSX 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} + // expected-note@-2 {{add 'if #available' version check}} {{5-31=if #available(macOS 10.51, *) {\n functionAvailableOn10_51()\n } else {\n // Fallback on earlier versions\n }}} } } enum EnumForFixit { - // expected-note@-1 2{{add @available attribute to enclosing enum}} {{1-1=@available(OSX 10.51, *)\n}} + // expected-note@-1 2{{add @available attribute to enclosing enum}} {{1-1=@available(macOS 10.51, *)\n}} case CaseWithUnavailablePayload(p: ClassAvailableOn10_51) // expected-error@-1 {{'ClassAvailableOn10_51' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add @available attribute to enclosing case}} {{3-3=@available(OSX 10.51, *)\n }} + // expected-note@-2 {{add @available attribute to enclosing case}} {{3-3=@available(macOS 10.51, *)\n }} case CaseWithUnavailablePayload2(p: ClassAvailableOn10_51), WithoutPayload // expected-error@-1 {{'ClassAvailableOn10_51' is only available in macOS 10.51 or newer}} - // expected-note@-2 {{add @available attribute to enclosing case}} {{3-3=@available(OSX 10.51, *)\n }} + // expected-note@-2 {{add @available attribute to enclosing case}} {{3-3=@available(macOS 10.51, *)\n }} } @@ -1348,18 +1348,18 @@ class X { } func testForFixitWithNestedMemberRefExpr() { - // expected-note@-1 2{{add @available attribute to enclosing global function}} {{1-1=@available(OSX 10.52, *)\n}} + // expected-note@-1 2{{add @available attribute to enclosing global function}} {{1-1=@available(macOS 10.52, *)\n}} let x = X() x.y.z = globalFuncAvailableOn10_52() // expected-error@-1 {{'globalFuncAvailableOn10_52()' is only available in macOS 10.52 or newer}} - // expected-note@-2 {{add 'if #available' version check}} {{3-39=if #available(OSX 10.52, *) {\n x.y.z = globalFuncAvailableOn10_52()\n } else {\n // Fallback on earlier versions\n }}} + // expected-note@-2 {{add 'if #available' version check}} {{3-39=if #available(macOS 10.52, *) {\n x.y.z = globalFuncAvailableOn10_52()\n } else {\n // Fallback on earlier versions\n }}} // Access via dynamic member reference let anyX: AnyObject = x anyX.y?.z = globalFuncAvailableOn10_52() // expected-error@-1 {{'globalFuncAvailableOn10_52()' is only available in macOS 10.52 or newer}} - // expected-note@-2 {{add 'if #available' version check}} {{3-43=if #available(OSX 10.52, *) {\n anyX.y?.z = globalFuncAvailableOn10_52()\n } else {\n // Fallback on earlier versions\n }}} + // expected-note@-2 {{add 'if #available' version check}} {{3-43=if #available(macOS 10.52, *) {\n anyX.y?.z = globalFuncAvailableOn10_52()\n } else {\n // Fallback on earlier versions\n }}} } diff --git a/test/Sema/availability_versions_playgrounds.swift b/test/Sema/availability_versions_playgrounds.swift index 33c08e68a5c46..2c644db52b61a 100644 --- a/test/Sema/availability_versions_playgrounds.swift +++ b/test/Sema/availability_versions_playgrounds.swift @@ -26,7 +26,7 @@ func someFunction() { if #available(OSX 10.50, *) { // expected-note {{enclosing scope here}} // Still warn if the check is useless because an enclosing #available rules // it out. - if #available(OSX 10.50, *) { // expected-warning {{unnecessary check for 'OSX'; enclosing scope ensures guard will always be true}} + if #available(OSX 10.50, *) { // expected-warning {{unnecessary check for 'macOS'; enclosing scope ensures guard will always be true}} } } } @@ -35,7 +35,7 @@ func someFunction() { func availableOn10_50() { // expected-note {{enclosing scope here}} // Still warn if the check is useless because an enclosing @available rules // it out. - if #available(OSX 10.50, *) { // expected-warning {{unnecessary check for 'OSX'; enclosing scope ensures guard will always be true}} + if #available(OSX 10.50, *) { // expected-warning {{unnecessary check for 'macOS'; enclosing scope ensures guard will always be true}} } } diff --git a/test/SourceKit/DocSupport/doc_swift_module.swift.response b/test/SourceKit/DocSupport/doc_swift_module.swift.response index c35a607481291..ed1f6b1897d97 100644 --- a/test/SourceKit/DocSupport/doc_swift_module.swift.response +++ b/test/SourceKit/DocSupport/doc_swift_module.swift.response @@ -2186,7 +2186,7 @@ func shouldPrintAnyAsKeyword(x x: Any) }, { key.kind: source.lang.swift.decl.extension.class, - key.doc.full_as_xml: "@available(OSX 10.12, iOS 10.0, watchOS 3.0, tvOS 10.0, *)\nextension C1some comments", + key.doc.full_as_xml: "@available(macOS 10.12, iOS 10.0, watchOS 3.0, tvOS 10.0, *)\nextension C1some comments", key.offset: 473, key.length: 37, key.fully_annotated_decl: "extension C1", @@ -2437,7 +2437,7 @@ func shouldPrintAnyAsKeyword(x x: Any) }, { key.kind: source.lang.swift.decl.extension.class, - key.doc.full_as_xml: "@available(OSX 10.12, iOS 10.0, watchOS 3.0, tvOS 10.0, *)\nextension C2some comments", + key.doc.full_as_xml: "@available(macOS 10.12, iOS 10.0, watchOS 3.0, tvOS 10.0, *)\nextension C2some comments", key.offset: 982, key.length: 37, key.fully_annotated_decl: "extension C2", diff --git a/test/SourceKit/InterfaceGen/gen_clang_module.swift.response b/test/SourceKit/InterfaceGen/gen_clang_module.swift.response index 82e393d599c2b..0927f5c120840 100644 --- a/test/SourceKit/InterfaceGen/gen_clang_module.swift.response +++ b/test/SourceKit/InterfaceGen/gen_clang_module.swift.response @@ -328,11 +328,11 @@ open class FooUnavailableMembers : FooClassBase { open func deprecated() - @available(OSX 10.1, *) + @available(macOS 10.1, *) open func availabilityIntroduced() - @available(OSX, introduced: 10.1, message: "x") + @available(macOS, introduced: 10.1, message: "x") open func availabilityIntroducedMsg() } @@ -3362,241 +3362,241 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { { key.kind: source.lang.swift.syntaxtype.keyword, key.offset: 6485, - key.length: 3 + key.length: 5 }, { key.kind: source.lang.swift.syntaxtype.number, - key.offset: 6489, + key.offset: 6491, key.length: 4 }, { key.kind: source.lang.swift.syntaxtype.attribute.builtin, - key.offset: 6502, + key.offset: 6504, key.length: 4 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 6507, + key.offset: 6509, key.length: 4 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 6512, + key.offset: 6514, key.length: 22 }, { key.kind: source.lang.swift.syntaxtype.attribute.builtin, - key.offset: 6547, + key.offset: 6549, key.length: 10 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 6558, - key.length: 3 + key.offset: 6560, + key.length: 5 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 6563, + key.offset: 6567, key.length: 10 }, { key.kind: source.lang.swift.syntaxtype.number, - key.offset: 6575, + key.offset: 6579, key.length: 4 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 6581, + key.offset: 6585, key.length: 7 }, { key.kind: source.lang.swift.syntaxtype.string, - key.offset: 6590, + key.offset: 6594, key.length: 3 }, { key.kind: source.lang.swift.syntaxtype.attribute.builtin, - key.offset: 6599, + key.offset: 6603, key.length: 4 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 6604, + key.offset: 6608, key.length: 4 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 6609, + key.offset: 6613, key.length: 25 }, { key.kind: source.lang.swift.syntaxtype.attribute.builtin, - key.offset: 6640, + key.offset: 6644, key.length: 6 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 6647, + key.offset: 6651, key.length: 5 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 6653, + key.offset: 6657, key.length: 9 }, { key.kind: source.lang.swift.syntaxtype.attribute.builtin, - key.offset: 6668, + key.offset: 6672, key.length: 10 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 6682, + key.offset: 6686, key.length: 10 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 6694, + key.offset: 6698, key.length: 7 }, { key.kind: source.lang.swift.syntaxtype.string, - key.offset: 6703, + key.offset: 6707, key.length: 27 }, { key.kind: source.lang.swift.syntaxtype.attribute.builtin, - key.offset: 6732, + key.offset: 6736, key.length: 6 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 6739, + key.offset: 6743, key.length: 4 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 6744, + key.offset: 6748, key.length: 21 }, { key.kind: source.lang.swift.syntaxtype.typeidentifier, - key.offset: 6768, + key.offset: 6772, key.length: 3 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 6784, + key.offset: 6788, key.length: 4 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 6789, + key.offset: 6793, key.length: 13 }, { key.kind: source.lang.swift.syntaxtype.number, - key.offset: 6805, + key.offset: 6809, key.length: 1 }, { key.kind: source.lang.swift.syntaxtype.comment, - key.offset: 6807, + key.offset: 6811, key.length: 54 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 6866, + key.offset: 6870, key.length: 4 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 6871, + key.offset: 6875, key.length: 10 }, { key.kind: source.lang.swift.syntaxtype.number, - key.offset: 6884, + key.offset: 6888, key.length: 1 }, { key.kind: source.lang.swift.syntaxtype.comment, - key.offset: 6886, + key.offset: 6890, key.length: 51 }, { key.kind: source.lang.swift.syntaxtype.attribute.builtin, - key.offset: 6939, + key.offset: 6943, key.length: 6 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 6946, + key.offset: 6950, key.length: 5 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 6952, + key.offset: 6956, key.length: 19 }, { key.kind: source.lang.swift.syntaxtype.attribute.builtin, - key.offset: 6979, + key.offset: 6983, key.length: 6 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 6986, + key.offset: 6990, key.length: 4 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 6991, + key.offset: 6995, key.length: 1 }, { key.kind: source.lang.swift.syntaxtype.attribute.builtin, - key.offset: 6998, + key.offset: 7002, key.length: 6 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 7005, + key.offset: 7009, key.length: 5 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 7011, + key.offset: 7015, key.length: 22 }, { key.kind: source.lang.swift.syntaxtype.typeidentifier, - key.offset: 7036, + key.offset: 7040, key.length: 3 }, { key.kind: source.lang.swift.syntaxtype.typeidentifier, - key.offset: 7040, + key.offset: 7044, key.length: 19 }, { key.kind: source.lang.swift.syntaxtype.attribute.builtin, - key.offset: 7067, + key.offset: 7071, key.length: 8 }, { key.kind: source.lang.swift.syntaxtype.attribute.builtin, - key.offset: 7076, + key.offset: 7080, key.length: 6 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 7083, + key.offset: 7087, key.length: 4 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 7088, + key.offset: 7092, key.length: 1 } ] @@ -4220,18 +4220,18 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { }, { key.kind: source.lang.swift.ref.struct, - key.offset: 6768, + key.offset: 6772, key.length: 3, key.is_system: 1 }, { key.kind: source.lang.swift.ref.module, - key.offset: 7036, + key.offset: 7040, key.length: 3 }, { key.kind: source.lang.swift.ref.class, - key.offset: 7040, + key.offset: 7044, key.length: 19 } ] @@ -6859,11 +6859,11 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { key.accessibility: source.lang.swift.accessibility.open, key.name: "FooUnavailableMembers", key.offset: 6298, - key.length: 340, + key.length: 344, key.nameoffset: 6304, key.namelength: 21, key.bodyoffset: 6342, - key.bodylength: 295, + key.bodylength: 299, key.inheritedtypes: [ { key.name: "FooClassBase" @@ -6941,19 +6941,19 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { key.kind: source.lang.swift.decl.function.method.instance, key.accessibility: source.lang.swift.accessibility.open, key.name: "availabilityIntroduced()", - key.offset: 6507, + key.offset: 6509, key.length: 29, - key.nameoffset: 6512, + key.nameoffset: 6514, key.namelength: 24, key.attributes: [ { - key.offset: 6502, + key.offset: 6504, key.length: 4, key.attribute: source.decl.attribute.open }, { key.offset: 6474, - key.length: 23, + key.length: 25, key.attribute: source.decl.attribute.available } ] @@ -6962,19 +6962,19 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { key.kind: source.lang.swift.decl.function.method.instance, key.accessibility: source.lang.swift.accessibility.open, key.name: "availabilityIntroducedMsg()", - key.offset: 6604, + key.offset: 6608, key.length: 32, - key.nameoffset: 6609, + key.nameoffset: 6613, key.namelength: 27, key.attributes: [ { - key.offset: 6599, + key.offset: 6603, key.length: 4, key.attribute: source.decl.attribute.open }, { - key.offset: 6547, - key.length: 47, + key.offset: 6549, + key.length: 49, key.attribute: source.decl.attribute.available } ] @@ -6985,15 +6985,15 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { key.kind: source.lang.swift.decl.class, key.accessibility: source.lang.swift.accessibility.public, key.name: "FooCFType", - key.offset: 6647, + key.offset: 6651, key.length: 19, - key.nameoffset: 6653, + key.nameoffset: 6657, key.namelength: 9, - key.bodyoffset: 6664, + key.bodyoffset: 6668, key.bodylength: 1, key.attributes: [ { - key.offset: 6640, + key.offset: 6644, key.length: 6, key.attribute: source.decl.attribute.public } @@ -7003,11 +7003,11 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { key.kind: source.lang.swift.decl.enum, key.accessibility: source.lang.swift.accessibility.public, key.name: "ABAuthorizationStatus", - key.offset: 6739, + key.offset: 6743, key.length: 199, - key.nameoffset: 6744, + key.nameoffset: 6748, key.namelength: 21, - key.bodyoffset: 6773, + key.bodyoffset: 6777, key.bodylength: 164, key.inheritedtypes: [ { @@ -7016,12 +7016,12 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { ], key.attributes: [ { - key.offset: 6732, + key.offset: 6736, key.length: 6, key.attribute: source.decl.attribute.public }, { - key.offset: 6668, + key.offset: 6672, key.length: 63, key.attribute: source.decl.attribute.available } @@ -7029,14 +7029,14 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { key.elements: [ { key.kind: source.lang.swift.structure.elem.typeref, - key.offset: 6768, + key.offset: 6772, key.length: 3 } ], key.substructure: [ { key.kind: source.lang.swift.decl.enumcase, - key.offset: 6784, + key.offset: 6788, key.length: 22, key.nameoffset: 0, key.namelength: 0, @@ -7045,14 +7045,14 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { key.kind: source.lang.swift.decl.enumelement, key.accessibility: source.lang.swift.accessibility.public, key.name: "notDetermined", - key.offset: 6789, + key.offset: 6793, key.length: 17, - key.nameoffset: 6789, + key.nameoffset: 6793, key.namelength: 13, key.elements: [ { key.kind: source.lang.swift.structure.elem.init_expr, - key.offset: 6805, + key.offset: 6809, key.length: 1 } ] @@ -7061,7 +7061,7 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { }, { key.kind: source.lang.swift.decl.enumcase, - key.offset: 6866, + key.offset: 6870, key.length: 19, key.nameoffset: 0, key.namelength: 0, @@ -7070,14 +7070,14 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { key.kind: source.lang.swift.decl.enumelement, key.accessibility: source.lang.swift.accessibility.public, key.name: "restricted", - key.offset: 6871, + key.offset: 6875, key.length: 14, - key.nameoffset: 6871, + key.nameoffset: 6875, key.namelength: 10, key.elements: [ { key.kind: source.lang.swift.structure.elem.init_expr, - key.offset: 6884, + key.offset: 6888, key.length: 1 } ] @@ -7090,15 +7090,15 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { key.kind: source.lang.swift.decl.class, key.accessibility: source.lang.swift.accessibility.public, key.name: "FooOverlayClassBase", - key.offset: 6946, + key.offset: 6950, key.length: 50, - key.nameoffset: 6952, + key.nameoffset: 6956, key.namelength: 19, - key.bodyoffset: 6973, + key.bodyoffset: 6977, key.bodylength: 22, key.attributes: [ { - key.offset: 6939, + key.offset: 6943, key.length: 6, key.attribute: source.decl.attribute.public } @@ -7108,13 +7108,13 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { key.kind: source.lang.swift.decl.function.method.instance, key.accessibility: source.lang.swift.accessibility.public, key.name: "f()", - key.offset: 6986, + key.offset: 6990, key.length: 8, - key.nameoffset: 6991, + key.nameoffset: 6995, key.namelength: 3, key.attributes: [ { - key.offset: 6979, + key.offset: 6983, key.length: 6, key.attribute: source.decl.attribute.public } @@ -7126,11 +7126,11 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { key.kind: source.lang.swift.decl.class, key.accessibility: source.lang.swift.accessibility.public, key.name: "FooOverlayClassDerived", - key.offset: 7005, + key.offset: 7009, key.length: 88, - key.nameoffset: 7011, + key.nameoffset: 7015, key.namelength: 22, - key.bodyoffset: 7061, + key.bodyoffset: 7065, key.bodylength: 31, key.inheritedtypes: [ { @@ -7139,7 +7139,7 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { ], key.attributes: [ { - key.offset: 6998, + key.offset: 7002, key.length: 6, key.attribute: source.decl.attribute.public } @@ -7147,7 +7147,7 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { key.elements: [ { key.kind: source.lang.swift.structure.elem.typeref, - key.offset: 7036, + key.offset: 7040, key.length: 23 } ], @@ -7156,18 +7156,18 @@ public class FooOverlayClassDerived : Foo.FooOverlayClassBase { key.kind: source.lang.swift.decl.function.method.instance, key.accessibility: source.lang.swift.accessibility.public, key.name: "f()", - key.offset: 7083, + key.offset: 7087, key.length: 8, - key.nameoffset: 7088, + key.nameoffset: 7092, key.namelength: 3, key.attributes: [ { - key.offset: 7076, + key.offset: 7080, key.length: 6, key.attribute: source.decl.attribute.public }, { - key.offset: 7067, + key.offset: 7071, key.length: 8, key.attribute: source.decl.attribute.override } diff --git a/test/attr/Inputs/OldAndNew.swift b/test/attr/Inputs/OldAndNew.swift index bf573f65b95af..0332aaacbc457 100644 --- a/test/attr/Inputs/OldAndNew.swift +++ b/test/attr/Inputs/OldAndNew.swift @@ -6,7 +6,7 @@ public func fiveOnly() -> Int { } // CHECK: @available(swift 5.0) -// CHECK-NEXT: @available(OSX 10.1, *) +// CHECK-NEXT: @available(macOS 10.1, *) // CHECK-NEXT: func fiveOnlyWithMac() -> Int @available(swift, introduced: 5.0) @available(macOS, introduced: 10.1) @@ -15,7 +15,7 @@ public func fiveOnlyWithMac() -> Int { } // CHECK: @available(swift 5.0) -// CHECK-NEXT: @available(OSX 10.1, *) +// CHECK-NEXT: @available(macOS 10.1, *) // CHECK-NEXT: func fiveOnlyWithMac2() -> Int @available(macOS, introduced: 10.1) @available(swift, introduced: 5.0) diff --git a/test/attr/Inputs/PackageDescription.swift b/test/attr/Inputs/PackageDescription.swift index 5a370ab0457f2..59e3221edd8ce 100644 --- a/test/attr/Inputs/PackageDescription.swift +++ b/test/attr/Inputs/PackageDescription.swift @@ -6,7 +6,7 @@ public enum SwiftVersion { case v4 // CHECK: @available(_PackageDescription 5.0) - // CHECK-NEXT: @available(OSX 10.1, *) + // CHECK-NEXT: @available(macOS 10.1, *) // CHECK-NEXT: v5 @available(_PackageDescription, introduced: 5.0) @available(macOS, introduced: 10.1) From 1ac9acb5fe60648480014c79073171234d48f922 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 8 Jul 2020 18:36:29 -0400 Subject: [PATCH 06/36] Sema: We don't have to explicitly set the type of the main function --- lib/Sema/TypeCheckAttr.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 147affbcbc40e..bab9ff834aee5 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -1871,10 +1871,6 @@ void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) { bool mainFunctionThrows = mainFunction->hasThrows(); - auto voidToVoidFunctionType = - FunctionType::get({}, context.TheEmptyTupleType, - FunctionType::ExtInfo().withThrows(mainFunctionThrows)); - auto nominalToVoidToVoidFunctionType = FunctionType::get({AnyFunctionType::Param(nominal->getInterfaceType())}, voidToVoidFunctionType); auto *func = FuncDecl::create( context, /*StaticLoc*/ SourceLoc(), StaticSpellingKind::KeywordStatic, /*FuncLoc*/ SourceLoc(), @@ -1930,7 +1926,6 @@ void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) { auto *body = BraceStmt::create(context, SourceLoc(), stmts, SourceLoc(), /*Implicit*/true); func->setBodyParsed(body); - func->setInterfaceType(nominalToVoidToVoidFunctionType); iterableDeclContext->addMember(func); From 56a0b82a014770136080606665007c308ca81acc Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 8 Jul 2020 18:50:54 -0400 Subject: [PATCH 07/36] Sema: Lazily synthesize body of main function --- lib/Sema/TypeCheckAttr.cpp | 117 +++++++++++++++++++++---------------- 1 file changed, 67 insertions(+), 50 deletions(-) diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index bab9ff834aee5..a97ec5ca215a8 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -1782,6 +1782,67 @@ void AttributeChecker::visitUIApplicationMainAttr(UIApplicationMainAttr *attr) { C.getIdentifier("UIApplicationMain")); } +namespace { +struct MainTypeAttrParams { + FuncDecl *mainFunction; + MainTypeAttr *attr; +}; + +} +static std::pair +synthesizeMainBody(AbstractFunctionDecl *fn, void *arg) { + ASTContext &context = fn->getASTContext(); + MainTypeAttrParams *params = (MainTypeAttrParams *) arg; + + FuncDecl *mainFunction = params->mainFunction; + auto location = params->attr->getLocation(); + NominalTypeDecl *nominal = fn->getDeclContext()->getSelfNominalTypeDecl(); + + auto *typeExpr = TypeExpr::createImplicit(nominal->getDeclaredType(), context); + + SubstitutionMap substitutionMap; + if (auto *environment = mainFunction->getGenericEnvironment()) { + substitutionMap = SubstitutionMap::get( + environment->getGenericSignature(), + [&](SubstitutableType *type) { return nominal->getDeclaredType(); }, + LookUpConformanceInModule(nominal->getModuleContext())); + } else { + substitutionMap = SubstitutionMap(); + } + + auto funcDeclRef = ConcreteDeclRef(mainFunction, substitutionMap); + + auto *memberRefExpr = new (context) MemberRefExpr( + typeExpr, SourceLoc(), funcDeclRef, DeclNameLoc(location), + /*Implicit*/ true); + memberRefExpr->setImplicit(true); + + auto *callExpr = CallExpr::createImplicit(context, memberRefExpr, {}, {}); + callExpr->setImplicit(true); + callExpr->setThrows(mainFunction->hasThrows()); + callExpr->setType(context.TheEmptyTupleType); + + Expr *returnedExpr; + + if (mainFunction->hasThrows()) { + auto *tryExpr = new (context) TryExpr( + SourceLoc(), callExpr, context.TheEmptyTupleType, /*implicit=*/true); + returnedExpr = tryExpr; + } else { + returnedExpr = callExpr; + } + + auto *returnStmt = + new (context) ReturnStmt(SourceLoc(), callExpr, /*Implicit=*/true); + + SmallVector stmts; + stmts.push_back(returnStmt); + auto *body = BraceStmt::create(context, SourceLoc(), stmts, + SourceLoc(), /*Implicit*/true); + + return std::make_pair(body, /*typechecked=*/false); +} + void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) { auto *extension = dyn_cast(D); @@ -1802,11 +1863,8 @@ void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) { braces = nominal->getBraces(); } - if (!nominal) { - assert(false && "Should have already recognized that the MainType decl " + assert(nominal && "Should have already recognized that the MainType decl " "isn't applicable to decls other than NominalTypeDecls"); - return; - } assert(iterableDeclContext); assert(declContext); @@ -1833,7 +1891,6 @@ void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) { // mainType.main() from the entry point, and that would require fully // type-checking the call to mainType.main(). auto &context = D->getASTContext(); - auto location = attr->getLocation(); auto resolution = resolveValueMember( *declContext, nominal->getInterfaceType(), context.Id_main); @@ -1869,14 +1926,12 @@ void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) { mainFunction = viableCandidates[0]; } - bool mainFunctionThrows = mainFunction->hasThrows(); - auto *func = FuncDecl::create( context, /*StaticLoc*/ SourceLoc(), StaticSpellingKind::KeywordStatic, /*FuncLoc*/ SourceLoc(), DeclName(context, DeclBaseName(context.Id_MainEntryPoint), ParameterList::createEmpty(context)), - /*NameLoc*/ SourceLoc(), /*Throws=*/mainFunctionThrows, + /*NameLoc*/ SourceLoc(), /*Throws=*/mainFunction->hasThrows(), /*ThrowsLoc=*/SourceLoc(), /*GenericParams=*/nullptr, ParameterList::createEmpty(context), /*FnRetType=*/TypeLoc::withoutLoc(TupleType::getEmpty(context)), @@ -1884,48 +1939,10 @@ void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) { func->setImplicit(true); func->setSynthesized(true); - auto *typeExpr = TypeExpr::createImplicit(nominal->getDeclaredType(), context); - - SubstitutionMap substitutionMap; - if (auto *environment = mainFunction->getGenericEnvironment()) { - substitutionMap = SubstitutionMap::get( - environment->getGenericSignature(), - [&](SubstitutableType *type) { return nominal->getDeclaredType(); }, - LookUpConformanceInModule(nominal->getModuleContext())); - } else { - substitutionMap = SubstitutionMap(); - } - - auto funcDeclRef = ConcreteDeclRef(mainFunction, substitutionMap); - - auto *memberRefExpr = new (context) MemberRefExpr( - typeExpr, SourceLoc(), funcDeclRef, DeclNameLoc(location), - /*Implicit*/ true); - memberRefExpr->setImplicit(true); - - auto *callExpr = CallExpr::createImplicit(context, memberRefExpr, {}, {}); - callExpr->setImplicit(true); - callExpr->setThrows(mainFunctionThrows); - callExpr->setType(context.TheEmptyTupleType); - - Expr *returnedExpr; - - if (mainFunctionThrows) { - auto *tryExpr = new (context) TryExpr( - SourceLoc(), callExpr, context.TheEmptyTupleType, /*implicit=*/true); - returnedExpr = tryExpr; - } else { - returnedExpr = callExpr; - } - - auto *returnStmt = - new (context) ReturnStmt(SourceLoc(), callExpr, /*Implicit=*/true); - - SmallVector stmts; - stmts.push_back(returnStmt); - auto *body = BraceStmt::create(context, SourceLoc(), stmts, - SourceLoc(), /*Implicit*/true); - func->setBodyParsed(body); + auto *params = context.Allocate(); + params->mainFunction = mainFunction; + params->attr = attr; + func->setBodySynthesizer(synthesizeMainBody, params); iterableDeclContext->addMember(func); From 50632d294ccb302f3dc95eb8660e6a4fed359a35 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 8 Jul 2020 19:12:15 -0400 Subject: [PATCH 08/36] Sema: Add a request to synthesize the main function --- include/swift/AST/ASTTypeIDZone.def | 1 + include/swift/AST/ASTTypeIDs.h | 1 + include/swift/AST/TypeCheckRequests.h | 17 +++++++++ include/swift/AST/TypeCheckerTypeIDZone.def | 2 + lib/Sema/TypeCheckAttr.cpp | 41 +++++++++++++++------ 5 files changed, 51 insertions(+), 11 deletions(-) diff --git a/include/swift/AST/ASTTypeIDZone.def b/include/swift/AST/ASTTypeIDZone.def index daf638f797480..bf6aeaf1e65dc 100644 --- a/include/swift/AST/ASTTypeIDZone.def +++ b/include/swift/AST/ASTTypeIDZone.def @@ -40,6 +40,7 @@ SWIFT_TYPEID_NAMED(ConstructorDecl *, ConstructorDecl) SWIFT_TYPEID_NAMED(CustomAttr *, CustomAttr) SWIFT_TYPEID_NAMED(Decl *, Decl) SWIFT_TYPEID_NAMED(EnumDecl *, EnumDecl) +SWIFT_TYPEID_NAMED(FuncDecl *, FuncDecl) SWIFT_TYPEID_NAMED(GenericParamList *, GenericParamList) SWIFT_TYPEID_NAMED(GenericTypeParamType *, GenericTypeParamType) SWIFT_TYPEID_NAMED(InfixOperatorDecl *, InfixOperatorDecl) diff --git a/include/swift/AST/ASTTypeIDs.h b/include/swift/AST/ASTTypeIDs.h index b81fdc686b9ac..973f7dc2c9836 100644 --- a/include/swift/AST/ASTTypeIDs.h +++ b/include/swift/AST/ASTTypeIDs.h @@ -30,6 +30,7 @@ class ConstructorDecl; class CustomAttr; class Decl; class EnumDecl; +class FuncDecl; enum class FunctionBuilderBodyPreCheck : uint8_t; class GenericParamList; class GenericSignature; diff --git a/include/swift/AST/TypeCheckRequests.h b/include/swift/AST/TypeCheckRequests.h index b2866e493daf7..540b999668150 100644 --- a/include/swift/AST/TypeCheckRequests.h +++ b/include/swift/AST/TypeCheckRequests.h @@ -2564,6 +2564,23 @@ class CustomAttrTypeRequest void cacheResult(Type value) const; }; +class SynthesizeMainFunctionRequest + : public SimpleRequest { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + // Evaluation. + FuncDecl *evaluate(Evaluator &evaluator, Decl *) const; + +public: + bool isCached() const { return true; } +}; + // Allow AnyValue to compare two Type values, even though Type doesn't // support ==. template<> diff --git a/include/swift/AST/TypeCheckerTypeIDZone.def b/include/swift/AST/TypeCheckerTypeIDZone.def index f9cfdb060ecf9..f940dc811a05a 100644 --- a/include/swift/AST/TypeCheckerTypeIDZone.def +++ b/include/swift/AST/TypeCheckerTypeIDZone.def @@ -274,3 +274,5 @@ SWIFT_REQUEST(TypeChecker, LookupAllConformancesInContextRequest, Uncached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, SimpleDidSetRequest, bool(AccessorDecl *), Cached, NoLocationInfo) +SWIFT_REQUEST(TypeChecker, SynthesizeMainFunctionRequest, + FuncDecl *(Decl *), Cached, NoLocationInfo) diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index a97ec5ca215a8..40a63188d2836 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -1843,7 +1843,15 @@ synthesizeMainBody(AbstractFunctionDecl *fn, void *arg) { return std::make_pair(body, /*typechecked=*/false); } -void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) { +FuncDecl * +SynthesizeMainFunctionRequest::evaluate(Evaluator &evaluator, + Decl *D) const { + auto &context = D->getASTContext(); + + MainTypeAttr *attr = D->getAttrs().getAttribute(); + if (attr == nullptr) + return nullptr; + auto *extension = dyn_cast(D); IterableDeclContext *iterableDeclContext; @@ -1870,10 +1878,10 @@ void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) { // The type cannot be generic. if (nominal->isGenericContext()) { - diagnose(attr->getLocation(), - diag::attr_generic_ApplicationMain_not_supported, 2); + context.Diags.diagnose(attr->getLocation(), + diag::attr_generic_ApplicationMain_not_supported, 2); attr->setInvalid(); - return; + return nullptr; } SourceFile *file = cast(declContext->getModuleScopeContext()); @@ -1890,7 +1898,6 @@ void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) { // usual type-checking. The alternative would be to directly call // mainType.main() from the entry point, and that would require fully // type-checking the call to mainType.main(). - auto &context = D->getASTContext(); auto resolution = resolveValueMember( *declContext, nominal->getInterfaceType(), context.Id_main); @@ -1918,10 +1925,11 @@ void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) { } if (viableCandidates.size() != 1) { - diagnose(attr->getLocation(), diag::attr_MainType_without_main, - nominal->getBaseName()); + context.Diags.diagnose(attr->getLocation(), + diag::attr_MainType_without_main, + nominal->getBaseName()); attr->setInvalid(); - return; + return nullptr; } mainFunction = viableCandidates[0]; } @@ -1968,12 +1976,23 @@ void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) { // Of course, this function's body does not type-check. file->DelayedFunctions.push_back(func); + return func; +} + +void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) { + auto &context = D->getASTContext(); + + SourceFile *file = D->getDeclContext()->getParentSourceFile(); + assert(file); + + auto *func = evaluateOrDefault(context.evaluator, + SynthesizeMainFunctionRequest{D}, + nullptr); + // Register the func as the main decl in the module. If there are multiples // they will be diagnosed. - if (file->registerMainDecl(func, attr->getLocation())) { + if (file->registerMainDecl(func, attr->getLocation())) attr->setInvalid(); - return; - } } /// Determine whether the given context is an extension to an Objective-C class From 244dc4a7685c41774ac90c821a4f1dbc57f9b469 Mon Sep 17 00:00:00 2001 From: Nathan Hawes Date: Wed, 8 Jul 2020 15:47:06 -0700 Subject: [PATCH 09/36] [AST] Rename PlatformKind::OSX to PlatformKind::macOS Because the names are coming from a .def file used for printing too, this simplifies the printing logic as well. --- include/swift/AST/PlatformKinds.def | 4 +-- lib/AST/Availability.cpp | 2 +- lib/AST/PlatformKind.cpp | 27 +++++++------------ lib/ClangImporter/ClangImporter.cpp | 12 ++++----- lib/ClangImporter/ImportDecl.cpp | 4 +-- lib/Driver/DarwinToolChains.cpp | 2 +- lib/IDE/CodeCompletion.cpp | 1 - lib/PrintAsObjC/DeclAndTypePrinter.cpp | 4 +-- lib/Sema/TypeCheckAvailability.cpp | 4 +-- lib/Serialization/SerializedModuleLoader.cpp | 2 +- lib/SymbolGraphGen/AvailabilityMixin.cpp | 4 +-- lib/TBDGen/TBDGen.cpp | 4 +-- .../lib/SwiftLang/SwiftDocSupport.cpp | 4 +-- .../ModuleAnalyzerNodes.cpp | 2 +- 14 files changed, 33 insertions(+), 43 deletions(-) diff --git a/include/swift/AST/PlatformKinds.def b/include/swift/AST/PlatformKinds.def index cc9b9ede0eae8..29ae59270d6ab 100644 --- a/include/swift/AST/PlatformKinds.def +++ b/include/swift/AST/PlatformKinds.def @@ -25,11 +25,11 @@ AVAILABILITY_PLATFORM(iOS, "iOS") AVAILABILITY_PLATFORM(tvOS, "tvOS") AVAILABILITY_PLATFORM(watchOS, "watchOS") -AVAILABILITY_PLATFORM(OSX, "macOS") +AVAILABILITY_PLATFORM(macOS, "macOS") AVAILABILITY_PLATFORM(iOSApplicationExtension, "application extensions for iOS") AVAILABILITY_PLATFORM(tvOSApplicationExtension, "application extensions for tvOS") AVAILABILITY_PLATFORM(watchOSApplicationExtension, "application extensions for watchOS") -AVAILABILITY_PLATFORM(OSXApplicationExtension, "application extensions for macOS") +AVAILABILITY_PLATFORM(macOSApplicationExtension, "application extensions for macOS") AVAILABILITY_PLATFORM(macCatalyst, "Mac Catalyst") AVAILABILITY_PLATFORM(macCatalystApplicationExtension, "application extensions for Mac Catalyst") diff --git a/lib/AST/Availability.cpp b/lib/AST/Availability.cpp index 48cb3911128e6..a1ebba5ca9003 100644 --- a/lib/AST/Availability.cpp +++ b/lib/AST/Availability.cpp @@ -358,7 +358,7 @@ AvailabilityContext ASTContext::getSwift53Availability() { return AvailabilityContext::alwaysAvailable(); llvm::VersionTuple macOVersion53(10, 16, 0); - macOVersion53 = canonicalizePlatformVersion(PlatformKind::OSX, macOVersion53); + macOVersion53 = canonicalizePlatformVersion(PlatformKind::macOS, macOVersion53); return AvailabilityContext( VersionRange::allGTE(macOVersion53)); } else if (target.isiOS()) { diff --git a/lib/AST/PlatformKind.cpp b/lib/AST/PlatformKind.cpp index 370a9f83b821b..ff6a00319b1aa 100644 --- a/lib/AST/PlatformKind.cpp +++ b/lib/AST/PlatformKind.cpp @@ -24,15 +24,6 @@ using namespace swift; StringRef swift::platformString(PlatformKind platform) { - // FIXME: Update PlatformKinds.def to use the macOS spelling by default. - switch (platform) { - case PlatformKind::OSX: - return "macOS"; - case PlatformKind::OSXApplicationExtension: - return "macOSApplicationExtension"; - default: break; - } - switch (platform) { case PlatformKind::none: return "*"; @@ -62,8 +53,8 @@ Optional swift::platformFromString(StringRef Name) { return llvm::StringSwitch>(Name) #define AVAILABILITY_PLATFORM(X, PrettyName) .Case(#X, PlatformKind::X) #include "swift/AST/PlatformKinds.def" - .Case("macOS", PlatformKind::OSX) - .Case("macOSApplicationExtension", PlatformKind::OSXApplicationExtension) + .Case("OSX", PlatformKind::macOS) + .Case("OSXApplicationExtension", PlatformKind::macOSApplicationExtension) .Default(Optional()); } @@ -73,7 +64,7 @@ static bool isPlatformActiveForTarget(PlatformKind Platform, if (Platform == PlatformKind::none) return true; - if (Platform == PlatformKind::OSXApplicationExtension || + if (Platform == PlatformKind::macOSApplicationExtension || Platform == PlatformKind::iOSApplicationExtension || Platform == PlatformKind::macCatalystApplicationExtension) if (!EnableAppExtensionRestrictions) @@ -81,8 +72,8 @@ static bool isPlatformActiveForTarget(PlatformKind Platform, // FIXME: This is an awful way to get the current OS. switch (Platform) { - case PlatformKind::OSX: - case PlatformKind::OSXApplicationExtension: + case PlatformKind::macOS: + case PlatformKind::macOSApplicationExtension: return Target.isMacOSX(); case PlatformKind::iOS: case PlatformKind::iOSApplicationExtension: @@ -118,8 +109,8 @@ bool swift::isPlatformActive(PlatformKind Platform, const LangOptions &LangOpts, PlatformKind swift::targetPlatform(const LangOptions &LangOpts) { if (LangOpts.Target.isMacOSX()) { return (LangOpts.EnableAppExtensionRestrictions - ? PlatformKind::OSXApplicationExtension - : PlatformKind::OSX); + ? PlatformKind::macOSApplicationExtension + : PlatformKind::macOS); } if (LangOpts.Target.isTvOS()) { @@ -171,8 +162,8 @@ llvm::VersionTuple swift::canonicalizePlatformVersion( // Canonicalize macOS version for macOS Big Sur to treat // 10.16 as 11.0. - if (platform == PlatformKind::OSX || - platform == PlatformKind::OSXApplicationExtension) { + if (platform == PlatformKind::macOS || + platform == PlatformKind::macOSApplicationExtension) { return llvm::Triple::getCanonicalVersionForOS(llvm::Triple::MacOSX, version); } diff --git a/lib/ClangImporter/ClangImporter.cpp b/lib/ClangImporter/ClangImporter.cpp index bffb8f4d85bcd..a5c98d3ea7a1b 100644 --- a/lib/ClangImporter/ClangImporter.cpp +++ b/lib/ClangImporter/ClangImporter.cpp @@ -1903,8 +1903,8 @@ PlatformAvailability::PlatformAvailability(const LangOptions &langOpts) deprecatedAsUnavailableMessage = ""; break; - case PlatformKind::OSX: - case PlatformKind::OSXApplicationExtension: + case PlatformKind::macOS: + case PlatformKind::macOSApplicationExtension: deprecatedAsUnavailableMessage = "APIs deprecated as of macOS 10.9 and earlier are unavailable in Swift"; break; @@ -1916,9 +1916,9 @@ PlatformAvailability::PlatformAvailability(const LangOptions &langOpts) bool PlatformAvailability::isPlatformRelevant(StringRef name) const { switch (platformKind) { - case PlatformKind::OSX: + case PlatformKind::macOS: return name == "macos"; - case PlatformKind::OSXApplicationExtension: + case PlatformKind::macOSApplicationExtension: return name == "macos" || name == "macos_app_extension"; case PlatformKind::iOS: @@ -1958,8 +1958,8 @@ bool PlatformAvailability::treatDeprecatedAsUnavailable( case PlatformKind::none: llvm_unreachable("version but no platform?"); - case PlatformKind::OSX: - case PlatformKind::OSXApplicationExtension: + case PlatformKind::macOS: + case PlatformKind::macOSApplicationExtension: // Anything deprecated in OSX 10.9.x and earlier is unavailable in Swift. return major < 10 || (major == 10 && (!minor.hasValue() || minor.getValue() <= 9)); diff --git a/lib/ClangImporter/ImportDecl.cpp b/lib/ClangImporter/ImportDecl.cpp index 524b26b59ff18..af7fd77e1597f 100644 --- a/lib/ClangImporter/ImportDecl.cpp +++ b/lib/ClangImporter/ImportDecl.cpp @@ -7613,12 +7613,12 @@ void ClangImporter::Implementation::importAttributes( auto platformK = llvm::StringSwitch>(Platform) .Case("ios", PlatformKind::iOS) - .Case("macos", PlatformKind::OSX) + .Case("macos", PlatformKind::macOS) .Case("tvos", PlatformKind::tvOS) .Case("watchos", PlatformKind::watchOS) .Case("ios_app_extension", PlatformKind::iOSApplicationExtension) .Case("macos_app_extension", - PlatformKind::OSXApplicationExtension) + PlatformKind::macOSApplicationExtension) .Case("tvos_app_extension", PlatformKind::tvOSApplicationExtension) .Case("watchos_app_extension", diff --git a/lib/Driver/DarwinToolChains.cpp b/lib/Driver/DarwinToolChains.cpp index a7153dc064e52..4245ff6641b57 100644 --- a/lib/Driver/DarwinToolChains.cpp +++ b/lib/Driver/DarwinToolChains.cpp @@ -615,7 +615,7 @@ toolchains::Darwin::addDeploymentTargetArgs(ArgStringList &Arguments, // The first deployment of arm64 for macOS is version 10.16; if (triple.isAArch64() && major <= 10 && minor < 16) { llvm::VersionTuple firstMacARM64e(10, 16, 0); - firstMacARM64e = canonicalizePlatformVersion(PlatformKind::OSX, + firstMacARM64e = canonicalizePlatformVersion(PlatformKind::macOS, firstMacARM64e); major = firstMacARM64e.getMajor(); minor = firstMacARM64e.getMinor().getValueOr(0); diff --git a/lib/IDE/CodeCompletion.cpp b/lib/IDE/CodeCompletion.cpp index db073e6604365..f700231f7dd59 100644 --- a/lib/IDE/CodeCompletion.cpp +++ b/lib/IDE/CodeCompletion.cpp @@ -4429,7 +4429,6 @@ class CompletionLookup final : public swift::VisibleDeclConsumer { if (ParamIndex == 0) { addDeclAttrParamKeyword("*", "Platform", false); - // For code completion, suggest 'macOS' instead of 'OSX'. #define AVAILABILITY_PLATFORM(X, PrettyName) \ addDeclAttrParamKeyword(swift::platformString(PlatformKind::X), \ "Platform", false); diff --git a/lib/PrintAsObjC/DeclAndTypePrinter.cpp b/lib/PrintAsObjC/DeclAndTypePrinter.cpp index 468b0c002230f..b848f04627854 100644 --- a/lib/PrintAsObjC/DeclAndTypePrinter.cpp +++ b/lib/PrintAsObjC/DeclAndTypePrinter.cpp @@ -821,7 +821,7 @@ class DeclAndTypePrinter::Implementation const char *plat; switch (AvAttr->Platform) { - case PlatformKind::OSX: + case PlatformKind::macOS: plat = "macos"; break; case PlatformKind::iOS: @@ -836,7 +836,7 @@ class DeclAndTypePrinter::Implementation case PlatformKind::watchOS: plat = "watchos"; break; - case PlatformKind::OSXApplicationExtension: + case PlatformKind::macOSApplicationExtension: plat = "macos_app_extension"; break; case PlatformKind::iOSApplicationExtension: diff --git a/lib/Sema/TypeCheckAvailability.cpp b/lib/Sema/TypeCheckAvailability.cpp index a62d73c95ff83..b6c1f91dd7be7 100644 --- a/lib/Sema/TypeCheckAvailability.cpp +++ b/lib/Sema/TypeCheckAvailability.cpp @@ -1253,8 +1253,8 @@ static bool fixAvailabilityByNarrowingNearbyVersionCheck( auto Platform = targetPlatform(Context.LangOpts); if (RunningVers.getMajor() != RequiredVers.getMajor()) return false; - if ((Platform == PlatformKind::OSX || - Platform == PlatformKind::OSXApplicationExtension) && + if ((Platform == PlatformKind::macOS || + Platform == PlatformKind::macOSApplicationExtension) && !(RunningVers.getMinor().hasValue() && RequiredVers.getMinor().hasValue() && RunningVers.getMinor().getValue() == diff --git a/lib/Serialization/SerializedModuleLoader.cpp b/lib/Serialization/SerializedModuleLoader.cpp index 3aded479fe5e2..25a81335ad015 100644 --- a/lib/Serialization/SerializedModuleLoader.cpp +++ b/lib/Serialization/SerializedModuleLoader.cpp @@ -625,7 +625,7 @@ getOSAndVersionForDiagnostics(const llvm::Triple &triple) { // macOS triples represent their versions differently, so we have to use the // special accessor. triple.getMacOSXVersion(major, minor, micro); - osName = swift::prettyPlatformString(PlatformKind::OSX); + osName = swift::prettyPlatformString(PlatformKind::macOS); } else { triple.getOSVersion(major, minor, micro); if (triple.isWatchOS()) { diff --git a/lib/SymbolGraphGen/AvailabilityMixin.cpp b/lib/SymbolGraphGen/AvailabilityMixin.cpp index eb0e7c1041522..aa446f412a83c 100644 --- a/lib/SymbolGraphGen/AvailabilityMixin.cpp +++ b/lib/SymbolGraphGen/AvailabilityMixin.cpp @@ -40,7 +40,7 @@ StringRef getDomain(const AvailableAttr &AvAttr) { return { "iOS" }; case swift::PlatformKind::macCatalyst: return { "macCatalyst" }; - case swift::PlatformKind::OSX: + case swift::PlatformKind::macOS: return { "macOS" }; case swift::PlatformKind::tvOS: return { "tvOS" }; @@ -50,7 +50,7 @@ StringRef getDomain(const AvailableAttr &AvAttr) { return { "iOSAppExtension" }; case swift::PlatformKind::macCatalystApplicationExtension: return { "macCatalystAppExtension" }; - case swift::PlatformKind::OSXApplicationExtension: + case swift::PlatformKind::macOSApplicationExtension: return { "macOSAppExtension" }; case swift::PlatformKind::tvOSApplicationExtension: return { "tvOSAppExtension" }; diff --git a/lib/TBDGen/TBDGen.cpp b/lib/TBDGen/TBDGen.cpp index 5d5a63d06d614..6679eeb1b98f3 100644 --- a/lib/TBDGen/TBDGen.cpp +++ b/lib/TBDGen/TBDGen.cpp @@ -257,8 +257,8 @@ getLinkerPlatformId(OriginallyDefinedInAttr::ActiveVersion Ver) { case swift::PlatformKind::watchOSApplicationExtension: return Ver.IsSimulator ? LinkerPlatformId::watchOS_sim: LinkerPlatformId::watchOS; - case swift::PlatformKind::OSX: - case swift::PlatformKind::OSXApplicationExtension: + case swift::PlatformKind::macOS: + case swift::PlatformKind::macOSApplicationExtension: return LinkerPlatformId::macOS; case swift::PlatformKind::macCatalyst: case swift::PlatformKind::macCatalystApplicationExtension: diff --git a/tools/SourceKit/lib/SwiftLang/SwiftDocSupport.cpp b/tools/SourceKit/lib/SwiftLang/SwiftDocSupport.cpp index 7fb63fd0e8aab..f59c9d0564bd1 100644 --- a/tools/SourceKit/lib/SwiftLang/SwiftDocSupport.cpp +++ b/tools/SourceKit/lib/SwiftLang/SwiftDocSupport.cpp @@ -686,7 +686,7 @@ static void reportAttributes(ASTContext &Ctx, PlatformUID = PlatformIOS; break; case PlatformKind::macCatalyst: PlatformUID = PlatformMacCatalyst; break; - case PlatformKind::OSX: + case PlatformKind::macOS: PlatformUID = PlatformOSX; break; case PlatformKind::tvOS: PlatformUID = PlatformtvOS; break; @@ -696,7 +696,7 @@ static void reportAttributes(ASTContext &Ctx, PlatformUID = PlatformIOSAppExt; break; case PlatformKind::macCatalystApplicationExtension: PlatformUID = PlatformMacCatalystAppExt; break; - case PlatformKind::OSXApplicationExtension: + case PlatformKind::macOSApplicationExtension: PlatformUID = PlatformOSXAppExt; break; case PlatformKind::tvOSApplicationExtension: PlatformUID = PlatformtvOSAppExt; break; diff --git a/tools/swift-api-digester/ModuleAnalyzerNodes.cpp b/tools/swift-api-digester/ModuleAnalyzerNodes.cpp index a48769a9bc2f0..265043d491e62 100644 --- a/tools/swift-api-digester/ModuleAnalyzerNodes.cpp +++ b/tools/swift-api-digester/ModuleAnalyzerNodes.cpp @@ -1331,7 +1331,7 @@ SDKNodeInitInfo::SDKNodeInitInfo(SDKContext &Ctx, Decl *D): SugaredGenericSig(Ctx.checkingABI()? printGenericSignature(Ctx, D, /*Canonical*/false): StringRef()), - IntromacOS(Ctx.getPlatformIntroVersion(D, PlatformKind::OSX)), + IntromacOS(Ctx.getPlatformIntroVersion(D, PlatformKind::macOS)), IntroiOS(Ctx.getPlatformIntroVersion(D, PlatformKind::iOS)), IntrotvOS(Ctx.getPlatformIntroVersion(D, PlatformKind::tvOS)), IntrowatchOS(Ctx.getPlatformIntroVersion(D, PlatformKind::watchOS)), From f8583099b96d8abeacdf9106bfefb24c2f713aad Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 8 Jul 2020 19:14:08 -0400 Subject: [PATCH 10/36] Sema: EmittedMembersRequest forces SynthesizeMainFunctionRequest --- lib/Sema/TypeCheckDecl.cpp | 6 ++++++ test/SILGen/attr_main_class.swift | 11 +++++++++++ test/SILGen/attr_main_class_2.swift | 11 +++++++++++ test/SILGen/attr_main_enum.swift | 11 +++++++++++ test/SILGen/attr_main_struct.swift | 11 +++++++++++ 5 files changed, 50 insertions(+) create mode 100644 test/SILGen/attr_main_class.swift create mode 100644 test/SILGen/attr_main_class_2.swift create mode 100644 test/SILGen/attr_main_enum.swift create mode 100644 test/SILGen/attr_main_struct.swift diff --git a/lib/Sema/TypeCheckDecl.cpp b/lib/Sema/TypeCheckDecl.cpp index 92638c7f25699..f2af26aa9d614 100644 --- a/lib/Sema/TypeCheckDecl.cpp +++ b/lib/Sema/TypeCheckDecl.cpp @@ -2523,6 +2523,12 @@ EmittedMembersRequest::evaluate(Evaluator &evaluator, forceConformance(Context.getProtocol(KnownProtocolKind::Hashable)); forceConformance(Context.getProtocol(KnownProtocolKind::Differentiable)); + // If the class has a @main attribute, we need to force synthesis of the + // $main function. + (void) evaluateOrDefault(Context.evaluator, + SynthesizeMainFunctionRequest{CD}, + nullptr); + for (auto *member : CD->getMembers()) { if (auto *var = dyn_cast(member)) { // The projected storage wrapper ($foo) might have dynamically-dispatched diff --git a/test/SILGen/attr_main_class.swift b/test/SILGen/attr_main_class.swift new file mode 100644 index 0000000000000..3e964ab8105ed --- /dev/null +++ b/test/SILGen/attr_main_class.swift @@ -0,0 +1,11 @@ +// RUN: %target-swift-emit-silgen -parse-as-library %s | %FileCheck %s + +@main class Horse { + static func main() {} +} + +// CHECK-LABEL: sil hidden [ossa] @$s15attr_main_class5HorseC5$mainyyFZ : $@convention(method) (@thick Horse.Type) -> () { + +// CHECK-LABEL: sil [ossa] @main : $@convention(c) (Int32, UnsafeMutablePointer>>) -> Int32 { +// CHECK: function_ref @$s15attr_main_class5HorseC5$mainyyFZ : $@convention(method) (@thick Horse.Type) -> () +// CHECK: } diff --git a/test/SILGen/attr_main_class_2.swift b/test/SILGen/attr_main_class_2.swift new file mode 100644 index 0000000000000..4af16cd7c3efd --- /dev/null +++ b/test/SILGen/attr_main_class_2.swift @@ -0,0 +1,11 @@ +// RUN: %target-swift-emit-silgen -parse-as-library %s | %FileCheck %s + +@main class Horse { + class func main() {} +} + +// CHECK-LABEL: sil hidden [ossa] @$s17attr_main_class_25HorseC5$mainyyFZ : $@convention(method) (@thick Horse.Type) -> () { + +// CHECK-LABEL: sil [ossa] @main : $@convention(c) (Int32, UnsafeMutablePointer>>) -> Int32 { +// CHECK: function_ref @$s17attr_main_class_25HorseC5$mainyyFZ : $@convention(method) (@thick Horse.Type) -> () +// CHECK: } diff --git a/test/SILGen/attr_main_enum.swift b/test/SILGen/attr_main_enum.swift new file mode 100644 index 0000000000000..900684a3dabc0 --- /dev/null +++ b/test/SILGen/attr_main_enum.swift @@ -0,0 +1,11 @@ +// RUN: %target-swift-emit-silgen -parse-as-library %s | %FileCheck %s + +@main enum Horse { + static func main() {} +} + +// CHECK-LABEL: sil hidden [ossa] @$s14attr_main_enum5HorseO5$mainyyFZ : $@convention(method) (@thin Horse.Type) -> () { + +// CHECK-LABEL: sil [ossa] @main : $@convention(c) (Int32, UnsafeMutablePointer>>) -> Int32 { +// CHECK: function_ref @$s14attr_main_enum5HorseO5$mainyyFZ : $@convention(method) (@thin Horse.Type) -> () +// CHECK: } diff --git a/test/SILGen/attr_main_struct.swift b/test/SILGen/attr_main_struct.swift new file mode 100644 index 0000000000000..1dc8e1fa45470 --- /dev/null +++ b/test/SILGen/attr_main_struct.swift @@ -0,0 +1,11 @@ +// RUN: %target-swift-emit-silgen -parse-as-library %s | %FileCheck %s + +@main struct Horse { + static func main() {} +} + +// CHECK-LABEL: sil hidden [ossa] @$s16attr_main_struct5HorseV5$mainyyFZ : $@convention(method) (@thin Horse.Type) -> () { + +// CHECK-LABEL: sil [ossa] @main : $@convention(c) (Int32, UnsafeMutablePointer>>) -> Int32 { +// CHECK: function_ref @$s16attr_main_struct5HorseV5$mainyyFZ : $@convention(method) (@thin Horse.Type) -> () +// CHECK: } From 7bc50889f4507ea74f0076c5e24f8c8085613c64 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 8 Jul 2020 19:19:45 -0400 Subject: [PATCH 11/36] Sema: Check declaration attributes before checking members This simplifies matters if checking an attribute adds members to the nominal type or extension. --- lib/Sema/TypeCheckAttr.cpp | 25 ------------------------- lib/Sema/TypeCheckDeclPrimary.cpp | 21 +++++++++++---------- test/Compatibility/accessibility.swift | 14 +++++++------- test/Sema/accessibility.swift | 14 +++++++------- 4 files changed, 25 insertions(+), 49 deletions(-) diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 40a63188d2836..7375764e368ee 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -1884,9 +1884,6 @@ SynthesizeMainFunctionRequest::evaluate(Evaluator &evaluator, return nullptr; } - SourceFile *file = cast(declContext->getModuleScopeContext()); - assert(file); - // Create a function // // func $main() { @@ -1954,28 +1951,6 @@ SynthesizeMainFunctionRequest::evaluate(Evaluator &evaluator, iterableDeclContext->addMember(func); - // This function must be type-checked. Why? Consider the following scenario: - // - // protocol AlmostMainable {} - // protocol ReallyMainable {} - // extension AlmostMainable where Self : ReallyMainable { - // static func main() {} - // } - // @main struct Main : AlmostMainable {} - // - // Note in particular that Main does not conform to ReallyMainable. - // - // In this case, resolveValueMember will find the function main in the - // extension, and so, since there is one candidate, the function $main will - // accordingly be formed as usual: - // - // func $main() { - // return Main.main() - // } - // - // Of course, this function's body does not type-check. - file->DelayedFunctions.push_back(func); - return func; } diff --git a/lib/Sema/TypeCheckDeclPrimary.cpp b/lib/Sema/TypeCheckDeclPrimary.cpp index 30ef739271870..2550d986a61f1 100644 --- a/lib/Sema/TypeCheckDeclPrimary.cpp +++ b/lib/Sema/TypeCheckDeclPrimary.cpp @@ -1797,13 +1797,13 @@ class DeclChecker : public DeclVisitor { checkGenericParams(ED); // Check for circular inheritance of the raw type. - (void)ED->hasCircularRawValue(); + (void) ED->hasCircularRawValue(); + + TypeChecker::checkDeclAttributes(ED); for (Decl *member : ED->getMembers()) visit(member); - TypeChecker::checkDeclAttributes(ED); - checkInheritanceClause(ED); checkAccessControl(ED); @@ -1845,13 +1845,13 @@ class DeclChecker : public DeclVisitor { installCodingKeysIfNecessary(SD); + TypeChecker::checkDeclAttributes(SD); + for (Decl *Member : SD->getMembers()) visit(Member); TypeChecker::checkPatternBindingCaptures(SD); - TypeChecker::checkDeclAttributes(SD); - checkInheritanceClause(SD); checkAccessControl(SD); @@ -1974,6 +1974,8 @@ class DeclChecker : public DeclVisitor { // Force creation of an implicit destructor, if any. (void) CD->getDestructor(); + TypeChecker::checkDeclAttributes(CD); + for (Decl *Member : CD->getEmittedMembers()) visit(Member); @@ -2084,8 +2086,6 @@ class DeclChecker : public DeclVisitor { } } - TypeChecker::checkDeclAttributes(CD); - checkInheritanceClause(CD); checkAccessControl(CD); @@ -2105,12 +2105,12 @@ class DeclChecker : public DeclVisitor { // Check for circular inheritance within the protocol. (void)PD->hasCircularInheritedProtocols(); + TypeChecker::checkDeclAttributes(PD); + // Check the members. for (auto Member : PD->getMembers()) visit(Member); - TypeChecker::checkDeclAttributes(PD); - checkAccessControl(PD); checkInheritanceClause(PD); @@ -2428,6 +2428,8 @@ class DeclChecker : public DeclVisitor { checkGenericParams(ED); + TypeChecker::checkDeclAttributes(ED); + for (Decl *Member : ED->getMembers()) visit(Member); @@ -2435,7 +2437,6 @@ class DeclChecker : public DeclVisitor { TypeChecker::checkConformancesInContext(ED); - TypeChecker::checkDeclAttributes(ED); checkAccessControl(ED); checkExplicitAvailability(ED); diff --git a/test/Compatibility/accessibility.swift b/test/Compatibility/accessibility.swift index e7402cd845c11..35b348b1b4e61 100644 --- a/test/Compatibility/accessibility.swift +++ b/test/Compatibility/accessibility.swift @@ -107,7 +107,7 @@ private extension PublicStruct { private func extImplPrivate() {} } public extension InternalStruct { // expected-error {{extension of internal struct cannot be declared public}} {{1-8=}} - public func extMemberPublic() {} // expected-warning {{'public' modifier is redundant for instance method declared in a public extension}} {{3-10=}} + public func extMemberPublic() {} fileprivate func extFuncPublic() {} private func extImplPublic() {} } @@ -127,12 +127,12 @@ private extension InternalStruct { private func extImplPrivate() {} } public extension FilePrivateStruct { // expected-error {{extension of fileprivate struct cannot be declared public}} {{1-8=}} - public func extMemberPublic() {} // expected-warning {{'public' modifier is redundant for instance method declared in a public extension}} {{3-10=}} + public func extMemberPublic() {} fileprivate func extFuncPublic() {} private func extImplPublic() {} } internal extension FilePrivateStruct { // expected-error {{extension of fileprivate struct cannot be declared internal}} {{1-10=}} - public func extMemberInternal() {} // expected-warning {{'public' modifier conflicts with extension's default access of 'internal'}} {{none}} + public func extMemberInternal() {} fileprivate func extFuncInternal() {} private func extImplInternal() {} } @@ -147,18 +147,18 @@ private extension FilePrivateStruct { private func extImplPrivate() {} } public extension PrivateStruct { // expected-error {{extension of private struct cannot be declared public}} {{1-8=}} - public func extMemberPublic() {} // expected-warning {{'public' modifier is redundant for instance method declared in a public extension}} {{3-10=}} + public func extMemberPublic() {} fileprivate func extFuncPublic() {} private func extImplPublic() {} } internal extension PrivateStruct { // expected-error {{extension of private struct cannot be declared internal}} {{1-10=}} - public func extMemberInternal() {} // expected-warning {{'public' modifier conflicts with extension's default access of 'internal'}} {{none}} + public func extMemberInternal() {} fileprivate func extFuncInternal() {} private func extImplInternal() {} } fileprivate extension PrivateStruct { // expected-error {{extension of private struct cannot be declared fileprivate}} {{1-13=}} - public func extMemberFilePrivate() {} // expected-warning {{'public' modifier conflicts with extension's default access of 'fileprivate'}} {{none}} - fileprivate func extFuncFilePrivate() {} // expected-warning {{'fileprivate' modifier is redundant for instance method declared in a fileprivate extension}} {{3-15=}} + public func extMemberFilePrivate() {} + fileprivate func extFuncFilePrivate() {} private func extImplFilePrivate() {} } private extension PrivateStruct { diff --git a/test/Sema/accessibility.swift b/test/Sema/accessibility.swift index b7e5e46e20f63..2e4d016488f2e 100644 --- a/test/Sema/accessibility.swift +++ b/test/Sema/accessibility.swift @@ -106,7 +106,7 @@ private extension PublicStruct { private func extImplPrivate() {} } public extension InternalStruct { // expected-error {{extension of internal struct cannot be declared public}} {{1-8=}} - public func extMemberPublic() {} // expected-warning {{'public' modifier is redundant for instance method declared in a public extension}} {{3-10=}} + public func extMemberPublic() {} fileprivate func extFuncPublic() {} private func extImplPublic() {} } @@ -126,12 +126,12 @@ private extension InternalStruct { private func extImplPrivate() {} } public extension FilePrivateStruct { // expected-error {{extension of fileprivate struct cannot be declared public}} {{1-8=}} - public func extMemberPublic() {} // expected-warning {{'public' modifier is redundant for instance method declared in a public extension}} {{3-10=}} + public func extMemberPublic() {} fileprivate func extFuncPublic() {} private func extImplPublic() {} } internal extension FilePrivateStruct { // expected-error {{extension of fileprivate struct cannot be declared internal}} {{1-10=}} - public func extMemberInternal() {} // expected-warning {{'public' modifier conflicts with extension's default access of 'internal'}} {{none}} + public func extMemberInternal() {} fileprivate func extFuncInternal() {} private func extImplInternal() {} } @@ -146,18 +146,18 @@ private extension FilePrivateStruct { private func extImplPrivate() {} } public extension PrivateStruct { // expected-error {{extension of private struct cannot be declared public}} {{1-8=}} - public func extMemberPublic() {} // expected-warning {{'public' modifier is redundant for instance method declared in a public extension}} {{3-10=}} + public func extMemberPublic() {} fileprivate func extFuncPublic() {} private func extImplPublic() {} } internal extension PrivateStruct { // expected-error {{extension of private struct cannot be declared internal}} {{1-10=}} - public func extMemberInternal() {} // expected-warning {{'public' modifier conflicts with extension's default access of 'internal'}} {{none}} + public func extMemberInternal() {} fileprivate func extFuncInternal() {} private func extImplInternal() {} } fileprivate extension PrivateStruct { // expected-error {{extension of private struct cannot be declared fileprivate}} {{1-13=}} - public func extMemberFilePrivate() {} // expected-warning {{'public' modifier conflicts with extension's default access of 'fileprivate'}} {{none}} - fileprivate func extFuncFilePrivate() {} // expected-warning {{'fileprivate' modifier is redundant for instance method declared in a fileprivate extension}} {{3-15=}} + public func extMemberFilePrivate() {} + fileprivate func extFuncFilePrivate() {} private func extImplFilePrivate() {} } private extension PrivateStruct { From 505d87ca26edaa5dac3b16f62495d806d83d63f7 Mon Sep 17 00:00:00 2001 From: Karoy Lorentey Date: Wed, 8 Jul 2020 19:57:04 -0700 Subject: [PATCH 12/36] [Foundation] Extract logging details to a standalone function --- stdlib/public/Darwin/Foundation/CheckClass.swift | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/stdlib/public/Darwin/Foundation/CheckClass.swift b/stdlib/public/Darwin/Foundation/CheckClass.swift index b7ac9da9cc8c2..4d94e9cd357e7 100644 --- a/stdlib/public/Darwin/Foundation/CheckClass.swift +++ b/stdlib/public/Darwin/Foundation/CheckClass.swift @@ -23,6 +23,11 @@ private func _isClassFirstSeen(_ theClass: AnyClass) -> Bool { } } +internal func _logRuntimeIssue(_ message: String) { + NSLog("%@", message) + _swift_reportToDebugger(0, message, nil) +} + extension NSKeyedUnarchiver { /// Checks if class `theClass` is good for archiving. /// @@ -56,8 +61,7 @@ extension NSKeyedUnarchiver { If there are no existing archives containing this class, you should choose a unique, prefixed name instead: "@objc(ABCMyModel)" """ - NSLog("%@", message) - _swift_reportToDebugger(0, message, nil) + _logRuntimeIssue(message) } return 1 } From 94e92636997a45150eee5311b6453c2eaec9956a Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 8 Jul 2020 22:37:15 -0400 Subject: [PATCH 13/36] Sema: Fix crash on circular reference in checkContextualRequirements() The call to getGenericSignature() might return nullptr if we encounter a circular reference. Fixes . --- lib/Sema/TypeCheckType.cpp | 11 ++++++++++- .../compiler_crashers_2_fixed/rdar64992293.swift | 12 ++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 validation-test/compiler_crashers_2_fixed/rdar64992293.swift diff --git a/lib/Sema/TypeCheckType.cpp b/lib/Sema/TypeCheckType.cpp index ac2b44d348d2a..f9b59d8462aaf 100644 --- a/lib/Sema/TypeCheckType.cpp +++ b/lib/Sema/TypeCheckType.cpp @@ -619,6 +619,8 @@ static Type checkContextualRequirements(Type type, return type; } + auto &ctx = dc->getASTContext(); + SourceLoc noteLoc; { // We are interested in either a contextual where clause or @@ -637,6 +639,13 @@ static Type checkContextualRequirements(Type type, const auto subMap = parentTy->getContextSubstitutions(decl->getDeclContext()); const auto genericSig = decl->getGenericSignature(); + if (!genericSig) { + ctx.Diags.diagnose(loc, diag::recursive_decl_reference, + decl->getDescriptiveKind(), decl->getName()); + decl->diagnose(diag::kind_declared_here, DescriptiveDeclKind::Type); + return ErrorType::get(ctx); + } + const auto result = TypeChecker::checkGenericArguments( dc, loc, noteLoc, type, @@ -647,7 +656,7 @@ static Type checkContextualRequirements(Type type, switch (result) { case RequirementCheckResult::Failure: case RequirementCheckResult::SubstitutionFailure: - return ErrorType::get(dc->getASTContext()); + return ErrorType::get(ctx); case RequirementCheckResult::Success: return type; } diff --git a/validation-test/compiler_crashers_2_fixed/rdar64992293.swift b/validation-test/compiler_crashers_2_fixed/rdar64992293.swift new file mode 100644 index 0000000000000..2baabbe4e5b36 --- /dev/null +++ b/validation-test/compiler_crashers_2_fixed/rdar64992293.swift @@ -0,0 +1,12 @@ +// RUN: not %target-swift-frontend -typecheck %s + +public protocol SomeProtocol {} + +public struct Impl: SomeProtocol where Param: SomeProtocol {} + +public struct Wrapper where Content: SomeProtocol {} + +public extension Wrapper where Content == Impl { + typealias WrapperParam = SomeProtocol +} + From 1f5433fe518ef83a43f23ee6dbe96ce0b43fae81 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Wed, 8 Jul 2020 22:49:44 -0400 Subject: [PATCH 14/36] Add regression test for rdar://64759168 --- .../compiler_crashers_2_fixed/rdar64759168.swift | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 validation-test/compiler_crashers_2_fixed/rdar64759168.swift diff --git a/validation-test/compiler_crashers_2_fixed/rdar64759168.swift b/validation-test/compiler_crashers_2_fixed/rdar64759168.swift new file mode 100644 index 0000000000000..8478e642557e0 --- /dev/null +++ b/validation-test/compiler_crashers_2_fixed/rdar64759168.swift @@ -0,0 +1,15 @@ +// RUN: %target-swift-frontend -emit-ir %s + +public class Clazz {} + +public protocol SelectableFieldValueProtocol {} + +public protocol FieldProtocol { + associatedtype SelectableValue : SelectableFieldValueProtocol +} + +public protocol SelectFieldValueCoordinatorDelegate { + associatedtype Field : Clazz, FieldProtocol +} + +public class SelectFieldValueCoordinator where Field == Delegate.Field {} From e4a8e0ea45521359b22d4a08ccce119b02a22e8f Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Wed, 8 Jul 2020 22:14:32 -0700 Subject: [PATCH 15/36] Remove -sil-merge-partial-modules This is now the default behaviour for -merge-modules. Stop passing it in the driver and remove it from FrontendOptions.td. --- include/swift/Option/FrontendOptions.td | 4 ---- lib/Driver/ToolChains.cpp | 3 --- 2 files changed, 7 deletions(-) diff --git a/include/swift/Option/FrontendOptions.td b/include/swift/Option/FrontendOptions.td index b9f4adbceb59a..d465a0f1dfb71 100644 --- a/include/swift/Option/FrontendOptions.td +++ b/include/swift/Option/FrontendOptions.td @@ -554,10 +554,6 @@ def sil_unroll_threshold : Separate<["-"], "sil-unroll-threshold">, MetaVarName<"<250>">, HelpText<"Controls the aggressiveness of loop unrolling">; -// FIXME: This option is now redundant and should eventually be removed. -def sil_merge_partial_modules : Flag<["-"], "sil-merge-partial-modules">, - Alias; - def sil_verify_all : Flag<["-"], "sil-verify-all">, HelpText<"Verify SIL after each transform">; diff --git a/lib/Driver/ToolChains.cpp b/lib/Driver/ToolChains.cpp index 3e158c76274cf..e3447e1fb9672 100644 --- a/lib/Driver/ToolChains.cpp +++ b/lib/Driver/ToolChains.cpp @@ -981,9 +981,6 @@ ToolChain::constructInvocation(const MergeModuleJobAction &job, // serialized ASTs. Arguments.push_back("-parse-as-library"); - // Merge serialized SIL from partial modules. - Arguments.push_back("-sil-merge-partial-modules"); - // Disable SIL optimization passes; we've already optimized the code in each // partial mode. Arguments.push_back("-disable-diagnostic-passes"); From b42910a430f8fb9dabf0d243236f921f644cdd66 Mon Sep 17 00:00:00 2001 From: Karoy Lorentey Date: Wed, 8 Jul 2020 22:33:17 -0700 Subject: [PATCH 16/36] [test] Reenable NSValueBridging tests on i386 Instead, disable tests we cannot run with a targeted platform conditional. --- test/stdlib/NSValueBridging.swift.gyb | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/test/stdlib/NSValueBridging.swift.gyb b/test/stdlib/NSValueBridging.swift.gyb index 626415bcc7b99..3b91eeb9fc86f 100644 --- a/test/stdlib/NSValueBridging.swift.gyb +++ b/test/stdlib/NSValueBridging.swift.gyb @@ -18,18 +18,11 @@ // // REQUIRES: objc_interop -// The UIKit overlay isn't present on iOS 10.3. -// UNSUPPORTED: CPU=i386 - import StdlibUnittest import StdlibUnittestFoundationExtras import Foundation import CoreGraphics -#if canImport(UIKit) -import UIKit -#endif - var nsValueBridging = TestSuite("NSValueBridging") func rangesEqual(_ x: NSRange, _ y: NSRange) -> Bool { @@ -73,7 +66,13 @@ ${ testCase("CGAffineTransform", "CGAffineTransform(rotationAngle: .pi)", N #endif -#if canImport(UIKit) +// The last supported iOS version for 32-bit platforms is iOS 10.3, which didn't +// ship with the UIKit overlay, so we cannot run NSValue bridging tests there. +// +// FIXME: Test back-deployment scenarios with the Swift 5.0 compatibility +// runtime rather than a freshly built stdlib. (rdar://62694723) +#if canImport(UIKit) && !(os(iOS) && (arch(armv7) || arch(armv7s) || arch(i386))) +import UIKit ${ testCase("CGRect", "CGRect(x: 17, y: 38, width: 6, height: 79)", "cgRect", "(==)") } ${ testCase("CGPoint", "CGPoint(x: 17, y: 38)", "cgPoint", "(==)") } From 3de9e5ac77da9a21ff652eaae3fa0a134536ca6f Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Thu, 9 Jul 2020 12:11:08 +0200 Subject: [PATCH 17/36] Rename synthesized initializer test to follow the pattern And also removed a test that is semantically a duplicate of the synthesized initializer test. --- ...s.swift => synthesized-initializers-silgen.swift} | 0 validation-test/SILGen/Inputs/cxx-types.h | 3 --- validation-test/SILGen/Inputs/module.modulemap | 3 --- .../SILGen/cxx-address-only-object-init.swift | 12 ------------ 4 files changed, 18 deletions(-) rename test/Interop/Cxx/class/{synthesized-initializers.swift => synthesized-initializers-silgen.swift} (100%) delete mode 100644 validation-test/SILGen/Inputs/cxx-types.h delete mode 100644 validation-test/SILGen/Inputs/module.modulemap delete mode 100644 validation-test/SILGen/cxx-address-only-object-init.swift diff --git a/test/Interop/Cxx/class/synthesized-initializers.swift b/test/Interop/Cxx/class/synthesized-initializers-silgen.swift similarity index 100% rename from test/Interop/Cxx/class/synthesized-initializers.swift rename to test/Interop/Cxx/class/synthesized-initializers-silgen.swift diff --git a/validation-test/SILGen/Inputs/cxx-types.h b/validation-test/SILGen/Inputs/cxx-types.h deleted file mode 100644 index 1a883e6f2e126..0000000000000 --- a/validation-test/SILGen/Inputs/cxx-types.h +++ /dev/null @@ -1,3 +0,0 @@ -struct HasCustomCopyConst { - HasCustomCopyConst(HasCustomCopyConst const&) { } -}; diff --git a/validation-test/SILGen/Inputs/module.modulemap b/validation-test/SILGen/Inputs/module.modulemap deleted file mode 100644 index cb4d7c21edee3..0000000000000 --- a/validation-test/SILGen/Inputs/module.modulemap +++ /dev/null @@ -1,3 +0,0 @@ -module CXXTypes { - header "cxx-types.h" -} diff --git a/validation-test/SILGen/cxx-address-only-object-init.swift b/validation-test/SILGen/cxx-address-only-object-init.swift deleted file mode 100644 index 0d39b9e4a231e..0000000000000 --- a/validation-test/SILGen/cxx-address-only-object-init.swift +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: %target-swift-frontend -enable-cxx-interop -I %S/Inputs %s -emit-silgen | %FileCheck %s - -import CXXTypes - -// Just make sure we create the object and don't crash. -// CHECK-LABEL: @$s4main4testyyF -// CHECK: alloc_stack -// CHECK: apply -// CHECK: return %{{[0-9]+}} : $() -public func test() { - let c = HasCustomCopyConst() -} From 03411af15a1780ddca1962e3a373e832761f3736 Mon Sep 17 00:00:00 2001 From: David Zarzycki Date: Thu, 9 Jul 2020 08:22:46 -0400 Subject: [PATCH 18/36] [SIL] Convert computeLoweredRValueType to CanTypeVisitor In practice, LLVM can only optimize trivial `if` chains into a `switch`. --- lib/SIL/IR/TypeLowering.cpp | 233 +++++++++++++++++++----------------- 1 file changed, 126 insertions(+), 107 deletions(-) diff --git a/lib/SIL/IR/TypeLowering.cpp b/lib/SIL/IR/TypeLowering.cpp index fd3e8f28fe6d2..06a4273f92dc1 100644 --- a/lib/SIL/IR/TypeLowering.cpp +++ b/lib/SIL/IR/TypeLowering.cpp @@ -1772,131 +1772,150 @@ CanType TypeConverter::computeLoweredRValueType(TypeExpansionContext forExpansion, AbstractionPattern origType, CanType substType) { - // AST function types are turned into SIL function types: - // - the type is uncurried as desired - // - types are turned into their unbridged equivalents, depending - // on the abstract CC - // - ownership conventions are deduced - // - a minimal substituted generic signature is extracted to represent - // possible ABI-compatible substitutions - if (auto substFnType = dyn_cast(substType)) { - // If the formal type uses a C convention, it is not formally - // abstractable, and it may be subject to implicit bridging. - auto extInfo = substFnType->getExtInfo(); - if (getSILFunctionLanguage(extInfo.getSILRepresentation()) - == SILFunctionLanguage::C) { - // The importer only applies fully-reversible bridging to the - // component types of C function pointers. - auto bridging = Bridgeability::Full; - if (extInfo.getSILRepresentation() - == SILFunctionTypeRepresentation::CFunctionPointer) - bridging = Bridgeability::None; - - // Bridge the parameters and result of the function type. - auto bridgedFnType = getBridgedFunctionType(origType, substFnType, - extInfo, bridging); - substFnType = bridgedFnType; - - // Also rewrite the type of the abstraction pattern. - auto signature = origType.getGenericSignatureOrNull(); - if (origType.isTypeParameter()) { - origType = AbstractionPattern(signature, bridgedFnType); - } else { - origType.rewriteType(signature, bridgedFnType); + class LoweredRValueTypeVisitor + : public CanTypeVisitor { + TypeConverter &TC; + TypeExpansionContext forExpansion; + AbstractionPattern origType; + + public: + LoweredRValueTypeVisitor(TypeConverter &TC, + TypeExpansionContext forExpansion, + AbstractionPattern origType) + : TC(TC), forExpansion(forExpansion), origType(origType) {} + + // AST function types are turned into SIL function types: + // - the type is uncurried as desired + // - types are turned into their unbridged equivalents, depending + // on the abstract CC + // - ownership conventions are deduced + // - a minimal substituted generic signature is extracted to represent + // possible ABI-compatible substitutions + CanType visitAnyFunctionType(CanAnyFunctionType substFnType) { + // If the formal type uses a C convention, it is not formally + // abstractable, and it may be subject to implicit bridging. + auto extInfo = substFnType->getExtInfo(); + if (getSILFunctionLanguage(extInfo.getSILRepresentation()) == + SILFunctionLanguage::C) { + // The importer only applies fully-reversible bridging to the + // component types of C function pointers. + auto bridging = Bridgeability::Full; + if (extInfo.getSILRepresentation() == + SILFunctionTypeRepresentation::CFunctionPointer) + bridging = Bridgeability::None; + + // Bridge the parameters and result of the function type. + auto bridgedFnType = + TC.getBridgedFunctionType(origType, substFnType, extInfo, bridging); + substFnType = bridgedFnType; + + // Also rewrite the type of the abstraction pattern. + auto signature = origType.getGenericSignatureOrNull(); + if (origType.isTypeParameter()) { + origType = AbstractionPattern(signature, bridgedFnType); + } else { + origType.rewriteType(signature, bridgedFnType); + } } + + return ::getNativeSILFunctionType(TC, forExpansion, origType, + substFnType); } - return getNativeSILFunctionType(*this, forExpansion, origType, substFnType); - } + // Ignore dynamic self types. + CanType visitDynamicSelfType(CanDynamicSelfType selfType) { + return TC.getLoweredRValueType(forExpansion, origType, + selfType.getSelfType()); + } - // Ignore dynamic self types. - if (auto selfType = dyn_cast(substType)) { - return getLoweredRValueType(forExpansion, origType, selfType.getSelfType()); - } + // Static metatypes are unitary and can optimized to a "thin" empty + // representation if the type also appears as a static metatype in the + // original abstraction pattern. + CanType visitMetatypeType(CanMetatypeType substMeta) { + // If the metatype has already been lowered, it will already carry its + // representation. + if (substMeta->hasRepresentation()) { + assert(substMeta->isLegalSILType()); + return substOpaqueTypesWithUnderlyingTypes(substMeta, forExpansion); + } - // Static metatypes are unitary and can optimized to a "thin" empty - // representation if the type also appears as a static metatype in the - // original abstraction pattern. - if (auto substMeta = dyn_cast(substType)) { - // If the metatype has already been lowered, it will already carry its - // representation. - if (substMeta->hasRepresentation()) { - assert(substMeta->isLegalSILType()); - return substOpaqueTypesWithUnderlyingTypes(substMeta, forExpansion); - } + MetatypeRepresentation repr; - MetatypeRepresentation repr; - - auto origMeta = origType.getAs(); - if (!origMeta) { - // If the metatype matches a dependent type, it must be thick. - assert(origType.isTypeParameterOrOpaqueArchetype()); - repr = MetatypeRepresentation::Thick; - } else { - // Otherwise, we're thin if the metatype is thinnable both - // substituted and in the abstraction pattern. - if (hasSingletonMetatype(substMeta.getInstanceType()) - && hasSingletonMetatype(origMeta.getInstanceType())) - repr = MetatypeRepresentation::Thin; - else + auto origMeta = origType.getAs(); + if (!origMeta) { + // If the metatype matches a dependent type, it must be thick. + assert(origType.isTypeParameterOrOpaqueArchetype()); repr = MetatypeRepresentation::Thick; - } + } else { + // Otherwise, we're thin if the metatype is thinnable both + // substituted and in the abstraction pattern. + if (hasSingletonMetatype(substMeta.getInstanceType()) && + hasSingletonMetatype(origMeta.getInstanceType())) + repr = MetatypeRepresentation::Thin; + else + repr = MetatypeRepresentation::Thick; + } - CanType instanceType = substOpaqueTypesWithUnderlyingTypes( - substMeta.getInstanceType(), forExpansion); + CanType instanceType = substOpaqueTypesWithUnderlyingTypes( + substMeta.getInstanceType(), forExpansion); - // Regardless of thinness, metatypes are always trivial. - return CanMetatypeType::get(instanceType, repr); - } + // Regardless of thinness, metatypes are always trivial. + return CanMetatypeType::get(instanceType, repr); + } - // Give existential metatypes @thick representation by default. - if (auto existMetatype = dyn_cast(substType)) { - if (existMetatype->hasRepresentation()) { - assert(existMetatype->isLegalSILType()); - return existMetatype; + // Give existential metatypes @thick representation by default. + CanType + visitExistentialMetatypeType(CanExistentialMetatypeType existMetatype) { + if (existMetatype->hasRepresentation()) { + assert(existMetatype->isLegalSILType()); + return existMetatype; + } + + return CanExistentialMetatypeType::get(existMetatype.getInstanceType(), + MetatypeRepresentation::Thick); } - return CanExistentialMetatypeType::get(existMetatype.getInstanceType(), - MetatypeRepresentation::Thick); - } + // Lower tuple element types. + CanType visitTupleType(CanTupleType substTupleType) { + return computeLoweredTupleType(TC, forExpansion, origType, + substTupleType); + } - // Lower tuple element types. - if (auto substTupleType = dyn_cast(substType)) { - return computeLoweredTupleType(*this, forExpansion, origType, - substTupleType); - } + // Lower the referent type of reference storage types. + CanType visitReferenceStorageType(CanReferenceStorageType substRefType) { + return computeLoweredReferenceStorageType(TC, forExpansion, origType, + substRefType); + } - // Lower the referent type of reference storage types. - if (auto substRefType = dyn_cast(substType)) { - return computeLoweredReferenceStorageType(*this, forExpansion, origType, - substRefType); - } + CanType visitSILFunctionType(CanSILFunctionType silFnTy) { + if (!silFnTy->hasOpaqueArchetype() || + !forExpansion.shouldLookThroughOpaqueTypeArchetypes()) + return silFnTy; + return silFnTy->substituteOpaqueArchetypes(TC, forExpansion); + } - // Lower the object type of optional types. - if (auto substObjectType = substType.getOptionalObjectType()) { - return computeLoweredOptionalType(*this, forExpansion, origType, - substType, substObjectType); - } + CanType visitType(CanType substType) { + // Lower the object type of optional types. + if (auto substObjectType = substType.getOptionalObjectType()) { + return computeLoweredOptionalType(TC, forExpansion, origType, substType, + substObjectType); + } - if (auto silFnTy = dyn_cast(substType)) { - if (!substType->hasOpaqueArchetype() || - !forExpansion.shouldLookThroughOpaqueTypeArchetypes()) - return substType; - return silFnTy->substituteOpaqueArchetypes(*this, forExpansion); - } + // The Swift type directly corresponds to the lowered type. + auto underlyingTy = + substOpaqueTypesWithUnderlyingTypes(substType, forExpansion, + /*allowLoweredTypes*/ true); + if (underlyingTy != substType) { + underlyingTy = + TC.computeLoweredRValueType(forExpansion, origType, underlyingTy); + } - // The Swift type directly corresponds to the lowered type. - auto underlyingTy = - substOpaqueTypesWithUnderlyingTypes(substType, forExpansion, - /*allowLoweredTypes*/ true); - if (underlyingTy != substType) { - underlyingTy = computeLoweredRValueType( - forExpansion, - origType, - underlyingTy); - } + return underlyingTy; + } + }; - return underlyingTy; + LoweredRValueTypeVisitor visitor(*this, forExpansion, origType); + return visitor.visit(substType); } const TypeLowering & From 40104ba8b1d65043a84b49a6255b9b5dddbceb65 Mon Sep 17 00:00:00 2001 From: martinboehme Date: Thu, 9 Jul 2020 15:59:26 +0200 Subject: [PATCH 19/36] Link against the C++ standard library when C++ interop is enabled (#30914) This doesn't yet allow including C++ headers on platforms where libc++ isn't the default; see comments in UnixToolChains.cpp for details. However, it does, for example, allow throwing and catching exceptions in C++ code used through interop, unblocking https://github.com/apple/swift/pull/30674/files. The flags (-enable-experimental-cxx-interop and -experimental-cxx-stdlib) carry "experimental" in the name to emphasize that C++ interop is still an experimental feature. Co-authored-by: Michael Forster --- include/swift/AST/DiagnosticsDriver.def | 4 ++ include/swift/Driver/ToolChain.h | 3 ++ include/swift/Option/Options.td | 8 ++++ lib/Driver/DarwinToolChains.cpp | 12 +++++ lib/Driver/ToolChains.cpp | 31 +++++++++++++ lib/Driver/UnixToolChains.cpp | 31 ++++--------- lib/Driver/WindowsToolChains.cpp | 35 ++++----------- test/Driver/cxx_interop.swift | 10 +++++ test/Driver/linker.swift | 58 +++++++++++++++++++++++++ 9 files changed, 142 insertions(+), 50 deletions(-) create mode 100644 test/Driver/cxx_interop.swift diff --git a/include/swift/AST/DiagnosticsDriver.def b/include/swift/AST/DiagnosticsDriver.def index 52d864ca5fe52..b6b6a3cd1e484 100644 --- a/include/swift/AST/DiagnosticsDriver.def +++ b/include/swift/AST/DiagnosticsDriver.def @@ -186,6 +186,10 @@ ERROR(cannot_find_migration_script, none, ERROR(error_darwin_static_stdlib_not_supported, none, "-static-stdlib is no longer supported on Apple platforms", ()) +ERROR(error_darwin_only_supports_libcxx, none, + "The only C++ standard library supported on Apple platforms is libc++", + ()) + WARNING(warn_drv_darwin_sdk_invalid_settings, none, "SDK settings were ignored because 'SDKSettings.json' could not be parsed", ()) diff --git a/include/swift/Driver/ToolChain.h b/include/swift/Driver/ToolChain.h index cd8bebfdeceaa..cd55451a01c8b 100644 --- a/include/swift/Driver/ToolChain.h +++ b/include/swift/Driver/ToolChain.h @@ -294,6 +294,9 @@ class ToolChain { void getClangLibraryPath(const llvm::opt::ArgList &Args, SmallString<128> &LibPath) const; + // Returns the Clang driver executable to use for linking. + const char *getClangLinkerDriver(const llvm::opt::ArgList &Args) const; + /// Returns the name the clang library for a given sanitizer would have on /// the current toolchain. /// diff --git a/include/swift/Option/Options.td b/include/swift/Option/Options.td index bf07e0615b055..4dc519e24b1d5 100644 --- a/include/swift/Option/Options.td +++ b/include/swift/Option/Options.td @@ -556,6 +556,14 @@ def disable_direct_intramodule_dependencies : Flag<["-"], Flags<[FrontendOption, HelpHidden]>, HelpText<"Disable experimental dependency tracking that never cascades">; +def enable_experimental_cxx_interop : + Flag<["-"], "enable-experimental-cxx-interop">, + HelpText<"Allow importing C++ modules into Swift (experimental feature)">; + +def experimental_cxx_stdlib : + Separate<["-"], "experimental-cxx-stdlib">, + HelpText<"C++ standard library to use; forwarded to Clang's -stdlib flag">; + // Diagnostic control options def suppress_warnings : Flag<["-"], "suppress-warnings">, diff --git a/lib/Driver/DarwinToolChains.cpp b/lib/Driver/DarwinToolChains.cpp index 9ac64c30c5206..6ea07d3fb70bb 100644 --- a/lib/Driver/DarwinToolChains.cpp +++ b/lib/Driver/DarwinToolChains.cpp @@ -795,6 +795,11 @@ toolchains::Darwin::constructInvocation(const DynamicLinkJobAction &job, Arguments.push_back("-arch"); Arguments.push_back(context.Args.MakeArgString(getTriple().getArchName())); + // On Darwin, we only support libc++. + if (context.Args.hasArg(options::OPT_enable_experimental_cxx_interop)) { + Arguments.push_back("-lc++"); + } + addArgsToLinkStdlib(Arguments, job, context); addProfileGenerationArgs(Arguments, context); @@ -938,6 +943,13 @@ toolchains::Darwin::validateArguments(DiagnosticEngine &diags, if (args.hasArg(options::OPT_static_stdlib)) { diags.diagnose(SourceLoc(), diag::error_darwin_static_stdlib_not_supported); } + + // If a C++ standard library is specified, it has to be libc++. + if (auto arg = args.getLastArg(options::OPT_experimental_cxx_stdlib)) { + if (StringRef(arg->getValue()) != "libc++") { + diags.diagnose(SourceLoc(), diag::error_darwin_only_supports_libcxx); + } + } } void diff --git a/lib/Driver/ToolChains.cpp b/lib/Driver/ToolChains.cpp index 3e158c76274cf..2e1e320ebfe75 100644 --- a/lib/Driver/ToolChains.cpp +++ b/lib/Driver/ToolChains.cpp @@ -167,6 +167,17 @@ void ToolChain::addCommonFrontendArgs(const OutputInfo &OI, arguments.push_back("-disable-objc-interop"); } + // Add flags for C++ interop. + if (inputArgs.hasArg(options::OPT_enable_experimental_cxx_interop)) { + arguments.push_back("-enable-cxx-interop"); + } + if (const Arg *arg = + inputArgs.getLastArg(options::OPT_experimental_cxx_stdlib)) { + arguments.push_back("-Xcc"); + arguments.push_back( + inputArgs.MakeArgString(Twine("-stdlib=") + arg->getValue())); + } + // Handle the CPU and its preferences. inputArgs.AddLastArg(arguments, options::OPT_target_cpu); @@ -1330,6 +1341,26 @@ void ToolChain::getRuntimeLibraryPaths(SmallVectorImpl &runtimeLibP } } +const char *ToolChain::getClangLinkerDriver( + const llvm::opt::ArgList &Args) const { + // We don't use `clang++` unconditionally because we want to avoid pulling in + // a C++ standard library if it's not needed, in particular because the + // standard library that `clang++` selects by default may not be the one that + // is desired. + const char *LinkerDriver = + Args.hasArg(options::OPT_enable_experimental_cxx_interop) ? "clang++" + : "clang"; + if (const Arg *A = Args.getLastArg(options::OPT_tools_directory)) { + StringRef toolchainPath(A->getValue()); + + // If there is a linker driver in the toolchain folder, use that instead. + if (auto tool = llvm::sys::findProgramByName(LinkerDriver, {toolchainPath})) + LinkerDriver = Args.MakeArgString(tool.get()); + } + + return LinkerDriver; +} + bool ToolChain::sanitizerRuntimeLibExists(const ArgList &args, StringRef sanitizerName, bool shared) const { diff --git a/lib/Driver/UnixToolChains.cpp b/lib/Driver/UnixToolChains.cpp index 14a5e88a54de0..246ce6a8b7a2c 100644 --- a/lib/Driver/UnixToolChains.cpp +++ b/lib/Driver/UnixToolChains.cpp @@ -181,31 +181,9 @@ toolchains::GenericUnix::constructInvocation(const DynamicLinkJobAction &job, } // Configure the toolchain. - // - // By default use the system `clang` to perform the link. We use `clang` for - // the driver here because we do not wish to select a particular C++ runtime. - // Furthermore, until C++ interop is enabled, we cannot have a dependency on - // C++ code from pure Swift code. If linked libraries are C++ based, they - // should properly link C++. In the case of static linking, the user can - // explicitly specify the C++ runtime to link against. This is particularly - // important for platforms like android where as it is a Linux platform, the - // default C++ runtime is `libstdc++` which is unsupported on the target but - // as the builds are usually cross-compiled from Linux, libstdc++ is going to - // be present. This results in linking the wrong version of libstdc++ - // generating invalid binaries. It is also possible to use different C++ - // runtimes than the default C++ runtime for the platform (e.g. libc++ on - // Windows rather than msvcprt). When C++ interop is enabled, we will need to - // surface this via a driver flag. For now, opt for the simpler approach of - // just using `clang` and avoid a dependency on the C++ runtime. - const char *Clang = "clang"; if (const Arg *A = context.Args.getLastArg(options::OPT_tools_directory)) { StringRef toolchainPath(A->getValue()); - // If there is a clang in the toolchain folder, use that instead. - if (auto tool = llvm::sys::findProgramByName("clang", {toolchainPath})) { - Clang = context.Args.MakeArgString(tool.get()); - } - // Look for binutils in the toolchain folder. Arguments.push_back("-B"); Arguments.push_back(context.Args.MakeArgString(A->getValue())); @@ -307,6 +285,13 @@ toolchains::GenericUnix::constructInvocation(const DynamicLinkJobAction &job, } } + // Link against the desired C++ standard library. + if (const Arg *A = + context.Args.getLastArg(options::OPT_experimental_cxx_stdlib)) { + Arguments.push_back( + context.Args.MakeArgString(Twine("-stdlib=") + A->getValue())); + } + // Explicitly pass the target to the linker Arguments.push_back( context.Args.MakeArgString("--target=" + getTriple().str())); @@ -352,7 +337,7 @@ toolchains::GenericUnix::constructInvocation(const DynamicLinkJobAction &job, Arguments.push_back( context.Args.MakeArgString(context.Output.getPrimaryOutputFilename())); - InvocationInfo II{Clang, Arguments}; + InvocationInfo II{getClangLinkerDriver(context.Args), Arguments}; II.allowsResponseFiles = true; return II; diff --git a/lib/Driver/WindowsToolChains.cpp b/lib/Driver/WindowsToolChains.cpp index 749f0c902a80c..7cd26f17bc22d 100644 --- a/lib/Driver/WindowsToolChains.cpp +++ b/lib/Driver/WindowsToolChains.cpp @@ -84,32 +84,6 @@ toolchains::Windows::constructInvocation(const DynamicLinkJobAction &job, Arguments.push_back("/DEBUG"); } - // Configure the toolchain. - // - // By default use the system `clang` to perform the link. We use `clang` for - // the driver here because we do not wish to select a particular C++ runtime. - // Furthermore, until C++ interop is enabled, we cannot have a dependency on - // C++ code from pure Swift code. If linked libraries are C++ based, they - // should properly link C++. In the case of static linking, the user can - // explicitly specify the C++ runtime to link against. This is particularly - // important for platforms like android where as it is a Linux platform, the - // default C++ runtime is `libstdc++` which is unsupported on the target but - // as the builds are usually cross-compiled from Linux, libstdc++ is going to - // be present. This results in linking the wrong version of libstdc++ - // generating invalid binaries. It is also possible to use different C++ - // runtimes than the default C++ runtime for the platform (e.g. libc++ on - // Windows rather than msvcprt). When C++ interop is enabled, we will need to - // surface this via a driver flag. For now, opt for the simpler approach of - // just using `clang` and avoid a dependency on the C++ runtime. - const char *Clang = "clang"; - if (const Arg *A = context.Args.getLastArg(options::OPT_tools_directory)) { - StringRef toolchainPath(A->getValue()); - - // If there is a clang in the toolchain folder, use that instead. - if (auto tool = llvm::sys::findProgramByName("clang", {toolchainPath})) - Clang = context.Args.MakeArgString(tool.get()); - } - // Rely on `-libc` to correctly identify the MSVC Runtime Library. We use // `-nostartfiles` as that limits the difference to just the // `-defaultlib:libcmt` which is passed unconditionally with the `clang` @@ -159,6 +133,13 @@ toolchains::Windows::constructInvocation(const DynamicLinkJobAction &job, Arguments.push_back(context.Args.MakeArgString(context.OI.SDKPath)); } + // Link against the desired C++ standard library. + if (const Arg *A = + context.Args.getLastArg(options::OPT_experimental_cxx_stdlib)) { + Arguments.push_back(context.Args.MakeArgString( + Twine("-stdlib=") + A->getValue())); + } + if (job.getKind() == LinkKind::Executable) { if (context.OI.SelectedSanitizers & SanitizerKind::Address) addLinkRuntimeLib(context.Args, Arguments, @@ -196,7 +177,7 @@ toolchains::Windows::constructInvocation(const DynamicLinkJobAction &job, Arguments.push_back( context.Args.MakeArgString(context.Output.getPrimaryOutputFilename())); - InvocationInfo II{Clang, Arguments}; + InvocationInfo II{getClangLinkerDriver(context.Args), Arguments}; II.allowsResponseFiles = true; return II; diff --git a/test/Driver/cxx_interop.swift b/test/Driver/cxx_interop.swift new file mode 100644 index 0000000000000..7795e80b7777b --- /dev/null +++ b/test/Driver/cxx_interop.swift @@ -0,0 +1,10 @@ +// RUN: %swiftc_driver -driver-print-jobs -target x86_64-apple-macosx10.9 %s -enable-experimental-cxx-interop 2>^1 | %FileCheck -check-prefix ENABLE %s + +// RUN: %swiftc_driver -driver-print-jobs -target x86_64-apple-macosx10.9 %s -enable-experimental-cxx-interop -experimental-cxx-stdlib libc++ 2>^1 | %FileCheck -check-prefix STDLIB %s + +// ENABLE: swift +// ENABLE: -enable-cxx-interop + +// STDLIB: swift +// STDLIB-DAG: -enable-cxx-interop +// STDLIB-DAG: -Xcc -stdlib=libc++ diff --git a/test/Driver/linker.swift b/test/Driver/linker.swift index 991963830c960..339a3b964540d 100644 --- a/test/Driver/linker.swift +++ b/test/Driver/linker.swift @@ -101,6 +101,21 @@ // INFERRED_NAMED_DARWIN tests above: 'libLINKER.dylib'. // RUN: %swiftc_driver -sdk "" -driver-print-jobs -target x86_64-apple-macosx10.9 -emit-library %s -o libLINKER.dylib | %FileCheck -check-prefix INFERRED_NAME_DARWIN %s +// On Darwin, when C++ interop is turned on, we link against libc++ explicitly +// regardless of whether -experimental-cxx-stdlib is specified or not. So also +// run a test where C++ interop is turned off to make sure we don't link +// against libc++ in this case. +// RUN: %swiftc_driver -sdk "" -driver-print-jobs -target x86_64-apple-ios7.1 %s 2>&1 | %FileCheck -check-prefix IOS-no-cxx-interop %s +// RUN: %swiftc_driver -sdk "" -driver-print-jobs -target x86_64-apple-ios7.1 -enable-experimental-cxx-interop %s 2>&1 | %FileCheck -check-prefix IOS-cxx-interop-libcxx %s +// RUN: %swiftc_driver -sdk "" -driver-print-jobs -target x86_64-apple-ios7.1 -enable-experimental-cxx-interop -experimental-cxx-stdlib libc++ %s 2>&1 | %FileCheck -check-prefix IOS-cxx-interop-libcxx %s +// RUN: not %swiftc_driver -sdk "" -driver-print-jobs -target x86_64-apple-ios7.1 -enable-experimental-cxx-interop -experimental-cxx-stdlib libstdc++ %s 2>&1 | %FileCheck -check-prefix IOS-cxx-interop-libstdcxx %s + +// RUN: %swiftc_driver -sdk "" -driver-print-jobs -target x86_64-unknown-linux-gnu -enable-experimental-cxx-interop %s 2>&1 | %FileCheck -check-prefix LINUX-cxx-interop %s +// RUN: %swiftc_driver -sdk "" -driver-print-jobs -target x86_64-unknown-linux-gnu -enable-experimental-cxx-interop -experimental-cxx-stdlib libc++ %s 2>&1 | %FileCheck -check-prefix LINUX-cxx-interop-libcxx %s + +// RUN: %swiftc_driver -sdk "" -driver-print-jobs -target x86_64-unknown-windows-msvc -enable-experimental-cxx-interop %s 2>&1 | %FileCheck -check-prefix WINDOWS-cxx-interop %s +// RUN: %swiftc_driver -sdk "" -driver-print-jobs -target x86_64-unknown-windows-msvc -enable-experimental-cxx-interop -experimental-cxx-stdlib libc++ %s 2>&1 | %FileCheck -check-prefix WINDOWS-cxx-interop-libcxx %s + // Check reading the SDKSettings.json from an SDK // RUN: %swiftc_driver -sdk "" -driver-print-jobs -target x86_64-apple-macosx10.9 -sdk %S/Inputs/MacOSX10.15.versioned.sdk %s 2>&1 | %FileCheck -check-prefix MACOS_10_15 %s // RUN: %swiftc_driver -sdk "" -driver-print-jobs -target x86_64-apple-macosx10.9 -sdk %S/Inputs/MacOSX10.15.4.versioned.sdk %s 2>&1 | %FileCheck -check-prefix MACOS_10_15_4 %s @@ -424,6 +439,49 @@ // INFERRED_NAME_WINDOWS: -o LINKER.dll // INFERRED_NAME_WASI: -o libLINKER.so +// Instead of a single "NOT" check for this run, we would really want to check +// for all of the driver arguments that we _do_ expect, and then use an +// --implicit-check-not to check that -lc++ doesn't occur. +// However, --implicit-check-not has a bug where it fails to flag the +// unexpected text when it occurs after text matched by a CHECK-DAG; see +// https://bugs.llvm.org/show_bug.cgi?id=45629 +// For this reason, we use a single "NOT" check for the time being here. +// The same consideration applies to the Linux and Windows cases below. +// IOS-no-cxx-interop-NOT: -lc++ + +// IOS-cxx-interop-libcxx: swift +// IOS-cxx-interop-libcxx-DAG: -enable-cxx-interop +// IOS-cxx-interop-libcxx-DAG: -o [[OBJECTFILE:.*]] + +// IOS-cxx-interop-libcxx: {{(bin/)?}}ld{{"? }} +// IOS-cxx-interop-libcxx-DAG: [[OBJECTFILE]] +// IOS-cxx-interop-libcxx-DAG: -lc++ +// IOS-cxx-interop-libcxx: -o linker + +// IOS-cxx-interop-libstdcxx: error: The only C++ standard library supported on Apple platforms is libc++ + +// LINUX-cxx-interop-NOT: -stdlib + +// LINUX-cxx-interop-libcxx: swift +// LINUX-cxx-interop-libcxx-DAG: -enable-cxx-interop +// LINUX-cxx-interop-libcxx-DAG: -o [[OBJECTFILE:.*]] + +// LINUX-cxx-interop-libcxx: clang++{{(\.exe)?"? }} +// LINUX-cxx-interop-libcxx-DAG: [[OBJECTFILE]] +// LINUX-cxx-interop-libcxx-DAG: -stdlib=libc++ +// LINUX-cxx-interop-libcxx: -o linker + +// WINDOWS-cxx-interop-NOT: -stdlib + +// WINDOWS-cxx-interop-libcxx: swift +// WINDOWS-cxx-interop-libcxx-DAG: -enable-cxx-interop +// WINDOWS-cxx-interop-libcxx-DAG: -o [[OBJECTFILE:.*]] + +// WINDOWS-cxx-interop-libcxx: clang++{{(\.exe)?"? }} +// WINDOWS-cxx-interop-libcxx-DAG: [[OBJECTFILE]] +// WINDOWS-cxx-interop-libcxx-DAG: -stdlib=libc++ +// WINDOWS-cxx-interop-libcxx: -o linker + // Test ld detection. We use hard links to make sure // the Swift driver really thinks it's been moved. From 7fccbadce415bd68ad4d0c5d1a5f4b50c83774a7 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 9 Jul 2020 07:18:14 -0700 Subject: [PATCH 20/36] [AutoDiff] NFC: Reimplement `PullbackCloner` using pimpl pattern. (#32778) Reimplement `PullbackCloner` using the pointer-to-implementation pattern. `PullbackCloner.h` is now tiny: `PullbackCloner` exposes only a `bool run()` entry point. All of the implementation is moved to `PullbackCloner::Implementation` in `PullbackCloner.cpp`. Benefits of this approach: - A main benefit is that methods can be defined directly in `PullbackCloner.cpp` without needing to separately declare them in `PullbackCloner.h`. - There is now no code duplication between `PullbackCloner.h` and `PullbackCloner.cpp`. - Consequently, method documentation is easier to read because it appears directly on method definitions, instead of on method declarations in a separate file. This is important for documentation of `PullbackCloner` instruction visitor methods, which explain pullback transformation rules. - Incremental recompilation may be faster since `PullbackCloner.h` changes less often. Partially resolves SR-13182. --- .../Differentiation/PullbackCloner.h | 497 +-- .../Differentiation/PullbackCloner.cpp | 3319 +++++++++-------- 2 files changed, 1847 insertions(+), 1969 deletions(-) diff --git a/include/swift/SILOptimizer/Differentiation/PullbackCloner.h b/include/swift/SILOptimizer/Differentiation/PullbackCloner.h index 514557c0f03e9..37d666e16463c 100644 --- a/include/swift/SILOptimizer/Differentiation/PullbackCloner.h +++ b/include/swift/SILOptimizer/Differentiation/PullbackCloner.h @@ -38,501 +38,22 @@ namespace autodiff { class ADContext; class VJPCloner; -class PullbackCloner final : public SILInstructionVisitor { -private: - /// The parent VJP cloner. - VJPCloner &vjpCloner; - - /// Dominance info for the original function. - DominanceInfo *domInfo = nullptr; - - /// Post-dominance info for the original function. - PostDominanceInfo *postDomInfo = nullptr; - - /// Post-order info for the original function. - PostOrderFunctionInfo *postOrderInfo = nullptr; - - /// Mapping from original basic blocks to corresponding pullback basic blocks. - /// Pullback basic blocks always have the predecessor as the single argument. - llvm::DenseMap pullbackBBMap; - - /// Mapping from original basic blocks and original values to corresponding - /// adjoint values. - llvm::DenseMap, AdjointValue> valueMap; - - /// Mapping from original basic blocks and original values to corresponding - /// adjoint buffers. - llvm::DenseMap, SILValue> bufferMap; - - /// Mapping from pullback basic blocks to pullback struct arguments. - llvm::DenseMap pullbackStructArguments; - - /// Mapping from pullback struct field declarations to pullback struct - /// elements destructured from the linear map basic block argument. In the - /// beginning of each pullback basic block, the block's pullback struct is - /// destructured into individual elements stored here. - llvm::DenseMap pullbackStructElements; - - /// Mapping from original basic blocks and successor basic blocks to - /// corresponding pullback trampoline basic blocks. Trampoline basic blocks - /// take additional arguments in addition to the predecessor enum argument. - llvm::DenseMap, SILBasicBlock *> - pullbackTrampolineBBMap; - - /// Mapping from original basic blocks to dominated active values. - llvm::DenseMap> activeValues; - - /// Mapping from original basic blocks and original active values to - /// corresponding pullback block arguments. - llvm::DenseMap, SILArgument *> - activeValuePullbackBBArgumentMap; - - /// Mapping from original basic blocks to local temporary values to be cleaned - /// up. This is populated when pullback emission is run on one basic block and - /// cleaned before processing another basic block. - llvm::DenseMap> - blockTemporaries; - - /// The main builder. - SILBuilder builder; - - /// An auxiliary local allocation builder. - SILBuilder localAllocBuilder; - - /// Stack buffers allocated for storing local adjoint values. - SmallVector functionLocalAllocations; - - /// A set used to remember local allocations that were destroyed. - llvm::SmallDenseSet destroyedLocalAllocations; - - /// The seed arguments of the pullback function. - SmallVector seeds; - - llvm::BumpPtrAllocator allocator; - - bool errorOccurred = false; - - ADContext &getContext() const; - SILModule &getModule() const; - ASTContext &getASTContext() const; - SILFunction &getOriginal() const; - SILFunction &getPullback() const; - SILDifferentiabilityWitness *getWitness() const; - DifferentiationInvoker getInvoker() const; - LinearMapInfo &getPullbackInfo(); - const SILAutoDiffIndices getIndices() const; - const DifferentiableActivityInfo &getActivityInfo() const; +/// A helper class for generating pullback functions. +class PullbackCloner final { + struct Implementation; + Implementation &impl; public: - explicit PullbackCloner(VJPCloner &vjpCloner); - -private: - //--------------------------------------------------------------------------// - // Pullback struct mapping - //--------------------------------------------------------------------------// - - void initializePullbackStructElements(SILBasicBlock *origBB, - SILInstructionResultArray values); - - /// Returns the pullback struct element value corresponding to the given - /// original block and pullback struct field. - SILValue getPullbackStructElement(SILBasicBlock *origBB, VarDecl *field); - - //--------------------------------------------------------------------------// - // Type transformer - //--------------------------------------------------------------------------// - - /// Get the type lowering for the given AST type. - const Lowering::TypeLowering &getTypeLowering(Type type); - - /// Remap any archetypes into the current function's context. - SILType remapType(SILType ty); - - Optional getTangentSpace(CanType type); - - /// Returns the tangent value category of the given value. - SILValueCategory getTangentValueCategory(SILValue v); - - /// Assuming the given type conforms to `Differentiable` after remapping, - /// returns the associated tangent space type. - SILType getRemappedTangentType(SILType type); - - /// Substitutes all replacement types of the given substitution map using the - /// pullback function's substitution map. - SubstitutionMap remapSubstitutionMap(SubstitutionMap substMap); - - //--------------------------------------------------------------------------// - // Temporary value management - //--------------------------------------------------------------------------// - - /// Record a temporary value for cleanup before its block's terminator. - SILValue recordTemporary(SILValue value); - - /// Clean up all temporary values for the given pullback block. - void cleanUpTemporariesForBlock(SILBasicBlock *bb, SILLocation loc); - - //--------------------------------------------------------------------------// - // Adjoint value factory methods - //--------------------------------------------------------------------------// - - AdjointValue makeZeroAdjointValue(SILType type); - - AdjointValue makeConcreteAdjointValue(SILValue value); - - template - AdjointValue makeAggregateAdjointValue(SILType type, EltRange elements); - - //--------------------------------------------------------------------------// - // Adjoint value materialization - //--------------------------------------------------------------------------// - - /// Materializes an adjoint value. The type of the given adjoint value must be - /// loadable. - SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc); - - /// Materializes an adjoint value indirectly to a SIL buffer. - void materializeAdjointIndirect(AdjointValue val, SILValue destBuffer, - SILLocation loc); - - //--------------------------------------------------------------------------// - // Helpers for adjoint value materialization - //--------------------------------------------------------------------------// - - /// Emits a zero value into the given address by calling - /// `AdditiveArithmetic.zero`. The given type must conform to - /// `AdditiveArithmetic`. - void emitZeroIndirect(CanType type, SILValue address, SILLocation loc); - - /// Emits a zero value by calling `AdditiveArithmetic.zero`. The given type - /// must conform to `AdditiveArithmetic` and be loadable in SIL. - SILValue emitZeroDirect(CanType type, SILLocation loc); - - //--------------------------------------------------------------------------// - // Adjoint value mapping - //--------------------------------------------------------------------------// - - /// Returns true if the given value in the original function has a - /// corresponding adjoint value. - bool hasAdjointValue(SILBasicBlock *origBB, SILValue originalValue) const; - - /// Initializes the adjoint value for the original value. Asserts that the - /// original value does not already have an adjoint value. - void setAdjointValue(SILBasicBlock *origBB, SILValue originalValue, - AdjointValue adjointValue); - - /// Returns the adjoint value for a value in the original function. - /// - /// This method first tries to find an existing entry in the adjoint value - /// mapping. If no entry exists, creates a zero adjoint value. - AdjointValue getAdjointValue(SILBasicBlock *origBB, SILValue originalValue); - - /// Adds `newAdjointValue` to the adjoint value for `originalValue` and sets - /// the sum as the new adjoint value. - void addAdjointValue(SILBasicBlock *origBB, SILValue originalValue, - AdjointValue newAdjointValue, SILLocation loc); - - /// Get the pullback block argument corresponding to the given original block - /// and active value. - SILArgument *getActiveValuePullbackBlockArgument(SILBasicBlock *origBB, - SILValue activeValue); - - //--------------------------------------------------------------------------// - // Adjoint value accumulation - //--------------------------------------------------------------------------// - - /// Given two adjoint values, accumulates them and returns their sum. - AdjointValue accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs, - SILLocation loc); - - /// Generates code returning `result = lhs + rhs`. - /// - /// Given two materialized adjoint values, accumulates them and returns their - /// sum. The adjoint values must have a loadable type. - SILValue accumulateDirect(SILValue lhs, SILValue rhs, SILLocation loc); - - /// Generates code for `resultAddress = lhsAddress + rhsAddress`. - /// - /// Given two addresses with the same `AdditiveArithmetic`-conforming type, - /// accumulates them into a result address using `AdditiveArithmetic.+`. - void accumulateIndirect(SILValue resultAddress, SILValue lhsAddress, - SILValue rhsAddress, SILLocation loc); - - /// Generates code for `lhsDestAddress += rhsAddress`. - /// - /// Given two addresses with the same `AdditiveArithmetic`-conforming type, - /// accumulates the rhs into the lhs using `AdditiveArithmetic.+=`. - void accumulateIndirect(SILValue lhsDestAddress, SILValue rhsAddress, - SILLocation loc); - - //--------------------------------------------------------------------------// - // Adjoint buffer mapping - //--------------------------------------------------------------------------// - - /// If the given original value is an address projection, returns a - /// corresponding adjoint projection to be used as its adjoint buffer. - /// - /// Helper function for `getAdjointBuffer`. - SILValue getAdjointProjection(SILBasicBlock *origBB, SILValue originalValue); - - /// Returns the adjoint buffer for the original value. - /// - /// This method first tries to find an existing entry in the adjoint buffer - /// mapping. If no entry exists, creates a zero adjoint buffer. - SILValue &getAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue); - - /// Initializes the adjoint buffer for the original value. Asserts that the - /// original value does not already have an adjoint buffer. - void setAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue, - SILValue adjointBuffer); - - /// Accumulates `rhsAddress` into the adjoint buffer corresponding to the - /// original value. - void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue, - SILValue rhsAddress, SILLocation loc); - - /// Given the adjoint value of an array initialized from an - /// `array.uninitialized_intrinsic` application and an array element index, - /// returns an `alloc_stack` containing the adjoint value of the array element - /// at the given index by applying `Array.TangentVector.subscript`. - AllocStackInst *getArrayAdjointElementBuffer(SILValue arrayAdjoint, - int eltIndex, SILLocation loc); - - /// Given the adjoint value of an array initialized from an - /// `array.uninitialized_intrinsic` application, accumulates the adjoint - /// value's elements into the adjoint buffers of its element addresses. - void accumulateArrayLiteralElementAddressAdjoints( - SILBasicBlock *origBB, SILValue originalValue, - AdjointValue arrayAdjointValue, SILLocation loc); - - /// Returns a next insertion point for creating a local allocation: either - /// before the previous local allocation, or at the start of the pullback - /// entry if no local allocations exist. - /// - /// Helper for `createFunctionLocalAllocation`. - SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint(); - - /// Creates and returns a local allocation with the given type. - /// - /// Local allocations are created uninitialized in the pullback entry and - /// deallocated in the pullback exit. All local allocations not in - /// `destroyedLocalAllocations` are also destroyed in the pullback exit. + /// Creates a pullback cloner from a parent VJP cloner. /// - /// Helper for `getAdjointBuffer`. - AllocStackInst *createFunctionLocalAllocation(SILType type, SILLocation loc); - - //--------------------------------------------------------------------------// - // CFG mapping - //--------------------------------------------------------------------------// - - SILBasicBlock *getPullbackBlock(SILBasicBlock *originalBlock) { - return pullbackBBMap.lookup(originalBlock); - } - - SILBasicBlock *getPullbackTrampolineBlock(SILBasicBlock *originalBlock, - SILBasicBlock *successorBlock) { - return pullbackTrampolineBBMap.lookup({originalBlock, successorBlock}); - } - - //--------------------------------------------------------------------------// - // Debugging utilities - //--------------------------------------------------------------------------// - - void printAdjointValueMapping(); - void printAdjointBufferMapping(); - -public: - //--------------------------------------------------------------------------// - // Entry point - //--------------------------------------------------------------------------// + /// The parent VJP cloner stores the original function and an empty + /// to-be-generated pullback function. + explicit PullbackCloner(VJPCloner &vjpCloner); + ~PullbackCloner(); /// Performs pullback generation on the empty pullback function. Returns true /// if any error occurs. bool run(); - - /// Performs pullback generation on the empty pullback function, given that - /// the original function is a "semantic member accessor". - /// - /// "Semantic member accessors" are attached to member properties that have a - /// corresponding tangent stored property in the parent `TangentVector` type. - /// These accessors have special-case pullback generation based on their - /// semantic behavior. - /// - /// Returns true if any error occurs. - bool runForSemanticMemberAccessor(); - bool runForSemanticMemberGetter(); - bool runForSemanticMemberSetter(); - - /// If original result is non-varied, it will always have a zero derivative. - /// Skip full pullback generation and simply emit zero derivatives for wrt - /// parameters. - void emitZeroDerivativesForNonvariedResult(SILValue origNonvariedResult); - - using TrampolineBlockSet = SmallPtrSet; - - /// Determines the pullback successor block for a given original block and one - /// of its predecessors. When a trampoline block is necessary, emits code into - /// the trampoline block to trampoline the original block's active value's - /// adjoint values. - /// - /// Populates `pullbackTrampolineBlockMap`, which maps active values' adjoint - /// values to the pullback successor blocks in which they are used. This - /// allows us to release those values in pullback successor blocks that do not - /// use them. - SILBasicBlock * - buildPullbackSuccessor(SILBasicBlock *origBB, SILBasicBlock *origPredBB, - llvm::SmallDenseMap - &pullbackTrampolineBlockMap); - - /// Emits pullback code in the corresponding pullback block. - void visitSILBasicBlock(SILBasicBlock *bb); - - void visit(SILInstruction *inst); - - void visitSILInstruction(SILInstruction *inst); - - void visitApplyInst(ApplyInst *ai); - - void visitBeginApplyInst(BeginApplyInst *bai); - - /// Handle `struct` instruction. - /// Original: y = struct (x0, x1, x2, ...) - /// Adjoint: adj[x0] += struct_extract adj[y], #x0 - /// adj[x1] += struct_extract adj[y], #x1 - /// adj[x2] += struct_extract adj[y], #x2 - /// ... - void visitStructInst(StructInst *si); - - /// Handle `struct_extract` instruction. - /// Original: y = struct_extract x, #field - /// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0) - /// ^~~~~~~ - /// field in tangent space corresponding to #field - void visitStructExtractInst(StructExtractInst *sei); - - /// Handle `ref_element_addr` instruction. - /// Original: y = ref_element_addr x, - /// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0) - /// ^~~~~~~ - /// field in tangent space corresponding to #field - void visitRefElementAddrInst(RefElementAddrInst *reai); - - /// Handle `tuple` instruction. - /// Original: y = tuple (x0, x1, x2, ...) - /// Adjoint: (adj[x0], adj[x1], adj[x2], ...) += destructure_tuple adj[y] - /// ^~~ - /// excluding non-differentiable elements - void visitTupleInst(TupleInst *ti); - - /// Handle `tuple_extract` instruction. - /// Original: y = tuple_extract x, - /// Adjoint: adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0) - /// ^~~~~~ - /// n'-th element, where n' is tuple tangent space - /// index corresponding to n - void visitTupleExtractInst(TupleExtractInst *tei); - - /// Handle `destructure_tuple` instruction. - /// Original: (y0, ..., yn) = destructure_tuple x - /// Adjoint: adj[x].0 += adj[y0] - /// ... - /// adj[x].n += adj[yn] - void visitDestructureTupleInst(DestructureTupleInst *dti); - - /// Handle `load` or `load_borrow` instruction - /// Original: y = load/load_borrow x - /// Adjoint: adj[x] += adj[y] - void visitLoadOperation(SingleValueInstruction *inst); - void visitLoadInst(LoadInst *li) { visitLoadOperation(li); } - void visitLoadBorrowInst(LoadBorrowInst *lbi) { visitLoadOperation(lbi); } - - /// Handle `store` or `store_borrow` instruction. - /// Original: store/store_borrow x to y - /// Adjoint: adj[x] += load adj[y]; adj[y] = 0 - void visitStoreOperation(SILBasicBlock *bb, SILLocation loc, SILValue origSrc, - SILValue origDest); - void visitStoreInst(StoreInst *si); - void visitStoreBorrowInst(StoreBorrowInst *sbi) { - visitStoreOperation(sbi->getParent(), sbi->getLoc(), sbi->getSrc(), - sbi->getDest()); - } - - /// Handle `copy_addr` instruction. - /// Original: copy_addr x to y - /// Adjoint: adj[x] += adj[y]; adj[y] = 0 - void visitCopyAddrInst(CopyAddrInst *cai); - - /// Handle `copy_value` instruction. - /// Original: y = copy_value x - /// Adjoint: adj[x] += adj[y] - void visitCopyValueInst(CopyValueInst *cvi); - - /// Handle `begin_borrow` instruction. - /// Original: y = begin_borrow x - /// Adjoint: adj[x] += adj[y] - void visitBeginBorrowInst(BeginBorrowInst *bbi); - - /// Handle `begin_access` instruction. - /// Original: y = begin_access x - /// Adjoint: nothing - void visitBeginAccessInst(BeginAccessInst *bai); - - /// Handle `unconditional_checked_cast_addr` instruction. - /// Original: y = unconditional_checked_cast_addr x - /// Adjoint: adj[x] += unconditional_checked_cast_addr adj[y] - void visitUnconditionalCheckedCastAddrInst( - UnconditionalCheckedCastAddrInst *uccai); - - /// Handle `unchecked_ref_cast` instruction. - /// Original: y = unchecked_ref_cast x - /// Adjoint: adj[x] += adj[y] (assuming x' and y' have the same type) - void visitUncheckedRefCastInst(UncheckedRefCastInst *urci); - - /// Handle `upcast` instruction. - /// Original: y = upcast x - /// Adjoint: adj[x] += adj[y] (assuming x' and y' have the same type) - void visitUpcastInst(UpcastInst *ui); - -#define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst); -#undef NOT_DIFFERENTIABLE - -#define NO_ADJOINT(INST) \ - void visit##INST##Inst(INST##Inst *inst) {} - // Terminators. - NO_ADJOINT(Return) - NO_ADJOINT(Branch) - NO_ADJOINT(CondBranch) - - // Address projections. - NO_ADJOINT(StructElementAddr) - NO_ADJOINT(TupleElementAddr) - - // Array literal initialization address projections. - NO_ADJOINT(PointerToAddress) - NO_ADJOINT(IndexAddr) - - // Memory allocation/access. - NO_ADJOINT(AllocStack) - NO_ADJOINT(DeallocStack) - NO_ADJOINT(EndAccess) - - // Debugging/reference counting instructions. - NO_ADJOINT(DebugValue) - NO_ADJOINT(DebugValueAddr) - NO_ADJOINT(RetainValue) - NO_ADJOINT(RetainValueAddr) - NO_ADJOINT(ReleaseValue) - NO_ADJOINT(ReleaseValueAddr) - NO_ADJOINT(StrongRetain) - NO_ADJOINT(StrongRelease) - NO_ADJOINT(UnownedRetain) - NO_ADJOINT(UnownedRelease) - NO_ADJOINT(StrongRetainUnowned) - NO_ADJOINT(DestroyValue) - NO_ADJOINT(DestroyAddr) - - // Value ownership. - NO_ADJOINT(EndBorrow) -#undef NO_DERIVATIVE }; } // end namespace autodiff diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index 81e999aabc514..73f5efba6666b 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -42,653 +42,1037 @@ namespace autodiff { class ADContext; class VJPCloner; -PullbackCloner::PullbackCloner(VJPCloner &vjpCloner) - : vjpCloner(vjpCloner), builder(getPullback()), - localAllocBuilder(getPullback()) { - // Get dominance and post-order info for the original function. - auto &passManager = getContext().getPassManager(); - auto *domAnalysis = passManager.getAnalysis(); - auto *postDomAnalysis = passManager.getAnalysis(); - auto *postOrderAnalysis = passManager.getAnalysis(); - domInfo = domAnalysis->get(vjpCloner.original); - postDomInfo = postDomAnalysis->get(vjpCloner.original); - postOrderInfo = postOrderAnalysis->get(vjpCloner.original); -} +/// The implementation class for `PullbackCloner`. +/// +/// The implementation class is a `SILInstructionVisitor`. Effectively, it acts +/// as a `SILCloner` that visits basic blocks in post-order and that visits +/// instructions per basic block in reverse order. This visitation order is +/// necessary for generating pullback functions, whose control flow graph is +/// ~a transposed version of the original function's control flow graph. +struct PullbackCloner::Implementation final + : public SILInstructionVisitor { +private: + /// The parent VJP cloner. + VJPCloner &vjpCloner; + + /// Dominance info for the original function. + DominanceInfo *domInfo = nullptr; + + /// Post-dominance info for the original function. + PostDominanceInfo *postDomInfo = nullptr; + + /// Post-order info for the original function. + PostOrderFunctionInfo *postOrderInfo = nullptr; + + /// Mapping from original basic blocks to corresponding pullback basic blocks. + /// Pullback basic blocks always have the predecessor as the single argument. + llvm::DenseMap pullbackBBMap; + + /// Mapping from original basic blocks and original values to corresponding + /// adjoint values. + llvm::DenseMap, AdjointValue> valueMap; + + /// Mapping from original basic blocks and original values to corresponding + /// adjoint buffers. + llvm::DenseMap, SILValue> bufferMap; + + /// Mapping from pullback basic blocks to pullback struct arguments. + llvm::DenseMap pullbackStructArguments; + + /// Mapping from pullback struct field declarations to pullback struct + /// elements destructured from the linear map basic block argument. In the + /// beginning of each pullback basic block, the block's pullback struct is + /// destructured into individual elements stored here. + llvm::DenseMap pullbackStructElements; + + /// Mapping from original basic blocks and successor basic blocks to + /// corresponding pullback trampoline basic blocks. Trampoline basic blocks + /// take additional arguments in addition to the predecessor enum argument. + llvm::DenseMap, SILBasicBlock *> + pullbackTrampolineBBMap; + + /// Mapping from original basic blocks to dominated active values. + llvm::DenseMap> activeValues; + + /// Mapping from original basic blocks and original active values to + /// corresponding pullback block arguments. + llvm::DenseMap, SILArgument *> + activeValuePullbackBBArgumentMap; + + /// Mapping from original basic blocks to local temporary values to be cleaned + /// up. This is populated when pullback emission is run on one basic block and + /// cleaned before processing another basic block. + llvm::DenseMap> + blockTemporaries; + + /// The main builder. + SILBuilder builder; + + /// An auxiliary local allocation builder. + SILBuilder localAllocBuilder; + + /// Stack buffers allocated for storing local adjoint values. + SmallVector functionLocalAllocations; + + /// A set used to remember local allocations that were destroyed. + llvm::SmallDenseSet destroyedLocalAllocations; + + /// The seed arguments of the pullback function. + SmallVector seeds; + + llvm::BumpPtrAllocator allocator; + + bool errorOccurred = false; + + ADContext &getContext() const { return vjpCloner.context; } + SILModule &getModule() const { return getContext().getModule(); } + ASTContext &getASTContext() const { return getPullback().getASTContext(); } + SILFunction &getOriginal() const { return *vjpCloner.original; } + SILFunction &getPullback() const { return *vjpCloner.pullback; } + SILDifferentiabilityWitness *getWitness() const { return vjpCloner.witness; } + DifferentiationInvoker getInvoker() const { return vjpCloner.invoker; } + LinearMapInfo &getPullbackInfo() { return vjpCloner.pullbackInfo; } + const SILAutoDiffIndices getIndices() const { return vjpCloner.getIndices(); } + const DifferentiableActivityInfo &getActivityInfo() const { + return vjpCloner.activityInfo; + } + +public: + explicit Implementation(VJPCloner &vjpCloner); + +private: + //--------------------------------------------------------------------------// + // Pullback struct mapping + //--------------------------------------------------------------------------// + + void initializePullbackStructElements(SILBasicBlock *origBB, + SILInstructionResultArray values) { + auto *pbStructDecl = getPullbackInfo().getLinearMapStruct(origBB); + assert(pbStructDecl->getStoredProperties().size() == values.size() && + "The number of pullback struct fields must equal the number of " + "pullback struct element values"); + for (auto pair : llvm::zip(pbStructDecl->getStoredProperties(), values)) { + assert(std::get<1>(pair).getOwnershipKind() != + ValueOwnershipKind::Guaranteed && + "Pullback struct elements must be @owned"); + auto insertion = + pullbackStructElements.insert({std::get<0>(pair), std::get<1>(pair)}); + (void)insertion; + assert(insertion.second && "A pullback struct element already exists!"); + } + } -ADContext &PullbackCloner::getContext() const { return vjpCloner.context; } + /// Returns the pullback struct element value corresponding to the given + /// original block and pullback struct field. + SILValue getPullbackStructElement(SILBasicBlock *origBB, VarDecl *field) { + assert(getPullbackInfo().getLinearMapStruct(origBB) == + cast(field->getDeclContext())); + assert(pullbackStructElements.count(field) && + "Pullback struct element for this field does not exist!"); + return pullbackStructElements.lookup(field); + } -SILModule &PullbackCloner::getModule() const { - return getContext().getModule(); -} + //--------------------------------------------------------------------------// + // Type transformer + //--------------------------------------------------------------------------// -ASTContext &PullbackCloner::getASTContext() const { - return getPullback().getASTContext(); -} + /// Get the type lowering for the given AST type. + const Lowering::TypeLowering &getTypeLowering(Type type) { + auto pbGenSig = + getPullback().getLoweredFunctionType()->getSubstGenericSignature(); + Lowering::AbstractionPattern pattern(pbGenSig, + type->getCanonicalType(pbGenSig)); + return getPullback().getTypeLowering(pattern, type); + } -SILFunction &PullbackCloner::getOriginal() const { return *vjpCloner.original; } + /// Remap any archetypes into the current function's context. + SILType remapType(SILType ty) { + if (ty.hasArchetype()) + ty = ty.mapTypeOutOfContext(); + auto remappedType = ty.getASTType()->getCanonicalType( + getPullback().getLoweredFunctionType()->getSubstGenericSignature()); + auto remappedSILType = + SILType::getPrimitiveType(remappedType, ty.getCategory()); + return getPullback().mapTypeIntoContext(remappedSILType); + } -SILFunction &PullbackCloner::getPullback() const { return *vjpCloner.pullback; } + Optional getTangentSpace(CanType type) { + // Use witness generic signature to remap types. + if (auto witnessGenSig = getWitness()->getDerivativeGenericSignature()) + type = witnessGenSig->getCanonicalTypeInContext(type); + return type->getAutoDiffTangentSpace( + LookUpConformanceInModule(getModule().getSwiftModule())); + } -SILDifferentiabilityWitness *PullbackCloner::getWitness() const { - return vjpCloner.witness; -} + /// Returns the tangent value category of the given value. + SILValueCategory getTangentValueCategory(SILValue v) { + // Tangent value category table: + // + // Let $L be a loadable type and $*A be an address-only type. + // + // Original type | Tangent type loadable? | Tangent value category and type + // --------------|------------------------|-------------------------------- + // $L | loadable | object, $L' (no mismatch) + // $*A | loadable | address, $*L' (create a buffer) + // $L | address-only | address, $*A' (no alternative) + // $*A | address-only | address, $*A' (no alternative) + + // TODO(SR-13077): Make "tangent value category" depend solely on whether + // the tangent type is loadable or address-only. + // + // For loadable tangent types, using symbolic adjoint values instead of + // concrete adjoint buffers is more efficient. + + // Quick check: if the value has an address type, the tangent value category + // is currently always "address". + if (v->getType().isAddress()) + return SILValueCategory::Address; + // If the value has an object type and the tangent type is not address-only, + // then the tangent value category is "object". + auto tanSpace = getTangentSpace(remapType(v->getType()).getASTType()); + auto tanASTType = tanSpace->getCanonicalType(); + if (v->getType().isObject() && getTypeLowering(tanASTType).isLoadable()) + return SILValueCategory::Object; + // Otherwise, the tangent value category is "address". + return SILValueCategory::Address; + } -DifferentiationInvoker PullbackCloner::getInvoker() const { - return vjpCloner.invoker; -} + /// Assuming the given type conforms to `Differentiable` after remapping, + /// returns the associated tangent space type. + SILType getRemappedTangentType(SILType type) { + return SILType::getPrimitiveType( + getTangentSpace(remapType(type).getASTType())->getCanonicalType(), + type.getCategory()); + } -LinearMapInfo &PullbackCloner::getPullbackInfo() { - return vjpCloner.pullbackInfo; -} + /// Substitutes all replacement types of the given substitution map using the + /// pullback function's substitution map. + SubstitutionMap remapSubstitutionMap(SubstitutionMap substMap) { + return substMap.subst(getPullback().getForwardingSubstitutionMap()); + } -const SILAutoDiffIndices PullbackCloner::getIndices() const { - return vjpCloner.getIndices(); -} + //--------------------------------------------------------------------------// + // Temporary value management + //--------------------------------------------------------------------------// -const DifferentiableActivityInfo &PullbackCloner::getActivityInfo() const { - return vjpCloner.activityInfo; -} + /// Record a temporary value for cleanup before its block's terminator. + SILValue recordTemporary(SILValue value) { + assert(value->getType().isObject()); + assert(value->getFunction() == &getPullback()); + auto inserted = blockTemporaries[value->getParentBlock()].insert(value); + (void)inserted; + LLVM_DEBUG(getADDebugStream() << "Recorded temporary " << value); + assert(inserted && "Temporary already recorded?"); + return value; + } -//--------------------------------------------------------------------------// -// Pullback struct mapping -//--------------------------------------------------------------------------// + /// Clean up all temporary values for the given pullback block. + void cleanUpTemporariesForBlock(SILBasicBlock *bb, SILLocation loc) { + assert(bb->getParent() == &getPullback()); + LLVM_DEBUG(getADDebugStream() << "Cleaning up temporaries for pullback bb" + << bb->getDebugID() << '\n'); + for (auto temp : blockTemporaries[bb]) + builder.emitDestroyValueOperation(loc, temp); + blockTemporaries[bb].clear(); + } -void PullbackCloner::initializePullbackStructElements( - SILBasicBlock *origBB, SILInstructionResultArray values) { - auto *pbStructDecl = getPullbackInfo().getLinearMapStruct(origBB); - assert(pbStructDecl->getStoredProperties().size() == values.size() && - "The number of pullback struct fields must equal the number of " - "pullback struct element values"); - for (auto pair : llvm::zip(pbStructDecl->getStoredProperties(), values)) { - assert(std::get<1>(pair).getOwnershipKind() != - ValueOwnershipKind::Guaranteed && - "Pullback struct elements must be @owned"); - auto insertion = - pullbackStructElements.insert({std::get<0>(pair), std::get<1>(pair)}); - (void)insertion; - assert(insertion.second && "A pullback struct element already exists!"); + //--------------------------------------------------------------------------// + // Adjoint value factory methods + //--------------------------------------------------------------------------// + + AdjointValue makeZeroAdjointValue(SILType type) { + return AdjointValue::createZero(allocator, remapType(type)); } -} -SILValue PullbackCloner::getPullbackStructElement(SILBasicBlock *origBB, - VarDecl *field) { - assert(getPullbackInfo().getLinearMapStruct(origBB) == - cast(field->getDeclContext())); - assert(pullbackStructElements.count(field) && - "Pullback struct element for this field does not exist!"); - return pullbackStructElements.lookup(field); -} + AdjointValue makeConcreteAdjointValue(SILValue value) { + return AdjointValue::createConcrete(allocator, value); + } -//--------------------------------------------------------------------------// -// Temporary value management -//--------------------------------------------------------------------------// + template + AdjointValue makeAggregateAdjointValue(SILType type, EltRange elements) { + AdjointValue *buf = reinterpret_cast(allocator.Allocate( + elements.size() * sizeof(AdjointValue), alignof(AdjointValue))); + MutableArrayRef elementsCopy(buf, elements.size()); + std::uninitialized_copy(elements.begin(), elements.end(), + elementsCopy.begin()); + return AdjointValue::createAggregate(allocator, remapType(type), + elementsCopy); + } -SILValue PullbackCloner::recordTemporary(SILValue value) { - assert(value->getType().isObject()); - assert(value->getFunction() == &getPullback()); - auto inserted = blockTemporaries[value->getParentBlock()].insert(value); - (void)inserted; - LLVM_DEBUG(getADDebugStream() << "Recorded temporary " << value); - assert(inserted && "Temporary already recorded?"); - return value; -} + //--------------------------------------------------------------------------// + // Adjoint value materialization + //--------------------------------------------------------------------------// -void PullbackCloner::cleanUpTemporariesForBlock(SILBasicBlock *bb, - SILLocation loc) { - assert(bb->getParent() == &getPullback()); - LLVM_DEBUG(getADDebugStream() << "Cleaning up temporaries for pullback bb" - << bb->getDebugID() << '\n'); - for (auto temp : blockTemporaries[bb]) - builder.emitDestroyValueOperation(loc, temp); - blockTemporaries[bb].clear(); -} + /// Materializes an adjoint value. The type of the given adjoint value must be + /// loadable. + SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc) { + assert(val.getType().isObject()); + LLVM_DEBUG(getADDebugStream() + << "Materializing adjoint for " << val << '\n'); + switch (val.getKind()) { + case AdjointValueKind::Zero: + return recordTemporary(emitZeroDirect(val.getType().getASTType(), loc)); + case AdjointValueKind::Aggregate: { + SmallVector elements; + for (auto i : range(val.getNumAggregateElements())) { + auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc); + elements.push_back(builder.emitCopyValueOperation(loc, eltVal)); + } + if (val.getType().is()) + return recordTemporary( + builder.createTuple(loc, val.getType(), elements)); + else + return recordTemporary( + builder.createStruct(loc, val.getType(), elements)); + } + case AdjointValueKind::Concrete: + return val.getConcreteValue(); + } + } -//--------------------------------------------------------------------------// -// Type transformer -//--------------------------------------------------------------------------// + /// Materializes an adjoint value indirectly to a SIL buffer. + void materializeAdjointIndirect(AdjointValue val, SILValue destAddress, + SILLocation loc) { + assert(destAddress->getType().isAddress()); + switch (val.getKind()) { + /// If adjoint value is a symbolic zero, emit a call to + /// `AdditiveArithmetic.zero`. + case AdjointValueKind::Zero: + emitZeroIndirect(val.getSwiftType(), destAddress, loc); + break; + /// If adjoint value is a symbolic aggregate (tuple or struct), recursively + /// materialize materialize the symbolic tuple or struct, filling the + /// buffer. + case AdjointValueKind::Aggregate: { + if (auto *tupTy = val.getSwiftType()->getAs()) { + for (auto idx : range(val.getNumAggregateElements())) { + auto eltTy = SILType::getPrimitiveAddressType( + tupTy->getElementType(idx)->getCanonicalType()); + auto *eltBuf = + builder.createTupleElementAddr(loc, destAddress, idx, eltTy); + materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc); + } + } else if (auto *structDecl = + val.getSwiftType()->getStructOrBoundGenericStruct()) { + auto fieldIt = structDecl->getStoredProperties().begin(); + for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end(); + ++fieldIt, ++i) { + auto eltBuf = + builder.createStructElementAddr(loc, destAddress, *fieldIt); + materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc); + } + } else { + llvm_unreachable("Not an aggregate type"); + } + break; + } + /// If adjoint value is concrete, it is already materialized. Store it in + /// the destination address. + case AdjointValueKind::Concrete: + auto concreteVal = val.getConcreteValue(); + builder.emitStoreValueOperation(loc, concreteVal, destAddress, + StoreOwnershipQualifier::Init); + break; + } + } -const Lowering::TypeLowering &PullbackCloner::getTypeLowering(Type type) { - auto pbGenSig = - getPullback().getLoweredFunctionType()->getSubstGenericSignature(); - Lowering::AbstractionPattern pattern(pbGenSig, - type->getCanonicalType(pbGenSig)); - return getPullback().getTypeLowering(pattern, type); -} + //--------------------------------------------------------------------------// + // Helpers for adjoint value materialization + //--------------------------------------------------------------------------// -/// Remap any archetypes into the current function's context. -SILType PullbackCloner::remapType(SILType ty) { - if (ty.hasArchetype()) - ty = ty.mapTypeOutOfContext(); - auto remappedType = ty.getASTType()->getCanonicalType( - getPullback().getLoweredFunctionType()->getSubstGenericSignature()); - auto remappedSILType = - SILType::getPrimitiveType(remappedType, ty.getCategory()); - return getPullback().mapTypeIntoContext(remappedSILType); -} + /// Emits a zero value into the given address by calling + /// `AdditiveArithmetic.zero`. The given type must conform to + /// `AdditiveArithmetic`. + void emitZeroIndirect(CanType type, SILValue address, SILLocation loc) { + auto tangentSpace = getTangentSpace(type); + assert(tangentSpace && "No tangent space for this type"); + switch (tangentSpace->getKind()) { + case TangentSpace::Kind::TangentVector: + emitZeroIntoBuffer(builder, type, address, loc); + return; + case TangentSpace::Kind::Tuple: { + auto tupleType = tangentSpace->getTuple(); + SmallVector zeroElements; + for (unsigned i : range(tupleType->getNumElements())) { + auto eltAddr = builder.createTupleElementAddr(loc, address, i); + emitZeroIndirect(tupleType->getElementType(i)->getCanonicalType(), + eltAddr, loc); + } + return; + } + } + } -Optional PullbackCloner::getTangentSpace(CanType type) { - // Use witness generic signature to remap types. - if (auto witnessGenSig = getWitness()->getDerivativeGenericSignature()) - type = witnessGenSig->getCanonicalTypeInContext(type); - return type->getAutoDiffTangentSpace( - LookUpConformanceInModule(getModule().getSwiftModule())); -} + /// Emits a zero value by calling `AdditiveArithmetic.zero`. The given type + /// must conform to `AdditiveArithmetic` and be loadable in SIL. + SILValue emitZeroDirect(CanType type, SILLocation loc) { + auto silType = getModule().Types.getLoweredLoadableType( + type, TypeExpansionContext::minimal(), getModule()); + auto *alloc = builder.createAllocStack(loc, silType); + emitZeroIndirect(type, alloc, loc); + auto zeroValue = builder.emitLoadValueOperation( + loc, alloc, LoadOwnershipQualifier::Take); + builder.createDeallocStack(loc, alloc); + return zeroValue; + } -SILValueCategory PullbackCloner::getTangentValueCategory(SILValue v) { - // Tangent value category table: - // - // Let $L be a loadable type and $*A be an address-only type. - // - // Original type | Tangent type loadable? | Tangent value category and type - // --------------|------------------------|-------------------------------- - // $L | loadable | object, $L' (no mismatch) - // $*A | loadable | address, $*L' (create a buffer) - // $L | address-only | address, $*A' (no alternative) - // $*A | address-only | address, $*A' (no alternative) - - // TODO(SR-13077): Make "tangent value category" depend solely on whether the - // tangent type is loadable or address-only. - // - // For loadable tangent types, using symbolic adjoint values instead of - // concrete adjoint buffers is more efficient. - - // Quick check: if the value has an address type, the tangent value category - // is currently always "address". - if (v->getType().isAddress()) - return SILValueCategory::Address; - // If the value has an object type and the tangent type is not address-only, - // then the tangent value category is "object". - auto tanSpace = getTangentSpace(remapType(v->getType()).getASTType()); - auto tanASTType = tanSpace->getCanonicalType(); - if (v->getType().isObject() && getTypeLowering(tanASTType).isLoadable()) - return SILValueCategory::Object; - // Otherwise, the tangent value category is "address". - return SILValueCategory::Address; -} + //--------------------------------------------------------------------------// + // Adjoint value mapping + //--------------------------------------------------------------------------// -SILType PullbackCloner::getRemappedTangentType(SILType type) { - return SILType::getPrimitiveType( - getTangentSpace(remapType(type).getASTType())->getCanonicalType(), - type.getCategory()); -} + /// Returns true if the given value in the original function has a + /// corresponding adjoint value. + bool hasAdjointValue(SILBasicBlock *origBB, SILValue originalValue) const { + assert(origBB->getParent() == &getOriginal()); + assert(originalValue->getType().isObject()); + return valueMap.count({origBB, originalValue}); + } -SubstitutionMap PullbackCloner::remapSubstitutionMap(SubstitutionMap substMap) { - return substMap.subst(getPullback().getForwardingSubstitutionMap()); -} + /// Initializes the adjoint value for the original value. Asserts that the + /// original value does not already have an adjoint value. + void setAdjointValue(SILBasicBlock *origBB, SILValue originalValue, + AdjointValue adjointValue) { + LLVM_DEBUG(getADDebugStream() + << "Setting adjoint value for " << originalValue); + assert(origBB->getParent() == &getOriginal()); + assert(originalValue->getType().isObject()); + assert(getTangentValueCategory(originalValue) == SILValueCategory::Object); + assert(adjointValue.getType().isObject()); + assert(originalValue->getFunction() == &getOriginal()); + // The adjoint value must be in the tangent space. + assert(adjointValue.getType() == + getRemappedTangentType(originalValue->getType())); + auto insertion = + valueMap.try_emplace({origBB, originalValue}, adjointValue); + LLVM_DEBUG(getADDebugStream() + << "The new adjoint value, replacing the existing one, is: " + << insertion.first->getSecond()); + if (!insertion.second) + insertion.first->getSecond() = adjointValue; + } + + /// Returns the adjoint value for a value in the original function. + /// + /// This method first tries to find an existing entry in the adjoint value + /// mapping. If no entry exists, creates a zero adjoint value. + AdjointValue getAdjointValue(SILBasicBlock *origBB, SILValue originalValue) { + assert(origBB->getParent() == &getOriginal()); + assert(originalValue->getType().isObject()); + assert(getTangentValueCategory(originalValue) == SILValueCategory::Object); + assert(originalValue->getFunction() == &getOriginal()); + auto insertion = valueMap.try_emplace( + {origBB, originalValue}, + makeZeroAdjointValue(getRemappedTangentType(originalValue->getType()))); + auto it = insertion.first; + return it->getSecond(); + } + + /// Adds `newAdjointValue` to the adjoint value for `originalValue` and sets + /// the sum as the new adjoint value. + void addAdjointValue(SILBasicBlock *origBB, SILValue originalValue, + AdjointValue newAdjointValue, SILLocation loc) { + assert(origBB->getParent() == &getOriginal()); + assert(originalValue->getType().isObject()); + assert(newAdjointValue.getType().isObject()); + assert(originalValue->getFunction() == &getOriginal()); + LLVM_DEBUG(getADDebugStream() + << "Adding adjoint value for " << originalValue); + // The adjoint value must be in the tangent space. + assert(newAdjointValue.getType() == + getRemappedTangentType(originalValue->getType())); + auto insertion = + valueMap.try_emplace({origBB, originalValue}, newAdjointValue); + auto inserted = insertion.second; + if (inserted) + return; + // If adjoint already exists, accumulate the adjoint onto the existing + // adjoint. + auto it = insertion.first; + auto existingValue = it->getSecond(); + valueMap.erase(it); + auto adjVal = accumulateAdjointsDirect(existingValue, newAdjointValue, loc); + // If the original value is the `Array` result of an + // `array.uninitialized_intrinsic` application, accumulate adjoint buffers + // for the array element addresses. + accumulateArrayLiteralElementAddressAdjoints(origBB, originalValue, adjVal, + loc); + setAdjointValue(origBB, originalValue, adjVal); + } + + /// Get the pullback block argument corresponding to the given original block + /// and active value. + SILArgument *getActiveValuePullbackBlockArgument(SILBasicBlock *origBB, + SILValue activeValue) { + assert(getTangentValueCategory(activeValue) == SILValueCategory::Object); + assert(origBB->getParent() == &getOriginal()); + auto pullbackBBArg = + activeValuePullbackBBArgumentMap[{origBB, activeValue}]; + assert(pullbackBBArg); + assert(pullbackBBArg->getParent() == getPullbackBlock(origBB)); + return pullbackBBArg; + } + + //--------------------------------------------------------------------------// + // Adjoint value accumulation + //--------------------------------------------------------------------------// + + /// Given two adjoint values, accumulates them and returns their sum. + AdjointValue accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs, + SILLocation loc); + + /// Generates code returning `result = lhs + rhs`. + /// + /// Given two materialized adjoint values, accumulates them and returns their + /// sum. The adjoint values must have a loadable type. + SILValue accumulateDirect(SILValue lhs, SILValue rhs, SILLocation loc); + + /// Generates code for `resultAddress = lhsAddress + rhsAddress`. + /// + /// Given two addresses with the same `AdditiveArithmetic`-conforming type, + /// accumulates them into a result address using `AdditiveArithmetic.+`. + void accumulateIndirect(SILValue resultAddress, SILValue lhsAddress, + SILValue rhsAddress, SILLocation loc); + + /// Generates code for `lhsDestAddress += rhsAddress`. + /// + /// Given two addresses with the same `AdditiveArithmetic`-conforming type, + /// accumulates the rhs into the lhs using `AdditiveArithmetic.+=`. + void accumulateIndirect(SILValue lhsDestAddress, SILValue rhsAddress, + SILLocation loc); + + //--------------------------------------------------------------------------// + // Adjoint buffer mapping + //--------------------------------------------------------------------------// + + /// If the given original value is an address projection, returns a + /// corresponding adjoint projection to be used as its adjoint buffer. + /// + /// Helper function for `getAdjointBuffer`. + SILValue getAdjointProjection(SILBasicBlock *origBB, SILValue originalValue); + + /// Returns the adjoint buffer for the original value. + /// + /// This method first tries to find an existing entry in the adjoint buffer + /// mapping. If no entry exists, creates a zero adjoint buffer. + SILValue &getAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue) { + assert(getTangentValueCategory(originalValue) == SILValueCategory::Address); + assert(originalValue->getFunction() == &getOriginal()); + auto insertion = bufferMap.try_emplace({origBB, originalValue}, SILValue()); + if (!insertion.second) // not inserted + return insertion.first->getSecond(); + + // If the original buffer is a projection, return a corresponding projection + // into the adjoint buffer. + if (auto adjProj = getAdjointProjection(origBB, originalValue)) + return (bufferMap[{origBB, originalValue}] = adjProj); + + auto bufType = getRemappedTangentType(originalValue->getType()); + // Set insertion point for local allocation builder: before the last local + // allocation, or at the start of the pullback function's entry if no local + // allocations exist yet. + auto *newBuf = createFunctionLocalAllocation( + bufType, RegularLocation::getAutoGeneratedLocation()); + // Temporarily change global builder insertion point and emit zero into the + // local allocation. + auto insertionPoint = builder.getInsertionBB(); + builder.setInsertionPoint(localAllocBuilder.getInsertionBB(), + localAllocBuilder.getInsertionPoint()); + emitZeroIndirect(bufType.getASTType(), newBuf, newBuf->getLoc()); + builder.setInsertionPoint(insertionPoint); + return (insertion.first->getSecond() = newBuf); + } + + /// Initializes the adjoint buffer for the original value. Asserts that the + /// original value does not already have an adjoint buffer. + void setAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue, + SILValue adjointBuffer) { + assert(getTangentValueCategory(originalValue) == SILValueCategory::Address); + auto insertion = + bufferMap.try_emplace({origBB, originalValue}, adjointBuffer); + assert(insertion.second && "Adjoint buffer already exists"); + (void)insertion; + } -//--------------------------------------------------------------------------// -// Adjoint value mapping -//--------------------------------------------------------------------------// + /// Accumulates `rhsAddress` into the adjoint buffer corresponding to the + /// original value. + void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue, + SILValue rhsAddress, SILLocation loc) { + assert(getTangentValueCategory(originalValue) == + SILValueCategory::Address && + rhsAddress->getType().isAddress()); + assert(originalValue->getFunction() == &getOriginal()); + assert(rhsAddress->getFunction() == &getPullback()); + auto adjointBuffer = getAdjointBuffer(origBB, originalValue); + accumulateIndirect(adjointBuffer, rhsAddress, loc); + } + + /// Returns a next insertion point for creating a local allocation: either + /// before the previous local allocation, or at the start of the pullback + /// entry if no local allocations exist. + /// + /// Helper for `createFunctionLocalAllocation`. + SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint() { + // If there are no local allocations, insert at the pullback entry start. + if (functionLocalAllocations.empty()) + return getPullback().getEntryBlock()->begin(); + // Otherwise, insert before the last local allocation. Inserting before + // rather than after ensures that allocation and zero initialization + // instructions are grouped together. + auto lastLocalAlloc = functionLocalAllocations.back(); + return lastLocalAlloc->getDefiningInstruction()->getIterator(); + } + + /// Creates and returns a local allocation with the given type. + /// + /// Local allocations are created uninitialized in the pullback entry and + /// deallocated in the pullback exit. All local allocations not in + /// `destroyedLocalAllocations` are also destroyed in the pullback exit. + /// + /// Helper for `getAdjointBuffer`. + AllocStackInst *createFunctionLocalAllocation(SILType type, SILLocation loc) { + // Set insertion point for local allocation builder: before the last local + // allocation, or at the start of the pullback function's entry if no local + // allocations exist yet. + localAllocBuilder.setInsertionPoint( + getPullback().getEntryBlock(), + getNextFunctionLocalAllocationInsertionPoint()); + // Create and return local allocation. + auto *alloc = localAllocBuilder.createAllocStack(loc, type); + functionLocalAllocations.push_back(alloc); + return alloc; + } + + //--------------------------------------------------------------------------// + // Array literal initialization differentiation + //--------------------------------------------------------------------------// + + /// Given the adjoint value of an array initialized from an + /// `array.uninitialized_intrinsic` application and an array element index, + /// returns an `alloc_stack` containing the adjoint value of the array element + /// at the given index by applying `Array.TangentVector.subscript`. + AllocStackInst *getArrayAdjointElementBuffer(SILValue arrayAdjoint, + int eltIndex, SILLocation loc); + + /// Given the adjoint value of an array initialized from an + /// `array.uninitialized_intrinsic` application, accumulates the adjoint + /// value's elements into the adjoint buffers of its element addresses. + void accumulateArrayLiteralElementAddressAdjoints( + SILBasicBlock *origBB, SILValue originalValue, + AdjointValue arrayAdjointValue, SILLocation loc); + + //--------------------------------------------------------------------------// + // CFG mapping + //--------------------------------------------------------------------------// + + SILBasicBlock *getPullbackBlock(SILBasicBlock *originalBlock) { + return pullbackBBMap.lookup(originalBlock); + } + + SILBasicBlock *getPullbackTrampolineBlock(SILBasicBlock *originalBlock, + SILBasicBlock *successorBlock) { + return pullbackTrampolineBBMap.lookup({originalBlock, successorBlock}); + } + + //--------------------------------------------------------------------------// + // Debugging utilities + //--------------------------------------------------------------------------// + + void printAdjointValueMapping() { + // Group original/adjoint values by basic block. + llvm::DenseMap> tmp; + for (auto pair : valueMap) { + auto origPair = pair.first; + auto *origBB = origPair.first; + auto origValue = origPair.second; + auto adjValue = pair.second; + tmp[origBB].insert({origValue, adjValue}); + } + // Print original/adjoint values per basic block. + auto &s = getADDebugStream() << "Adjoint value mapping:\n"; + for (auto &origBB : getOriginal()) { + if (!pullbackBBMap.count(&origBB)) + continue; + auto bbValueMap = tmp[&origBB]; + s << "bb" << origBB.getDebugID(); + s << " (size " << bbValueMap.size() << "):\n"; + for (auto valuePair : bbValueMap) { + auto origValue = valuePair.first; + auto adjValue = valuePair.second; + s << "ORIG: " << origValue; + s << "ADJ: " << adjValue << '\n'; + } + s << '\n'; + } + } -bool PullbackCloner::hasAdjointValue(SILBasicBlock *origBB, - SILValue originalValue) const { - assert(origBB->getParent() == &getOriginal()); - assert(originalValue->getType().isObject()); - return valueMap.count({origBB, originalValue}); -} + void printAdjointBufferMapping() { + // Group original/adjoint buffers by basic block. + llvm::DenseMap> tmp; + for (auto pair : bufferMap) { + auto origPair = pair.first; + auto *origBB = origPair.first; + auto origBuf = origPair.second; + auto adjBuf = pair.second; + tmp[origBB][origBuf] = adjBuf; + } + // Print original/adjoint buffers per basic block. + auto &s = getADDebugStream() << "Adjoint buffer mapping:\n"; + for (auto &origBB : getOriginal()) { + if (!pullbackBBMap.count(&origBB)) + continue; + auto bbBufferMap = tmp[&origBB]; + s << "bb" << origBB.getDebugID(); + s << " (size " << bbBufferMap.size() << "):\n"; + for (auto valuePair : bbBufferMap) { + auto origBuf = valuePair.first; + auto adjBuf = valuePair.second; + s << "ORIG: " << origBuf; + s << "ADJ: " << adjBuf << '\n'; + } + s << '\n'; + } + } -void PullbackCloner::setAdjointValue(SILBasicBlock *origBB, - SILValue originalValue, - AdjointValue adjointValue) { - LLVM_DEBUG(getADDebugStream() - << "Setting adjoint value for " << originalValue); - assert(origBB->getParent() == &getOriginal()); - assert(originalValue->getType().isObject()); - assert(getTangentValueCategory(originalValue) == SILValueCategory::Object); - assert(adjointValue.getType().isObject()); - assert(originalValue->getFunction() == &getOriginal()); - // The adjoint value must be in the tangent space. - assert(adjointValue.getType() == - getRemappedTangentType(originalValue->getType())); - auto insertion = valueMap.try_emplace({origBB, originalValue}, adjointValue); - LLVM_DEBUG(getADDebugStream() - << "The new adjoint value, replacing the existing one, is: " - << insertion.first->getSecond()); - if (!insertion.second) - insertion.first->getSecond() = adjointValue; -} +public: + //--------------------------------------------------------------------------// + // Entry point + //--------------------------------------------------------------------------// + + /// Performs pullback generation on the empty pullback function. Returns true + /// if any error occurs. + bool run(); + + /// Performs pullback generation on the empty pullback function, given that + /// the original function is a "semantic member accessor". + /// + /// "Semantic member accessors" are attached to member properties that have a + /// corresponding tangent stored property in the parent `TangentVector` type. + /// These accessors have special-case pullback generation based on their + /// semantic behavior. + /// + /// Returns true if any error occurs. + bool runForSemanticMemberAccessor(); + bool runForSemanticMemberGetter(); + bool runForSemanticMemberSetter(); + + /// If original result is non-varied, it will always have a zero derivative. + /// Skip full pullback generation and simply emit zero derivatives for wrt + /// parameters. + void emitZeroDerivativesForNonvariedResult(SILValue origNonvariedResult); + + using TrampolineBlockSet = SmallPtrSet; + + /// Determines the pullback successor block for a given original block and one + /// of its predecessors. When a trampoline block is necessary, emits code into + /// the trampoline block to trampoline the original block's active value's + /// adjoint values. + /// + /// Populates `pullbackTrampolineBlockMap`, which maps active values' adjoint + /// values to the pullback successor blocks in which they are used. This + /// allows us to release those values in pullback successor blocks that do not + /// use them. + SILBasicBlock * + buildPullbackSuccessor(SILBasicBlock *origBB, SILBasicBlock *origPredBB, + llvm::SmallDenseMap + &pullbackTrampolineBlockMap); + + /// Emits pullback code in the corresponding pullback block. + void visitSILBasicBlock(SILBasicBlock *bb); + + void visit(SILInstruction *inst) { + if (errorOccurred) + return; -AdjointValue PullbackCloner::getAdjointValue(SILBasicBlock *origBB, - SILValue originalValue) { - assert(origBB->getParent() == &getOriginal()); - assert(originalValue->getType().isObject()); - assert(getTangentValueCategory(originalValue) == SILValueCategory::Object); - assert(originalValue->getFunction() == &getOriginal()); - auto insertion = valueMap.try_emplace( - {origBB, originalValue}, - makeZeroAdjointValue(getRemappedTangentType(originalValue->getType()))); - auto it = insertion.first; - return it->getSecond(); -} + LLVM_DEBUG(getADDebugStream() + << "PullbackCloner visited:\n[ORIG]" << *inst); +#ifndef NDEBUG + auto beforeInsertion = std::prev(builder.getInsertionPoint()); +#endif + SILInstructionVisitor::visit(inst); + LLVM_DEBUG({ + auto &s = llvm::dbgs() << "[ADJ] Emitted in pullback:\n"; + auto afterInsertion = builder.getInsertionPoint(); + for (auto it = ++beforeInsertion; it != afterInsertion; ++it) + s << *it; + }); + } + + /// Fallback instruction visitor for unhandled instructions. + /// Emit a general non-differentiability diagnostic. + void visitSILInstruction(SILInstruction *inst) { + LLVM_DEBUG(getADDebugStream() + << "Unhandled instruction in PullbackCloner: " << *inst); + getContext().emitNondifferentiabilityError( + inst, getInvoker(), diag::autodiff_expression_not_differentiable_note); + errorOccurred = true; + } -void PullbackCloner::addAdjointValue(SILBasicBlock *origBB, - SILValue originalValue, - AdjointValue newAdjointValue, - SILLocation loc) { - assert(origBB->getParent() == &getOriginal()); - assert(originalValue->getType().isObject()); - assert(newAdjointValue.getType().isObject()); - assert(originalValue->getFunction() == &getOriginal()); - LLVM_DEBUG(getADDebugStream() - << "Adding adjoint value for " << originalValue); - // The adjoint value must be in the tangent space. - assert(newAdjointValue.getType() == - getRemappedTangentType(originalValue->getType())); - auto insertion = - valueMap.try_emplace({origBB, originalValue}, newAdjointValue); - auto inserted = insertion.second; - if (inserted) - return; - // If adjoint already exists, accumulate the adjoint onto the existing - // adjoint. - auto it = insertion.first; - auto existingValue = it->getSecond(); - valueMap.erase(it); - auto adjVal = accumulateAdjointsDirect(existingValue, newAdjointValue, loc); - // If the original value is the `Array` result of an - // `array.uninitialized_intrinsic` application, accumulate adjoint buffers - // for the array element addresses. - accumulateArrayLiteralElementAddressAdjoints(origBB, originalValue, adjVal, - loc); - setAdjointValue(origBB, originalValue, adjVal); -} + /// Handle `apply` instruction. + /// Original: (y0, y1, ...) = apply @fn (x0, x1, ...) + /// Adjoint: (adj[x0], adj[x1], ...) += apply @fn_pullback (adj[y0], ...) + void visitApplyInst(ApplyInst *ai) { + assert(getPullbackInfo().shouldDifferentiateApplySite(ai)); + // Skip `array.uninitialized_intrinsic` applications, which have special + // `store` and `copy_addr` support. + if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) + return; + auto loc = ai->getLoc(); + auto *bb = ai->getParent(); + // Handle `array.finalize_intrinsic` applications. + // `array.finalize_intrinsic` semantically behaves like an identity + // function. + if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) { + assert(ai->getNumArguments() == 1 && + "Expected intrinsic to have one operand"); + // Accumulate result's adjoint into argument's adjoint. + auto adjResult = getAdjointValue(bb, ai); + auto origArg = ai->getArgumentsWithoutIndirectResults().front(); + addAdjointValue(bb, origArg, adjResult, loc); + return; + } + // Replace a call to a function with a call to its pullback. + auto &nestedApplyInfo = getContext().getNestedApplyInfo(); + auto applyInfoLookup = nestedApplyInfo.find(ai); + // If no `NestedApplyInfo` was found, then this task doesn't need to be + // differentiated. + if (applyInfoLookup == nestedApplyInfo.end()) { + // Must not be active. + assert(!getActivityInfo().isActive(ai, getIndices())); + return; + } + auto applyInfo = applyInfoLookup->getSecond(); + + // Get the pullback. + auto *field = getPullbackInfo().lookUpLinearMapDecl(ai); + assert(field); + auto pullback = getPullbackStructElement(ai->getParent(), field); + + // Get the original result of the `apply` instruction. + SmallVector origDirectResults; + forEachApplyDirectResult(ai, [&](SILValue directResult) { + origDirectResults.push_back(directResult); + }); + SmallVector origAllResults; + collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); + // Append `inout` arguments after original results. + for (auto paramIdx : applyInfo.indices.parameters->getIndices()) { + auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( + ai->getNumIndirectResults() + paramIdx); + if (!paramInfo.isIndirectMutating()) + continue; + origAllResults.push_back( + ai->getArgumentsWithoutIndirectResults()[paramIdx]); + } -void PullbackCloner::accumulateArrayLiteralElementAddressAdjoints( - SILBasicBlock *origBB, SILValue originalValue, - AdjointValue arrayAdjointValue, SILLocation loc) { - // Return if the original value is not the `Array` result of an - // `array.uninitialized_intrinsic` application. - auto *dti = dyn_cast_or_null( - originalValue->getDefiningInstruction()); - if (!dti) - return; - if (!ArraySemanticsCall(dti->getOperand(), - semantics::ARRAY_UNINITIALIZED_INTRINSIC)) - return; - if (originalValue != dti->getResult(0)) - return; - // Accumulate the array's adjoint value into the adjoint buffers of its - // element addresses: `pointer_to_address` and `index_addr` instructions. - LLVM_DEBUG(getADDebugStream() - << "Accumulating adjoint value for array literal into element " - "address adjoint buffers" - << originalValue); - auto arrayAdjoint = materializeAdjointDirect(arrayAdjointValue, loc); - builder.setInsertionPoint(arrayAdjoint->getParentBlock()); - for (auto use : dti->getResult(1)->getUses()) { - auto *ptai = dyn_cast(use->getUser()); - auto adjBuf = getAdjointBuffer(origBB, ptai); - auto *eltAdjBuf = getArrayAdjointElementBuffer(arrayAdjoint, 0, loc); - accumulateIndirect(adjBuf, eltAdjBuf, loc); - for (auto use : ptai->getUses()) { - if (auto *iai = dyn_cast(use->getUser())) { - auto *ili = cast(iai->getIndex()); - auto eltIndex = ili->getValue().getLimitedValue(); - auto adjBuf = getAdjointBuffer(origBB, iai); - auto *eltAdjBuf = - getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, loc); - accumulateIndirect(adjBuf, eltAdjBuf, loc); + // Get callee pullback arguments. + SmallVector args; + + // Handle callee pullback indirect results. + // Create local allocations for these and destroy them after the call. + auto pullbackType = + remapType(pullback->getType()).castTo(); + auto actualPullbackType = applyInfo.originalPullbackType + ? *applyInfo.originalPullbackType + : pullbackType; + actualPullbackType = actualPullbackType->getUnsubstitutedType(getModule()); + SmallVector pullbackIndirectResults; + for (auto indRes : actualPullbackType->getIndirectFormalResults()) { + auto *alloc = builder.createAllocStack( + loc, remapType(indRes.getSILStorageInterfaceType())); + pullbackIndirectResults.push_back(alloc); + args.push_back(alloc); + } + + // Collect callee pullback formal arguments. + for (auto resultIndex : applyInfo.indices.results->getIndices()) { + assert(resultIndex < origAllResults.size()); + auto origResult = origAllResults[resultIndex]; + // Get the seed (i.e. adjoint value of the original result). + SILValue seed; + switch (getTangentValueCategory(origResult)) { + case SILValueCategory::Object: + seed = materializeAdjointDirect(getAdjointValue(bb, origResult), loc); + break; + case SILValueCategory::Address: + seed = getAdjointBuffer(bb, origResult); + break; } + args.push_back(seed); } - } -} -SILArgument * -PullbackCloner::getActiveValuePullbackBlockArgument(SILBasicBlock *origBB, - SILValue activeValue) { - assert(getTangentValueCategory(activeValue) == SILValueCategory::Object); - assert(origBB->getParent() == &getOriginal()); - auto pullbackBBArg = activeValuePullbackBBArgumentMap[{origBB, activeValue}]; - assert(pullbackBBArg); - assert(pullbackBBArg->getParent() == getPullbackBlock(origBB)); - return pullbackBBArg; -} + // If callee pullback was reabstracted in VJP, reabstract callee pullback. + if (applyInfo.originalPullbackType) { + SILOptFunctionBuilder fb(getContext().getTransform()); + pullback = reabstractFunction( + builder, fb, loc, pullback, *applyInfo.originalPullbackType, + [this](SubstitutionMap subs) -> SubstitutionMap { + return this->remapSubstitutionMap(subs); + }); + } -//--------------------------------------------------------------------------// -// Adjoint buffer mapping -//--------------------------------------------------------------------------// + // Call the callee pullback. + auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(), + args, /*isNonThrowing*/ false); + builder.emitDestroyValueOperation(loc, pullback); + + // Extract all results from `pullbackCall`. + SmallVector dirResults; + extractAllElements(pullbackCall, builder, dirResults); + // Get all results in type-defined order. + SmallVector allResults; + collectAllActualResultsInTypeOrder(pullbackCall, dirResults, allResults); + + LLVM_DEBUG({ + auto &s = getADDebugStream(); + s << "All results of the nested pullback call:\n"; + llvm::for_each(allResults, [&](SILValue v) { s << v; }); + }); + + // Accumulate adjoints for original differentiation parameters. + auto allResultsIt = allResults.begin(); + for (unsigned i : applyInfo.indices.parameters->getIndices()) { + auto origArg = ai->getArgument(ai->getNumIndirectResults() + i); + // Skip adjoint accumulation for `inout` arguments. + auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( + ai->getNumIndirectResults() + i); + if (paramInfo.isIndirectMutating()) + continue; + auto tan = *allResultsIt++; + if (tan->getType().isAddress()) { + addToAdjointBuffer(bb, origArg, tan, loc); + } else { + if (origArg->getType().isAddress()) { + auto *tmpBuf = builder.createAllocStack(loc, tan->getType()); + builder.emitStoreValueOperation(loc, tan, tmpBuf, + StoreOwnershipQualifier::Init); + addToAdjointBuffer(bb, origArg, tmpBuf, loc); + builder.emitDestroyAddrAndFold(loc, tmpBuf); + builder.createDeallocStack(loc, tmpBuf); + } else { + recordTemporary(tan); + addAdjointValue(bb, origArg, makeConcreteAdjointValue(tan), loc); + } + } + } + // Destroy unused pullback direct results. Needed for pullback results from + // VJPs extracted from `@differentiable` function callees, where the + // `@differentiable` function's differentiation parameter indices are a + // superset of the active `apply` parameter indices. + while (allResultsIt != allResults.end()) { + auto unusedPullbackDirectResult = *allResultsIt++; + if (unusedPullbackDirectResult->getType().isAddress()) + continue; + builder.emitDestroyValueOperation(loc, unusedPullbackDirectResult); + } + // Destroy and deallocate pullback indirect results. + for (auto *alloc : llvm::reverse(pullbackIndirectResults)) { + builder.emitDestroyAddrAndFold(loc, alloc); + builder.createDeallocStack(loc, alloc); + } + } -SILValue PullbackCloner::getAdjointProjection(SILBasicBlock *origBB, - SILValue originalProjection) { - // Handle `struct_element_addr`. - // Adjoint projection: a `struct_element_addr` into the base adjoint buffer. - if (auto *seai = dyn_cast(originalProjection)) { - assert(!seai->getField()->getAttrs().hasAttribute() && - "`@noDerivative` struct projections should never be active"); - auto adjSource = getAdjointBuffer(origBB, seai->getOperand()); - auto structType = remapType(seai->getOperand()->getType()).getASTType(); - auto *tanField = - getTangentStoredProperty(getContext(), seai, structType, getInvoker()); - assert(tanField && "Invalid projections should have been diagnosed"); - return builder.createStructElementAddr(seai->getLoc(), adjSource, tanField); + void visitBeginApplyInst(BeginApplyInst *bai) { + // Diagnose `begin_apply` instructions. + // Coroutine differentiation is not yet supported. + getContext().emitNondifferentiabilityError( + bai, getInvoker(), diag::autodiff_coroutines_not_supported); + errorOccurred = true; + return; } - // Handle `tuple_element_addr`. - // Adjoint projection: a `tuple_element_addr` into the base adjoint buffer. - if (auto *teai = dyn_cast(originalProjection)) { - auto source = teai->getOperand(); - auto adjSource = getAdjointBuffer(origBB, source); - if (!adjSource->getType().is()) - return adjSource; - auto origTupleTy = source->getType().castTo(); - unsigned adjIndex = 0; - for (unsigned i : range(teai->getFieldNo())) { - if (getTangentSpace( - origTupleTy->getElement(i).getType()->getCanonicalType())) - ++adjIndex; + + /// Handle `struct` instruction. + /// Original: y = struct (x0, x1, x2, ...) + /// Adjoint: adj[x0] += struct_extract adj[y], #x0 + /// adj[x1] += struct_extract adj[y], #x1 + /// adj[x2] += struct_extract adj[y], #x2 + /// ... + void visitStructInst(StructInst *si) { + auto *bb = si->getParent(); + auto loc = si->getLoc(); + auto *structDecl = si->getStructDecl(); + auto av = getAdjointValue(bb, si); + switch (av.getKind()) { + case AdjointValueKind::Zero: + for (auto *field : structDecl->getStoredProperties()) { + auto fv = si->getFieldValue(field); + addAdjointValue( + bb, fv, makeZeroAdjointValue(getRemappedTangentType(fv->getType())), + loc); + } + break; + case AdjointValueKind::Concrete: { + auto adjStruct = materializeAdjointDirect(std::move(av), loc); + auto *dti = builder.createDestructureStruct(si->getLoc(), adjStruct); + + // Find the struct `TangentVector` type. + auto structTy = remapType(si->getType()).getASTType(); +#ifndef NDEBUG + auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType(); + assert(!getTypeLowering(tangentVectorTy).isAddressOnly()); + assert(tangentVectorTy->getStructOrBoundGenericStruct()); +#endif + + // Accumulate adjoints for the fields of the `struct` operand. + unsigned fieldIndex = 0; + for (auto it = structDecl->getStoredProperties().begin(); + it != structDecl->getStoredProperties().end(); ++it, ++fieldIndex) { + VarDecl *field = *it; + if (field->getAttrs().hasAttribute()) + continue; + // Find the corresponding field in the tangent space. + auto *tanField = getTangentStoredProperty(getContext(), field, structTy, + loc, getInvoker()); + if (!tanField) { + errorOccurred = true; + return; + } + auto tanElt = dti->getResult(fieldIndex); + addAdjointValue(bb, si->getFieldValue(field), + makeConcreteAdjointValue(tanElt), si->getLoc()); + } + break; + } + case AdjointValueKind::Aggregate: { + // Note: All user-called initializations go through the calls to the + // initializer, and synthesized initializers only have one level of struct + // formation which will not result into any aggregate adjoint valeus. + llvm_unreachable("Aggregate adjoint values should not occur for `struct` " + "instructions"); + } } - return builder.createTupleElementAddr(teai->getLoc(), adjSource, adjIndex); } - // Handle `ref_element_addr`. - // Adjoint projection: a local allocation initialized with the corresponding - // field value from the class's base adjoint value. - if (auto *reai = dyn_cast(originalProjection)) { - assert(!reai->getField()->getAttrs().hasAttribute() && - "`@noDerivative` class projections should never be active"); - auto loc = reai->getLoc(); - // Get the class operand, stripping `begin_borrow`. - auto classOperand = stripBorrow(reai->getOperand()); - auto classType = remapType(reai->getOperand()->getType()).getASTType(); + + /// Handle `struct_extract` instruction. + /// Original: y = struct_extract x, #field + /// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0) + /// ^~~~~~~ + /// field in tangent space corresponding to #field + void visitStructExtractInst(StructExtractInst *sei) { + auto *bb = sei->getParent(); + auto loc = getValidLocation(sei); + auto structTy = remapType(sei->getOperand()->getType()).getASTType(); + auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType(); + assert(!getTypeLowering(tangentVectorTy).isAddressOnly()); + auto tangentVectorSILTy = SILType::getPrimitiveObjectType(tangentVectorTy); + auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct(); + assert(tangentVectorDecl); + // Find the corresponding field in the tangent space. auto *tanField = - getTangentStoredProperty(getContext(), reai->getField(), classType, - reai->getLoc(), getInvoker()); + getTangentStoredProperty(getContext(), sei, structTy, getInvoker()); assert(tanField && "Invalid projections should have been diagnosed"); - // Create a local allocation for the element adjoint buffer. - auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType(); - auto eltTanSILType = - remapType(SILType::getPrimitiveAddressType(eltTanType)); - auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc); - // Check the class operand's `TangentVector` value category. - switch (getTangentValueCategory(classOperand)) { - case SILValueCategory::Object: { - // Get the class operand's adjoint value. Currently, it must be a - // `TangentVector` struct. - auto adjClass = - materializeAdjointDirect(getAdjointValue(origBB, classOperand), loc); - builder.emitScopedBorrowOperation( - loc, adjClass, [&](SILValue borrowedAdjClass) { - // Initialize the element adjoint buffer with the base adjoint - // value. - auto *adjElt = - builder.createStructExtract(loc, borrowedAdjClass, tanField); - auto adjEltCopy = builder.emitCopyValueOperation(loc, adjElt); - builder.emitStoreValueOperation(loc, adjEltCopy, eltAdjBuffer, - StoreOwnershipQualifier::Init); - }); - return eltAdjBuffer; - } - case SILValueCategory::Address: { - // Get the class operand's adjoint buffer. Currently, it must be a - // `TangentVector` struct. - auto adjClass = getAdjointBuffer(origBB, classOperand); - // Initialize the element adjoint buffer with the base adjoint buffer. - auto *adjElt = builder.createStructElementAddr(loc, adjClass, tanField); - builder.createCopyAddr(loc, adjElt, eltAdjBuffer, IsNotTake, - IsInitialization); - return eltAdjBuffer; - } - } - } - // Handle `begin_access`. - // Adjoint projection: the base adjoint buffer itself. - if (auto *bai = dyn_cast(originalProjection)) { - auto adjBase = getAdjointBuffer(origBB, bai->getOperand()); - if (errorOccurred) - return (bufferMap[{origBB, originalProjection}] = SILValue()); - // Return the base buffer's adjoint buffer. - return adjBase; - } - // Handle `array.uninitialized_intrinsic` application element addresses. - // Adjoint projection: a local allocation initialized by applying - // `Array.TangentVector.subscript` to the base array's adjoint value. - auto *ai = - getAllocateUninitializedArrayIntrinsicElementAddress(originalProjection); - auto *definingInst = dyn_cast_or_null( - originalProjection->getDefiningInstruction()); - bool isAllocateUninitializedArrayIntrinsicElementAddress = - ai && definingInst && - (isa(definingInst) || - isa(definingInst)); - if (isAllocateUninitializedArrayIntrinsicElementAddress) { - // Get the array element index of the result address. - int eltIndex = 0; - if (auto *iai = dyn_cast(definingInst)) { - auto *ili = cast(iai->getIndex()); - eltIndex = ili->getValue().getLimitedValue(); - } - // Get the array adjoint value. - SILValue arrayAdjoint; - assert(ai && "Expected `array.uninitialized_intrinsic` application"); - for (auto use : ai->getUses()) { - auto *dti = dyn_cast(use->getUser()); - if (!dti) - continue; - assert(!arrayAdjoint && "Array adjoint already found"); - // The first `destructure_tuple` result is the `Array` value. - auto arrayValue = dti->getResult(0); - arrayAdjoint = materializeAdjointDirect( - getAdjointValue(origBB, arrayValue), definingInst->getLoc()); - } - assert(arrayAdjoint && "Array does not have adjoint value"); - // Apply `Array.TangentVector.subscript` to get array element adjoint value. - auto *eltAdjBuffer = - getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, ai->getLoc()); - return eltAdjBuffer; - } - return SILValue(); -} - -SILValue &PullbackCloner::getAdjointBuffer(SILBasicBlock *origBB, - SILValue originalValue) { - assert(getTangentValueCategory(originalValue) == SILValueCategory::Address); - assert(originalValue->getFunction() == &getOriginal()); - auto insertion = bufferMap.try_emplace({origBB, originalValue}, SILValue()); - if (!insertion.second) // not inserted - return insertion.first->getSecond(); - - // If the original buffer is a projection, return a corresponding projection - // into the adjoint buffer. - if (auto adjProj = getAdjointProjection(origBB, originalValue)) - return (bufferMap[{origBB, originalValue}] = adjProj); - - auto bufType = getRemappedTangentType(originalValue->getType()); - // Set insertion point for local allocation builder: before the last local - // allocation, or at the start of the pullback function's entry if no local - // allocations exist yet. - auto *newBuf = createFunctionLocalAllocation( - bufType, RegularLocation::getAutoGeneratedLocation()); - // Temporarily change global builder insertion point and emit zero into the - // local allocation. - auto insertionPoint = builder.getInsertionBB(); - builder.setInsertionPoint(localAllocBuilder.getInsertionBB(), - localAllocBuilder.getInsertionPoint()); - emitZeroIndirect(bufType.getASTType(), newBuf, newBuf->getLoc()); - builder.setInsertionPoint(insertionPoint); - return (insertion.first->getSecond() = newBuf); -} - -void PullbackCloner::setAdjointBuffer(SILBasicBlock *origBB, - SILValue originalValue, - SILValue adjointBuffer) { - assert(getTangentValueCategory(originalValue) == SILValueCategory::Address); - auto insertion = - bufferMap.try_emplace({origBB, originalValue}, adjointBuffer); - assert(insertion.second && "Adjoint buffer already exists"); - (void)insertion; -} - -void PullbackCloner::addToAdjointBuffer(SILBasicBlock *origBB, - SILValue originalValue, - SILValue rhsAddress, SILLocation loc) { - assert(getTangentValueCategory(originalValue) == SILValueCategory::Address && - rhsAddress->getType().isAddress()); - assert(originalValue->getFunction() == &getOriginal()); - assert(rhsAddress->getFunction() == &getPullback()); - auto adjointBuffer = getAdjointBuffer(origBB, originalValue); - accumulateIndirect(adjointBuffer, rhsAddress, loc); -} - -SILBasicBlock::iterator -PullbackCloner::getNextFunctionLocalAllocationInsertionPoint() { - // If there are no local allocations, insert at the pullback entry start. - if (functionLocalAllocations.empty()) - return getPullback().getEntryBlock()->begin(); - // Otherwise, insert before the last local allocation. Inserting before - // rather than after ensures that allocation and zero initialization - // instructions are grouped together. - auto lastLocalAlloc = functionLocalAllocations.back(); - return lastLocalAlloc->getDefiningInstruction()->getIterator(); -} - -AllocStackInst *PullbackCloner::createFunctionLocalAllocation(SILType type, - SILLocation loc) { - // Set insertion point for local allocation builder: before the last local - // allocation, or at the start of the pullback function's entry if no local - // allocations exist yet. - localAllocBuilder.setInsertionPoint( - getPullback().getEntryBlock(), - getNextFunctionLocalAllocationInsertionPoint()); - // Create and return local allocation. - auto *alloc = localAllocBuilder.createAllocStack(loc, type); - functionLocalAllocations.push_back(alloc); - return alloc; -} - -//--------------------------------------------------------------------------// -// Debugging utilities -//--------------------------------------------------------------------------// - -void PullbackCloner::printAdjointValueMapping() { - // Group original/adjoint values by basic block. - llvm::DenseMap> tmp; - for (auto pair : valueMap) { - auto origPair = pair.first; - auto *origBB = origPair.first; - auto origValue = origPair.second; - auto adjValue = pair.second; - tmp[origBB].insert({origValue, adjValue}); - } - // Print original/adjoint values per basic block. - auto &s = getADDebugStream() << "Adjoint value mapping:\n"; - for (auto &origBB : getOriginal()) { - if (!pullbackBBMap.count(&origBB)) - continue; - auto bbValueMap = tmp[&origBB]; - s << "bb" << origBB.getDebugID(); - s << " (size " << bbValueMap.size() << "):\n"; - for (auto valuePair : bbValueMap) { - auto origValue = valuePair.first; - auto adjValue = valuePair.second; - s << "ORIG: " << origValue; - s << "ADJ: " << adjValue << '\n'; - } - s << '\n'; - } -} - -void PullbackCloner::printAdjointBufferMapping() { - // Group original/adjoint buffers by basic block. - llvm::DenseMap> tmp; - for (auto pair : bufferMap) { - auto origPair = pair.first; - auto *origBB = origPair.first; - auto origBuf = origPair.second; - auto adjBuf = pair.second; - tmp[origBB][origBuf] = adjBuf; - } - // Print original/adjoint buffers per basic block. - auto &s = getADDebugStream() << "Adjoint buffer mapping:\n"; - for (auto &origBB : getOriginal()) { - if (!pullbackBBMap.count(&origBB)) - continue; - auto bbBufferMap = tmp[&origBB]; - s << "bb" << origBB.getDebugID(); - s << " (size " << bbBufferMap.size() << "):\n"; - for (auto valuePair : bbBufferMap) { - auto origBuf = valuePair.first; - auto adjBuf = valuePair.second; - s << "ORIG: " << origBuf; - s << "ADJ: " << adjBuf << '\n'; - } - s << '\n'; - } -} - -//--------------------------------------------------------------------------// -// Member accessor pullback generation -//--------------------------------------------------------------------------// - -bool PullbackCloner::runForSemanticMemberAccessor() { - auto &original = getOriginal(); - auto *accessor = cast(original.getDeclContext()->getAsDecl()); - switch (accessor->getAccessorKind()) { - case AccessorKind::Get: - return runForSemanticMemberGetter(); - case AccessorKind::Set: - return runForSemanticMemberSetter(); - // TODO(SR-12640): Support `modify` accessors. - default: - llvm_unreachable("Unsupported accessor kind; inconsistent with " - "`isSemanticMemberAccessor`?"); - } -} - -bool PullbackCloner::runForSemanticMemberGetter() { - auto &original = getOriginal(); - auto &pullback = getPullback(); - auto pbLoc = getPullback().getLocation(); - - auto *accessor = cast(original.getDeclContext()->getAsDecl()); - assert(accessor->getAccessorKind() == AccessorKind::Get); - - auto *origEntry = original.getEntryBlock(); - auto *pbEntry = pullback.getEntryBlock(); - builder.setInsertionPoint(pbEntry); - - // Get getter argument and result values. - // Getter type: $(Self) -> Result - // Pullback type: $(Result', PB_Struct) -> Self' - assert(original.getLoweredFunctionType()->getNumParameters() == 1); - assert(pullback.getLoweredFunctionType()->getNumParameters() == 2); - assert(pullback.getLoweredFunctionType()->getNumResults() == 1); - SILValue origSelf = original.getArgumentsWithoutIndirectResults().front(); - - SmallVector origFormalResults; - collectAllFormalResultsInTypeOrder(original, origFormalResults); - assert(getIndices().results->getNumIndices() == 1 && - "Getter should have one semantic result"); - auto origResult = origFormalResults[*getIndices().results->begin()]; - - auto tangentVectorSILTy = pullback.getConventions().getSingleSILResultType( - TypeExpansionContext::minimal()); - auto tangentVectorTy = tangentVectorSILTy.getASTType(); - auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct(); - - // Look up the corresponding field in the tangent space. - auto *origField = cast(accessor->getStorage()); - auto baseType = remapType(origSelf->getType()).getASTType(); - auto *tanField = getTangentStoredProperty(getContext(), origField, baseType, - pbLoc, getInvoker()); - if (!tanField) { - errorOccurred = true; - return true; - } - - // Switch based on the base tangent struct's value category. - // TODO(TF-1255): Simplify using unified adjoint value data structure. - switch (tangentVectorSILTy.getCategory()) { - case SILValueCategory::Object: { - auto adjResult = getAdjointValue(origEntry, origResult); - switch (adjResult.getKind()) { + // Accumulate adjoint for the `struct_extract` operand. + auto av = getAdjointValue(bb, sei); + switch (av.getKind()) { case AdjointValueKind::Zero: - addAdjointValue(origEntry, origSelf, - makeZeroAdjointValue(tangentVectorSILTy), pbLoc); + addAdjointValue(bb, sei->getOperand(), + makeZeroAdjointValue(tangentVectorSILTy), loc); break; case AdjointValueKind::Concrete: case AdjointValueKind::Aggregate: { SmallVector eltVals; for (auto *field : tangentVectorDecl->getStoredProperties()) { if (field == tanField) { - eltVals.push_back(adjResult); + eltVals.push_back(av); } else { auto substMap = tangentVectorTy->getMemberSubstitutionMap( field->getModuleContext(), field); @@ -698,116 +1082,461 @@ bool PullbackCloner::runForSemanticMemberGetter() { eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); } } - addAdjointValue(origEntry, origSelf, + addAdjointValue(bb, sei->getOperand(), makeAggregateAdjointValue(tangentVectorSILTy, eltVals), - pbLoc); + loc); } } - break; } - case SILValueCategory::Address: { - assert(pullback.getIndirectResults().size() == 1); - auto pbIndRes = pullback.getIndirectResults().front(); - auto *adjSelf = createFunctionLocalAllocation( - pbIndRes->getType().getObjectType(), pbLoc); - setAdjointBuffer(origEntry, origSelf, adjSelf); - for (auto *field : tangentVectorDecl->getStoredProperties()) { - auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, field); - if (field == tanField) { - // Switch based on the property's value category. - // TODO(TF-1255): Simplify using unified adjoint value data structure. - switch (origResult->getType().getCategory()) { - case SILValueCategory::Object: { - auto adjResult = getAdjointValue(origEntry, origResult); - auto adjResultValue = materializeAdjointDirect(adjResult, pbLoc); - auto adjResultValueCopy = - builder.emitCopyValueOperation(pbLoc, adjResultValue); - builder.emitStoreValueOperation(pbLoc, adjResultValueCopy, adjSelfElt, - StoreOwnershipQualifier::Init); - break; - } - case SILValueCategory::Address: { - auto adjResult = getAdjointBuffer(origEntry, origResult); - builder.createCopyAddr(pbLoc, adjResult, adjSelfElt, IsTake, - IsInitialization); - destroyedLocalAllocations.insert(adjResult); - break; - } + + /// Handle `ref_element_addr` instruction. + /// Original: y = ref_element_addr x, + /// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0) + /// ^~~~~~~ + /// field in tangent space corresponding to #field + void visitRefElementAddrInst(RefElementAddrInst *reai) { + auto *bb = reai->getParent(); + auto loc = reai->getLoc(); + auto adjBuf = getAdjointBuffer(bb, reai); + auto classOperand = reai->getOperand(); + auto classType = remapType(reai->getOperand()->getType()).getASTType(); + auto *tanField = + getTangentStoredProperty(getContext(), reai, classType, getInvoker()); + assert(tanField && "Invalid projections should have been diagnosed"); + switch (getTangentValueCategory(classOperand)) { + case SILValueCategory::Object: { + auto classTy = remapType(classOperand->getType()).getASTType(); + auto tangentVectorTy = getTangentSpace(classTy)->getCanonicalType(); + auto tangentVectorSILTy = + SILType::getPrimitiveObjectType(tangentVectorTy); + auto *tangentVectorDecl = + tangentVectorTy->getStructOrBoundGenericStruct(); + // Accumulate adjoint for the `ref_element_addr` operand. + SmallVector eltVals; + for (auto *field : tangentVectorDecl->getStoredProperties()) { + if (field == tanField) { + auto adjElt = builder.emitLoadValueOperation( + reai->getLoc(), adjBuf, LoadOwnershipQualifier::Copy); + eltVals.push_back(makeConcreteAdjointValue(adjElt)); + recordTemporary(adjElt); + } else { + auto substMap = tangentVectorTy->getMemberSubstitutionMap( + field->getModuleContext(), field); + auto fieldTy = field->getType().subst(substMap); + auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType(); + assert(fieldSILTy.isObject()); + eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); } - } else { - auto fieldType = pullback.mapTypeIntoContext(field->getInterfaceType()) - ->getCanonicalType(); - emitZeroIndirect(fieldType, adjSelfElt, pbLoc); } + addAdjointValue(bb, classOperand, + makeAggregateAdjointValue(tangentVectorSILTy, eltVals), + loc); + break; + } + case SILValueCategory::Address: { + auto adjBufClass = getAdjointBuffer(bb, classOperand); + auto adjBufElt = + builder.createStructElementAddr(loc, adjBufClass, tanField); + accumulateIndirect(adjBufElt, adjBuf, loc); + break; + } } - break; - } } - return false; -} -bool PullbackCloner::runForSemanticMemberSetter() { - auto &original = getOriginal(); - auto &pullback = getPullback(); - auto pbLoc = getPullback().getLocation(); + /// Handle `tuple` instruction. + /// Original: y = tuple (x0, x1, x2, ...) + /// Adjoint: (adj[x0], adj[x1], adj[x2], ...) += destructure_tuple adj[y] + /// ^~~ + /// excluding non-differentiable elements + void visitTupleInst(TupleInst *ti) { + auto *bb = ti->getParent(); + auto av = getAdjointValue(bb, ti); + switch (av.getKind()) { + case AdjointValueKind::Zero: + for (auto elt : ti->getElements()) { + if (!getTangentSpace(elt->getType().getASTType())) + continue; + addAdjointValue( + bb, elt, + makeZeroAdjointValue(getRemappedTangentType(elt->getType())), + ti->getLoc()); + } + break; + case AdjointValueKind::Concrete: { + auto adjVal = av.getConcreteValue(); + unsigned adjIdx = 0; + auto adjValCopy = builder.emitCopyValueOperation(ti->getLoc(), adjVal); + SmallVector adjElts; + if (!adjVal->getType().getAs()) { + recordTemporary(adjValCopy); + adjElts.push_back(adjValCopy); + } else { + auto *dti = builder.createDestructureTuple(ti->getLoc(), adjValCopy); + for (auto adjElt : dti->getResults()) + recordTemporary(adjElt); + adjElts.append(dti->getResults().begin(), dti->getResults().end()); + } + // Accumulate adjoints for `tuple` operands, skipping the + // non-differentiable ones. + for (auto i : range(ti->getNumOperands())) { + if (!getTangentSpace(ti->getOperand(i)->getType().getASTType())) + continue; + auto adjElt = adjElts[adjIdx++]; + addAdjointValue(bb, ti->getOperand(i), makeConcreteAdjointValue(adjElt), + ti->getLoc()); + } + break; + } + case AdjointValueKind::Aggregate: + unsigned adjIdx = 0; + for (auto i : range(ti->getElements().size())) { + if (!getTangentSpace(ti->getElement(i)->getType().getASTType())) + continue; + addAdjointValue(bb, ti->getElement(i), av.getAggregateElement(adjIdx++), + ti->getLoc()); + } + break; + } + } - auto *accessor = cast(original.getDeclContext()->getAsDecl()); - assert(accessor->getAccessorKind() == AccessorKind::Set); + /// Handle `tuple_extract` instruction. + /// Original: y = tuple_extract x, + /// Adjoint: adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0) + /// ^~~~~~ + /// n'-th element, where n' is tuple tangent space + /// index corresponding to n + void visitTupleExtractInst(TupleExtractInst *tei) { + auto *bb = tei->getParent(); + auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType()); + auto av = getAdjointValue(bb, tei); + switch (av.getKind()) { + case AdjointValueKind::Zero: + addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy), + tei->getLoc()); + break; + case AdjointValueKind::Aggregate: + case AdjointValueKind::Concrete: { + auto tupleTy = tei->getTupleType(); + auto tupleTanTupleTy = tupleTanTy.getAs(); + if (!tupleTanTupleTy) { + addAdjointValue(bb, tei->getOperand(), av, tei->getLoc()); + break; + } + SmallVector elements; + unsigned adjIdx = 0; + for (unsigned i : range(tupleTy->getNumElements())) { + if (!getTangentSpace( + tupleTy->getElement(i).getType()->getCanonicalType())) + continue; + if (tei->getFieldNo() == i) + elements.push_back(av); + else + elements.push_back(makeZeroAdjointValue( + getRemappedTangentType(SILType::getPrimitiveObjectType( + tupleTanTupleTy->getElementType(adjIdx++) + ->getCanonicalType())))); + } + if (elements.size() == 1) { + addAdjointValue(bb, tei->getOperand(), elements.front(), tei->getLoc()); + break; + } + addAdjointValue(bb, tei->getOperand(), + makeAggregateAdjointValue(tupleTanTy, elements), + tei->getLoc()); + break; + } + } + } - auto *origEntry = original.getEntryBlock(); - auto *pbEntry = pullback.getEntryBlock(); - builder.setInsertionPoint(pbEntry); + /// Handle `destructure_tuple` instruction. + /// Original: (y0, ..., yn) = destructure_tuple x + /// Adjoint: adj[x].0 += adj[y0] + /// ... + /// adj[x].n += adj[yn] + void visitDestructureTupleInst(DestructureTupleInst *dti) { + auto *bb = dti->getParent(); + auto tupleTanTy = getRemappedTangentType(dti->getOperand()->getType()); + SmallVector adjValues; + for (auto origElt : dti->getResults()) { + if (!getTangentSpace(remapType(origElt->getType()).getASTType())) + continue; + adjValues.push_back(getAdjointValue(bb, origElt)); + } + // Handle tuple tangent type. + // Add adjoints for every tuple element that has a tangent space. + if (tupleTanTy.is()) { + assert(adjValues.size() > 1); + addAdjointValue(bb, dti->getOperand(), + makeAggregateAdjointValue(tupleTanTy, adjValues), + dti->getLoc()); + } + // Handle non-tuple tangent type. + // Add adjoint for the single tuple element that has a tangent space. + else { + assert(adjValues.size() == 1); + addAdjointValue(bb, dti->getOperand(), adjValues.front(), dti->getLoc()); + } + } - // Get setter argument values. - // Setter type: $(inout Self, Argument) -> () - // Pullback type (wrt self): $(inout Self', PB_Struct) -> () - // Pullback type (wrt both): $(inout Self', PB_Struct) -> Argument' - assert(original.getLoweredFunctionType()->getNumParameters() == 2); - assert(pullback.getLoweredFunctionType()->getNumParameters() == 2); - assert(pullback.getLoweredFunctionType()->getNumResults() == 0 || - pullback.getLoweredFunctionType()->getNumResults() == 1); + /// Handle `load` or `load_borrow` instruction + /// Original: y = load/load_borrow x + /// Adjoint: adj[x] += adj[y] + void visitLoadOperation(SingleValueInstruction *inst) { + assert(isa(inst) || isa(inst)); + auto *bb = inst->getParent(); + auto loc = inst->getLoc(); + switch (getTangentValueCategory(inst)) { + case SILValueCategory::Object: { + auto adjVal = materializeAdjointDirect(getAdjointValue(bb, inst), loc); + // Allocate a local buffer and store the adjoint value. This buffer will + // be used for accumulation into the adjoint buffer. + auto adjBuf = builder.createAllocStack(loc, adjVal->getType()); + auto copy = builder.emitCopyValueOperation(loc, adjVal); + builder.emitStoreValueOperation(loc, copy, adjBuf, + StoreOwnershipQualifier::Init); + // Accumulate the adjoint value in the local buffer into the adjoint + // buffer. + addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc); + builder.emitDestroyAddr(loc, adjBuf); + builder.createDeallocStack(loc, adjBuf); + break; + } + case SILValueCategory::Address: { + auto adjBuf = getAdjointBuffer(bb, inst); + addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc); + break; + } + } + } + void visitLoadInst(LoadInst *li) { visitLoadOperation(li); } + void visitLoadBorrowInst(LoadBorrowInst *lbi) { visitLoadOperation(lbi); } - SILValue origArg = original.getArgumentsWithoutIndirectResults()[0]; - SILValue origSelf = original.getArgumentsWithoutIndirectResults()[1]; + /// Handle `store` or `store_borrow` instruction. + /// Original: store/store_borrow x to y + /// Adjoint: adj[x] += load adj[y]; adj[y] = 0 + void visitStoreOperation(SILBasicBlock *bb, SILLocation loc, SILValue origSrc, + SILValue origDest) { + auto &adjBuf = getAdjointBuffer(bb, origDest); + switch (getTangentValueCategory(origSrc)) { + case SILValueCategory::Object: { + auto adjVal = builder.emitLoadValueOperation( + loc, adjBuf, LoadOwnershipQualifier::Take); + recordTemporary(adjVal); + addAdjointValue(bb, origSrc, makeConcreteAdjointValue(adjVal), loc); + emitZeroIndirect(adjBuf->getType().getASTType(), adjBuf, loc); + break; + } + case SILValueCategory::Address: { + addToAdjointBuffer(bb, origSrc, adjBuf, loc); + builder.emitDestroyAddr(loc, adjBuf); + emitZeroIndirect(adjBuf->getType().getASTType(), adjBuf, loc); + break; + } + } + } + void visitStoreInst(StoreInst *si) { + visitStoreOperation(si->getParent(), si->getLoc(), si->getSrc(), + si->getDest()); + } + void visitStoreBorrowInst(StoreBorrowInst *sbi) { + visitStoreOperation(sbi->getParent(), sbi->getLoc(), sbi->getSrc(), + sbi->getDest()); + } - // Look up the corresponding field in the tangent space. - auto *origField = cast(accessor->getStorage()); - auto baseType = remapType(origSelf->getType()).getASTType(); - auto *tanField = getTangentStoredProperty(getContext(), origField, baseType, - pbLoc, getInvoker()); - if (!tanField) { - errorOccurred = true; - return true; + /// Handle `copy_addr` instruction. + /// Original: copy_addr x to y + /// Adjoint: adj[x] += adj[y]; adj[y] = 0 + void visitCopyAddrInst(CopyAddrInst *cai) { + auto *bb = cai->getParent(); + auto &adjDest = getAdjointBuffer(bb, cai->getDest()); + auto destType = remapType(adjDest->getType()); + addToAdjointBuffer(bb, cai->getSrc(), adjDest, cai->getLoc()); + builder.emitDestroyAddrAndFold(cai->getLoc(), adjDest); + emitZeroIndirect(destType.getASTType(), adjDest, cai->getLoc()); } - auto adjSelf = getAdjointBuffer(origEntry, origSelf); - auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField); - // Switch based on the property's value category. - // TODO(TF-1255): Simplify using unified adjoint value data structure. - switch (origArg->getType().getCategory()) { - case SILValueCategory::Object: { - auto adjArg = builder.emitLoadValueOperation(pbLoc, adjSelfElt, - LoadOwnershipQualifier::Take); - setAdjointValue(origEntry, origArg, makeConcreteAdjointValue(adjArg)); - blockTemporaries[pbEntry].insert(adjArg); - break; + /// Handle `copy_value` instruction. + /// Original: y = copy_value x + /// Adjoint: adj[x] += adj[y] + void visitCopyValueInst(CopyValueInst *cvi) { + auto *bb = cvi->getParent(); + switch (getTangentValueCategory(cvi)) { + case SILValueCategory::Object: { + auto adj = getAdjointValue(bb, cvi); + addAdjointValue(bb, cvi->getOperand(), adj, cvi->getLoc()); + break; + } + case SILValueCategory::Address: { + auto &adjDest = getAdjointBuffer(bb, cvi); + auto destType = remapType(adjDest->getType()); + addToAdjointBuffer(bb, cvi->getOperand(), adjDest, cvi->getLoc()); + builder.emitDestroyAddrAndFold(cvi->getLoc(), adjDest); + emitZeroIndirect(destType.getASTType(), adjDest, cvi->getLoc()); + break; + } + } } - case SILValueCategory::Address: { - addToAdjointBuffer(origEntry, origArg, adjSelfElt, pbLoc); - builder.emitDestroyOperation(pbLoc, adjSelfElt); - break; + + /// Handle `begin_borrow` instruction. + /// Original: y = begin_borrow x + /// Adjoint: adj[x] += adj[y] + void visitBeginBorrowInst(BeginBorrowInst *bbi) { + auto *bb = bbi->getParent(); + switch (getTangentValueCategory(bbi)) { + case SILValueCategory::Object: { + auto adj = getAdjointValue(bb, bbi); + addAdjointValue(bb, bbi->getOperand(), adj, bbi->getLoc()); + break; + } + case SILValueCategory::Address: { + auto &adjDest = getAdjointBuffer(bb, bbi); + auto destType = remapType(adjDest->getType()); + addToAdjointBuffer(bb, bbi->getOperand(), adjDest, bbi->getLoc()); + builder.emitDestroyAddrAndFold(bbi->getLoc(), adjDest); + emitZeroIndirect(destType.getASTType(), adjDest, bbi->getLoc()); + break; + } + } } + + /// Handle `begin_access` instruction. + /// Original: y = begin_access x + /// Adjoint: nothing + void visitBeginAccessInst(BeginAccessInst *bai) { + // Check for non-differentiable writes. + if (bai->getAccessKind() == SILAccessKind::Modify) { + if (isa(bai->getSource())) { + getContext().emitNondifferentiabilityError( + bai, getInvoker(), + diag::autodiff_cannot_differentiate_writes_to_global_variables); + errorOccurred = true; + return; + } + if (isa(bai->getSource())) { + getContext().emitNondifferentiabilityError( + bai, getInvoker(), + diag::autodiff_cannot_differentiate_writes_to_mutable_captures); + errorOccurred = true; + return; + } + } } - emitZeroIndirect(adjSelfElt->getType().getASTType(), adjSelfElt, pbLoc); - return false; + /// Handle `unconditional_checked_cast_addr` instruction. + /// Original: y = unconditional_checked_cast_addr x + /// Adjoint: adj[x] += unconditional_checked_cast_addr adj[y] + void visitUnconditionalCheckedCastAddrInst( + UnconditionalCheckedCastAddrInst *uccai) { + auto *bb = uccai->getParent(); + auto &adjDest = getAdjointBuffer(bb, uccai->getDest()); + auto &adjSrc = getAdjointBuffer(bb, uccai->getSrc()); + auto destType = remapType(adjDest->getType()); + auto castBuf = builder.createAllocStack(uccai->getLoc(), adjSrc->getType()); + builder.createUnconditionalCheckedCastAddr( + uccai->getLoc(), adjDest, adjDest->getType().getASTType(), castBuf, + adjSrc->getType().getASTType()); + addToAdjointBuffer(bb, uccai->getSrc(), castBuf, uccai->getLoc()); + builder.emitDestroyAddrAndFold(uccai->getLoc(), castBuf); + builder.createDeallocStack(uccai->getLoc(), castBuf); + emitZeroIndirect(destType.getASTType(), adjDest, uccai->getLoc()); + } + + /// Handle `unchecked_ref_cast` instruction. + /// Original: y = unchecked_ref_cast x + /// Adjoint: adj[x] += adj[y] + /// (assuming adj[x] and adj[y] have the same type) + void visitUncheckedRefCastInst(UncheckedRefCastInst *urci) { + auto *bb = urci->getParent(); + assert(urci->getOperand()->getType().isObject()); + assert(getRemappedTangentType(urci->getOperand()->getType()) == + getRemappedTangentType(urci->getType()) && + "Operand/result must have the same `TangentVector` type"); + auto adj = getAdjointValue(bb, urci); + addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc()); + } + + /// Handle `upcast` instruction. + /// Original: y = upcast x + /// Adjoint: adj[x] += adj[y] + /// (assuming adj[x] and adj[y] have the same type) + void visitUpcastInst(UpcastInst *ui) { + auto *bb = ui->getParent(); + assert(ui->getOperand()->getType().isObject()); + assert(getRemappedTangentType(ui->getOperand()->getType()) == + getRemappedTangentType(ui->getType()) && + "Operand/result must have the same `TangentVector` type"); + auto adj = getAdjointValue(bb, ui); + addAdjointValue(bb, ui->getOperand(), adj, ui->getLoc()); + } + +#define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst); +#undef NOT_DIFFERENTIABLE + +#define NO_ADJOINT(INST) \ + void visit##INST##Inst(INST##Inst *inst) {} + // Terminators. + NO_ADJOINT(Return) + NO_ADJOINT(Branch) + NO_ADJOINT(CondBranch) + + // Address projections. + NO_ADJOINT(StructElementAddr) + NO_ADJOINT(TupleElementAddr) + + // Array literal initialization address projections. + NO_ADJOINT(PointerToAddress) + NO_ADJOINT(IndexAddr) + + // Memory allocation/access. + NO_ADJOINT(AllocStack) + NO_ADJOINT(DeallocStack) + NO_ADJOINT(EndAccess) + + // Debugging/reference counting instructions. + NO_ADJOINT(DebugValue) + NO_ADJOINT(DebugValueAddr) + NO_ADJOINT(RetainValue) + NO_ADJOINT(RetainValueAddr) + NO_ADJOINT(ReleaseValue) + NO_ADJOINT(ReleaseValueAddr) + NO_ADJOINT(StrongRetain) + NO_ADJOINT(StrongRelease) + NO_ADJOINT(UnownedRetain) + NO_ADJOINT(UnownedRelease) + NO_ADJOINT(StrongRetainUnowned) + NO_ADJOINT(DestroyValue) + NO_ADJOINT(DestroyAddr) + + // Value ownership. + NO_ADJOINT(EndBorrow) +#undef NO_ADJOINT +}; + +PullbackCloner::Implementation::Implementation(VJPCloner &vjpCloner) + : vjpCloner(vjpCloner), builder(getPullback()), + localAllocBuilder(getPullback()) { + // Get dominance and post-order info for the original function. + auto &passManager = getContext().getPassManager(); + auto *domAnalysis = passManager.getAnalysis(); + auto *postDomAnalysis = passManager.getAnalysis(); + auto *postOrderAnalysis = passManager.getAnalysis(); + domInfo = domAnalysis->get(vjpCloner.original); + postDomInfo = postDomAnalysis->get(vjpCloner.original); + postOrderInfo = postOrderAnalysis->get(vjpCloner.original); } +PullbackCloner::PullbackCloner(VJPCloner &vjpCloner) + : impl(*new Implementation(vjpCloner)) {} + +PullbackCloner::~PullbackCloner() { delete &impl; } + //--------------------------------------------------------------------------// // Entry point //--------------------------------------------------------------------------// -bool PullbackCloner::run() { +bool PullbackCloner::run() { return impl.run(); } + +bool PullbackCloner::Implementation::run() { PrettyStackTraceSILFunction trace("generating pullback for", &getOriginal()); auto &original = getOriginal(); auto &pullback = getPullback(); @@ -1165,7 +1894,7 @@ bool PullbackCloner::run() { return errorOccurred; } -void PullbackCloner::emitZeroDerivativesForNonvariedResult( +void PullbackCloner::Implementation::emitZeroDerivativesForNonvariedResult( SILValue origNonvariedResult) { auto &pullback = getPullback(); auto pbLoc = getPullback().getLocation(); @@ -1208,7 +1937,7 @@ void PullbackCloner::emitZeroDerivativesForNonvariedResult( << pullback); } -SILBasicBlock *PullbackCloner::buildPullbackSuccessor( +SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor( SILBasicBlock *origBB, SILBasicBlock *origPredBB, SmallDenseMap &pullbackTrampolineBlockMap) { // Get the pullback block and optional pullback trampoline block of the @@ -1294,7 +2023,7 @@ SILBasicBlock *PullbackCloner::buildPullbackSuccessor( return pullbackTrampolineBB; } -void PullbackCloner::visitSILBasicBlock(SILBasicBlock *bb) { +void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) { auto pbLoc = getPullback().getLocation(); // Get the corresponding pullback basic block. auto *pbBB = getPullbackBlock(bb); @@ -1323,8 +2052,9 @@ void PullbackCloner::visitSILBasicBlock(SILBasicBlock *bb) { // Emit a branching terminator for the block. // If the original block is the original entry, then the pullback block is - // the pullback exit. This is handled specially in `PullbackCloner::run()`, - // so we leave the block non-terminated. + // the pullback exit. This is handled specially in + // `PullbackCloner::Implementation::run()`, so we leave the block + // non-terminated. if (bb->isEntry()) return; @@ -1401,841 +2131,337 @@ void PullbackCloner::visitSILBasicBlock(SILBasicBlock *bb) { pullbackSuccessorCases); } -void PullbackCloner::visit(SILInstruction *inst) { - if (errorOccurred) - return; - - LLVM_DEBUG(getADDebugStream() << "PullbackCloner visited:\n[ORIG]" << *inst); -#ifndef NDEBUG - auto beforeInsertion = std::prev(builder.getInsertionPoint()); -#endif - SILInstructionVisitor::visit(inst); - LLVM_DEBUG({ - auto &s = llvm::dbgs() << "[ADJ] Emitted in pullback:\n"; - auto afterInsertion = builder.getInsertionPoint(); - for (auto it = ++beforeInsertion; it != afterInsertion; ++it) - s << *it; - }); -} - -void PullbackCloner::visitSILInstruction(SILInstruction *inst) { - LLVM_DEBUG(getADDebugStream() - << "Unhandled instruction in PullbackCloner: " << *inst); - getContext().emitNondifferentiabilityError( - inst, getInvoker(), diag::autodiff_expression_not_differentiable_note); - errorOccurred = true; -} +//--------------------------------------------------------------------------// +// Member accessor pullback generation +//--------------------------------------------------------------------------// -AllocStackInst * -PullbackCloner::getArrayAdjointElementBuffer(SILValue arrayAdjoint, - int eltIndex, SILLocation loc) { - auto &ctx = builder.getASTContext(); - auto arrayTanType = cast(arrayAdjoint->getType().getASTType()); - auto arrayType = arrayTanType->getParent()->castTo(); - auto eltTanType = arrayType->getGenericArgs().front()->getCanonicalType(); - auto eltTanSILType = remapType(SILType::getPrimitiveAddressType(eltTanType)); - // Get `function_ref` and generic signature of - // `Array.TangentVector.subscript.getter`. - auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct(); - auto subscriptLookup = - arrayTanStructDecl->lookupDirect(DeclBaseName::createSubscript()); - SubscriptDecl *subscriptDecl = nullptr; - for (auto *candidate : subscriptLookup) { - auto candidateModule = candidate->getModuleContext(); - if (candidateModule->getName() == ctx.Id_Differentiation || - candidateModule->isStdlibModule()) { - assert(!subscriptDecl && "Multiple `Array.TangentVector.subscript`s"); - subscriptDecl = cast(candidate); -#ifdef NDEBUG - break; -#endif - } +bool PullbackCloner::Implementation::runForSemanticMemberAccessor() { + auto &original = getOriginal(); + auto *accessor = cast(original.getDeclContext()->getAsDecl()); + switch (accessor->getAccessorKind()) { + case AccessorKind::Get: + return runForSemanticMemberGetter(); + case AccessorKind::Set: + return runForSemanticMemberSetter(); + // TODO(SR-12640): Support `modify` accessors. + default: + llvm_unreachable("Unsupported accessor kind; inconsistent with " + "`isSemanticMemberAccessor`?"); } - assert(subscriptDecl && "No `Array.TangentVector.subscript`"); - auto *subscriptGetterDecl = - subscriptDecl->getOpaqueAccessor(AccessorKind::Get); - assert(subscriptGetterDecl && "No `Array.TangentVector.subscript` getter"); - SILOptFunctionBuilder fb(getContext().getTransform()); - auto *subscriptGetterFn = fb.getOrCreateFunction( - loc, SILDeclRef(subscriptGetterDecl), NotForDefinition); - // %subscript_fn = function_ref @Array.TangentVector.subscript.getter - auto *subscriptFnRef = builder.createFunctionRef(loc, subscriptGetterFn); - auto subscriptFnGenSig = - subscriptGetterFn->getLoweredFunctionType()->getSubstGenericSignature(); - // Apply `Array.TangentVector.subscript.getter` to get array element adjoint - // buffer. - // %index_literal = integer_literal $Builtin.IntXX, - auto builtinIntType = - SILType::getPrimitiveObjectType(ctx.getIntDecl() - ->getStoredProperties() - .front() - ->getInterfaceType() - ->getCanonicalType()); - auto *eltIndexLiteral = - builder.createIntegerLiteral(loc, builtinIntType, eltIndex); - auto intType = SILType::getPrimitiveObjectType( - ctx.getIntDecl()->getDeclaredType()->getCanonicalType()); - // %index_int = struct $Int (%index_literal) - auto *eltIndexInt = builder.createStruct(loc, intType, {eltIndexLiteral}); - auto *swiftModule = getModule().getSwiftModule(); - auto *diffProto = ctx.getProtocol(KnownProtocolKind::Differentiable); - auto diffConf = swiftModule->lookupConformance(eltTanType, diffProto); - assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`"); - auto *addArithProto = ctx.getProtocol(KnownProtocolKind::AdditiveArithmetic); - auto addArithConf = swiftModule->lookupConformance(eltTanType, addArithProto); - assert(!addArithConf.isInvalid() && - "Missing conformance to `AdditiveArithmetic`"); - auto subMap = SubstitutionMap::get(subscriptFnGenSig, {eltTanType}, - {addArithConf, diffConf}); - // %elt_adj = alloc_stack $T.TangentVector - // Create and register a local allocation. - auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc); - // Temporarily change global builder insertion point and emit zero into the - // local allocation. - auto insertionPoint = builder.getInsertionBB(); - builder.setInsertionPoint(localAllocBuilder.getInsertionBB(), - localAllocBuilder.getInsertionPoint()); - emitZeroIndirect(eltTanType, eltAdjBuffer, loc); - builder.setInsertionPoint(insertionPoint); - // Immediately destroy the emitted zero value. - // NOTE: It is not efficient to emit a zero value then immediately destroy - // it. However, it was the easiest way to to avoid "lifetime mismatch in - // predecessors" memory lifetime verification errors for control flow - // differentiation. - // Perhaps we can avoid emitting a zero value if local allocations are created - // per pullback bb instead of all in the pullback entry: TF-1075. - builder.emitDestroyOperation(loc, eltAdjBuffer); - // apply %subscript_fn(%elt_adj, %index_int, %array_adj) - builder.createApply(loc, subscriptFnRef, subMap, - {eltAdjBuffer, eltIndexInt, arrayAdjoint}); - return eltAdjBuffer; } -void PullbackCloner::visitApplyInst(ApplyInst *ai) { - assert(getPullbackInfo().shouldDifferentiateApplySite(ai)); - // Skip `array.uninitialized_intrinsic` applications, which have special - // `store` and `copy_addr` support. - if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) - return; - auto loc = ai->getLoc(); - auto *bb = ai->getParent(); - // Handle `array.finalize_intrinsic` applications. `array.finalize_intrinsic` - // semantically behaves like an identity function. - if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) { - assert(ai->getNumArguments() == 1 && - "Expected intrinsic to have one operand"); - // Accumulate result's adjoint into argument's adjoint. - auto adjResult = getAdjointValue(bb, ai); - auto origArg = ai->getArgumentsWithoutIndirectResults().front(); - addAdjointValue(bb, origArg, adjResult, loc); - return; - } - // Replace a call to a function with a call to its pullback. - auto &nestedApplyInfo = getContext().getNestedApplyInfo(); - auto applyInfoLookup = nestedApplyInfo.find(ai); - // If no `NestedApplyInfo` was found, then this task doesn't need to be - // differentiated. - if (applyInfoLookup == nestedApplyInfo.end()) { - // Must not be active. - assert(!getActivityInfo().isActive(ai, getIndices())); - return; - } - auto applyInfo = applyInfoLookup->getSecond(); - - // Get the pullback. - auto *field = getPullbackInfo().lookUpLinearMapDecl(ai); - assert(field); - auto pullback = getPullbackStructElement(ai->getParent(), field); - - // Get the original result of the `apply` instruction. - SmallVector origDirectResults; - forEachApplyDirectResult(ai, [&](SILValue directResult) { - origDirectResults.push_back(directResult); - }); - SmallVector origAllResults; - collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); - // Append `inout` arguments after original results. - for (auto paramIdx : applyInfo.indices.parameters->getIndices()) { - auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( - ai->getNumIndirectResults() + paramIdx); - if (!paramInfo.isIndirectMutating()) - continue; - origAllResults.push_back( - ai->getArgumentsWithoutIndirectResults()[paramIdx]); - } - - // Get callee pullback arguments. - SmallVector args; - - // Handle callee pullback indirect results. - // Create local allocations for these and destroy them after the call. - auto pullbackType = remapType(pullback->getType()).castTo(); - auto actualPullbackType = applyInfo.originalPullbackType - ? *applyInfo.originalPullbackType - : pullbackType; - actualPullbackType = actualPullbackType->getUnsubstitutedType(getModule()); - SmallVector pullbackIndirectResults; - for (auto indRes : actualPullbackType->getIndirectFormalResults()) { - auto *alloc = builder.createAllocStack( - loc, remapType(indRes.getSILStorageInterfaceType())); - pullbackIndirectResults.push_back(alloc); - args.push_back(alloc); - } - - // Collect callee pullback formal arguments. - for (auto resultIndex : applyInfo.indices.results->getIndices()) { - assert(resultIndex < origAllResults.size()); - auto origResult = origAllResults[resultIndex]; - // Get the seed (i.e. adjoint value of the original result). - SILValue seed; - switch (getTangentValueCategory(origResult)) { - case SILValueCategory::Object: - seed = materializeAdjointDirect(getAdjointValue(bb, origResult), loc); - break; - case SILValueCategory::Address: - seed = getAdjointBuffer(bb, origResult); - break; - } - args.push_back(seed); - } - - // If callee pullback was reabstracted in VJP, reabstract callee pullback. - if (applyInfo.originalPullbackType) { - SILOptFunctionBuilder fb(getContext().getTransform()); - pullback = reabstractFunction( - builder, fb, loc, pullback, *applyInfo.originalPullbackType, - [this](SubstitutionMap subs) -> SubstitutionMap { - return this->remapSubstitutionMap(subs); - }); - } - - // Call the callee pullback. - auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(), - args, /*isNonThrowing*/ false); - builder.emitDestroyValueOperation(loc, pullback); - - // Extract all results from `pullbackCall`. - SmallVector dirResults; - extractAllElements(pullbackCall, builder, dirResults); - // Get all results in type-defined order. - SmallVector allResults; - collectAllActualResultsInTypeOrder(pullbackCall, dirResults, allResults); - - LLVM_DEBUG({ - auto &s = getADDebugStream(); - s << "All results of the nested pullback call:\n"; - llvm::for_each(allResults, [&](SILValue v) { s << v; }); - }); - - // Accumulate adjoints for original differentiation parameters. - auto allResultsIt = allResults.begin(); - for (unsigned i : applyInfo.indices.parameters->getIndices()) { - auto origArg = ai->getArgument(ai->getNumIndirectResults() + i); - // Skip adjoint accumulation for `inout` arguments. - auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( - ai->getNumIndirectResults() + i); - if (paramInfo.isIndirectMutating()) - continue; - auto tan = *allResultsIt++; - if (tan->getType().isAddress()) { - addToAdjointBuffer(bb, origArg, tan, loc); - } else { - if (origArg->getType().isAddress()) { - auto *tmpBuf = builder.createAllocStack(loc, tan->getType()); - builder.emitStoreValueOperation(loc, tan, tmpBuf, - StoreOwnershipQualifier::Init); - addToAdjointBuffer(bb, origArg, tmpBuf, loc); - builder.emitDestroyAddrAndFold(loc, tmpBuf); - builder.createDeallocStack(loc, tmpBuf); - } else { - recordTemporary(tan); - addAdjointValue(bb, origArg, makeConcreteAdjointValue(tan), loc); - } - } - } - // Destroy unused pullback direct results. Needed for pullback results from - // VJPs extracted from `@differentiable` function callees, where the - // `@differentiable` function's differentiation parameter indices are a - // superset of the active `apply` parameter indices. - while (allResultsIt != allResults.end()) { - auto unusedPullbackDirectResult = *allResultsIt++; - if (unusedPullbackDirectResult->getType().isAddress()) - continue; - builder.emitDestroyValueOperation(loc, unusedPullbackDirectResult); - } - // Destroy and deallocate pullback indirect results. - for (auto *alloc : llvm::reverse(pullbackIndirectResults)) { - builder.emitDestroyAddrAndFold(loc, alloc); - builder.createDeallocStack(loc, alloc); - } -} +bool PullbackCloner::Implementation::runForSemanticMemberGetter() { + auto &original = getOriginal(); + auto &pullback = getPullback(); + auto pbLoc = getPullback().getLocation(); -void PullbackCloner::visitStructInst(StructInst *si) { - auto *bb = si->getParent(); - auto loc = si->getLoc(); - auto *structDecl = si->getStructDecl(); - auto av = getAdjointValue(bb, si); - switch (av.getKind()) { - case AdjointValueKind::Zero: - for (auto *field : structDecl->getStoredProperties()) { - auto fv = si->getFieldValue(field); - addAdjointValue( - bb, fv, makeZeroAdjointValue(getRemappedTangentType(fv->getType())), - loc); - } - break; - case AdjointValueKind::Concrete: { - auto adjStruct = materializeAdjointDirect(std::move(av), loc); - auto *dti = builder.createDestructureStruct(si->getLoc(), adjStruct); + auto *accessor = cast(original.getDeclContext()->getAsDecl()); + assert(accessor->getAccessorKind() == AccessorKind::Get); - // Find the struct `TangentVector` type. - auto structTy = remapType(si->getType()).getASTType(); -#ifndef NDEBUG - auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType(); - assert(!getTypeLowering(tangentVectorTy).isAddressOnly()); - assert(tangentVectorTy->getStructOrBoundGenericStruct()); -#endif + auto *origEntry = original.getEntryBlock(); + auto *pbEntry = pullback.getEntryBlock(); + builder.setInsertionPoint(pbEntry); - // Accumulate adjoints for the fields of the `struct` operand. - unsigned fieldIndex = 0; - for (auto it = structDecl->getStoredProperties().begin(); - it != structDecl->getStoredProperties().end(); ++it, ++fieldIndex) { - VarDecl *field = *it; - if (field->getAttrs().hasAttribute()) - continue; - // Find the corresponding field in the tangent space. - auto *tanField = getTangentStoredProperty(getContext(), field, structTy, - loc, getInvoker()); - if (!tanField) { - errorOccurred = true; - return; - } - auto tanElt = dti->getResult(fieldIndex); - addAdjointValue(bb, si->getFieldValue(field), - makeConcreteAdjointValue(tanElt), si->getLoc()); - } - break; - } - case AdjointValueKind::Aggregate: { - // Note: All user-called initializations go through the calls to the - // initializer, and synthesized initializers only have one level of struct - // formation which will not result into any aggregate adjoint valeus. - llvm_unreachable("Aggregate adjoint values should not occur for `struct` " - "instructions"); - } - } -} + // Get getter argument and result values. + // Getter type: $(Self) -> Result + // Pullback type: $(Result', PB_Struct) -> Self' + assert(original.getLoweredFunctionType()->getNumParameters() == 1); + assert(pullback.getLoweredFunctionType()->getNumParameters() == 2); + assert(pullback.getLoweredFunctionType()->getNumResults() == 1); + SILValue origSelf = original.getArgumentsWithoutIndirectResults().front(); -void PullbackCloner::visitBeginApplyInst(BeginApplyInst *bai) { - // Diagnose `begin_apply` instructions. - // Coroutine differentiation is not yet supported. - getContext().emitNondifferentiabilityError( - bai, getInvoker(), diag::autodiff_coroutines_not_supported); - errorOccurred = true; - return; -} + SmallVector origFormalResults; + collectAllFormalResultsInTypeOrder(original, origFormalResults); + assert(getIndices().results->getNumIndices() == 1 && + "Getter should have one semantic result"); + auto origResult = origFormalResults[*getIndices().results->begin()]; -void PullbackCloner::visitStructExtractInst(StructExtractInst *sei) { - auto *bb = sei->getParent(); - auto loc = getValidLocation(sei); - auto structTy = remapType(sei->getOperand()->getType()).getASTType(); - auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType(); - assert(!getTypeLowering(tangentVectorTy).isAddressOnly()); - auto tangentVectorSILTy = SILType::getPrimitiveObjectType(tangentVectorTy); + auto tangentVectorSILTy = pullback.getConventions().getSingleSILResultType( + TypeExpansionContext::minimal()); + auto tangentVectorTy = tangentVectorSILTy.getASTType(); auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct(); - assert(tangentVectorDecl); - // Find the corresponding field in the tangent space. - auto *tanField = - getTangentStoredProperty(getContext(), sei, structTy, getInvoker()); - assert(tanField && "Invalid projections should have been diagnosed"); - // Accumulate adjoint for the `struct_extract` operand. - auto av = getAdjointValue(bb, sei); - switch (av.getKind()) { - case AdjointValueKind::Zero: - addAdjointValue(bb, sei->getOperand(), - makeZeroAdjointValue(tangentVectorSILTy), loc); - break; - case AdjointValueKind::Concrete: - case AdjointValueKind::Aggregate: { - SmallVector eltVals; - for (auto *field : tangentVectorDecl->getStoredProperties()) { - if (field == tanField) { - eltVals.push_back(av); - } else { - auto substMap = tangentVectorTy->getMemberSubstitutionMap( - field->getModuleContext(), field); - auto fieldTy = field->getType().subst(substMap); - auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType(); - assert(fieldSILTy.isObject()); - eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); - } - } - addAdjointValue(bb, sei->getOperand(), - makeAggregateAdjointValue(tangentVectorSILTy, eltVals), - loc); - } + + // Look up the corresponding field in the tangent space. + auto *origField = cast(accessor->getStorage()); + auto baseType = remapType(origSelf->getType()).getASTType(); + auto *tanField = getTangentStoredProperty(getContext(), origField, baseType, + pbLoc, getInvoker()); + if (!tanField) { + errorOccurred = true; + return true; } -} -void PullbackCloner::visitRefElementAddrInst(RefElementAddrInst *reai) { - auto *bb = reai->getParent(); - auto loc = reai->getLoc(); - auto adjBuf = getAdjointBuffer(bb, reai); - auto classOperand = reai->getOperand(); - auto classType = remapType(reai->getOperand()->getType()).getASTType(); - auto *tanField = - getTangentStoredProperty(getContext(), reai, classType, getInvoker()); - assert(tanField && "Invalid projections should have been diagnosed"); - switch (getTangentValueCategory(classOperand)) { + // Switch based on the base tangent struct's value category. + // TODO(TF-1255): Simplify using unified adjoint value data structure. + switch (tangentVectorSILTy.getCategory()) { case SILValueCategory::Object: { - auto classTy = remapType(classOperand->getType()).getASTType(); - auto tangentVectorTy = getTangentSpace(classTy)->getCanonicalType(); - auto tangentVectorSILTy = SILType::getPrimitiveObjectType(tangentVectorTy); - auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct(); - // Accumulate adjoint for the `ref_element_addr` operand. - SmallVector eltVals; - for (auto *field : tangentVectorDecl->getStoredProperties()) { - if (field == tanField) { - auto adjElt = builder.emitLoadValueOperation( - reai->getLoc(), adjBuf, LoadOwnershipQualifier::Copy); - eltVals.push_back(makeConcreteAdjointValue(adjElt)); - recordTemporary(adjElt); - } else { - auto substMap = tangentVectorTy->getMemberSubstitutionMap( - field->getModuleContext(), field); - auto fieldTy = field->getType().subst(substMap); - auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType(); - assert(fieldSILTy.isObject()); - eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); + auto adjResult = getAdjointValue(origEntry, origResult); + switch (adjResult.getKind()) { + case AdjointValueKind::Zero: + addAdjointValue(origEntry, origSelf, + makeZeroAdjointValue(tangentVectorSILTy), pbLoc); + break; + case AdjointValueKind::Concrete: + case AdjointValueKind::Aggregate: { + SmallVector eltVals; + for (auto *field : tangentVectorDecl->getStoredProperties()) { + if (field == tanField) { + eltVals.push_back(adjResult); + } else { + auto substMap = tangentVectorTy->getMemberSubstitutionMap( + field->getModuleContext(), field); + auto fieldTy = field->getType().subst(substMap); + auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType(); + assert(fieldSILTy.isObject()); + eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); + } } + addAdjointValue(origEntry, origSelf, + makeAggregateAdjointValue(tangentVectorSILTy, eltVals), + pbLoc); } - addAdjointValue(bb, classOperand, - makeAggregateAdjointValue(tangentVectorSILTy, eltVals), - loc); - break; - } - case SILValueCategory::Address: { - auto adjBufClass = getAdjointBuffer(bb, classOperand); - auto adjBufElt = - builder.createStructElementAddr(loc, adjBufClass, tanField); - accumulateIndirect(adjBufElt, adjBuf, loc); - break; - } - } -} - -void PullbackCloner::visitTupleInst(TupleInst *ti) { - auto *bb = ti->getParent(); - auto av = getAdjointValue(bb, ti); - switch (av.getKind()) { - case AdjointValueKind::Zero: - for (auto elt : ti->getElements()) { - if (!getTangentSpace(elt->getType().getASTType())) - continue; - addAdjointValue( - bb, elt, makeZeroAdjointValue(getRemappedTangentType(elt->getType())), - ti->getLoc()); - } - break; - case AdjointValueKind::Concrete: { - auto adjVal = av.getConcreteValue(); - unsigned adjIdx = 0; - auto adjValCopy = builder.emitCopyValueOperation(ti->getLoc(), adjVal); - SmallVector adjElts; - if (!adjVal->getType().getAs()) { - recordTemporary(adjValCopy); - adjElts.push_back(adjValCopy); - } else { - auto *dti = builder.createDestructureTuple(ti->getLoc(), adjValCopy); - for (auto adjElt : dti->getResults()) - recordTemporary(adjElt); - adjElts.append(dti->getResults().begin(), dti->getResults().end()); - } - // Accumulate adjoints for `tuple` operands, skipping the - // non-differentiable ones. - for (auto i : range(ti->getNumOperands())) { - if (!getTangentSpace(ti->getOperand(i)->getType().getASTType())) - continue; - auto adjElt = adjElts[adjIdx++]; - addAdjointValue(bb, ti->getOperand(i), makeConcreteAdjointValue(adjElt), - ti->getLoc()); - } - break; - } - case AdjointValueKind::Aggregate: - unsigned adjIdx = 0; - for (auto i : range(ti->getElements().size())) { - if (!getTangentSpace(ti->getElement(i)->getType().getASTType())) - continue; - addAdjointValue(bb, ti->getElement(i), av.getAggregateElement(adjIdx++), - ti->getLoc()); } break; } -} - -void PullbackCloner::visitTupleExtractInst(TupleExtractInst *tei) { - auto *bb = tei->getParent(); - auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType()); - auto av = getAdjointValue(bb, tei); - switch (av.getKind()) { - case AdjointValueKind::Zero: - addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy), - tei->getLoc()); - break; - case AdjointValueKind::Aggregate: - case AdjointValueKind::Concrete: { - auto tupleTy = tei->getTupleType(); - auto tupleTanTupleTy = tupleTanTy.getAs(); - if (!tupleTanTupleTy) { - addAdjointValue(bb, tei->getOperand(), av, tei->getLoc()); - break; - } - SmallVector elements; - unsigned adjIdx = 0; - for (unsigned i : range(tupleTy->getNumElements())) { - if (!getTangentSpace( - tupleTy->getElement(i).getType()->getCanonicalType())) - continue; - if (tei->getFieldNo() == i) - elements.push_back(av); - else - elements.push_back(makeZeroAdjointValue( - getRemappedTangentType(SILType::getPrimitiveObjectType( - tupleTanTupleTy->getElementType(adjIdx++) - ->getCanonicalType())))); - } - if (elements.size() == 1) { - addAdjointValue(bb, tei->getOperand(), elements.front(), tei->getLoc()); - break; + case SILValueCategory::Address: { + assert(pullback.getIndirectResults().size() == 1); + auto pbIndRes = pullback.getIndirectResults().front(); + auto *adjSelf = createFunctionLocalAllocation( + pbIndRes->getType().getObjectType(), pbLoc); + setAdjointBuffer(origEntry, origSelf, adjSelf); + for (auto *field : tangentVectorDecl->getStoredProperties()) { + auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, field); + if (field == tanField) { + // Switch based on the property's value category. + // TODO(TF-1255): Simplify using unified adjoint value data structure. + switch (origResult->getType().getCategory()) { + case SILValueCategory::Object: { + auto adjResult = getAdjointValue(origEntry, origResult); + auto adjResultValue = materializeAdjointDirect(adjResult, pbLoc); + auto adjResultValueCopy = + builder.emitCopyValueOperation(pbLoc, adjResultValue); + builder.emitStoreValueOperation(pbLoc, adjResultValueCopy, adjSelfElt, + StoreOwnershipQualifier::Init); + break; + } + case SILValueCategory::Address: { + auto adjResult = getAdjointBuffer(origEntry, origResult); + builder.createCopyAddr(pbLoc, adjResult, adjSelfElt, IsTake, + IsInitialization); + destroyedLocalAllocations.insert(adjResult); + break; + } + } + } else { + auto fieldType = pullback.mapTypeIntoContext(field->getInterfaceType()) + ->getCanonicalType(); + emitZeroIndirect(fieldType, adjSelfElt, pbLoc); + } } - addAdjointValue(bb, tei->getOperand(), - makeAggregateAdjointValue(tupleTanTy, elements), - tei->getLoc()); break; } } + return false; } -void PullbackCloner::visitDestructureTupleInst(DestructureTupleInst *dti) { - auto *bb = dti->getParent(); - auto tupleTanTy = getRemappedTangentType(dti->getOperand()->getType()); - SmallVector adjValues; - for (auto origElt : dti->getResults()) { - if (!getTangentSpace(remapType(origElt->getType()).getASTType())) - continue; - adjValues.push_back(getAdjointValue(bb, origElt)); - } - // Handle tuple tangent type. - // Add adjoints for every tuple element that has a tangent space. - if (tupleTanTy.is()) { - assert(adjValues.size() > 1); - addAdjointValue(bb, dti->getOperand(), - makeAggregateAdjointValue(tupleTanTy, adjValues), - dti->getLoc()); - } - // Handle non-tuple tangent type. - // Add adjoint for the single tuple element that has a tangent space. - else { - assert(adjValues.size() == 1); - addAdjointValue(bb, dti->getOperand(), adjValues.front(), dti->getLoc()); - } -} +bool PullbackCloner::Implementation::runForSemanticMemberSetter() { + auto &original = getOriginal(); + auto &pullback = getPullback(); + auto pbLoc = getPullback().getLocation(); -void PullbackCloner::visitLoadOperation(SingleValueInstruction *inst) { - assert(isa(inst) || isa(inst)); - auto *bb = inst->getParent(); - auto loc = inst->getLoc(); - switch (getTangentValueCategory(inst)) { - case SILValueCategory::Object: { - auto adjVal = materializeAdjointDirect(getAdjointValue(bb, inst), loc); - // Allocate a local buffer and store the adjoint value. This buffer will be - // used for accumulation into the adjoint buffer. - auto adjBuf = builder.createAllocStack(loc, adjVal->getType()); - auto copy = builder.emitCopyValueOperation(loc, adjVal); - builder.emitStoreValueOperation(loc, copy, adjBuf, - StoreOwnershipQualifier::Init); - // Accumulate the adjoint value in the local buffer into the adjoint buffer. - addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc); - builder.emitDestroyAddr(loc, adjBuf); - builder.createDeallocStack(loc, adjBuf); - break; - } - case SILValueCategory::Address: { - auto adjBuf = getAdjointBuffer(bb, inst); - addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc); - break; - } - } -} + auto *accessor = cast(original.getDeclContext()->getAsDecl()); + assert(accessor->getAccessorKind() == AccessorKind::Set); -void PullbackCloner::visitStoreOperation(SILBasicBlock *bb, SILLocation loc, - SILValue origSrc, SILValue origDest) { - auto &adjBuf = getAdjointBuffer(bb, origDest); - switch (getTangentValueCategory(origSrc)) { - case SILValueCategory::Object: { - auto adjVal = builder.emitLoadValueOperation(loc, adjBuf, - LoadOwnershipQualifier::Take); - recordTemporary(adjVal); - addAdjointValue(bb, origSrc, makeConcreteAdjointValue(adjVal), loc); - emitZeroIndirect(adjBuf->getType().getASTType(), adjBuf, loc); - break; - } - case SILValueCategory::Address: { - addToAdjointBuffer(bb, origSrc, adjBuf, loc); - builder.emitDestroyAddr(loc, adjBuf); - emitZeroIndirect(adjBuf->getType().getASTType(), adjBuf, loc); - break; - } - } -} + auto *origEntry = original.getEntryBlock(); + auto *pbEntry = pullback.getEntryBlock(); + builder.setInsertionPoint(pbEntry); -void PullbackCloner::visitStoreInst(StoreInst *si) { - visitStoreOperation(si->getParent(), si->getLoc(), si->getSrc(), - si->getDest()); -} + // Get setter argument values. + // Setter type: $(inout Self, Argument) -> () + // Pullback type (wrt self): $(inout Self', PB_Struct) -> () + // Pullback type (wrt both): $(inout Self', PB_Struct) -> Argument' + assert(original.getLoweredFunctionType()->getNumParameters() == 2); + assert(pullback.getLoweredFunctionType()->getNumParameters() == 2); + assert(pullback.getLoweredFunctionType()->getNumResults() == 0 || + pullback.getLoweredFunctionType()->getNumResults() == 1); -void PullbackCloner::visitCopyAddrInst(CopyAddrInst *cai) { - auto *bb = cai->getParent(); - auto &adjDest = getAdjointBuffer(bb, cai->getDest()); - auto destType = remapType(adjDest->getType()); - addToAdjointBuffer(bb, cai->getSrc(), adjDest, cai->getLoc()); - builder.emitDestroyAddrAndFold(cai->getLoc(), adjDest); - emitZeroIndirect(destType.getASTType(), adjDest, cai->getLoc()); -} + SILValue origArg = original.getArgumentsWithoutIndirectResults()[0]; + SILValue origSelf = original.getArgumentsWithoutIndirectResults()[1]; -void PullbackCloner::visitCopyValueInst(CopyValueInst *cvi) { - auto *bb = cvi->getParent(); - switch (getTangentValueCategory(cvi)) { - case SILValueCategory::Object: { - auto adj = getAdjointValue(bb, cvi); - addAdjointValue(bb, cvi->getOperand(), adj, cvi->getLoc()); - break; - } - case SILValueCategory::Address: { - auto &adjDest = getAdjointBuffer(bb, cvi); - auto destType = remapType(adjDest->getType()); - addToAdjointBuffer(bb, cvi->getOperand(), adjDest, cvi->getLoc()); - builder.emitDestroyAddrAndFold(cvi->getLoc(), adjDest); - emitZeroIndirect(destType.getASTType(), adjDest, cvi->getLoc()); - break; - } + // Look up the corresponding field in the tangent space. + auto *origField = cast(accessor->getStorage()); + auto baseType = remapType(origSelf->getType()).getASTType(); + auto *tanField = getTangentStoredProperty(getContext(), origField, baseType, + pbLoc, getInvoker()); + if (!tanField) { + errorOccurred = true; + return true; } -} -void PullbackCloner::visitBeginBorrowInst(BeginBorrowInst *bbi) { - auto *bb = bbi->getParent(); - switch (getTangentValueCategory(bbi)) { + auto adjSelf = getAdjointBuffer(origEntry, origSelf); + auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField); + // Switch based on the property's value category. + // TODO(TF-1255): Simplify using unified adjoint value data structure. + switch (origArg->getType().getCategory()) { case SILValueCategory::Object: { - auto adj = getAdjointValue(bb, bbi); - addAdjointValue(bb, bbi->getOperand(), adj, bbi->getLoc()); + auto adjArg = builder.emitLoadValueOperation(pbLoc, adjSelfElt, + LoadOwnershipQualifier::Take); + setAdjointValue(origEntry, origArg, makeConcreteAdjointValue(adjArg)); + blockTemporaries[pbEntry].insert(adjArg); break; } case SILValueCategory::Address: { - auto &adjDest = getAdjointBuffer(bb, bbi); - auto destType = remapType(adjDest->getType()); - addToAdjointBuffer(bb, bbi->getOperand(), adjDest, bbi->getLoc()); - builder.emitDestroyAddrAndFold(bbi->getLoc(), adjDest); - emitZeroIndirect(destType.getASTType(), adjDest, bbi->getLoc()); + addToAdjointBuffer(origEntry, origArg, adjSelfElt, pbLoc); + builder.emitDestroyOperation(pbLoc, adjSelfElt); break; } } -} - -void PullbackCloner::visitBeginAccessInst(BeginAccessInst *bai) { - // Check for non-differentiable writes. - if (bai->getAccessKind() == SILAccessKind::Modify) { - if (isa(bai->getSource())) { - getContext().emitNondifferentiabilityError( - bai, getInvoker(), - diag::autodiff_cannot_differentiate_writes_to_global_variables); - errorOccurred = true; - return; - } - if (isa(bai->getSource())) { - getContext().emitNondifferentiabilityError( - bai, getInvoker(), - diag::autodiff_cannot_differentiate_writes_to_mutable_captures); - errorOccurred = true; - return; - } - } -} - -void PullbackCloner::visitUnconditionalCheckedCastAddrInst( - UnconditionalCheckedCastAddrInst *uccai) { - auto *bb = uccai->getParent(); - auto &adjDest = getAdjointBuffer(bb, uccai->getDest()); - auto &adjSrc = getAdjointBuffer(bb, uccai->getSrc()); - auto destType = remapType(adjDest->getType()); - auto castBuf = builder.createAllocStack(uccai->getLoc(), adjSrc->getType()); - builder.createUnconditionalCheckedCastAddr( - uccai->getLoc(), adjDest, adjDest->getType().getASTType(), castBuf, - adjSrc->getType().getASTType()); - addToAdjointBuffer(bb, uccai->getSrc(), castBuf, uccai->getLoc()); - builder.emitDestroyAddrAndFold(uccai->getLoc(), castBuf); - builder.createDeallocStack(uccai->getLoc(), castBuf); - emitZeroIndirect(destType.getASTType(), adjDest, uccai->getLoc()); -} - -void PullbackCloner::visitUncheckedRefCastInst(UncheckedRefCastInst *urci) { - auto *bb = urci->getParent(); - assert(urci->getOperand()->getType().isObject()); - assert(getRemappedTangentType(urci->getOperand()->getType()) == - getRemappedTangentType(urci->getType()) && - "Operand/result must have the same `TangentVector` type"); - auto adj = getAdjointValue(bb, urci); - addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc()); -} - -void PullbackCloner::visitUpcastInst(UpcastInst *ui) { - auto *bb = ui->getParent(); - assert(ui->getOperand()->getType().isObject()); - assert(getRemappedTangentType(ui->getOperand()->getType()) == - getRemappedTangentType(ui->getType()) && - "Operand/result must have the same `TangentVector` type"); - auto adj = getAdjointValue(bb, ui); - addAdjointValue(bb, ui->getOperand(), adj, ui->getLoc()); -} - -#define NOT_DIFFERENTIABLE(INST, DIAG) \ - void PullbackCloner::visit##INST##Inst(INST##Inst *inst) { \ - getContext().emitNondifferentiabilityError(inst, getInvoker(), \ - diag::DIAG); \ - errorOccurred = true; \ - return; \ - } -#undef NOT_DIFFERENTIABLE - -AdjointValue PullbackCloner::makeZeroAdjointValue(SILType type) { - return AdjointValue::createZero(allocator, remapType(type)); -} + emitZeroIndirect(adjSelfElt->getType().getASTType(), adjSelfElt, pbLoc); -AdjointValue PullbackCloner::makeConcreteAdjointValue(SILValue value) { - return AdjointValue::createConcrete(allocator, value); + return false; } -template -AdjointValue PullbackCloner::makeAggregateAdjointValue(SILType type, - EltRange elements) { - AdjointValue *buf = reinterpret_cast(allocator.Allocate( - elements.size() * sizeof(AdjointValue), alignof(AdjointValue))); - MutableArrayRef elementsCopy(buf, elements.size()); - std::uninitialized_copy(elements.begin(), elements.end(), - elementsCopy.begin()); - return AdjointValue::createAggregate(allocator, remapType(type), - elementsCopy); -} +//--------------------------------------------------------------------------// +// Adjoint buffer mapping +//--------------------------------------------------------------------------// -SILValue PullbackCloner::materializeAdjointDirect(AdjointValue val, - SILLocation loc) { - assert(val.getType().isObject()); - LLVM_DEBUG(getADDebugStream() << "Materializing adjoint for " << val << '\n'); - switch (val.getKind()) { - case AdjointValueKind::Zero: - return recordTemporary(emitZeroDirect(val.getType().getASTType(), loc)); - case AdjointValueKind::Aggregate: { - SmallVector elements; - for (auto i : range(val.getNumAggregateElements())) { - auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc); - elements.push_back(builder.emitCopyValueOperation(loc, eltVal)); - } - if (val.getType().is()) - return recordTemporary(builder.createTuple(loc, val.getType(), elements)); - else - return recordTemporary( - builder.createStruct(loc, val.getType(), elements)); +SILValue PullbackCloner::Implementation::getAdjointProjection( + SILBasicBlock *origBB, SILValue originalProjection) { + // Handle `struct_element_addr`. + // Adjoint projection: a `struct_element_addr` into the base adjoint buffer. + if (auto *seai = dyn_cast(originalProjection)) { + assert(!seai->getField()->getAttrs().hasAttribute() && + "`@noDerivative` struct projections should never be active"); + auto adjSource = getAdjointBuffer(origBB, seai->getOperand()); + auto structType = remapType(seai->getOperand()->getType()).getASTType(); + auto *tanField = + getTangentStoredProperty(getContext(), seai, structType, getInvoker()); + assert(tanField && "Invalid projections should have been diagnosed"); + return builder.createStructElementAddr(seai->getLoc(), adjSource, tanField); } - case AdjointValueKind::Concrete: - return val.getConcreteValue(); + // Handle `tuple_element_addr`. + // Adjoint projection: a `tuple_element_addr` into the base adjoint buffer. + if (auto *teai = dyn_cast(originalProjection)) { + auto source = teai->getOperand(); + auto adjSource = getAdjointBuffer(origBB, source); + if (!adjSource->getType().is()) + return adjSource; + auto origTupleTy = source->getType().castTo(); + unsigned adjIndex = 0; + for (unsigned i : range(teai->getFieldNo())) { + if (getTangentSpace( + origTupleTy->getElement(i).getType()->getCanonicalType())) + ++adjIndex; + } + return builder.createTupleElementAddr(teai->getLoc(), adjSource, adjIndex); } - llvm_unreachable("invalid value kind"); -} - -void PullbackCloner::materializeAdjointIndirect(AdjointValue val, - SILValue destAddress, - SILLocation loc) { - assert(destAddress->getType().isAddress()); - switch (val.getKind()) { - /// If adjoint value is a symbolic zero, emit a call to - /// `AdditiveArithmetic.zero`. - case AdjointValueKind::Zero: - emitZeroIndirect(val.getSwiftType(), destAddress, loc); - break; - /// If adjoint value is a symbolic aggregate (tuple or struct), recursively - /// materialize materialize the symbolic tuple or struct, filling the - /// buffer. - case AdjointValueKind::Aggregate: { - if (auto *tupTy = val.getSwiftType()->getAs()) { - for (auto idx : range(val.getNumAggregateElements())) { - auto eltTy = SILType::getPrimitiveAddressType( - tupTy->getElementType(idx)->getCanonicalType()); - auto *eltBuf = - builder.createTupleElementAddr(loc, destAddress, idx, eltTy); - materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc); - } - } else if (auto *structDecl = - val.getSwiftType()->getStructOrBoundGenericStruct()) { - auto fieldIt = structDecl->getStoredProperties().begin(); - for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end(); - ++fieldIt, ++i) { - auto eltBuf = - builder.createStructElementAddr(loc, destAddress, *fieldIt); - materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc); - } - } else { - llvm_unreachable("Not an aggregate type"); + // Handle `ref_element_addr`. + // Adjoint projection: a local allocation initialized with the corresponding + // field value from the class's base adjoint value. + if (auto *reai = dyn_cast(originalProjection)) { + assert(!reai->getField()->getAttrs().hasAttribute() && + "`@noDerivative` class projections should never be active"); + auto loc = reai->getLoc(); + // Get the class operand, stripping `begin_borrow`. + auto classOperand = stripBorrow(reai->getOperand()); + auto classType = remapType(reai->getOperand()->getType()).getASTType(); + auto *tanField = + getTangentStoredProperty(getContext(), reai->getField(), classType, + reai->getLoc(), getInvoker()); + assert(tanField && "Invalid projections should have been diagnosed"); + // Create a local allocation for the element adjoint buffer. + auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType(); + auto eltTanSILType = + remapType(SILType::getPrimitiveAddressType(eltTanType)); + auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc); + // Check the class operand's `TangentVector` value category. + switch (getTangentValueCategory(classOperand)) { + case SILValueCategory::Object: { + // Get the class operand's adjoint value. Currently, it must be a + // `TangentVector` struct. + auto adjClass = + materializeAdjointDirect(getAdjointValue(origBB, classOperand), loc); + builder.emitScopedBorrowOperation( + loc, adjClass, [&](SILValue borrowedAdjClass) { + // Initialize the element adjoint buffer with the base adjoint + // value. + auto *adjElt = + builder.createStructExtract(loc, borrowedAdjClass, tanField); + auto adjEltCopy = builder.emitCopyValueOperation(loc, adjElt); + builder.emitStoreValueOperation(loc, adjEltCopy, eltAdjBuffer, + StoreOwnershipQualifier::Init); + }); + return eltAdjBuffer; + } + case SILValueCategory::Address: { + // Get the class operand's adjoint buffer. Currently, it must be a + // `TangentVector` struct. + auto adjClass = getAdjointBuffer(origBB, classOperand); + // Initialize the element adjoint buffer with the base adjoint buffer. + auto *adjElt = builder.createStructElementAddr(loc, adjClass, tanField); + builder.createCopyAddr(loc, adjElt, eltAdjBuffer, IsNotTake, + IsInitialization); + return eltAdjBuffer; + } } - break; } - /// If adjoint value is concrete, it is already materialized. Store it in the - /// destination address. - case AdjointValueKind::Concrete: - auto concreteVal = val.getConcreteValue(); - builder.emitStoreValueOperation(loc, concreteVal, destAddress, - StoreOwnershipQualifier::Init); - break; + // Handle `begin_access`. + // Adjoint projection: the base adjoint buffer itself. + if (auto *bai = dyn_cast(originalProjection)) { + auto adjBase = getAdjointBuffer(origBB, bai->getOperand()); + if (errorOccurred) + return (bufferMap[{origBB, originalProjection}] = SILValue()); + // Return the base buffer's adjoint buffer. + return adjBase; } -} - -void PullbackCloner::emitZeroIndirect(CanType type, SILValue address, - SILLocation loc) { - auto tangentSpace = getTangentSpace(type); - assert(tangentSpace && "No tangent space for this type"); - switch (tangentSpace->getKind()) { - case TangentSpace::Kind::TangentVector: - emitZeroIntoBuffer(builder, type, address, loc); - return; - case TangentSpace::Kind::Tuple: { - auto tupleType = tangentSpace->getTuple(); - SmallVector zeroElements; - for (unsigned i : range(tupleType->getNumElements())) { - auto eltAddr = builder.createTupleElementAddr(loc, address, i); - emitZeroIndirect(tupleType->getElementType(i)->getCanonicalType(), - eltAddr, loc); + // Handle `array.uninitialized_intrinsic` application element addresses. + // Adjoint projection: a local allocation initialized by applying + // `Array.TangentVector.subscript` to the base array's adjoint value. + auto *ai = + getAllocateUninitializedArrayIntrinsicElementAddress(originalProjection); + auto *definingInst = dyn_cast_or_null( + originalProjection->getDefiningInstruction()); + bool isAllocateUninitializedArrayIntrinsicElementAddress = + ai && definingInst && + (isa(definingInst) || + isa(definingInst)); + if (isAllocateUninitializedArrayIntrinsicElementAddress) { + // Get the array element index of the result address. + int eltIndex = 0; + if (auto *iai = dyn_cast(definingInst)) { + auto *ili = cast(iai->getIndex()); + eltIndex = ili->getValue().getLimitedValue(); } - return; - } + // Get the array adjoint value. + SILValue arrayAdjoint; + assert(ai && "Expected `array.uninitialized_intrinsic` application"); + for (auto use : ai->getUses()) { + auto *dti = dyn_cast(use->getUser()); + if (!dti) + continue; + assert(!arrayAdjoint && "Array adjoint already found"); + // The first `destructure_tuple` result is the `Array` value. + auto arrayValue = dti->getResult(0); + arrayAdjoint = materializeAdjointDirect( + getAdjointValue(origBB, arrayValue), definingInst->getLoc()); + } + assert(arrayAdjoint && "Array does not have adjoint value"); + // Apply `Array.TangentVector.subscript` to get array element adjoint value. + auto *eltAdjBuffer = + getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, ai->getLoc()); + return eltAdjBuffer; } + return SILValue(); } -SILValue PullbackCloner::emitZeroDirect(CanType type, SILLocation loc) { - auto silType = getModule().Types.getLoweredLoadableType( - type, TypeExpansionContext::minimal(), getModule()); - auto *buffer = builder.createAllocStack(loc, silType); - emitZeroIndirect(type, buffer, loc); - auto loaded = - builder.emitLoadValueOperation(loc, buffer, LoadOwnershipQualifier::Take); - builder.createDeallocStack(loc, buffer); - return loaded; -} +//----------------------------------------------------------------------------// +// Adjoint value accumulation +//----------------------------------------------------------------------------// -AdjointValue PullbackCloner::accumulateAdjointsDirect(AdjointValue lhs, - AdjointValue rhs, - SILLocation loc) { +AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect( + AdjointValue lhs, AdjointValue rhs, SILLocation loc) { LLVM_DEBUG(getADDebugStream() << "Materializing adjoint directly.\nLHS: " << lhs << "\nRHS: " << rhs << '\n'); - switch (lhs.getKind()) { // x case AdjointValueKind::Concrete: { @@ -2303,12 +2529,12 @@ AdjointValue PullbackCloner::accumulateAdjointsDirect(AdjointValue lhs, } } } - llvm_unreachable("invalid LHS kind"); + llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715 } -SILValue PullbackCloner::accumulateDirect(SILValue lhs, SILValue rhs, - SILLocation loc) { - // TODO: Optimize for the case when lhs == rhs. +SILValue PullbackCloner::Implementation::accumulateDirect(SILValue lhs, + SILValue rhs, + SILLocation loc) { LLVM_DEBUG(getADDebugStream() << "Emitting adjoint accumulation for lhs: " << lhs << " and rhs: " << rhs); assert(lhs->getType() == rhs->getType() && "Adjoints must have equal types!"); @@ -2353,13 +2579,13 @@ SILValue PullbackCloner::accumulateDirect(SILValue lhs, SILValue rhs, return builder.createTuple(loc, adjointTy, adjElements); } } - llvm_unreachable("invalid tangent space"); + llvm_unreachable("Invalid tangent space"); // silences MSVC C4715 } -void PullbackCloner::accumulateIndirect(SILValue resultAddress, - SILValue lhsAddress, - SILValue rhsAddress, SILLocation loc) { - // TODO: Optimize for the case when lhs == rhs. +void PullbackCloner::Implementation::accumulateIndirect(SILValue resultAddress, + SILValue lhsAddress, + SILValue rhsAddress, + SILLocation loc) { assert(lhsAddress->getType() == rhsAddress->getType() && "Adjoint values must have same type!"); assert(lhsAddress->getType().isAddress() && @@ -2416,8 +2642,9 @@ void PullbackCloner::accumulateIndirect(SILValue resultAddress, } } -void PullbackCloner::accumulateIndirect(SILValue lhsDestAddress, - SILValue rhsAddress, SILLocation loc) { +void PullbackCloner::Implementation::accumulateIndirect(SILValue lhsDestAddress, + SILValue rhsAddress, + SILLocation loc) { assert(lhsDestAddress->getType().isAddress() && rhsAddress->getType().isAddress()); assert(lhsDestAddress->getFunction() == &getPullback()); @@ -2468,5 +2695,135 @@ void PullbackCloner::accumulateIndirect(SILValue lhsDestAddress, } } +//----------------------------------------------------------------------------// +// Array literal initialization differentiation +//----------------------------------------------------------------------------// + +void PullbackCloner::Implementation:: + accumulateArrayLiteralElementAddressAdjoints(SILBasicBlock *origBB, + SILValue originalValue, + AdjointValue arrayAdjointValue, + SILLocation loc) { + // Return if the original value is not the `Array` result of an + // `array.uninitialized_intrinsic` application. + auto *dti = dyn_cast_or_null( + originalValue->getDefiningInstruction()); + if (!dti) + return; + if (!ArraySemanticsCall(dti->getOperand(), + semantics::ARRAY_UNINITIALIZED_INTRINSIC)) + return; + if (originalValue != dti->getResult(0)) + return; + // Accumulate the array's adjoint value into the adjoint buffers of its + // element addresses: `pointer_to_address` and `index_addr` instructions. + LLVM_DEBUG(getADDebugStream() + << "Accumulating adjoint value for array literal into element " + "address adjoint buffers" + << originalValue); + auto arrayAdjoint = materializeAdjointDirect(arrayAdjointValue, loc); + builder.setInsertionPoint(arrayAdjoint->getParentBlock()); + for (auto use : dti->getResult(1)->getUses()) { + auto *ptai = dyn_cast(use->getUser()); + auto adjBuf = getAdjointBuffer(origBB, ptai); + auto *eltAdjBuf = getArrayAdjointElementBuffer(arrayAdjoint, 0, loc); + accumulateIndirect(adjBuf, eltAdjBuf, loc); + for (auto use : ptai->getUses()) { + if (auto *iai = dyn_cast(use->getUser())) { + auto *ili = cast(iai->getIndex()); + auto eltIndex = ili->getValue().getLimitedValue(); + auto adjBuf = getAdjointBuffer(origBB, iai); + auto *eltAdjBuf = + getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, loc); + accumulateIndirect(adjBuf, eltAdjBuf, loc); + } + } + } +} + +AllocStackInst *PullbackCloner::Implementation::getArrayAdjointElementBuffer( + SILValue arrayAdjoint, int eltIndex, SILLocation loc) { + auto &ctx = builder.getASTContext(); + auto arrayTanType = cast(arrayAdjoint->getType().getASTType()); + auto arrayType = arrayTanType->getParent()->castTo(); + auto eltTanType = arrayType->getGenericArgs().front()->getCanonicalType(); + auto eltTanSILType = remapType(SILType::getPrimitiveAddressType(eltTanType)); + // Get `function_ref` and generic signature of + // `Array.TangentVector.subscript.getter`. + auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct(); + auto subscriptLookup = + arrayTanStructDecl->lookupDirect(DeclBaseName::createSubscript()); + SubscriptDecl *subscriptDecl = nullptr; + for (auto *candidate : subscriptLookup) { + auto candidateModule = candidate->getModuleContext(); + if (candidateModule->getName() == ctx.Id_Differentiation || + candidateModule->isStdlibModule()) { + assert(!subscriptDecl && "Multiple `Array.TangentVector.subscript`s"); + subscriptDecl = cast(candidate); +#ifdef NDEBUG + break; +#endif + } + } + assert(subscriptDecl && "No `Array.TangentVector.subscript`"); + auto *subscriptGetterDecl = + subscriptDecl->getOpaqueAccessor(AccessorKind::Get); + assert(subscriptGetterDecl && "No `Array.TangentVector.subscript` getter"); + SILOptFunctionBuilder fb(getContext().getTransform()); + auto *subscriptGetterFn = fb.getOrCreateFunction( + loc, SILDeclRef(subscriptGetterDecl), NotForDefinition); + // %subscript_fn = function_ref @Array.TangentVector.subscript.getter + auto *subscriptFnRef = builder.createFunctionRef(loc, subscriptGetterFn); + auto subscriptFnGenSig = + subscriptGetterFn->getLoweredFunctionType()->getSubstGenericSignature(); + // Apply `Array.TangentVector.subscript.getter` to get array element adjoint + // buffer. + // %index_literal = integer_literal $Builtin.IntXX, + auto builtinIntType = + SILType::getPrimitiveObjectType(ctx.getIntDecl() + ->getStoredProperties() + .front() + ->getInterfaceType() + ->getCanonicalType()); + auto *eltIndexLiteral = + builder.createIntegerLiteral(loc, builtinIntType, eltIndex); + auto intType = SILType::getPrimitiveObjectType( + ctx.getIntDecl()->getDeclaredType()->getCanonicalType()); + // %index_int = struct $Int (%index_literal) + auto *eltIndexInt = builder.createStruct(loc, intType, {eltIndexLiteral}); + auto *swiftModule = getModule().getSwiftModule(); + auto *diffProto = ctx.getProtocol(KnownProtocolKind::Differentiable); + auto diffConf = swiftModule->lookupConformance(eltTanType, diffProto); + assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`"); + auto *addArithProto = ctx.getProtocol(KnownProtocolKind::AdditiveArithmetic); + auto addArithConf = swiftModule->lookupConformance(eltTanType, addArithProto); + assert(!addArithConf.isInvalid() && + "Missing conformance to `AdditiveArithmetic`"); + auto subMap = SubstitutionMap::get(subscriptFnGenSig, {eltTanType}, + {addArithConf, diffConf}); + // %elt_adj = alloc_stack $T.TangentVector + // Create and register a local allocation. + auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc); + // Temporarily change global builder insertion point and emit zero into the + // local allocation. + auto insertionPoint = builder.getInsertionBB(); + builder.setInsertionPoint(localAllocBuilder.getInsertionBB(), + localAllocBuilder.getInsertionPoint()); + emitZeroIndirect(eltTanType, eltAdjBuffer, loc); + builder.setInsertionPoint(insertionPoint); + // Immediately destroy the emitted zero value. + // NOTE: It is not efficient to emit a zero value then immediately destroy + // it. However, it was the easiest way to to avoid "lifetime mismatch in + // predecessors" memory lifetime verification errors for control flow + // differentiation. + // Perhaps we can avoid emitting a zero value if local allocations are created + // per pullback bb instead of all in the pullback entry: TF-1075. + builder.emitDestroyOperation(loc, eltAdjBuffer); + // apply %subscript_fn(%elt_adj, %index_int, %array_adj) + builder.createApply(loc, subscriptFnRef, subMap, + {eltAdjBuffer, eltIndexInt, arrayAdjoint}); + return eltAdjBuffer; +} + } // end namespace autodiff } // end namespace swift From 6f97c7a2fb721f5e812965e4b0e81a47989e68fc Mon Sep 17 00:00:00 2001 From: Saleem Abdulrasool Date: Thu, 9 Jul 2020 08:27:12 -0700 Subject: [PATCH 21/36] test: make `test_util` more python 3 friendly The last set of changes to make it backwards compatible with Python 2 required converting the arguments. That is not compatible on Python 3 unfortunately. Only perform that on Python 2 to make the utility compatible with 2 and 3. --- utils/incrparse/test_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/incrparse/test_util.py b/utils/incrparse/test_util.py index 35c965276c17e..e1a0109b7a1d2 100755 --- a/utils/incrparse/test_util.py +++ b/utils/incrparse/test_util.py @@ -22,7 +22,8 @@ def escapeCmdArg(arg): def run_command(cmd): - cmd = list(map(lambda s: s.encode('utf-8'), cmd)) + if sys.version_info[0] < 3: + cmd = list(map(lambda s: s.encode('utf-8'), cmd)) print(' '.join([escapeCmdArg(arg) for arg in cmd])) return subprocess.check_output(cmd, stderr=subprocess.STDOUT) From 67073c7a670d94019e28c7d1719ffcec14060581 Mon Sep 17 00:00:00 2001 From: Florian Friedrich Date: Thu, 9 Jul 2020 18:05:25 +0200 Subject: [PATCH 22/36] Fix typo in comparison in comments and help text --- include/swift/Driver/Compilation.h | 4 ++-- include/swift/Option/Options.td | 4 ++-- stdlib/public/core/StringComparable.swift | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/swift/Driver/Compilation.h b/include/swift/Driver/Compilation.h index 840681c69e868..1da048fa802b5 100644 --- a/include/swift/Driver/Compilation.h +++ b/include/swift/Driver/Compilation.h @@ -56,7 +56,7 @@ namespace driver { enum class OutputLevel { /// Indicates that normal output should be produced. Normal, - + /// Indicates that only jobs should be printed and not run. (-###) PrintJobs, @@ -83,7 +83,7 @@ class Compilation { const bool &EnableIncrementalBuild; const bool EnableSourceRangeDependencies; - /// If not empty, the path to use to log the comparision. + /// If not empty, the path to use to log the comparison. const StringRef CompareIncrementalSchemesPath; const unsigned SwiftInputCount; diff --git a/include/swift/Option/Options.td b/include/swift/Option/Options.td index 4dc519e24b1d5..4918ef188c961 100644 --- a/include/swift/Option/Options.td +++ b/include/swift/Option/Options.td @@ -154,7 +154,7 @@ Flag<["-"], "enable-only-one-dependency-file">, Flags<[DoesNotAffectIncrementalB def disable_only_one_dependency_file : Flag<["-"], "disable-only-one-dependency-file">, Flags<[DoesNotAffectIncrementalBuild]>, HelpText<"Disables incremental build optimization that only produces one dependencies file">; - + def enable_source_range_dependencies : Flag<["-"], "enable-source-range-dependencies">, Flags<[]>, @@ -167,7 +167,7 @@ HelpText<"Print a simple message comparing dependencies with source ranges (w/ f def driver_compare_incremental_schemes_path : Separate<["-"], "driver-compare-incremental-schemes-path">, Flags<[ArgumentIsPath,DoesNotAffectIncrementalBuild]>, -HelpText<"Path to use for machine-readable comparision">, +HelpText<"Path to use for machine-readable comparison">, MetaVarName<"">; def driver_compare_incremental_schemes_path_EQ : diff --git a/stdlib/public/core/StringComparable.swift b/stdlib/public/core/StringComparable.swift index 9ddc860c24454..74686053d3863 100644 --- a/stdlib/public/core/StringComparable.swift +++ b/stdlib/public/core/StringComparable.swift @@ -65,7 +65,7 @@ extension StringProtocol { } extension String: Equatable { - @inlinable @inline(__always) // For the bitwise comparision + @inlinable @inline(__always) // For the bitwise comparison @_effects(readonly) @_semantics("string.equals") public static func == (lhs: String, rhs: String) -> Bool { @@ -74,7 +74,7 @@ extension String: Equatable { } extension String: Comparable { - @inlinable @inline(__always) // For the bitwise comparision + @inlinable @inline(__always) // For the bitwise comparison @_effects(readonly) public static func < (lhs: String, rhs: String) -> Bool { return _stringCompare(lhs._guts, rhs._guts, expecting: .less) From 6df3d90143902fc361df13159115db61e4bd4fa9 Mon Sep 17 00:00:00 2001 From: Anthony Latsis Date: Thu, 9 Jul 2020 00:39:30 +0300 Subject: [PATCH 23/36] [NFC] AssociatedTypeInference: Improve dumping of an inference solution --- lib/Sema/TypeCheckProtocolInference.cpp | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/lib/Sema/TypeCheckProtocolInference.cpp b/lib/Sema/TypeCheckProtocolInference.cpp index 81b4aad879314..1e9528f2c5bba 100644 --- a/lib/Sema/TypeCheckProtocolInference.cpp +++ b/lib/Sema/TypeCheckProtocolInference.cpp @@ -75,21 +75,26 @@ void InferredAssociatedTypesByWitness::dump(llvm::raw_ostream &out, } void InferredTypeWitnessesSolution::dump() const { + const auto numValueWitnesses = ValueWitnesses.size(); llvm::errs() << "Type Witnesses:\n"; for (auto &typeWitness : TypeWitnesses) { llvm::errs() << " " << typeWitness.first->getName() << " := "; typeWitness.second.first->print(llvm::errs()); - llvm::errs() << " value " << typeWitness.second.second << '\n'; + if (typeWitness.second.second == numValueWitnesses) { + llvm::errs() << ", abstract"; + } else { + llvm::errs() << ", inferred from $" << typeWitness.second.second; + } + llvm::errs() << '\n'; } llvm::errs() << "Value Witnesses:\n"; for (unsigned i : indices(ValueWitnesses)) { - auto &valueWitness = ValueWitnesses[i]; - llvm::errs() << i << ": " << (Decl*)valueWitness.first - << ' ' << valueWitness.first->getBaseName() << '\n'; - valueWitness.first->getDeclContext()->printContext(llvm::errs()); - llvm::errs() << " for " << (Decl*)valueWitness.second - << ' ' << valueWitness.second->getBaseName() << '\n'; - valueWitness.second->getDeclContext()->printContext(llvm::errs()); + const auto &valueWitness = ValueWitnesses[i]; + llvm::errs() << '$' << i << ":\n "; + valueWitness.first->dumpRef(llvm::errs()); + llvm::errs() << " ->\n "; + valueWitness.second->dumpRef(llvm::errs()); + llvm::errs() << '\n'; } } From 42b19b75c7438f5c1244938df5a66879c75582b6 Mon Sep 17 00:00:00 2001 From: Anthony Latsis Date: Thu, 9 Jul 2020 18:34:41 +0300 Subject: [PATCH 24/36] [NFC] Remove a redundant location print after dumping a decl reference --- lib/Sema/Constraint.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/Sema/Constraint.cpp b/lib/Sema/Constraint.cpp index 0756b06ba7b6d..1037ce1b80ab5 100644 --- a/lib/Sema/Constraint.cpp +++ b/lib/Sema/Constraint.cpp @@ -383,9 +383,6 @@ void Constraint::print(llvm::raw_ostream &Out, SourceManager *sm) const { auto decl = overload.getDecl(); decl->dumpRef(Out); Out << " : " << decl->getInterfaceType(); - if (!sm || !decl->getLoc().isValid()) return; - Out << " at "; - decl->getLoc().print(Out, *sm); }; switch (overload.getKind()) { From 09ee4a60b1f34a3f47ef2f55ecdc87e8c2d7595b Mon Sep 17 00:00:00 2001 From: Saleem Abdulrasool Date: Thu, 9 Jul 2020 10:54:56 -0700 Subject: [PATCH 25/36] Update test_util.py Appease the python linter. --- utils/incrparse/test_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/incrparse/test_util.py b/utils/incrparse/test_util.py index e1a0109b7a1d2..2d3936d4be5cb 100755 --- a/utils/incrparse/test_util.py +++ b/utils/incrparse/test_util.py @@ -23,7 +23,7 @@ def escapeCmdArg(arg): def run_command(cmd): if sys.version_info[0] < 3: - cmd = list(map(lambda s: s.encode('utf-8'), cmd)) + cmd = list(map(lambda s: s.encode('utf-8'), cmd)) print(' '.join([escapeCmdArg(arg) for arg in cmd])) return subprocess.check_output(cmd, stderr=subprocess.STDOUT) From 2701a0809b919e84e5b5d1a10ed61934465bf95c Mon Sep 17 00:00:00 2001 From: Nate Chandler Date: Thu, 9 Jul 2020 09:50:00 -0700 Subject: [PATCH 26/36] [metadata prespecialization] Ptrauth for compared protocol conformances. Two protocol conformance descriptors are passed to swift_compareProtocolConformanceDecriptors from generic metadata accessors when there is a canonical prespecialization and one of the generic arguments has a protocol requirement. Previously, the descriptors were incorrectly being passed without ptrauth processing: one from the witness table in the arguments that are passed in to the accessor and one known statically. Here, the descriptor in the witness table is authed using the ProtocolConformanceDescriptor schema. Then, both descriptors are signed using the ProtocolConformanceDescriptorsAsArguments schema. Finally, in the runtime function, the descriptors are authed. --- include/swift/ABI/MetadataValues.h | 3 +++ include/swift/AST/IRGenOptions.h | 6 ++++++ lib/IRGen/GenPointerAuth.cpp | 3 +++ lib/IRGen/GenPointerAuth.h | 2 ++ lib/IRGen/IRGen.cpp | 5 +++++ lib/IRGen/MetadataRequest.cpp | 28 +++++++++++++++++++++++++++- stdlib/public/runtime/Metadata.cpp | 5 +++++ 7 files changed, 51 insertions(+), 1 deletion(-) diff --git a/include/swift/ABI/MetadataValues.h b/include/swift/ABI/MetadataValues.h index 31b60023e1938..6ab98e9e37e56 100644 --- a/include/swift/ABI/MetadataValues.h +++ b/include/swift/ABI/MetadataValues.h @@ -1118,6 +1118,9 @@ namespace SpecialPointerAuthDiscriminators { /// Runtime function variables exported by the runtime. const uint16_t RuntimeFunctionEntry = 0x625b; + /// Protocol conformance descriptors. + const uint16_t ProtocolConformanceDescriptor = 0xc6eb; + /// Value witness functions. const uint16_t InitializeBufferWithCopyOfBuffer = 0xda4a; const uint16_t Destroy = 0x04f8; diff --git a/include/swift/AST/IRGenOptions.h b/include/swift/AST/IRGenOptions.h index 9b55d54d28d6a..9f70678ca6e05 100644 --- a/include/swift/AST/IRGenOptions.h +++ b/include/swift/AST/IRGenOptions.h @@ -117,6 +117,12 @@ struct PointerAuthOptions : clang::PointerAuthOptions { /// Type descriptor data pointers when passed as arguments. PointerAuthSchema TypeDescriptorsAsArguments; + /// Protocol conformance descriptors. + PointerAuthSchema ProtocolConformanceDescriptors; + + /// Protocol conformance descriptors when passed as arguments. + PointerAuthSchema ProtocolConformanceDescriptorsAsArguments; + /// Resumption functions from yield-once coroutines. PointerAuthSchema YieldOnceResumeFunctions; diff --git a/lib/IRGen/GenPointerAuth.cpp b/lib/IRGen/GenPointerAuth.cpp index 10edd40a0dea2..fa6821b267900 100644 --- a/lib/IRGen/GenPointerAuth.cpp +++ b/lib/IRGen/GenPointerAuth.cpp @@ -309,6 +309,9 @@ PointerAuthEntity::getDeclDiscriminator(IRGenModule &IGM) const { case Special::TypeDescriptor: case Special::TypeDescriptorAsArgument: return SpecialPointerAuthDiscriminators::TypeDescriptor; + case Special::ProtocolConformanceDescriptor: + case Special::ProtocolConformanceDescriptorAsArgument: + return SpecialPointerAuthDiscriminators::ProtocolConformanceDescriptor; case Special::PartialApplyCapture: return PointerAuthDiscriminator_PartialApplyCapture; case Special::KeyPathDestroy: diff --git a/lib/IRGen/GenPointerAuth.h b/lib/IRGen/GenPointerAuth.h index c30e4daf50668..8640f68cbf30f 100644 --- a/lib/IRGen/GenPointerAuth.h +++ b/lib/IRGen/GenPointerAuth.h @@ -62,6 +62,8 @@ class PointerAuthEntity { KeyPathInitializer, KeyPathMetadataAccessor, DynamicReplacementKey, + ProtocolConformanceDescriptor, + ProtocolConformanceDescriptorAsArgument, }; private: diff --git a/lib/IRGen/IRGen.cpp b/lib/IRGen/IRGen.cpp index ca88f2efaccff..a0c233979069f 100644 --- a/lib/IRGen/IRGen.cpp +++ b/lib/IRGen/IRGen.cpp @@ -651,6 +651,11 @@ static void setPointerAuthOptions(PointerAuthOptions &opts, opts.SwiftDynamicReplacementKeys = PointerAuthSchema(dataKey, /*address*/ true, Discrimination::Decl); + opts.ProtocolConformanceDescriptors = + PointerAuthSchema(dataKey, /*address*/ true, Discrimination::Decl); + opts.ProtocolConformanceDescriptorsAsArguments = + PointerAuthSchema(dataKey, /*address*/ false, Discrimination::Decl); + // Coroutine resumption functions are never stored globally in the ABI, // so we can do some things that aren't normally okay to do. However, // we can't use ASIB because that would break ARM64 interoperation. diff --git a/lib/IRGen/MetadataRequest.cpp b/lib/IRGen/MetadataRequest.cpp index e69c72d0914ac..ebb2f45913961 100644 --- a/lib/IRGen/MetadataRequest.cpp +++ b/lib/IRGen/MetadataRequest.cpp @@ -1970,7 +1970,7 @@ static void emitCanonicalSpecializationsForGenericTypeMetadataAccessFunction( } else { RootProtocolConformance *rootConformance = concreteConformance->getRootConformance(); - auto *expectedDescriptor = + llvm::Value *expectedDescriptor = IGF.IGM.getAddrOfProtocolConformanceDescriptor(rootConformance); auto *witnessTable = valueAtIndex(requirementIndex); auto *witnessBuffer = @@ -1981,6 +1981,32 @@ static void emitCanonicalSpecializationsForGenericTypeMetadataAccessFunction( uncastProvidedDescriptor, IGM.ProtocolConformanceDescriptorPtrTy); + // Auth the stored descriptor. + auto storedScheme = + IGF.IGM.getOptions().PointerAuth.ProtocolConformanceDescriptors; + if (storedScheme) { + auto authInfo = PointerAuthInfo::emit( + IGF, storedScheme, witnessTable, + PointerAuthEntity::Special::ProtocolConformanceDescriptor); + providedDescriptor = + emitPointerAuthAuth(IGF, providedDescriptor, authInfo); + } + + // Sign the descriptors. + auto argScheme = + IGF.IGM.getOptions() + .PointerAuth.ProtocolConformanceDescriptorsAsArguments; + if (argScheme) { + auto authInfo = PointerAuthInfo::emit( + IGF, argScheme, nullptr, + PointerAuthEntity::Special:: + ProtocolConformanceDescriptorAsArgument); + expectedDescriptor = + emitPointerAuthSign(IGF, expectedDescriptor, authInfo); + providedDescriptor = + emitPointerAuthSign(IGF, providedDescriptor, authInfo); + } + auto *call = IGF.Builder.CreateCall( IGF.IGM.getCompareProtocolConformanceDescriptorsFn(), {providedDescriptor, expectedDescriptor}); diff --git a/stdlib/public/runtime/Metadata.cpp b/stdlib/public/runtime/Metadata.cpp index 5703748dfc182..18dc96a828c6e 100644 --- a/stdlib/public/runtime/Metadata.cpp +++ b/stdlib/public/runtime/Metadata.cpp @@ -4996,6 +4996,11 @@ const WitnessTable *swift::swift_getAssociatedConformanceWitness( bool swift::swift_compareProtocolConformanceDescriptors( const ProtocolConformanceDescriptor *lhs, const ProtocolConformanceDescriptor *rhs) { + lhs = swift_auth_data_non_address( + lhs, SpecialPointerAuthDiscriminators::ProtocolConformanceDescriptor); + rhs = swift_auth_data_non_address( + rhs, SpecialPointerAuthDiscriminators::ProtocolConformanceDescriptor); + return MetadataCacheKey::compareProtocolConformanceDescriptors(lhs, rhs) == 0; } From 3543cf85c454aaad21d3c23009b7d3f6d93141cd Mon Sep 17 00:00:00 2001 From: Andrew Trick Date: Thu, 9 Jul 2020 11:09:06 -0700 Subject: [PATCH 27/36] Cleanup MemAccessUtils. Organize the utilities in this file by section to make subsequent diffs easier to read and prepare for adding more utilities. --- include/swift/SIL/MemAccessUtils.h | 218 ++++++++++-------- lib/SIL/Utils/MemAccessUtils.cpp | 340 ++++++++++++++--------------- 2 files changed, 294 insertions(+), 264 deletions(-) diff --git a/include/swift/SIL/MemAccessUtils.h b/include/swift/SIL/MemAccessUtils.h index 6670c09730ee3..cb4c918ab1384 100644 --- a/include/swift/SIL/MemAccessUtils.h +++ b/include/swift/SIL/MemAccessUtils.h @@ -44,6 +44,10 @@ #include "swift/SIL/SILInstruction.h" #include "llvm/ADT/DenseMap.h" +//===----------------------------------------------------------------------===// +// MARK: General Helpers +//===----------------------------------------------------------------------===// + namespace swift { /// Get the base address of a formal access by stripping access markers and @@ -103,6 +107,14 @@ inline bool accessKindMayConflict(SILAccessKind a, SILAccessKind b) { return !(a == SILAccessKind::Read && b == SILAccessKind::Read); } +} // end namespace swift + +//===----------------------------------------------------------------------===// +// MARK: AccessedStorage +//===----------------------------------------------------------------------===// + +namespace swift { + /// Represents the identity of a storage object being accessed. /// /// AccessedStorage is carefully designed to solve three problems: @@ -406,24 +418,23 @@ class AccessedStorage { bool operator==(const AccessedStorage &) const = delete; bool operator!=(const AccessedStorage &) const = delete; }; + } // end namespace swift namespace llvm { - /// Enable using AccessedStorage as a key in DenseMap. /// Do *not* include any extra pass data in key equality. template <> struct DenseMapInfo { static swift::AccessedStorage getEmptyKey() { return swift::AccessedStorage(swift::SILValue::getFromOpaqueValue( - llvm::DenseMapInfo::getEmptyKey()), - swift::AccessedStorage::Unidentified); + llvm::DenseMapInfo::getEmptyKey()), + swift::AccessedStorage::Unidentified); } static swift::AccessedStorage getTombstoneKey() { - return swift::AccessedStorage( - swift::SILValue::getFromOpaqueValue( - llvm::DenseMapInfo::getTombstoneKey()), - swift::AccessedStorage::Unidentified); + return swift::AccessedStorage(swift::SILValue::getFromOpaqueValue( + llvm::DenseMapInfo::getTombstoneKey()), + swift::AccessedStorage::Unidentified); } static unsigned getHashValue(swift::AccessedStorage storage) { @@ -450,63 +461,10 @@ template <> struct DenseMapInfo { return LHS.hasIdenticalBase(RHS); } }; - -} // end namespace llvm +} // namespace llvm namespace swift { -/// Abstract CRTP class for a visitor passed to \c visitAccessUseDefChain. -template -class AccessUseDefChainVisitor { -protected: - Impl &asImpl() { - return static_cast(*this); - } -public: - // Subclasses can provide a method for any identified access base: - // Result visitBase(SILValue base, AccessedStorage::Kind kind); - - // Visitors for specific identified access kinds. These default to calling out - // to visitIdentified. - - Result visitClassAccess(RefElementAddrInst *field) { - return asImpl().visitBase(field, AccessedStorage::Class); - } - Result visitArgumentAccess(SILFunctionArgument *arg) { - return asImpl().visitBase(arg, AccessedStorage::Argument); - } - Result visitBoxAccess(AllocBoxInst *box) { - return asImpl().visitBase(box, AccessedStorage::Box); - } - /// The argument may be either a GlobalAddrInst or the ApplyInst for a global accessor function. - Result visitGlobalAccess(SILValue global) { - return asImpl().visitBase(global, AccessedStorage::Global); - } - Result visitYieldAccess(BeginApplyResult *yield) { - return asImpl().visitBase(yield, AccessedStorage::Yield); - } - Result visitStackAccess(AllocStackInst *stack) { - return asImpl().visitBase(stack, AccessedStorage::Stack); - } - Result visitNestedAccess(BeginAccessInst *access) { - return asImpl().visitBase(access, AccessedStorage::Nested); - } - - // Visitors for unidentified base values. - - Result visitUnidentified(SILValue base) { - return asImpl().visitBase(base, AccessedStorage::Unidentified); - } - - // Subclasses must provide implementations to visit non-access bases - // and phi arguments, and for incomplete projections from the access: - // void visitNonAccess(SILValue base); - // void visitPhi(SILPhiArgument *phi); - // void visitIncomplete(SILValue projectedAddr, SILValue parentAddr); - - Result visit(SILValue sourceAddr); -}; - /// Given an address accessed by an instruction that reads or modifies /// memory, return an AccessedStorage object that identifies the formal access. /// @@ -534,6 +492,14 @@ AccessedStorage findAccessedStorage(SILValue sourceAddr); /// access has Unsafe enforcement. AccessedStorage findAccessedStorageNonNested(SILValue sourceAddr); +} // end namespace swift + +//===----------------------------------------------------------------------===// +// MARK: Helper API +//===----------------------------------------------------------------------===// + +namespace swift { + /// Return true if the given address operand is used by a memory operation that /// initializes the memory at that address, implying that the previous value is /// uninitialized. @@ -550,6 +516,24 @@ bool memInstMustInitialize(Operand *memOper); bool isSingleInitAllocStack(AllocStackInst *asi, SmallVectorImpl &destroyingUses); +/// Return true if the given address value is produced by a special address +/// producer that is only used for local initialization, not formal access. +bool isAddressForLocalInitOnly(SILValue sourceAddr); + +/// Return true if the given apply invokes a global addressor defined in another +/// module. +bool isExternalGlobalAddressor(ApplyInst *AI); + +/// Return true if the given StructExtractInst extracts the RawPointer from +/// Unsafe[Mutable]Pointer. +bool isUnsafePointerExtraction(StructExtractInst *SEI); + +/// Given a block argument address base, check if it is actually a box projected +/// from a switch_enum. This is a valid pattern at any SIL stage resulting in a +/// block-type phi. In later SIL stages, the optimizer may form address-type +/// phis, causing this assert if called on those cases. +void checkSwitchEnumBlockArg(SILPhiArgument *arg); + /// Return true if the given address producer may be the source of a formal /// access (a read or write of a potentially aliased, user visible variable). /// @@ -560,17 +544,14 @@ bool isSingleInitAllocStack(AllocStackInst *asi, /// storage = findAccessedStorage(address) /// needsAccessMarker = storage && isPossibleFormalAccessBase(storage) /// -/// Warning: This is only valid for SIL with well-formed accessed. For example, +/// Warning: This is only valid for SIL with well-formed accesses. For example, /// it will not handle address-type phis. Optimization passes after /// DiagnoseStaticExclusivity may violate these assumptions. -bool isPossibleFormalAccessBase(const AccessedStorage &storage, SILFunction *F); - -/// Visit each address accessed by the given memory operation. /// -/// This only visits instructions that modify memory in some user-visible way, -/// which could be considered part of a formal access. -void visitAccessedAddress(SILInstruction *I, - llvm::function_ref visitor); +/// This is not a member of AccessedStorage because it only makes sense to use +/// in SILGen before access markers are emitted, or when verifying access +/// markers. +bool isPossibleFormalAccessBase(const AccessedStorage &storage, SILFunction *F); /// Perform a RAUW operation on begin_access with it's own source operand. /// Then erase the begin_access and all associated end_access instructions. @@ -580,20 +561,65 @@ void visitAccessedAddress(SILInstruction *I, /// instruction following this begin_access was not also erased. SILBasicBlock::iterator removeBeginAccess(BeginAccessInst *beginAccess); -/// Return true if the given address value is produced by a special address -/// producer that is only used for local initialization, not formal access. -bool isAddressForLocalInitOnly(SILValue sourceAddr); -/// Return true if the given apply invokes a global addressor defined in another -/// module. -bool isExternalGlobalAddressor(ApplyInst *AI); -/// Return true if the given StructExtractInst extracts the RawPointer from -/// Unsafe[Mutable]Pointer. -bool isUnsafePointerExtraction(StructExtractInst *SEI); -/// Given a block argument address base, check if it is actually a box projected -/// from a switch_enum. This is a valid pattern at any SIL stage resulting in a -/// block-type phi. In later SIL stages, the optimizer may form address-type -/// phis, causing this assert if called on those cases. -void checkSwitchEnumBlockArg(SILPhiArgument *arg); +} // end namespace swift + +//===----------------------------------------------------------------------===// +// MARK: AccessUseDefChainVisitor +//===----------------------------------------------------------------------===// + +namespace swift { + +/// Abstract CRTP class for a visitor passed to \c visitAccessUseDefChain. +template +class AccessUseDefChainVisitor { +protected: + Impl &asImpl() { return static_cast(*this); } + +public: + // Subclasses can provide a method for any identified access base: + // Result visitBase(SILValue base, AccessedStorage::Kind kind); + + // Visitors for specific identified access kinds. These default to calling out + // to visitIdentified. + + Result visitClassAccess(RefElementAddrInst *field) { + return asImpl().visitBase(field, AccessedStorage::Class); + } + Result visitArgumentAccess(SILFunctionArgument *arg) { + return asImpl().visitBase(arg, AccessedStorage::Argument); + } + Result visitBoxAccess(AllocBoxInst *box) { + return asImpl().visitBase(box, AccessedStorage::Box); + } + /// The argument may be either a GlobalAddrInst or the ApplyInst for a global + /// accessor function. + Result visitGlobalAccess(SILValue global) { + return asImpl().visitBase(global, AccessedStorage::Global); + } + Result visitYieldAccess(BeginApplyResult *yield) { + return asImpl().visitBase(yield, AccessedStorage::Yield); + } + Result visitStackAccess(AllocStackInst *stack) { + return asImpl().visitBase(stack, AccessedStorage::Stack); + } + Result visitNestedAccess(BeginAccessInst *access) { + return asImpl().visitBase(access, AccessedStorage::Nested); + } + + // Visitors for unidentified base values. + + Result visitUnidentified(SILValue base) { + return asImpl().visitBase(base, AccessedStorage::Unidentified); + } + + // Subclasses must provide implementations to visit non-access bases + // and phi arguments, and for incomplete projections from the access: + // void visitNonAccess(SILValue base); + // void visitPhi(SILPhiArgument *phi); + // void visitIncomplete(SILValue projectedAddr, SILValue parentAddr); + + Result visit(SILValue sourceAddr); +}; template Result AccessUseDefChainVisitor::visit(SILValue sourceAddr) { @@ -727,11 +753,12 @@ Result AccessUseDefChainVisitor::visit(SILValue sourceAddr) { return asImpl().visitIncomplete(sourceAddr, cast(sourceAddr)->getOperand(0)); - // Access to a Builtin.RawPointer. Treat this like the inductive cases - // above because some RawPointers originate from identified locations. See - // the special case for global addressors, which return RawPointer, - // above. AddressToPointer is also handled because it results from inlining a - // global addressor without folding the AddressToPointer->PointerToAddress. + // Access to a Builtin.RawPointer. Treat this like the inductive cases above + // because some RawPointers originate from identified locations. See the + // special case for global addressors, which return RawPointer, + // above. AddressToPointer is also handled because it results from inlining + // a global addressor without folding the + // AddressToPointer->PointerToAddress. // // If the inductive search does not find a valid addressor, it will // eventually reach the default case that returns in invalid location. This @@ -768,4 +795,19 @@ Result AccessUseDefChainVisitor::visit(SILValue sourceAddr) { } // end namespace swift +//===----------------------------------------------------------------------===// +// MARK: Verification +//===----------------------------------------------------------------------===// + +namespace swift { + +/// Visit each address accessed by the given memory operation. +/// +/// This only visits instructions that modify memory in some user-visible way, +/// which could be considered part of a formal access. +void visitAccessedAddress(SILInstruction *I, + llvm::function_ref visitor); + +} // end namespace swift + #endif diff --git a/lib/SIL/Utils/MemAccessUtils.cpp b/lib/SIL/Utils/MemAccessUtils.cpp index 6f041f59995e3..2841239bf0e0c 100644 --- a/lib/SIL/Utils/MemAccessUtils.cpp +++ b/lib/SIL/Utils/MemAccessUtils.cpp @@ -21,6 +21,10 @@ using namespace swift; +//===----------------------------------------------------------------------===// +// MARK: General Helpers +//===----------------------------------------------------------------------===// + SILValue swift::stripAccessMarkers(SILValue v) { while (auto *bai = dyn_cast(v)) { v = bai->getOperand(); @@ -82,6 +86,10 @@ bool swift::isLetAddress(SILValue accessedAddress) { return false; } +//===----------------------------------------------------------------------===// +// MARK: AccessedStorage +//===----------------------------------------------------------------------===// + AccessedStorage::AccessedStorage(SILValue base, Kind kind) { assert(base && "invalid storage base"); initKind(kind); @@ -138,6 +146,20 @@ AccessedStorage::AccessedStorage(SILValue base, Kind kind) { } } +// Return true if the given access is on a 'let' lvalue. +bool AccessedStorage::isLetAccess(SILFunction *F) const { + if (auto *decl = dyn_cast_or_null(getDecl())) + return decl->isLet(); + + // It's unclear whether a global will ever be missing it's varDecl, but + // technically we only preserve it for debug info. So if we don't have a decl, + // check the flag on SILGlobalVariable, which is guaranteed valid, + if (getKind() == AccessedStorage::Global) + return getGlobal()->isLet(); + + return false; +} + const ValueDecl *AccessedStorage::getDecl() const { switch (getKind()) { case Box: @@ -217,93 +239,7 @@ void AccessedStorage::print(raw_ostream &os) const { void AccessedStorage::dump() const { print(llvm::dbgs()); } -// Return true if the given apply invokes a global addressor defined in another -// module. -bool swift::isExternalGlobalAddressor(ApplyInst *AI) { - FullApplySite apply(AI); - auto *funcRef = apply.getReferencedFunctionOrNull(); - if (!funcRef) - return false; - - return funcRef->isGlobalInit() && funcRef->isExternalDeclaration(); -} - -// Return true if the given StructExtractInst extracts the RawPointer from -// Unsafe[Mutable]Pointer. -bool swift::isUnsafePointerExtraction(StructExtractInst *SEI) { - assert(isa(SEI->getType().getASTType())); - auto &C = SEI->getModule().getASTContext(); - auto *decl = SEI->getStructDecl(); - return decl == C.getUnsafeMutablePointerDecl() - || decl == C.getUnsafePointerDecl(); -} - -// Given a block argument address base, check if it is actually a box projected -// from a switch_enum. This is a valid pattern at any SIL stage resulting in a -// block-type phi. In later SIL stages, the optimizer may form address-type -// phis, causing this assert if called on those cases. -void swift::checkSwitchEnumBlockArg(SILPhiArgument *arg) { - assert(!arg->getType().isAddress()); - SILBasicBlock *Pred = arg->getParent()->getSinglePredecessorBlock(); - if (!Pred || !isa(Pred->getTerminator())) { - arg->dump(); - llvm_unreachable("unexpected box source."); - } -} - -/// Return true if the given address value is produced by a special address -/// producer that is only used for local initialization, not formal access. -bool swift::isAddressForLocalInitOnly(SILValue sourceAddr) { - switch (sourceAddr->getKind()) { - default: - return false; - - // Value to address conversions: the operand is the non-address source - // value. These allow local mutation of the value but should never be used - // for formal access of an lvalue. - case ValueKind::OpenExistentialBoxInst: - case ValueKind::ProjectExistentialBoxInst: - return true; - - // Self-evident local initialization. - case ValueKind::InitEnumDataAddrInst: - case ValueKind::InitExistentialAddrInst: - case ValueKind::AllocExistentialBoxInst: - case ValueKind::AllocValueBufferInst: - case ValueKind::ProjectValueBufferInst: - return true; - } -} - namespace { -// The result of an accessed storage query. A complete result produces an -// AccessedStorage object, which may or may not be valid. An incomplete result -// produces a SILValue representing the source address for use with deeper -// queries. The source address is not necessarily a SIL address type because -// the query can extend past pointer-to-address casts. -class AccessedStorageResult { - AccessedStorage storage; - SILValue address; - bool complete; - - explicit AccessedStorageResult(SILValue address) - : address(address), complete(false) {} - -public: - AccessedStorageResult(const AccessedStorage &storage) - : storage(storage), complete(true) {} - - static AccessedStorageResult incomplete(SILValue address) { - return AccessedStorageResult(address); - } - - bool isComplete() const { return complete; } - - AccessedStorage getStorage() const { assert(complete); return storage; } - - SILValue getAddress() const { assert(!complete); return address; } -}; - struct FindAccessedStorageVisitor : public AccessUseDefChainVisitor { @@ -372,19 +308,9 @@ AccessedStorage swift::findAccessedStorageNonNested(SILValue sourceAddr) { } } -// Return true if the given access is on a 'let' lvalue. -bool AccessedStorage::isLetAccess(SILFunction *F) const { - if (auto *decl = dyn_cast_or_null(getDecl())) - return decl->isLet(); - - // It's unclear whether a global will ever be missing it's varDecl, but - // technically we only preserve it for debug info. So if we don't have a decl, - // check the flag on SILGlobalVariable, which is guaranteed valid, - if (getKind() == AccessedStorage::Global) - return getGlobal()->isLet(); - - return false; -} +//===----------------------------------------------------------------------===// +// MARK: Helper API +//===----------------------------------------------------------------------===// static bool isScratchBuffer(SILValue value) { // Special case unsafe value buffer access. @@ -419,6 +345,125 @@ bool swift::memInstMustInitialize(Operand *memOper) { } } +bool swift::isSingleInitAllocStack(AllocStackInst *asi, + SmallVectorImpl &destroyingUses) { + // For now, we just look through projections and rely on memInstMustInitialize + // to classify all other uses as init or not. + SmallVector worklist(asi->getUses()); + bool foundInit = false; + + while (!worklist.empty()) { + auto *use = worklist.pop_back_val(); + auto *user = use->getUser(); + + if (Projection::isAddressProjection(user) + || isa(user)) { + // Look through address projections. + for (SILValue r : user->getResults()) { + llvm::copy(r->getUses(), std::back_inserter(worklist)); + } + continue; + } + + if (auto *li = dyn_cast(user)) { + // If we are not taking, + if (li->getOwnershipQualifier() != LoadOwnershipQualifier::Take) { + continue; + } + // Treat load [take] as a write. + return false; + } + + switch (user->getKind()) { + default: + break; + case SILInstructionKind::DestroyAddrInst: + destroyingUses.push_back(use); + continue; + case SILInstructionKind::DeallocStackInst: + case SILInstructionKind::LoadBorrowInst: + case SILInstructionKind::DebugValueAddrInst: + continue; + } + + // See if we have an initializer and that such initializer is in the same + // block. + if (memInstMustInitialize(use)) { + if (user->getParent() != asi->getParent() || foundInit) { + return false; + } + + foundInit = true; + continue; + } + + // Otherwise, if we have found something not in our whitelist, return false. + return false; + } + + // We did not find any users that we did not understand. So we can + // conservatively return true here. + return true; +} + +/// Return true if the given address value is produced by a special address +/// producer that is only used for local initialization, not formal access. +bool swift::isAddressForLocalInitOnly(SILValue sourceAddr) { + switch (sourceAddr->getKind()) { + default: + return false; + + // Value to address conversions: the operand is the non-address source + // value. These allow local mutation of the value but should never be used + // for formal access of an lvalue. + case ValueKind::OpenExistentialBoxInst: + case ValueKind::ProjectExistentialBoxInst: + return true; + + // Self-evident local initialization. + case ValueKind::InitEnumDataAddrInst: + case ValueKind::InitExistentialAddrInst: + case ValueKind::AllocExistentialBoxInst: + case ValueKind::AllocValueBufferInst: + case ValueKind::ProjectValueBufferInst: + return true; + } +} + +// Return true if the given apply invokes a global addressor defined in another +// module. +bool swift::isExternalGlobalAddressor(ApplyInst *AI) { + FullApplySite apply(AI); + auto *funcRef = apply.getReferencedFunctionOrNull(); + if (!funcRef) + return false; + + return funcRef->isGlobalInit() && funcRef->isExternalDeclaration(); +} + +// Return true if the given StructExtractInst extracts the RawPointer from +// Unsafe[Mutable]Pointer. +bool swift::isUnsafePointerExtraction(StructExtractInst *SEI) { + assert(isa(SEI->getType().getASTType())); + auto &C = SEI->getModule().getASTContext(); + auto *decl = SEI->getStructDecl(); + return decl == C.getUnsafeMutablePointerDecl() + || decl == C.getUnsafePointerDecl(); +} + +// Given a block argument address base, check if it is actually a box projected +// from a switch_enum. This is a valid pattern at any SIL stage resulting in a +// block-type phi. In later SIL stages, the optimizer may form address-type +// phis, causing this assert if called on those cases. +void swift::checkSwitchEnumBlockArg(SILPhiArgument *arg) { + assert(!arg->getType().isAddress()); + SILBasicBlock *Pred = arg->getParent()->getSinglePredecessorBlock(); + if (!Pred || !isa(Pred->getTerminator())) { + arg->dump(); + llvm_unreachable("unexpected box source."); + } +} + bool swift::isPossibleFormalAccessBase(const AccessedStorage &storage, SILFunction *F) { switch (storage.getKind()) { @@ -477,6 +522,26 @@ bool swift::isPossibleFormalAccessBase(const AccessedStorage &storage, return true; } +SILBasicBlock::iterator swift::removeBeginAccess(BeginAccessInst *beginAccess) { + while (!beginAccess->use_empty()) { + Operand *op = *beginAccess->use_begin(); + + // Delete any associated end_access instructions. + if (auto endAccess = dyn_cast(op->getUser())) { + endAccess->eraseFromParent(); + + // Forward all other uses to the original address. + } else { + op->set(beginAccess->getSource()); + } + } + return beginAccess->getParent()->erase(beginAccess); +} + +//===----------------------------------------------------------------------===// +// Verification +//===----------------------------------------------------------------------===// + /// Helper for visitApplyAccesses that visits address-type call arguments, /// including arguments to @noescape functions that are passed as closures to /// the current call. @@ -707,80 +772,3 @@ void swift::visitAccessedAddress(SILInstruction *I, return; } } - -SILBasicBlock::iterator swift::removeBeginAccess(BeginAccessInst *beginAccess) { - while (!beginAccess->use_empty()) { - Operand *op = *beginAccess->use_begin(); - - // Delete any associated end_access instructions. - if (auto endAccess = dyn_cast(op->getUser())) { - endAccess->eraseFromParent(); - - // Forward all other uses to the original address. - } else { - op->set(beginAccess->getSource()); - } - } - return beginAccess->getParent()->erase(beginAccess); -} - -bool swift::isSingleInitAllocStack(AllocStackInst *asi, - SmallVectorImpl &destroyingUses) { - // For now, we just look through projections and rely on memInstMustInitialize - // to classify all other uses as init or not. - SmallVector worklist(asi->getUses()); - bool foundInit = false; - - while (!worklist.empty()) { - auto *use = worklist.pop_back_val(); - auto *user = use->getUser(); - - if (Projection::isAddressProjection(user) || - isa(user)) { - // Look through address projections. - for (SILValue r : user->getResults()) { - llvm::copy(r->getUses(), std::back_inserter(worklist)); - } - continue; - } - - if (auto *li = dyn_cast(user)) { - // If we are not taking, - if (li->getOwnershipQualifier() != LoadOwnershipQualifier::Take) { - continue; - } - // Treat load [take] as a write. - return false; - } - - switch (user->getKind()) { - default: - break; - case SILInstructionKind::DestroyAddrInst: - destroyingUses.push_back(use); - continue; - case SILInstructionKind::DeallocStackInst: - case SILInstructionKind::LoadBorrowInst: - case SILInstructionKind::DebugValueAddrInst: - continue; - } - - // See if we have an initializer and that such initializer is in the same - // block. - if (memInstMustInitialize(use)) { - if (user->getParent() != asi->getParent() || foundInit) { - return false; - } - - foundInit = true; - continue; - } - - // Otherwise, if we have found something not in our whitelist, return false. - return false; - } - - // We did not find any users that we did not understand. So we can - // conservatively return true here. - return true; -} From 17be66c3d62464ee9d46df8388200a77e6a5385e Mon Sep 17 00:00:00 2001 From: Rintaro Ishizaki Date: Wed, 8 Jul 2020 21:12:12 -0700 Subject: [PATCH 28/36] [PlaceholderExpansion] Omit return type in closure signature Return type in the closure signature is often redundant when expanding placeholders, because the type of the clossures are usually inferred from the context (i.e. calling function), users don't need to write the return type explicitly. They are not only redundant, but also sometimes harmful when the return type is a generic parameter or its requirement. Actually, there is no correct spelling in such cases. So omit the return type and the parentheses around the parameter clause. rdar://problem/63607976 --- lib/IDE/CodeCompletion.cpp | 10 +----- .../complete_multiple_trailingclosure.swift | 36 +++++++++---------- ..._multiple_trailingclosure_signatures.swift | 12 +++---- ...multiple_trailing_closure_signatures.swift | 10 +++--- ...de-expand-multiple-trailing-closures.swift | 8 ++--- test/SourceKit/CodeExpand/code-expand.swift | 12 +++---- tools/SourceKit/lib/SwiftLang/SwiftEditor.cpp | 32 +++-------------- 7 files changed, 44 insertions(+), 76 deletions(-) diff --git a/lib/IDE/CodeCompletion.cpp b/lib/IDE/CodeCompletion.cpp index aff9699acf673..0c59dc224403b 100644 --- a/lib/IDE/CodeCompletion.cpp +++ b/lib/IDE/CodeCompletion.cpp @@ -1016,10 +1016,6 @@ void CodeCompletionResultBuilder::addCallParameter(Identifier Name, SmallString<32> buffer; llvm::raw_svector_ostream OS(buffer); - bool returnsVoid = AFT->getResult()->isVoid(); - bool hasSignature = !returnsVoid || !AFT->getParams().empty(); - if (hasSignature) - OS << "("; bool firstParam = true; for (const auto ¶m : AFT->getParams()) { if (!firstParam) @@ -1038,12 +1034,8 @@ void CodeCompletionResultBuilder::addCallParameter(Identifier Name, OS << "#>"; } } - if (hasSignature) - OS << ")"; - if (!returnsVoid) - OS << " -> " << AFT->getResult()->getString(PO); - if (hasSignature) + if (!firstParam) OS << " in"; addChunkWithText( diff --git a/test/IDE/complete_multiple_trailingclosure.swift b/test/IDE/complete_multiple_trailingclosure.swift index efef1d98aca64..d78d6ee6bdcb5 100644 --- a/test/IDE/complete_multiple_trailingclosure.swift +++ b/test/IDE/complete_multiple_trailingclosure.swift @@ -24,11 +24,11 @@ func testGlobalFunc() { { 1 } #^GLOBALFUNC_SAMELINE^# #^GLOBALFUNC_NEWLINE^# // GLOBALFUNC_SAMELINE: Begin completions, 1 items -// GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn2: () -> String {() -> String in|}#}[#() -> String#]; +// GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn2: () -> String {|}#}[#() -> String#]; // GLOBALFUNC_SAMELINE: End completions // GLOBALFUNC_NEWLINE: Begin completions, 1 items -// GLOBALFUNC_NEWLINE-DAG: Pattern/ExprSpecific: {#fn2: () -> String {() -> String in|}#}[#() -> String#]; +// GLOBALFUNC_NEWLINE-DAG: Pattern/ExprSpecific: {#fn2: () -> String {|}#}[#() -> String#]; // GLOBALFUNC_NEWLINE: End completions globalFunc1() @@ -53,14 +53,14 @@ func testMethod(value: MyStruct) { } #^METHOD_SAMELINE^# #^METHOD_NEWLINE^# // METHOD_SAMELINE: Begin completions, 4 items -// METHOD_SAMELINE-DAG: Pattern/ExprSpecific: {#fn2: (() -> String)? {() -> String in|}#}[#(() -> String)?#]; +// METHOD_SAMELINE-DAG: Pattern/ExprSpecific: {#fn2: (() -> String)? {|}#}[#(() -> String)?#]; // METHOD_SAMELINE-DAG: Decl[InstanceMethod]/CurrNominal: .enumFunc()[#Void#]; // METHOD_SAMELINE-DAG: Decl[InfixOperatorFunction]/OtherModule[Swift]/IsSystem: [' ']+ {#SimpleEnum#}[#SimpleEnum#]; // METHOD_SAMELINE-DAG: Keyword[self]/CurrNominal: .self[#SimpleEnum#]; // METHOD_SAMELINE: End completions // METHOD_NEWLINE: Begin completions -// METHOD_NEWLINE-DAG: Pattern/ExprSpecific: {#fn2: (() -> String)? {() -> String in|}#}[#(() -> String)?#]; +// METHOD_NEWLINE-DAG: Pattern/ExprSpecific: {#fn2: (() -> String)? {|}#}[#(() -> String)?#]; // METHOD_NEWLINE-DAG: Keyword[class]/None: class; // METHOD_NEWLINE-DAG: Keyword[if]/None: if; // METHOD_NEWLINE-DAG: Keyword[try]/None: try; @@ -84,15 +84,15 @@ func testOverloadedInit() { #^INIT_OVERLOADED_NEWLINE^# // INIT_OVERLOADED_SAMELINE: Begin completions, 4 items -// INIT_OVERLOADED_SAMELINE-DAG: Pattern/ExprSpecific: {#fn2: () -> String {() -> String in|}#}[#() -> String#]; -// INIT_OVERLOADED_SAMELINE-DAG: Pattern/ExprSpecific: {#fn3: () -> String {() -> String in|}#}[#() -> String#]; +// INIT_OVERLOADED_SAMELINE-DAG: Pattern/ExprSpecific: {#fn2: () -> String {|}#}[#() -> String#]; +// INIT_OVERLOADED_SAMELINE-DAG: Pattern/ExprSpecific: {#fn3: () -> String {|}#}[#() -> String#]; // INIT_OVERLOADED_SAMELINE-DAG: Decl[InstanceMethod]/CurrNominal: .testStructMethod()[#Void#]; // INIT_OVERLOADED_SAMELINE-DAG: Keyword[self]/CurrNominal: .self[#TestStruct#]; // INIT_OVERLOADED_SAMELINE: End completions // INIT_OVERLOADED_NEWLINE: Begin completions -// INIT_OVERLOADED_NEWLINE-DAG: Pattern/ExprSpecific: {#fn2: () -> String {() -> String in|}#}[#() -> String#]; -// INIT_OVERLOADED_NEWLINE-DAG: Pattern/ExprSpecific: {#fn3: () -> String {() -> String in|}#}[#() -> String#]; +// INIT_OVERLOADED_NEWLINE-DAG: Pattern/ExprSpecific: {#fn2: () -> String {|}#}[#() -> String#]; +// INIT_OVERLOADED_NEWLINE-DAG: Pattern/ExprSpecific: {#fn3: () -> String {|}#}[#() -> String#]; // INIT_OVERLOADED_NEWLINE-DAG: Keyword[class]/None: class; // INIT_OVERLOADED_NEWLINE-DAG: Keyword[if]/None: if; // INIT_OVERLOADED_NEWLINE-DAG: Keyword[try]/None: try; @@ -111,15 +111,15 @@ func testOptionalInit() { #^INIT_OPTIONAL_NEWLINE^# // INIT_OPTIONAL_SAMELINE: Begin completions, 4 items -// INIT_OPTIONAL_SAMELINE-DAG: Pattern/ExprSpecific: {#fn2: () -> String {() -> String in|}#}[#() -> String#]; -// INIT_OPTIONAL_SAMELINE-DAG: Pattern/ExprSpecific: {#fn3: () -> String {() -> String in|}#}[#() -> String#]; +// INIT_OPTIONAL_SAMELINE-DAG: Pattern/ExprSpecific: {#fn2: () -> String {|}#}[#() -> String#]; +// INIT_OPTIONAL_SAMELINE-DAG: Pattern/ExprSpecific: {#fn3: () -> String {|}#}[#() -> String#]; // INIT_OPTIONAL_SAMELINE-DAG: Decl[InstanceMethod]/CurrNominal: .testStructMethod()[#Void#]; // INIT_OPTIONAL_SAMELINE-DAG: Keyword[self]/CurrNominal: .self[#TestStruct2#]; // INIT_OPTIONAL_SAMELINE: End completions // INIT_OPTIONAL_NEWLINE: Begin completions -// INIT_OPTIONAL_NEWLINE-DAG: Pattern/ExprSpecific: {#fn2: () -> String {() -> String in|}#}[#() -> String#]; -// INIT_OPTIONAL_NEWLINE-DAG: Pattern/ExprSpecific: {#fn3: () -> String {() -> String in|}#}[#() -> String#]; +// INIT_OPTIONAL_NEWLINE-DAG: Pattern/ExprSpecific: {#fn2: () -> String {|}#}[#() -> String#]; +// INIT_OPTIONAL_NEWLINE-DAG: Pattern/ExprSpecific: {#fn3: () -> String {|}#}[#() -> String#]; // INIT_OPTIONAL_NEWLINE-DAG: Keyword[class]/None: class; // INIT_OPTIONAL_NEWLINE-DAG: Keyword[if]/None: if; // INIT_OPTIONAL_NEWLINE-DAG: Keyword[try]/None: try; @@ -139,11 +139,11 @@ func testOptionalInit() { #^INIT_REQUIRED_NEWLINE_1^# // INIT_REQUIRED_SAMELINE_1: Begin completions, 1 items -// INIT_REQUIRED_SAMELINE_1-DAG: Pattern/ExprSpecific: {#fn2: () -> String {() -> String in|}#}[#() -> String#]; +// INIT_REQUIRED_SAMELINE_1-DAG: Pattern/ExprSpecific: {#fn2: () -> String {|}#}[#() -> String#]; // INIT_REQUIRED_SAMELINE_1: End completions // INIT_REQUIRED_NEWLINE_1: Begin completions, 1 items -// INIT_REQUIRED_NEWLINE_1-DAG: Pattern/ExprSpecific: {#fn2: () -> String {() -> String in|}#}[#() -> String#]; +// INIT_REQUIRED_NEWLINE_1-DAG: Pattern/ExprSpecific: {#fn2: () -> String {|}#}[#() -> String#]; // INIT_REQUIRED_NEWLINE_1: End completions // missing 'fn3'. @@ -155,11 +155,11 @@ func testOptionalInit() { #^INIT_REQUIRED_NEWLINE_2^# // INIT_REQUIRED_SAMELINE_2: Begin completions, 1 items -// INIT_REQUIRED_SAMELINE_2-DAG: Pattern/ExprSpecific: {#fn3: () -> String {() -> String in|}#}[#() -> String#]; +// INIT_REQUIRED_SAMELINE_2-DAG: Pattern/ExprSpecific: {#fn3: () -> String {|}#}[#() -> String#]; // INIT_REQUIRED_SAMELINE_2: End completions // INIT_REQUIRED_NEWLINE_2: Begin completions, 1 items -// INIT_REQUIRED_NEWLINE_2-DAG: Pattern/ExprSpecific: {#fn3: () -> String {() -> String in|}#}[#() -> String#]; +// INIT_REQUIRED_NEWLINE_2-DAG: Pattern/ExprSpecific: {#fn3: () -> String {|}#}[#() -> String#]; // INIT_REQUIRED_NEWLINE_2: End completions // Call is completed. @@ -218,14 +218,14 @@ struct TestNominalMember: P { #^MEMBERDECL_NEWLINE^# // MEMBERDECL_SAMELINE: Begin completions, 4 items -// MEMBERDECL_SAMELINE-DAG: Pattern/ExprSpecific: {#fn2: (() -> String)? {() -> String in|}#}[#(() -> String)?#]; name=fn2: (() -> String)? +// MEMBERDECL_SAMELINE-DAG: Pattern/ExprSpecific: {#fn2: (() -> String)? {|}#}[#(() -> String)?#]; name=fn2: (() -> String)? // MEMBERDECL_SAMELINE-DAG: Decl[InstanceMethod]/CurrNominal: .enumFunc()[#Void#]; name=enumFunc() // MEMBERDECL_SAMELINE-DAG: Decl[InfixOperatorFunction]/OtherModule[Swift]/IsSystem: [' ']+ {#SimpleEnum#}[#SimpleEnum#]; name=+ SimpleEnum // MEMBERDECL_SAMELINE-DAG: Keyword[self]/CurrNominal: .self[#SimpleEnum#]; name=self // MEMBERDECL_SAMELINE: End completions // MEMBERDECL_NEWLINE: Begin completions -// MEMBERDECL_NEWLINE-DAG: Pattern/ExprSpecific: {#fn2: (() -> String)? {() -> String in|}#}[#(() -> String)?#]; name=fn2: (() -> String)? +// MEMBERDECL_NEWLINE-DAG: Pattern/ExprSpecific: {#fn2: (() -> String)? {|}#}[#(() -> String)?#]; name=fn2: (() -> String)? // MEMBERDECL_NEWLINE-DAG: Keyword[enum]/None: enum; name=enum // MEMBERDECL_NEWLINE-DAG: Keyword[func]/None: func; name=func // MEMBERDECL_NEWLINE-DAG: Keyword[private]/None: private; name=private diff --git a/test/IDE/complete_multiple_trailingclosure_signatures.swift b/test/IDE/complete_multiple_trailingclosure_signatures.swift index ca07b2d53065d..88560fe45464c 100644 --- a/test/IDE/complete_multiple_trailingclosure_signatures.swift +++ b/test/IDE/complete_multiple_trailingclosure_signatures.swift @@ -17,12 +17,12 @@ func test() { // GLOBALFUNC_SAMELINE: Begin completions // GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn2: () -> Void {|}#}[#() -> Void#]; -// GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn3: (Int) -> Void {(<#Int#>) in|}#}[#(Int) -> Void#]; -// GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn4: (Int, String) -> Void {(<#Int#>, <#String#>) in|}#}[#(Int, String) -> Void#]; -// GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn5: (Int, String) -> Int {(<#Int#>, <#String#>) -> Int in|}#}[#(Int, String) -> Int#]; +// GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn3: (Int) -> Void {<#Int#> in|}#}[#(Int) -> Void#]; +// GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn4: (Int, String) -> Void {<#Int#>, <#String#> in|}#}[#(Int, String) -> Void#]; +// GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn5: (Int, String) -> Int {<#Int#>, <#String#> in|}#}[#(Int, String) -> Int#]; // FIXME: recover names -// GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn6: (Int, String) -> Int {(<#Int#>, <#String#>) -> Int in|}#}[#(Int, String) -> Int#]; -// GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn7: (inout Int) -> Void {(<#inout Int#>) in|}#}[#(inout Int) -> Void#]; -// GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn8: (Int...) -> Void {(<#Int...#>) in|}#}[#(Int...) -> Void#]; +// GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn6: (Int, String) -> Int {<#Int#>, <#String#> in|}#}[#(Int, String) -> Int#]; +// GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn7: (inout Int) -> Void {<#inout Int#> in|}#}[#(inout Int) -> Void#]; +// GLOBALFUNC_SAMELINE-DAG: Pattern/ExprSpecific: {#fn8: (Int...) -> Void {<#Int...#> in|}#}[#(Int...) -> Void#]; // GLOBALFUNC_SAMELINE: End completions } diff --git a/test/SourceKit/CodeComplete/multiple_trailing_closure_signatures.swift b/test/SourceKit/CodeComplete/multiple_trailing_closure_signatures.swift index 31d0b944aa348..8d5af8c44d630 100644 --- a/test/SourceKit/CodeComplete/multiple_trailing_closure_signatures.swift +++ b/test/SourceKit/CodeComplete/multiple_trailing_closure_signatures.swift @@ -19,11 +19,11 @@ func func1( // CHECK: key.results: [ // CHECK-DAG: key.sourcetext: "fn2: {\n<#code#>\n}" -// CHECK-DAG: key.sourcetext: "fn3: { (<#Int#>) in\n<#code#>\n}" -// CHECK-DAG: key.sourcetext: "fn4: { (<#Int#>, <#String#>) in\n<#code#>\n}", -// CHECK-DAG: key.sourcetext: "fn5: { (<#Int#>, <#String#>) -> Int in\n<#code#>\n}", -// CHECK-DAG: key.sourcetext: "fn7: { (<#inout Int#>) in\n<#code#>\n}", -// CHECK-DAG: key.sourcetext: "fn8: { (<#Int...#>) in\n<#code#>\n}", +// CHECK-DAG: key.sourcetext: "fn3: { <#Int#> in\n<#code#>\n}" +// CHECK-DAG: key.sourcetext: "fn4: { <#Int#>, <#String#> in\n<#code#>\n}", +// CHECK-DAG: key.sourcetext: "fn5: { <#Int#>, <#String#> in\n<#code#>\n}", +// CHECK-DAG: key.sourcetext: "fn7: { <#inout Int#> in\n<#code#>\n}", +// CHECK-DAG: key.sourcetext: "fn8: { <#Int...#> in\n<#code#>\n}", // CHECK: ] // DESCRIPTION-NOT: key.description: "fn{{[0-9]*}}: { diff --git a/test/SourceKit/CodeExpand/code-expand-multiple-trailing-closures.swift b/test/SourceKit/CodeExpand/code-expand-multiple-trailing-closures.swift index e1a2428e3a537..c5811be8a179f 100644 --- a/test/SourceKit/CodeExpand/code-expand-multiple-trailing-closures.swift +++ b/test/SourceKit/CodeExpand/code-expand-multiple-trailing-closures.swift @@ -89,18 +89,18 @@ nonTrailingAndTrailing2(a: <#T##() -> ()#>, b: <#T##Int#> c: <#T##() -> ()#>) withTypesAndLabels1(a: <#T##(_ booly: Bool, inty: Int) -> ()#>, b: <#T##(solo: Xyz) -> ()#>) -// CHECK: withTypesAndLabels1 { (booly, inty) in +// CHECK: withTypesAndLabels1 { booly, inty in // CHECK-NEXT: <#code#> -// CHECK-NEXT: } b: { (solo) in +// CHECK-NEXT: } b: { solo in // CHECK-NEXT: <#code#> // CHECK-NEXT: } func reset_parser1() {} withTypes1(a: <#T##(Bool, Int) -> ()#>, b: <#T##() -> Int#>) -// CHECK: withTypes1 { (<#Bool#>, <#Int#>) in +// CHECK: withTypes1 { <#Bool#>, <#Int#> in // CHECK-NEXT: <#code#> -// CHECK-NEXT: } b: { () -> Int in +// CHECK-NEXT: } b: { // CHECK-NEXT: <#code#> // CHECK-NEXT: } diff --git a/test/SourceKit/CodeExpand/code-expand.swift b/test/SourceKit/CodeExpand/code-expand.swift index ab5f00fcbec62..33c2a82643e53 100644 --- a/test/SourceKit/CodeExpand/code-expand.swift +++ b/test/SourceKit/CodeExpand/code-expand.swift @@ -11,23 +11,23 @@ foo(x: <#T##() -> Void#>, y: <#T##Int#>) // CHECK-NEXT: }, y: Int) anArr.indexOfObjectPassingTest(<#T##predicate: ((AnyObject!, Int, UnsafePointer) -> Bool)?##((AnyObject!, Int, UnsafePointer) -> Bool)?#>) -// CHECK: anArr.indexOfObjectPassingTest { (<#AnyObject!#>, <#Int#>, <#UnsafePointer#>) -> Bool in +// CHECK: anArr.indexOfObjectPassingTest { <#AnyObject!#>, <#Int#>, <#UnsafePointer#> in // CHECK-NEXT: <#code#> // CHECK-NEXT: } anArr.indexOfObjectPassingTest(<#T##predicate: ((_ obj: AnyObject!, _ idx: Int, _ stop: UnsafePointer) -> Bool)?##((_ obj: AnyObject!, _ idx: Int, _ stop: UnsafePointer) -> Bool)?#>) -// CHECK: anArr.indexOfObjectPassingTest { (obj, idx, stop) -> Bool in +// CHECK: anArr.indexOfObjectPassingTest { obj, idx, stop in // CHECK-NEXT: <#code#> // CHECK-NEXT: } anArr.indexOfObjectAtIndexes(<#T##s: NSIndexSet?##NSIndexSet?#>, options: <#T##NSEnumerationOptions#>, passingTest: <#T##((AnyObject!, Int, UnsafePointer) -> Bool)?#>) -// CHECK: anArr.indexOfObjectAtIndexes(NSIndexSet?, options: NSEnumerationOptions) { (<#AnyObject!#>, <#Int#>, <#UnsafePointer#>) -> Bool in +// CHECK: anArr.indexOfObjectAtIndexes(NSIndexSet?, options: NSEnumerationOptions) { <#AnyObject!#>, <#Int#>, <#UnsafePointer#> in // CHECK-NEXT: <#code#> // CHECK-NEXT: } if anArr.indexOfObjectPassingTest(<#T##predicate: ((AnyObject!, Int, UnsafePointer) -> Bool)?##((AnyObject!, Int, UnsafePointer) -> Bool)?#>) { } -// CHECK: if anArr.indexOfObjectPassingTest({ (<#AnyObject!#>, <#Int#>, <#UnsafePointer#>) -> Bool in +// CHECK: if anArr.indexOfObjectPassingTest({ <#AnyObject!#>, <#Int#>, <#UnsafePointer#> in // CHECK-NEXT: <#code#> // CHECK-NEXT: }) { // CHECK-NEXT: } @@ -48,10 +48,10 @@ do { } foo(x: <#T##Self.SegueIdentifier -> Void#>) -// CHECK: foo { (<#Self.SegueIdentifier#>) in +// CHECK: foo { <#Self.SegueIdentifier#> in store.requestAccessToEntityType(<#T##entityType: EKEntityType##EKEntityType#>, completion: <#T##EKEventStoreRequestAccessCompletionHandler##EKEventStoreRequestAccessCompletionHandler##(Bool, NSError?) -> Void#>) -// CHECK: store.requestAccessToEntityType(EKEntityType) { (<#Bool#>, <#NSError?#>) in +// CHECK: store.requestAccessToEntityType(EKEntityType) { <#Bool#>, <#NSError?#> in // CHECK-NEXT: <#code#> // CHECK-NEXT: } diff --git a/tools/SourceKit/lib/SwiftLang/SwiftEditor.cpp b/tools/SourceKit/lib/SwiftLang/SwiftEditor.cpp index 68943da63c200..ab573292cd31d 100644 --- a/tools/SourceKit/lib/SwiftLang/SwiftEditor.cpp +++ b/tools/SourceKit/lib/SwiftLang/SwiftEditor.cpp @@ -2129,23 +2129,10 @@ void SwiftEditorDocument::formatText(unsigned Line, unsigned Length, Consumer.recordAffectedLineRange(LineRange.startLine(), LineRange.lineCount()); } -bool isReturningVoid(const SourceManager &SM, CharSourceRange Range) { - if (Range.isInvalid()) - return false; - StringRef Text = SM.extractText(Range); - return "()" == Text || "Void" == Text; -} - static void printClosureBody(const PlaceholderExpansionScanner::ClosureInfo &closure, llvm::raw_ostream &OS, const SourceManager &SM) { - bool ReturningVoid = isReturningVoid(SM, closure.ReturnTypeRange); - - bool HasSignature = !closure.Params.empty() || - (closure.ReturnTypeRange.isValid() && !ReturningVoid); bool FirstParam = true; - if (HasSignature) - OS << "("; for (auto &Param : closure.Params) { if (!FirstParam) OS << ", "; @@ -2153,30 +2140,19 @@ printClosureBody(const PlaceholderExpansionScanner::ClosureInfo &closure, if (Param.NameRange.isValid()) { // If we have a parameter name, just output the name as is and skip // the type. For example: - // <#(arg1: Int, arg2: Int)#> turns into (arg1, arg2). + // <#(arg1: Int, arg2: Int)#> turns into '{ arg1, arg2 in'. OS << SM.extractText(Param.NameRange); } else { // If we only have the parameter type, output the type as a // placeholder. For example: - // <#(Int, Int)#> turns into (<#Int#>, <#Int#>). + // <#(Int, Int)#> turns into '{ <#Int#>, <#Int#> in'. OS << "<#"; OS << SM.extractText(Param.TypeRange); OS << "#>"; } } - if (HasSignature) - OS << ") "; - if (closure.ReturnTypeRange.isValid()) { - auto ReturnTypeText = SM.extractText(closure.ReturnTypeRange); - - // We need return type if it is not Void. - if (!ReturningVoid) { - OS << "-> "; - OS << ReturnTypeText << " "; - } - } - if (HasSignature) - OS << "in"; + if (!FirstParam) + OS << " in"; OS << "\n" << getCodePlaceholder() << "\n"; } From f0749951069e9344a85657c79a8e3155a599363d Mon Sep 17 00:00:00 2001 From: Saleem Abdulrasool Date: Thu, 9 Jul 2020 15:54:34 -0700 Subject: [PATCH 29/36] test: partially undo changes to `PathSanitizingFileCheck` The regular expression engine escaped the strings differently across python 2 and 3. Using a raw string makes this simpler to understand and obsoletes the comment. This change also now properly allows the replacement to occur in the same way on 2 and 3. --- utils/PathSanitizingFileCheck | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/utils/PathSanitizingFileCheck b/utils/PathSanitizingFileCheck index 3496895ec0d07..030fe04fdb2b0 100755 --- a/utils/PathSanitizingFileCheck +++ b/utils/PathSanitizingFileCheck @@ -80,10 +80,9 @@ constants.""") # Since we want to use pattern as a regex in some platforms, we need # to escape it first, and then replace the escaped slash # literal (r'\\/') for our platform-dependent slash regex. - stdin = re.sub(re.sub('\\\\/' if sys.version_info[0] < 3 else r'[/\\]', + stdin = re.sub(re.sub(r'\\/' if sys.version_info[0] < 3 else r'/', slashes_re, re.escape(pattern)), - replacement, - stdin) + replacement, stdin) if args.dry_run: print(stdin) From 0b5dbb111a5554d38e8626f86afff8b3909490e9 Mon Sep 17 00:00:00 2001 From: Rintaro Ishizaki Date: Thu, 9 Jul 2020 16:48:45 -0700 Subject: [PATCH 30/36] [swift-ide-test] Add indicator of "reusing ASTContext" to the result --- tools/swift-ide-test/swift-ide-test.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tools/swift-ide-test/swift-ide-test.cpp b/tools/swift-ide-test/swift-ide-test.cpp index c32dedc0d7d62..531b9da23d3f6 100644 --- a/tools/swift-ide-test/swift-ide-test.cpp +++ b/tools/swift-ide-test/swift-ide-test.cpp @@ -1198,6 +1198,7 @@ static int doBatchCodeCompletion(const CompilerInvocation &InitInvok, PrintingDiagnosticConsumer PrintDiags; auto completionStart = std::chrono::high_resolution_clock::now(); + bool wasASTContextReused = false; bool isSuccess = CompletionInst.performOperation( Invocation, /*Args=*/{}, FileSystem, completionBuffer.get(), Offset, /*EnableASTCaching=*/true, Error, @@ -1217,11 +1218,15 @@ static int doBatchCodeCompletion(const CompilerInvocation &InitInvok, auto *SF = CI.getCodeCompletionFile(); performCodeCompletionSecondPass(*SF, *callbacksFactory); + wasASTContextReused = reusingASTContext; }); auto completionEnd = std::chrono::high_resolution_clock::now(); auto elapsed = std::chrono::duration_cast( completionEnd - completionStart); - llvm::errs() << "Elapsed: " << elapsed.count() << " msec\n"; + llvm::errs() << "Elapsed: " << elapsed.count() << " msec"; + if (wasASTContextReused) + llvm::errs() << " (reusing ASTContext)"; + llvm::errs() << "\n"; OS.flush(); if (OutputDir.empty()) { From b77d68a9f32b371278060f133a9514d7ba4bd5d5 Mon Sep 17 00:00:00 2001 From: Vedant Kumar Date: Thu, 9 Jul 2020 17:10:16 -0700 Subject: [PATCH 31/36] [build-script] Tie llvm, swift, and lldb to the same sysroot rdar://62895058 --- utils/build-script-impl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/utils/build-script-impl b/utils/build-script-impl index 71057c73630cb..6505a7ea71dc7 100755 --- a/utils/build-script-impl +++ b/utils/build-script-impl @@ -705,6 +705,10 @@ function set_build_options_for_host() { # in the compiler checks CMake performs -DCMAKE_OSX_ARCHITECTURES="${architecture}" ) + + lldb_cmake_options+=( + -DCMAKE_OSX_SYSROOT:PATH="${cmake_os_sysroot}" + ) ;; esac From 01a44bbdf929f0eaf81dbb881e33394db1ce76f9 Mon Sep 17 00:00:00 2001 From: Nate Chandler Date: Thu, 9 Jul 2020 17:09:26 -0700 Subject: [PATCH 32/36] [Test] Xfail attr/attr_originally_definedin_backward_compatibility.swift. rdar://problem/64298096 --- .../attr/attr_originally_definedin_backward_compatibility.swift | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/attr/attr_originally_definedin_backward_compatibility.swift b/test/attr/attr_originally_definedin_backward_compatibility.swift index c1327748cf5b6..5978ba0d92f58 100644 --- a/test/attr/attr_originally_definedin_backward_compatibility.swift +++ b/test/attr/attr_originally_definedin_backward_compatibility.swift @@ -1,6 +1,8 @@ // REQUIRES: executable_test // REQUIRES: OS=macosx || OS=ios // UNSUPPORTED: DARWIN_SIMULATOR=ios +// rdar://problem/64298096 +// XFAIL: OS=ios && CPU=arm64 // // RUN: %empty-directory(%t) // From 383482b6577247f47aff66fd56c5aff0778441ae Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 9 Jul 2020 20:02:10 -0700 Subject: [PATCH 33/36] [AutoDiff] Fix derivative generic signature same-type requirements. (#32803) Fix SILGen for `@derivative` attributes where the derivative generic signature is equal to the original generic signature and has all concrete generic parameters (i.e. all generic parameters are bound to concrete types via same-type requirements). SILGen should emit a differentiability witness with no generic signature. This is already done for `@differentiable` attributes. Resolves TF-1292. --- lib/SILGen/SILGen.cpp | 47 ++++- lib/Sema/TypeCheckAttr.cpp | 18 -- ...ntiability_witness_generic_signature.swift | 177 ++++++++++++++++++ 3 files changed, 221 insertions(+), 21 deletions(-) create mode 100644 test/AutoDiff/SILGen/differentiability_witness_generic_signature.swift diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index 7f2b690297c55..bc93cdf2387a3 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -935,6 +935,43 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, emitDifferentiabilityWitnessesForFunction(constant, F); } +/// Returns the SIL differentiability witness generic signature given the +/// original declaration's generic signature and the derivative generic +/// signature. +/// +/// In general, the differentiability witness generic signature is equal to the +/// derivative generic signature. +/// +/// Edge case, if two conditions are satisfied: +/// 1. The derivative generic signature is equal to the original generic +/// signature. +/// 2. The derivative generic signature has *all concrete* generic parameters +/// (i.e. all generic parameters are bound to concrete types via same-type +/// requirements). +/// +/// Then the differentiability witness generic signature is `nullptr`. +/// +/// Both the original and derivative declarations are lowered to SIL functions +/// with a fully concrete type and no generic signature, so the +/// differentiability witness should similarly have no generic signature. +static GenericSignature +getDifferentiabilityWitnessGenericSignature(GenericSignature origGenSig, + GenericSignature derivativeGenSig) { + // If there is no derivative generic signature, return the original generic + // signature. + if (!derivativeGenSig) + return origGenSig; + // If derivative generic signature has all concrete generic parameters and is + // equal to the original generic signature, return `nullptr`. + auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature(); + auto origCanGenSig = origGenSig.getCanonicalSignature(); + if (origCanGenSig == derivativeCanGenSig && + derivativeCanGenSig->areAllParamsConcrete()) + return GenericSignature(); + // Otherwise, return the derivative generic signature. + return derivativeGenSig; +} + void SILGenModule::emitDifferentiabilityWitnessesForFunction( SILDeclRef constant, SILFunction *F) { // Visit `@derivative` attributes and generate SIL differentiability @@ -955,8 +992,11 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( diffAttr->getDerivativeGenericSignature()) && "Type-checking should resolve derivative generic signatures for " "all original SIL functions with generic signatures"); + auto witnessGenSig = getDifferentiabilityWitnessGenericSignature( + AFD->getGenericSignature(), + diffAttr->getDerivativeGenericSignature()); AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices, - diffAttr->getDerivativeGenericSignature()); + witnessGenSig); emitDifferentiabilityWitness(AFD, F, config, /*jvp*/ nullptr, /*vjp*/ nullptr, diffAttr); } @@ -975,10 +1015,11 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( auto origDeclRef = SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD)); auto *origFn = getFunction(origDeclRef, NotForDefinition); - auto derivativeGenSig = AFD->getGenericSignature(); + auto witnessGenSig = getDifferentiabilityWitnessGenericSignature( + origAFD->getGenericSignature(), AFD->getGenericSignature()); auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0}); AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices, - derivativeGenSig); + witnessGenSig); emitDifferentiabilityWitness(origAFD, origFn, config, jvp, vjp, derivAttr); } diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 7375764e368ee..ac10198e2bae3 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -4194,24 +4194,6 @@ bool resolveDifferentiableAttrDerivativeGenericSignature( attr->getLocation(), /*allowConcreteGenericParams=*/true); } - // Set the resolved derivative generic signature in the attribute. - // Do not set the derivative generic signature if the original function's - // generic signature is equal to `derivativeGenSig` and all generic parameters - // are concrete. In that case, the original function and derivative functions - // are all lowered as SIL functions with no generic signature (specialized - // with concrete types from same-type requirements), so the derivative generic - // signature should not be set. - auto skipDerivativeGenericSignature = [&] { - auto origCanGenSig = - original->getGenericSignature().getCanonicalSignature(); - auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature(); - if (!derivativeCanGenSig) - return false; - return origCanGenSig == derivativeCanGenSig && - derivativeCanGenSig->areAllParamsConcrete(); - }; - if (skipDerivativeGenericSignature()) - derivativeGenSig = GenericSignature(); attr->setDerivativeGenericSignature(derivativeGenSig); return false; } diff --git a/test/AutoDiff/SILGen/differentiability_witness_generic_signature.swift b/test/AutoDiff/SILGen/differentiability_witness_generic_signature.swift new file mode 100644 index 0000000000000..ff198e430e33c --- /dev/null +++ b/test/AutoDiff/SILGen/differentiability_witness_generic_signature.swift @@ -0,0 +1,177 @@ +// RUN: %target-swift-emit-silgen -verify -module-name main %s | %FileCheck %s +// RUN: %target-swift-emit-sil -verify -module-name main %s + +// NOTE(SR-11950): SILParser crashes for SILGen round-trip. + +// This file tests: +// - The "derivative generic signature" of `@differentiable` and `@derivative` +// attributes. +// - The generic signature of lowered SIL differentiability witnesses. + +// Context: +// - For `@differentiable` attributes: the derivative generic signature is +// resolved from the original declaration's generic signature and additional +// `where` clause requirements. +// - For `@derivative` attributes: the derivative generic signature is the +// attributed declaration's generic signature. + +import _Differentiation + +//===----------------------------------------------------------------------===// +// Same-type requirements +//===----------------------------------------------------------------------===// + +// Test original declaration with a generic signature and derivative generic +// signature where all generic parameters are concrete (i.e. bound to concrete +// types via same-type requirements). + +struct AllConcrete: Differentiable {} + +extension AllConcrete { + // Original generic signature: `` + // Derivative generic signature: `` + // Witness generic signature: `` + @_silgen_name("allconcrete_where_gensig_constrained") + @differentiable(where T == Float) + func whereClauseGenericSignatureConstrained() -> AllConcrete { + return self + } +} +extension AllConcrete where T == Float { + @derivative(of: whereClauseGenericSignatureConstrained) + func jvpWhereClauseGenericSignatureConstrained() -> ( + value: AllConcrete, differential: (TangentVector) -> TangentVector + ) { + (whereClauseGenericSignatureConstrained(), { $0 }) + } +} + +// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig_constrained +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_where_gensig_constrained : $@convention(method) (AllConcrete) -> AllConcrete { +// CHECK-NEXT: jvp: @AD__allconcrete_where_gensig_constrained__jvp_src_0_wrt_0_SfRszl : $@convention(method) (AllConcrete) -> (AllConcrete, @owned @callee_guaranteed (AllConcrete.TangentVector) -> AllConcrete.TangentVector) +// CHECK-NEXT: } + +// If a `@differentiable` or `@derivative` attribute satisfies two conditions: +// 1. The derivative generic signature is equal to the original generic signature. +// 2. The derivative generic signature has *all concrete* generic parameters. +// +// Then the attribute should be lowered to a SIL differentiability witness with +// *no* derivative generic signature. + +extension AllConcrete where T == Float { + // Original generic signature: `` + // Derivative generic signature: `` + // Witness generic signature: none + @_silgen_name("allconcrete_original_gensig") + @differentiable + func originalGenericSignature() -> AllConcrete { + return self + } + + @derivative(of: originalGenericSignature) + func jvpOriginalGenericSignature() -> ( + value: AllConcrete, differential: (TangentVector) -> TangentVector + ) { + (originalGenericSignature(), { $0 }) + } + +// CHECK-LABEL: // differentiability witness for allconcrete_original_gensig +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_original_gensig : $@convention(method) (AllConcrete) -> AllConcrete { +// CHECK-NEXT: jvp: @AD__allconcrete_original_gensig__jvp_src_0_wrt_0 : $@convention(method) (AllConcrete) -> (AllConcrete, @owned @callee_guaranteed (AllConcrete.TangentVector) -> AllConcrete.TangentVector) +// CHECK-NEXT: } + + // Original generic signature: `` + // Derivative generic signature: `` (explicit `where` clause) + // Witness generic signature: none + @_silgen_name("allconcrete_where_gensig") + @differentiable(where T == Float) + func whereClauseGenericSignature() -> AllConcrete { + return self + } + + @derivative(of: whereClauseGenericSignature) + func jvpWhereClauseGenericSignature() -> ( + value: AllConcrete, differential: (TangentVector) -> TangentVector + ) { + (whereClauseGenericSignature(), { $0 }) + } + +// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @allconcrete_where_gensig : $@convention(method) (AllConcrete) -> AllConcrete { +// CHECK-NEXT: jvp: @AD__allconcrete_where_gensig__jvp_src_0_wrt_0 : $@convention(method) (AllConcrete) -> (AllConcrete, @owned @callee_guaranteed (AllConcrete.TangentVector) -> AllConcrete.TangentVector) +// CHECK-NEXT: } +} + +// Test original declaration with a generic signature and derivative generic +// signature where *not* all generic parameters are concrete. +// types via same-type requirements). + +struct NotAllConcrete: Differentiable {} + +extension NotAllConcrete { + // Original generic signature: `` + // Derivative generic signature: `` + // Witness generic signature: `` (not all concrete) + @_silgen_name("notallconcrete_where_gensig_constrained") + @differentiable(where T == Float) + func whereClauseGenericSignatureConstrained() -> NotAllConcrete { + return self + } +} +extension NotAllConcrete where T == Float { + @derivative(of: whereClauseGenericSignatureConstrained) + func jvpWhereClauseGenericSignatureConstrained() -> ( + value: NotAllConcrete, differential: (TangentVector) -> TangentVector + ) { + (whereClauseGenericSignatureConstrained(), { $0 }) + } +} + +// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig_constrained +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @notallconcrete_where_gensig_constrained : $@convention(method) (NotAllConcrete) -> NotAllConcrete { +// CHECK-NEXT: jvp: @AD__notallconcrete_where_gensig_constrained__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete) -> (NotAllConcrete, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for .TangentVector, NotAllConcrete.TangentVector>) +// CHECK-NEXT: } + +extension NotAllConcrete where T == Float { + // Original generic signature: `` + // Derivative generic signature: `` + // Witness generic signature: `` (not all concrete) + @_silgen_name("notallconcrete_original_gensig") + @differentiable + func originalGenericSignature() -> NotAllConcrete { + return self + } + + @derivative(of: originalGenericSignature) + func jvpOriginalGenericSignature() -> ( + value: NotAllConcrete, differential: (TangentVector) -> TangentVector + ) { + (originalGenericSignature(), { $0 }) + } + +// CHECK-LABEL: // differentiability witness for notallconcrete_original_gensig +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @notallconcrete_original_gensig : $@convention(method) (NotAllConcrete) -> NotAllConcrete { +// CHECK-NEXT: jvp: @AD__notallconcrete_original_gensig__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete) -> (NotAllConcrete, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for .TangentVector, NotAllConcrete.TangentVector>) +// CHECK-NEXT: } + + // Original generic signature: `` + // Derivative generic signature: `` (explicit `where` clause) + // Witness generic signature: `` (not all concrete) + @_silgen_name("notallconcrete_where_gensig") + @differentiable(where T == Float) + func whereClauseGenericSignature() -> NotAllConcrete { + return self + } + + @derivative(of: whereClauseGenericSignature) + func jvpWhereClauseGenericSignature() -> ( + value: NotAllConcrete, differential: (TangentVector) -> TangentVector + ) { + (whereClauseGenericSignature(), { $0 }) + } + +// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig +// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @notallconcrete_where_gensig : $@convention(method) (NotAllConcrete) -> NotAllConcrete { +// CHECK-NEXT: jvp: @AD__notallconcrete_where_gensig__jvp_src_0_wrt_0_SfRszr0_l : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete) -> (NotAllConcrete, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for .TangentVector, NotAllConcrete.TangentVector>) +// CHECK-NEXT: } +} From 2daf6cdb9d0950db512889faccfb0af4191ec367 Mon Sep 17 00:00:00 2001 From: Michael Gottesman Date: Thu, 9 Jul 2020 23:04:37 -0700 Subject: [PATCH 34/36] [gardening] Update SILLowerAggregateInstrs to be of a more modern style. This is a really simple pass that isn't going to be touched for a long time after I am done fixing the pass for ownership. So it makes sense to clean it up now. I am doing this as a separate commit before updating for ownership. --- .../Transforms/SILLowerAggregateInstrs.cpp | 218 +++++++++--------- 1 file changed, 114 insertions(+), 104 deletions(-) diff --git a/lib/SILOptimizer/Transforms/SILLowerAggregateInstrs.cpp b/lib/SILOptimizer/Transforms/SILLowerAggregateInstrs.cpp index bb03fe517951d..331b44b0fc329 100644 --- a/lib/SILOptimizer/Transforms/SILLowerAggregateInstrs.cpp +++ b/lib/SILOptimizer/Transforms/SILLowerAggregateInstrs.cpp @@ -9,9 +9,12 @@ // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// -// -// Simplify aggregate instructions into scalar instructions. -// +/// +/// \file +/// +/// Simplify aggregate instructions into scalar instructions using simple +/// peephole transformations. +/// //===----------------------------------------------------------------------===// #define DEBUG_TYPE "sil-lower-aggregate-instrs" @@ -26,6 +29,7 @@ #include "swift/SILOptimizer/Utils/InstOptUtils.h" #include "llvm/ADT/Statistic.h" #include "llvm/Support/Debug.h" + using namespace swift; using namespace swift::Lowering; @@ -70,154 +74,156 @@ STATISTIC(NumExpand, "Number of instructions expanded"); /// // no retain of %new! /// // no load/release of %old! /// store %new to %1 : $*T -static bool expandCopyAddr(CopyAddrInst *CA) { - SILModule &M = CA->getModule(); - SILFunction *F = CA->getFunction(); - SILValue Source = CA->getSrc(); +static bool expandCopyAddr(CopyAddrInst *cai) { + SILFunction *fn = cai->getFunction(); + SILModule &module = cai->getModule(); + SILValue source = cai->getSrc(); // If we have an address only type don't do anything. - SILType SrcType = Source->getType(); - if (SrcType.isAddressOnly(*F)) + SILType srcType = source->getType(); + if (srcType.isAddressOnly(*fn)) return false; - bool expand = shouldExpand(M, SrcType.getObjectType()); + bool expand = shouldExpand(module, srcType.getObjectType()); using TypeExpansionKind = Lowering::TypeLowering::TypeExpansionKind; auto expansionKind = expand ? TypeExpansionKind::MostDerivedDescendents : TypeExpansionKind::None; - SILBuilderWithScope Builder(CA); + SILBuilderWithScope builder(cai); // %new = load %0 : $*T - LoadInst *New = Builder.createLoad(CA->getLoc(), Source, - LoadOwnershipQualifier::Unqualified); + LoadInst *newValue = builder.createLoad(cai->getLoc(), source, + LoadOwnershipQualifier::Unqualified); - SILValue Destination = CA->getDest(); + SILValue destAddr = cai->getDest(); // If our object type is not trivial, we may need to release the old value and // retain the new one. - auto &TL = F->getTypeLowering(SrcType); + auto &typeLowering = fn->getTypeLowering(srcType); // If we have a non-trivial type... - if (!TL.isTrivial()) { - + if (!typeLowering.isTrivial()) { // If we are not initializing: // %old = load %1 : $*T - IsInitialization_t IsInit = CA->isInitializationOfDest(); - LoadInst *Old = nullptr; - if (IsInitialization_t::IsNotInitialization == IsInit) { - Old = Builder.createLoad(CA->getLoc(), Destination, - LoadOwnershipQualifier::Unqualified); + auto isInit = cai->isInitializationOfDest(); + LoadInst *oldValue = nullptr; + if (IsInitialization_t::IsNotInitialization == isInit) { + oldValue = builder.createLoad(cai->getLoc(), destAddr, + LoadOwnershipQualifier::Unqualified); } // If we are not taking and have a reference type: // strong_retain %new : $*T // or if we have a non-trivial non-reference type. // retain_value %new : $*T - IsTake_t IsTake = CA->isTakeOfSrc(); - if (IsTake_t::IsNotTake == IsTake) { - TL.emitLoweredCopyValue(Builder, CA->getLoc(), New, expansionKind); + if (IsTake_t::IsNotTake == cai->isTakeOfSrc()) { + typeLowering.emitLoweredCopyValue(builder, cai->getLoc(), newValue, + expansionKind); } // If we are not initializing: // strong_release %old : $*T // *or* // release_value %old : $*T - if (Old) { - TL.emitLoweredDestroyValue(Builder, CA->getLoc(), Old, expansionKind); + if (oldValue) { + typeLowering.emitLoweredDestroyValue(builder, cai->getLoc(), oldValue, + expansionKind); } } // Create the store. - Builder.createStore(CA->getLoc(), New, Destination, + builder.createStore(cai->getLoc(), newValue, destAddr, StoreOwnershipQualifier::Unqualified); ++NumExpand; return true; } -static bool expandDestroyAddr(DestroyAddrInst *DA) { - SILFunction *F = DA->getFunction(); - SILModule &Module = DA->getModule(); - SILBuilderWithScope Builder(DA); +static bool expandDestroyAddr(DestroyAddrInst *dai) { + SILFunction *fn = dai->getFunction(); + SILModule &module = dai->getModule(); + SILBuilderWithScope builder(dai); // Strength reduce destroy_addr inst into release/store if // we have a non-address only type. - SILValue Addr = DA->getOperand(); + SILValue addr = dai->getOperand(); // If we have an address only type, do nothing. - SILType Type = Addr->getType(); - if (Type.isAddressOnly(*F)) + SILType type = addr->getType(); + if (type.isAddressOnly(*fn)) return false; - bool expand = shouldExpand(Module, Type.getObjectType()); + bool expand = shouldExpand(module, type.getObjectType()); // If we have a non-trivial type... - if (!Type.isTrivial(*F)) { + if (!type.isTrivial(*fn)) { // If we have a type with reference semantics, emit a load/strong release. - LoadInst *LI = Builder.createLoad(DA->getLoc(), Addr, + LoadInst *li = builder.createLoad(dai->getLoc(), addr, LoadOwnershipQualifier::Unqualified); - auto &TL = F->getTypeLowering(Type); + auto &typeLowering = fn->getTypeLowering(type); using TypeExpansionKind = Lowering::TypeLowering::TypeExpansionKind; auto expansionKind = expand ? TypeExpansionKind::MostDerivedDescendents : TypeExpansionKind::None; - TL.emitLoweredDestroyValue(Builder, DA->getLoc(), LI, expansionKind); + typeLowering.emitLoweredDestroyValue(builder, dai->getLoc(), li, + expansionKind); } ++NumExpand; return true; } -static bool expandReleaseValue(ReleaseValueInst *DV) { - SILFunction *F = DV->getFunction(); - SILModule &Module = DV->getModule(); - SILBuilderWithScope Builder(DV); +static bool expandReleaseValue(ReleaseValueInst *rvi) { + SILFunction *fn = rvi->getFunction(); + SILModule &module = rvi->getModule(); + SILBuilderWithScope builder(rvi); // Strength reduce destroy_addr inst into release/store if // we have a non-address only type. - SILValue Value = DV->getOperand(); + SILValue value = rvi->getOperand(); // If we have an address only type, do nothing. - SILType Type = Value->getType(); - assert(!SILModuleConventions(Module).useLoweredAddresses() - || Type.isLoadable(*F) && - "release_value should never be called on a non-loadable type."); + SILType type = value->getType(); + assert(!SILModuleConventions(module).useLoweredAddresses() || + type.isLoadable(*fn) && + "release_value should never be called on a non-loadable type."); - if (!shouldExpand(Module, Type.getObjectType())) + if (!shouldExpand(module, type.getObjectType())) return false; - auto &TL = F->getTypeLowering(Type); - TL.emitLoweredDestroyValueMostDerivedDescendents(Builder, DV->getLoc(), - Value); + auto &TL = fn->getTypeLowering(type); + TL.emitLoweredDestroyValueMostDerivedDescendents(builder, rvi->getLoc(), + value); - LLVM_DEBUG(llvm::dbgs() << " Expanding Destroy Value: " << *DV); + LLVM_DEBUG(llvm::dbgs() << " Expanding: " << *rvi); ++NumExpand; return true; } -static bool expandRetainValue(RetainValueInst *CV) { - SILFunction *F = CV->getFunction(); - SILModule &Module = CV->getModule(); - SILBuilderWithScope Builder(CV); +static bool expandRetainValue(RetainValueInst *rvi) { + SILFunction *fn = rvi->getFunction(); + SILModule &module = rvi->getModule(); + SILBuilderWithScope builder(rvi); // Strength reduce destroy_addr inst into release/store if // we have a non-address only type. - SILValue Value = CV->getOperand(); + SILValue value = rvi->getOperand(); // If we have an address only type, do nothing. - SILType Type = Value->getType(); - assert(!SILModuleConventions(Module).useLoweredAddresses() - || Type.isLoadable(*F) && - "Copy Value can only be called on loadable types."); + SILType type = value->getType(); + assert(!SILModuleConventions(module).useLoweredAddresses() || + type.isLoadable(*fn) && + "Copy Value can only be called on loadable types."); - if (!shouldExpand(Module, Type.getObjectType())) + if (!shouldExpand(module, type.getObjectType())) return false; - auto &TL = F->getTypeLowering(Type); - TL.emitLoweredCopyValueMostDerivedDescendents(Builder, CV->getLoc(), Value); + auto &typeLowering = fn->getTypeLowering(type); + typeLowering.emitLoweredCopyValueMostDerivedDescendents(builder, + rvi->getLoc(), value); - LLVM_DEBUG(llvm::dbgs() << " Expanding Copy Value: " << *CV); + LLVM_DEBUG(llvm::dbgs() << " Expanding: " << *rvi); ++NumExpand; return true; @@ -227,73 +233,77 @@ static bool expandRetainValue(RetainValueInst *CV) { // Top Level Driver //===----------------------------------------------------------------------===// -static bool processFunction(SILFunction &Fn) { - bool Changed = false; - for (auto BI = Fn.begin(), BE = Fn.end(); BI != BE; ++BI) { - auto II = BI->begin(), IE = BI->end(); - while (II != IE) { - SILInstruction *Inst = &*II; +static bool processFunction(SILFunction &fn) { + bool changed = false; + for (auto &block : fn) { + auto ii = block.begin(), ie = block.end(); + while (ii != ie) { + SILInstruction *inst = &*ii; - LLVM_DEBUG(llvm::dbgs() << "Visiting: " << *Inst); + LLVM_DEBUG(llvm::dbgs() << "Visiting: " << *inst); - if (auto *CA = dyn_cast(Inst)) - if (expandCopyAddr(CA)) { - ++II; - CA->eraseFromParent(); - Changed = true; + if (auto *cai = dyn_cast(inst)) + if (expandCopyAddr(cai)) { + ++ii; + cai->eraseFromParent(); + changed = true; continue; } - if (auto *DA = dyn_cast(Inst)) - if (expandDestroyAddr(DA)) { - ++II; - DA->eraseFromParent(); - Changed = true; + if (auto *dai = dyn_cast(inst)) + if (expandDestroyAddr(dai)) { + ++ii; + dai->eraseFromParent(); + changed = true; continue; } - if (auto *CV = dyn_cast(Inst)) - if (expandRetainValue(CV)) { - ++II; - CV->eraseFromParent(); - Changed = true; + if (auto *rvi = dyn_cast(inst)) + if (expandRetainValue(rvi)) { + ++ii; + rvi->eraseFromParent(); + changed = true; continue; } - if (auto *DV = dyn_cast(Inst)) - if (expandReleaseValue(DV)) { - ++II; - DV->eraseFromParent(); - Changed = true; + if (auto *rvi = dyn_cast(inst)) + if (expandReleaseValue(rvi)) { + ++ii; + rvi->eraseFromParent(); + changed = true; continue; } - ++II; + ++ii; } } - return Changed; + return changed; } +//===----------------------------------------------------------------------===// +// Top Level Entrypoint +//===----------------------------------------------------------------------===// + namespace { + class SILLowerAggregate : public SILFunctionTransform { /// The entry point to the transformation. void run() override { - SILFunction *F = getFunction(); + SILFunction *f = getFunction(); // FIXME: Can we support ownership? - if (F->hasOwnership()) + if (f->hasOwnership()) return; - LLVM_DEBUG(llvm::dbgs() << "***** LowerAggregate on function: " << - F->getName() << " *****\n"); - bool Changed = processFunction(*F); - if (Changed) { + LLVM_DEBUG(llvm::dbgs() << "***** LowerAggregate on function: " + << f->getName() << " *****\n"); + bool changed = processFunction(*f); + if (changed) { invalidateAnalysis(SILAnalysis::InvalidationKind::CallsAndInstructions); } } - }; -} // end anonymous namespace +} // end anonymous namespace SILTransform *swift::createLowerAggregateInstrs() { return new SILLowerAggregate(); From c6812cee68c79b97ce617edcedc33f1ae8b2bb23 Mon Sep 17 00:00:00 2001 From: Nate Chandler Date: Fri, 10 Jul 2020 06:21:11 -0700 Subject: [PATCH 35/36] [Test] Mark executable test appropriately. --- validation-test/Runtime/rdar64672291.swift | 1 + 1 file changed, 1 insertion(+) diff --git a/validation-test/Runtime/rdar64672291.swift b/validation-test/Runtime/rdar64672291.swift index 03636624486a7..5b1a15ab76f44 100644 --- a/validation-test/Runtime/rdar64672291.swift +++ b/validation-test/Runtime/rdar64672291.swift @@ -1,5 +1,6 @@ // RUN: %target-run-simple-swift // REQUIRES: objc_interop +// REQUIRES: executable_test import Foundation From c5270a596874f3bcafdfa65d2bd02f060de99649 Mon Sep 17 00:00:00 2001 From: Nate Chandler Date: Thu, 9 Jul 2020 17:25:50 -0700 Subject: [PATCH 36/36] [Test] Xfail stdlib/NSValueBridging.swift rdar://problem/64995079 --- test/stdlib/NSValueBridging.swift.gyb | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/stdlib/NSValueBridging.swift.gyb b/test/stdlib/NSValueBridging.swift.gyb index 3b91eeb9fc86f..6ca31835917ba 100644 --- a/test/stdlib/NSValueBridging.swift.gyb +++ b/test/stdlib/NSValueBridging.swift.gyb @@ -15,6 +15,8 @@ // RUN: %target-codesign %t.out // RUN: %target-run %t.out // REQUIRES: executable_test +// rdar://problem/64995079 +// XFAIL: OS=ios && CPU=armv7s // // REQUIRES: objc_interop