Skip to content

Commit 776c02e

Browse files
committed
session will cache node
Signed-off-by: thinkAfCod <[email protected]>
1 parent ee30681 commit 776c02e

File tree

6 files changed

+44
-16
lines changed

6 files changed

+44
-16
lines changed

p2p/discover/v5_talk.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ const talkHandlerLaunchTimeout = 400 * time.Millisecond
3939
// Note that talk handlers are expected to come up with a response very quickly, within at
4040
// most 200ms or so. If the handler takes longer than that, the remote end may time out
4141
// and wont receive the response.
42-
type TalkRequestHandler func(enode.ID, *net.UDPAddr, []byte) []byte
42+
type TalkRequestHandler func(*enode.Node, *net.UDPAddr, []byte) []byte
4343

4444
type talkSystem struct {
4545
transport *UDPv5
@@ -72,13 +72,18 @@ func (t *talkSystem) register(protocol string, handler TalkRequestHandler) {
7272

7373
// handleRequest handles a talk request.
7474
func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire.TalkRequest) {
75+
var n *enode.Node
76+
if n = t.transport.codec.SessionNode(id, addr.String()); n == nil {
77+
log.Error("Got a TALKREQ from a node that has not completed the handshake", "id", id, "addr", addr)
78+
return
79+
}
7580
t.mutex.Lock()
7681
handler, ok := t.handlers[req.Protocol]
7782
t.mutex.Unlock()
7883

7984
if !ok {
8085
resp := &v5wire.TalkResponse{ReqID: req.ReqID}
81-
t.transport.sendResponse(id, addr, resp)
86+
t.transport.sendResponse(n.ID(), addr, resp)
8287
return
8388
}
8489

@@ -90,9 +95,9 @@ func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire
9095
go func() {
9196
defer func() { t.slots <- struct{}{} }()
9297
udpAddr := &net.UDPAddr{IP: addr.Addr().AsSlice(), Port: int(addr.Port())}
93-
respMessage := handler(id, udpAddr, req.Message)
98+
respMessage := handler(n, udpAddr, req.Message)
9499
resp := &v5wire.TalkResponse{ReqID: req.ReqID, Message: respMessage}
95-
t.transport.sendFromAnotherThread(id, addr, resp)
100+
t.transport.sendFromAnotherThread(n.ID(), addr, resp)
96101
}()
97102
case <-timeout.C:
98103
// Couldn't get it in time, drop the request.

p2p/discover/v5_udp.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ type codecV5 interface {
6464
// CurrentChallenge returns the most recent WHOAREYOU challenge that was encoded to given node.
6565
// This will return a non-nil value if there is an active handshake attempt with the node, and nil otherwise.
6666
CurrentChallenge(id enode.ID, addr string) *v5wire.Whoareyou
67+
68+
// SessionNode returns a node that has completed the handshake.
69+
SessionNode(enode.ID, string) *enode.Node
6770
}
6871

6972
// UDPv5 is the implementation of protocol version 5.

p2p/discover/v5_udp_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ func TestUDPv5_talkHandling(t *testing.T) {
486486
defer test.close()
487487

488488
var recvMessage []byte
489-
test.udp.RegisterTalkHandler("test", func(id enode.ID, addr *net.UDPAddr, message []byte) []byte {
489+
test.udp.RegisterTalkHandler("test", func(n *enode.Node, addr *net.UDPAddr, message []byte) []byte {
490490
recvMessage = message
491491
return []byte("test response")
492492
})
@@ -795,6 +795,10 @@ func (c *testCodec) Decode(input []byte, addr string) (enode.ID, *enode.Node, v5
795795
return frame.NodeID, nil, p, nil
796796
}
797797

798+
func (c *testCodec) SessionNode(id enode.ID, addr string) *enode.Node {
799+
return c.test.nodesByID[id].Node()
800+
}
801+
798802
func (c *testCodec) decodeFrame(input []byte) (frame testCodecFrame, p v5wire.Packet, err error) {
799803
if err = rlp.DecodeBytes(input, &frame); err != nil {
800804
return frame, nil, fmt.Errorf("invalid frame: %v", err)

p2p/discover/v5wire/encoding.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ func (c *Codec) encodeHandshakeHeader(toID enode.ID, addr string, challenge *Who
347347
}
348348

349349
// TODO: this should happen when the first authenticated message is received
350-
c.sc.storeNewSession(toID, addr, session)
350+
c.sc.storeNewSession(toID, addr, session, challenge.Node)
351351

352352
// Encode the auth header.
353353
var (
@@ -522,7 +522,7 @@ func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData
522522
}
523523

524524
// Handshake OK, drop the challenge and store the new session keys.
525-
c.sc.storeNewSession(auth.h.SrcID, fromAddr, session)
525+
c.sc.storeNewSession(auth.h.SrcID, fromAddr, session, node)
526526
c.sc.deleteHandshake(auth.h.SrcID, fromAddr)
527527
return node, msg, nil
528528
}
@@ -644,6 +644,10 @@ func (c *Codec) decryptMessage(input, nonce, headerData, readKey []byte) (Packet
644644
return DecodeMessage(msgdata[0], msgdata[1:])
645645
}
646646

647+
func (c *Codec) SessionNode(id enode.ID, addr string) *enode.Node {
648+
return c.sc.readNode(id, addr)
649+
}
650+
647651
// checkValid performs some basic validity checks on the header.
648652
// The packetLen here is the length remaining after the static header.
649653
func (h *StaticHeader) checkValid(packetLen int, protocolID [6]byte) error {

p2p/discover/v5wire/encoding_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func TestHandshake_rekey(t *testing.T) {
166166
readKey: []byte("BBBBBBBBBBBBBBBB"),
167167
writeKey: []byte("AAAAAAAAAAAAAAAA"),
168168
}
169-
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session)
169+
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session, net.nodeB.n())
170170

171171
// A -> B FINDNODE (encrypted with zero keys)
172172
findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{})
@@ -209,8 +209,8 @@ func TestHandshake_rekey2(t *testing.T) {
209209
readKey: []byte("CCCCCCCCCCCCCCCC"),
210210
writeKey: []byte("DDDDDDDDDDDDDDDD"),
211211
}
212-
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA)
213-
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB)
212+
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA, net.nodeB.n())
213+
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB, net.nodeA.n())
214214

