From 69c6a0d2574c72d2e302d1e2a5cfd2080282bc1e Mon Sep 17 00:00:00 2001 From: Marius Thesing Date: Sun, 22 Sep 2024 11:37:36 +0200 Subject: [PATCH] fix ConnectAsync not respecting the connection timeout --- src/Renci.SshNet/BaseClient.cs | 12 ++- .../BaseClientTest_ConnectAsync_Timeout.cs | 73 +++++++++++++++++++ 2 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 test/Renci.SshNet.Tests/Classes/BaseClientTest_ConnectAsync_Timeout.cs diff --git a/src/Renci.SshNet/BaseClient.cs b/src/Renci.SshNet/BaseClient.cs index 3332bffe4..3757878b4 100644 --- a/src/Renci.SshNet/BaseClient.cs +++ b/src/Renci.SshNet/BaseClient.cs @@ -307,7 +307,17 @@ public async Task ConnectAsync(CancellationToken cancellationToken) DisposeSession(session); } - Session = await CreateAndConnectSessionAsync(cancellationToken).ConfigureAwait(false); + using var timeoutCancellationTokenSource = new CancellationTokenSource(ConnectionInfo.Timeout); + using var linkedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCancellationTokenSource.Token); + + try + { + Session = await CreateAndConnectSessionAsync(linkedCancellationTokenSource.Token).ConfigureAwait(false); + } + catch (OperationCanceledException ex) when (timeoutCancellationTokenSource.IsCancellationRequested) + { + throw new SshOperationTimeoutException("Connection has timed out.", ex); + } } try diff --git a/test/Renci.SshNet.Tests/Classes/BaseClientTest_ConnectAsync_Timeout.cs b/test/Renci.SshNet.Tests/Classes/BaseClientTest_ConnectAsync_Timeout.cs new file mode 100644 index 000000000..d1d9df3bd --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/BaseClientTest_ConnectAsync_Timeout.cs @@ -0,0 +1,73 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using Moq; + +#if !NET8_0_OR_GREATER +using Renci.SshNet.Abstractions; +#endif +using Renci.SshNet.Common; +using Renci.SshNet.Connection; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class BaseClientTest_ConnectAsync_Timeout + { + private BaseClient _client; + + [TestInitialize] + public void Init() + { + var sessionMock = new Mock(); + var serviceFactoryMock = new Mock(); + var socketFactoryMock = new Mock(); + + sessionMock.Setup(p => p.ConnectAsync(It.IsAny())) + .Returns(c => Task.Delay(Timeout.Infinite, c)); + + serviceFactoryMock.Setup(p => p.CreateSocketFactory()) + .Returns(socketFactoryMock.Object); + + var connectionInfo = new ConnectionInfo("host", "user", new PasswordAuthenticationMethod("user", "pwd")) + { + Timeout = TimeSpan.FromSeconds(1) + }; + + serviceFactoryMock.Setup(p => p.CreateSession(connectionInfo, socketFactoryMock.Object)) + .Returns(sessionMock.Object); + + _client = new MyClient(connectionInfo, false, serviceFactoryMock.Object); + } + + [TestMethod] + public async Task ConnectAsyncWithTimeoutThrowsSshTimeoutException() + { + await Assert.ThrowsExceptionAsync(() => _client.ConnectAsync(CancellationToken.None)); + } + + [TestMethod] + public async Task ConnectAsyncWithCancelledTokenThrowsOperationCancelledException() + { + using var cancellationTokenSource = new CancellationTokenSource(); + await cancellationTokenSource.CancelAsync(); + await Assert.ThrowsExceptionAsync(() => _client.ConnectAsync(cancellationTokenSource.Token)); + } + + [TestCleanup] + public void Cleanup() + { + _client?.Dispose(); + } + + private class MyClient : BaseClient + { + public MyClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo, IServiceFactory serviceFactory) : base(connectionInfo, ownsConnectionInfo, serviceFactory) + { + } + } + } +}