diff --git a/lightning-liquidity/Cargo.toml b/lightning-liquidity/Cargo.toml index eebb0e8a076..1cc0d988544 100644 --- a/lightning-liquidity/Cargo.toml +++ b/lightning-liquidity/Cargo.toml @@ -37,6 +37,7 @@ lightning-persister = { version = "0.2.0", path = "../lightning-persister", defa proptest = "1.0.0" tokio = { version = "1.35", default-features = false, features = [ "rt-multi-thread", "time", "sync", "macros" ] } +parking_lot = { version = "0.12", default-features = false } [lints.rust.unexpected_cfgs] level = "forbid" diff --git a/lightning-persister/src/test_utils.rs b/lightning-persister/src/test_utils.rs index e6ad42e5bcd..8af33cef55b 100644 --- a/lightning-persister/src/test_utils.rs +++ b/lightning-persister/src/test_utils.rs @@ -113,7 +113,7 @@ pub(crate) fn do_test_data_migration // Integration-test the given KVStore implementation. Test relaying a few payments and check that // the persisted data is updated the appropriate number of times. -pub(crate) fn do_test_store(store_0: &K, store_1: &K) { +pub(crate) fn do_test_store(store_0: &K, store_1: &K) { let chanmon_cfgs = create_chanmon_cfgs(2); let mut node_cfgs = create_node_cfgs(2, &chanmon_cfgs); let chain_mon_0 = test_utils::TestChainMonitor::new( diff --git a/lightning/Cargo.toml b/lightning/Cargo.toml index 5fc763f7e5f..8023ad93693 100644 --- a/lightning/Cargo.toml +++ b/lightning/Cargo.toml @@ -54,6 +54,7 @@ inventory = { version = "0.3", optional = true } regex = "1.5.6" lightning-types = { version = "0.3.0", path = "../lightning-types", features = ["_test_utils"] } lightning-macros = { path = "../lightning-macros" } +parking_lot = { version = "0.12", default-features = false } [dev-dependencies.bitcoin] version = "0.32.2" diff --git a/lightning/src/lib.rs b/lightning/src/lib.rs index 392423b4137..feb4f0ac785 100644 --- a/lightning/src/lib.rs +++ b/lightning/src/lib.rs @@ -63,6 +63,8 @@ extern crate core; #[cfg(ldk_bench)] extern crate criterion; +#[cfg(all(feature = "std", test))] extern crate parking_lot; + #[macro_use] pub mod util; pub mod chain; diff --git a/lightning/src/ln/chanmon_update_fail_tests.rs b/lightning/src/ln/chanmon_update_fail_tests.rs index e99cf017b66..e1c95def431 100644 --- a/lightning/src/ln/chanmon_update_fail_tests.rs +++ b/lightning/src/ln/chanmon_update_fail_tests.rs @@ -3860,3 +3860,225 @@ fn test_claim_to_closed_channel_blocks_claimed_event() { nodes[1].chain_monitor.complete_sole_pending_chan_update(&chan_a.2); expect_payment_claimed!(nodes[1], payment_hash, 1_000_000); } + +#[test] +#[cfg(all(feature = "std", not(target_os = "windows")))] +fn test_single_channel_multiple_mpp() { + use std::sync::atomic::{AtomicBool, Ordering}; + + // Test what happens when we attempt to claim an MPP with many parts that came to us through + // the same channel with a synchronous persistence interface which has very high latency. + // + // Previously, if a `revoke_and_ack` came in while we were still running in + // `ChannelManager::claim_payment` we'd end up hanging waiting to apply a + // `ChannelMonitorUpdate` until after it completed. See the commit which introduced this test + // for more info. + let chanmon_cfgs = create_chanmon_cfgs(9); + let node_cfgs = create_node_cfgs(9, &chanmon_cfgs); + let configs = [None, None, None, None, None, None, None, None, None]; + let node_chanmgrs = create_node_chanmgrs(9, &node_cfgs, &configs); + let mut nodes = create_network(9, &node_cfgs, &node_chanmgrs); + + let node_7_id = nodes[7].node.get_our_node_id(); + let node_8_id = nodes[8].node.get_our_node_id(); + + // Send an MPP payment in six parts along the path shown from top to bottom + // 0 + // 1 2 3 4 5 6 + // 7 + // 8 + // + // We can in theory reproduce this issue with fewer channels/HTLCs, but getting this test + // robust is rather challenging. We rely on having the main test thread wait on locks held in + // the background `claim_funds` thread and unlocking when the `claim_funds` thread completes a + // single `ChannelMonitorUpdate`. + // This thread calls `get_and_clear_pending_msg_events()` and `handle_revoke_and_ack()`, both + // of which require `ChannelManager` locks, but we have to make sure this thread gets a chance + // to be blocked on the mutexes before we let the background thread wake `claim_funds` so that + // the mutex can switch to this main thread. + // This relies on our locks being fair, but also on our threads getting runtime during the test + // run, which can be pretty competitive. Thus we do a dumb dance to be as conservative as + // possible - we have a background thread which completes a `ChannelMonitorUpdate` (by sending + // into the `write_blocker` mpsc) but it doesn't run until a mpsc channel sends from this main + // thread to the background thread, and then we let it sleep a while before we send the + // `ChannelMonitorUpdate` unblocker. + // Further, we give ourselves two chances each time, needing 4 HTLCs just to unlock our two + // `ChannelManager` calls. We then need a few remaining HTLCs to actually trigger the bug, so + // we use 6 HTLCs. + // Finaly, we do not run this test on Winblowz because it, somehow, in 2025, does not implement + // actual preemptive multitasking and thinks that cooperative multitasking somehow is + // acceptable in the 21st century, let alone a quarter of the way into it. + const MAX_THREAD_INIT_TIME: std::time::Duration = std::time::Duration::from_secs(1); + + create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 100_000, 0); + create_announced_chan_between_nodes_with_value(&nodes, 0, 2, 100_000, 0); + create_announced_chan_between_nodes_with_value(&nodes, 0, 3, 100_000, 0); + create_announced_chan_between_nodes_with_value(&nodes, 0, 4, 100_000, 0); + create_announced_chan_between_nodes_with_value(&nodes, 0, 5, 100_000, 0); + create_announced_chan_between_nodes_with_value(&nodes, 0, 6, 100_000, 0); + + create_announced_chan_between_nodes_with_value(&nodes, 1, 7, 100_000, 0); + create_announced_chan_between_nodes_with_value(&nodes, 2, 7, 100_000, 0); + create_announced_chan_between_nodes_with_value(&nodes, 3, 7, 100_000, 0); + create_announced_chan_between_nodes_with_value(&nodes, 4, 7, 100_000, 0); + create_announced_chan_between_nodes_with_value(&nodes, 5, 7, 100_000, 0); + create_announced_chan_between_nodes_with_value(&nodes, 6, 7, 100_000, 0); + create_announced_chan_between_nodes_with_value(&nodes, 7, 8, 1_000_000, 0); + + let (mut route, payment_hash, payment_preimage, payment_secret) = get_route_and_payment_hash!(&nodes[0], nodes[8], 50_000_000); + + send_along_route_with_secret(&nodes[0], route, &[&[&nodes[1], &nodes[7], &nodes[8]], &[&nodes[2], &nodes[7], &nodes[8]], &[&nodes[3], &nodes[7], &nodes[8]], &[&nodes[4], &nodes[7], &nodes[8]], &[&nodes[5], &nodes[7], &nodes[8]], &[&nodes[6], &nodes[7], &nodes[8]]], 50_000_000, payment_hash, payment_secret); + + let (do_a_write, blocker) = std::sync::mpsc::sync_channel(0); + *nodes[8].chain_monitor.write_blocker.lock().unwrap() = Some(blocker); + + // Until we have std::thread::scoped we have to unsafe { turn off the borrow checker }. + // We do this by casting a pointer to a `TestChannelManager` to a pointer to a + // `TestChannelManager` with different (in this case 'static) lifetime. + // This is even suggested in the second example at + // https://doc.rust-lang.org/std/mem/fn.transmute.html#examples + let claim_node: &'static TestChannelManager<'static, 'static> = + unsafe { std::mem::transmute(nodes[8].node as &TestChannelManager) }; + let thrd = std::thread::spawn(move || { + // Initiate the claim in a background thread as it will immediately block waiting on the + // `write_blocker` we set above. + claim_node.claim_funds(payment_preimage); + }); + + // First unlock one monitor so that we have a pending + // `update_fulfill_htlc`/`commitment_signed` pair to pass to our counterparty. + do_a_write.send(()).unwrap(); + + // Then fetch the `update_fulfill_htlc`/`commitment_signed`. Note that the + // `get_and_clear_pending_msg_events` will immediately hang trying to take a peer lock which + // `claim_funds` is holding. Thus, we release a second write after a small sleep in the + // background to give `claim_funds` a chance to step forward, unblocking + // `get_and_clear_pending_msg_events`. + let do_a_write_background = do_a_write.clone(); + let block_thrd2 = AtomicBool::new(true); + let block_thrd2_read: &'static AtomicBool = unsafe { std::mem::transmute(&block_thrd2) }; + let thrd2 = std::thread::spawn(move || { + while block_thrd2_read.load(Ordering::Acquire) { + std::thread::yield_now(); + } + std::thread::sleep(MAX_THREAD_INIT_TIME); + do_a_write_background.send(()).unwrap(); + std::thread::sleep(MAX_THREAD_INIT_TIME); + do_a_write_background.send(()).unwrap(); + }); + block_thrd2.store(false, Ordering::Release); + let first_updates = get_htlc_update_msgs(&nodes[8], &nodes[7].node.get_our_node_id()); + thrd2.join().unwrap(); + + // Disconnect node 6 from all its peers so it doesn't bother to fail the HTLCs back + nodes[7].node.peer_disconnected(nodes[1].node.get_our_node_id()); + nodes[7].node.peer_disconnected(nodes[2].node.get_our_node_id()); + nodes[7].node.peer_disconnected(nodes[3].node.get_our_node_id()); + nodes[7].node.peer_disconnected(nodes[4].node.get_our_node_id()); + nodes[7].node.peer_disconnected(nodes[5].node.get_our_node_id()); + nodes[7].node.peer_disconnected(nodes[6].node.get_our_node_id()); + + nodes[7].node.handle_update_fulfill_htlc(node_8_id, &first_updates.update_fulfill_htlcs[0]); + check_added_monitors(&nodes[7], 1); + expect_payment_forwarded!(nodes[7], nodes[1], nodes[8], Some(1000), false, false); + nodes[7].node.handle_commitment_signed(node_8_id, &first_updates.commitment_signed); + check_added_monitors(&nodes[7], 1); + let (raa, cs) = get_revoke_commit_msgs(&nodes[7], &node_8_id); + + // Now, handle the `revoke_and_ack` from node 5. Note that `claim_funds` is still blocked on + // our peer lock, so we have to release a write to let it process. + // After this call completes, the channel previously would be locked up and should not be able + // to make further progress. + let do_a_write_background = do_a_write.clone(); + let block_thrd3 = AtomicBool::new(true); + let block_thrd3_read: &'static AtomicBool = unsafe { std::mem::transmute(&block_thrd3) }; + let thrd3 = std::thread::spawn(move || { + while block_thrd3_read.load(Ordering::Acquire) { + std::thread::yield_now(); + } + std::thread::sleep(MAX_THREAD_INIT_TIME); + do_a_write_background.send(()).unwrap(); + std::thread::sleep(MAX_THREAD_INIT_TIME); + do_a_write_background.send(()).unwrap(); + }); + block_thrd3.store(false, Ordering::Release); + nodes[8].node.handle_revoke_and_ack(node_7_id, &raa); + thrd3.join().unwrap(); + assert!(!thrd.is_finished()); + + let thrd4 = std::thread::spawn(move || { + do_a_write.send(()).unwrap(); + do_a_write.send(()).unwrap(); + }); + + thrd4.join().unwrap(); + thrd.join().unwrap(); + + expect_payment_claimed!(nodes[8], payment_hash, 50_000_000); + + // At the end, we should have 7 ChannelMonitorUpdates - 6 for HTLC claims, and one for the + // above `revoke_and_ack`. + check_added_monitors(&nodes[8], 7); + + // Now drive everything to the end, at least as far as node 7 is concerned... + *nodes[8].chain_monitor.write_blocker.lock().unwrap() = None; + nodes[8].node.handle_commitment_signed(node_7_id, &cs); + check_added_monitors(&nodes[8], 1); + + let (updates, raa) = get_updates_and_revoke(&nodes[8], &nodes[7].node.get_our_node_id()); + + nodes[7].node.handle_update_fulfill_htlc(node_8_id, &updates.update_fulfill_htlcs[0]); + expect_payment_forwarded!(nodes[7], nodes[2], nodes[8], Some(1000), false, false); + nodes[7].node.handle_update_fulfill_htlc(node_8_id, &updates.update_fulfill_htlcs[1]); + expect_payment_forwarded!(nodes[7], nodes[3], nodes[8], Some(1000), false, false); + let mut next_source = 4; + if let Some(update) = updates.update_fulfill_htlcs.get(2) { + nodes[7].node.handle_update_fulfill_htlc(node_8_id, update); + expect_payment_forwarded!(nodes[7], nodes[4], nodes[8], Some(1000), false, false); + next_source += 1; + } + + nodes[7].node.handle_commitment_signed(node_8_id, &updates.commitment_signed); + nodes[7].node.handle_revoke_and_ack(node_8_id, &raa); + if updates.update_fulfill_htlcs.get(2).is_some() { + check_added_monitors(&nodes[7], 5); + } else { + check_added_monitors(&nodes[7], 4); + } + + let (raa, cs) = get_revoke_commit_msgs(&nodes[7], &node_8_id); + + nodes[8].node.handle_revoke_and_ack(node_7_id, &raa); + nodes[8].node.handle_commitment_signed(node_7_id, &cs); + check_added_monitors(&nodes[8], 2); + + let (updates, raa) = get_updates_and_revoke(&nodes[8], &node_7_id); + + nodes[7].node.handle_update_fulfill_htlc(node_8_id, &updates.update_fulfill_htlcs[0]); + expect_payment_forwarded!(nodes[7], nodes[next_source], nodes[8], Some(1000), false, false); + next_source += 1; + nodes[7].node.handle_update_fulfill_htlc(node_8_id, &updates.update_fulfill_htlcs[1]); + expect_payment_forwarded!(nodes[7], nodes[next_source], nodes[8], Some(1000), false, false); + next_source += 1; + if let Some(update) = updates.update_fulfill_htlcs.get(2) { + nodes[7].node.handle_update_fulfill_htlc(node_8_id, update); + expect_payment_forwarded!(nodes[7], nodes[next_source], nodes[8], Some(1000), false, false); + } + + nodes[7].node.handle_commitment_signed(node_8_id, &updates.commitment_signed); + nodes[7].node.handle_revoke_and_ack(node_8_id, &raa); + if updates.update_fulfill_htlcs.get(2).is_some() { + check_added_monitors(&nodes[7], 5); + } else { + check_added_monitors(&nodes[7], 4); + } + + let (raa, cs) = get_revoke_commit_msgs(&nodes[7], &node_8_id); + nodes[8].node.handle_revoke_and_ack(node_7_id, &raa); + nodes[8].node.handle_commitment_signed(node_7_id, &cs); + check_added_monitors(&nodes[8], 2); + + let raa = get_event_msg!(nodes[8], MessageSendEvent::SendRevokeAndACK, node_7_id); + nodes[7].node.handle_revoke_and_ack(node_8_id, &raa); + check_added_monitors(&nodes[7], 1); +} diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index fab15bfea28..f6eeb4b50f1 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -1132,7 +1132,7 @@ pub(crate) enum MonitorUpdateCompletionAction { /// A pending MPP claim which hasn't yet completed. /// /// Not written to disk. - pending_mpp_claim: Option<(PublicKey, ChannelId, u64, PendingMPPClaimPointer)>, + pending_mpp_claim: Option<(PublicKey, ChannelId, PendingMPPClaimPointer)>, }, /// Indicates an [`events::Event`] should be surfaced to the user and possibly resume the /// operation of another channel. @@ -1234,10 +1234,16 @@ impl From<&MPPClaimHTLCSource> for HTLCClaimSource { } } +#[derive(Debug)] +pub(crate) struct PendingMPPClaim { + channels_without_preimage: Vec<(PublicKey, OutPoint, ChannelId)>, + channels_with_preimage: Vec<(PublicKey, OutPoint, ChannelId)>, +} + #[derive(Clone, Debug, Hash, PartialEq, Eq)] /// The source of an HTLC which is being claimed as a part of an incoming payment. Each part is -/// tracked in [`PendingMPPClaim`] as well as in [`ChannelMonitor`]s, so that it can be converted -/// to an [`HTLCClaimSource`] for claim replays on startup. +/// tracked in [`ChannelMonitor`]s, so that it can be converted to an [`HTLCClaimSource`] for claim +/// replays on startup. struct MPPClaimHTLCSource { counterparty_node_id: PublicKey, funding_txo: OutPoint, @@ -1252,12 +1258,6 @@ impl_writeable_tlv_based!(MPPClaimHTLCSource, { (6, htlc_id, required), }); -#[derive(Debug)] -pub(crate) struct PendingMPPClaim { - channels_without_preimage: Vec, - channels_with_preimage: Vec, -} - #[derive(Clone, Debug, PartialEq, Eq)] /// When we're claiming a(n MPP) payment, we want to store information about that payment in the /// [`ChannelMonitor`] so that we can replay the claim without any information from the @@ -7207,8 +7207,15 @@ where } }).collect(); let pending_mpp_claim_ptr_opt = if sources.len() > 1 { + let mut channels_without_preimage = Vec::with_capacity(mpp_parts.len()); + for part in mpp_parts.iter() { + let chan = (part.counterparty_node_id, part.funding_txo, part.channel_id); + if !channels_without_preimage.contains(&chan) { + channels_without_preimage.push(chan); + } + } Some(Arc::new(Mutex::new(PendingMPPClaim { - channels_without_preimage: mpp_parts.clone(), + channels_without_preimage, channels_with_preimage: Vec::new(), }))) } else { @@ -7219,7 +7226,7 @@ where let this_mpp_claim = pending_mpp_claim_ptr_opt.as_ref().and_then(|pending_mpp_claim| if let Some(cp_id) = htlc.prev_hop.counterparty_node_id { let claim_ptr = PendingMPPClaimPointer(Arc::clone(pending_mpp_claim)); - Some((cp_id, htlc.prev_hop.channel_id, htlc.prev_hop.htlc_id, claim_ptr)) + Some((cp_id, htlc.prev_hop.channel_id, claim_ptr)) } else { None } @@ -7552,7 +7559,7 @@ This indicates a bug inside LDK. Please report this error at https://github.com/ for action in actions.into_iter() { match action { MonitorUpdateCompletionAction::PaymentClaimed { payment_hash, pending_mpp_claim } => { - if let Some((counterparty_node_id, chan_id, htlc_id, claim_ptr)) = pending_mpp_claim { + if let Some((counterparty_node_id, chan_id, claim_ptr)) = pending_mpp_claim { let per_peer_state = self.per_peer_state.read().unwrap(); per_peer_state.get(&counterparty_node_id).map(|peer_state_mutex| { let mut peer_state = peer_state_mutex.lock().unwrap(); @@ -7563,24 +7570,17 @@ This indicates a bug inside LDK. Please report this error at https://github.com/ if *pending_claim == claim_ptr { let mut pending_claim_state_lock = pending_claim.0.lock().unwrap(); let pending_claim_state = &mut *pending_claim_state_lock; - pending_claim_state.channels_without_preimage.retain(|htlc_info| { + pending_claim_state.channels_without_preimage.retain(|(cp, op, cid)| { let this_claim = - htlc_info.counterparty_node_id == counterparty_node_id - && htlc_info.channel_id == chan_id - && htlc_info.htlc_id == htlc_id; + *cp == counterparty_node_id && *cid == chan_id; if this_claim { - pending_claim_state.channels_with_preimage.push(htlc_info.clone()); + pending_claim_state.channels_with_preimage.push((*cp, *op, *cid)); false } else { true } }); if pending_claim_state.channels_without_preimage.is_empty() { - for htlc_info in pending_claim_state.channels_with_preimage.iter() { - let freed_chan = ( - htlc_info.counterparty_node_id, - htlc_info.funding_txo, - htlc_info.channel_id, - blocker.clone() - ); + for (cp, op, cid) in pending_claim_state.channels_with_preimage.iter() { + let freed_chan = (*cp, *op, *cid, blocker.clone()); freed_channels.push(freed_chan); } } @@ -14786,8 +14786,16 @@ where if payment_claim.mpp_parts.is_empty() { return Err(DecodeError::InvalidValue); } + let mut channels_without_preimage = payment_claim.mpp_parts.iter() + .map(|htlc_info| (htlc_info.counterparty_node_id, htlc_info.funding_txo, htlc_info.channel_id)) + .collect::>(); + // If we have multiple MPP parts which were received over the same channel, + // we only track it once as once we get a preimage durably in the + // `ChannelMonitor` it will be used for all HTLCs with a matching hash. + channels_without_preimage.sort_unstable(); + channels_without_preimage.dedup(); let pending_claims = PendingMPPClaim { - channels_without_preimage: payment_claim.mpp_parts.clone(), + channels_without_preimage, channels_with_preimage: Vec::new(), }; let pending_claim_ptr_opt = Some(Arc::new(Mutex::new(pending_claims))); @@ -14820,7 +14828,7 @@ where for part in payment_claim.mpp_parts.iter() { let pending_mpp_claim = pending_claim_ptr_opt.as_ref().map(|ptr| ( - part.counterparty_node_id, part.channel_id, part.htlc_id, + part.counterparty_node_id, part.channel_id, PendingMPPClaimPointer(Arc::clone(&ptr)) )); let pending_claim_ptr = pending_claim_ptr_opt.as_ref().map(|ptr| diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index 64775d9e0f7..60dce316418 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -10,7 +10,7 @@ //! A bunch of useful utilities for building networks of nodes and exchanging messages between //! nodes for functional tests. -use crate::chain::{BestBlock, ChannelMonitorUpdateStatus, Confirm, Listen, Watch, chainmonitor::Persist}; +use crate::chain::{BestBlock, ChannelMonitorUpdateStatus, Confirm, Listen, Watch}; use crate::chain::channelmonitor::ChannelMonitor; use crate::chain::transaction::OutPoint; use crate::events::{ClaimedHTLC, ClosureReason, Event, HTLCDestination, PathFailure, PaymentPurpose, PaymentFailureReason}; @@ -398,7 +398,7 @@ pub struct NodeCfg<'a> { pub override_init_features: Rc>>, } -type TestChannelManager<'node_cfg, 'chan_mon_cfg> = ChannelManager< +pub(crate) type TestChannelManager<'node_cfg, 'chan_mon_cfg> = ChannelManager< &'node_cfg TestChainMonitor<'chan_mon_cfg>, &'chan_mon_cfg test_utils::TestBroadcaster, &'node_cfg test_utils::TestKeysInterface, @@ -779,6 +779,26 @@ pub fn get_revoke_commit_msgs>(node: & }) } +/// Gets a `UpdateHTLCs` and `revoke_and_ack` (i.e. after we get a responding `commitment_signed` +/// while we have updates in the holding cell). +pub fn get_updates_and_revoke>(node: &H, recipient: &PublicKey) -> (msgs::CommitmentUpdate, msgs::RevokeAndACK) { + let events = node.node().get_and_clear_pending_msg_events(); + assert_eq!(events.len(), 2); + (match events[0] { + MessageSendEvent::UpdateHTLCs { ref node_id, ref updates } => { + assert_eq!(node_id, recipient); + (*updates).clone() + }, + _ => panic!("Unexpected event"), + }, match events[1] { + MessageSendEvent::SendRevokeAndACK { ref node_id, ref msg } => { + assert_eq!(node_id, recipient); + (*msg).clone() + }, + _ => panic!("Unexpected event"), + }) +} + #[macro_export] /// Gets an RAA and CS which were sent in response to a commitment update /// @@ -3286,7 +3306,7 @@ pub fn create_node_cfgs<'a>(node_count: usize, chanmon_cfgs: &'a Vec(node_count: usize, chanmon_cfgs: &'a Vec, persisters: Vec<&'a impl Persist>) -> Vec> { +pub fn create_node_cfgs_with_persisters<'a>(node_count: usize, chanmon_cfgs: &'a Vec, persisters: Vec<&'a impl test_utils::SyncPersist>) -> Vec> { let mut nodes = Vec::new(); for i in 0..node_count { diff --git a/lightning/src/ln/monitor_tests.rs b/lightning/src/ln/monitor_tests.rs index 0a42d0a8b99..c44095a1308 100644 --- a/lightning/src/ln/monitor_tests.rs +++ b/lightning/src/ln/monitor_tests.rs @@ -3298,10 +3298,10 @@ fn test_update_replay_panics() { // Ensure applying the force-close update skipping the last normal update fails let poisoned_monitor = monitor.clone(); - std::panic::catch_unwind(|| { + std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { let _ = poisoned_monitor.update_monitor(&updates[1], &nodes[1].tx_broadcaster, &nodes[1].fee_estimator, &nodes[1].logger); // We should panic, rather than returning an error here. - }).unwrap_err(); + })).unwrap_err(); // Then apply the last normal and force-close update and make sure applying the preimage // updates out-of-order fails. @@ -3309,17 +3309,17 @@ fn test_update_replay_panics() { monitor.update_monitor(&updates[1], &nodes[1].tx_broadcaster, &nodes[1].fee_estimator, &nodes[1].logger).unwrap(); let poisoned_monitor = monitor.clone(); - std::panic::catch_unwind(|| { + std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { let _ = poisoned_monitor.update_monitor(&updates[3], &nodes[1].tx_broadcaster, &nodes[1].fee_estimator, &nodes[1].logger); // We should panic, rather than returning an error here. - }).unwrap_err(); + })).unwrap_err(); // Make sure re-applying the force-close update fails let poisoned_monitor = monitor.clone(); - std::panic::catch_unwind(|| { + std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { let _ = poisoned_monitor.update_monitor(&updates[1], &nodes[1].tx_broadcaster, &nodes[1].fee_estimator, &nodes[1].logger); // We should panic, rather than returning an error here. - }).unwrap_err(); + })).unwrap_err(); // ...and finally ensure that applying all the updates succeeds. monitor.update_monitor(&updates[2], &nodes[1].tx_broadcaster, &nodes[1].fee_estimator, &nodes[1].logger).unwrap(); diff --git a/lightning/src/sync/debug_sync.rs b/lightning/src/sync/debug_sync.rs index f142328e45c..991a71ffbe0 100644 --- a/lightning/src/sync/debug_sync.rs +++ b/lightning/src/sync/debug_sync.rs @@ -5,15 +5,16 @@ use core::time::Duration; use std::cell::RefCell; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Condvar as StdCondvar; -use std::sync::Mutex as StdMutex; -use std::sync::MutexGuard as StdMutexGuard; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::RwLock as StdRwLock; use std::sync::RwLockReadGuard as StdRwLockReadGuard; use std::sync::RwLockWriteGuard as StdRwLockWriteGuard; -pub use std::sync::WaitTimeoutResult; +use parking_lot::Condvar as StdCondvar; +use parking_lot::Mutex as StdMutex; +use parking_lot::MutexGuard as StdMutexGuard; + +pub use parking_lot::WaitTimeoutResult; use crate::prelude::*; @@ -46,10 +47,9 @@ impl Condvar { &'a self, guard: MutexGuard<'a, T>, condition: F, ) -> LockResult> { let mutex: &'a Mutex = guard.mutex; - self.inner - .wait_while(guard.into_inner(), condition) - .map(|lock| MutexGuard { mutex, lock }) - .map_err(|_| ()) + let mut lock = guard.into_inner(); + self.inner.wait_while(&mut lock, condition); + Ok(MutexGuard { mutex, lock: Some(lock) }) } #[allow(unused)] @@ -57,10 +57,9 @@ impl Condvar { &'a self, guard: MutexGuard<'a, T>, dur: Duration, condition: F, ) -> LockResult<(MutexGuard<'a, T>, WaitTimeoutResult)> { let mutex = guard.mutex; - self.inner - .wait_timeout_while(guard.into_inner(), dur, condition) - .map_err(|_| ()) - .map(|(lock, e)| (MutexGuard { mutex, lock }, e)) + let mut lock = guard.into_inner(); + let e = self.inner.wait_while_for(&mut lock, condition, dur); + Ok((MutexGuard { mutex, lock: Some(lock) }, e)) } pub fn notify_all(&self) { @@ -150,7 +149,7 @@ impl LockMetadata { LOCKS_INIT.call_once(|| unsafe { LOCKS = Some(StdMutex::new(new_hash_map())); }); - let mut locks = unsafe { LOCKS.as_ref() }.unwrap().lock().unwrap(); + let mut locks = unsafe { LOCKS.as_ref() }.unwrap().lock(); match locks.entry(lock_constr_location) { hash_map::Entry::Occupied(e) => { assert_eq!(lock_constr_colno, @@ -185,7 +184,7 @@ impl LockMetadata { } } for (_locked_idx, locked) in held.borrow().iter() { - for (locked_dep_idx, _locked_dep) in locked.locked_before.lock().unwrap().iter() { + for (locked_dep_idx, _locked_dep) in locked.locked_before.lock().iter() { let is_dep_this_lock = *locked_dep_idx == this.lock_idx; let has_same_construction = *locked_dep_idx == locked.lock_idx; if is_dep_this_lock && !has_same_construction { @@ -210,7 +209,7 @@ impl LockMetadata { } } // Insert any already-held locks in our locked-before set. - let mut locked_before = this.locked_before.lock().unwrap(); + let mut locked_before = this.locked_before.lock(); if !locked_before.contains_key(&locked.lock_idx) { let lockdep = LockDep { lock: Arc::clone(locked), _lockdep_trace: Backtrace::new() }; locked_before.insert(lockdep.lock.lock_idx, lockdep); @@ -237,7 +236,7 @@ impl LockMetadata { // Since a try-lock will simply fail if the lock is held already, we do not // consider try-locks to ever generate lockorder inversions. However, if a try-lock // succeeds, we do consider it to have created lockorder dependencies. - let mut locked_before = this.locked_before.lock().unwrap(); + let mut locked_before = this.locked_before.lock(); for (locked_idx, locked) in held.borrow().iter() { if !locked_before.contains_key(locked_idx) { let lockdep = @@ -252,11 +251,17 @@ impl LockMetadata { pub struct Mutex { inner: StdMutex, + poisoned: AtomicBool, deps: Arc, } + impl Mutex { pub(crate) fn into_inner(self) -> LockResult { - self.inner.into_inner().map_err(|_| ()) + if self.poisoned.load(Ordering::Acquire) { + Err(()) + } else { + Ok(self.inner.into_inner()) + } } } @@ -278,14 +283,14 @@ impl fmt::Debug for Mutex { #[must_use = "if unused the Mutex will immediately unlock"] pub struct MutexGuard<'a, T: Sized + 'a> { mutex: &'a Mutex, - lock: StdMutexGuard<'a, T>, + lock: Option>, } impl<'a, T: Sized> MutexGuard<'a, T> { fn into_inner(self) -> StdMutexGuard<'a, T> { // Somewhat unclear why we cannot move out of self.lock, but doing so gets E0509. unsafe { - let v: StdMutexGuard<'a, T> = std::ptr::read(&self.lock); + let v: StdMutexGuard<'a, T> = std::ptr::read(self.lock.as_ref().unwrap()); std::mem::forget(self); v } @@ -297,6 +302,10 @@ impl Drop for MutexGuard<'_, T> { LOCKS_HELD.with(|held| { held.borrow_mut().remove(&self.mutex.deps.lock_idx); }); + if std::thread::panicking() { + self.mutex.poisoned.store(true, Ordering::Release); + } + StdMutexGuard::unlock_fair(self.lock.take().unwrap()); } } @@ -304,37 +313,52 @@ impl Deref for MutexGuard<'_, T> { type Target = T; fn deref(&self) -> &T { - &self.lock.deref() + &self.lock.as_ref().unwrap().deref() } } impl DerefMut for MutexGuard<'_, T> { fn deref_mut(&mut self) -> &mut T { - self.lock.deref_mut() + self.lock.as_mut().unwrap().deref_mut() } } impl Mutex { pub fn new(inner: T) -> Mutex { - Mutex { inner: StdMutex::new(inner), deps: LockMetadata::new() } + Mutex { + inner: StdMutex::new(inner), + poisoned: AtomicBool::new(false), + deps: LockMetadata::new(), + } } pub fn lock<'a>(&'a self) -> LockResult> { LockMetadata::pre_lock(&self.deps, false); - self.inner.lock().map(|lock| MutexGuard { mutex: self, lock }).map_err(|_| ()) + let lock = self.inner.lock(); + if self.poisoned.load(Ordering::Acquire) { + Err(()) + } else { + Ok(MutexGuard { mutex: self, lock: Some(lock) }) + } } pub fn try_lock<'a>(&'a self) -> LockResult> { - let res = - self.inner.try_lock().map(|lock| MutexGuard { mutex: self, lock }).map_err(|_| ()); + let res = self.inner.try_lock().ok_or(()); if res.is_ok() { + if self.poisoned.load(Ordering::Acquire) { + return Err(()); + } LockMetadata::try_locked(&self.deps); } - res + res.map(|lock| MutexGuard { mutex: self, lock: Some(lock) }) } pub fn get_mut<'a>(&'a mut self) -> LockResult<&'a mut T> { - self.inner.get_mut().map_err(|_| ()) + if self.poisoned.load(Ordering::Acquire) { + Err(()) + } else { + Ok(self.inner.get_mut()) + } } } @@ -345,9 +369,10 @@ impl<'a, T: 'a> LockTestExt<'a> for Mutex { } type ExclLock = MutexGuard<'a, T>; #[inline] - fn unsafe_well_ordered_double_lock_self(&'a self) -> MutexGuard { + fn unsafe_well_ordered_double_lock_self(&'a self) -> MutexGuard<'a, T> { LockMetadata::pre_lock(&self.deps, true); - self.inner.lock().map(|lock| MutexGuard { mutex: self, lock }).unwrap() + let lock = self.inner.lock(); + MutexGuard { mutex: self, lock: Some(lock) } } } diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 501207e1e22..b3264e3ba12 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -378,6 +378,24 @@ impl SignerProvider for OnlyReadsKeysInterface { } } +#[cfg(feature = "std")] +pub trait SyncBroadcaster: chaininterface::BroadcasterInterface + Sync {} +#[cfg(feature = "std")] +pub trait SyncPersist: Persist + Sync {} +#[cfg(feature = "std")] +impl SyncBroadcaster for T {} +#[cfg(feature = "std")] +impl + Sync> SyncPersist for T {} + +#[cfg(not(feature = "std"))] +pub trait SyncBroadcaster: chaininterface::BroadcasterInterface {} +#[cfg(not(feature = "std"))] +pub trait SyncPersist: Persist {} +#[cfg(not(feature = "std"))] +impl SyncBroadcaster for T {} +#[cfg(not(feature = "std"))] +impl> SyncPersist for T {} + pub struct TestChainMonitor<'a> { pub added_monitors: Mutex)>>, pub monitor_updates: Mutex>>, @@ -385,10 +403,10 @@ pub struct TestChainMonitor<'a> { pub chain_monitor: ChainMonitor< TestChannelSigner, &'a TestChainSource, - &'a dyn chaininterface::BroadcasterInterface, + &'a dyn SyncBroadcaster, &'a TestFeeEstimator, &'a TestLogger, - &'a dyn Persist, + &'a dyn SyncPersist, >, pub keys_manager: &'a TestKeysInterface, /// If this is set to Some(), the next update_channel call (not watch_channel) must be a @@ -398,13 +416,14 @@ pub struct TestChainMonitor<'a> { /// If this is set to Some(), the next round trip serialization check will not hold after an /// update_channel call (not watch_channel) for the given channel_id. pub expect_monitor_round_trip_fail: Mutex>, + #[cfg(feature = "std")] + pub write_blocker: Mutex>>, } impl<'a> TestChainMonitor<'a> { pub fn new( - chain_source: Option<&'a TestChainSource>, - broadcaster: &'a dyn chaininterface::BroadcasterInterface, logger: &'a TestLogger, - fee_estimator: &'a TestFeeEstimator, persister: &'a dyn Persist, - keys_manager: &'a TestKeysInterface, + chain_source: Option<&'a TestChainSource>, broadcaster: &'a dyn SyncBroadcaster, + logger: &'a TestLogger, fee_estimator: &'a TestFeeEstimator, + persister: &'a dyn SyncPersist, keys_manager: &'a TestKeysInterface, ) -> Self { Self { added_monitors: Mutex::new(Vec::new()), @@ -420,6 +439,8 @@ impl<'a> TestChainMonitor<'a> { keys_manager, expect_channel_force_closed: Mutex::new(None), expect_monitor_round_trip_fail: Mutex::new(None), + #[cfg(feature = "std")] + write_blocker: Mutex::new(None), } } @@ -433,6 +454,11 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { fn watch_channel( &self, channel_id: ChannelId, monitor: ChannelMonitor, ) -> Result { + #[cfg(feature = "std")] + if let Some(blocker) = &*self.write_blocker.lock().unwrap() { + blocker.recv().unwrap(); + } + // At every point where we get a monitor update, we should be able to send a useful monitor // to a watchtower and disk... let mut w = TestVecWriter(Vec::new()); @@ -455,6 +481,11 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { fn update_channel( &self, channel_id: ChannelId, update: &ChannelMonitorUpdate, ) -> chain::ChannelMonitorUpdateStatus { + #[cfg(feature = "std")] + if let Some(blocker) = &*self.write_blocker.lock().unwrap() { + blocker.recv().unwrap(); + } + // Every monitor update should survive roundtrip let mut w = TestVecWriter(Vec::new()); update.write(&mut w).unwrap(); @@ -1739,19 +1770,17 @@ impl Drop for TestChainSource { pub struct TestScorer { /// Stores a tuple of (scid, ChannelUsage) - scorer_expectations: RefCell>>, + scorer_expectations: Mutex>>, } impl TestScorer { pub fn new() -> Self { - Self { scorer_expectations: RefCell::new(None) } + Self { scorer_expectations: Mutex::new(None) } } pub fn expect_usage(&self, scid: u64, expectation: ChannelUsage) { - self.scorer_expectations - .borrow_mut() - .get_or_insert_with(|| VecDeque::new()) - .push_back((scid, expectation)); + let mut expectations = self.scorer_expectations.lock().unwrap(); + expectations.get_or_insert_with(|| VecDeque::new()).push_back((scid, expectation)); } } @@ -1772,7 +1801,7 @@ impl ScoreLookUp for TestScorer { Some(scid) => scid, None => return 0, }; - if let Some(scorer_expectations) = self.scorer_expectations.borrow_mut().as_mut() { + if let Some(scorer_expectations) = self.scorer_expectations.lock().unwrap().as_mut() { match scorer_expectations.pop_front() { Some((scid, expectation)) => { assert_eq!(expectation, usage); @@ -1810,7 +1839,7 @@ impl Drop for TestScorer { return; } - if let Some(scorer_expectations) = self.scorer_expectations.borrow().as_ref() { + if let Some(scorer_expectations) = self.scorer_expectations.lock().unwrap().as_ref() { if !scorer_expectations.is_empty() { panic!("Unsatisfied scorer expectations: {:?}", scorer_expectations) }