@@ -82,6 +82,13 @@ enum HTLCInterceptedAction {
82
82
ForwardPayment ( ForwardPaymentAction ) ,
83
83
}
84
84
85
+ /// Possible actions that need to be taken when a payment is forwarded.
86
+ #[ derive( Debug , PartialEq ) ]
87
+ enum PaymentForwardedAction {
88
+ ForwardPayment ( ForwardPaymentAction ) ,
89
+ ForwardHTLCs ( ForwardHTLCsAction ) ,
90
+ }
91
+
85
92
/// The forwarding of a payment while skimming the JIT channel opening fee.
86
93
#[ derive( Debug , PartialEq ) ]
87
94
struct ForwardPaymentAction ( ChannelId , FeePayment ) ;
@@ -319,17 +326,36 @@ impl OutboundJITChannelState {
319
326
}
320
327
321
328
fn payment_forwarded (
322
- & mut self ,
323
- ) -> Result < ( Self , Option < ForwardHTLCsAction > ) , ChannelStateError > {
329
+ & mut self , skimmed_fee_msat : Option < u64 > ,
330
+ ) -> Result < ( Self , Option < PaymentForwardedAction > ) , ChannelStateError > {
324
331
match self {
325
332
OutboundJITChannelState :: PendingPaymentForward {
326
- payment_queue, channel_id, ..
333
+ payment_queue,
334
+ channel_id,
335
+ opening_fee_msat,
327
336
} => {
328
337
let mut payment_queue_lock = payment_queue. lock ( ) . unwrap ( ) ;
329
- let payment_forwarded =
330
- OutboundJITChannelState :: PaymentForwarded { channel_id : * channel_id } ;
331
- let forward_htlcs = ForwardHTLCsAction ( * channel_id, payment_queue_lock. clear ( ) ) ;
332
- Ok ( ( payment_forwarded, Some ( forward_htlcs) ) )
338
+
339
+ let skimmed_fee_msat = skimmed_fee_msat. unwrap_or ( 0 ) ;
340
+ let remaining_fee = opening_fee_msat. saturating_sub ( skimmed_fee_msat) ;
341
+
342
+ if remaining_fee > 0 {
343
+ let ( state, payment_action) = try_get_payment (
344
+ Arc :: clone ( payment_queue) ,
345
+ payment_queue_lock,
346
+ * channel_id,
347
+ remaining_fee,
348
+ ) ;
349
+ Ok ( ( state, payment_action. map ( |pa| PaymentForwardedAction :: ForwardPayment ( pa) ) ) )
350
+ } else {
351
+ let payment_forwarded =
352
+ OutboundJITChannelState :: PaymentForwarded { channel_id : * channel_id } ;
353
+ let forward_htlcs = ForwardHTLCsAction ( * channel_id, payment_queue_lock. clear ( ) ) ;
354
+ Ok ( (
355
+ payment_forwarded,
356
+ Some ( PaymentForwardedAction :: ForwardHTLCs ( forward_htlcs) ) ,
357
+ ) )
358
+ }
333
359
} ,
334
360
OutboundJITChannelState :: PaymentForwarded { channel_id } => {
335
361
let payment_forwarded =
@@ -363,6 +389,10 @@ impl OutboundJITChannel {
363
389
}
364
390
}
365
391
392
+ pub fn has_paid_fee ( & self ) -> bool {
393
+ matches ! ( self . state, OutboundJITChannelState :: PaymentForwarded { .. } )
394
+ }
395
+
366
396
fn htlc_intercepted (
367
397
& mut self , htlc : InterceptedHTLC ,
368
398
) -> Result < Option < HTLCInterceptedAction > , LightningError > {
@@ -386,8 +416,10 @@ impl OutboundJITChannel {
386
416
Ok ( action)
387
417
}
388
418
389
- fn payment_forwarded ( & mut self ) -> Result < Option < ForwardHTLCsAction > , LightningError > {
390
- let ( new_state, action) = self . state . payment_forwarded ( ) ?;
419
+ fn payment_forwarded (
420
+ & mut self , skimmed_fee_msat : Option < u64 > ,
421
+ ) -> Result < Option < PaymentForwardedAction > , LightningError > {
422
+ let ( new_state, action) = self . state . payment_forwarded ( skimmed_fee_msat) ?;
391
423
self . state = new_state;
392
424
Ok ( action)
393
425
}
@@ -813,7 +845,9 @@ where
813
845
/// greater or equal to 0.0.107.
814
846
///
815
847
/// [`Event::PaymentForwarded`]: lightning::events::Event::PaymentForwarded
816
- pub fn payment_forwarded ( & self , next_channel_id : ChannelId ) -> Result < ( ) , APIError > {
848
+ pub fn payment_forwarded (
849
+ & self , next_channel_id : ChannelId , skimmed_fee_msat : Option < u64 > ,
850
+ ) -> Result < bool , APIError > {
817
851
if let Some ( counterparty_node_id) =
818
852
self . peer_by_channel_id . read ( ) . unwrap ( ) . get ( & next_channel_id)
819
853
{
@@ -827,8 +861,10 @@ where
827
861
if let Some ( jit_channel) =
828
862
peer_state. outbound_channels_by_intercept_scid . get_mut ( & intercept_scid)
829
863
{
830
- match jit_channel. payment_forwarded ( ) {
831
- Ok ( Some ( ForwardHTLCsAction ( channel_id, htlcs) ) ) => {
864
+ match jit_channel. payment_forwarded ( skimmed_fee_msat) {
865
+ Ok ( Some ( PaymentForwardedAction :: ForwardHTLCs (
866
+ ForwardHTLCsAction ( channel_id, htlcs) ,
867
+ ) ) ) => {
832
868
for htlc in htlcs {
833
869
self . channel_manager . get_cm ( ) . forward_intercepted_htlc (
834
870
htlc. intercept_id ,
@@ -838,6 +874,29 @@ where
838
874
) ?;
839
875
}
840
876
} ,
877
+ Ok ( Some ( PaymentForwardedAction :: ForwardPayment (
878
+ ForwardPaymentAction (
879
+ channel_id,
880
+ FeePayment { htlcs, opening_fee_msat } ,
881
+ ) ,
882
+ ) ) ) => {
883
+ let amounts_to_forward_msat =
884
+ calculate_amount_to_forward_per_htlc (
885
+ & htlcs,
886
+ opening_fee_msat,
887
+ ) ;
888
+
889
+ for ( intercept_id, amount_to_forward_msat) in
890
+ amounts_to_forward_msat
891
+ {
892
+ self . channel_manager . get_cm ( ) . forward_intercepted_htlc (
893
+ intercept_id,
894
+ & channel_id,
895
+ * counterparty_node_id,
896
+ amount_to_forward_msat,
897
+ ) ?;
898
+ }
899
+ } ,
841
900
Ok ( None ) => { } ,
842
901
Err ( e) => {
843
902
return Err ( APIError :: APIMisuseError {
@@ -848,6 +907,7 @@ where
848
907
} )
849
908
} ,
850
909
}
910
+ return Ok ( jit_channel. has_paid_fee ( ) ) ;
851
911
}
852
912
} else {
853
913
return Err ( APIError :: APIMisuseError {
@@ -863,7 +923,7 @@ where
863
923
}
864
924
}
865
925
866
- Ok ( ( ) )
926
+ Ok ( false )
867
927
}
868
928
869
929
/// Forward [`Event::ChannelReady`] event parameters into this function.
@@ -1410,12 +1470,18 @@ mod tests {
1410
1470
}
1411
1471
state = new_state;
1412
1472
}
1473
+
1474
+ // TODO: how do I get the expected skimmed amount here
1475
+
1413
1476
// Payment completes, queued payments get forwarded.
1414
1477
{
1415
- let ( new_state, action) = state. payment_forwarded ( ) . unwrap ( ) ;
1478
+ let ( new_state, action) = state. payment_forwarded ( Some ( 100_000_000 ) ) . unwrap ( ) ;
1416
1479
assert ! ( matches!( new_state, OutboundJITChannelState :: PaymentForwarded { .. } ) ) ;
1417
1480
match action {
1418
- Some ( ForwardHTLCsAction ( channel_id, htlcs) ) => {
1481
+ Some ( PaymentForwardedAction :: ForwardHTLCs ( ForwardHTLCsAction (
1482
+ channel_id,
1483
+ htlcs,
1484
+ ) ) ) => {
1419
1485
assert_eq ! ( channel_id, ChannelId ( [ 200 ; 32 ] ) ) ;
1420
1486
assert_eq ! (
1421
1487
htlcs,
@@ -1551,12 +1617,18 @@ mod tests {
1551
1617
}
1552
1618
state = new_state;
1553
1619
}
1620
+
1621
+ // TODO: how do I grab the expected skimmed fee amount here.
1622
+
1554
1623
// Payment completes, queued payments get forwarded.
1555
1624
{
1556
- let ( new_state, action) = state. payment_forwarded ( ) . unwrap ( ) ;
1625
+ let ( new_state, action) = state. payment_forwarded ( Some ( 100_000_000 ) ) . unwrap ( ) ;
1557
1626
assert ! ( matches!( new_state, OutboundJITChannelState :: PaymentForwarded { .. } ) ) ;
1558
1627
match action {
1559
- Some ( ForwardHTLCsAction ( channel_id, htlcs) ) => {
1628
+ Some ( PaymentForwardedAction :: ForwardHTLCs ( ForwardHTLCsAction (
1629
+ channel_id,
1630
+ htlcs,
1631
+ ) ) ) => {
1560
1632
assert_eq ! ( channel_id, ChannelId ( [ 200 ; 32 ] ) ) ;
1561
1633
assert_eq ! (
1562
1634
htlcs,
0 commit comments