Skip to content
This repository was archived by the owner on Jan 6, 2025. It is now read-only.

Commit 34d736e

Browse files
check if fee has been collected
1 parent 6658675 commit 34d736e

File tree

1 file changed

+89
-17
lines changed

1 file changed

+89
-17
lines changed

src/lsps2/service.rs

+89-17
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ enum HTLCInterceptedAction {
8282
ForwardPayment(ForwardPaymentAction),
8383
}
8484

85+
/// Possible actions that need to be taken when a payment is forwarded.
86+
#[derive(Debug, PartialEq)]
87+
enum PaymentForwardedAction {
88+
ForwardPayment(ForwardPaymentAction),
89+
ForwardHTLCs(ForwardHTLCsAction),
90+
}
91+
8592
/// The forwarding of a payment while skimming the JIT channel opening fee.
8693
#[derive(Debug, PartialEq)]
8794
struct ForwardPaymentAction(ChannelId, FeePayment);
@@ -319,17 +326,36 @@ impl OutboundJITChannelState {
319326
}
320327

321328
fn payment_forwarded(
322-
&mut self,
323-
) -> Result<(Self, Option<ForwardHTLCsAction>), ChannelStateError> {
329+
&mut self, skimmed_fee_msat: Option<u64>,
330+
) -> Result<(Self, Option<PaymentForwardedAction>), ChannelStateError> {
324331
match self {
325332
OutboundJITChannelState::PendingPaymentForward {
326-
payment_queue, channel_id, ..
333+
payment_queue,
334+
channel_id,
335+
opening_fee_msat,
327336
} => {
328337
let mut payment_queue_lock = payment_queue.lock().unwrap();
329-
let payment_forwarded =
330-
OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id };
331-
let forward_htlcs = ForwardHTLCsAction(*channel_id, payment_queue_lock.clear());
332-
Ok((payment_forwarded, Some(forward_htlcs)))
338+
339+
let skimmed_fee_msat = skimmed_fee_msat.unwrap_or(0);
340+
let remaining_fee = opening_fee_msat.saturating_sub(skimmed_fee_msat);
341+
342+
if remaining_fee > 0 {
343+
let (state, payment_action) = try_get_payment(
344+
Arc::clone(payment_queue),
345+
payment_queue_lock,
346+
*channel_id,
347+
remaining_fee,
348+
);
349+
Ok((state, payment_action.map(|pa| PaymentForwardedAction::ForwardPayment(pa))))
350+
} else {
351+
let payment_forwarded =
352+
OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id };
353+
let forward_htlcs = ForwardHTLCsAction(*channel_id, payment_queue_lock.clear());
354+
Ok((
355+
payment_forwarded,
356+
Some(PaymentForwardedAction::ForwardHTLCs(forward_htlcs)),
357+
))
358+
}
333359
},
334360
OutboundJITChannelState::PaymentForwarded { channel_id } => {
335361
let payment_forwarded =
@@ -363,6 +389,10 @@ impl OutboundJITChannel {
363389
}
364390
}
365391

