From 3754db92483a97d186d31f9cc60b5dbc169619e9 Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Fri, 27 Sep 2024 17:49:38 +0200 Subject: [PATCH 1/2] Fix sftp async methods not observing error conditions --- src/Renci.SshNet/ISubsystemSession.cs | 6 +- src/Renci.SshNet/Sftp/SftpSession.cs | 533 ++++++++---------- src/Renci.SshNet/SftpClient.cs | 10 +- src/Renci.SshNet/SubsystemSession.cs | 63 ++- .../Classes/SftpClientTest.cs | 32 -- .../Classes/SftpClientTest_AsyncExceptions.cs | 268 +++++++++ 6 files changed, 570 insertions(+), 342 deletions(-) create mode 100644 test/Renci.SshNet.Tests/Classes/SftpClientTest_AsyncExceptions.cs diff --git a/src/Renci.SshNet/ISubsystemSession.cs b/src/Renci.SshNet/ISubsystemSession.cs index 44e190825..4fa6b28b2 100644 --- a/src/Renci.SshNet/ISubsystemSession.cs +++ b/src/Renci.SshNet/ISubsystemSession.cs @@ -11,12 +11,12 @@ namespace Renci.SshNet internal interface ISubsystemSession : IDisposable { /// - /// Gets or set the number of seconds to wait for an operation to complete. + /// Gets or sets the number of milliseconds to wait for an operation to complete. /// /// - /// The number of seconds to wait for an operation to complete, or -1 to wait indefinitely. + /// The number of milliseconds to wait for an operation to complete, or -1 to wait indefinitely. /// - int OperationTimeout { get; } + int OperationTimeout { get; set; } /// /// Gets a value indicating whether this session is open. diff --git a/src/Renci.SshNet/Sftp/SftpSession.cs b/src/Renci.SshNet/Sftp/SftpSession.cs index a023854bc..7fd8888dd 100644 --- a/src/Renci.SshNet/Sftp/SftpSession.cs +++ b/src/Renci.SshNet/Sftp/SftpSession.cs @@ -523,28 +523,24 @@ public byte[] RequestOpen(string path, Flags flags, bool nullOnError = false) /// A task that represents the asynchronous SSH_FXP_OPEN request. The value of its /// contains the file handle of the specified path. /// - public async Task RequestOpenAsync(string path, Flags flags, CancellationToken cancellationToken) + public Task RequestOpenAsync(string path, Flags flags, CancellationToken cancellationToken) { - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - SendRequest(new SftpOpenRequest(ProtocolVersion, - NextRequestId, - path, - _encoding, - flags, - response => tcs.TrySetResult(response.Handle), - response => tcs.TrySetException(GetSftpException(response)))); + SendRequest(new SftpOpenRequest(ProtocolVersion, + NextRequestId, + path, + _encoding, + flags, + response => tcs.TrySetResult(response.Handle), + response => tcs.TrySetException(GetSftpException(response)))); - return await tcs.Task.ConfigureAwait(false); - } + return WaitOnHandleAsync(tcs, OperationTimeout, cancellationToken); } /// @@ -651,8 +647,13 @@ public void RequestClose(byte[] handle) /// /// A task that represents the asynchronous SSH_FXP_CLOSE request. /// - public async Task RequestCloseAsync(byte[] handle, CancellationToken cancellationToken) + public Task RequestCloseAsync(byte[] handle, CancellationToken cancellationToken) { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); SendRequest(new SftpCloseRequest(ProtocolVersion, @@ -670,17 +671,7 @@ public async Task RequestCloseAsync(byte[] handle, CancellationToken cancellatio } })); - // Only check for cancellation after the SftpCloseRequest was sent - cancellationToken.ThrowIfCancellationRequested(); - -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - _ = await tcs.Task.ConfigureAwait(false); - } + return WaitOnHandleAsync(tcs, OperationTimeout, cancellationToken); } /// @@ -875,38 +866,34 @@ public byte[] RequestRead(byte[] handle, ulong offset, uint length) /// its contains the data read from the file, or an empty /// array when the end of the file is reached. /// - public async Task RequestReadAsync(byte[] handle, ulong offset, uint length, CancellationToken cancellationToken) + public Task RequestReadAsync(byte[] handle, ulong offset, uint length, CancellationToken cancellationToken) { - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - SendRequest(new SftpReadRequest(ProtocolVersion, - NextRequestId, - handle, - offset, - length, - response => tcs.TrySetResult(response.Data), - response => + SendRequest(new SftpReadRequest(ProtocolVersion, + NextRequestId, + handle, + offset, + length, + response => tcs.TrySetResult(response.Data), + response => + { + if (response.StatusCode == StatusCodes.Eof) { - if (response.StatusCode == StatusCodes.Eof) - { - _ = tcs.TrySetResult(Array.Empty()); - } - else - { - _ = tcs.TrySetException(GetSftpException(response)); - } - })); + _ = tcs.TrySetResult(Array.Empty()); + } + else + { + _ = tcs.TrySetException(GetSftpException(response)); + } + })); - return await tcs.Task.ConfigureAwait(false); - } + return WaitOnHandleAsync(tcs, OperationTimeout, cancellationToken); } /// @@ -972,39 +959,35 @@ public void RequestWrite(byte[] handle, /// /// A task that represents the asynchronous SSH_FXP_WRITE request. /// - public async Task RequestWriteAsync(byte[] handle, ulong serverOffset, byte[] data, int offset, int length, CancellationToken cancellationToken) + public Task RequestWriteAsync(byte[] handle, ulong serverOffset, byte[] data, int offset, int length, CancellationToken cancellationToken) { - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - SendRequest(new SftpWriteRequest(ProtocolVersion, - NextRequestId, - handle, - serverOffset, - data, - offset, - length, - response => - { - if (response.StatusCode == StatusCodes.Ok) - { - _ = tcs.TrySetResult(true); - } - else - { - _ = tcs.TrySetException(GetSftpException(response)); - } - })); + SendRequest(new SftpWriteRequest(ProtocolVersion, + NextRequestId, + handle, + serverOffset, + data, + offset, + length, + response => + { + if (response.StatusCode == StatusCodes.Ok) + { + _ = tcs.TrySetResult(true); + } + else + { + _ = tcs.TrySetException(GetSftpException(response)); + } + })); - _ = await tcs.Task.ConfigureAwait(false); - } + return WaitOnHandleAsync(tcs, OperationTimeout, cancellationToken); } /// @@ -1058,27 +1041,23 @@ public SftpFileAttributes RequestLStat(string path) /// A task the represents the asynchronous SSH_FXP_LSTAT request. The value of its /// contains the file attributes of the specified path. /// - public async Task RequestLStatAsync(string path, CancellationToken cancellationToken) + public Task RequestLStatAsync(string path, CancellationToken cancellationToken) { - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - SendRequest(new SftpLStatRequest(ProtocolVersion, - NextRequestId, - path, - _encoding, - response => tcs.TrySetResult(response.Attributes), - response => tcs.TrySetException(GetSftpException(response)))); + SendRequest(new SftpLStatRequest(ProtocolVersion, + NextRequestId, + path, + _encoding, + response => tcs.TrySetResult(response.Attributes), + response => tcs.TrySetException(GetSftpException(response)))); - return await tcs.Task.ConfigureAwait(false); - } + return WaitOnHandleAsync(tcs, OperationTimeout, cancellationToken); } /// @@ -1191,26 +1170,22 @@ public SftpFileAttributes RequestFStat(byte[] handle, bool nullOnError) /// A task that represents the asynchronous SSH_FXP_FSTAT request. The value of its /// contains the file attributes of the specified handle. /// - public async Task RequestFStatAsync(byte[] handle, CancellationToken cancellationToken) + public Task RequestFStatAsync(byte[] handle, CancellationToken cancellationToken) { - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - SendRequest(new SftpFStatRequest(ProtocolVersion, - NextRequestId, - handle, - response => tcs.TrySetResult(response.Attributes), - response => tcs.TrySetException(GetSftpException(response)))); + SendRequest(new SftpFStatRequest(ProtocolVersion, + NextRequestId, + handle, + response => tcs.TrySetResult(response.Attributes), + response => tcs.TrySetException(GetSftpException(response)))); - return await tcs.Task.ConfigureAwait(false); - } + return WaitOnHandleAsync(tcs, OperationTimeout, cancellationToken); } /// @@ -1329,27 +1304,23 @@ public byte[] RequestOpenDir(string path, bool nullOnError = false) /// A task that represents the asynchronous SSH_FXP_OPENDIR request. The value of its /// contains the handle of the specified path. /// - public async Task RequestOpenDirAsync(string path, CancellationToken cancellationToken) + public Task RequestOpenDirAsync(string path, CancellationToken cancellationToken) { - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - SendRequest(new SftpOpenDirRequest(ProtocolVersion, - NextRequestId, - path, - _encoding, - response => tcs.TrySetResult(response.Handle), - response => tcs.TrySetException(GetSftpException(response)))); + SendRequest(new SftpOpenDirRequest(ProtocolVersion, + NextRequestId, + path, + _encoding, + response => tcs.TrySetResult(response.Handle), + response => tcs.TrySetException(GetSftpException(response)))); - return await tcs.Task.ConfigureAwait(false); - } + return WaitOnHandleAsync(tcs, OperationTimeout, cancellationToken); } /// @@ -1410,36 +1381,32 @@ public KeyValuePair[] RequestReadDir(byte[] handle) /// key is the name of a file in the directory and the value is the /// of the file. /// - public async Task[]> RequestReadDirAsync(byte[] handle, CancellationToken cancellationToken) + public Task[]> RequestReadDirAsync(byte[] handle, CancellationToken cancellationToken) { - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled[]>(cancellationToken); + } var tcs = new TaskCompletionSource[]>(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(s => ((TaskCompletionSource[]>)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(s => ((TaskCompletionSource[]>)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - SendRequest(new SftpReadDirRequest(ProtocolVersion, - NextRequestId, - handle, - response => tcs.TrySetResult(response.Files), - response => + SendRequest(new SftpReadDirRequest(ProtocolVersion, + NextRequestId, + handle, + response => tcs.TrySetResult(response.Files), + response => + { + if (response.StatusCode == StatusCodes.Eof) { - if (response.StatusCode == StatusCodes.Eof) - { - _ = tcs.TrySetResult(null); - } - else - { - _ = tcs.TrySetException(GetSftpException(response)); - } - })); + _ = tcs.TrySetResult(null); + } + else + { + _ = tcs.TrySetException(GetSftpException(response)); + } + })); - return await tcs.Task.ConfigureAwait(false); - } + return WaitOnHandleAsync(tcs, OperationTimeout, cancellationToken); } /// @@ -1481,36 +1448,32 @@ public void RequestRemove(string path) /// /// A task that represents the asynchronous SSH_FXP_REMOVE request. /// - public async Task RequestRemoveAsync(string path, CancellationToken cancellationToken) + public Task RequestRemoveAsync(string path, CancellationToken cancellationToken) { - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - SendRequest(new SftpRemoveRequest(ProtocolVersion, - NextRequestId, - path, - _encoding, - response => - { - if (response.StatusCode == StatusCodes.Ok) - { - _ = tcs.TrySetResult(true); - } - else - { - _ = tcs.TrySetException(GetSftpException(response)); - } - })); + SendRequest(new SftpRemoveRequest(ProtocolVersion, + NextRequestId, + path, + _encoding, + response => + { + if (response.StatusCode == StatusCodes.Ok) + { + _ = tcs.TrySetResult(true); + } + else + { + _ = tcs.TrySetException(GetSftpException(response)); + } + })); - _ = await tcs.Task.ConfigureAwait(false); - } + return WaitOnHandleAsync(tcs, OperationTimeout, cancellationToken); } /// @@ -1550,36 +1513,32 @@ public void RequestMkDir(string path) /// The path. /// The to observe. /// A that represents the asynchronous SSH_FXP_MKDIR operation. - public async Task RequestMkDirAsync(string path, CancellationToken cancellationToken = default) + public Task RequestMkDirAsync(string path, CancellationToken cancellationToken = default) { - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - SendRequest(new SftpMkDirRequest(ProtocolVersion, - NextRequestId, - path, - _encoding, - response => + SendRequest(new SftpMkDirRequest(ProtocolVersion, + NextRequestId, + path, + _encoding, + response => + { + if (response.StatusCode == StatusCodes.Ok) { - if (response.StatusCode == StatusCodes.Ok) - { - _ = tcs.TrySetResult(true); - } - else - { - tcs.TrySetException(GetSftpException(response)); - } - })); + _ = tcs.TrySetResult(true); + } + else + { + _ = tcs.TrySetException(GetSftpException(response)); + } + })); - _ = await tcs.Task.ConfigureAwait(false); - } + return WaitOnHandleAsync(tcs, OperationTimeout, cancellationToken); } /// @@ -1614,37 +1573,33 @@ public void RequestRmDir(string path) } /// - public async Task RequestRmDirAsync(string path, CancellationToken cancellationToken = default) + public Task RequestRmDirAsync(string path, CancellationToken cancellationToken = default) { - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - SendRequest(new SftpRmDirRequest(ProtocolVersion, - NextRequestId, - path, - _encoding, - response => + SendRequest(new SftpRmDirRequest(ProtocolVersion, + NextRequestId, + path, + _encoding, + response => + { + var exception = GetSftpException(response); + if (exception is not null) { - var exception = GetSftpException(response); - if (exception is not null) - { - tcs.TrySetException(exception); - } - else - { - tcs.TrySetResult(true); - } - })); + _ = tcs.TrySetException(exception); + } + else + { + _ = tcs.TrySetResult(true); + } + })); - _ = await tcs.Task.ConfigureAwait(false); - } + return WaitOnHandleAsync(tcs, OperationTimeout, cancellationToken); } /// @@ -1691,37 +1646,33 @@ internal KeyValuePair[] RequestRealPath(string path, return result; } - internal async Task[]> RequestRealPathAsync(string path, bool nullOnError, CancellationToken cancellationToken) + internal Task[]> RequestRealPathAsync(string path, bool nullOnError, CancellationToken cancellationToken) { - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled[]>(cancellationToken); + } var tcs = new TaskCompletionSource[]>(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(s => ((TaskCompletionSource[]>)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(s => ((TaskCompletionSource[]>)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - SendRequest(new SftpRealPathRequest(ProtocolVersion, - NextRequestId, - path, - _encoding, - response => tcs.TrySetResult(response.Files), - response => + SendRequest(new SftpRealPathRequest(ProtocolVersion, + NextRequestId, + path, + _encoding, + response => tcs.TrySetResult(response.Files), + response => + { + if (nullOnError) { - if (nullOnError) - { - _ = tcs.TrySetResult(null); - } - else - { - _ = tcs.TrySetException(GetSftpException(response)); - } - })); + _ = tcs.TrySetResult(null); + } + else + { + _ = tcs.TrySetException(GetSftpException(response)); + } + })); - return await tcs.Task.ConfigureAwait(false); - } + return WaitOnHandleAsync(tcs, OperationTimeout, cancellationToken); } /// @@ -1921,37 +1872,33 @@ public void RequestRename(string oldPath, string newPath) /// /// A task that represents the asynchronous SSH_FXP_RENAME request. /// - public async Task RequestRenameAsync(string oldPath, string newPath, CancellationToken cancellationToken) + public Task RequestRenameAsync(string oldPath, string newPath, CancellationToken cancellationToken) { - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - SendRequest(new SftpRenameRequest(ProtocolVersion, - NextRequestId, - oldPath, - newPath, - _encoding, - response => - { - if (response.StatusCode == StatusCodes.Ok) - { - _ = tcs.TrySetResult(true); - } - else - { - _ = tcs.TrySetException(GetSftpException(response)); - } - })); + SendRequest(new SftpRenameRequest(ProtocolVersion, + NextRequestId, + oldPath, + newPath, + _encoding, + response => + { + if (response.StatusCode == StatusCodes.Ok) + { + _ = tcs.TrySetResult(true); + } + else + { + _ = tcs.TrySetException(GetSftpException(response)); + } + })); - _ = await tcs.Task.ConfigureAwait(false); - } + return WaitOnHandleAsync(tcs, OperationTimeout, cancellationToken); } /// @@ -2149,32 +2096,28 @@ public SftpFileSystemInformation RequestStatVfs(string path, bool nullOnError = /// contains the file system information for the specified /// path. /// - public async Task RequestStatVfsAsync(string path, CancellationToken cancellationToken) + public Task RequestStatVfsAsync(string path, CancellationToken cancellationToken) { if (ProtocolVersion < 3) { throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "SSH_FXP_EXTENDED operation is not supported in {0} version that server operates in.", ProtocolVersion)); } - cancellationToken.ThrowIfCancellationRequested(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); -#if NET || NETSTANDARD2_1_OR_GREATER - await using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false)) -#else - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetCanceled(cancellationToken), tcs, useSynchronizationContext: false)) -#endif // NET || NETSTANDARD2_1_OR_GREATER - { - SendRequest(new StatVfsRequest(ProtocolVersion, - NextRequestId, - path, - _encoding, - response => tcs.TrySetResult(response.GetReply().Information), - response => tcs.TrySetException(GetSftpException(response)))); + SendRequest(new StatVfsRequest(ProtocolVersion, + NextRequestId, + path, + _encoding, + response => tcs.TrySetResult(response.GetReply().Information), + response => tcs.TrySetException(GetSftpException(response)))); - return await tcs.Task.ConfigureAwait(false); - } + return WaitOnHandleAsync(tcs, OperationTimeout, cancellationToken); } /// diff --git a/src/Renci.SshNet/SftpClient.cs b/src/Renci.SshNet/SftpClient.cs index d20a8ad0a..4a5d7dbe7 100644 --- a/src/Renci.SshNet/SftpClient.cs +++ b/src/Renci.SshNet/SftpClient.cs @@ -46,21 +46,21 @@ public class SftpClient : BaseClient, ISftpClient /// The timeout to wait until an operation completes. The default value is negative /// one (-1) milliseconds, which indicates an infinite timeout period. /// - /// The method was called after the client was disposed. /// represents a value that is less than -1 or greater than milliseconds. public TimeSpan OperationTimeout { get { - CheckDisposed(); - return TimeSpan.FromMilliseconds(_operationTimeout); } set { - CheckDisposed(); - _operationTimeout = value.AsTimeout(nameof(OperationTimeout)); + + if (_sftpSession is { } sftpSession) + { + sftpSession.OperationTimeout = _operationTimeout; + } } } diff --git a/src/Renci.SshNet/SubsystemSession.cs b/src/Renci.SshNet/SubsystemSession.cs index b5912a805..5ddd80718 100644 --- a/src/Renci.SshNet/SubsystemSession.cs +++ b/src/Renci.SshNet/SubsystemSession.cs @@ -2,6 +2,7 @@ using System.Globalization; using System.Runtime.ExceptionServices; using System.Threading; +using System.Threading.Tasks; using Renci.SshNet.Abstractions; using Renci.SshNet.Channels; @@ -29,13 +30,8 @@ internal abstract class SubsystemSession : ISubsystemSession private EventWaitHandle _channelClosedWaitHandle = new ManualResetEvent(initialState: false); private bool _isDisposed; - /// - /// Gets or set the number of seconds to wait for an operation to complete. - /// - /// - /// The number of seconds to wait for an operation to complete, or -1 to wait indefinitely. - /// - public int OperationTimeout { get; private set; } + /// + public int OperationTimeout { get; set; } /// /// Occurs when an error occurred. @@ -250,6 +246,59 @@ public void WaitOnHandle(WaitHandle waitHandle, int millisecondsTimeout) } } + protected async Task WaitOnHandleAsync(TaskCompletionSource tcs, int millisecondsTimeout, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + var errorOccuredReg = ThreadPool.RegisterWaitForSingleObject( + _errorOccuredWaitHandle, + (tcs, _) => ((TaskCompletionSource)tcs).TrySetException(_exception), + state: tcs, + millisecondsTimeOutInterval: -1, + executeOnlyOnce: true); + + var sessionDisconnectedReg = ThreadPool.RegisterWaitForSingleObject( + _sessionDisconnectedWaitHandle, + static (tcs, _) => ((TaskCompletionSource)tcs).TrySetException(new SshException("Connection was closed by the server.")), + state: tcs, + millisecondsTimeOutInterval: -1, + executeOnlyOnce: true); + + var channelClosedReg = ThreadPool.RegisterWaitForSingleObject( + _channelClosedWaitHandle, + static (tcs, _) => ((TaskCompletionSource)tcs).TrySetException(new SshException("Channel was closed.")), + state: tcs, + millisecondsTimeOutInterval: -1, + executeOnlyOnce: true); + + using var timeoutCts = new CancellationTokenSource(millisecondsTimeout); + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token); + + using var tokenReg = linkedCts.Token.Register( + static s => + { + (var tcs, var cancellationToken) = ((TaskCompletionSource, CancellationToken))s; + _ = tcs.TrySetCanceled(cancellationToken); + }, + state: (tcs, cancellationToken), + useSynchronizationContext: false); + + try + { + return await tcs.Task.ConfigureAwait(false); + } + catch (OperationCanceledException oce) when (timeoutCts.IsCancellationRequested) + { + throw new SshOperationTimeoutException("Operation has timed out.", oce); + } + finally + { + _ = errorOccuredReg.Unregister(waitObject: null); + _ = sessionDisconnectedReg.Unregister(waitObject: null); + _ = channelClosedReg.Unregister(waitObject: null); + } + } + /// /// Blocks the current thread until the specified gets signaled, using a /// 32-bit signed integer to specify the time interval in milliseconds. diff --git a/test/Renci.SshNet.Tests/Classes/SftpClientTest.cs b/test/Renci.SshNet.Tests/Classes/SftpClientTest.cs index 8e8ecd3a2..21c2f2758 100644 --- a/test/Renci.SshNet.Tests/Classes/SftpClientTest.cs +++ b/test/Renci.SshNet.Tests/Classes/SftpClientTest.cs @@ -115,37 +115,5 @@ public void OperationTimeout_GreaterThanLowerLimit() Assert.AreEqual("OperationTimeout", ex.ParamName); } } - - [TestMethod] - public void OperationTimeout_Disposed() - { - var connectionInfo = new PasswordConnectionInfo("host", 22, "admin", "pwd"); - var target = new SftpClient(connectionInfo); - target.Dispose(); - - // getter - try - { - var actual = target.OperationTimeout; - Assert.Fail("Should have failed, but returned: " + actual); - } - catch (ObjectDisposedException ex) - { - Assert.IsNull(ex.InnerException); - Assert.AreEqual(typeof(SftpClient).FullName, ex.ObjectName); - } - - // setter - try - { - target.OperationTimeout = TimeSpan.FromMilliseconds(5); - Assert.Fail(); - } - catch (ObjectDisposedException ex) - { - Assert.IsNull(ex.InnerException); - Assert.AreEqual(typeof(SftpClient).FullName, ex.ObjectName); - } - } } } diff --git a/test/Renci.SshNet.Tests/Classes/SftpClientTest_AsyncExceptions.cs b/test/Renci.SshNet.Tests/Classes/SftpClientTest_AsyncExceptions.cs new file mode 100644 index 000000000..791e3b3a2 --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/SftpClientTest_AsyncExceptions.cs @@ -0,0 +1,268 @@ +using System; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using Moq; + +#if !NET8_0_OR_GREATER +using Renci.SshNet.Abstractions; +#endif +using Renci.SshNet.Channels; +using Renci.SshNet.Common; +using Renci.SshNet.Connection; +using Renci.SshNet.Messages; +using Renci.SshNet.Messages.Authentication; +using Renci.SshNet.Messages.Connection; +using Renci.SshNet.Sftp; +using Renci.SshNet.Sftp.Responses; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class SftpClientTest_AsyncExceptions + { + private MySession _session; + private SftpClient _client; + + [TestInitialize] + public void Init() + { + var socketFactoryMock = new Mock(MockBehavior.Strict); + var serviceFactoryMock = new Mock(MockBehavior.Strict); + + var connInfo = new PasswordConnectionInfo("host", "user", "pwd"); + + _session = new MySession(connInfo); + + var concreteServiceFactory = new ServiceFactory(); + + serviceFactoryMock + .Setup(p => p.CreateSocketFactory()) + .Returns(socketFactoryMock.Object); + + serviceFactoryMock + .Setup(p => p.CreateSession(It.IsAny(), socketFactoryMock.Object)) + .Returns(_session); + + serviceFactoryMock + .Setup(p => p.CreateSftpResponseFactory()) + .Returns(concreteServiceFactory.CreateSftpResponseFactory); + + serviceFactoryMock + .Setup(p => p.CreateSftpSession(_session, It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(concreteServiceFactory.CreateSftpSession); + + _client = new SftpClient(connInfo, false, serviceFactoryMock.Object); + _client.Connect(); + } + + [TestMethod] + public async Task Async_ObservesSessionDisconnected() + { + Task openTask = _client.OpenAsync("path", FileMode.Create, FileAccess.Write, CancellationToken.None); + + Assert.IsFalse(openTask.IsCompleted); + + _session.InvokeDisconnected(); + + var ex = await Assert.ThrowsExceptionAsync(() => openTask); + Assert.AreEqual("Connection was closed by the server.", ex.Message); + } + + [TestMethod] + public async Task Async_ObservesChannelClosed() + { + Task openTask = _client.OpenAsync("path", FileMode.Create, FileAccess.Write, CancellationToken.None); + + Assert.IsFalse(openTask.IsCompleted); + + _session.InvokeChannelCloseReceived(); + + var ex = await Assert.ThrowsExceptionAsync(() => openTask); + Assert.AreEqual("Channel was closed.", ex.Message); + } + + [TestMethod] + public async Task Async_ObservesCancellationToken() + { + using CancellationTokenSource cts = new(); + + Task openTask = _client.OpenAsync("path", FileMode.Create, FileAccess.Write, cts.Token); + + Assert.IsFalse(openTask.IsCompleted); + + await cts.CancelAsync(); + + var ex = await Assert.ThrowsExceptionAsync(() => openTask); + Assert.AreEqual(cts.Token, ex.CancellationToken); + } + + [TestMethod] + public async Task Async_ObservesOperationTimeout() + { + _client.OperationTimeout = TimeSpan.FromMilliseconds(250); + + Task openTask = _client.OpenAsync("path", FileMode.Create, FileAccess.Write, CancellationToken.None); + + var ex = await Assert.ThrowsExceptionAsync(() => openTask); + } + + [TestMethod] + public async Task Async_ObservesErrorOccurred() + { + Task openTask = _client.OpenAsync("path", FileMode.Create, FileAccess.Write, CancellationToken.None); + + Assert.IsFalse(openTask.IsCompleted); + + MyException ex = new("my exception"); + + _session.InvokeErrorOccurred(ex); + + var ex2 = await Assert.ThrowsExceptionAsync(() => openTask); + Assert.AreEqual(ex.Message, ex2.Message); + } + +#pragma warning disable IDE0022 // Use block body for method +#pragma warning disable IDE0025 // Use block body for property +#pragma warning disable CS0067 // event is unused + private class MySession(ConnectionInfo connectionInfo) : ISession + { + public IConnectionInfo ConnectionInfo => connectionInfo; + + public event EventHandler> ChannelCloseReceived; + public event EventHandler> ChannelDataReceived; + public event EventHandler> ChannelEofReceived; + public event EventHandler> ChannelExtendedDataReceived; + public event EventHandler> ChannelFailureReceived; + public event EventHandler> ChannelOpenConfirmationReceived; + public event EventHandler> ChannelOpenFailureReceived; + public event EventHandler> ChannelOpenReceived; + public event EventHandler> ChannelRequestReceived; + public event EventHandler> ChannelSuccessReceived; + public event EventHandler> ChannelWindowAdjustReceived; + public event EventHandler Disconnected; + public event EventHandler ErrorOccured; + public event EventHandler ServerIdentificationReceived; + public event EventHandler HostKeyReceived; + public event EventHandler> RequestSuccessReceived; + public event EventHandler> RequestFailureReceived; + public event EventHandler> UserAuthenticationBannerReceived; + + public void InvokeDisconnected() + { + Disconnected?.Invoke(this, new EventArgs()); + } + + public void InvokeChannelCloseReceived() + { + ChannelCloseReceived?.Invoke( + this, + new MessageEventArgs(new ChannelCloseMessage(0))); + } + + public void InvokeErrorOccurred(Exception ex) + { + ErrorOccured?.Invoke(this, new ExceptionEventArgs(ex)); + } + + public void SendMessage(Message message) + { + if (message is ChannelOpenMessage) + { + ChannelOpenConfirmationReceived?.Invoke( + this, + new MessageEventArgs( + new ChannelOpenConfirmationMessage(0, int.MaxValue, int.MaxValue, 0))); + } + else if (message is ChannelRequestMessage) + { + ChannelSuccessReceived?.Invoke( + this, + new MessageEventArgs(new ChannelSuccessMessage(0))); + } + else if (message is ChannelDataMessage dataMsg) + { + if (dataMsg.Data[sizeof(uint)] == (byte)SftpMessageTypes.Init) + { + ChannelDataReceived?.Invoke( + this, + new MessageEventArgs( + new ChannelDataMessage(0, new SftpVersionResponse() { Version = 3 }.GetBytes()))); + } + else if (dataMsg.Data[sizeof(uint)] == (byte)SftpMessageTypes.RealPath) + { + ChannelDataReceived?.Invoke( + this, + new MessageEventArgs( + new ChannelDataMessage(0, + new SftpNameResponse(3, Encoding.UTF8) + { + ResponseId = 1, + Files = [new("thepath", new SftpFileAttributes(default, default, default, default, default, default, default))] + }.GetBytes()))); + } + } + } + + public bool IsConnected => false; + + public SemaphoreSlim SessionSemaphore { get; } = new(1); + + public IChannelSession CreateChannelSession() => new ChannelSession(this, 0, int.MaxValue, int.MaxValue); + + public WaitHandle MessageListenerCompleted => throw new NotImplementedException(); + + public void Connect() + { + } + + public Task ConnectAsync(CancellationToken cancellationToken) => throw new NotImplementedException(); + + public IChannelDirectTcpip CreateChannelDirectTcpip() => throw new NotImplementedException(); + + public IChannelForwardedTcpip CreateChannelForwardedTcpip(uint remoteChannelNumber, uint remoteWindowSize, uint remoteChannelDataPacketSize) + => throw new NotImplementedException(); + + public void Dispose() + { + } + + public void OnDisconnecting() + { + } + + public void Disconnect() => throw new NotImplementedException(); + + public void RegisterMessage(string messageName) => throw new NotImplementedException(); + + public bool TrySendMessage(Message message) => throw new NotImplementedException(); + + public WaitResult TryWait(WaitHandle waitHandle, TimeSpan timeout, out Exception exception) => throw new NotImplementedException(); + + public WaitResult TryWait(WaitHandle waitHandle, TimeSpan timeout) => throw new NotImplementedException(); + + public void UnRegisterMessage(string messageName) => throw new NotImplementedException(); + + public void WaitOnHandle(WaitHandle waitHandle) + { + } + + public void WaitOnHandle(WaitHandle waitHandle, TimeSpan timeout) => throw new NotImplementedException(); + } + + [TestCleanup] + public void Cleanup() => _client?.Dispose(); + +#pragma warning disable + private class MyException : Exception + { + public MyException(string message) : base(message) + { + } + } + } +} From 966c7c3999dcd538d18fde668e15eb6c1a9dae75 Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Mon, 7 Oct 2024 21:55:14 +0200 Subject: [PATCH 2/2] Update ISftpClient --- src/Renci.SshNet/ISftpClient.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Renci.SshNet/ISftpClient.cs b/src/Renci.SshNet/ISftpClient.cs index f91ba7df1..d3c4b3b6a 100644 --- a/src/Renci.SshNet/ISftpClient.cs +++ b/src/Renci.SshNet/ISftpClient.cs @@ -56,7 +56,6 @@ public interface ISftpClient : IBaseClient /// The timeout to wait until an operation completes. The default value is negative /// one (-1) milliseconds, which indicates an infinite timeout period. /// - /// The method was called after the client was disposed. /// represents a value that is less than -1 or greater than milliseconds. TimeSpan OperationTimeout { get; set; }