Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion peer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ A quick overview of the major features peer provides are as follows:
- These could all be sent manually via the standard message output function,
but the helpers provide additional nice functionality such as duplicate
filtering and address randomization
- Ability to wait for shutdown/disconnect
- Context-aware Run method for all asynchronous I/O processing that blocks until disconnect
- Comprehensive test coverage

## Installation and Updating
Expand Down
12 changes: 5 additions & 7 deletions peer/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,12 @@ etc.

[NewOutboundPeer] and [NewInboundPeer] must be followed by calling
[Peer.Handshake] on the returned instance to perform the initial protocol
negotiation handshake process and finally [Peer.Start] to start all async I/O
goroutines.
negotiation handshake process and finally [Peer.Run] to start all async I/O
goroutines and block until peer disconnection and resource cleanup has
completed.

[Peer.WaitForDisconnect] can be used to block until peer disconnection and
resource cleanup has completed.

When finished with the peer call [Peer.Disconnect] to close the connection and
clean up all resources.
When finished with the peer, call [Peer.Disconnect] or cancel the context
provided to [Peer.Run] to close the connection and clean up all resources.

# Callbacks

Expand Down
12 changes: 5 additions & 7 deletions peer/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func mockRemotePeer(listenAddr string) (net.Listener, error) {
fmt.Printf("inbound handshake error: %v\n", err)
return
}
p.Start()
p.Run(context.Background())
}()
}()

