Skip to content

Close client request stream on error and end #1410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ public struct GRPCAsyncBidirectionalStreamingCall<Request: Sendable, Response: S
onError: { error in
asyncCall.responseParts.handleError(error)
asyncCall.responseSource.finish(throwing: error)
asyncCall.requestStream.asyncWriter.cancelAsynchronously()
},
onResponsePart: AsyncCall.makeResponsePartHandler(
responseParts: asyncCall.responseParts,
responseSource: asyncCall.responseSource
responseSource: asyncCall.responseSource,
requestStream: asyncCall.requestStream
)
)

Expand All @@ -102,9 +104,11 @@ public struct GRPCAsyncBidirectionalStreamingCall<Request: Sendable, Response: S

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
internal enum AsyncCall {
internal static func makeResponsePartHandler<Response>(
internal static func makeResponsePartHandler<Response, Request>(
responseParts: StreamingResponseParts<Response>,
responseSource: PassthroughMessageSource<Response, Error>
responseSource: PassthroughMessageSource<Response, Error>,
requestStream: GRPCAsyncRequestStreamWriter<Request>?,
requestType: Request.Type = Request.self
) -> (GRPCClientResponsePart<Response>) -> Void {
return { responsePart in
// Handle the metadata, trailers and status.
Expand All @@ -125,6 +129,28 @@ internal enum AsyncCall {
} else {
responseSource.finish(throwing: status)
}

requestStream?.asyncWriter.cancelAsynchronously()
}
}
}

internal static func makeResponsePartHandler<Response, Request>(
responseParts: UnaryResponseParts<Response>,
requestStream: GRPCAsyncRequestStreamWriter<Request>?,
requestType: Request.Type = Request.self,
responseType: Response.Type = Response.self
) -> (GRPCClientResponsePart<Response>) -> Void {
return { responsePart in
// Handle (most of) all parts.
responseParts.handle(responsePart)

// Handle the status.
switch responsePart {
case .metadata, .message:
()
case .end:
requestStream?.asyncWriter.cancelAsynchronously()
}
}
}
Expand Down
19 changes: 14 additions & 5 deletions Sources/GRPC/AsyncAwaitSupport/GRPCAsyncClientStreamingCall.swift
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,26 @@ public struct GRPCAsyncClientStreamingCall<Request: Sendable, Response: Sendable
private init(call: Call<Request, Response>) {
self.call = call
self.responseParts = UnaryResponseParts(on: call.eventLoop)
self.call.invokeStreamingRequests(
onError: self.responseParts.handleError(_:),
onResponsePart: self.responseParts.handle(_:)
)
self.requestStream = call.makeRequestStreamWriter()
}

/// We expose this as the only non-private initializer so that the caller
/// knows that invocation is part of initialisation.
internal static func makeAndInvoke(call: Call<Request, Response>) -> Self {
Self(call: call)
let asyncCall = Self(call: call)

asyncCall.call.invokeStreamingRequests(
onError: { error in
asyncCall.responseParts.handleError(error)
asyncCall.requestStream.asyncWriter.cancelAsynchronously()
},
onResponsePart: AsyncCall.makeResponsePartHandler(
responseParts: asyncCall.responseParts,
requestStream: asyncCall.requestStream
)
)

return asyncCall
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ public struct GRPCAsyncServerStreamingCall<Request: Sendable, Response: Sendable
},
onResponsePart: AsyncCall.makeResponsePartHandler(
responseParts: asyncCall.responseParts,
responseSource: asyncCall.responseSource
responseSource: asyncCall.responseSource,
requestStream: nil,
requestType: Request.self
)
)

Expand Down
177 changes: 177 additions & 0 deletions Tests/GRPCTests/AsyncAwaitSupport/AsyncClientTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Copyright 2022, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if compiler(>=5.6)
import EchoImplementation
import EchoModel
import GRPC
import NIOCore
import NIOPosix
import XCTest

