diff --git a/p2p/discover/v5_talk.go b/p2p/discover/v5_talk.go index 2246b47141c..dca09870d89 100644 --- a/p2p/discover/v5_talk.go +++ b/p2p/discover/v5_talk.go @@ -39,7 +39,7 @@ const talkHandlerLaunchTimeout = 400 * time.Millisecond // Note that talk handlers are expected to come up with a response very quickly, within at // most 200ms or so. If the handler takes longer than that, the remote end may time out // and wont receive the response. -type TalkRequestHandler func(enode.ID, *net.UDPAddr, []byte) []byte +type TalkRequestHandler func(*enode.Node, *net.UDPAddr, []byte) []byte type talkSystem struct { transport *UDPv5 @@ -72,13 +72,19 @@ func (t *talkSystem) register(protocol string, handler TalkRequestHandler) { // handleRequest handles a talk request. func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire.TalkRequest) { + n := t.transport.codec.SessionNode(id, addr.String()) + if n == nil { + // The node must be contained in the session here, since we wouldn't have + // received the request otherwise. + panic("missing node in session") + } t.mutex.Lock() handler, ok := t.handlers[req.Protocol] t.mutex.Unlock() if !ok { resp := &v5wire.TalkResponse{ReqID: req.ReqID} - t.transport.sendResponse(id, addr, resp) + t.transport.sendResponse(n.ID(), addr, resp) return } @@ -90,9 +96,9 @@ func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire go func() { defer func() { t.slots <- struct{}{} }() udpAddr := &net.UDPAddr{IP: addr.Addr().AsSlice(), Port: int(addr.Port())} - respMessage := handler(id, udpAddr, req.Message) + respMessage := handler(n, udpAddr, req.Message) resp := &v5wire.TalkResponse{ReqID: req.ReqID, Message: respMessage} - t.transport.sendFromAnotherThread(id, addr, resp) + t.transport.sendFromAnotherThread(n.ID(), addr, resp) }() case <-timeout.C: // Couldn't get it in time, drop the request. diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go index 6f7c7971528..9679f5c61a9 100644 --- a/p2p/discover/v5_udp.go +++ b/p2p/discover/v5_udp.go @@ -64,6 +64,9 @@ type codecV5 interface { // CurrentChallenge returns the most recent WHOAREYOU challenge that was encoded to given node. // This will return a non-nil value if there is an active handshake attempt with the node, and nil otherwise. CurrentChallenge(id enode.ID, addr string) *v5wire.Whoareyou + + // SessionNode returns a node that has completed the handshake. + SessionNode(id enode.ID, addr string) *enode.Node } // UDPv5 is the implementation of protocol version 5. diff --git a/p2p/discover/v5_udp_test.go b/p2p/discover/v5_udp_test.go index 3026dff5380..34450d803db 100644 --- a/p2p/discover/v5_udp_test.go +++ b/p2p/discover/v5_udp_test.go @@ -486,7 +486,7 @@ func TestUDPv5_talkHandling(t *testing.T) { defer test.close() var recvMessage []byte - test.udp.RegisterTalkHandler("test", func(id enode.ID, addr *net.UDPAddr, message []byte) []byte { + test.udp.RegisterTalkHandler("test", func(n *enode.Node, addr *net.UDPAddr, message []byte) []byte { recvMessage = message return []byte("test response") }) @@ -795,6 +795,10 @@ func (c *testCodec) Decode(input []byte, addr string) (enode.ID, *enode.Node, v5 return frame.NodeID, nil, p, nil } +func (c *testCodec) SessionNode(id enode.ID, addr string) *enode.Node { + return c.test.nodesByID[id].Node() +} + func (c *testCodec) decodeFrame(input []byte) (frame testCodecFrame, p v5wire.Packet, err error) { if err = rlp.DecodeBytes(input, &frame); err != nil { return frame, nil, fmt.Errorf("invalid frame: %v", err) diff --git a/p2p/discover/v5wire/encoding.go b/p2p/discover/v5wire/encoding.go index e50b7cd16d1..2a6c8807cca 100644 --- a/p2p/discover/v5wire/encoding.go +++ b/p2p/discover/v5wire/encoding.go @@ -347,7 +347,7 @@ func (c *Codec) encodeHandshakeHeader(toID enode.ID, addr string, challenge *Who } // TODO: this should happen when the first authenticated message is received - c.sc.storeNewSession(toID, addr, session) + c.sc.storeNewSession(toID, addr, session, challenge.Node) // Encode the auth header. var ( @@ -522,7 +522,7 @@ func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData } // Handshake OK, drop the challenge and store the new session keys. - c.sc.storeNewSession(auth.h.SrcID, fromAddr, session) + c.sc.storeNewSession(auth.h.SrcID, fromAddr, session, node) c.sc.deleteHandshake(auth.h.SrcID, fromAddr) return node, msg, nil } @@ -644,6 +644,10 @@ func (c *Codec) decryptMessage(input, nonce, headerData, readKey []byte) (Packet return DecodeMessage(msgdata[0], msgdata[1:]) } +func (c *Codec) SessionNode(id enode.ID, addr string) *enode.Node { + return c.sc.readNode(id, addr) +} + // checkValid performs some basic validity checks on the header. // The packetLen here is the length remaining after the static header. func (h *StaticHeader) checkValid(packetLen int, protocolID [6]byte) error { diff --git a/p2p/discover/v5wire/encoding_test.go b/p2p/discover/v5wire/encoding_test.go index df97e40e892..2304d0f1327 100644 --- a/p2p/discover/v5wire/encoding_test.go +++ b/p2p/discover/v5wire/encoding_test.go @@ -166,7 +166,7 @@ func TestHandshake_rekey(t *testing.T) { readKey: []byte("BBBBBBBBBBBBBBBB"), writeKey: []byte("AAAAAAAAAAAAAAAA"), } - net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session) + net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session, net.nodeB.n()) // A -> B FINDNODE (encrypted with zero keys) findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{}) @@ -209,8 +209,8 @@ func TestHandshake_rekey2(t *testing.T) { readKey: []byte("CCCCCCCCCCCCCCCC"), writeKey: []byte("DDDDDDDDDDDDDDDD"), } - net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA) - net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB) + net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA, net.nodeB.n()) + net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB, net.nodeA.n()) // A -> B FINDNODE encrypted with initKeysA findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{Distances: []uint{3}}) @@ -362,8 +362,8 @@ func TestTestVectorsV5(t *testing.T) { ENRSeq: 2, }, prep: func(net *handshakeTest) { - net.nodeA.c.sc.storeNewSession(idB, addr, session) - net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped()) + net.nodeA.c.sc.storeNewSession(idB, addr, session, net.nodeB.n()) + net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped(), net.nodeA.n()) }, }, { @@ -499,8 +499,8 @@ func BenchmarkV5_DecodePing(b *testing.B) { readKey: []byte{233, 203, 93, 195, 86, 47, 177, 186, 227, 43, 2, 141, 244, 230, 120, 17}, writeKey: []byte{79, 145, 252, 171, 167, 216, 252, 161, 208, 190, 176, 106, 214, 39, 178, 134}, } - net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session) - net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped()) + net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session, net.nodeB.n()) + net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped(), net.nodeA.n()) addrB := net.nodeA.addr() ping := &Ping{ReqID: []byte("reqid"), ENRSeq: 5} enc, _, err := net.nodeA.c.Encode(net.nodeB.id(), addrB, ping, nil) diff --git a/p2p/discover/v5wire/session.go b/p2p/discover/v5wire/session.go index 862c21fcee9..5a2166b1438 100644 --- a/p2p/discover/v5wire/session.go +++ b/p2p/discover/v5wire/session.go @@ -54,11 +54,12 @@ type session struct { writeKey []byte readKey []byte nonceCounter uint32 + node *enode.Node } // keysFlipped returns a copy of s with the read and write keys flipped. func (s *session) keysFlipped() *session { - return &session{s.readKey, s.writeKey, s.nonceCounter} + return &session{s.readKey, s.writeKey, s.nonceCounter, s.node} } func NewSessionCache(maxItems int, clock mclock.Clock) *SessionCache { @@ -103,8 +104,19 @@ func (sc *SessionCache) readKey(id enode.ID, addr string) []byte { return nil } +func (sc *SessionCache) readNode(id enode.ID, addr string) *enode.Node { + if s := sc.session(id, addr); s != nil { + return s.node + } + return nil +} + // storeNewSession stores new encryption keys in the cache. -func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session) { +func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session, n *enode.Node) { + if n == nil { + panic("nil node in storeNewSession") + } + s.node = n sc.sessions.Add(sessionID{id, addr}, s) }