Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions p2p/discover/v5_talk.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ const talkHandlerLaunchTimeout = 400 * time.Millisecond
// Note that talk handlers are expected to come up with a response very quickly, within at
// most 200ms or so. If the handler takes longer than that, the remote end may time out
// and wont receive the response.
type TalkRequestHandler func(enode.ID, *net.UDPAddr, []byte) []byte
type TalkRequestHandler func(*enode.Node, *net.UDPAddr, []byte) []byte

type talkSystem struct {
transport *UDPv5
Expand Down Expand Up @@ -72,13 +72,19 @@ func (t *talkSystem) register(protocol string, handler TalkRequestHandler) {

// handleRequest handles a talk request.
func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire.TalkRequest) {
n := t.transport.codec.SessionNode(id, addr.String())
if n == nil {
// The node must be contained in the session here, since we wouldn't have
// received the request otherwise.
panic("missing node in session")
}
t.mutex.Lock()
handler, ok := t.handlers[req.Protocol]
t.mutex.Unlock()

if !ok {
resp := &v5wire.TalkResponse{ReqID: req.ReqID}
t.transport.sendResponse(id, addr, resp)
t.transport.sendResponse(n.ID(), addr, resp)
return
}

Expand All @@ -90,9 +96,9 @@ func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire
go func() {
defer func() { t.slots <- struct{}{} }()
udpAddr := &net.UDPAddr{IP: addr.Addr().AsSlice(), Port: int(addr.Port())}
respMessage := handler(id, udpAddr, req.Message)
respMessage := handler(n, udpAddr, req.Message)
resp := &v5wire.TalkResponse{ReqID: req.ReqID, Message: respMessage}
t.transport.sendFromAnotherThread(id, addr, resp)
t.transport.sendFromAnotherThread(n.ID(), addr, resp)
}()
case <-timeout.C:
// Couldn't get it in time, drop the request.
Expand Down
3 changes: 3 additions & 0 deletions p2p/discover/v5_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ type codecV5 interface {
// CurrentChallenge returns the most recent WHOAREYOU challenge that was encoded to given node.
// This will return a non-nil value if there is an active handshake attempt with the node, and nil otherwise.
CurrentChallenge(id enode.ID, addr string) *v5wire.Whoareyou

// SessionNode returns a node that has completed the handshake.
SessionNode(id enode.ID, addr string) *enode.Node
}

// UDPv5 is the implementation of protocol version 5.
Expand Down
6 changes: 5 additions & 1 deletion p2p/discover/v5_udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ func TestUDPv5_talkHandling(t *testing.T) {
defer test.close()

var recvMessage []byte
test.udp.RegisterTalkHandler("test", func(id enode.ID, addr *net.UDPAddr, message []byte) []byte {
test.udp.RegisterTalkHandler("test", func(n *enode.Node, addr *net.UDPAddr, message []byte) []byte {
recvMessage = message
return []byte("test response")
})
Expand Down Expand Up @@ -795,6 +795,10 @@ func (c *testCodec) Decode(input []byte, addr string) (enode.ID, *enode.Node, v5
return frame.NodeID, nil, p, nil
}

func (c *testCodec) SessionNode(id enode.ID, addr string) *enode.Node {
return c.test.nodesByID[id].Node()
}

func (c *testCodec) decodeFrame(input []byte) (frame testCodecFrame, p v5wire.Packet, err error) {
if err = rlp.DecodeBytes(input, &frame); err != nil {
return frame, nil, fmt.Errorf("invalid frame: %v", err)
Expand Down
8 changes: 6 additions & 2 deletions p2p/discover/v5wire/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ func (c *Codec) encodeHandshakeHeader(toID enode.ID, addr string, challenge *Who
}

// TODO: this should happen when the first authenticated message is received
c.sc.storeNewSession(toID, addr, session)
c.sc.storeNewSession(toID, addr, session, challenge.Node)

// Encode the auth header.
var (
Expand Down Expand Up @@ -522,7 +522,7 @@ func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData
}

// Handshake OK, drop the challenge and store the new session keys.
c.sc.storeNewSession(auth.h.SrcID, fromAddr, session)
c.sc.storeNewSession(auth.h.SrcID, fromAddr, session, node)
c.sc.deleteHandshake(auth.h.SrcID, fromAddr)
return node, msg, nil
}
Expand Down Expand Up @@ -644,6 +644,10 @@ func (c *Codec) decryptMessage(input, nonce, headerData, readKey []byte) (Packet
return DecodeMessage(msgdata[0], msgdata[1:])
}

func (c *Codec) SessionNode(id enode.ID, addr string) *enode.Node {
return c.sc.readNode(id, addr)
}

// checkValid performs some basic validity checks on the header.
// The packetLen here is the length remaining after the static header.
func (h *StaticHeader) checkValid(packetLen int, protocolID [6]byte) error {
Expand Down
14 changes: 7 additions & 7 deletions p2p/discover/v5wire/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func TestHandshake_rekey(t *testing.T) {
readKey: []byte("BBBBBBBBBBBBBBBB"),
writeKey: []byte("AAAAAAAAAAAAAAAA"),
}
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session)
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session, net.nodeB.n())

// A -> B FINDNODE (encrypted with zero keys)
findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{})
Expand Down Expand Up @@ -209,8 +209,8 @@ func TestHandshake_rekey2(t *testing.T) {
readKey: []byte("CCCCCCCCCCCCCCCC"),
writeKey: []byte("DDDDDDDDDDDDDDDD"),
}
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA)
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB)
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA, net.nodeB.n())
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB, net.nodeA.n())

