@@ -81,6 +81,13 @@ enum HTLCInterceptedAction {
81
81
ForwardPayment ( ForwardPaymentAction ) ,
82
82
}
83
83
84
+ /// Possible actions that need to be taken when a payment is forwarded.
85
+ #[ derive( Debug , PartialEq ) ]
86
+ enum PaymentForwardedAction {
87
+ ForwardPayment ( ForwardPaymentAction ) ,
88
+ ForwardHTLCs ( ForwardHTLCsAction ) ,
89
+ }
90
+
84
91
/// The forwarding of a payment while skimming the JIT channel opening fee.
85
92
#[ derive( Debug , PartialEq ) ]
86
93
struct ForwardPaymentAction ( ChannelId , FeePayment ) ;
@@ -318,23 +325,42 @@ impl OutboundJITChannelState {
318
325
}
319
326
320
327
fn payment_forwarded (
321
- & mut self ,
322
- ) -> Result < ( Self , Option < ForwardHTLCsAction > ) , ChannelStateError > {
328
+ & mut self , skimmed_fee_msat : Option < u64 > ,
329
+ ) -> Result < ( Self , Option < PaymentForwardedAction > ) , ChannelStateError > {
323
330
match self {
324
331
OutboundJITChannelState :: PendingPaymentForward {
325
- payment_queue, channel_id, ..
332
+ payment_queue,
333
+ channel_id,
334
+ opening_fee_msat,
326
335
} => {
327
336
let mut payment_queue_lock = payment_queue. lock ( ) . unwrap ( ) ;
328
- let payment_forwarded =
329
- OutboundJITChannelState :: PaymentForwarded { channel_id : * channel_id } ;
330
- let htlcs = payment_queue_lock
331
- . clear ( )
332
- . into_iter ( )
333
- . map ( |( _, htlcs) | htlcs)
334
- . flatten ( )
335
- . collect ( ) ;
336
- let forward_htlcs = ForwardHTLCsAction ( * channel_id, htlcs) ;
337
- Ok ( ( payment_forwarded, Some ( forward_htlcs) ) )
337
+
338
+ let skimmed_fee_msat = skimmed_fee_msat. unwrap_or ( 0 ) ;
339
+ let remaining_fee = opening_fee_msat. saturating_sub ( skimmed_fee_msat) ;
340
+
341
+ if remaining_fee > 0 {
342
+ let ( state, payment_action) = try_get_payment (
343
+ Arc :: clone ( payment_queue) ,
344
+ payment_queue_lock,
345
+ * channel_id,
346
+ remaining_fee,
347
+ ) ;
348
+ Ok ( ( state, payment_action. map ( |pa| PaymentForwardedAction :: ForwardPayment ( pa) ) ) )
349
+ } else {
350
+ let payment_forwarded =
351
+ OutboundJITChannelState :: PaymentForwarded { channel_id : * channel_id } ;
352
+ let htlcs = payment_queue_lock
353
+ . clear ( )
354
+ . into_iter ( )
355
+ . map ( |( _, htlcs) | htlcs)
356
+ . flatten ( )
357
+ . collect ( ) ;
358
+ let forward_htlcs = ForwardHTLCsAction ( * channel_id, htlcs) ;
359
+ Ok ( (
360
+ payment_forwarded,
361
+ Some ( PaymentForwardedAction :: ForwardHTLCs ( forward_htlcs) ) ,
362
+ ) )
363
+ }
338
364
} ,
339
365
OutboundJITChannelState :: PaymentForwarded { channel_id } => {
340
366
let payment_forwarded =
@@ -368,6 +394,10 @@ impl OutboundJITChannel {
368
394
}
369
395
}
370
396
397
+ pub fn has_paid_fee ( & self ) -> bool {
398
+ matches ! ( self . state, OutboundJITChannelState :: PaymentForwarded { .. } )
399
+ }
400
+
371
401
fn htlc_intercepted (
372
402
& mut self , htlc : InterceptedHTLC ,
373
403
) -> Result < Option < HTLCInterceptedAction > , LightningError > {
@@ -391,8 +421,10 @@ impl OutboundJITChannel {
391
421
Ok ( action)
392
422
}
393
423
394
- fn payment_forwarded ( & mut self ) -> Result < Option < ForwardHTLCsAction > , LightningError > {
395
- let ( new_state, action) = self . state . payment_forwarded ( ) ?;
424
+ fn payment_forwarded (
425
+ & mut self , skimmed_fee_msat : Option < u64 > ,
426
+ ) -> Result < Option < PaymentForwardedAction > , LightningError > {
427
+ let ( new_state, action) = self . state . payment_forwarded ( skimmed_fee_msat) ?;
396
428
self . state = new_state;
397
429
Ok ( action)
398
430
}
@@ -818,7 +850,9 @@ where
818
850
/// greater or equal to 0.0.107.
819
851
///
820
852
/// [`Event::PaymentForwarded`]: lightning::events::Event::PaymentForwarded
821
- pub fn payment_forwarded ( & self , next_channel_id : ChannelId ) -> Result < ( ) , APIError > {
853
+ pub fn payment_forwarded (
854
+ & self , next_channel_id : ChannelId , skimmed_fee_msat : Option < u64 > ,
855
+ ) -> Result < bool , APIError > {
822
856
if let Some ( counterparty_node_id) =
823
857
self . peer_by_channel_id . read ( ) . unwrap ( ) . get ( & next_channel_id)
824
858
{
@@ -832,8 +866,10 @@ where
832
866
if let Some ( jit_channel) =
833
867
peer_state. outbound_channels_by_intercept_scid . get_mut ( & intercept_scid)
834
868
{
835
- match jit_channel. payment_forwarded ( ) {
836
- Ok ( Some ( ForwardHTLCsAction ( channel_id, htlcs) ) ) => {
869
+ match jit_channel. payment_forwarded ( skimmed_fee_msat) {
870
+ Ok ( Some ( PaymentForwardedAction :: ForwardHTLCs (
871
+ ForwardHTLCsAction ( channel_id, htlcs) ,
872
+ ) ) ) => {
837
873
for htlc in htlcs {
838
874
self . channel_manager . get_cm ( ) . forward_intercepted_htlc (
839
875
htlc. intercept_id ,
@@ -843,6 +879,29 @@ where
843
879
) ?;
844
880
}
845
881
} ,
882
+ Ok ( Some ( PaymentForwardedAction :: ForwardPayment (
883
+ ForwardPaymentAction (
884
+ channel_id,
885
+ FeePayment { htlcs, opening_fee_msat } ,
886
+ ) ,
887
+ ) ) ) => {
888
+ let amounts_to_forward_msat =
889
+ calculate_amount_to_forward_per_htlc (
890
+ & htlcs,
891
+ opening_fee_msat,
892
+ ) ;
893
+
894
+ for ( intercept_id, amount_to_forward_msat) in
895
+ amounts_to_forward_msat
896
+ {
897
+ self . channel_manager . get_cm ( ) . forward_intercepted_htlc (
898
+ intercept_id,
899
+ & channel_id,
900
+ * counterparty_node_id,
901
+ amount_to_forward_msat,
902
+ ) ?;
903
+ }
904
+ } ,
846
905
Ok ( None ) => { } ,
847
906
Err ( e) => {
848
907
return Err ( APIError :: APIMisuseError {
@@ -853,6 +912,7 @@ where
853
912
} )
854
913
} ,
855
914
}
915
+ return Ok ( jit_channel. has_paid_fee ( ) ) ;
856
916
}
857
917
} else {
858
918
return Err ( APIError :: APIMisuseError {
@@ -868,7 +928,7 @@ where
868
928
}
869
929
}
870
930
871
- Ok ( ( ) )
931
+ Ok ( false )
872
932
}
873
933
874
934
/// Used by LSP to fail intercepted htlcs backwards when the channel open fails for any reason.
@@ -1476,12 +1536,18 @@ mod tests {
1476
1536
}
1477
1537
state = new_state;
1478
1538
}
1539
+
1540
+ // TODO: how do I get the expected skimmed amount here
1541
+
1479
1542
// Payment completes, queued payments get forwarded.
1480
1543
{
1481
- let ( new_state, action) = state. payment_forwarded ( ) . unwrap ( ) ;
1544
+ let ( new_state, action) = state. payment_forwarded ( Some ( 100_000_000 ) ) . unwrap ( ) ;
1482
1545
assert ! ( matches!( new_state, OutboundJITChannelState :: PaymentForwarded { .. } ) ) ;
1483
1546
match action {
1484
- Some ( ForwardHTLCsAction ( channel_id, htlcs) ) => {
1547
+ Some ( PaymentForwardedAction :: ForwardHTLCs ( ForwardHTLCsAction (
1548
+ channel_id,
1549
+ htlcs,
1550
+ ) ) ) => {
1485
1551
assert_eq ! ( channel_id, ChannelId ( [ 200 ; 32 ] ) ) ;
1486
1552
assert_eq ! (
1487
1553
htlcs,
@@ -1617,12 +1683,18 @@ mod tests {
1617
1683
}
1618
1684
state = new_state;
1619
1685
}
1686
+
1687
+ // TODO: how do I grab the expected skimmed fee amount here.
1688
+
1620
1689
// Payment completes, queued payments get forwarded.
1621
1690
{
1622
- let ( new_state, action) = state. payment_forwarded ( ) . unwrap ( ) ;
1691
+ let ( new_state, action) = state. payment_forwarded ( Some ( 100_000_000 ) ) . unwrap ( ) ;
1623
1692
assert ! ( matches!( new_state, OutboundJITChannelState :: PaymentForwarded { .. } ) ) ;
1624
1693
match action {
1625
- Some ( ForwardHTLCsAction ( channel_id, htlcs) ) => {
1694
+ Some ( PaymentForwardedAction :: ForwardHTLCs ( ForwardHTLCsAction (
1695
+ channel_id,
1696
+ htlcs,
1697
+ ) ) ) => {
1626
1698
assert_eq ! ( channel_id, ChannelId ( [ 200 ; 32 ] ) ) ;
1627
1699
assert_eq ! (
1628
1700
htlcs,
0 commit comments