Skip to content
This repository was archived by the owner on Jan 6, 2025. It is now read-only.

Commit 5983ec0

Browse files
check if fee has been collected
1 parent 0513e0c commit 5983ec0

File tree

1 file changed

+95
-23
lines changed

1 file changed

+95
-23
lines changed

src/lsps2/service.rs

+95-23
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ enum HTLCInterceptedAction {
8181
ForwardPayment(ForwardPaymentAction),
8282
}
8383

84+
/// Possible actions that need to be taken when a payment is forwarded.
85+
#[derive(Debug, PartialEq)]
86+
enum PaymentForwardedAction {
87+
ForwardPayment(ForwardPaymentAction),
88+
ForwardHTLCs(ForwardHTLCsAction),
89+
}
90+
8491
/// The forwarding of a payment while skimming the JIT channel opening fee.
8592
#[derive(Debug, PartialEq)]
8693
struct ForwardPaymentAction(ChannelId, FeePayment);
@@ -318,23 +325,42 @@ impl OutboundJITChannelState {
318325
}
319326

320327
fn payment_forwarded(
321-
&mut self,
322-
) -> Result<(Self, Option<ForwardHTLCsAction>), ChannelStateError> {
328+
&mut self, skimmed_fee_msat: Option<u64>,
329+
) -> Result<(Self, Option<PaymentForwardedAction>), ChannelStateError> {
323330
match self {
324331
OutboundJITChannelState::PendingPaymentForward {
325-
payment_queue, channel_id, ..
332+
payment_queue,
333+
channel_id,
334+
opening_fee_msat,
326335
} => {
327336
let mut payment_queue_lock = payment_queue.lock().unwrap();
328-
let payment_forwarded =
329-
OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id };
330-
let htlcs = payment_queue_lock
331-
.clear()
332-
.into_iter()
333-
.map(|(_, htlcs)| htlcs)
334-
.flatten()
335-
.collect();
336-
let forward_htlcs = ForwardHTLCsAction(*channel_id, htlcs);
337-
Ok((payment_forwarded, Some(forward_htlcs)))
337+
338+
let skimmed_fee_msat = skimmed_fee_msat.unwrap_or(0);
339+
let remaining_fee = opening_fee_msat.saturating_sub(skimmed_fee_msat);
340+
341+
if remaining_fee > 0 {
342+
let (state, payment_action) = try_get_payment(
343+
Arc::clone(payment_queue),
344+
payment_queue_lock,
345+
*channel_id,
346+
remaining_fee,
347+
);
348+
Ok((state, payment_action.map(|pa| PaymentForwardedAction::ForwardPayment(pa))))
349+
} else {
350+
let payment_forwarded =
351+
OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id };
352+
let htlcs = payment_queue_lock
353+
.clear()
354+
.into_iter()
355+
.map(|(_, htlcs)| htlcs)
356+
.flatten()
357+
.collect();
358+
let forward_htlcs = ForwardHTLCsAction(*channel_id, htlcs);
359+
Ok((
360+
payment_forwarded,
361+
Some(PaymentForwardedAction::ForwardHTLCs(forward_htlcs)),
362+
))
363+
}
338364
},
339365
OutboundJITChannelState::PaymentForwarded { channel_id } => {
340366
let payment_forwarded =
@@ -368,6 +394,10 @@ impl OutboundJITChannel {
368394
}
369395
}
370396

