Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/release-notes/release-notes-0.19.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@
* [Fixed a bug](https://github.com/lightningnetwork/lnd/pull/9322) that caused
estimateroutefee to ignore the default payment timeout.

* [Fixed a possible crash of htlcswitch upon shutdown](https://github.com/lightningnetwork/lnd/pull/9140)
caused by a race condition in goroutines tracking mechanism.

# New Features

* [Support](https://github.com/lightningnetwork/lnd/pull/8390) for
Expand Down
138 changes: 88 additions & 50 deletions htlcswitch/switch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package htlcswitch

import (
"bytes"
"context"
"errors"
"fmt"
"math/rand"
Expand Down Expand Up @@ -245,8 +246,8 @@ type Switch struct {
// This will be retrieved by the registered links atomically.
bestHeight uint32

wg sync.WaitGroup
quit chan struct{}
// gm starts and stops tasks in goroutines and waits for them.
gm *fn.GoroutineManager

// cfg is a copy of the configuration struct that the htlc switch
// service was initialized with.
Expand Down Expand Up @@ -368,8 +369,11 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) {
return nil, err
}

gm := fn.NewGoroutineManager()

s := &Switch{
bestHeight: currentHeight,
gm: gm,
cfg: &cfg,
circuits: circuitMap,
linkIndex: make(map[lnwire.ChannelID]ChannelLink),
Expand All @@ -382,7 +386,6 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) {
chanCloseRequests: make(chan *ChanClose),
resolutionMsgs: make(chan *resolutionMsg),
resMsgStore: resStore,
quit: make(chan struct{}),
}

s.aliasToReal = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID)
Expand Down Expand Up @@ -420,14 +423,14 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro
ResolutionMsg: msg,
errChan: errChan,
}:
case <-s.quit:
case <-s.gm.Done():
return ErrSwitchExiting
}

select {
case err := <-errChan:
return err
case <-s.quit:
case <-s.gm.Done():
return ErrSwitchExiting
}
}
Expand Down Expand Up @@ -493,14 +496,11 @@ func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash,
// Since the attempt was known, we can start a goroutine that can
// extract the result when it is available, and pass it on to the
// caller.
s.wg.Add(1)
go func() {
defer s.wg.Done()

ok := s.gm.Go(context.TODO(), func(ctx context.Context) {
var n *networkResult
select {
case n = <-nChan:
case <-s.quit:
case <-ctx.Done():
// We close the result channel to signal a shutdown. We
// don't send any result in this case since the HTLC is
// still in flight.
Expand All @@ -524,7 +524,11 @@ func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash,
return
}
resultChan <- result
}()
})
// The switch shutting down is signaled by closing the channel.
if !ok {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we still need this check? Won't the line <-ctx.Done() be hit when it's shutting down?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If GoroutineManager.Stop is called before the Go method (i.e., the switch is in the process of stopping), the Go method will return false without launching a new goroutine. In such cases, we should perform the same action as if it had stopped after launching the goroutine - specifically, closing resultChan. Failing to close the channel and simply returning it could cause the caller to get stuck indefinitely while waiting to receive from the channel.

close(resultChan)
}

