Skip to content

Commit d07827b

Browse files
authored
Fixed dequeuing of incoming queue (#1319)
* Fixed dequeuing of incoming queue. * Adjusted return of Expect to make sure it returns the full incoming queue.
1 parent bcaf354 commit d07827b

File tree

2 files changed

+81
-45
lines changed

2 files changed

+81
-45
lines changed

src/Renci.SshNet/ShellStream.cs

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -282,19 +282,14 @@ public void Expect(TimeSpan timeout, params ExpectAction[] expectActions)
282282

283283
if (match.Success)
284284
{
285-
var returnText = matchText.Substring(0, match.Index + match.Length);
286-
var returnLength = _encoding.GetByteCount(returnText);
285+
#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
286+
var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length));
287+
#else
288+
var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length));
289+
#endif
287290

288291
// Remove processed items from the queue
289-
for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
290-
{
291-
if (_expect.Count == _incoming.Count)
292-
{
293-
_ = _expect.Dequeue();
294-
}
295-
296-
_ = _incoming.Dequeue();
297-
}
292+
var returnText = SyncQueuesAndReturn(returnLength);
298293

299294
expectAction.Action(returnText);
300295
expectedFound = true;
@@ -385,19 +380,14 @@ public string Expect(Regex regex, TimeSpan timeout)
385380

386381
if (match.Success)
387382
{
388-
returnText = matchText.Substring(0, match.Index + match.Length);
389-
var returnLength = _encoding.GetByteCount(returnText);
383+
#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
384+
var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length));
385+
#else
386+
var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length));
387+
#endif
390388

391389
// Remove processed items from the queue
392-
for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
393-
{
394-
if (_expect.Count == _incoming.Count)
395-
{
396-
_ = _expect.Dequeue();
397-
}
398-
399-
_ = _incoming.Dequeue();
400-
}
390+
returnText = SyncQueuesAndReturn(returnLength);
401391

402392
break;
403393
}
@@ -501,19 +491,14 @@ public IAsyncResult BeginExpect(TimeSpan timeout, AsyncCallback callback, object
501491

502492
if (match.Success)
503493
{
504-
returnText = matchText.Substring(0, match.Index + match.Length);
505-
var returnLength = _encoding.GetByteCount(returnText);
494+
#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
495+
var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length));
496+
#else
497+
var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length));
498+
#endif
506499

507500
// Remove processed items from the queue
508-
for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
509-
{
510-
if (_expect.Count == _incoming.Count)
511-
{
512-
_ = _expect.Dequeue();
513-
}
514-
515-
_ = _incoming.Dequeue();
516-
}
501+
returnText = SyncQueuesAndReturn(returnLength);
517502

518503
expectAction.Action(returnText);
519504
callback?.Invoke(asyncResult);
@@ -614,15 +599,7 @@ public string ReadLine(TimeSpan timeout)
614599
var bytesProcessed = _encoding.GetByteCount(text + CrLf);
615600

616601
// remove processed bytes from the queue
617-
for (var i = 0; i < bytesProcessed; i++)
618-
{
619-
if (_expect.Count == _incoming.Count)
620-
{
621-
_ = _expect.Dequeue();
622-
}
623-
624-
_ = _incoming.Dequeue();
625-
}
602+
SyncQueuesAndDequeue(bytesProcessed);
626603

627604
break;
628605
}
@@ -687,7 +664,7 @@ public override int Read(byte[] buffer, int offset, int count)
687664
{
688665
for (; i < count && _incoming.Count > 0; i++)
689666
{
690-
if (_expect.Count == _incoming.Count)
667+
if (_incoming.Count == _expect.Count)
691668
{
692669
_ = _expect.Dequeue();
693670
}
@@ -869,5 +846,37 @@ private void OnDataReceived(byte[] data)
869846
{
870847
DataReceived?.Invoke(this, new ShellDataEventArgs(data));
871848
}
849+
850+
private string SyncQueuesAndReturn(int bytesToDequeue)
851+
{
852+
string incomingText;
853+
854+
lock (_incoming)
855+
{
856+
var incomingLength = _incoming.Count - _expect.Count + bytesToDequeue;
857+
incomingText = _encoding.GetString(_incoming.ToArray(), 0, incomingLength);
858+
859+
SyncQueuesAndDequeue(bytesToDequeue);
860+
}
861+
862+
return incomingText;
863+
}
864+
865+
private void SyncQueuesAndDequeue(int bytesToDequeue)
866+
{
867+
lock (_incoming)
868+
{
869+
while (_incoming.Count > _expect.Count)
870+
{
871+
_ = _incoming.Dequeue();
872+
}
873+
874+
for (var count = 0; count < bytesToDequeue && _incoming.Count > 0; count++)
875+
{
876+
_ = _incoming.Dequeue();
877+
_ = _expect.Dequeue();
878+
}
879+
}
880+
}
872881
}
873882
}

test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ namespace Renci.SshNet.Tests.Classes
1717
[TestClass]
1818
public class ShellStreamTest_ReadExpect
1919
{
20+
private const int BufferSize = 1024;
21+
private const int ExpectSize = BufferSize * 2;
2022
private ShellStream _shellStream;
2123
private ChannelSessionStub _channelSessionStub;
2224

@@ -42,8 +44,8 @@ public void Initialize()
4244
width: 800,
4345
height: 600,
4446
terminalModeValues: null,
45-
bufferSize: 1024,
46-
expectSize: 2048);
47+
bufferSize: BufferSize,
48+
expectSize: ExpectSize);
4749
}
4850

4951
[TestMethod]
@@ -244,6 +246,31 @@ public void Expect_String_LargeExpect()
244246
Assert.AreEqual($"{new string('c', 100)}", _shellStream.Read());
245247
}
246248

249+
[TestMethod]
250+
public void Expect_String_DequeueChecks()
251+
{
252+
const string expected = "ccccc";
253+
254+
// Prime buffer
255+
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string(' ', BufferSize)));
256+
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string(' ', ExpectSize)));
257+
258+
// Test data
259+
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('a', 100)));
260+
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('b', 100)));
261+
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(expected));
262+
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('d', 100)));
263+
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('e', 100)));
264+
265+
// Expected result
266+
var expectedResult = $"{new string(' ', BufferSize)}{new string(' ', ExpectSize)}{new string('a', 100)}{new string('b', 100)}{expected}";
267+
var expectedRead = $"{new string('d', 100)}{new string('e', 100)}";
268+
269+
Assert.AreEqual(expectedResult, _shellStream.Expect(expected));
270+
271+
Assert.AreEqual(expectedRead, _shellStream.Read());
272+
}
273+
247274
[TestMethod]
248275
public void Expect_Timeout()
249276
{

0 commit comments

Comments
 (0)