diff --git a/peer/README.md b/peer/README.md index fac986936..2ea5351f2 100644 --- a/peer/README.md +++ b/peer/README.md @@ -48,7 +48,7 @@ A quick overview of the major features peer provides are as follows: - These could all be sent manually via the standard message output function, but the helpers provide additional nice functionality such as duplicate filtering and address randomization - - Ability to wait for shutdown/disconnect + - Context-aware Run method for all asynchronous I/O processing that blocks until disconnect - Comprehensive test coverage ## Installation and Updating diff --git a/peer/doc.go b/peer/doc.go index ca5703007..03db3e6a7 100644 --- a/peer/doc.go +++ b/peer/doc.go @@ -63,14 +63,12 @@ etc. [NewOutboundPeer] and [NewInboundPeer] must be followed by calling [Peer.Handshake] on the returned instance to perform the initial protocol -negotiation handshake process and finally [Peer.Start] to start all async I/O -goroutines. +negotiation handshake process and finally [Peer.Run] to start all async I/O +goroutines and block until peer disconnection and resource cleanup has +completed. -[Peer.WaitForDisconnect] can be used to block until peer disconnection and -resource cleanup has completed. - -When finished with the peer call [Peer.Disconnect] to close the connection and -clean up all resources. +When finished with the peer, call [Peer.Disconnect] or cancel the context +provided to [Peer.Run] to close the connection and clean up all resources. # Callbacks diff --git a/peer/example_test.go b/peer/example_test.go index a4cead5ce..8be29df1e 100644 --- a/peer/example_test.go +++ b/peer/example_test.go @@ -47,7 +47,7 @@ func mockRemotePeer(listenAddr string) (net.Listener, error) { fmt.Printf("inbound handshake error: %v\n", err) return } - p.Start() + p.Run(context.Background()) }() }() @@ -104,11 +104,13 @@ func Example_newOutboundPeer() { IdleTimeout: time.Second * 120, } p := peer.NewOutboundPeer(peerCfg, conn.RemoteAddr(), conn) - if err := p.Handshake(context.Background(), nil); err != nil { + ctx := context.Background() + if err := p.Handshake(ctx, nil); err != nil { fmt.Printf("outbound peer handshake error: %v\n", err) return } - p.Start() + go p.Run(ctx) + defer p.Disconnect() // Ping the remote peer aysnchronously. p.QueueMessage(wire.NewMsgPing(rand.Uint64()), nil) @@ -121,10 +123,6 @@ func Example_newOutboundPeer() { fmt.Printf("Example_newOutboundPeer: pong timeout") } - // Disconnect the peer. - p.Disconnect() - p.WaitForDisconnect() - // Output: // outbound: received pong } diff --git a/peer/peer.go b/peer/peer.go index c227f1cf3..0ae69baa9 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -2289,7 +2289,7 @@ var errHandshakeTimeout = makeError(ErrHandshakeTimeout, // // This should only be called once when the peer is first connected. // -// The caller MUST only start the async I/O processing with [Peer.Start] after +// The caller MUST only start the async I/O processing with [Peer.Run] after // this function returns without error. func (p *Peer) Handshake(ctx context.Context, onVersion OnVersionCallback) error { handshakeErr := make(chan error, 1) @@ -2319,17 +2319,45 @@ func (p *Peer) Handshake(ctx context.Context, onVersion OnVersionCallback) error return nil } -// Start begins processing input and output messages. Callers MUST only call -// this after [Peer.Handshake] completes without error. -func (p *Peer) Start() { - log.Tracef("Starting peer %s", p) +// Run begins processing input and output messages. Callers MUST only call this +// after [Peer.Handshake] completes without error. +func (p *Peer) Run(ctx context.Context) { + log.Tracef("Running peer %s", p) // The protocol has been negotiated successfully so start processing input // and output messages. - go p.stallHandler() - go p.inHandler() - go p.queueHandler() - go p.outHandler() + var wg sync.WaitGroup + wg.Add(4) + go func() { + p.stallHandler() + wg.Done() + }() + go func() { + p.inHandler() + wg.Done() + }() + go func() { + p.queueHandler() + wg.Done() + }() + go func() { + p.outHandler() + wg.Done() + }() + + // Forcibly disconnect the peer when the context is cancelled which also + // closes the quit channel and thus ensures all of the above goroutines are + // shutdown. + // + // Select across the quit channel as well since the context is not cancelled + // when the connection is closed. + select { + case <-ctx.Done(): + p.Disconnect() + case <-p.quit: + } + + wg.Wait() } // WaitForDisconnect waits until the peer has completely disconnected and all @@ -2388,7 +2416,7 @@ func newPeerBase(cfgOrig *Config, conn net.Conn, inbound bool) *Peer { } // NewInboundPeer returns a new inbound Decred peer. Use [Peer.Handshake] to -// perform the initial version negotiation and then [Peer.Start] to begin +// perform the initial version negotiation and then [Peer.Run] to begin // processing incoming and outgoing messages when the handshake is successful. func NewInboundPeer(cfg *Config, conn net.Conn) *Peer { p := newPeerBase(cfg, conn, true) @@ -2397,7 +2425,7 @@ func NewInboundPeer(cfg *Config, conn net.Conn) *Peer { } // NewOutboundPeer returns a new outbound Decred peer. Use [Peer.Handshake] to -// perform the initial version negotiation and then [Peer.Start] to begin +// perform the initial version negotiation and then [Peer.Run] to begin // processing incoming and outgoing messages when the handshake is successful. func NewOutboundPeer(cfg *Config, addr net.Addr, conn net.Conn) *Peer { p := newPeerBase(cfg, conn, false) diff --git a/peer/peer_test.go b/peer/peer_test.go index 2daf67d01..0f42c30ee 100644 --- a/peer/peer_test.go +++ b/peer/peer_test.go @@ -118,7 +118,7 @@ type peerStats struct { wantBytesReceived uint64 } -// runPeersAsync invokes the [Peer.Start] method on the passed peers in separate +// runPeersAsync invokes the [Peer.Run] method on the passed peers in separate // goroutines and returns a cancelable context and wait group the caller can use // to shutdown the peers and wait for clean shutdown. func runPeersAsync(peers ...*Peer) (context.CancelFunc, *sync.WaitGroup) { @@ -127,12 +127,7 @@ func runPeersAsync(peers ...*Peer) (context.CancelFunc, *sync.WaitGroup) { wg.Add(len(peers)) for _, peer := range peers { go func(peer *Peer) { - peer.Start() - select { - case <-ctx.Done(): - peer.Disconnect() - case <-peer.quit: - } + peer.Run(ctx) wg.Done() }(peer) } diff --git a/server.go b/server.go index 58a2df9b4..a3bcab71a 100644 --- a/server.go +++ b/server.go @@ -910,15 +910,19 @@ func (sp *serverPeer) serveGetData() { // evicting any remaining orphans sent by the peer and shutting down all // goroutines. func (sp *serverPeer) Run(ctx context.Context) { + // Start processing async I/O. + disconnected := make(chan struct{}) var wg sync.WaitGroup - wg.Add(1) + wg.Add(2) go func() { sp.serveGetData() wg.Done() }() - - // Start processing async I/O. - sp.Start() + go func() { + sp.Peer.Run(ctx) + close(disconnected) + wg.Done() + }() // Request all block announcements via full headers instead of the inv // message. @@ -929,7 +933,7 @@ func (sp *serverPeer) Run(ctx context.Context) { // Wait for the peer to disconnect and notify the net sync manager and // server accordingly. - sp.WaitForDisconnect() + <-disconnected srvr := sp.server srvr.DonePeer(sp) srvr.syncManager.OnPeerDisconnected(sp.syncMgrPeer) @@ -2431,13 +2435,8 @@ func connToNetAddr(conn net.Conn) (*addrmgr.NetAddress, error) { // established prior to any further peer setup. // // This function is safe for concurrent access. -func (s *server) handleBannedConn(conn net.Conn) bool { - host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) - if err != nil { - srvrLog.Debugf("can't split hostport %v", err) - conn.Close() - return true - } +func (s *server) handleBannedConn(remoteAddr *addrmgr.NetAddress, conn net.Conn) bool { + host := net.IP(remoteAddr.IP).String() s.peerState.Lock() defer s.peerState.Unlock() @@ -2533,12 +2532,12 @@ func (s *server) inboundPeerConnected(ctx context.Context, conn net.Conn) { } // Disconnect banned connections. - if disconnected := s.handleBannedConn(conn); disconnected { + if disconnected := s.handleBannedConn(remoteNetAddr, conn); disconnected { return } sp := newServerPeer(s, remoteNetAddr, false) - sp.isWhitelisted = isWhitelisted(conn.RemoteAddr()) + sp.isWhitelisted = isWhitelisted(remoteNetAddr) sp.Peer = peer.NewInboundPeer(newPeerConfig(sp), conn) if err := sp.Handshake(ctx, sp.OnVersion); err != nil { srvrLog.Debugf("Failed handshake for inbound peer %s: %v", @@ -2567,7 +2566,7 @@ func (s *server) outboundPeerConnected(ctx context.Context, c *connmgr.ConnReq, // Disconnect banned connections. Ideally we would never connect to a // banned peer, but the connection manager is currently unaware of banned // addresses, so this is needed. - if disconnected := s.handleBannedConn(conn); disconnected { + if disconnected := s.handleBannedConn(remoteNetAddr, conn); disconnected { s.connManager.Disconnect(c.ID()) return } @@ -2576,7 +2575,7 @@ func (s *server) outboundPeerConnected(ctx context.Context, c *connmgr.ConnReq, p := peer.NewOutboundPeer(newPeerConfig(sp), c.Addr, conn) sp.Peer = p sp.connReq.Store(c) - sp.isWhitelisted = isWhitelisted(conn.RemoteAddr()) + sp.isWhitelisted = isWhitelisted(remoteNetAddr) if err := sp.Handshake(ctx, sp.OnVersion); err != nil { srvrLog.Debugf("Failed handshake for outbound peer %s: %v", c.Addr, err) s.connManager.Disconnect(c.ID()) @@ -4691,22 +4690,12 @@ func addLocalAddress(addrMgr *addrmgr.AddrManager, addr string, services wire.Se // isWhitelisted returns whether the IP address is included in the whitelisted // networks and IPs. -func isWhitelisted(addr net.Addr) bool { +func isWhitelisted(addr *addrmgr.NetAddress) bool { if len(cfg.whitelists) == 0 { return false } - host, _, err := net.SplitHostPort(addr.String()) - if err != nil { - srvrLog.Warnf("Unable to SplitHostPort on '%s': %v", addr, err) - return false - } - ip := net.ParseIP(host) - if ip == nil { - srvrLog.Warnf("Unable to parse IP '%s'", addr) - return false - } - + ip := net.IP(addr.IP) for _, ipnet := range cfg.whitelists { if ipnet.Contains(ip) { return true