Skip to content

Commit 97384c1

Browse files
committed
quic: remove streams from the conn when done
When a stream has been fully shut down--the peer has closed its end and acked every frame we will send for it--remove it from the Conn's set of active streams. We do the actual removal on the conn's loop, so stream cleanup can access conn state without worrying about locking. For golang/go#58547 Change-Id: Id9715693649929b07d303f0c4b3a782d135f0326 Reviewed-on: https://go-review.googlesource.com/c/net/+/524296 Reviewed-by: Jonathan Amsterdam <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent 03d5e62 commit 97384c1

File tree

6 files changed

+315
-44
lines changed

6 files changed

+315
-44
lines changed

internal/quic/atomic_bits.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright 2023 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
//go:build go1.21
6+
7+
package quic
8+
9+
import "sync/atomic"
10+
11+
// atomicBits is an atomic uint32 that supports setting individual bits.
12+
type atomicBits[T ~uint32] struct {
13+
bits atomic.Uint32
14+
}
15+
16+
// set sets the bits in mask to the corresponding bits in v.
17+
// It returns the new value.
18+
func (a *atomicBits[T]) set(v, mask T) T {
19+
if v&^mask != 0 {
20+
panic("BUG: bits in v are not in mask")
21+
}
22+
for {
23+
o := a.bits.Load()
24+
n := (o &^ uint32(mask)) | uint32(v)
25+
if a.bits.CompareAndSwap(o, n) {
26+
return T(n)
27+
}
28+
}
29+
}
30+
31+
func (a *atomicBits[T]) load() T {
32+
return T(a.bits.Load())
33+
}

