diff --git a/lightning-block-sync/src/gossip.rs b/lightning-block-sync/src/gossip.rs index 083156baab3..0fe221b9231 100644 --- a/lightning-block-sync/src/gossip.rs +++ b/lightning-block-sync/src/gossip.rs @@ -10,11 +10,10 @@ use bitcoin::hash_types::BlockHash; use bitcoin::transaction::{OutPoint, TxOut}; use lightning::ln::peer_handler::APeerManager; - use lightning::routing::gossip::{NetworkGraph, P2PGossipSync}; use lightning::routing::utxo::{UtxoFuture, UtxoLookup, UtxoLookupError, UtxoResult}; - use lightning::util::logger::Logger; +use lightning::util::native_async::FutureSpawner; use std::collections::VecDeque; use std::future::Future; @@ -43,17 +42,6 @@ pub trait UtxoSource: BlockSource + 'static { fn is_output_unspent<'a>(&'a self, outpoint: OutPoint) -> AsyncBlockSourceResult<'a, bool>; } -/// A generic trait which is able to spawn futures in the background. -/// -/// If the `tokio` feature is enabled, this is implemented on `TokioSpawner` struct which -/// delegates to `tokio::spawn()`. -pub trait FutureSpawner: Send + Sync + 'static { - /// Spawns the given future as a background task. - /// - /// This method MUST NOT block on the given future immediately. - fn spawn + Send + 'static>(&self, future: T); -} - #[cfg(feature = "tokio")] /// A trivial [`FutureSpawner`] which delegates to `tokio::spawn`. pub struct TokioSpawner; diff --git a/lightning/src/chain/chainmonitor.rs b/lightning/src/chain/chainmonitor.rs index 36d26aee971..4abd0cd88c0 100644 --- a/lightning/src/chain/chainmonitor.rs +++ b/lightning/src/chain/chainmonitor.rs @@ -46,12 +46,14 @@ use crate::ln::our_peer_storage::{DecryptedOurPeerStorage, PeerStorageMonitorHol use crate::ln::types::ChannelId; use crate::prelude::*; use crate::sign::ecdsa::EcdsaChannelSigner; -use crate::sign::{EntropySource, PeerStorageKey}; +use crate::sign::{EntropySource, PeerStorageKey, SignerProvider}; use crate::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard}; use crate::types::features::{InitFeatures, NodeFeatures}; +use crate::util::async_poll::{MaybeSend, MaybeSync}; use crate::util::errors::APIError; use crate::util::logger::{Logger, WithContext}; -use crate::util::persist::MonitorName; +use crate::util::native_async::FutureSpawner; +use crate::util::persist::{KVStore, MonitorName, MonitorUpdatingPersisterAsync}; #[cfg(peer_storage)] use crate::util::ser::{VecWriter, Writeable}; use crate::util::wakers::{Future, Notifier}; @@ -192,6 +194,17 @@ pub trait Persist { /// restart, this method must in that case be idempotent, ensuring it can handle scenarios where /// the monitor already exists in the archive. fn archive_persisted_channel(&self, monitor_name: MonitorName); + + /// Fetches the set of [`ChannelMonitorUpdate`]s, previously persisted with + /// [`Self::update_persisted_channel`], which have completed. + /// + /// Returning an update here is equivalent to calling + /// [`ChainMonitor::channel_monitor_updated`]. Because of this, this method is defaulted and + /// hidden in the docs. + #[doc(hidden)] + fn get_and_clear_completed_updates(&self) -> Vec<(ChannelId, u64)> { + Vec::new() + } } struct MonitorHolder { @@ -235,6 +248,93 @@ impl Deref for LockedChannelMonitor<'_, Chann } } +/// An unconstructable [`Persist`]er which is used under the hood when you call +/// [`ChainMonitor::new_async_beta`]. +pub struct AsyncPersister< + K: Deref + MaybeSend + MaybeSync + 'static, + S: FutureSpawner, + L: Deref + MaybeSend + MaybeSync + 'static, + ES: Deref + MaybeSend + MaybeSync + 'static, + SP: Deref + MaybeSend + MaybeSync + 'static, + BI: Deref + MaybeSend + MaybeSync + 'static, + FE: Deref + MaybeSend + MaybeSync + 'static, +> where + K::Target: KVStore + MaybeSync, + L::Target: Logger, + ES::Target: EntropySource + Sized, + SP::Target: SignerProvider + Sized, + BI::Target: BroadcasterInterface, + FE::Target: FeeEstimator, +{ + persister: MonitorUpdatingPersisterAsync, +} + +impl< + K: Deref + MaybeSend + MaybeSync + 'static, + S: FutureSpawner, + L: Deref + MaybeSend + MaybeSync + 'static, + ES: Deref + MaybeSend + MaybeSync + 'static, + SP: Deref + MaybeSend + MaybeSync + 'static, + BI: Deref + MaybeSend + MaybeSync + 'static, + FE: Deref + MaybeSend + MaybeSync + 'static, + > Deref for AsyncPersister +where + K::Target: KVStore + MaybeSync, + L::Target: Logger, + ES::Target: EntropySource + Sized, + SP::Target: SignerProvider + Sized, + BI::Target: BroadcasterInterface, + FE::Target: FeeEstimator, +{ + type Target = Self; + fn deref(&self) -> &Self { + self + } +} + +impl< + K: Deref + MaybeSend + MaybeSync + 'static, + S: FutureSpawner, + L: Deref + MaybeSend + MaybeSync + 'static, + ES: Deref + MaybeSend + MaybeSync + 'static, + SP: Deref + MaybeSend + MaybeSync + 'static, + BI: Deref + MaybeSend + MaybeSync + 'static, + FE: Deref + MaybeSend + MaybeSync + 'static, + > Persist<::EcdsaSigner> for AsyncPersister +where + K::Target: KVStore + MaybeSync, + L::Target: Logger, + ES::Target: EntropySource + Sized, + SP::Target: SignerProvider + Sized, + BI::Target: BroadcasterInterface, + FE::Target: FeeEstimator, + ::EcdsaSigner: MaybeSend + 'static, +{ + fn persist_new_channel( + &self, monitor_name: MonitorName, + monitor: &ChannelMonitor<::EcdsaSigner>, + ) -> ChannelMonitorUpdateStatus { + self.persister.spawn_async_persist_new_channel(monitor_name, monitor); + ChannelMonitorUpdateStatus::InProgress + } + + fn update_persisted_channel( + &self, monitor_name: MonitorName, monitor_update: Option<&ChannelMonitorUpdate>, + monitor: &ChannelMonitor<::EcdsaSigner>, + ) -> ChannelMonitorUpdateStatus { + self.persister.spawn_async_update_persisted_channel(monitor_name, monitor_update, monitor); + ChannelMonitorUpdateStatus::InProgress + } + + fn archive_persisted_channel(&self, monitor_name: MonitorName) { + self.persister.spawn_async_archive_persisted_channel(monitor_name); + } + + fn get_and_clear_completed_updates(&self) -> Vec<(ChannelId, u64)> { + self.persister.get_and_clear_completed_updates() + } +} + /// An implementation of [`chain::Watch`] for monitoring channels. /// /// Connected and disconnected blocks must be provided to `ChainMonitor` as documented by @@ -291,6 +391,63 @@ pub struct ChainMonitor< our_peerstorage_encryption_key: PeerStorageKey, } +impl< + K: Deref + MaybeSend + MaybeSync + 'static, + S: FutureSpawner, + SP: Deref + MaybeSend + MaybeSync + 'static, + C: Deref, + T: Deref + MaybeSend + MaybeSync + 'static, + F: Deref + MaybeSend + MaybeSync + 'static, + L: Deref + MaybeSend + MaybeSync + 'static, + ES: Deref + MaybeSend + MaybeSync + 'static, + > + ChainMonitor< + ::EcdsaSigner, + C, + T, + F, + L, + AsyncPersister, + ES, + > where + K::Target: KVStore + MaybeSync, + SP::Target: SignerProvider + Sized, + C::Target: chain::Filter, + T::Target: BroadcasterInterface, + F::Target: FeeEstimator, + L::Target: Logger, + ES::Target: EntropySource + Sized, + ::EcdsaSigner: MaybeSend + 'static, +{ + /// Creates a new `ChainMonitor` used to watch on-chain activity pertaining to channels. + /// + /// This behaves the same as [`ChainMonitor::new`] except that it relies on + /// [`MonitorUpdatingPersisterAsync`] and thus allows persistence to be completed async. + /// + /// Note that async monitor updating is considered beta, and bugs may be triggered by its use. + pub fn new_async_beta( + chain_source: Option, broadcaster: T, logger: L, feeest: F, + persister: MonitorUpdatingPersisterAsync, _entropy_source: ES, + _our_peerstorage_encryption_key: PeerStorageKey, + ) -> Self { + Self { + monitors: RwLock::new(new_hash_map()), + chain_source, + broadcaster, + logger, + fee_estimator: feeest, + persister: AsyncPersister { persister }, + _entropy_source, + pending_monitor_events: Mutex::new(Vec::new()), + highest_chain_height: AtomicUsize::new(0), + event_notifier: Notifier::new(), + pending_send_only_events: Mutex::new(Vec::new()), + #[cfg(peer_storage)] + our_peerstorage_encryption_key: _our_peerstorage_encryption_key, + } + } +} + impl< ChannelSigner: EcdsaChannelSigner, C: Deref, @@ -1357,6 +1514,9 @@ where fn release_pending_monitor_events( &self, ) -> Vec<(OutPoint, ChannelId, Vec, PublicKey)> { + for (channel_id, update_id) in self.persister.get_and_clear_completed_updates() { + let _ = self.channel_monitor_updated(channel_id, update_id); + } let mut pending_monitor_events = self.pending_monitor_events.lock().unwrap().split_off(0); for monitor_state in self.monitors.read().unwrap().values() { let monitor_events = monitor_state.monitor.get_and_clear_pending_monitor_events(); diff --git a/lightning/src/ln/chanmon_update_fail_tests.rs b/lightning/src/ln/chanmon_update_fail_tests.rs index e0de92c27fa..cd74e4a0a76 100644 --- a/lightning/src/ln/chanmon_update_fail_tests.rs +++ b/lightning/src/ln/chanmon_update_fail_tests.rs @@ -12,7 +12,9 @@ //! There are a bunch of these as their handling is relatively error-prone so they are split out //! here. See also the chanmon_fail_consistency fuzz test. -use crate::chain::channelmonitor::{ChannelMonitor, ANTI_REORG_DELAY}; +use crate::chain::chainmonitor::ChainMonitor; +use crate::chain::channelmonitor::{ChannelMonitor, MonitorEvent, ANTI_REORG_DELAY}; +use crate::chain::transaction::OutPoint; use crate::chain::{ChannelMonitorUpdateStatus, Listen, Watch}; use crate::events::{ClosureReason, Event, HTLCHandlingFailureType, PaymentPurpose}; use crate::ln::channel::AnnouncementSigsState; @@ -22,6 +24,13 @@ use crate::ln::msgs::{ BaseMessageHandler, ChannelMessageHandler, MessageSendEvent, RoutingMessageHandler, }; use crate::ln::types::ChannelId; +use crate::sign::NodeSigner; +use crate::util::native_async::FutureQueue; +use crate::util::persist::{ + MonitorName, MonitorUpdatingPersisterAsync, CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE, + CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE, + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, +}; use crate::util::ser::{ReadableArgs, Writeable}; use crate::util::test_channel_signer::TestChannelSigner; use crate::util::test_utils::TestBroadcaster; @@ -4847,3 +4856,200 @@ fn test_single_channel_multiple_mpp() { nodes[7].node.handle_revoke_and_ack(node_i_id, &raa); check_added_monitors(&nodes[7], 1); } + +#[test] +fn native_async_persist() { + // Test ChainMonitor::new_async_beta and the backing MonitorUpdatingPersisterAsync. + // + // Because our test utils aren't really set up for such utils, we simply test them directly, + // first spinning up some nodes to create a `ChannelMonitor` and some `ChannelMonitorUpdate`s + // we can apply. + let (monitor, updates); + let mut chanmon_cfgs = create_chanmon_cfgs(2); + let node_cfgs = create_node_cfgs(2, &chanmon_cfgs); + let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); + let nodes = create_network(2, &node_cfgs, &node_chanmgrs); + + let (_, _, chan_id, funding_tx) = create_announced_chan_between_nodes(&nodes, 0, 1); + + monitor = get_monitor!(nodes[0], chan_id).clone(); + send_payment(&nodes[0], &[&nodes[1]], 1_000_000); + let mon_updates = + nodes[0].chain_monitor.monitor_updates.lock().unwrap().remove(&chan_id).unwrap(); + updates = mon_updates.into_iter().collect::>(); + assert!(updates.len() >= 4, "The test below needs at least four updates"); + + core::mem::drop(nodes); + core::mem::drop(node_chanmgrs); + core::mem::drop(node_cfgs); + + let node_0_utils = chanmon_cfgs.remove(0); + let (logger, keys_manager, tx_broadcaster, fee_estimator) = ( + node_0_utils.logger, + node_0_utils.keys_manager, + node_0_utils.tx_broadcaster, + node_0_utils.fee_estimator, + ); + + // Now that we have some updates, build a new ChainMonitor with a backing async KVStore. + let logger = Arc::new(logger); + let keys_manager = Arc::new(keys_manager); + let tx_broadcaster = Arc::new(tx_broadcaster); + let fee_estimator = Arc::new(fee_estimator); + + let kv_store = Arc::new(test_utils::TestStore::new(false)); + let persist_futures = Arc::new(FutureQueue::new()); + let native_async_persister = MonitorUpdatingPersisterAsync::new( + Arc::clone(&kv_store), + Arc::clone(&persist_futures), + Arc::clone(&logger), + 42, + Arc::clone(&keys_manager), + Arc::clone(&keys_manager), + Arc::clone(&tx_broadcaster), + Arc::clone(&fee_estimator), + ); + let chain_source = test_utils::TestChainSource::new(Network::Testnet); + let async_chain_monitor = ChainMonitor::new_async_beta( + Some(&chain_source), + tx_broadcaster, + logger, + fee_estimator, + native_async_persister, + Arc::clone(&keys_manager), + keys_manager.get_peer_storage_key(), + ); + + // Write the initial ChannelMonitor async, testing primarily that the `MonitorEvent::Completed` + // isn't returned until the write is completed (via `complete_all_async_writes`) and the future + // is `poll`ed (which a background spawn should do automatically in production, but which is + // needed to get the future completion through to the `ChainMonitor`). + let write_status = async_chain_monitor.watch_channel(chan_id, monitor).unwrap(); + assert_eq!(write_status, ChannelMonitorUpdateStatus::InProgress); + + // The write will remain pending until we call `complete_all_async_writes`, below. + assert_eq!(persist_futures.pending_futures(), 1); + persist_futures.poll_futures(); + assert_eq!(persist_futures.pending_futures(), 1); + + let funding_txo = OutPoint { txid: funding_tx.compute_txid(), index: 0 }; + let key = MonitorName::V1Channel(funding_txo).to_string(); + let pending_writes = kv_store.list_pending_async_writes( + CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE, + CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE, + &key, + ); + assert_eq!(pending_writes.len(), 1); + + // Once we complete the future, the write will still be pending until the future gets `poll`ed. + kv_store.complete_all_async_writes(); + assert_eq!(persist_futures.pending_futures(), 1); + assert_eq!(async_chain_monitor.release_pending_monitor_events().len(), 0); + + assert_eq!(persist_futures.pending_futures(), 1); + persist_futures.poll_futures(); + assert_eq!(persist_futures.pending_futures(), 0); + + let completed_persist = async_chain_monitor.release_pending_monitor_events(); + assert_eq!(completed_persist.len(), 1); + assert_eq!(completed_persist[0].2.len(), 1); + assert!(matches!(completed_persist[0].2[0], MonitorEvent::Completed { .. })); + + // Now test two async `ChannelMonitorUpdate`s in flight at once, completing them in-order but + // separately. + let update_status = async_chain_monitor.update_channel(chan_id, &updates[0]); + assert_eq!(update_status, ChannelMonitorUpdateStatus::InProgress); + + let update_status = async_chain_monitor.update_channel(chan_id, &updates[1]); + assert_eq!(update_status, ChannelMonitorUpdateStatus::InProgress); + + persist_futures.poll_futures(); + assert_eq!(async_chain_monitor.release_pending_monitor_events().len(), 0); + + let pending_writes = kv_store.list_pending_async_writes( + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, + &key, + "1", + ); + assert_eq!(pending_writes.len(), 1); + let pending_writes = kv_store.list_pending_async_writes( + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, + &key, + "2", + ); + assert_eq!(pending_writes.len(), 1); + + kv_store.complete_async_writes_through( + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, + &key, + "1", + usize::MAX, + ); + persist_futures.poll_futures(); + // While the `ChainMonitor` could return a `MonitorEvent::Completed` here, it currently + // doesn't. If that ever changes we should validate that the `Completed` event has the correct + // `monitor_update_id` (1). + assert!(async_chain_monitor.release_pending_monitor_events().is_empty()); + + kv_store.complete_async_writes_through( + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, + &key, + "2", + usize::MAX, + ); + persist_futures.poll_futures(); + let completed_persist = async_chain_monitor.release_pending_monitor_events(); + assert_eq!(completed_persist.len(), 1); + assert_eq!(completed_persist[0].2.len(), 1); + assert!(matches!(completed_persist[0].2[0], MonitorEvent::Completed { .. })); + + // Finally, test two async `ChanelMonitorUpdate`s in flight at once, completing them + // out-of-order and ensuring that no `MonitorEvent::Completed` is generated until they are both + // completed (and that it marks both as completed when it is generated). + let update_status = async_chain_monitor.update_channel(chan_id, &updates[2]); + assert_eq!(update_status, ChannelMonitorUpdateStatus::InProgress); + + let update_status = async_chain_monitor.update_channel(chan_id, &updates[3]); + assert_eq!(update_status, ChannelMonitorUpdateStatus::InProgress); + + persist_futures.poll_futures(); + assert_eq!(async_chain_monitor.release_pending_monitor_events().len(), 0); + + let pending_writes = kv_store.list_pending_async_writes( + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, + &key, + "3", + ); + assert_eq!(pending_writes.len(), 1); + let pending_writes = kv_store.list_pending_async_writes( + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, + &key, + "4", + ); + assert_eq!(pending_writes.len(), 1); + + kv_store.complete_async_writes_through( + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, + &key, + "4", + usize::MAX, + ); + persist_futures.poll_futures(); + assert_eq!(async_chain_monitor.release_pending_monitor_events().len(), 0); + + kv_store.complete_async_writes_through( + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, + &key, + "3", + usize::MAX, + ); + persist_futures.poll_futures(); + let completed_persist = async_chain_monitor.release_pending_monitor_events(); + assert_eq!(completed_persist.len(), 1); + assert_eq!(completed_persist[0].2.len(), 1); + if let MonitorEvent::Completed { monitor_update_id, .. } = &completed_persist[0].2[0] { + assert_eq!(*monitor_update_id, 4); + } else { + panic!(); + } +} diff --git a/lightning/src/util/mod.rs b/lightning/src/util/mod.rs index 84c0c113f85..968f8222d9a 100644 --- a/lightning/src/util/mod.rs +++ b/lightning/src/util/mod.rs @@ -26,6 +26,7 @@ pub mod base32; pub(crate) mod base32; pub mod errors; pub mod message_signing; +pub mod native_async; pub mod persist; pub mod scid_utils; pub mod ser; diff --git a/lightning/src/util/native_async.rs b/lightning/src/util/native_async.rs new file mode 100644 index 00000000000..dc26cb42bd0 --- /dev/null +++ b/lightning/src/util/native_async.rs @@ -0,0 +1,113 @@ +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +//! This module contains a few public utility which are used to run LDK in a native Rust async +//! environment. + +#[cfg(all(test, feature = "std"))] +use crate::sync::Mutex; +use crate::util::async_poll::{MaybeSend, MaybeSync}; + +#[cfg(all(test, not(feature = "std")))] +use core::cell::RefCell; +use core::future::Future; +#[cfg(test)] +use core::pin::Pin; + +/// A generic trait which is able to spawn futures in the background. +pub trait FutureSpawner: MaybeSend + MaybeSync + 'static { + /// Spawns the given future as a background task. + /// + /// This method MUST NOT block on the given future immediately. + fn spawn + MaybeSend + 'static>(&self, future: T); +} + +#[cfg(test)] +trait MaybeSendableFuture: Future + MaybeSend + 'static {} +#[cfg(test)] +impl + MaybeSend + 'static> MaybeSendableFuture for F {} + +/// A simple [`FutureSpawner`] which holds [`Future`]s until they are manually polled via +/// [`Self::poll_futures`]. +#[cfg(all(test, feature = "std"))] +pub(crate) struct FutureQueue(Mutex>>>); +#[cfg(all(test, not(feature = "std")))] +pub(crate) struct FutureQueue(RefCell>>>); + +#[cfg(test)] +impl FutureQueue { + pub(crate) fn new() -> Self { + #[cfg(feature = "std")] + { + FutureQueue(Mutex::new(Vec::new())) + } + #[cfg(not(feature = "std"))] + { + FutureQueue(RefCell::new(Vec::new())) + } + } + + pub(crate) fn pending_futures(&self) -> usize { + #[cfg(feature = "std")] + { + self.0.lock().unwrap().len() + } + #[cfg(not(feature = "std"))] + { + self.0.borrow().len() + } + } + + pub(crate) fn poll_futures(&self) { + let mut futures; + #[cfg(feature = "std")] + { + futures = self.0.lock().unwrap(); + } + #[cfg(not(feature = "std"))] + { + futures = self.0.borrow_mut(); + } + futures.retain_mut(|fut| { + use core::task::{Context, Poll}; + let waker = crate::util::async_poll::dummy_waker(); + match fut.as_mut().poll(&mut Context::from_waker(&waker)) { + Poll::Ready(()) => false, + Poll::Pending => true, + } + }); + } +} + +#[cfg(test)] +impl FutureSpawner for FutureQueue { + fn spawn + MaybeSend + 'static>(&self, future: T) { + #[cfg(feature = "std")] + { + self.0.lock().unwrap().push(Box::pin(future)); + } + #[cfg(not(feature = "std"))] + { + self.0.borrow_mut().push(Box::pin(future)); + } + } +} + +#[cfg(test)] +impl + MaybeSend + MaybeSync + 'static> FutureSpawner + for D +{ + fn spawn + MaybeSend + 'static>(&self, future: T) { + #[cfg(feature = "std")] + { + self.0.lock().unwrap().push(Box::pin(future)); + } + #[cfg(not(feature = "std"))] + { + self.0.borrow_mut().push(Box::pin(future)); + } + } +} diff --git a/lightning/src/util/persist.rs b/lightning/src/util/persist.rs index e3fb86fb88a..9036a27f49c 100644 --- a/lightning/src/util/persist.rs +++ b/lightning/src/util/persist.rs @@ -11,13 +11,17 @@ //! [`ChannelManager`]: crate::ln::channelmanager::ChannelManager //! [`NetworkGraph`]: crate::routing::gossip::NetworkGraph +use alloc::sync::Arc; + use bitcoin::hashes::hex::FromHex; use bitcoin::{BlockHash, Txid}; -use core::cmp; + use core::future::Future; +use core::mem; use core::ops::Deref; use core::pin::Pin; use core::str::FromStr; +use core::task; use crate::prelude::*; use crate::{io, log_error}; @@ -29,7 +33,10 @@ use crate::chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdate}; use crate::chain::transaction::OutPoint; use crate::ln::types::ChannelId; use crate::sign::{ecdsa::EcdsaChannelSigner, EntropySource, SignerProvider}; +use crate::sync::Mutex; +use crate::util::async_poll::{dummy_waker, MaybeSend, MaybeSync}; use crate::util::logger::Logger; +use crate::util::native_async::FutureSpawner; use crate::util::ser::{Readable, ReadableArgs, Writeable}; /// The alphabet of characters allowed for namespaces and keys. @@ -405,6 +412,26 @@ where Ok(res) } +struct PanicingSpawner; +impl FutureSpawner for PanicingSpawner { + fn spawn + MaybeSend + 'static>(&self, _: T) { + unreachable!(); + } +} + +fn poll_sync_future(future: F) -> F::Output { + let mut waker = dummy_waker(); + let mut ctx = task::Context::from_waker(&mut waker); + // TODO A future MSRV bump to 1.68 should allow for the pin macro + match Pin::new(&mut Box::pin(future)).poll(&mut ctx) { + task::Poll::Ready(result) => result, + task::Poll::Pending => { + // In a sync context, we can't wait for the future to complete. + unreachable!("Sync KVStore-derived futures can not be pending in a sync context"); + }, + } +} + /// Implements [`Persist`] in a way that writes and reads both [`ChannelMonitor`]s and /// [`ChannelMonitorUpdate`]s. /// @@ -489,25 +516,17 @@ where /// If you have many stale updates stored (such as after a crash with pending lazy deletes), and /// would like to get rid of them, consider using the /// [`MonitorUpdatingPersister::cleanup_stale_updates`] function. -pub struct MonitorUpdatingPersister +pub struct MonitorUpdatingPersister( + MonitorUpdatingPersisterAsync, PanicingSpawner, L, ES, SP, BI, FE>, +) where K::Target: KVStoreSync, L::Target: Logger, ES::Target: EntropySource + Sized, SP::Target: SignerProvider + Sized, BI::Target: BroadcasterInterface, - FE::Target: FeeEstimator, -{ - kv_store: K, - logger: L, - maximum_pending_updates: u64, - entropy_source: ES, - signer_provider: SP, - broadcaster: BI, - fee_estimator: FE, -} + FE::Target: FeeEstimator; -#[allow(dead_code)] impl MonitorUpdatingPersister where @@ -534,19 +553,27 @@ where /// less frequent "waves." /// - [`MonitorUpdatingPersister`] will potentially have more listing to do if you need to run /// [`MonitorUpdatingPersister::cleanup_stale_updates`]. + /// + /// Note that you can disable the update-writing entirely by setting `maximum_pending_updates` + /// to zero, causing this [`Persist`] implementation to behave like the blanket [`Persist`] + /// implementation for all [`KVStoreSync`]s. pub fn new( kv_store: K, logger: L, maximum_pending_updates: u64, entropy_source: ES, signer_provider: SP, broadcaster: BI, fee_estimator: FE, ) -> Self { - MonitorUpdatingPersister { - kv_store, + // Note that calling the spawner only happens in the `pub(crate)` `spawn_*` methods defined + // with additional bounds on `MonitorUpdatingPersisterAsync`. Thus its safe to provide a + // dummy always-panic implementation here. + MonitorUpdatingPersister(MonitorUpdatingPersisterAsync::new( + KVStoreSyncWrapper(kv_store), + PanicingSpawner, logger, maximum_pending_updates, entropy_source, signer_provider, broadcaster, fee_estimator, - } + )) } /// Reads all stored channel monitors, along with any stored updates for them. @@ -560,13 +587,222 @@ where Vec<(BlockHash, ChannelMonitor<::EcdsaSigner>)>, io::Error, > { - let monitor_list = self.kv_store.list( - CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE, - CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE, - )?; + poll_sync_future(self.0.read_all_channel_monitors_with_updates()) + } + + /// Read a single channel monitor, along with any stored updates for it. + /// + /// It is extremely important that your [`KVStoreSync::read`] implementation uses the + /// [`io::ErrorKind::NotFound`] variant correctly. For more information, please see the + /// documentation for [`MonitorUpdatingPersister`]. + /// + /// For `monitor_key`, channel storage keys can be the channel's funding [`OutPoint`], with an + /// underscore `_` between txid and index for v1 channels. For example, given: + /// + /// - Transaction ID: `deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef` + /// - Index: `1` + /// + /// The correct `monitor_key` would be: + /// `deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef_1` + /// + /// For v2 channels, the hex-encoded [`ChannelId`] is used directly for `monitor_key` instead. + /// + /// Loading a large number of monitors will be faster if done in parallel. You can use this + /// function to accomplish this. Take care to limit the number of parallel readers. + pub fn read_channel_monitor_with_updates( + &self, monitor_key: &str, + ) -> Result<(BlockHash, ChannelMonitor<::EcdsaSigner>), io::Error> + { + poll_sync_future(self.0.read_channel_monitor_with_updates(monitor_key)) + } + + /// Cleans up stale updates for all monitors. + /// + /// This function works by first listing all monitors, and then for each of them, listing all + /// updates. The updates that have an `update_id` less than or equal to than the stored monitor + /// are deleted. The deletion can either be lazy or non-lazy based on the `lazy` flag; this will + /// be passed to [`KVStoreSync::remove`]. + pub fn cleanup_stale_updates(&self, lazy: bool) -> Result<(), io::Error> { + poll_sync_future(self.0.cleanup_stale_updates(lazy)) + } +} + +impl< + ChannelSigner: EcdsaChannelSigner, + K: Deref, + L: Deref, + ES: Deref, + SP: Deref, + BI: Deref, + FE: Deref, + > Persist for MonitorUpdatingPersister +where + K::Target: KVStoreSync, + L::Target: Logger, + ES::Target: EntropySource + Sized, + SP::Target: SignerProvider + Sized, + BI::Target: BroadcasterInterface, + FE::Target: FeeEstimator, +{ + /// Persists a new channel. This means writing the entire monitor to the + /// parametrized [`KVStoreSync`]. + fn persist_new_channel( + &self, monitor_name: MonitorName, monitor: &ChannelMonitor, + ) -> chain::ChannelMonitorUpdateStatus { + let res = poll_sync_future(self.0 .0.persist_new_channel(monitor_name, monitor)); + match res { + Ok(_) => chain::ChannelMonitorUpdateStatus::Completed, + Err(e) => { + log_error!( + self.0 .0.logger, + "Failed to write ChannelMonitor {}/{}/{} reason: {}", + CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE, + CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE, + monitor_name, + e + ); + chain::ChannelMonitorUpdateStatus::UnrecoverableError + }, + } + } + + /// Persists a channel update, writing only the update to the parameterized [`KVStoreSync`] if possible. + /// + /// In some cases, this will forward to [`MonitorUpdatingPersister::persist_new_channel`]: + /// + /// - No full monitor is found in [`KVStoreSync`] + /// - The number of pending updates exceeds `maximum_pending_updates` as given to [`Self::new`] + /// - LDK commands re-persisting the entire monitor through this function, specifically when + /// `update` is `None`. + /// - The update is at [`u64::MAX`], indicating an update generated by pre-0.1 LDK. + fn update_persisted_channel( + &self, monitor_name: MonitorName, update: Option<&ChannelMonitorUpdate>, + monitor: &ChannelMonitor, + ) -> chain::ChannelMonitorUpdateStatus { + let inner = Arc::clone(&self.0 .0); + let res = poll_sync_future(inner.update_persisted_channel(monitor_name, update, monitor)); + match res { + Ok(()) => chain::ChannelMonitorUpdateStatus::Completed, + Err(e) => { + log_error!( + self.0 .0.logger, + "Failed to write ChannelMonitorUpdate {} id {} reason: {}", + monitor_name, + update.as_ref().map(|upd| upd.update_id).unwrap_or(0), + e + ); + chain::ChannelMonitorUpdateStatus::UnrecoverableError + }, + } + } + + fn archive_persisted_channel(&self, monitor_name: MonitorName) { + poll_sync_future(self.0 .0.archive_persisted_channel(monitor_name)); + } +} + +/// A variant of the [`MonitorUpdatingPersister`] which utilizes the async [`KVStore`] and offers +/// async versions of the public accessors. +/// +/// Note that async monitor updating is considered beta, and bugs may be triggered by its use. +/// +/// Unlike [`MonitorUpdatingPersister`], this does not implement [`Persist`], but is instead used +/// directly by the [`ChainMonitor`] via [`ChainMonitor::new_async_beta`]. +/// +/// [`ChainMonitor`]: crate::chain::chainmonitor::ChainMonitor +/// [`ChainMonitor::new_async_beta`]: crate::chain::chainmonitor::ChainMonitor::new_async_beta +pub struct MonitorUpdatingPersisterAsync< + K: Deref, + S: FutureSpawner, + L: Deref, + ES: Deref, + SP: Deref, + BI: Deref, + FE: Deref, +>(Arc>) +where + K::Target: KVStore, + L::Target: Logger, + ES::Target: EntropySource + Sized, + SP::Target: SignerProvider + Sized, + BI::Target: BroadcasterInterface, + FE::Target: FeeEstimator; + +struct MonitorUpdatingPersisterAsyncInner< + K: Deref, + S: FutureSpawner, + L: Deref, + ES: Deref, + SP: Deref, + BI: Deref, + FE: Deref, +> where + K::Target: KVStore, + L::Target: Logger, + ES::Target: EntropySource + Sized, + SP::Target: SignerProvider + Sized, + BI::Target: BroadcasterInterface, + FE::Target: FeeEstimator, +{ + kv_store: K, + async_completed_updates: Mutex>, + future_spawner: S, + logger: L, + maximum_pending_updates: u64, + entropy_source: ES, + signer_provider: SP, + broadcaster: BI, + fee_estimator: FE, +} + +impl + MonitorUpdatingPersisterAsync +where + K::Target: KVStore, + L::Target: Logger, + ES::Target: EntropySource + Sized, + SP::Target: SignerProvider + Sized, + BI::Target: BroadcasterInterface, + FE::Target: FeeEstimator, +{ + /// Constructs a new [`MonitorUpdatingPersisterAsync`]. + /// + /// See [`MonitorUpdatingPersister::new`] for more info. + pub fn new( + kv_store: K, future_spawner: S, logger: L, maximum_pending_updates: u64, + entropy_source: ES, signer_provider: SP, broadcaster: BI, fee_estimator: FE, + ) -> Self { + MonitorUpdatingPersisterAsync(Arc::new(MonitorUpdatingPersisterAsyncInner { + kv_store, + async_completed_updates: Mutex::new(Vec::new()), + future_spawner, + logger, + maximum_pending_updates, + entropy_source, + signer_provider, + broadcaster, + fee_estimator, + })) + } + + /// Reads all stored channel monitors, along with any stored updates for them. + /// + /// It is extremely important that your [`KVStore::read`] implementation uses the + /// [`io::ErrorKind::NotFound`] variant correctly. For more information, please see the + /// documentation for [`MonitorUpdatingPersister`]. + pub async fn read_all_channel_monitors_with_updates( + &self, + ) -> Result< + Vec<(BlockHash, ChannelMonitor<::EcdsaSigner>)>, + io::Error, + > { + let primary = CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE; + let secondary = CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE; + let monitor_list = self.0.kv_store.list(primary, secondary).await?; let mut res = Vec::with_capacity(monitor_list.len()); + // TODO: Parallelize this loop for monitor_key in monitor_list { - res.push(self.read_channel_monitor_with_updates(monitor_key.as_str())?) + res.push(self.read_channel_monitor_with_updates(monitor_key.as_str()).await?) } Ok(res) } @@ -590,20 +826,132 @@ where /// /// Loading a large number of monitors will be faster if done in parallel. You can use this /// function to accomplish this. Take care to limit the number of parallel readers. - pub fn read_channel_monitor_with_updates( + pub async fn read_channel_monitor_with_updates( + &self, monitor_key: &str, + ) -> Result<(BlockHash, ChannelMonitor<::EcdsaSigner>), io::Error> + { + self.0.read_channel_monitor_with_updates(monitor_key).await + } + + /// Cleans up stale updates for all monitors. + /// + /// This function works by first listing all monitors, and then for each of them, listing all + /// updates. The updates that have an `update_id` less than or equal to than the stored monitor + /// are deleted. The deletion can either be lazy or non-lazy based on the `lazy` flag; this will + /// be passed to [`KVStoreSync::remove`]. + pub async fn cleanup_stale_updates(&self, lazy: bool) -> Result<(), io::Error> { + self.0.cleanup_stale_updates(lazy).await + } +} + +impl< + K: Deref + MaybeSend + MaybeSync + 'static, + S: FutureSpawner, + L: Deref + MaybeSend + MaybeSync + 'static, + ES: Deref + MaybeSend + MaybeSync + 'static, + SP: Deref + MaybeSend + MaybeSync + 'static, + BI: Deref + MaybeSend + MaybeSync + 'static, + FE: Deref + MaybeSend + MaybeSync + 'static, + > MonitorUpdatingPersisterAsync +where + K::Target: KVStore + MaybeSync, + L::Target: Logger, + ES::Target: EntropySource + Sized, + SP::Target: SignerProvider + Sized, + BI::Target: BroadcasterInterface, + FE::Target: FeeEstimator, + ::EcdsaSigner: MaybeSend + 'static, +{ + pub(crate) fn spawn_async_persist_new_channel( + &self, monitor_name: MonitorName, + monitor: &ChannelMonitor<::EcdsaSigner>, + ) { + let inner = Arc::clone(&self.0); + // Note that `persist_new_channel` is a sync method which calls all the way through to the + // sync KVStore::write method (which returns a future) to ensure writes are well-ordered. + let future = inner.persist_new_channel(monitor_name, monitor); + let channel_id = monitor.channel_id(); + let completion = (monitor.channel_id(), monitor.get_latest_update_id()); + self.0.future_spawner.spawn(async move { + match future.await { + Ok(()) => inner.async_completed_updates.lock().unwrap().push(completion), + Err(e) => { + log_error!( + inner.logger, + "Failed to persist new ChannelMonitor {channel_id}: {e}. The node will now likely stall as this channel will not be able to make progress. You should restart as soon as possible.", + ); + }, + } + }); + } + + pub(crate) fn spawn_async_update_persisted_channel( + &self, monitor_name: MonitorName, update: Option<&ChannelMonitorUpdate>, + monitor: &ChannelMonitor<::EcdsaSigner>, + ) { + let inner = Arc::clone(&self.0); + // Note that `update_persisted_channel` is a sync method which calls all the way through to + // the sync KVStore::write method (which returns a future) to ensure writes are well-ordered + let future = inner.update_persisted_channel(monitor_name, update, monitor); + let channel_id = monitor.channel_id(); + let completion = if let Some(update) = update { + Some((monitor.channel_id(), update.update_id)) + } else { + None + }; + let inner = Arc::clone(&self.0); + self.0.future_spawner.spawn(async move { + match future.await { + Ok(()) => if let Some(completion) = completion { + inner.async_completed_updates.lock().unwrap().push(completion); + }, + Err(e) => { + log_error!( + inner.logger, + "Failed to persist new ChannelMonitor {channel_id}: {e}. The node will now likely stall as this channel will not be able to make progress. You should restart as soon as possible.", + ); + }, + } + }); + } + + pub(crate) fn spawn_async_archive_persisted_channel(&self, monitor_name: MonitorName) { + let inner = Arc::clone(&self.0); + self.0.future_spawner.spawn(async move { + inner.archive_persisted_channel(monitor_name).await; + }); + } + + pub(crate) fn get_and_clear_completed_updates(&self) -> Vec<(ChannelId, u64)> { + mem::take(&mut *self.0.async_completed_updates.lock().unwrap()) + } +} + +impl + MonitorUpdatingPersisterAsyncInner +where + K::Target: KVStore, + L::Target: Logger, + ES::Target: EntropySource + Sized, + SP::Target: SignerProvider + Sized, + BI::Target: BroadcasterInterface, + FE::Target: FeeEstimator, +{ + pub async fn read_channel_monitor_with_updates( &self, monitor_key: &str, ) -> Result<(BlockHash, ChannelMonitor<::EcdsaSigner>), io::Error> { let monitor_name = MonitorName::from_str(monitor_key)?; - let (block_hash, monitor) = self.read_monitor(&monitor_name, monitor_key)?; + let (block_hash, monitor) = self.read_monitor(&monitor_name, monitor_key).await?; let mut current_update_id = monitor.get_latest_update_id(); + // TODO: Parallelize this loop by speculatively reading a batch of updates loop { current_update_id = match current_update_id.checked_add(1) { Some(next_update_id) => next_update_id, None => break, }; let update_name = UpdateName::from(current_update_id); - let update = match self.read_monitor_update(monitor_key, &update_name) { + let update = match self.read_monitor_update(monitor_key, &update_name).await { Ok(update) => update, Err(err) if err.kind() == io::ErrorKind::NotFound => { // We can't find any more updates, so we are done. @@ -629,15 +977,14 @@ where } /// Read a channel monitor. - fn read_monitor( + async fn read_monitor( &self, monitor_name: &MonitorName, monitor_key: &str, ) -> Result<(BlockHash, ChannelMonitor<::EcdsaSigner>), io::Error> { - let mut monitor_cursor = io::Cursor::new(self.kv_store.read( - CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE, - CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE, - monitor_key, - )?); + let primary = CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE; + let secondary = CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE; + let monitor_bytes = self.kv_store.read(primary, secondary, monitor_key).await?; + let mut monitor_cursor = io::Cursor::new(monitor_bytes); // Discard the sentinel bytes if found. if monitor_cursor.get_ref().starts_with(MONITOR_UPDATING_PERSISTER_PREPEND_SENTINEL) { monitor_cursor.set_position(MONITOR_UPDATING_PERSISTER_PREPEND_SENTINEL.len() as u64); @@ -674,15 +1021,12 @@ where } /// Read a channel monitor update. - fn read_monitor_update( + async fn read_monitor_update( &self, monitor_key: &str, update_name: &UpdateName, ) -> Result { - let update_bytes = self.kv_store.read( - CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, - monitor_key, - update_name.as_str(), - )?; - ChannelMonitorUpdate::read(&mut io::Cursor::new(update_bytes)).map_err(|e| { + let primary = CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE; + let update_bytes = self.kv_store.read(primary, monitor_key, update_name.as_str()).await?; + ChannelMonitorUpdate::read(&mut &update_bytes[..]).map_err(|e| { log_error!( self.logger, "Failed to read ChannelMonitorUpdate {}/{}/{}, reason: {}", @@ -695,222 +1039,166 @@ where }) } - /// Cleans up stale updates for all monitors. - /// - /// This function works by first listing all monitors, and then for each of them, listing all - /// updates. The updates that have an `update_id` less than or equal to than the stored monitor - /// are deleted. The deletion can either be lazy or non-lazy based on the `lazy` flag; this will - /// be passed to [`KVStoreSync::remove`]. - pub fn cleanup_stale_updates(&self, lazy: bool) -> Result<(), io::Error> { - let monitor_keys = self.kv_store.list( - CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE, - CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE, - )?; + async fn cleanup_stale_updates(&self, lazy: bool) -> Result<(), io::Error> { + let primary = CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE; + let secondary = CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE; + let monitor_keys = self.kv_store.list(primary, secondary).await?; for monitor_key in monitor_keys { let monitor_name = MonitorName::from_str(&monitor_key)?; - let (_, current_monitor) = self.read_monitor(&monitor_name, &monitor_key)?; - let updates = self - .kv_store - .list(CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, monitor_key.as_str())?; - for update in updates { - let update_name = UpdateName::new(update)?; - // if the update_id is lower than the stored monitor, delete - if update_name.0 <= current_monitor.get_latest_update_id() { - self.kv_store.remove( - CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, - monitor_key.as_str(), - update_name.as_str(), - lazy, - )?; - } + let (_, current_monitor) = self.read_monitor(&monitor_name, &monitor_key).await?; + let latest_update_id = current_monitor.get_latest_update_id(); + self.cleanup_stale_updates_for_monitor_to(&monitor_key, latest_update_id, lazy).await?; + } + Ok(()) + } + + async fn cleanup_stale_updates_for_monitor_to( + &self, monitor_key: &str, latest_update_id: u64, lazy: bool, + ) -> Result<(), io::Error> { + let primary = CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE; + let updates = self.kv_store.list(primary, monitor_key).await?; + for update in updates { + let update_name = UpdateName::new(update)?; + // if the update_id is lower than the stored monitor, delete + if update_name.0 <= latest_update_id { + self.kv_store.remove(primary, monitor_key, update_name.as_str(), lazy).await?; } } Ok(()) } -} -impl< - ChannelSigner: EcdsaChannelSigner, - K: Deref, - L: Deref, - ES: Deref, - SP: Deref, - BI: Deref, - FE: Deref, - > Persist for MonitorUpdatingPersister -where - K::Target: KVStoreSync, - L::Target: Logger, - ES::Target: EntropySource + Sized, - SP::Target: SignerProvider + Sized, - BI::Target: BroadcasterInterface, - FE::Target: FeeEstimator, -{ - /// Persists a new channel. This means writing the entire monitor to the - /// parametrized [`KVStoreSync`]. - fn persist_new_channel( + fn persist_new_channel( &self, monitor_name: MonitorName, monitor: &ChannelMonitor, - ) -> chain::ChannelMonitorUpdateStatus { + ) -> impl Future> { // Determine the proper key for this monitor let monitor_key = monitor_name.to_string(); // Serialize and write the new monitor let mut monitor_bytes = Vec::with_capacity( MONITOR_UPDATING_PERSISTER_PREPEND_SENTINEL.len() + monitor.serialized_length(), ); - monitor_bytes.extend_from_slice(MONITOR_UPDATING_PERSISTER_PREPEND_SENTINEL); - monitor.write(&mut monitor_bytes).unwrap(); - match self.kv_store.write( - CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE, - CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE, - monitor_key.as_str(), - monitor_bytes, - ) { - Ok(_) => chain::ChannelMonitorUpdateStatus::Completed, - Err(e) => { - log_error!( - self.logger, - "Failed to write ChannelMonitor {}/{}/{} reason: {}", - CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE, - CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE, - monitor_key.as_str(), - e - ); - chain::ChannelMonitorUpdateStatus::UnrecoverableError - }, + // If `maximum_pending_updates` is zero, we aren't actually writing monitor updates at all. + // Thus, there's no need to add the sentinel prefix as the monitor can be read directly + // from disk without issue. + if self.maximum_pending_updates != 0 { + monitor_bytes.extend_from_slice(MONITOR_UPDATING_PERSISTER_PREPEND_SENTINEL); } + monitor.write(&mut monitor_bytes).unwrap(); + // Note that this is NOT an async function, but rather calls the *sync* KVStore write + // method, allowing it to do its queueing immediately, and then return a future for the + // completion of the write. This ensures monitor persistence ordering is preserved. + let primary = CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE; + let secondary = CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE; + self.kv_store.write(primary, secondary, monitor_key.as_str(), monitor_bytes) } - /// Persists a channel update, writing only the update to the parameterized [`KVStoreSync`] if possible. - /// - /// In some cases, this will forward to [`MonitorUpdatingPersister::persist_new_channel`]: - /// - /// - No full monitor is found in [`KVStoreSync`] - /// - The number of pending updates exceeds `maximum_pending_updates` as given to [`Self::new`] - /// - LDK commands re-persisting the entire monitor through this function, specifically when - /// `update` is `None`. - /// - The update is at [`u64::MAX`], indicating an update generated by pre-0.1 LDK. - fn update_persisted_channel( - &self, monitor_name: MonitorName, update: Option<&ChannelMonitorUpdate>, + fn update_persisted_channel<'a, ChannelSigner: EcdsaChannelSigner + 'a>( + self: Arc, monitor_name: MonitorName, update: Option<&ChannelMonitorUpdate>, monitor: &ChannelMonitor, - ) -> chain::ChannelMonitorUpdateStatus { + ) -> impl Future> + 'a + where + Self: 'a, + { const LEGACY_CLOSED_CHANNEL_UPDATE_ID: u64 = u64::MAX; + let mut res_a = None; + let mut res_b = None; + let mut res_c = None; if let Some(update) = update { let persist_update = update.update_id != LEGACY_CLOSED_CHANNEL_UPDATE_ID + && self.maximum_pending_updates != 0 && update.update_id % self.maximum_pending_updates != 0; if persist_update { let monitor_key = monitor_name.to_string(); let update_name = UpdateName::from(update.update_id); - match self.kv_store.write( - CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, - monitor_key.as_str(), + let primary = CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE; + // Note that this is NOT an async function, but rather calls the *sync* KVStore + // write method, allowing it to do its queueing immediately, and then return a + // future for the completion of the write. This ensures monitor persistence + // ordering is preserved. + res_a = Some(self.kv_store.write( + primary, + &monitor_key, update_name.as_str(), update.encode(), - ) { - Ok(()) => chain::ChannelMonitorUpdateStatus::Completed, - Err(e) => { - log_error!( - self.logger, - "Failed to write ChannelMonitorUpdate {}/{}/{} reason: {}", - CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, - monitor_key.as_str(), - update_name.as_str(), - e - ); - chain::ChannelMonitorUpdateStatus::UnrecoverableError - }, - } + )); } else { - // In case of channel-close monitor update, we need to read old monitor before persisting - // the new one in order to determine the cleanup range. - let maybe_old_monitor = match monitor.get_latest_update_id() { - LEGACY_CLOSED_CHANNEL_UPDATE_ID => { - let monitor_key = monitor_name.to_string(); - self.read_monitor(&monitor_name, &monitor_key).ok() - }, - _ => None, - }; - // We could write this update, but it meets criteria of our design that calls for a full monitor write. - let monitor_update_status = self.persist_new_channel(monitor_name, monitor); - - if let chain::ChannelMonitorUpdateStatus::Completed = monitor_update_status { - let channel_closed_legacy = - monitor.get_latest_update_id() == LEGACY_CLOSED_CHANNEL_UPDATE_ID; - let cleanup_range = if channel_closed_legacy { - // If there is an error while reading old monitor, we skip clean up. - maybe_old_monitor.map(|(_, ref old_monitor)| { - let start = old_monitor.get_latest_update_id(); - // We never persist an update with the legacy closed update_id - let end = cmp::min( - start.saturating_add(self.maximum_pending_updates), - LEGACY_CLOSED_CHANNEL_UPDATE_ID - 1, - ); - (start, end) - }) - } else { - let end = monitor.get_latest_update_id(); - let start = end.saturating_sub(self.maximum_pending_updates); - Some((start, end)) - }; - - if let Some((start, end)) = cleanup_range { - self.cleanup_in_range(monitor_name, start, end); + // Note that this is NOT an async function, but rather calls the *sync* KVStore + // write method, allowing it to do its queueing immediately, and then return a + // future for the completion of the write. This ensures monitor persistence + // ordering is preserved. This, thus, must happen before any await we do below. + let write_fut = self.persist_new_channel(monitor_name, monitor); + let latest_update_id = monitor.get_latest_update_id(); + + res_b = Some(async move { + let write_status = write_fut.await; + if let Ok(()) = write_status { + if latest_update_id == LEGACY_CLOSED_CHANNEL_UPDATE_ID { + let monitor_key = monitor_name.to_string(); + self.cleanup_stale_updates_for_monitor_to( + &monitor_key, + latest_update_id, + true, + ) + .await?; + } else { + let end = latest_update_id; + let start = end.saturating_sub(self.maximum_pending_updates); + self.cleanup_in_range(monitor_name, start, end).await; + } } - } - monitor_update_status + write_status + }); } } else { // There is no update given, so we must persist a new monitor. - self.persist_new_channel(monitor_name, monitor) + // Note that this is NOT an async function, but rather calls the *sync* KVStore write + // method, allowing it to do its queueing immediately, and then return a future for the + // completion of the write. This ensures monitor persistence ordering is preserved. + res_c = Some(self.persist_new_channel(monitor_name, monitor)); + } + async move { + // Complete any pending future(s). Note that to keep one return type we have to end + // with a single async move block that we return, rather than trying to return the + // individual futures themselves. + if let Some(a) = res_a { + a.await?; + } + if let Some(b) = res_b { + b.await?; + } + if let Some(c) = res_c { + c.await?; + } + Ok(()) } } - fn archive_persisted_channel(&self, monitor_name: MonitorName) { + async fn archive_persisted_channel(&self, monitor_name: MonitorName) { let monitor_key = monitor_name.to_string(); - let monitor = match self.read_channel_monitor_with_updates(&monitor_key) { + let monitor = match self.read_channel_monitor_with_updates(&monitor_key).await { Ok((_block_hash, monitor)) => monitor, Err(_) => return, }; - match self.kv_store.write( - ARCHIVED_CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE, - ARCHIVED_CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE, - monitor_key.as_str(), - monitor.encode(), - ) { + let primary = ARCHIVED_CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE; + let secondary = ARCHIVED_CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE; + match self.kv_store.write(primary, secondary, &monitor_key, monitor.encode()).await { Ok(()) => {}, Err(_e) => return, }; - let _ = self.kv_store.remove( - CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE, - CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE, - monitor_key.as_str(), - true, - ); + let primary = CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE; + let secondary = CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE; + let _ = self.kv_store.remove(primary, secondary, &monitor_key, true).await; } -} -impl - MonitorUpdatingPersister -where - ES::Target: EntropySource + Sized, - K::Target: KVStoreSync, - L::Target: Logger, - SP::Target: SignerProvider + Sized, - BI::Target: BroadcasterInterface, - FE::Target: FeeEstimator, -{ // Cleans up monitor updates for given monitor in range `start..=end`. - fn cleanup_in_range(&self, monitor_name: MonitorName, start: u64, end: u64) { + async fn cleanup_in_range(&self, monitor_name: MonitorName, start: u64, end: u64) { let monitor_key = monitor_name.to_string(); for update_id in start..=end { let update_name = UpdateName::from(update_id); - if let Err(e) = self.kv_store.remove( - CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, - monitor_key.as_str(), - update_name.as_str(), - true, - ) { + let primary = CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE; + let res = self.kv_store.remove(primary, &monitor_key, update_name.as_str(), true).await; + if let Err(e) = res { log_error!( self.logger, "Failed to clean up channel monitor updates for monitor {}, reason: {}", @@ -1111,9 +1399,10 @@ mod tests { use crate::ln::msgs::BaseMessageHandler; use crate::sync::Arc; use crate::util::test_channel_signer::TestChannelSigner; - use crate::util::test_utils::{self, TestLogger, TestStore}; + use crate::util::test_utils::{self, TestStore}; use crate::{check_added_monitors, check_closed_broadcast}; use bitcoin::hashes::hex::FromHex; + use core::cmp; const EXPECTED_UPDATES_PER_PAYMENT: u64 = 5; @@ -1188,31 +1477,28 @@ mod tests { } // Exercise the `MonitorUpdatingPersister` with real channels and payments. - #[test] - fn persister_with_real_monitors() { - // This value is used later to limit how many iterations we perform. - let persister_0_max_pending_updates = 7; - // Intentionally set this to a smaller value to test a different alignment. - let persister_1_max_pending_updates = 3; + fn do_persister_with_real_monitors(max_pending_updates_0: u64, max_pending_updates_1: u64) { let chanmon_cfgs = create_chanmon_cfgs(4); - let persister_0 = MonitorUpdatingPersister { - kv_store: &TestStore::new(false), - logger: &TestLogger::new(), - maximum_pending_updates: persister_0_max_pending_updates, - entropy_source: &chanmon_cfgs[0].keys_manager, - signer_provider: &chanmon_cfgs[0].keys_manager, - broadcaster: &chanmon_cfgs[0].tx_broadcaster, - fee_estimator: &chanmon_cfgs[0].fee_estimator, - }; - let persister_1 = MonitorUpdatingPersister { - kv_store: &TestStore::new(false), - logger: &TestLogger::new(), - maximum_pending_updates: persister_1_max_pending_updates, - entropy_source: &chanmon_cfgs[1].keys_manager, - signer_provider: &chanmon_cfgs[1].keys_manager, - broadcaster: &chanmon_cfgs[1].tx_broadcaster, - fee_estimator: &chanmon_cfgs[1].fee_estimator, - }; + let kv_store_0 = TestStore::new(false); + let persister_0 = MonitorUpdatingPersister::new( + &kv_store_0, + &chanmon_cfgs[0].logger, + max_pending_updates_0, + &chanmon_cfgs[0].keys_manager, + &chanmon_cfgs[0].keys_manager, + &chanmon_cfgs[0].tx_broadcaster, + &chanmon_cfgs[0].fee_estimator, + ); + let kv_store_1 = TestStore::new(false); + let persister_1 = MonitorUpdatingPersister::new( + &kv_store_1, + &chanmon_cfgs[1].logger, + max_pending_updates_1, + &chanmon_cfgs[1].keys_manager, + &chanmon_cfgs[1].keys_manager, + &chanmon_cfgs[1].tx_broadcaster, + &chanmon_cfgs[1].fee_estimator, + ); let mut node_cfgs = create_node_cfgs(2, &chanmon_cfgs); let chain_mon_0 = test_utils::TestChainMonitor::new( Some(&chanmon_cfgs[0].chain_source), @@ -1256,17 +1542,17 @@ mod tests { assert_eq!(mon.get_latest_update_id(), $expected_update_id); let monitor_name = mon.persistence_key(); - assert_eq!( - KVStoreSync::list( - &*persister_0.kv_store, - CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, - &monitor_name.to_string() - ) - .unwrap() - .len() as u64, - mon.get_latest_update_id() % persister_0_max_pending_updates, - "Wrong number of updates stored in persister 0", + let expected_updates = if max_pending_updates_0 == 0 { + 0 + } else { + mon.get_latest_update_id() % max_pending_updates_0 + }; + let update_list = KVStoreSync::list( + &kv_store_0, + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, + &monitor_name.to_string(), ); + assert_eq!(update_list.unwrap().len() as u64, expected_updates, "persister 0"); } persisted_chan_data_1 = persister_1.read_all_channel_monitors_with_updates().unwrap(); @@ -1274,17 +1560,17 @@ mod tests { for (_, mon) in persisted_chan_data_1.iter() { assert_eq!(mon.get_latest_update_id(), $expected_update_id); let monitor_name = mon.persistence_key(); - assert_eq!( - KVStoreSync::list( - &*persister_1.kv_store, - CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, - &monitor_name.to_string() - ) - .unwrap() - .len() as u64, - mon.get_latest_update_id() % persister_1_max_pending_updates, - "Wrong number of updates stored in persister 1", + let expected_updates = if max_pending_updates_1 == 0 { + 0 + } else { + mon.get_latest_update_id() % max_pending_updates_1 + }; + let update_list = KVStoreSync::list( + &kv_store_1, + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, + &monitor_name.to_string(), ); + assert_eq!(update_list.unwrap().len() as u64, expected_updates, "persister 1"); } }; } @@ -1302,7 +1588,7 @@ mod tests { // Send a few more payments to try all the alignments of max pending updates with // updates for a payment sent and received. let mut sender = 0; - for i in 3..=persister_0_max_pending_updates * 2 { + for i in 3..=max_pending_updates_0 * 2 { let receiver; if sender == 0 { sender = 1; @@ -1345,11 +1631,19 @@ mod tests { check_added_monitors!(nodes[1], 1); // Make sure everything is persisted as expected after close. + // We always send at least two payments, and loop up to max_pending_updates_0 * 2. check_persisted_data!( - persister_0_max_pending_updates * 2 * EXPECTED_UPDATES_PER_PAYMENT + 1 + cmp::max(2, max_pending_updates_0 * 2) * EXPECTED_UPDATES_PER_PAYMENT + 1 ); } + #[test] + fn persister_with_real_monitors() { + do_persister_with_real_monitors(7, 3); + do_persister_with_real_monitors(0, 1); + do_persister_with_real_monitors(4, 2); + } + // Test that if the `MonitorUpdatingPersister`'s can't actually write, trying to persist a // monitor or update with it results in the persister returning an UnrecoverableError status. #[test] @@ -1377,15 +1671,16 @@ mod tests { let cmu_map = nodes[1].chain_monitor.monitor_updates.lock().unwrap(); let cmu = &cmu_map.get(&added_monitors[0].1.channel_id()).unwrap()[0]; - let ro_persister = MonitorUpdatingPersister { - kv_store: &TestStore::new(true), - logger: &TestLogger::new(), - maximum_pending_updates: 11, - entropy_source: node_cfgs[0].keys_manager, - signer_provider: node_cfgs[0].keys_manager, - broadcaster: node_cfgs[0].tx_broadcaster, - fee_estimator: node_cfgs[0].fee_estimator, - }; + let store = TestStore::new(true); + let ro_persister = MonitorUpdatingPersister::new( + &store, + node_cfgs[0].logger, + 11, + node_cfgs[0].keys_manager, + node_cfgs[0].keys_manager, + node_cfgs[0].tx_broadcaster, + node_cfgs[0].fee_estimator, + ); let monitor_name = added_monitors[0].1.persistence_key(); match ro_persister.persist_new_channel(monitor_name, &added_monitors[0].1) { ChannelMonitorUpdateStatus::UnrecoverableError => { @@ -1423,24 +1718,26 @@ mod tests { fn clean_stale_updates_works() { let test_max_pending_updates = 7; let chanmon_cfgs = create_chanmon_cfgs(3); - let persister_0 = MonitorUpdatingPersister { - kv_store: &TestStore::new(false), - logger: &TestLogger::new(), - maximum_pending_updates: test_max_pending_updates, - entropy_source: &chanmon_cfgs[0].keys_manager, - signer_provider: &chanmon_cfgs[0].keys_manager, - broadcaster: &chanmon_cfgs[0].tx_broadcaster, - fee_estimator: &chanmon_cfgs[0].fee_estimator, - }; - let persister_1 = MonitorUpdatingPersister { - kv_store: &TestStore::new(false), - logger: &TestLogger::new(), - maximum_pending_updates: test_max_pending_updates, - entropy_source: &chanmon_cfgs[1].keys_manager, - signer_provider: &chanmon_cfgs[1].keys_manager, - broadcaster: &chanmon_cfgs[1].tx_broadcaster, - fee_estimator: &chanmon_cfgs[1].fee_estimator, - }; + let kv_store_0 = TestStore::new(false); + let persister_0 = MonitorUpdatingPersister::new( + &kv_store_0, + &chanmon_cfgs[0].logger, + test_max_pending_updates, + &chanmon_cfgs[0].keys_manager, + &chanmon_cfgs[0].keys_manager, + &chanmon_cfgs[0].tx_broadcaster, + &chanmon_cfgs[0].fee_estimator, + ); + let kv_store_1 = TestStore::new(false); + let persister_1 = MonitorUpdatingPersister::new( + &kv_store_1, + &chanmon_cfgs[1].logger, + test_max_pending_updates, + &chanmon_cfgs[1].keys_manager, + &chanmon_cfgs[1].keys_manager, + &chanmon_cfgs[1].tx_broadcaster, + &chanmon_cfgs[1].fee_estimator, + ); let mut node_cfgs = create_node_cfgs(2, &chanmon_cfgs); let chain_mon_0 = test_utils::TestChainMonitor::new( Some(&chanmon_cfgs[0].chain_source), @@ -1480,7 +1777,7 @@ mod tests { let (_, monitor) = &persisted_chan_data[0]; let monitor_name = monitor.persistence_key(); KVStoreSync::write( - &*persister_0.kv_store, + &kv_store_0, CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, &monitor_name.to_string(), UpdateName::from(1).as_str(), @@ -1493,7 +1790,7 @@ mod tests { // Confirm the stale update is unreadable/gone assert!(KVStoreSync::read( - &*persister_0.kv_store, + &kv_store_0, CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, &monitor_name.to_string(), UpdateName::from(1).as_str() diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 769c2a3ed3e..698e7519cbb 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -89,6 +89,7 @@ use core::future::Future; use core::mem; use core::pin::Pin; use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use core::task::{Context, Poll, Waker}; use core::time::Duration; use bitcoin::psbt::Psbt; @@ -856,26 +857,100 @@ impl Persist for TestPersister } } +// A simple multi-producer-single-consumer one-shot channel +type OneShotChannelState = Arc>, Option)>>; +struct OneShotChannel(OneShotChannelState); +impl Future for OneShotChannel { + type Output = Result<(), io::Error>; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut state = self.0.lock().unwrap(); + // If the future is complete, take() the result and return it, + state.0.take().map(|res| Poll::Ready(res)).unwrap_or_else(|| { + // otherwise, store the waker so that the future will be poll()ed again when the result + // is ready. + state.1 = Some(cx.waker().clone()); + Poll::Pending + }) + } +} + +/// An in-memory KVStore for testing. +/// +/// Sync writes always complete immediately while async writes always block until manually +/// completed with [`Self::complete_async_writes_through`] or [`Self::complete_all_async_writes`]. +/// +/// Removes always complete immediately. pub struct TestStore { + pending_async_writes: Mutex)>>>, persisted_bytes: Mutex>>>, read_only: bool, } impl TestStore { pub fn new(read_only: bool) -> Self { + let pending_async_writes = Mutex::new(new_hash_map()); let persisted_bytes = Mutex::new(new_hash_map()); - Self { persisted_bytes, read_only } + Self { pending_async_writes, persisted_bytes, read_only } + } + + pub fn list_pending_async_writes( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> Vec { + let key = format!("{primary_namespace}/{secondary_namespace}/{key}"); + let writes_lock = self.pending_async_writes.lock().unwrap(); + writes_lock + .get(&key) + .map(|v| v.iter().map(|(id, _, _)| *id).collect()) + .unwrap_or(Vec::new()) + } + + /// Completes all pending async writes for the given namespace and key, up to and through the + /// given `write_id` (which can be fetched from [`Self::list_pending_async_writes`]). + pub fn complete_async_writes_through( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, write_id: usize, + ) { + let prefix = format!("{primary_namespace}/{secondary_namespace}"); + let key = format!("{primary_namespace}/{secondary_namespace}/{key}"); + + let mut persisted_lock = self.persisted_bytes.lock().unwrap(); + let mut writes_lock = self.pending_async_writes.lock().unwrap(); + + let pending_writes = writes_lock.get_mut(&key).expect("No pending writes for given key"); + pending_writes.retain(|(id, res, data)| { + if *id <= write_id { + let namespace = persisted_lock.entry(prefix.clone()).or_insert(new_hash_map()); + *namespace.entry(key.to_string()).or_default() = data.clone(); + let mut future_state = res.lock().unwrap(); + future_state.0 = Some(Ok(())); + if let Some(waker) = future_state.1.take() { + waker.wake(); + } + false + } else { + true + } + }); + } + + /// Completes all pending async writes on all namespaces and keys. + pub fn complete_all_async_writes(&self) { + let pending_writes: Vec = + self.pending_async_writes.lock().unwrap().keys().cloned().collect(); + for key in pending_writes { + let mut levels = key.split("/"); + let primary = levels.next().unwrap(); + let secondary = levels.next().unwrap(); + let key = levels.next().unwrap(); + assert!(levels.next().is_none()); + self.complete_async_writes_through(primary, secondary, key, usize::MAX); + } } fn read_internal( &self, primary_namespace: &str, secondary_namespace: &str, key: &str, ) -> io::Result> { let persisted_lock = self.persisted_bytes.lock().unwrap(); - let prefixed = if secondary_namespace.is_empty() { - primary_namespace.to_string() - } else { - format!("{}/{}", primary_namespace, secondary_namespace) - }; + let prefixed = format!("{primary_namespace}/{secondary_namespace}"); if let Some(outer_ref) = persisted_lock.get(&prefixed) { if let Some(inner_ref) = outer_ref.get(key) { @@ -889,29 +964,6 @@ impl TestStore { } } - fn write_internal( - &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, - ) -> io::Result<()> { - if self.read_only { - return Err(io::Error::new( - io::ErrorKind::PermissionDenied, - "Cannot modify read-only store", - )); - } - let mut persisted_lock = self.persisted_bytes.lock().unwrap(); - - let prefixed = if secondary_namespace.is_empty() { - primary_namespace.to_string() - } else { - format!("{}/{}", primary_namespace, secondary_namespace) - }; - let outer_e = persisted_lock.entry(prefixed).or_insert(new_hash_map()); - let mut bytes = Vec::new(); - bytes.write_all(&buf)?; - outer_e.insert(key.to_string(), bytes); - Ok(()) - } - fn remove_internal( &self, primary_namespace: &str, secondary_namespace: &str, key: &str, _lazy: bool, ) -> io::Result<()> { @@ -923,16 +975,23 @@ impl TestStore { } let mut persisted_lock = self.persisted_bytes.lock().unwrap(); + let mut async_writes_lock = self.pending_async_writes.lock().unwrap(); - let prefixed = if secondary_namespace.is_empty() { - primary_namespace.to_string() - } else { - format!("{}/{}", primary_namespace, secondary_namespace) - }; + let prefixed = format!("{primary_namespace}/{secondary_namespace}"); if let Some(outer_ref) = persisted_lock.get_mut(&prefixed) { outer_ref.remove(&key.to_string()); } + if let Some(pending_writes) = async_writes_lock.remove(&format!("{prefixed}/{key}")) { + for (_, future, _) in pending_writes { + let mut future_lock = future.lock().unwrap(); + future_lock.0 = Some(Ok(())); + if let Some(waker) = future_lock.1.take() { + waker.wake(); + } + } + } + Ok(()) } @@ -941,11 +1000,7 @@ impl TestStore { ) -> io::Result> { let mut persisted_lock = self.persisted_bytes.lock().unwrap(); - let prefixed = if secondary_namespace.is_empty() { - primary_namespace.to_string() - } else { - format!("{}/{}", primary_namespace, secondary_namespace) - }; + let prefixed = format!("{primary_namespace}/{secondary_namespace}"); match persisted_lock.entry(prefixed) { hash_map::Entry::Occupied(e) => Ok(e.get().keys().cloned().collect()), hash_map::Entry::Vacant(_) => Ok(Vec::new()), @@ -963,8 +1018,15 @@ impl KVStore for TestStore { fn write( &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, ) -> Pin> + 'static + Send>> { - let res = self.write_internal(&primary_namespace, &secondary_namespace, &key, buf); - Box::pin(async move { res }) + let path = format!("{primary_namespace}/{secondary_namespace}/{key}"); + let future = Arc::new(Mutex::new((None, None))); + + let mut async_writes_lock = self.pending_async_writes.lock().unwrap(); + let pending_writes = async_writes_lock.entry(path).or_insert(Vec::new()); + let new_id = pending_writes.last().map(|(id, _, _)| id + 1).unwrap_or(0); + pending_writes.push((new_id, Arc::clone(&future), buf)); + + Box::pin(OneShotChannel(future)) } fn remove( &self, primary_namespace: &str, secondary_namespace: &str, key: &str, lazy: bool, @@ -990,7 +1052,30 @@ impl KVStoreSync for TestStore { fn write( &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, ) -> io::Result<()> { - self.write_internal(primary_namespace, secondary_namespace, key, buf) + if self.read_only { + return Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "Cannot modify read-only store", + )); + } + let mut persisted_lock = self.persisted_bytes.lock().unwrap(); + let mut async_writes_lock = self.pending_async_writes.lock().unwrap(); + + let prefixed = format!("{primary_namespace}/{secondary_namespace}"); + let async_writes_pending = async_writes_lock.remove(&format!("{prefixed}/{key}")); + let outer_e = persisted_lock.entry(prefixed).or_insert(new_hash_map()); + outer_e.insert(key.to_string(), buf); + + if let Some(pending_writes) = async_writes_pending { + for (_, future, _) in pending_writes { + let mut future_lock = future.lock().unwrap(); + future_lock.0 = Some(Ok(())); + if let Some(waker) = future_lock.1.take() { + waker.wake(); + } + } + } + Ok(()) } fn remove(