-
Notifications
You must be signed in to change notification settings - Fork 2.2k
htlcswitch: use fn.GoroutineManager #9140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ package htlcswitch | |
|
|
||
| import ( | ||
| "bytes" | ||
| "context" | ||
| "errors" | ||
| "fmt" | ||
| "math/rand" | ||
|
|
@@ -245,8 +246,8 @@ type Switch struct { | |
| // This will be retrieved by the registered links atomically. | ||
| bestHeight uint32 | ||
|
|
||
| wg sync.WaitGroup | ||
| quit chan struct{} | ||
| // gm starts and stops tasks in goroutines and waits for them. | ||
| gm *fn.GoroutineManager | ||
|
|
||
| // cfg is a copy of the configuration struct that the htlc switch | ||
| // service was initialized with. | ||
|
|
@@ -368,8 +369,11 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { | |
| return nil, err | ||
| } | ||
|
|
||
| gm := fn.NewGoroutineManager() | ||
|
|
||
| s := &Switch{ | ||
| bestHeight: currentHeight, | ||
| gm: gm, | ||
| cfg: &cfg, | ||
| circuits: circuitMap, | ||
| linkIndex: make(map[lnwire.ChannelID]ChannelLink), | ||
|
|
@@ -382,7 +386,6 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { | |
| chanCloseRequests: make(chan *ChanClose), | ||
| resolutionMsgs: make(chan *resolutionMsg), | ||
| resMsgStore: resStore, | ||
| quit: make(chan struct{}), | ||
| } | ||
|
|
||
| s.aliasToReal = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID) | ||
|
|
@@ -420,14 +423,14 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro | |
| ResolutionMsg: msg, | ||
| errChan: errChan, | ||
| }: | ||
| case <-s.quit: | ||
| case <-s.gm.Done(): | ||
| return ErrSwitchExiting | ||
| } | ||
|
|
||
| select { | ||
| case err := <-errChan: | ||
| return err | ||
| case <-s.quit: | ||
| case <-s.gm.Done(): | ||
| return ErrSwitchExiting | ||
| } | ||
| } | ||
|
|
@@ -493,14 +496,11 @@ func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash, | |
| // Since the attempt was known, we can start a goroutine that can | ||
| // extract the result when it is available, and pass it on to the | ||
| // caller. | ||
| s.wg.Add(1) | ||
| go func() { | ||
| defer s.wg.Done() | ||
|
|
||
| ok := s.gm.Go(context.TODO(), func(ctx context.Context) { | ||
| var n *networkResult | ||
| select { | ||
| case n = <-nChan: | ||
| case <-s.quit: | ||
| case <-ctx.Done(): | ||
| // We close the result channel to signal a shutdown. We | ||
| // don't send any result in this case since the HTLC is | ||
| // still in flight. | ||
|
|
@@ -524,7 +524,11 @@ func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash, | |
| return | ||
| } | ||
| resultChan <- result | ||
| }() | ||
| }) | ||
| // The switch shutting down is signaled by closing the channel. | ||
| if !ok { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we still need this check? Won't the line
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If |
||
| close(resultChan) | ||
| } | ||
|
|
||
| return resultChan, nil | ||
| } | ||
|
|
@@ -704,12 +708,19 @@ func (s *Switch) ForwardPackets(linkQuit <-chan struct{}, | |
| select { | ||
| case <-linkQuit: | ||
| return nil | ||
| case <-s.quit: | ||
|
|
||
| case <-s.gm.Done(): | ||
| return nil | ||
|
|
||
| default: | ||
| // Spawn a goroutine to log the errors returned from failed packets. | ||
| s.wg.Add(1) | ||
| go s.logFwdErrs(&numSent, &wg, fwdChan) | ||
| // Spawn a goroutine to log the errors returned from failed | ||
| // packets. | ||
| ok := s.gm.Go(context.TODO(), func(ctx context.Context) { | ||
| s.logFwdErrs(ctx, &numSent, &wg, fwdChan) | ||
| }) | ||
| if !ok { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here - why do we need this check? I think
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a similar situation. We should handle it the same way if the goroutine manager is stopped before the Go method is executed. |
||
| return nil | ||
ellemouton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
|
|
||
| // Make a first pass over the packets, forwarding any settles or fails. | ||
|
|
@@ -820,8 +831,8 @@ func (s *Switch) ForwardPackets(linkQuit <-chan struct{}, | |
| } | ||
|
|
||
| // logFwdErrs logs any errors received on `fwdChan`. | ||
| func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) { | ||
| defer s.wg.Done() | ||
| func (s *Switch) logFwdErrs(ctx context.Context, num *int, wg *sync.WaitGroup, | ||
| fwdChan chan error) { | ||
|
|
||
| // Wait here until the outer function has finished persisting | ||
| // and routing the packets. This guarantees we don't read from num until | ||
|
|
@@ -836,7 +847,8 @@ func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) { | |
| log.Errorf("Unhandled error while reforwarding htlc "+ | ||
| "settle/fail over htlcswitch: %v", err) | ||
| } | ||
| case <-s.quit: | ||
|
|
||
| case <-s.gm.Done(): | ||
| log.Errorf("unable to forward htlc packet " + | ||
| "htlc switch was stopped") | ||
| return | ||
|
|
@@ -862,7 +874,7 @@ func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error, | |
| return nil | ||
| case <-linkQuit: | ||
| return ErrLinkShuttingDown | ||
| case <-s.quit: | ||
| case <-s.gm.Done(): | ||
| return errors.New("htlc switch was stopped") | ||
| } | ||
| } | ||
|
|
@@ -940,8 +952,6 @@ func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) ( | |
| // | ||
| // NOTE: This method MUST be spawned as a goroutine. | ||
| func (s *Switch) handleLocalResponse(pkt *htlcPacket) { | ||
| defer s.wg.Done() | ||
|
|
||
| attemptID := pkt.incomingHTLCID | ||
|
|
||
| // The error reason will be unencypted in case this a local | ||
|
|
@@ -1114,7 +1124,9 @@ func (s *Switch) parseFailedPayment(deobfuscator ErrorDecrypter, | |
| // handlePacketForward is used in cases when we need forward the htlc update | ||
| // from one channel link to another and be able to propagate the settle/fail | ||
| // updates back. This behaviour is achieved by creation of payment circuits. | ||
| func (s *Switch) handlePacketForward(packet *htlcPacket) error { | ||
| func (s *Switch) handlePacketForward(ctx context.Context, | ||
| packet *htlcPacket) error { | ||
|
|
||
| switch htlc := packet.htlc.(type) { | ||
| // Channel link forwarded us a new htlc, therefore we initiate the | ||
| // payment circuit within our internal state so we can properly forward | ||
|
|
@@ -1123,7 +1135,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { | |
| return s.handlePacketAdd(packet, htlc) | ||
|
|
||
| case *lnwire.UpdateFulfillHTLC: | ||
| return s.handlePacketSettle(packet) | ||
| return s.handlePacketSettle(ctx, packet) | ||
|
|
||
| // Channel link forwarded us an update_fail_htlc message. | ||
| // | ||
|
|
@@ -1132,7 +1144,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { | |
| // forward it. Thus there's no need to catch `UpdateFailMalformedHTLC` | ||
| // here. | ||
| case *lnwire.UpdateFailHTLC: | ||
| return s.handlePacketFail(packet, htlc) | ||
| return s.handlePacketFail(ctx, packet, htlc) | ||
|
|
||
| default: | ||
| return fmt.Errorf("wrong update type: %T", htlc) | ||
|
|
@@ -1436,7 +1448,7 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, | |
| case s.chanCloseRequests <- command: | ||
| return updateChan, errChan | ||
|
|
||
| case <-s.quit: | ||
| case <-s.gm.Done(): | ||
| errChan <- ErrSwitchExiting | ||
| close(updateChan) | ||
| return updateChan, errChan | ||
|
|
@@ -1453,9 +1465,9 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, | |
| // total link capacity. | ||
| // | ||
| // NOTE: This MUST be run as a goroutine. | ||
| func (s *Switch) htlcForwarder() { | ||
| defer s.wg.Done() | ||
|
|
||
| // | ||
| //nolint:funlen | ||
| func (s *Switch) htlcForwarder(ctx context.Context) { | ||
| defer func() { | ||
| s.blockEpochStream.Cancel() | ||
|
|
||
|
|
@@ -1489,6 +1501,8 @@ func (s *Switch) htlcForwarder() { | |
| var wg sync.WaitGroup | ||
| for _, link := range linksToStop { | ||
| wg.Add(1) | ||
| // Here it is ok to start a goroutine directly bypassing | ||
| // s.gm, because we want for them to complete here. | ||
| go func(l ChannelLink) { | ||
| defer wg.Done() | ||
|
|
||
|
|
@@ -1613,7 +1627,7 @@ out: | |
| // encounter is due to the circuit already being | ||
| // closed. This is fine, as processing this message is | ||
| // meant to be idempotent. | ||
| err = s.handlePacketForward(pkt) | ||
| err = s.handlePacketForward(ctx, pkt) | ||
| if err != nil { | ||
| log.Errorf("Unable to forward resolution msg: %v", err) | ||
| } | ||
|
|
@@ -1622,21 +1636,22 @@ out: | |
| // packet concretely, then either forward it along, or | ||
| // interpret a return packet to a locally initialized one. | ||
| case cmd := <-s.htlcPlex: | ||
| cmd.err <- s.handlePacketForward(cmd.pkt) | ||
| cmd.err <- s.handlePacketForward(ctx, cmd.pkt) | ||
|
|
||
| // When this time ticks, then it indicates that we should | ||
| // collect all the forwarding events since the last internal, | ||
| // and write them out to our log. | ||
| case <-s.cfg.FwdEventTicker.Ticks(): | ||
| s.wg.Add(1) | ||
| go func() { | ||
| defer s.wg.Done() | ||
|
|
||
| if err := s.FlushForwardingEvents(); err != nil { | ||
| // The error of Go is ignored: if it is shutting down, | ||
| // the loop will terminate on the next iteration, in | ||
| // s.gm.Done case. | ||
ellemouton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| _ = s.gm.Go(ctx, func(ctx context.Context) { | ||
| err := s.FlushForwardingEvents() | ||
| if err != nil { | ||
| log.Errorf("Unable to flush "+ | ||
| "forwarding events: %v", err) | ||
| } | ||
| }() | ||
| }) | ||
|
|
||
| // The log ticker has fired, so we'll calculate some forwarding | ||
| // stats for the last 10 seconds to display within the logs to | ||
|
|
@@ -1739,7 +1754,7 @@ out: | |
| // memory. | ||
| s.pendingSettleFails = s.pendingSettleFails[:0] | ||
|
|
||
| case <-s.quit: | ||
| case <-s.gm.Done(): | ||
| return | ||
| } | ||
| } | ||
|
|
@@ -1749,6 +1764,7 @@ out: | |
| func (s *Switch) Start() error { | ||
| if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { | ||
| log.Warn("Htlc Switch already started") | ||
|
|
||
| return errors.New("htlc switch already started") | ||
| } | ||
|
|
||
|
|
@@ -1760,19 +1776,32 @@ func (s *Switch) Start() error { | |
| } | ||
| s.blockEpochStream = blockEpochStream | ||
|
|
||
| s.wg.Add(1) | ||
| go s.htlcForwarder() | ||
| ok := s.gm.Go(context.TODO(), func(ctx context.Context) { | ||
| s.htlcForwarder(ctx) | ||
| }) | ||
| if !ok { | ||
| // We are already stopping so we can ignore the error. | ||
| _ = s.Stop() | ||
| err = fmt.Errorf("unable to start htlc forwarder: %w", | ||
| ErrSwitchExiting) | ||
| log.Errorf("%v", err) | ||
|
|
||
| return err | ||
| } | ||
|
|
||
| if err := s.reforwardResponses(); err != nil { | ||
| s.Stop() | ||
| // We are already stopping so we can ignore the error. | ||
| _ = s.Stop() | ||
| log.Errorf("unable to reforward responses: %v", err) | ||
|
|
||
| return err | ||
| } | ||
|
|
||
| if err := s.reforwardResolutions(); err != nil { | ||
| // We are already stopping so we can ignore the error. | ||
| _ = s.Stop() | ||
| log.Errorf("unable to reforward resolutions: %v", err) | ||
|
|
||
| return err | ||
| } | ||
|
|
||
|
|
@@ -1991,9 +2020,8 @@ func (s *Switch) Stop() error { | |
| log.Info("HTLC Switch shutting down...") | ||
| defer log.Debug("HTLC Switch shutdown complete") | ||
|
|
||
| close(s.quit) | ||
|
|
||
| s.wg.Wait() | ||
| // Ask running goroutines to stop and wait for them. | ||
| s.gm.Stop() | ||
|
|
||
| // Wait until all active goroutines have finished exiting before | ||
| // stopping the mailboxes, otherwise the mailbox map could still be | ||
|
|
@@ -2349,7 +2377,7 @@ func (s *Switch) RemoveLink(chanID lnwire.ChannelID) { | |
| select { | ||
| case <-stopChan: | ||
| return | ||
| case <-s.quit: | ||
| case <-s.gm.Done(): | ||
| return | ||
| } | ||
| } | ||
|
|
@@ -2990,7 +3018,9 @@ func (s *Switch) handlePacketAdd(packet *htlcPacket, | |
| } | ||
|
|
||
| // handlePacketSettle handles forwarding a settle packet. | ||
| func (s *Switch) handlePacketSettle(packet *htlcPacket) error { | ||
| func (s *Switch) handlePacketSettle(ctx context.Context, | ||
| packet *htlcPacket) error { | ||
|
|
||
| // If the source of this packet has not been set, use the circuit map | ||
| // to lookup the origin. | ||
| circuit, err := s.closeCircuit(packet) | ||
|
|
@@ -3029,8 +3059,12 @@ func (s *Switch) handlePacketSettle(packet *htlcPacket) error { | |
| // NOTE: `closeCircuit` modifies the state of `packet`. | ||
| if localHTLC { | ||
| // TODO(yy): remove the goroutine and send back the error here. | ||
| s.wg.Add(1) | ||
| go s.handleLocalResponse(packet) | ||
| ok := s.gm.Go(ctx, func(ctx context.Context) { | ||
| s.handleLocalResponse(packet) | ||
| }) | ||
| if !ok { | ||
| return ErrSwitchExiting | ||
| } | ||
|
|
||
| // If this is a locally initiated HTLC, there's no need to | ||
| // forward it so we exit. | ||
|
|
@@ -3066,7 +3100,7 @@ func (s *Switch) handlePacketSettle(packet *htlcPacket) error { | |
| } | ||
|
|
||
| // handlePacketFail handles forwarding a fail packet. | ||
| func (s *Switch) handlePacketFail(packet *htlcPacket, | ||
| func (s *Switch) handlePacketFail(ctx context.Context, packet *htlcPacket, | ||
| htlc *lnwire.UpdateFailHTLC) error { | ||
|
|
||
| // If the source of this packet has not been set, use the circuit map | ||
|
|
@@ -3085,8 +3119,12 @@ func (s *Switch) handlePacketFail(packet *htlcPacket, | |
| // NOTE: `closeCircuit` modifies the state of `packet`. | ||
| if packet.incomingChanID == hop.Source { | ||
| // TODO(yy): remove the goroutine and send back the error here. | ||
| s.wg.Add(1) | ||
| go s.handleLocalResponse(packet) | ||
| ok := s.gm.Go(ctx, func(ctx context.Context) { | ||
| s.handleLocalResponse(packet) | ||
| }) | ||
| if !ok { | ||
| return ErrSwitchExiting | ||
| } | ||
|
|
||
| // If this is a locally initiated HTLC, there's no need to | ||
| // forward it so we exit. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.