diff --git a/conn.go b/conn.go index e27a971a..f8571102 100644 --- a/conn.go +++ b/conn.go @@ -454,13 +454,18 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er } } - timer := time.NewTimer(d) select { case <-c.mu: - timer.Stop() - case <-timer.C: - return errWriteTimeout + default: + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } } + defer func() { c.mu <- struct{}{} }() c.writeErrMu.Lock() diff --git a/conn_test.go b/conn_test.go index fe95b552..a3dde2de 100644 --- a/conn_test.go +++ b/conn_test.go @@ -148,6 +148,31 @@ func TestFraming(t *testing.T) { } } +func TestConcurrencyWriteControl(t *testing.T) { + const message = "this is a ping/pong messsage" + loop := 10 + workers := 10 + for i := 0; i < loop; i++ { + var connBuf bytes.Buffer + + wg := sync.WaitGroup{} + wc := newTestConn(nil, &connBuf, true) + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil { + t.Errorf("concurrently wc.WriteControl() returned %v", err) + } + }() + } + + wg.Wait() + wc.Close() + } +} + func TestControl(t *testing.T) { t.Parallel() const message = "this is a ping/pong messsage"