@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
final class AsyncClientCancellationTests: GRPCTestCase {
private var server: Server!
private var group: EventLoopGroup!
private var pool: GRPCChannel!

override func setUpWithError() throws {
try super.setUpWithError()
self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
}

override func tearDown() async throws {
try self.pool.close().wait()
self.pool = nil

try self.server.close().wait()
self.server = nil

try self.group.syncShutdownGracefully()
self.group = nil

try await super.tearDown()
}

private func startServer(service: CallHandlerProvider) throws -> Echo_EchoAsyncClient {
precondition(self.server == nil)
precondition(self.pool == nil)

self.server = try Server.insecure(group: self.group)
.withServiceProviders([service])
.withLogger(self.serverLogger)
.bind(host: "127.0.0.1", port: 0)
.wait()

self.pool = try GRPCChannelPool.with(
target: .host("127.0.0.1", port: self.server.channel.localAddress!.port!),
transportSecurity: .plaintext,
eventLoopGroup: self.group
) {
$0.backgroundActivityLogger = self.clientLogger
}

return Echo_EchoAsyncClient(channel: self.pool)
}

func testCancelUnaryFailsResponse() async throws {
// We don't want the RPC to complete before we cancel it so use the never resolving service.
let echo = try self.startServer(service: NeverResolvingEchoProvider())

let get = echo.makeGetCall(.with { $0.text = "foo bar baz" })
try await get.cancel()

await XCTAssertThrowsError(try await get.response)

// Status should be 'cancelled'.
let status = await get.status
XCTAssertEqual(status.code, .cancelled)
}

func testCancelServerStreamingClosesResponseStream() async throws {
// We don't want the RPC to complete before we cancel it so use the never resolving service.
let echo = try self.startServer(service: NeverResolvingEchoProvider())

let expand = echo.makeExpandCall(.with { $0.text = "foo bar baz" })
try await expand.cancel()

var responseStream = expand.responseStream.makeAsyncIterator()
await XCTAssertThrowsError(try await responseStream.next())

// Status should be 'cancelled'.
let status = await expand.status
XCTAssertEqual(status.code, .cancelled)
}

func testCancelClientStreamingClosesRequestStreamAndFailsResponse() async throws {
let echo = try self.startServer(service: EchoProvider())

let collect = echo.makeCollectCall()
// Make sure the stream is up before we cancel it.
try await collect.requestStream.send(.with { $0.text = "foo" })
try await collect.cancel()

// The next send should fail.
await XCTAssertThrowsError(try await collect.requestStream.send(.with { $0.text = "foo" }))
// There should be no response.
await XCTAssertThrowsError(try await collect.response)

// Status should be 'cancelled'.
let status = await collect.status
XCTAssertEqual(status.code, .cancelled)
}

func testClientStreamingClosesRequestStreamOnEnd() async throws {
let echo = try self.startServer(service: EchoProvider())

let collect = echo.makeCollectCall()
// Send and close.
try await collect.requestStream.send(.with { $0.text = "foo" })
try await collect.requestStream.finish()

// Await the response and status.
_ = try await collect.response
let status = await collect.status
XCTAssert(status.isOk)

// Sending should fail.
await XCTAssertThrowsError(
try await collect.requestStream.send(.with { $0.text = "should throw" })
)
}

func testCancelBidiStreamingClosesRequestStreamAndResponseStream() async throws {
let echo = try self.startServer(service: EchoProvider())

let update = echo.makeUpdateCall()
// Make sure the stream is up before we cancel it.
try await update.requestStream.send(.with { $0.text = "foo" })
// Wait for the response.
var responseStream = update.responseStream.makeAsyncIterator()
_ = try await responseStream.next()

// Now cancel. The next send should fail and we shouldn't receive any more responses.
try await update.cancel()
await XCTAssertThrowsError(try await update.requestStream.send(.with { $0.text = "foo" }))
await XCTAssertThrowsError(try await responseStream.next())

// Status should be 'cancelled'.
let status = await update.status
XCTAssertEqual(status.code, .cancelled)
}

func testBidiStreamingClosesRequestStreamOnEnd() async throws {
let echo = try self.startServer(service: EchoProvider())

let update = echo.makeUpdateCall()
// Send and close.
try await update.requestStream.send(.with { $0.text = "foo" })
try await update.requestStream.finish()

// Await the response and status.
let responseCount = try await update.responseStream.count()
XCTAssertEqual(responseCount, 1)

let status = await update.status
XCTAssert(status.isOk)

// Sending should fail.
await XCTAssertThrowsError(
try await update.requestStream.send(.with { $0.text = "should throw" })
)
}
}

#endif // compiler(>=5.6)