Skip to content

Commit ed17aa9

Browse files
belakSeanLatimer
andauthored
Fix read loop not exiting when the server closes the connection (#73)
Co-authored-by: Sean Latimer <[email protected]>
1 parent ca43582 commit ed17aa9

File tree

3 files changed

+58
-15
lines changed

3 files changed

+58
-15
lines changed

client.go

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ type cap struct {
5050
// much simpler.
5151
type Client struct {
5252
*Conn
53+
rwc io.ReadWriteCloser
5354
config ClientConfig
5455

5556
// Internal state
@@ -63,9 +64,10 @@ type Client struct {
6364
}
6465

6566
// NewClient creates a client given an io stream and a client config.
66-
func NewClient(rw io.ReadWriter, config ClientConfig) *Client {
67+
func NewClient(rwc io.ReadWriteCloser, config ClientConfig) *Client {
6768
c := &Client{
68-
Conn: NewConn(rw),
69+
Conn: NewConn(rwc),
70+
rwc: rwc,
6971
config: config,
7072
errChan: make(chan error, 1),
7173
caps: make(map[string]cap),
@@ -238,25 +240,30 @@ func (c *Client) sendError(err error) {
238240
}
239241
}
240242

241-
func (c *Client) startReadLoop(wg *sync.WaitGroup) {
243+
func (c *Client) startReadLoop(wg *sync.WaitGroup, exiting chan struct{}) {
242244
wg.Add(1)
243245

244246
go func() {
245247
defer wg.Done()
246248

247249
for {
248-
m, err := c.ReadMessage()
249-
if err != nil {
250-
c.sendError(err)
251-
break
252-
}
250+
select {
251+
case <-exiting:
252+
return
253+
default:
254+
m, err := c.ReadMessage()
255+
if err != nil {
256+
c.sendError(err)
257+
break
258+
}
253259

254-
if f, ok := clientFilters[m.Command]; ok {
255-
f(c, m)
256-
}
260+
if f, ok := clientFilters[m.Command]; ok {
261+
f(c, m)
262+
}
257263

258-
if c.config.Handler != nil {
259-
c.config.Handler.Handle(c, m)
264+
if c.config.Handler != nil {
265+
c.config.Handler.Handle(c, m)
266+
}
260267
}
261268
}
262269

@@ -296,7 +303,7 @@ func (c *Client) RunContext(ctx context.Context) error {
296303

297304
// Now that the handshake is pretty much done, we can start listening for
298305
// messages.
299-
c.startReadLoop(&wg)
306+
c.startReadLoop(&wg, exiting)
300307

301308
// Wait for an error from any goroutine or for the context to time out, then
302309
// signal we're exiting and wait for the goroutines to exit.
@@ -307,6 +314,7 @@ func (c *Client) RunContext(ctx context.Context) error {
307314
}
308315

309316
close(exiting)
317+
c.rwc.Close()
310318
wg.Wait()
311319

312320
return err

client_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,4 +408,16 @@ func TestPingLoop(t *testing.T) {
408408
SendLine("001 :hello_world\r\n"),
409409
Delay(25 * time.Millisecond),
410410
})
411+
412+
// See if we can get the client to hang
413+
runClientTest(t, config, errors.New("test error"), nil, []TestAction{
414+
ExpectLine("PASS :test_pass\r\n"),
415+
ExpectLine("NICK :test_nick\r\n"),
416+
ExpectLine("USER test_user 0 * :test_name\r\n"),
417+
// We queue this up a line early because the next write will happen after the delay.
418+
QueueWriteError(errors.New("test error")),
419+
SendLine("001 :hello_world\r\n"),
420+
Delay(2 * time.Second),
421+
AssertClosed(),
422+
})
411423
}

stream_test.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ func SendLine(output string) TestAction {
1818
return SendLineWithTimeout(output, 1*time.Second)
1919
}
2020

21+
func AssertClosed() TestAction {
22+
return func(t *testing.T, rw *testReadWriter) {
23+
if !rw.closed {
24+
assert.Fail(t, "Expected conn to be closed")
25+
}
26+
}
27+
}
28+
2129
func SendLineWithTimeout(output string, timeout time.Duration) TestAction {
2230
return func(t *testing.T, rw *testReadWriter) {
2331
waitChan := time.After(timeout)
@@ -116,6 +124,7 @@ type testReadWriter struct {
116124
readEmptyChan chan struct{}
117125
exiting chan struct{}
118126
clientDone chan struct{}
127+
closed bool
119128
serverBuffer bytes.Buffer
120129
}
121130

@@ -182,6 +191,20 @@ func (rw *testReadWriter) Write(buf []byte) (int, error) {
182191
}
183192
}
184193

194+
func (rw *testReadWriter) Close() error {
195+
select {
196+
case <-rw.exiting:
197+
return errors.New("Connection closed")
198+
default:
199+
// Ensure no double close
200+
if !rw.closed {
201+
rw.closed = true
202+
close(rw.exiting)
203+
}
204+
return nil
205+
}
206+
}
207+
185208
func newTestReadWriter(actions []TestAction) *testReadWriter {
186209
return &testReadWriter{
187210
actions: actions,
@@ -223,7 +246,7 @@ func runTest(t *testing.T, rw *testReadWriter, actions []TestAction) {
223246
// TODO: Make sure there are no more incoming messages
224247

225248
// Ask everything to shut down
226-
close(rw.exiting)
249+
rw.Close()
227250

228251
// Wait for the client to stop
229252
select {

0 commit comments

Comments
 (0)