215215
// A -> B FINDNODE encrypted with initKeysA
216216
findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{Distances: []uint{3}})
@@ -362,8 +362,8 @@ func TestTestVectorsV5(t *testing.T) {
362362
ENRSeq: 2,
363363
},
364364
prep: func(net *handshakeTest) {
365-
net.nodeA.c.sc.storeNewSession(idB, addr, session)
366-
net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped())
365+
net.nodeA.c.sc.storeNewSession(idB, addr, session, net.nodeB.n())
366+
net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped(), net.nodeA.n())
367367
},
368368
},
369369
{
@@ -499,8 +499,8 @@ func BenchmarkV5_DecodePing(b *testing.B) {
499499
readKey: []byte{233, 203, 93, 195, 86, 47, 177, 186, 227, 43, 2, 141, 244, 230, 120, 17},
500500
writeKey: []byte{79, 145, 252, 171, 167, 216, 252, 161, 208, 190, 176, 106, 214, 39, 178, 134},
501501
}
502-
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session)
503-
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped())
502+
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session, net.nodeB.n())
503+
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped(), net.nodeA.n())
504504
addrB := net.nodeA.addr()
505505
ping := &Ping{ReqID: []byte("reqid"), ENRSeq: 5}
506506
enc, _, err := net.nodeA.c.Encode(net.nodeB.id(), addrB, ping, nil)

p2p/discover/v5wire/session.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@ type session struct {
5454
writeKey []byte
5555
readKey []byte
5656
nonceCounter uint32
57+
node *enode.Node
5758
}
5859

5960
// keysFlipped returns a copy of s with the read and write keys flipped.
6061
func (s *session) keysFlipped() *session {
61-
return &session{s.readKey, s.writeKey, s.nonceCounter}
62+
return &session{s.readKey, s.writeKey, s.nonceCounter, s.node}
6263
}
6364

6465
func NewSessionCache(maxItems int, clock mclock.Clock) *SessionCache {
@@ -103,8 +104,19 @@ func (sc *SessionCache) readKey(id enode.ID, addr string) []byte {
103104
return nil
104105
}
105106

107+
func (sc *SessionCache) readNode(id enode.ID, addr string) *enode.Node {
108+
if s := sc.session(id, addr); s != nil {
109+
return s.node
110+
}
111+
return nil
112+
}
113+
106114
// storeNewSession stores new encryption keys in the cache.
107-
func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session) {
115+
func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session, n *enode.Node) {
116+
if n == nil {
117+
panic("session must caches a non-nil node")
118+
}
119+
s.node = n
108120
sc.sessions.Add(sessionID{id, addr}, s)
109121
}
110122

0 commit comments

Comments
 (0)