diff --git a/src/lib.rs b/src/lib.rs index 6a8343b..0272652 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -99,7 +99,7 @@ //! #[test] //! #[should_panic] //! fn test_fallible_work() { -//! let local_registry = fail::new_fail_group(); +//! let local_registry = fail::create_registry(); //! local_registry.register_current(); //! fail::cfg("read-dir", "panic").unwrap(); //! @@ -109,19 +109,18 @@ //! } //! ``` //! -//! It should be noted that the local registry will inherit the global registry when -//! it is created, which means that the inherited part can be shared. When you remove -//! a global fail point action from the local registry, it will affect all threads -//! using this fail point. +//! It should be noted that the local registry will will overwrite the global registry +//! if you register the current thread here. This means that the current thread can only +//! use the fail points configuration of the local registry after registration. //! -//! Here's a example to show the inheritance process: +//! Here's a example to show the process: //! //! ```rust //! fail::setup(); //! fail::cfg("p1", "sleep(100)").unwrap(); //! println!("Global registry: {:?}", fail::list()); //! { -//! let local_registry = fail::new_fail_group(); +//! let local_registry = fail::create_registry(); //! local_registry.register_current(); //! fail::cfg("p0", "pause").unwrap(); //! println!("Local registry: {:?}", fail::list()); @@ -138,10 +137,10 @@ //! FAILPOINTS=p0=return cargo run --features fail/failpoints //! Finished dev [unoptimized + debuginfo] target(s) in 0.01s //! Running `target/debug/failpointtest` -//! Global registry: [("p1", "sleep(100)"), ("p0", "return")] -//! Local registry: [("p1", "sleep(100)"), ("p0", "pause")] +//! Global registry: [("p1", "sleep(100)")] +//! Local registry: [("p0", "pause")] //! Local registry: [] -//! Global registry: [("p1", "sleep(100)"), ("p0", "return")] +//! Global registry: [("p1", "sleep(100)")] //! ``` //! //! In this example, program update global registry with environment variable first. @@ -325,19 +324,6 @@ struct Action { count: Option, } -impl Clone for Action { - fn clone(&self) -> Self { - Action { - count: self - .count - .as_ref() - .map(|c| AtomicUsize::new(c.load(Ordering::Relaxed))), - task: self.task.clone(), - freq: self.freq, - } - } -} - impl PartialEq for Action { fn eq(&self, hs: &Action) -> bool { if self.task != hs.task || self.freq != hs.freq { @@ -477,16 +463,6 @@ struct FailPoint { actions_str: RwLock, } -impl Clone for FailPoint { - fn clone(&self) -> Self { - FailPoint { - actions: RwLock::new(self.actions.read().unwrap().clone()), - actions_str: RwLock::new(self.actions_str.read().unwrap().clone()), - ..Default::default() - } - } -} - #[cfg_attr(feature = "cargo-clippy", allow(clippy::mutex_atomic))] impl FailPoint { fn new() -> FailPoint { @@ -575,25 +551,26 @@ pub struct FailPointRegistry { /// /// Each thread should be bound to exact one registry. Threads bound to the /// same registry share the same failpoints configuration. -pub fn new_fail_group() -> FailPointRegistry { - let registry = REGISTRY_GLOBAL.registry.read().unwrap(); - let mut new_registry = Registry::new(); - for (name, failpoint) in registry.iter() { - new_registry.insert(name.clone(), Arc::new(FailPoint::clone(failpoint))); - } +pub fn create_registry() -> FailPointRegistry { FailPointRegistry { - registry: Arc::new(RwLock::new(new_registry)), + registry: Arc::new(RwLock::new(Registry::new())), } } impl FailPointRegistry { /// Register the current thread to this failpoints registry. - pub fn register_current(&self) { + pub fn register_current(&self) -> Result<(), String> { let id = thread::current().id(); - REGISTRY_GROUP + let ret = REGISTRY_GROUP .write() .unwrap() .insert(id, self.registry.clone()); + + if ret.is_some() { + Err("current thread has been registered with one registry".to_owned()) + } else { + Ok(()) + } } /// Deregister the current thread to this failpoints registry. @@ -696,14 +673,13 @@ pub const fn has_failpoints() -> bool { /// /// Return a vector of `(name, actions)` pairs. pub fn list() -> Vec<(String, String)> { - let id = thread::current().id(); - let group = REGISTRY_GROUP.read().unwrap(); + let registry = { + let group = REGISTRY_GROUP.read().unwrap(); + let id = thread::current().id(); + group.get(&id).unwrap_or(®ISTRY_GLOBAL.registry).clone() + }; - let registry = group - .get(&id) - .unwrap_or(®ISTRY_GLOBAL.registry) - .read() - .unwrap(); + let registry = registry.read().unwrap(); registry .iter() @@ -713,15 +689,15 @@ pub fn list() -> Vec<(String, String)> { #[doc(hidden)] pub fn eval) -> R>(name: &str, f: F) -> Option { - let id = thread::current().id(); - let p = { - let group = REGISTRY_GROUP.read().unwrap(); - let registry = group - .get(&id) - .unwrap_or(®ISTRY_GLOBAL.registry) - .read() - .unwrap(); + let registry = { + let group = REGISTRY_GROUP.read().unwrap(); + let id = thread::current().id(); + group.get(&id).unwrap_or(®ISTRY_GLOBAL.registry).clone() + }; + + let registry = registry.read().unwrap(); + match registry.get(name) { None => return None, Some(p) => p.clone(), @@ -1125,6 +1101,28 @@ mod tests { "setup_and_teardown1=return;setup_and_teardown2=pause;", ); setup(); + + let group = create_registry(); + let handler = thread::spawn(move || { + group.register_current().unwrap(); + cfg("setup_and_teardown1", "panic").unwrap(); + cfg("setup_and_teardown2", "panic").unwrap(); + let l = list(); + assert!( + l.iter() + .find(|&x| x == &("setup_and_teardown1".to_owned(), "panic".to_owned())) + .is_some() + && l.iter() + .find(|&x| x == &("setup_and_teardown2".to_owned(), "panic".to_owned())) + .is_some() + && l.len() == 2 + ); + remove("setup_and_teardown2"); + let l = list(); + assert!(l.len() == 1); + }); + handler.join().unwrap(); + assert_eq!(f1(), 1); let (tx, rx) = mpsc::channel(); diff --git a/tests/tests.rs b/tests/tests.rs index 92e11c6..4d37cb1 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -10,8 +10,8 @@ use fail::fail_point; #[test] #[cfg_attr(not(feature = "failpoints"), ignore)] fn test_pause() { - let local_registry = fail::new_fail_group(); - local_registry.register_current(); + let local_registry = fail::create_registry(); + local_registry.register_current().unwrap(); let f = || { fail_point!("pause"); }; @@ -21,7 +21,7 @@ fn test_pause() { let (tx, rx) = mpsc::channel(); let thread_registry = local_registry.clone(); thread::spawn(move || { - thread_registry.register_current(); + thread_registry.register_current().unwrap(); // pause tx.send(f()).unwrap(); // woken up by new order pause, and then pause again. @@ -43,8 +43,8 @@ fn test_pause() { #[test] fn test_off() { - let local_registry = fail::new_fail_group(); - local_registry.register_current(); + let local_registry = fail::create_registry(); + local_registry.register_current().unwrap(); let f = || { fail_point!("off", |_| 2); @@ -59,8 +59,8 @@ fn test_off() { #[test] #[cfg_attr(not(feature = "failpoints"), ignore)] fn test_return() { - let local_registry = fail::new_fail_group(); - local_registry.register_current(); + let local_registry = fail::create_registry(); + local_registry.register_current().unwrap(); let f = || { fail_point!("return", |s: Option| s @@ -79,8 +79,8 @@ fn test_return() { #[test] #[cfg_attr(not(feature = "failpoints"), ignore)] fn test_sleep() { - let local_registry = fail::new_fail_group(); - local_registry.register_current(); + let local_registry = fail::create_registry(); + local_registry.register_current().unwrap(); let f = || { fail_point!("sleep"); @@ -99,8 +99,8 @@ fn test_sleep() { #[should_panic] #[cfg_attr(not(feature = "failpoints"), ignore)] fn test_panic() { - let local_registry = fail::new_fail_group(); - local_registry.register_current(); + let local_registry = fail::create_registry(); + local_registry.register_current().unwrap(); let f = || { fail_point!("panic"); @@ -112,8 +112,8 @@ fn test_panic() { #[test] #[cfg_attr(not(feature = "failpoints"), ignore)] fn test_print() { - let local_registry = fail::new_fail_group(); - local_registry.register_current(); + let local_registry = fail::create_registry(); + local_registry.register_current().unwrap(); struct LogCollector(Arc>>); impl log::Log for LogCollector { @@ -148,8 +148,8 @@ fn test_print() { #[test] fn test_yield() { - let local_registry = fail::new_fail_group(); - local_registry.register_current(); + let local_registry = fail::create_registry(); + local_registry.register_current().unwrap(); let f = || { fail_point!("yield"); @@ -161,8 +161,8 @@ fn test_yield() { #[test] #[cfg_attr(not(feature = "failpoints"), ignore)] fn test_callback() { - let local_registry = fail::new_fail_group(); - local_registry.register_current(); + let local_registry = fail::create_registry(); + local_registry.register_current().unwrap(); let f1 = || { fail_point!("cb"); @@ -185,8 +185,8 @@ fn test_callback() { #[test] #[cfg_attr(not(feature = "failpoints"), ignore)] fn test_delay() { - let local_registry = fail::new_fail_group(); - local_registry.register_current(); + let local_registry = fail::create_registry(); + local_registry.register_current().unwrap(); let f = || fail_point!("delay"); let timer = Instant::now(); @@ -198,8 +198,8 @@ fn test_delay() { #[test] #[cfg_attr(not(feature = "failpoints"), ignore)] fn test_freq_and_count() { - let local_registry = fail::new_fail_group(); - local_registry.register_current(); + let local_registry = fail::create_registry(); + local_registry.register_current().unwrap(); let f = || { fail_point!("freq_and_count", |s: Option| s @@ -222,8 +222,8 @@ fn test_freq_and_count() { #[test] #[cfg_attr(not(feature = "failpoints"), ignore)] fn test_condition() { - let local_registry = fail::new_fail_group(); - local_registry.register_current(); + let local_registry = fail::create_registry(); + local_registry.register_current().unwrap(); let f = |_enabled| { fail_point!("condition", _enabled, |_| 2); @@ -239,8 +239,8 @@ fn test_condition() { #[test] fn test_list() { - let local_registry = fail::new_fail_group(); - local_registry.register_current(); + let local_registry = fail::create_registry(); + local_registry.register_current().unwrap(); assert!(!fail::list().contains(&("list".to_string(), "off".to_string()))); fail::cfg("list", "off").unwrap(); @@ -251,12 +251,12 @@ fn test_list() { #[test] fn test_multiple_threads_cleanup() { - let local_registry = fail::new_fail_group(); - local_registry.register_current(); + let local_registry = fail::create_registry(); + local_registry.register_current().unwrap(); let (tx, rx) = mpsc::channel(); thread::spawn(move || { - local_registry.register_current(); + local_registry.register_current().unwrap(); fail::cfg("thread_point", "sleep(10)").unwrap(); tx.send(()).unwrap(); }); @@ -271,8 +271,8 @@ fn test_multiple_threads_cleanup() { let (tx, rx) = mpsc::channel(); let t = thread::spawn(move || { - let local_registry = fail::new_fail_group(); - local_registry.register_current(); + let local_registry = fail::create_registry(); + local_registry.register_current().unwrap(); fail::cfg("thread_point", "panic").unwrap(); let l = fail::list(); assert!(