392+
pub fn has_paid_fee(&self) -> bool {
393+
matches!(self.state, OutboundJITChannelState::PaymentForwarded { .. })
394+
}
395+
366396
fn htlc_intercepted(
367397
&mut self, htlc: InterceptedHTLC,
368398
) -> Result<Option<HTLCInterceptedAction>, LightningError> {
@@ -386,8 +416,10 @@ impl OutboundJITChannel {
386416
Ok(action)
387417
}
388418

389-
fn payment_forwarded(&mut self) -> Result<Option<ForwardHTLCsAction>, LightningError> {
390-
let (new_state, action) = self.state.payment_forwarded()?;
419+
fn payment_forwarded(
420+
&mut self, skimmed_fee_msat: Option<u64>,
421+
) -> Result<Option<PaymentForwardedAction>, LightningError> {
422+
let (new_state, action) = self.state.payment_forwarded(skimmed_fee_msat)?;
391423
self.state = new_state;
392424
Ok(action)
393425
}
@@ -813,7 +845,9 @@ where
813845
/// greater or equal to 0.0.107.
814846
///
815847
/// [`Event::PaymentForwarded`]: lightning::events::Event::PaymentForwarded
816-
pub fn payment_forwarded(&self, next_channel_id: ChannelId) -> Result<(), APIError> {
848+
pub fn payment_forwarded(
849+
&self, next_channel_id: ChannelId, skimmed_fee_msat: Option<u64>,
850+
) -> Result<bool, APIError> {
817851
if let Some(counterparty_node_id) =
818852
self.peer_by_channel_id.read().unwrap().get(&next_channel_id)
819853
{
@@ -827,8 +861,10 @@ where
827861
if let Some(jit_channel) =
828862
peer_state.outbound_channels_by_intercept_scid.get_mut(&intercept_scid)
829863
{
830-
match jit_channel.payment_forwarded() {
831-
Ok(Some(ForwardHTLCsAction(channel_id, htlcs))) => {
864+
match jit_channel.payment_forwarded(skimmed_fee_msat) {
865+
Ok(Some(PaymentForwardedAction::ForwardHTLCs(
866+
ForwardHTLCsAction(channel_id, htlcs),
867+
))) => {
832868
for htlc in htlcs {
833869
self.channel_manager.get_cm().forward_intercepted_htlc(
834870
htlc.intercept_id,
@@ -838,6 +874,29 @@ where
838874
)?;
839875
}
840876
},
877+
Ok(Some(PaymentForwardedAction::ForwardPayment(
878+
ForwardPaymentAction(
879+
channel_id,
880+
FeePayment { htlcs, opening_fee_msat },
881+
),
882+
))) => {
883+
let amounts_to_forward_msat =
884+
calculate_amount_to_forward_per_htlc(
885+
&htlcs,
886+
opening_fee_msat,
887+
);
888+
889+
for (intercept_id, amount_to_forward_msat) in
890+
amounts_to_forward_msat
891+
{
892+
self.channel_manager.get_cm().forward_intercepted_htlc(
893+
intercept_id,
894+
&channel_id,
895+
*counterparty_node_id,
896+
amount_to_forward_msat,
897+
)?;
898+
}
899+
},
841900
Ok(None) => {},
842901
Err(e) => {
843902
return Err(APIError::APIMisuseError {
@@ -848,6 +907,7 @@ where
848907
})
849908
},
850909
}
910+
return Ok(jit_channel.has_paid_fee());
851911
}
852912
} else {
853913
return Err(APIError::APIMisuseError {
@@ -863,7 +923,7 @@ where
863923
}
864924
}
865925

866-
Ok(())
926+
Ok(false)
867927
}
868928

869929
/// Forward [`Event::ChannelReady`] event parameters into this function.
@@ -1410,12 +1470,18 @@ mod tests {
14101470
}
14111471
state = new_state;
14121472
}
1473+
1474+
// TODO: how do I get the expected skimmed amount here
1475+
14131476
// Payment completes, queued payments get forwarded.
14141477
{
1415-
let (new_state, action) = state.payment_forwarded().unwrap();
1478+
let (new_state, action) = state.payment_forwarded(Some(100_000_000)).unwrap();
14161479
assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. }));
14171480
match action {
1418-
Some(ForwardHTLCsAction(channel_id, htlcs)) => {
1481+
Some(PaymentForwardedAction::ForwardHTLCs(ForwardHTLCsAction(
1482+
channel_id,
1483+
htlcs,
1484+
))) => {
14191485
assert_eq!(channel_id, ChannelId([200; 32]));
14201486
assert_eq!(
14211487
htlcs,
@@ -1551,12 +1617,18 @@ mod tests {
15511617
}
15521618
state = new_state;
15531619
}
1620+
1621+
// TODO: how do I grab the expected skimmed fee amount here.
1622+
15541623
// Payment completes, queued payments get forwarded.
15551624
{
1556-
let (new_state, action) = state.payment_forwarded().unwrap();
1625+
let (new_state, action) = state.payment_forwarded(Some(100_000_000)).unwrap();
15571626
assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. }));
15581627
match action {
1559-
Some(ForwardHTLCsAction(channel_id, htlcs)) => {
1628+
Some(PaymentForwardedAction::ForwardHTLCs(ForwardHTLCsAction(
1629+
channel_id,
1630+
htlcs,
1631+
))) => {
15601632
assert_eq!(channel_id, ChannelId([200; 32]));
15611633
assert_eq!(
15621634
htlcs,

0 commit comments

Comments
 (0)