// A -> B FINDNODE encrypted with initKeysA
findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{Distances: []uint{3}})
Expand Down Expand Up @@ -362,8 +362,8 @@ func TestTestVectorsV5(t *testing.T) {
ENRSeq: 2,
},
prep: func(net *handshakeTest) {
net.nodeA.c.sc.storeNewSession(idB, addr, session)
net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped())
net.nodeA.c.sc.storeNewSession(idB, addr, session, net.nodeB.n())
net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped(), net.nodeA.n())
},
},
{
Expand Down Expand Up @@ -499,8 +499,8 @@ func BenchmarkV5_DecodePing(b *testing.B) {
readKey: []byte{233, 203, 93, 195, 86, 47, 177, 186, 227, 43, 2, 141, 244, 230, 120, 17},
writeKey: []byte{79, 145, 252, 171, 167, 216, 252, 161, 208, 190, 176, 106, 214, 39, 178, 134},
}
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session)
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped())
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session, net.nodeB.n())
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped(), net.nodeA.n())
addrB := net.nodeA.addr()
ping := &Ping{ReqID: []byte("reqid"), ENRSeq: 5}
enc, _, err := net.nodeA.c.Encode(net.nodeB.id(), addrB, ping, nil)
Expand Down
16 changes: 14 additions & 2 deletions p2p/discover/v5wire/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ type session struct {
writeKey []byte
readKey []byte
nonceCounter uint32
node *enode.Node
}

// keysFlipped returns a copy of s with the read and write keys flipped.
func (s *session) keysFlipped() *session {
return &session{s.readKey, s.writeKey, s.nonceCounter}
return &session{s.readKey, s.writeKey, s.nonceCounter, s.node}
}

func NewSessionCache(maxItems int, clock mclock.Clock) *SessionCache {
Expand Down Expand Up @@ -103,8 +104,19 @@ func (sc *SessionCache) readKey(id enode.ID, addr string) []byte {
return nil
}

func (sc *SessionCache) readNode(id enode.ID, addr string) *enode.Node {
if s := sc.session(id, addr); s != nil {
return s.node
}
return nil
}

// storeNewSession stores new encryption keys in the cache.
func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session) {
func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session, n *enode.Node) {
if n == nil {
panic("nil node in storeNewSession")
}
s.node = n
sc.sessions.Add(sessionID{id, addr}, s)
}

Expand Down