Skip to content

Commit c174202

Browse files
authored
feat: remove unnecessary use of WaitGroup for protocol done signal (#608)
1 parent 2e915ed commit c174202

File tree

2 files changed

+21
-75
lines changed

2 files changed

+21
-75
lines changed

protocol/protocol.go

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import (
2525
"github.com/blinklabs-io/gouroboros/cbor"
2626
"github.com/blinklabs-io/gouroboros/connection"
2727
"github.com/blinklabs-io/gouroboros/muxer"
28-
"github.com/blinklabs-io/gouroboros/utils"
2928
)
3029

3130
// This is completely arbitrary, but the line had to be drawn somewhere
@@ -34,15 +33,16 @@ const maxMessagesPerSegment = 20
3433
// Protocol implements the base functionality of an Ouroboros mini-protocol
3534
type Protocol struct {
3635
config ProtocolConfig
36+
doneChan chan struct{}
3737
muxerSendChan chan *muxer.Segment
3838
muxerRecvChan chan *muxer.Segment
3939
muxerDoneChan chan bool
4040
sendQueueChan chan Message
41+
recvDoneChan chan struct{}
4142
recvReadyChan chan bool
43+
sendDoneChan chan struct{}
4244
sendReadyChan chan bool
4345
stateTransitionChan chan<- protocolStateTransition
44-
doneSignal *utils.DoneSignal
45-
waitGroup sync.WaitGroup
4646
onceStart sync.Once
4747
}
4848

@@ -105,8 +105,10 @@ type MessageFromCborFunc func(uint, []byte) (Message, error)
105105
// New returns a new Protocol object
106106
func New(config ProtocolConfig) *Protocol {
107107
p := &Protocol{
108-
config: config,
109-
doneSignal: utils.NewDoneSignal(),
108+
config: config,
109+
doneChan: make(chan struct{}),
110+
recvDoneChan: make(chan struct{}),
111+
sendDoneChan: make(chan struct{}),
110112
}
111113
return p
112114
}
@@ -133,7 +135,11 @@ func (p *Protocol) Start() {
133135
p.stateTransitionChan = stateTransitionChan
134136

135137
// Start our send and receive Goroutines
136-
p.waitGroup.Add(2)
138+
go func() {
139+
<-p.recvDoneChan
140+
<-p.sendDoneChan
141+
close(p.doneChan)
142+
}()
137143

138144
go p.stateLoop(stateTransitionChan)
139145
go p.recvLoop()
@@ -153,7 +159,7 @@ func (p *Protocol) Role() ProtocolRole {
153159

154160
// DoneChan returns the channel used to signal protocol shutdown
155161
func (p *Protocol) DoneChan() <-chan struct{} {
156-
return p.doneSignal.GetCh()
162+
return p.doneChan
157163
}
158164

159165
// SendMessage appends a message to the send queue
@@ -176,17 +182,16 @@ func (p *Protocol) SendError(err error) {
176182

177183
func (p *Protocol) sendLoop() {
178184
defer func() {
179-
p.waitGroup.Done()
180185
// Close muxer send channel
181186
// We are responsible for closing this channel as the sender, even through it
182187
// was created by the muxer
183188
close(p.muxerSendChan)
184-
p.doneSignal.Close()
189+
close(p.sendDoneChan)
185190
}()
186191

187192
for {
188193
select {
189-
case <-p.doneSignal.GetCh():
194+
case <-p.recvDoneChan:
190195
// Break out of send loop if we're shutting down
191196
return
192197
case <-p.sendReadyChan:
@@ -200,7 +205,7 @@ func (p *Protocol) sendLoop() {
200205
for {
201206
// Get next message from send queue
202207
select {
203-
case <-p.doneSignal.GetCh():
208+
case <-p.recvDoneChan:
204209
// Break out of send loop if we're shutting down
205210
return
206211
case msg, ok := <-p.sendQueueChan:
@@ -285,8 +290,7 @@ func (p *Protocol) sendLoop() {
285290

286291
func (p *Protocol) recvLoop() {
287292
defer func() {
288-
p.waitGroup.Done()
289-
p.doneSignal.Close()
293+
close(p.recvDoneChan)
290294
}()
291295

292296
leftoverData := false
@@ -298,7 +302,7 @@ func (p *Protocol) recvLoop() {
298302
if !leftoverData {
299303
// Wait for segment
300304
select {
301-
case <-p.doneSignal.GetCh():
305+
case <-p.sendDoneChan:
302306
// Break out of receive loop if we're shutting down
303307
return
304308
case <-p.muxerDoneChan:
@@ -314,7 +318,7 @@ func (p *Protocol) recvLoop() {
314318
leftoverData = false
315319
// Wait until ready to receive based on state map
316320
select {
317-
case <-p.doneSignal.GetCh():
321+
case <-p.sendDoneChan:
318322
// Break out of receive loop if we're shutting down
319323
return
320324
case <-p.muxerDoneChan:
@@ -431,9 +435,6 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) {
431435
return transitionTimer.C
432436
}
433437

434-
protocolDoneChan := p.doneSignal.GetCh()
435-
stateDoneChan := make(chan struct{})
436-
437438
setState(p.config.InitialState)
438439

439440
for {
@@ -467,24 +468,11 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) {
467468
),
468469
)
469470

470-
case <-protocolDoneChan:
471-
// Disable this case so it doesn't block
472-
protocolDoneChan = nil
473-
474-
// Wait for all other goroutines to finish before shutting down the state handler
475-
go func() {
476-
p.waitGroup.Wait()
477-
478-
close(stateDoneChan)
479-
}()
480-
481-
case <-stateDoneChan:
482-
// All other goroutines have finished, so we can stop the timer and return
471+
case <-p.doneChan:
472+
// Disable any previous state transition timer, as they are no longer needed
483473
if transitionTimer != nil && !transitionTimer.Stop() {
484474
<-transitionTimer.C
485475
}
486-
transitionTimer = nil
487-
488476
return
489477
}
490478
}

utils/utils.go

Lines changed: 0 additions & 42 deletions
This file was deleted.

0 commit comments

Comments
 (0)