@@ -25,7 +25,6 @@ import (
25
25
"github.com/blinklabs-io/gouroboros/cbor"
26
26
"github.com/blinklabs-io/gouroboros/connection"
27
27
"github.com/blinklabs-io/gouroboros/muxer"
28
- "github.com/blinklabs-io/gouroboros/utils"
29
28
)
30
29
31
30
// This is completely arbitrary, but the line had to be drawn somewhere
@@ -34,15 +33,16 @@ const maxMessagesPerSegment = 20
34
33
// Protocol implements the base functionality of an Ouroboros mini-protocol
35
34
type Protocol struct {
36
35
config ProtocolConfig
36
+ doneChan chan struct {}
37
37
muxerSendChan chan * muxer.Segment
38
38
muxerRecvChan chan * muxer.Segment
39
39
muxerDoneChan chan bool
40
40
sendQueueChan chan Message
41
+ recvDoneChan chan struct {}
41
42
recvReadyChan chan bool
43
+ sendDoneChan chan struct {}
42
44
sendReadyChan chan bool
43
45
stateTransitionChan chan <- protocolStateTransition
44
- doneSignal * utils.DoneSignal
45
- waitGroup sync.WaitGroup
46
46
onceStart sync.Once
47
47
}
48
48
@@ -105,8 +105,10 @@ type MessageFromCborFunc func(uint, []byte) (Message, error)
105
105
// New returns a new Protocol object
106
106
func New (config ProtocolConfig ) * Protocol {
107
107
p := & Protocol {
108
- config : config ,
109
- doneSignal : utils .NewDoneSignal (),
108
+ config : config ,
109
+ doneChan : make (chan struct {}),
110
+ recvDoneChan : make (chan struct {}),
111
+ sendDoneChan : make (chan struct {}),
110
112
}
111
113
return p
112
114
}
@@ -133,7 +135,11 @@ func (p *Protocol) Start() {
133
135
p .stateTransitionChan = stateTransitionChan
134
136
135
137
// Start our send and receive Goroutines
136
- p .waitGroup .Add (2 )
138
+ go func () {
139
+ <- p .recvDoneChan
140
+ <- p .sendDoneChan
141
+ close (p .doneChan )
142
+ }()
137
143
138
144
go p .stateLoop (stateTransitionChan )
139
145
go p .recvLoop ()
@@ -153,7 +159,7 @@ func (p *Protocol) Role() ProtocolRole {
153
159
154
160
// DoneChan returns the channel used to signal protocol shutdown
155
161
func (p * Protocol ) DoneChan () <- chan struct {} {
156
- return p .doneSignal . GetCh ()
162
+ return p .doneChan
157
163
}
158
164
159
165
// SendMessage appends a message to the send queue
@@ -176,17 +182,16 @@ func (p *Protocol) SendError(err error) {
176
182
177
183
func (p * Protocol ) sendLoop () {
178
184
defer func () {
179
- p .waitGroup .Done ()
180
185
// Close muxer send channel
181
186
// We are responsible for closing this channel as the sender, even through it
182
187
// was created by the muxer
183
188
close (p .muxerSendChan )
184
- p . doneSignal . Close ( )
189
+ close ( p . sendDoneChan )
185
190
}()
186
191
187
192
for {
188
193
select {
189
- case <- p .doneSignal . GetCh () :
194
+ case <- p .recvDoneChan :
190
195
// Break out of send loop if we're shutting down
191
196
return
192
197
case <- p .sendReadyChan :
@@ -200,7 +205,7 @@ func (p *Protocol) sendLoop() {
200
205
for {
201
206
// Get next message from send queue
202
207
select {
203
- case <- p .doneSignal . GetCh () :
208
+ case <- p .recvDoneChan :
204
209
// Break out of send loop if we're shutting down
205
210
return
206
211
case msg , ok := <- p .sendQueueChan :
@@ -285,8 +290,7 @@ func (p *Protocol) sendLoop() {
285
290
286
291
func (p * Protocol ) recvLoop () {
287
292
defer func () {
288
- p .waitGroup .Done ()
289
- p .doneSignal .Close ()
293
+ close (p .recvDoneChan )
290
294
}()
291
295
292
296
leftoverData := false
@@ -298,7 +302,7 @@ func (p *Protocol) recvLoop() {
298
302
if ! leftoverData {
299
303
// Wait for segment
300
304
select {
301
- case <- p .doneSignal . GetCh () :
305
+ case <- p .sendDoneChan :
302
306
// Break out of receive loop if we're shutting down
303
307
return
304
308
case <- p .muxerDoneChan :
@@ -314,7 +318,7 @@ func (p *Protocol) recvLoop() {
314
318
leftoverData = false
315
319
// Wait until ready to receive based on state map
316
320
select {
317
- case <- p .doneSignal . GetCh () :
321
+ case <- p .sendDoneChan :
318
322
// Break out of receive loop if we're shutting down
319
323
return
320
324
case <- p .muxerDoneChan :
@@ -431,9 +435,6 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) {
431
435
return transitionTimer .C
432
436
}
433
437
434
- protocolDoneChan := p .doneSignal .GetCh ()
435
- stateDoneChan := make (chan struct {})
436
-
437
438
setState (p .config .InitialState )
438
439
439
440
for {
@@ -467,24 +468,11 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) {
467
468
),
468
469
)
469
470
470
- case <- protocolDoneChan :
471
- // Disable this case so it doesn't block
472
- protocolDoneChan = nil
473
-
474
- // Wait for all other goroutines to finish before shutting down the state handler
475
- go func () {
476
- p .waitGroup .Wait ()
477
-
478
- close (stateDoneChan )
479
- }()
480
-
481
- case <- stateDoneChan :
482
- // All other goroutines have finished, so we can stop the timer and return
471
+ case <- p .doneChan :
472
+ // Disable any previous state transition timer, as they are no longer needed
483
473
if transitionTimer != nil && ! transitionTimer .Stop () {
484
474
<- transitionTimer .C
485
475
}
486
- transitionTimer = nil
487
-
488
476
return
489
477
}
490
478
}
0 commit comments