Skip to content

Commit 0680b7b

Browse files
authored
Add co-operative cancellation to async writer and passthrough source (#1414)
Motivation: Whenever we create continuations we should be careful to add support for co-operative cancellation via a cancellation handler. Modifications: - Add co-operative cancellation to the async write and passthrough source - Tests Result: Better cancellation support
1 parent 938d141 commit 0680b7b

File tree

6 files changed

+147
-51
lines changed

6 files changed

+147
-51
lines changed

Sources/GRPC/AsyncAwaitSupport/AsyncWriter.swift

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -245,50 +245,45 @@ internal final actor AsyncWriter<Delegate: AsyncWriterDelegate>: Sendable {
245245
/// have been suspended.
246246
@inlinable
247247
internal func write(_ element: Element) async throws {
248-
try await withCheckedThrowingContinuation { continuation in
249-
self._write(element, continuation: continuation)
250-
}
251-
}
252-
253-
@inlinable
254-
internal func _write(_ element: Element, continuation: CheckedContinuation<Void, Error>) {
255248
// There are three outcomes of writing:
256249
// - write the element directly (if the writer isn't paused and no writes are pending)
257250
// - queue the element (the writer is paused or there are writes already pending)
258251
// - error (the writer is complete or the queue is full).
259-
260-
if self._completionState.isPendingOrCompleted {
261-
continuation.resume(throwing: GRPCAsyncWriterError.alreadyFinished)
262-
} else if !self._isPaused, self._pendingElements.isEmpty {
263-
self._delegate.write(element)
264-
continuation.resume()
265-
} else if self._pendingElements.count < self._maxPendingElements {
266-
// The continuation will be resumed later.
267-
self._pendingElements.append(PendingElement(element, continuation: continuation))
268-
} else {
269-
continuation.resume(throwing: GRPCAsyncWriterError.tooManyPendingWrites)
252+
return try await withTaskCancellationHandler {
253+
if self._completionState.isPendingOrCompleted {
254+
throw GRPCAsyncWriterError.alreadyFinished
255+
} else if !self._isPaused, self._pendingElements.isEmpty {
256+
self._delegate.write(element)
257+
} else if self._pendingElements.count < self._maxPendingElements {
258+
// The continuation will be resumed later.
259+
try await withCheckedThrowingContinuation { continuation in
260+
self._pendingElements.append(PendingElement(element, continuation: continuation))
261+
}
262+
} else {
263+
throw GRPCAsyncWriterError.tooManyPendingWrites
264+
}
265+
} onCancel: {
266+
self.cancelAsynchronously()
270267
}
271268
}
272269

273270
/// Write the final element
274271
@inlinable
275272
internal func finish(_ end: End) async throws {
276-
try await withCheckedThrowingContinuation { continuation in
277-
self._finish(end, continuation: continuation)
278-
}
279-
}
280-
281-
@inlinable
282-
internal func _finish(_ end: End, continuation: CheckedContinuation<Void, Error>) {
283-
if self._completionState.isPendingOrCompleted {
284-
continuation.resume(throwing: GRPCAsyncWriterError.alreadyFinished)
285-
} else if !self._isPaused, self._pendingElements.isEmpty {
286-
self._completionState = .completed
287-
self._delegate.writeEnd(end)
288-
continuation.resume()
289-
} else {
290-
// Either we're paused or there are pending writes which must be consumed first.
291-
self._completionState = .pending(PendingEnd(end, continuation: continuation))
273+
return try await withTaskCancellationHandler {
274+
if self._completionState.isPendingOrCompleted {
275+
throw GRPCAsyncWriterError.alreadyFinished
276+
} else if !self._isPaused, self._pendingElements.isEmpty {
277+
self._completionState = .completed
278+
self._delegate.writeEnd(end)
279+
} else {
280+
try await withCheckedThrowingContinuation { continuation in
281+
// Either we're paused or there are pending writes which must be consumed first.
282+
self._completionState = .pending(PendingEnd(end, continuation: continuation))
283+
}
284+
}
285+
} onCancel: {
286+
self.cancelAsynchronously()
292287
}
293288
}
294289
}

Sources/GRPC/AsyncAwaitSupport/PassthroughMessageSequence.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ internal struct PassthroughMessageSequence<Element, Failure: Error>: AsyncSequen
5050

5151
@inlinable
5252
internal func next() async throws -> Element? {
53+
// The storage handles co-operative cancellation, so we don't bother checking here.
5354
return try await self._storage.consumeNextElement()
5455
}
5556
}

Sources/GRPC/AsyncAwaitSupport/PassthroughMessageSource.swift

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,14 @@ internal final class PassthroughMessageSource<Element, Failure: Error> {
108108
let result: _YieldResult = self._lock.withLock {
109109
if self._isTerminated {
110110
return .alreadyTerminated
111-
} else if let continuation = self._continuation {
111+
} else {
112112
self._isTerminated = isTerminator
113+
}
114+
115+
if let continuation = self._continuation {
113116
self._continuation = nil
114117
return .resume(continuation)
115118
} else {
116-
self._isTerminated = isTerminator
117119
self._continuationResults.append(continuationResult)
118120
return .queued(self._continuationResults.count)
119121
}
@@ -138,28 +140,31 @@ internal final class PassthroughMessageSource<Element, Failure: Error> {
138140

139141
@inlinable
140142
internal func consumeNextElement() async throws -> Element? {
141-
return try await withCheckedThrowingContinuation {
142-
self._consumeNextElement(continuation: $0)
143+
self._lock.lock()
144+
if let nextResult = self._continuationResults.popFirst() {
145+
self._lock.unlock()
146+
return try nextResult.get()
147+
} else if self._isTerminated {
148+
self._lock.unlock()
149+
return nil
143150
}
144-
}
145151

146-
@inlinable
147-
internal func _consumeNextElement(continuation: CheckedContinuation<Element?, Error>) {
148-
let continuationResult: _ContinuationResult? = self._lock.withLock {
149-
if let nextResult = self._continuationResults.popFirst() {
150-
return nextResult
151-
} else if self._isTerminated {
152-
return .success(nil)
153-
} else {
152+
// Slow path; we need a continuation.
153+
return try await withTaskCancellationHandler {
154+
try await withCheckedThrowingContinuation { continuation in
154155
// Nothing buffered and not terminated yet: save the continuation for later.
155156
precondition(self._continuation == nil)
156157
self._continuation = continuation
157-
return nil
158+
self._lock.unlock()
159+
}
160+
} onCancel: {
161+
let continuation: CheckedContinuation<Element?, Error>? = self._lock.withLock {
162+
let cont = self._continuation
163+
self._continuation = nil
164+
return cont
158165
}
159-
}
160166

161-
if let continuationResult = continuationResult {
162-
continuation.resume(with: continuationResult)
167+
continuation?.resume(throwing: CancellationError())
163168
}
164169
}
165170
}

Tests/GRPCTests/AsyncAwaitSupport/AsyncWriterTests.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,34 @@ internal class AsyncWriterTests: GRPCTestCase {
243243
XCTAssertTrue(delegate.elements.isEmpty)
244244
XCTAssertNil(delegate.end)
245245
}
246+
247+
func testCooperativeCancellationOnWrite() async throws {
248+
let delegate = CollectingDelegate<String, Void>()
249+
let writer = AsyncWriter(isWritable: false, delegate: delegate)
250+
try await withTaskCancelledAfter(nanoseconds: 100_000) {
251+
do {
252+
// Without co-operative cancellation then this will suspend indefinitely.
253+
try await writer.write("I should be cancelled")
254+
XCTFail("write(_:) should throw CancellationError")
255+
} catch {
256+
XCTAssert(error is CancellationError)
257+
}
258+
}
259+
}
260+
261+
func testCooperativeCancellationOnFinish() async throws {
262+
let delegate = CollectingDelegate<String, Void>()
263+
let writer = AsyncWriter(isWritable: false, delegate: delegate)
264+
try await withTaskCancelledAfter(nanoseconds: 100_000) {
265+
do {
266+
// Without co-operative cancellation then this will suspend indefinitely.
267+
try await writer.finish()
268+
XCTFail("finish() should throw CancellationError")
269+
} catch {
270+
XCTAssert(error is CancellationError)
271+
}
272+
}
273+
}
246274
}
247275

248276
fileprivate final class CollectingDelegate<

Tests/GRPCTests/AsyncAwaitSupport/PassthroughMessageSourceTests.swift

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,33 @@ class PassthroughMessageSourceTests: GRPCTestCase {
126126
}
127127
}
128128
}
129+
130+
func testCooperativeCancellationOfSourceOnNext() async throws {
131+
let source = PassthroughMessageSource<String, TestError>()
132+
try await withTaskCancelledAfter(nanoseconds: 100_000) {
133+
do {
134+
_ = try await source.consumeNextElement()
135+
XCTFail("consumeNextElement() should throw CancellationError")
136+
} catch {
137+
XCTAssert(error is CancellationError)
138+
}
139+
}
140+
}
141+
142+
func testCooperativeCancellationOfSequenceOnNext() async throws {
143+
let source = PassthroughMessageSource<String, TestError>()
144+
let sequence = PassthroughMessageSequence(consuming: source)
145+
try await withTaskCancelledAfter(nanoseconds: 100_000) {
146+
do {
147+
for try await _ in sequence {
148+
XCTFail("consumeNextElement() should throw CancellationError")
149+
}
150+
XCTFail("consumeNextElement() should throw CancellationError")
151+
} catch {
152+
XCTAssert(error is CancellationError)
153+
}
154+
}
155+
}
129156
}
130157

131158
fileprivate struct TestError: Error {}

Tests/GRPCTests/AsyncAwaitSupport/XCTest+AsyncAwait.swift

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,44 @@ internal func XCTAssertThrowsError<T>(
3131
}
3232
}
3333

34+
fileprivate enum TaskResult<Result> {
35+
case operation(Result)
36+
case cancellation
37+
}
38+
39+
@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
40+
func withTaskCancelledAfter<Result>(
41+
nanoseconds: UInt64,
42+
operation: @escaping @Sendable () async -> Result
43+
) async throws {
44+
try await withThrowingTaskGroup(of: TaskResult<Result>.self) { group in
45+
group.addTask {
46+
return .operation(await operation())
47+
}
48+
49+
group.addTask {
50+
try await Task.sleep(nanoseconds: nanoseconds)
51+
return .cancellation
52+
}
53+
54+
// Only the sleeping task can throw if it's cancelled, in which case we want to throw.
55+
let firstResult = try await group.next()
56+
// A task completed, cancel the rest.
57+
group.cancelAll()
58+
59+
// Check which task completed.
60+
switch firstResult {
61+
case .cancellation:
62+
() // Fine, what we expect.
63+
case .operation:
64+
XCTFail("Operation completed before cancellation")
65+
case .none:
66+
XCTFail("No tasks completed")
67+
}
68+
69+
// Wait for the other task. The operation cannot, only the sleeping task can.
70+
try await group.waitForAll()
71+
}
72+
}
73+
3474
#endif // compiler(>=5.6)

0 commit comments

Comments
 (0)