397+
pub fn has_paid_fee(&self) -> bool {
398+
matches!(self.state, OutboundJITChannelState::PaymentForwarded { .. })
399+
}
400+
371401
fn htlc_intercepted(
372402
&mut self, htlc: InterceptedHTLC,
373403
) -> Result<Option<HTLCInterceptedAction>, LightningError> {
@@ -391,8 +421,10 @@ impl OutboundJITChannel {
391421
Ok(action)
392422
}
393423

394-
fn payment_forwarded(&mut self) -> Result<Option<ForwardHTLCsAction>, LightningError> {
395-
let (new_state, action) = self.state.payment_forwarded()?;
424+
fn payment_forwarded(
425+
&mut self, skimmed_fee_msat: Option<u64>,
426+
) -> Result<Option<PaymentForwardedAction>, LightningError> {
427+
let (new_state, action) = self.state.payment_forwarded(skimmed_fee_msat)?;
396428
self.state = new_state;
397429
Ok(action)
398430
}
@@ -818,7 +850,9 @@ where
818850
/// greater or equal to 0.0.107.
819851
///
820852
/// [`Event::PaymentForwarded`]: lightning::events::Event::PaymentForwarded
821-
pub fn payment_forwarded(&self, next_channel_id: ChannelId) -> Result<(), APIError> {
853+
pub fn payment_forwarded(
854+
&self, next_channel_id: ChannelId, skimmed_fee_msat: Option<u64>,
855+
) -> Result<bool, APIError> {
822856
if let Some(counterparty_node_id) =
823857
self.peer_by_channel_id.read().unwrap().get(&next_channel_id)
824858
{
@@ -832,8 +866,10 @@ where
832866
if let Some(jit_channel) =
833867
peer_state.outbound_channels_by_intercept_scid.get_mut(&intercept_scid)
834868
{
835-
match jit_channel.payment_forwarded() {
836-
Ok(Some(ForwardHTLCsAction(channel_id, htlcs))) => {
869+
match jit_channel.payment_forwarded(skimmed_fee_msat) {
870+
Ok(Some(PaymentForwardedAction::ForwardHTLCs(
871+
ForwardHTLCsAction(channel_id, htlcs),
872+
))) => {
837873
for htlc in htlcs {
838874
self.channel_manager.get_cm().forward_intercepted_htlc(
839875
htlc.intercept_id,
@@ -843,6 +879,29 @@ where
843879
)?;
844880
}
845881
},
882+
Ok(Some(PaymentForwardedAction::ForwardPayment(
883+
ForwardPaymentAction(
884+
channel_id,
885+
FeePayment { htlcs, opening_fee_msat },
886+
),
887+
))) => {
888+
let amounts_to_forward_msat =
889+
calculate_amount_to_forward_per_htlc(
890+
&htlcs,
891+
opening_fee_msat,
892+
);
893+
894+
for (intercept_id, amount_to_forward_msat) in
895+
amounts_to_forward_msat
896+
{
897+
self.channel_manager.get_cm().forward_intercepted_htlc(
898+
intercept_id,
899+
&channel_id,
900+
*counterparty_node_id,
901+
amount_to_forward_msat,
902+
)?;
903+
}
904+
},
846905
Ok(None) => {},
847906
Err(e) => {
848907
return Err(APIError::APIMisuseError {
@@ -853,6 +912,7 @@ where
853912
})
854913
},
855914
}
915+
return Ok(jit_channel.has_paid_fee());
856916
}
857917
} else {
858918
return Err(APIError::APIMisuseError {
@@ -868,7 +928,7 @@ where
868928
}
869929
}
870930

871-
Ok(())
931+
Ok(false)
872932
}
873933

874934
/// Used by LSP to fail intercepted htlcs backwards when the channel open fails for any reason.
@@ -1476,12 +1536,18 @@ mod tests {
14761536
}
14771537
state = new_state;
14781538
}
1539+
1540+
// TODO: how do I get the expected skimmed amount here
1541+
14791542
// Payment completes, queued payments get forwarded.
14801543
{
1481-
let (new_state, action) = state.payment_forwarded().unwrap();
1544+
let (new_state, action) = state.payment_forwarded(Some(100_000_000)).unwrap();
14821545
assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. }));
14831546
match action {
1484-
Some(ForwardHTLCsAction(channel_id, htlcs)) => {
1547+
Some(PaymentForwardedAction::ForwardHTLCs(ForwardHTLCsAction(
1548+
channel_id,
1549+
htlcs,
1550+
))) => {
14851551
assert_eq!(channel_id, ChannelId([200; 32]));
14861552
assert_eq!(
14871553
htlcs,
@@ -1617,12 +1683,18 @@ mod tests {
16171683
}
16181684
state = new_state;
16191685
}
1686+
1687+
// TODO: how do I grab the expected skimmed fee amount here.
1688+
16201689
// Payment completes, queued payments get forwarded.
16211690
{
1622-
let (new_state, action) = state.payment_forwarded().unwrap();
1691+
let (new_state, action) = state.payment_forwarded(Some(100_000_000)).unwrap();
16231692
assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. }));
16241693
match action {
1625-
Some(ForwardHTLCsAction(channel_id, htlcs)) => {
1694+
Some(PaymentForwardedAction::ForwardHTLCs(ForwardHTLCsAction(
1695+
channel_id,
1696+
htlcs,
1697+
))) => {
16261698
assert_eq!(channel_id, ChannelId([200; 32]));
16271699
assert_eq!(
16281700
htlcs,

0 commit comments

Comments
 (0)