diff --git a/lightning/src/util/persist.rs b/lightning/src/util/persist.rs index 6b2ceaf4c34..e3fb86fb88a 100644 --- a/lightning/src/util/persist.rs +++ b/lightning/src/util/persist.rs @@ -1257,14 +1257,13 @@ mod tests { let monitor_name = mon.persistence_key(); assert_eq!( - persister_0 - .kv_store - .list( - CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, - &monitor_name.to_string() - ) - .unwrap() - .len() as u64, + KVStoreSync::list( + &*persister_0.kv_store, + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, + &monitor_name.to_string() + ) + .unwrap() + .len() as u64, mon.get_latest_update_id() % persister_0_max_pending_updates, "Wrong number of updates stored in persister 0", ); @@ -1276,14 +1275,13 @@ mod tests { assert_eq!(mon.get_latest_update_id(), $expected_update_id); let monitor_name = mon.persistence_key(); assert_eq!( - persister_1 - .kv_store - .list( - CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, - &monitor_name.to_string() - ) - .unwrap() - .len() as u64, + KVStoreSync::list( + &*persister_1.kv_store, + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, + &monitor_name.to_string() + ) + .unwrap() + .len() as u64, mon.get_latest_update_id() % persister_1_max_pending_updates, "Wrong number of updates stored in persister 1", ); @@ -1481,28 +1479,26 @@ mod tests { let persisted_chan_data = persister_0.read_all_channel_monitors_with_updates().unwrap(); let (_, monitor) = &persisted_chan_data[0]; let monitor_name = monitor.persistence_key(); - persister_0 - .kv_store - .write( - CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, - &monitor_name.to_string(), - UpdateName::from(1).as_str(), - vec![0u8; 1], - ) - .unwrap(); + KVStoreSync::write( + &*persister_0.kv_store, + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, + &monitor_name.to_string(), + UpdateName::from(1).as_str(), + vec![0u8; 1], + ) + .unwrap(); // Do the stale update cleanup persister_0.cleanup_stale_updates(false).unwrap(); // Confirm the stale update is unreadable/gone - assert!(persister_0 - .kv_store - .read( - CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, - &monitor_name.to_string(), - UpdateName::from(1).as_str() - ) - .is_err()); + assert!(KVStoreSync::read( + &*persister_0.kv_store, + CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, + &monitor_name.to_string(), + UpdateName::from(1).as_str() + ) + .is_err()); } fn persist_fn(_persist: P) -> bool diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index d28d0abbc32..9813071789b 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -57,7 +57,7 @@ use crate::util::dyn_signer::{ use crate::util::logger::{Logger, Record}; #[cfg(feature = "std")] use crate::util::mut_global::MutGlobal; -use crate::util::persist::{KVStoreSync, MonitorName}; +use crate::util::persist::{KVStore, KVStoreSync, MonitorName}; use crate::util::ser::{Readable, ReadableArgs, Writeable, Writer}; use crate::util::test_channel_signer::{EnforcementState, TestChannelSigner}; @@ -84,7 +84,10 @@ use crate::io; use crate::prelude::*; use crate::sign::{EntropySource, NodeSigner, RandomBytes, Recipient, SignerProvider}; use crate::sync::{Arc, Mutex}; +use alloc::boxed::Box; +use core::future::Future; use core::mem; +use core::pin::Pin; use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use core::time::Duration; @@ -863,10 +866,8 @@ impl TestStore { let persisted_bytes = Mutex::new(new_hash_map()); Self { persisted_bytes, read_only } } -} -impl KVStoreSync for TestStore { - fn read( + fn read_internal( &self, primary_namespace: &str, secondary_namespace: &str, key: &str, ) -> io::Result> { let persisted_lock = self.persisted_bytes.lock().unwrap(); @@ -888,7 +889,7 @@ impl KVStoreSync for TestStore { } } - fn write( + fn write_internal( &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, ) -> io::Result<()> { if self.read_only { @@ -911,7 +912,7 @@ impl KVStoreSync for TestStore { Ok(()) } - fn remove( + fn remove_internal( &self, primary_namespace: &str, secondary_namespace: &str, key: &str, _lazy: bool, ) -> io::Result<()> { if self.read_only { @@ -935,7 +936,9 @@ impl KVStoreSync for TestStore { Ok(()) } - fn list(&self, primary_namespace: &str, secondary_namespace: &str) -> io::Result> { + fn list_internal( + &self, primary_namespace: &str, secondary_namespace: &str, + ) -> io::Result> { let mut persisted_lock = self.persisted_bytes.lock().unwrap(); let prefixed = if secondary_namespace.is_empty() { @@ -950,6 +953,89 @@ impl KVStoreSync for TestStore { } } +impl KVStore for TestStore { + fn read( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> Pin, io::Error>> + 'static + Send>> { + let res = self.read_internal(&primary_namespace, &secondary_namespace, &key); + Box::pin(async move { TestStoreFuture::new(res).await }) + } + fn write( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, + ) -> Pin> + 'static + Send>> { + let res = self.write_internal(&primary_namespace, &secondary_namespace, &key, buf); + Box::pin(async move { TestStoreFuture::new(res).await }) + } + fn remove( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, lazy: bool, + ) -> Pin> + 'static + Send>> { + let res = self.remove_internal(&primary_namespace, &secondary_namespace, &key, lazy); + Box::pin(async move { TestStoreFuture::new(res).await }) + } + fn list( + &self, primary_namespace: &str, secondary_namespace: &str, + ) -> Pin, io::Error>> + 'static + Send>> { + let res = self.list_internal(primary_namespace, secondary_namespace); + Box::pin(async move { TestStoreFuture::new(res).await }) + } +} + +impl KVStoreSync for TestStore { + fn read( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> io::Result> { + self.read_internal(primary_namespace, secondary_namespace, key) + } + + fn write( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, + ) -> io::Result<()> { + self.write_internal(primary_namespace, secondary_namespace, key, buf) + } + + fn remove( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, lazy: bool, + ) -> io::Result<()> { + self.remove_internal(primary_namespace, secondary_namespace, key, lazy) + } + + fn list(&self, primary_namespace: &str, secondary_namespace: &str) -> io::Result> { + self.list_internal(primary_namespace, secondary_namespace) + } +} + +// A `Future` that returns the result only on the second poll. +pub(crate) struct TestStoreFuture { + inner: Mutex<(Option, Option>)>, +} + +impl TestStoreFuture { + fn new(res: io::Result) -> Self { + let inner = Mutex::new((None, Some(res))); + Self { inner } + } +} + +impl Future for TestStoreFuture { + type Output = Result; + fn poll( + self: Pin<&mut Self>, cx: &mut core::task::Context<'_>, + ) -> core::task::Poll { + let mut inner_lock = self.inner.lock().unwrap(); + let first_poll = inner_lock.0.is_none(); + if first_poll { + (*inner_lock).0 = Some(cx.waker().clone()); + core::task::Poll::Pending + } else { + let waker = inner_lock.0.take().expect("We should never poll more than twice"); + let res = inner_lock.1.take().expect("We should never poll more than twice"); + drop(inner_lock); + waker.wake(); + core::task::Poll::Ready(res) + } + } +} + unsafe impl Sync for TestStore {} unsafe impl Send for TestStore {}