internal/quic/conn_streams.go

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -185,24 +185,46 @@ func (c *Conn) appendStreamFrames(w *packetWriter, pnum packetNumber, pto bool)
185185
for {
186186
s := c.streams.sendHead
187187
const pto = false
188-
if !s.appendInFrames(w, pnum, pto) {
189-
return false
188+
189+
state := s.state.load()
190+
if state&streamInSend != 0 {
191+
s.ingate.lock()
192+
ok := s.appendInFramesLocked(w, pnum, pto)
193+
state = s.inUnlockNoQueue()
194+
if !ok {
195+
return false
196+
}
190197
}
191-
avail := w.avail()
192-
if !s.appendOutFrames(w, pnum, pto) {
193-
// We've sent some data for this stream, but it still has more to send.
194-
// If the stream got a reasonable chance to put data in a packet,
195-
// advance sendHead to the next stream in line, to avoid starvation.
196-
// We'll come back to this stream after going through the others.
197-
//
198-
// If the packet was already mostly out of space, leave sendHead alone
199-
// and come back to this stream again on the next packet.
200-
if avail > 512 {
201-
c.streams.sendHead = s.next
202-
c.streams.sendTail = s
198+
199+
if state&streamOutSend != 0 {
200+
avail := w.avail()
201+
s.outgate.lock()
202+
ok := s.appendOutFramesLocked(w, pnum, pto)
203+
state = s.outUnlockNoQueue()
204+
if !ok {
205+
// We've sent some data for this stream, but it still has more to send.
206+
// If the stream got a reasonable chance to put data in a packet,
207+
// advance sendHead to the next stream in line, to avoid starvation.
208+
// We'll come back to this stream after going through the others.
209+
//
210+
// If the packet was already mostly out of space, leave sendHead alone
211+
// and come back to this stream again on the next packet.
212+
if avail > 512 {
213+
c.streams.sendHead = s.next
214+
c.streams.sendTail = s
215+
}
216+
return false
203217
}
204-
return false
205218
}
219+
220+
if state == streamInDone|streamOutDone {
221+
// Stream is finished, remove it from the conn.
222+
s.state.set(streamConnRemoved, streamConnRemoved)
223+
delete(c.streams.streams, s.id)
224+
225+
// TODO: Provide the peer with additional stream quota (MAX_STREAMS).
226+
}
227+
206228
next := s.next
207229
s.next = nil
208230
if (next == s) != (s == c.streams.sendTail) {
@@ -231,10 +253,16 @@ func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool {
231253
defer c.streams.sendMu.Unlock()
232254
for _, s := range c.streams.streams {
233255
const pto = true
234-
if !s.appendInFrames(w, pnum, pto) {
256+
s.ingate.lock()
257+
inOK := s.appendInFramesLocked(w, pnum, pto)
258+
s.inUnlockNoQueue()
259+
if !inOK {
235260
return false
236261
}
237-
if !s.appendOutFrames(w, pnum, pto) {
262+
s.outgate.lock()
263+
outOK := s.appendOutFramesLocked(w, pnum, pto)
264+
s.outUnlockNoQueue()
265+
if !outOK {
238266
return false
239267
}
240268
}

internal/quic/conn_streams_test.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ package quic
88

99
import (
1010
"context"
11+
"fmt"
12+
"io"
1113
"testing"
1214
)
1315

@@ -253,3 +255,90 @@ func TestStreamsWriteQueueFairness(t *testing.T) {
253255
}
254256
}
255257
}
258+
259+
func TestStreamsShutdown(t *testing.T) {
260+
// These tests verify that a stream is removed from the Conn's map of live streams
261+
// after it is fully shut down.
262+
//
263+
// Each case consists of a setup step, after which one stream should exist,
264+
// and a shutdown step, after which no streams should remain in the Conn.
265+
for _, test := range []struct {
266+
name string
267+
side streamSide
268+
styp streamType
269+
setup func(*testing.T, *testConn, *Stream)
270+
shutdown func(*testing.T, *testConn, *Stream)
271+
}{{
272+
name: "closed",
273+
side: localStream,
274+
styp: uniStream,
275+
setup: func(t *testing.T, tc *testConn, s *Stream) {
276+
s.CloseContext(canceledContext())
277+
},
278+
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
279+
tc.writeAckForAll()
280+
},
281+
}, {
282+
name: "local close",
283+
side: localStream,
284+
styp: bidiStream,
285+
setup: func(t *testing.T, tc *testConn, s *Stream) {
286+
tc.writeFrames(packetType1RTT, debugFrameResetStream{
287+
id: s.id,
288+
})
289+
s.CloseContext(canceledContext())
290+
},
291+
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
292+
tc.writeAckForAll()
293+
},
294+
}, {
295+
name: "remote reset",
296+
side: localStream,
297+
styp: bidiStream,
298+
setup: func(t *testing.T, tc *testConn, s *Stream) {
299+
s.CloseContext(canceledContext())
300+
tc.wantIdle("all frames after CloseContext are ignored")
301+
tc.writeAckForAll()
302+
},
303+
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
304+
tc.writeFrames(packetType1RTT, debugFrameResetStream{
305+
id: s.id,
306+
})
307+
},
308+
}, {
309+
name: "local close",
310+
side: remoteStream,
311+
styp: uniStream,
312+
setup: func(t *testing.T, tc *testConn, s *Stream) {
313+
ctx := canceledContext()
314+
tc.writeFrames(packetType1RTT, debugFrameStream{
315+
id: s.id,
316+
fin: true,
317+
})
318+
if n, err := s.ReadContext(ctx, make([]byte, 16)); n != 0 || err != io.EOF {
319+
t.Errorf("ReadContext() = %v, %v; want 0, io.EOF", n, err)
320+
}
321+
},
322+
shutdown: func(t *testing.T, tc *testConn, s *Stream) {
323+
s.CloseRead()
324+
},
325+
}} {
326+
name := fmt.Sprintf("%v/%v/%v", test.side, test.styp, test.name)
327+
t.Run(name, func(t *testing.T) {
328+
tc, s := newTestConnAndStream(t, serverSide, test.side, test.styp,
329+
permissiveTransportParameters)
330+
tc.ignoreFrame(frameTypeStreamBase)
331+
tc.ignoreFrame(frameTypeStopSending)
332+
test.setup(t, tc, s)
333+
tc.wantIdle("conn should be idle after setup")
334+
if got, want := len(tc.conn.streams.streams), 1; got != want {
335+
t.Fatalf("after setup: %v streams in Conn's map; want %v", got, want)
336+
}
337+
test.shutdown(t, tc, s)
338+
tc.wantIdle("conn should be idle after shutdown")
339+
if got, want := len(tc.conn.streams.streams), 0; got != want {
340+
t.Fatalf("after shutdown: %v streams in Conn's map; want %v", got, want)
341+
}
342+
})
343+
}
344+
}

internal/quic/conn_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) {
394394
// writeAckForAll sends the Conn a datagram containing an ack for all packets up to the
395395
// last one received.
396396
func (tc *testConn) writeAckForAll() {
397+
tc.t.Helper()
397398
if tc.lastPacket == nil {
398399
return
399400
}
@@ -405,6 +406,7 @@ func (tc *testConn) writeAckForAll() {
405406
// writeAckForLatest sends the Conn a datagram containing an ack for the
406407
// most recent packet received.
407408
func (tc *testConn) writeAckForLatest() {
409+
tc.t.Helper()
408410
if tc.lastPacket == nil {
409411
return
410412
}

0 commit comments

Comments
 (0)