return resultChan, nil
}
Expand Down Expand Up @@ -704,12 +708,19 @@ func (s *Switch) ForwardPackets(linkQuit <-chan struct{},
select {
case <-linkQuit:
return nil
case <-s.quit:

case <-s.gm.Done():
return nil

default:
// Spawn a goroutine to log the errors returned from failed packets.
s.wg.Add(1)
go s.logFwdErrs(&numSent, &wg, fwdChan)
// Spawn a goroutine to log the errors returned from failed
// packets.
ok := s.gm.Go(context.TODO(), func(ctx context.Context) {
s.logFwdErrs(ctx, &numSent, &wg, fwdChan)
})
if !ok {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here - why do we need this check? I think s.logFwdErrs will listen on <-s.gm.Done(): and quit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a similar situation. We should handle it the same way if the goroutine manager is stopped before the Go method is executed.

return nil
}
}

// Make a first pass over the packets, forwarding any settles or fails.
Expand Down Expand Up @@ -820,8 +831,8 @@ func (s *Switch) ForwardPackets(linkQuit <-chan struct{},
}

// logFwdErrs logs any errors received on `fwdChan`.
func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) {
defer s.wg.Done()
func (s *Switch) logFwdErrs(ctx context.Context, num *int, wg *sync.WaitGroup,
fwdChan chan error) {

// Wait here until the outer function has finished persisting
// and routing the packets. This guarantees we don't read from num until
Expand All @@ -836,7 +847,8 @@ func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) {
log.Errorf("Unhandled error while reforwarding htlc "+
"settle/fail over htlcswitch: %v", err)
}
case <-s.quit:

case <-s.gm.Done():
log.Errorf("unable to forward htlc packet " +
"htlc switch was stopped")
return
Expand All @@ -862,7 +874,7 @@ func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error,
return nil
case <-linkQuit:
return ErrLinkShuttingDown
case <-s.quit:
case <-s.gm.Done():
return errors.New("htlc switch was stopped")
}
}
Expand Down Expand Up @@ -940,8 +952,6 @@ func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) (
//
// NOTE: This method MUST be spawned as a goroutine.
func (s *Switch) handleLocalResponse(pkt *htlcPacket) {
defer s.wg.Done()

attemptID := pkt.incomingHTLCID

// The error reason will be unencypted in case this a local
Expand Down Expand Up @@ -1114,7 +1124,9 @@ func (s *Switch) parseFailedPayment(deobfuscator ErrorDecrypter,
// handlePacketForward is used in cases when we need forward the htlc update
// from one channel link to another and be able to propagate the settle/fail
// updates back. This behaviour is achieved by creation of payment circuits.
func (s *Switch) handlePacketForward(packet *htlcPacket) error {
func (s *Switch) handlePacketForward(ctx context.Context,
packet *htlcPacket) error {

switch htlc := packet.htlc.(type) {
// Channel link forwarded us a new htlc, therefore we initiate the
// payment circuit within our internal state so we can properly forward
Expand All @@ -1123,7 +1135,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
return s.handlePacketAdd(packet, htlc)

case *lnwire.UpdateFulfillHTLC:
return s.handlePacketSettle(packet)
return s.handlePacketSettle(ctx, packet)

// Channel link forwarded us an update_fail_htlc message.
//
Expand All @@ -1132,7 +1144,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
// forward it. Thus there's no need to catch `UpdateFailMalformedHTLC`
// here.
case *lnwire.UpdateFailHTLC:
return s.handlePacketFail(packet, htlc)
return s.handlePacketFail(ctx, packet, htlc)

default:
return fmt.Errorf("wrong update type: %T", htlc)
Expand Down Expand Up @@ -1436,7 +1448,7 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint,
case s.chanCloseRequests <- command:
return updateChan, errChan

case <-s.quit:
case <-s.gm.Done():
errChan <- ErrSwitchExiting
close(updateChan)
return updateChan, errChan
Expand All @@ -1453,9 +1465,9 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint,
// total link capacity.
//
// NOTE: This MUST be run as a goroutine.
func (s *Switch) htlcForwarder() {
defer s.wg.Done()

//
//nolint:funlen
func (s *Switch) htlcForwarder(ctx context.Context) {
defer func() {
s.blockEpochStream.Cancel()

Expand Down Expand Up @@ -1489,6 +1501,8 @@ func (s *Switch) htlcForwarder() {
var wg sync.WaitGroup
for _, link := range linksToStop {
wg.Add(1)
// Here it is ok to start a goroutine directly bypassing
// s.gm, because we want for them to complete here.
go func(l ChannelLink) {
defer wg.Done()

Expand Down Expand Up @@ -1613,7 +1627,7 @@ out:
// encounter is due to the circuit already being
// closed. This is fine, as processing this message is
// meant to be idempotent.
err = s.handlePacketForward(pkt)
err = s.handlePacketForward(ctx, pkt)
if err != nil {
log.Errorf("Unable to forward resolution msg: %v", err)
}
Expand All @@ -1622,21 +1636,22 @@ out:
// packet concretely, then either forward it along, or
// interpret a return packet to a locally initialized one.
case cmd := <-s.htlcPlex:
cmd.err <- s.handlePacketForward(cmd.pkt)
cmd.err <- s.handlePacketForward(ctx, cmd.pkt)

// When this time ticks, then it indicates that we should
// collect all the forwarding events since the last internal,
// and write them out to our log.
case <-s.cfg.FwdEventTicker.Ticks():
s.wg.Add(1)
go func() {
defer s.wg.Done()

if err := s.FlushForwardingEvents(); err != nil {
// The error of Go is ignored: if it is shutting down,
// the loop will terminate on the next iteration, in
// s.gm.Done case.
_ = s.gm.Go(ctx, func(ctx context.Context) {
err := s.FlushForwardingEvents()
if err != nil {
log.Errorf("Unable to flush "+
"forwarding events: %v", err)
}
}()
})

// The log ticker has fired, so we'll calculate some forwarding
// stats for the last 10 seconds to display within the logs to
Expand Down Expand Up @@ -1739,7 +1754,7 @@ out:
// memory.
s.pendingSettleFails = s.pendingSettleFails[:0]

case <-s.quit:
case <-s.gm.Done():
return
}
}
Expand All @@ -1749,6 +1764,7 @@ out:
func (s *Switch) Start() error {
if !atomic.CompareAndSwapInt32(&s.started, 0, 1) {
log.Warn("Htlc Switch already started")

return errors.New("htlc switch already started")
}

Expand All @@ -1760,19 +1776,32 @@ func (s *Switch) Start() error {
}
s.blockEpochStream = blockEpochStream

s.wg.Add(1)
go s.htlcForwarder()
ok := s.gm.Go(context.TODO(), func(ctx context.Context) {
s.htlcForwarder(ctx)
})
if !ok {
// We are already stopping so we can ignore the error.
_ = s.Stop()
err = fmt.Errorf("unable to start htlc forwarder: %w",
ErrSwitchExiting)
log.Errorf("%v", err)

return err
}

if err := s.reforwardResponses(); err != nil {
s.Stop()
// We are already stopping so we can ignore the error.
_ = s.Stop()
log.Errorf("unable to reforward responses: %v", err)

return err
}

if err := s.reforwardResolutions(); err != nil {
// We are already stopping so we can ignore the error.
_ = s.Stop()
log.Errorf("unable to reforward resolutions: %v", err)

return err
}

Expand Down Expand Up @@ -1991,9 +2020,8 @@ func (s *Switch) Stop() error {
log.Info("HTLC Switch shutting down...")
defer log.Debug("HTLC Switch shutdown complete")

close(s.quit)

s.wg.Wait()
// Ask running goroutines to stop and wait for them.
s.gm.Stop()

// Wait until all active goroutines have finished exiting before
// stopping the mailboxes, otherwise the mailbox map could still be
Expand Down Expand Up @@ -2349,7 +2377,7 @@ func (s *Switch) RemoveLink(chanID lnwire.ChannelID) {
select {
case <-stopChan:
return
case <-s.quit:
case <-s.gm.Done():
return
}
}
Expand Down Expand Up @@ -2990,7 +3018,9 @@ func (s *Switch) handlePacketAdd(packet *htlcPacket,
}

// handlePacketSettle handles forwarding a settle packet.
func (s *Switch) handlePacketSettle(packet *htlcPacket) error {
func (s *Switch) handlePacketSettle(ctx context.Context,
packet *htlcPacket) error {

// If the source of this packet has not been set, use the circuit map
// to lookup the origin.
circuit, err := s.closeCircuit(packet)
Expand Down Expand Up @@ -3029,8 +3059,12 @@ func (s *Switch) handlePacketSettle(packet *htlcPacket) error {
// NOTE: `closeCircuit` modifies the state of `packet`.
if localHTLC {
// TODO(yy): remove the goroutine and send back the error here.
s.wg.Add(1)
go s.handleLocalResponse(packet)
ok := s.gm.Go(ctx, func(ctx context.Context) {
s.handleLocalResponse(packet)
})
if !ok {
return ErrSwitchExiting
}

// If this is a locally initiated HTLC, there's no need to
// forward it so we exit.
Expand Down Expand Up @@ -3066,7 +3100,7 @@ func (s *Switch) handlePacketSettle(packet *htlcPacket) error {
}

// handlePacketFail handles forwarding a fail packet.
func (s *Switch) handlePacketFail(packet *htlcPacket,
func (s *Switch) handlePacketFail(ctx context.Context, packet *htlcPacket,
htlc *lnwire.UpdateFailHTLC) error {

// If the source of this packet has not been set, use the circuit map
Expand All @@ -3085,8 +3119,12 @@ func (s *Switch) handlePacketFail(packet *htlcPacket,
// NOTE: `closeCircuit` modifies the state of `packet`.
if packet.incomingChanID == hop.Source {
// TODO(yy): remove the goroutine and send back the error here.
s.wg.Add(1)
go s.handleLocalResponse(packet)
ok := s.gm.Go(ctx, func(ctx context.Context) {
s.handleLocalResponse(packet)
})
if !ok {
return ErrSwitchExiting
}

// If this is a locally initiated HTLC, there's no need to
// forward it so we exit.
Expand Down
Loading
Loading