Expand Down Expand Up @@ -104,11 +104,13 @@ func Example_newOutboundPeer() {
IdleTimeout: time.Second * 120,
}
p := peer.NewOutboundPeer(peerCfg, conn.RemoteAddr(), conn)
if err := p.Handshake(context.Background(), nil); err != nil {
ctx := context.Background()
if err := p.Handshake(ctx, nil); err != nil {
fmt.Printf("outbound peer handshake error: %v\n", err)
return
}
p.Start()
go p.Run(ctx)
defer p.Disconnect()

// Ping the remote peer aysnchronously.
p.QueueMessage(wire.NewMsgPing(rand.Uint64()), nil)
Expand All @@ -121,10 +123,6 @@ func Example_newOutboundPeer() {
fmt.Printf("Example_newOutboundPeer: pong timeout")
}

// Disconnect the peer.
p.Disconnect()
p.WaitForDisconnect()

// Output:
// outbound: received pong
}
50 changes: 39 additions & 11 deletions peer/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2289,7 +2289,7 @@ var errHandshakeTimeout = makeError(ErrHandshakeTimeout,
//
// This should only be called once when the peer is first connected.
//
// The caller MUST only start the async I/O processing with [Peer.Start] after
// The caller MUST only start the async I/O processing with [Peer.Run] after
// this function returns without error.
func (p *Peer) Handshake(ctx context.Context, onVersion OnVersionCallback) error {
handshakeErr := make(chan error, 1)
Expand Down Expand Up @@ -2319,17 +2319,45 @@ func (p *Peer) Handshake(ctx context.Context, onVersion OnVersionCallback) error
return nil
}

// Start begins processing input and output messages. Callers MUST only call
// this after [Peer.Handshake] completes without error.
func (p *Peer) Start() {
log.Tracef("Starting peer %s", p)
// Run begins processing input and output messages. Callers MUST only call this
// after [Peer.Handshake] completes without error.
func (p *Peer) Run(ctx context.Context) {
log.Tracef("Running peer %s", p)

// The protocol has been negotiated successfully so start processing input
// and output messages.
go p.stallHandler()
go p.inHandler()
go p.queueHandler()
go p.outHandler()
var wg sync.WaitGroup
wg.Add(4)
go func() {
p.stallHandler()
wg.Done()
}()
go func() {
p.inHandler()
wg.Done()
}()
go func() {
p.queueHandler()
wg.Done()
}()
go func() {
p.outHandler()
wg.Done()
}()

// Forcibly disconnect the peer when the context is cancelled which also
// closes the quit channel and thus ensures all of the above goroutines are
// shutdown.
//
// Select across the quit channel as well since the context is not cancelled
// when the connection is closed.
select {
case <-ctx.Done():
p.Disconnect()
case <-p.quit:
}

wg.Wait()
}

// WaitForDisconnect waits until the peer has completely disconnected and all
Expand Down Expand Up @@ -2388,7 +2416,7 @@ func newPeerBase(cfgOrig *Config, conn net.Conn, inbound bool) *Peer {
}

// NewInboundPeer returns a new inbound Decred peer. Use [Peer.Handshake] to
// perform the initial version negotiation and then [Peer.Start] to begin
// perform the initial version negotiation and then [Peer.Run] to begin
// processing incoming and outgoing messages when the handshake is successful.
func NewInboundPeer(cfg *Config, conn net.Conn) *Peer {
p := newPeerBase(cfg, conn, true)
Expand All @@ -2397,7 +2425,7 @@ func NewInboundPeer(cfg *Config, conn net.Conn) *Peer {
}

// NewOutboundPeer returns a new outbound Decred peer. Use [Peer.Handshake] to
// perform the initial version negotiation and then [Peer.Start] to begin
// perform the initial version negotiation and then [Peer.Run] to begin
// processing incoming and outgoing messages when the handshake is successful.
func NewOutboundPeer(cfg *Config, addr net.Addr, conn net.Conn) *Peer {
p := newPeerBase(cfg, conn, false)
Expand Down
9 changes: 2 additions & 7 deletions peer/peer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ type peerStats struct {
wantBytesReceived uint64
}

// runPeersAsync invokes the [Peer.Start] method on the passed peers in separate
// runPeersAsync invokes the [Peer.Run] method on the passed peers in separate
// goroutines and returns a cancelable context and wait group the caller can use
// to shutdown the peers and wait for clean shutdown.
func runPeersAsync(peers ...*Peer) (context.CancelFunc, *sync.WaitGroup) {
Expand All @@ -127,12 +127,7 @@ func runPeersAsync(peers ...*Peer) (context.CancelFunc, *sync.WaitGroup) {
wg.Add(len(peers))
for _, peer := range peers {
go func(peer *Peer) {
peer.Start()
select {
case <-ctx.Done():
peer.Disconnect()
case <-peer.quit:
}
peer.Run(ctx)
wg.Done()
}(peer)
}
Expand Down
45 changes: 17 additions & 28 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -910,15 +910,19 @@ func (sp *serverPeer) serveGetData() {
// evicting any remaining orphans sent by the peer and shutting down all
// goroutines.
func (sp *serverPeer) Run(ctx context.Context) {
// Start processing async I/O.
disconnected := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
wg.Add(2)
go func() {
sp.serveGetData()
wg.Done()
}()

// Start processing async I/O.
sp.Start()
go func() {
sp.Peer.Run(ctx)
close(disconnected)
wg.Done()
}()

// Request all block announcements via full headers instead of the inv
// message.
Expand All @@ -929,7 +933,7 @@ func (sp *serverPeer) Run(ctx context.Context) {

// Wait for the peer to disconnect and notify the net sync manager and
// server accordingly.
sp.WaitForDisconnect()
<-disconnected
srvr := sp.server
srvr.DonePeer(sp)
srvr.syncManager.OnPeerDisconnected(sp.syncMgrPeer)
Expand Down Expand Up @@ -2431,13 +2435,8 @@ func connToNetAddr(conn net.Conn) (*addrmgr.NetAddress, error) {
// established prior to any further peer setup.
//
// This function is safe for concurrent access.
func (s *server) handleBannedConn(conn net.Conn) bool {
host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
srvrLog.Debugf("can't split hostport %v", err)
conn.Close()
return true
}
func (s *server) handleBannedConn(remoteAddr *addrmgr.NetAddress, conn net.Conn) bool {
host := net.IP(remoteAddr.IP).String()

s.peerState.Lock()
defer s.peerState.Unlock()
Expand Down Expand Up @@ -2533,12 +2532,12 @@ func (s *server) inboundPeerConnected(ctx context.Context, conn net.Conn) {
}

// Disconnect banned connections.
if disconnected := s.handleBannedConn(conn); disconnected {
if disconnected := s.handleBannedConn(remoteNetAddr, conn); disconnected {
return
}

sp := newServerPeer(s, remoteNetAddr, false)
sp.isWhitelisted = isWhitelisted(conn.RemoteAddr())
sp.isWhitelisted = isWhitelisted(remoteNetAddr)
sp.Peer = peer.NewInboundPeer(newPeerConfig(sp), conn)
if err := sp.Handshake(ctx, sp.OnVersion); err != nil {
srvrLog.Debugf("Failed handshake for inbound peer %s: %v",
Expand Down Expand Up @@ -2567,7 +2566,7 @@ func (s *server) outboundPeerConnected(ctx context.Context, c *connmgr.ConnReq,
// Disconnect banned connections. Ideally we would never connect to a
// banned peer, but the connection manager is currently unaware of banned
// addresses, so this is needed.
if disconnected := s.handleBannedConn(conn); disconnected {
if disconnected := s.handleBannedConn(remoteNetAddr, conn); disconnected {
s.connManager.Disconnect(c.ID())
return
}
Expand All @@ -2576,7 +2575,7 @@ func (s *server) outboundPeerConnected(ctx context.Context, c *connmgr.ConnReq,
p := peer.NewOutboundPeer(newPeerConfig(sp), c.Addr, conn)
sp.Peer = p
sp.connReq.Store(c)
sp.isWhitelisted = isWhitelisted(conn.RemoteAddr())
sp.isWhitelisted = isWhitelisted(remoteNetAddr)
if err := sp.Handshake(ctx, sp.OnVersion); err != nil {
srvrLog.Debugf("Failed handshake for outbound peer %s: %v", c.Addr, err)
s.connManager.Disconnect(c.ID())
Expand Down Expand Up @@ -4691,22 +4690,12 @@ func addLocalAddress(addrMgr *addrmgr.AddrManager, addr string, services wire.Se

// isWhitelisted returns whether the IP address is included in the whitelisted
// networks and IPs.
func isWhitelisted(addr net.Addr) bool {
func isWhitelisted(addr *addrmgr.NetAddress) bool {
if len(cfg.whitelists) == 0 {
return false
}

host, _, err := net.SplitHostPort(addr.String())
if err != nil {
srvrLog.Warnf("Unable to SplitHostPort on '%s': %v", addr, err)
return false
}
ip := net.ParseIP(host)
if ip == nil {
srvrLog.Warnf("Unable to parse IP '%s'", addr)
return false
}

ip := net.IP(addr.IP)
for _, ipnet := range cfg.whitelists {
if ipnet.Contains(ip) {
return true
Expand Down