From 7d1d46a876ae1063ddd27eb5c5b5a0bc0d08187d Mon Sep 17 00:00:00 2001 From: Sean Latimer Date: Thu, 4 Jul 2019 14:54:52 +0100 Subject: [PATCH 1/2] fixed readLoop not exiting when exiting channel is closed --- client.go | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index 575e306..91ee565 100644 --- a/client.go +++ b/client.go @@ -238,25 +238,30 @@ func (c *Client) sendError(err error) { } } -func (c *Client) startReadLoop(wg *sync.WaitGroup) { +func (c *Client) startReadLoop(wg *sync.WaitGroup, exiting chan struct{}) { wg.Add(1) go func() { defer wg.Done() for { - m, err := c.ReadMessage() - if err != nil { - c.sendError(err) - break - } + select { + case <-exiting: + return + default: + m, err := c.ReadMessage() + if err != nil { + c.sendError(err) + break + } - if f, ok := clientFilters[m.Command]; ok { - f(c, m) - } + if f, ok := clientFilters[m.Command]; ok { + f(c, m) + } - if c.config.Handler != nil { - c.config.Handler.Handle(c, m) + if c.config.Handler != nil { + c.config.Handler.Handle(c, m) + } } } @@ -296,7 +301,7 @@ func (c *Client) RunContext(ctx context.Context) error { // Now that the handshake is pretty much done, we can start listening for // messages. - c.startReadLoop(&wg) + c.startReadLoop(&wg, exiting) // Wait for an error from any goroutine or for the context to time out, then // signal we're exiting and wait for the goroutines to exit. From 806fddcb59c7ec3ee6ac2510f6794f91cd6a0505 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Mon, 15 Jul 2019 16:27:13 -0700 Subject: [PATCH 2/2] Fix client exit failure with an explicit Close --- client.go | 7 +++++-- client_test.go | 12 ++++++++++++ stream_test.go | 25 ++++++++++++++++++++++++- 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 91ee565..ec7c560 100644 --- a/client.go +++ b/client.go @@ -50,6 +50,7 @@ type cap struct { // much simpler. type Client struct { *Conn + rwc io.ReadWriteCloser config ClientConfig // Internal state @@ -63,9 +64,10 @@ type Client struct { } // NewClient creates a client given an io stream and a client config. -func NewClient(rw io.ReadWriter, config ClientConfig) *Client { +func NewClient(rwc io.ReadWriteCloser, config ClientConfig) *Client { c := &Client{ - Conn: NewConn(rw), + Conn: NewConn(rwc), + rwc: rwc, config: config, errChan: make(chan error, 1), caps: make(map[string]cap), @@ -312,6 +314,7 @@ func (c *Client) RunContext(ctx context.Context) error { } close(exiting) + c.rwc.Close() wg.Wait() return err diff --git a/client_test.go b/client_test.go index ee27808..6f11702 100644 --- a/client_test.go +++ b/client_test.go @@ -408,4 +408,16 @@ func TestPingLoop(t *testing.T) { SendLine("001 :hello_world\r\n"), Delay(25 * time.Millisecond), }) + + // See if we can get the client to hang + runClientTest(t, config, errors.New("test error"), nil, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0 * :test_name\r\n"), + // We queue this up a line early because the next write will happen after the delay. + QueueWriteError(errors.New("test error")), + SendLine("001 :hello_world\r\n"), + Delay(2 * time.Second), + AssertClosed(), + }) } diff --git a/stream_test.go b/stream_test.go index de0cbf6..d4e2f73 100644 --- a/stream_test.go +++ b/stream_test.go @@ -18,6 +18,14 @@ func SendLine(output string) TestAction { return SendLineWithTimeout(output, 1*time.Second) } +func AssertClosed() TestAction { + return func(t *testing.T, rw *testReadWriter) { + if !rw.closed { + assert.Fail(t, "Expected conn to be closed") + } + } +} + func SendLineWithTimeout(output string, timeout time.Duration) TestAction { return func(t *testing.T, rw *testReadWriter) { waitChan := time.After(timeout) @@ -116,6 +124,7 @@ type testReadWriter struct { readEmptyChan chan struct{} exiting chan struct{} clientDone chan struct{} + closed bool serverBuffer bytes.Buffer } @@ -182,6 +191,20 @@ func (rw *testReadWriter) Write(buf []byte) (int, error) { } } +func (rw *testReadWriter) Close() error { + select { + case <-rw.exiting: + return errors.New("Connection closed") + default: + // Ensure no double close + if !rw.closed { + rw.closed = true + close(rw.exiting) + } + return nil + } +} + func newTestReadWriter(actions []TestAction) *testReadWriter { return &testReadWriter{ actions: actions, @@ -223,7 +246,7 @@ func runTest(t *testing.T, rw *testReadWriter, actions []TestAction) { // TODO: Make sure there are no more incoming messages // Ask everything to shut down - close(rw.exiting) + rw.Close() // Wait for the client to stop select {