diff --git a/include/swift/AST/KnownProtocols.def b/include/swift/AST/KnownProtocols.def index 0d0c0f84dd15e..e7eab230dea3e 100644 --- a/include/swift/AST/KnownProtocols.def +++ b/include/swift/AST/KnownProtocols.def @@ -115,6 +115,7 @@ PROTOCOL(CxxSet) PROTOCOL(CxxRandomAccessCollection) PROTOCOL(CxxSequence) PROTOCOL(CxxUniqueSet) +PROTOCOL(CxxVector) PROTOCOL(UnsafeCxxInputIterator) PROTOCOL(UnsafeCxxMutableInputIterator) PROTOCOL(UnsafeCxxRandomAccessIterator) diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index eb08423bd4fa1..1dbb19ca78a9c 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -1139,6 +1139,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const { case KnownProtocolKind::CxxSet: case KnownProtocolKind::CxxSequence: case KnownProtocolKind::CxxUniqueSet: + case KnownProtocolKind::CxxVector: case KnownProtocolKind::UnsafeCxxInputIterator: case KnownProtocolKind::UnsafeCxxMutableInputIterator: case KnownProtocolKind::UnsafeCxxRandomAccessIterator: diff --git a/lib/ClangImporter/ClangDerivedConformances.cpp b/lib/ClangImporter/ClangDerivedConformances.cpp index 0c2038afcc590..dbf16809738e4 100644 --- a/lib/ClangImporter/ClangDerivedConformances.cpp +++ b/lib/ClangImporter/ClangDerivedConformances.cpp @@ -927,3 +927,47 @@ void swift::conformToCxxDictionaryIfNeeded( insert->getResultInterfaceType()); impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxDictionary}); } + +void swift::conformToCxxVectorIfNeeded(ClangImporter::Implementation &impl, + NominalTypeDecl *decl, + const clang::CXXRecordDecl *clangDecl) { + PrettyStackTraceDecl trace("conforming to CxxVector", decl); + + assert(decl); + assert(clangDecl); + ASTContext &ctx = decl->getASTContext(); + + // Only auto-conform types from the C++ standard library. Custom user types + // might have a similar interface but different semantics. + if (!isStdDecl(clangDecl, {"vector"})) + return; + + auto valueType = lookupDirectSingleWithoutExtensions( + decl, ctx.getIdentifier("value_type")); + auto iterType = lookupDirectSingleWithoutExtensions( + decl, ctx.getIdentifier("const_iterator")); + if (!valueType || !iterType) + return; + + ProtocolDecl *cxxRandomAccessIteratorProto = + ctx.getProtocol(KnownProtocolKind::UnsafeCxxRandomAccessIterator); + if (!cxxRandomAccessIteratorProto) + return; + + auto rawIteratorTy = iterType->getUnderlyingType(); + + // Check if RawIterator conforms to UnsafeCxxRandomAccessIterator. + ModuleDecl *module = decl->getModuleContext(); + auto rawIteratorConformanceRef = + module->lookupConformance(rawIteratorTy, cxxRandomAccessIteratorProto); + if (!isConcreteAndValid(rawIteratorConformanceRef, module)) + return; + + impl.addSynthesizedTypealias(decl, ctx.Id_Element, + valueType->getUnderlyingType()); + impl.addSynthesizedTypealias(decl, ctx.Id_ArrayLiteralElement, + valueType->getUnderlyingType()); + impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawIterator"), + rawIteratorTy); + impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxVector}); +} diff --git a/lib/ClangImporter/ClangDerivedConformances.h b/lib/ClangImporter/ClangDerivedConformances.h index 3faeb3efdee05..ac3e1f5822528 100644 --- a/lib/ClangImporter/ClangDerivedConformances.h +++ b/lib/ClangImporter/ClangDerivedConformances.h @@ -59,6 +59,12 @@ void conformToCxxDictionaryIfNeeded(ClangImporter::Implementation &impl, NominalTypeDecl *decl, const clang::CXXRecordDecl *clangDecl); +/// If the decl is an instantiation of C++ `std::vector`, synthesize a +/// conformance to CxxVector, which is defined in the Cxx module. +void conformToCxxVectorIfNeeded(ClangImporter::Implementation &impl, + NominalTypeDecl *decl, + const clang::CXXRecordDecl *clangDecl); + } // namespace swift #endif // SWIFT_CLANG_DERIVED_CONFORMANCES_H diff --git a/lib/ClangImporter/ImportDecl.cpp b/lib/ClangImporter/ImportDecl.cpp index 9e4c5c72c132e..a0cd3082048fe 100644 --- a/lib/ClangImporter/ImportDecl.cpp +++ b/lib/ClangImporter/ImportDecl.cpp @@ -2797,6 +2797,7 @@ namespace { conformToCxxDictionaryIfNeeded(Impl, nominalDecl, decl); conformToCxxPairIfNeeded(Impl, nominalDecl, decl); conformToCxxOptionalIfNeeded(Impl, nominalDecl, decl); + conformToCxxVectorIfNeeded(Impl, nominalDecl, decl); } if (auto *ntd = dyn_cast(result)) diff --git a/lib/IRGen/GenMeta.cpp b/lib/IRGen/GenMeta.cpp index 31c65151b2dce..02d5293c263e3 100644 --- a/lib/IRGen/GenMeta.cpp +++ b/lib/IRGen/GenMeta.cpp @@ -6310,6 +6310,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) { case KnownProtocolKind::CxxSet: case KnownProtocolKind::CxxSequence: case KnownProtocolKind::CxxUniqueSet: + case KnownProtocolKind::CxxVector: case KnownProtocolKind::UnsafeCxxInputIterator: case KnownProtocolKind::UnsafeCxxMutableInputIterator: case KnownProtocolKind::UnsafeCxxRandomAccessIterator: diff --git a/stdlib/public/Cxx/CMakeLists.txt b/stdlib/public/Cxx/CMakeLists.txt index f71de4e37173d..8ca70e3dd7f2e 100644 --- a/stdlib/public/Cxx/CMakeLists.txt +++ b/stdlib/public/Cxx/CMakeLists.txt @@ -19,6 +19,7 @@ add_swift_target_library(swiftCxx ${SWIFT_CXX_LIBRARY_KIND} NO_LINK_NAME IS_STDL CxxSet.swift CxxRandomAccessCollection.swift CxxSequence.swift + CxxVector.swift UnsafeCxxIterators.swift SWIFT_COMPILE_FLAGS ${SWIFT_RUNTIME_SWIFT_COMPILE_FLAGS} ${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS} diff --git a/stdlib/public/Cxx/CxxVector.swift b/stdlib/public/Cxx/CxxVector.swift new file mode 100644 index 0000000000000..03ed6b60ec641 --- /dev/null +++ b/stdlib/public/Cxx/CxxVector.swift @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +/// A C++ type that represents a vector of values. +/// +/// C++ standard library type `std::vector` conforms to this protocol. +public protocol CxxVector: ExpressibleByArrayLiteral { + associatedtype Element + associatedtype RawIterator: UnsafeCxxRandomAccessIterator + where RawIterator.Pointee == Element + + init() + + mutating func push_back(_ element: Element) +} + +extension CxxVector { + /// Creates a C++ vector containing the elements of a Swift Sequence. + /// + /// This initializes the vector by copying every element of the sequence. + /// + /// - Complexity: O(*n*), where *n* is the number of elements in the Swift + /// sequence + @inlinable + public init(_ sequence: __shared S) where S.Element == Element { + self.init() + for item in sequence { + self.push_back(item) + } + } +} + +extension CxxVector { + @inlinable + public init(arrayLiteral elements: Element...) { + self.init(elements) + } +} diff --git a/test/Interop/Cxx/stdlib/Inputs/std-vector.h b/test/Interop/Cxx/stdlib/Inputs/std-vector.h index f7c375ad854ff..050d3e3c0ad4b 100644 --- a/test/Interop/Cxx/stdlib/Inputs/std-vector.h +++ b/test/Interop/Cxx/stdlib/Inputs/std-vector.h @@ -9,4 +9,8 @@ using VectorOfString = std::vector; inline Vector initVector() { return {}; } +inline std::string takesVectorOfString(const VectorOfString &v) { + return v.front(); +} + #endif // TEST_INTEROP_CXX_STDLIB_INPUTS_STD_VECTOR_H \ No newline at end of file diff --git a/test/Interop/Cxx/stdlib/use-std-vector.swift b/test/Interop/Cxx/stdlib/use-std-vector.swift index a51a979034355..123b979eb9eba 100644 --- a/test/Interop/Cxx/stdlib/use-std-vector.swift +++ b/test/Interop/Cxx/stdlib/use-std-vector.swift @@ -19,6 +19,48 @@ StdVectorTestSuite.test("VectorOfInt.init") { expectTrue(v.empty()) } +StdVectorTestSuite.test("VectorOfInt.init(sequence)") { + let v = Vector([]) + expectEqual(v.size(), 0) + expectTrue(v.empty()) + + let v2 = Vector([1, 2, 3]) + expectEqual(v2.size(), 3) + expectFalse(v2.empty()) + expectEqual(v2[0], 1) + expectEqual(v2[1], 2) + expectEqual(v2[2], 3) +} + +StdVectorTestSuite.test("VectorOfString.init(sequence)") { + let v = VectorOfString([]) + expectEqual(v.size(), 0) + expectTrue(v.empty()) + + let v2 = VectorOfString(["", "ab", "abc"]) + expectEqual(v2.size(), 3) + expectFalse(v2.empty()) + expectEqual(v2[0], "") + expectEqual(v2[1], "ab") + expectEqual(v2[2], "abc") + + let first = takesVectorOfString(["abc", "qwe"]) + expectEqual(first, "abc") +} + +StdVectorTestSuite.test("VectorOfInt as ExpressibleByArrayLiteral") { + let v: Vector = [] + expectEqual(v.size(), 0) + expectTrue(v.empty()) + + let v2: Vector = [1, 2, 3] + expectEqual(v2.size(), 3) + expectFalse(v2.empty()) + expectEqual(v2[0], 1) + expectEqual(v2[1], 2) + expectEqual(v2[2], 3) +} + StdVectorTestSuite.test("VectorOfInt.push_back") { var v = Vector() let _42: CInt = 42