diff --git a/README.md b/README.md index ecdee1f..9f18192 100644 --- a/README.md +++ b/README.md @@ -21,14 +21,15 @@ Rust implementation of [TAPIR](https://syslab.cs.washington.edu/papers/tapir-tr- - [x] IR sync & merge - [x] Prepare retries - [x] Coordinator recovery - - [ ] Sharding + - [x] Sharding - [ ] Persistent storage (e.g. `sled`) - - [ ] Pessimistic read only transactions + - [ ] Snapshot read - [ ] Planned extensions - [x] Delete key operation - [ ] Garbage collection - - [ ] Quorum range scan + - [ ] Range scan - [ ] Automatic shard balancing + - [ ] Disaster recovery - [ ] Testing - [x] IR lock server (very simple) - [x] TAPIR-KV (simple) @@ -39,6 +40,12 @@ Rust implementation of [TAPIR](https://syslab.cs.washington.edu/papers/tapir-tr- - [ ] Reduce allocations - [ ] Reduce temporary unavailability - ## Acknowledgement +## Acknowledgements - Thank you to the TAPIR authors for answering my questions about the paper! \ No newline at end of file +Thanks to [James Wilcox](https://jamesrwilcox.com) for assigning TAPIR as a reading. + +Thanks to [the TAPIR authors](https://github.com/UWSysLab/tapir#contact-and-questions) for answering questions about +the paper! + +Thanks to [Kyle](https://aphyr.com) at [Jepsen](https://jepsen.io) for clarifying the relative +strength of isolation levels. \ No newline at end of file diff --git a/src/bin/maelstrom.rs b/src/bin/maelstrom.rs index 65ff15a..095781c 100644 --- a/src/bin/maelstrom.rs +++ b/src/bin/maelstrom.rs @@ -19,7 +19,9 @@ use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering::SeqCst; use std::sync::{Arc, Mutex}; use std::time::Duration; -use tapirs::{IrMembership, IrMessage, IrReplica, TapirClient, TapirReplica, Transport}; +use tapirs::{ + IrMembership, IrMessage, IrReplica, TapirClient, TapirReplica, TapirTransport, Transport, +}; use tokio::spawn; type K = String; @@ -193,6 +195,20 @@ impl Transport> for Maelstrom { } } +impl TapirTransport for Maelstrom { + fn shard_addresses( + &self, + shard: tapirs::ShardNumber, + ) -> impl futures::Future> + Send + 'static { + assert_eq!(shard.0, 0); + std::future::ready(IrMembership::new(vec![ + IdEnum::Replica(0), + IdEnum::Replica(1), + IdEnum::Replica(2), + ])) + } +} + #[async_trait] impl Process for KvNode { fn init( @@ -219,12 +235,11 @@ impl Process for KvNode { match id { IdEnum::Replica(_) => KvNodeInner::Replica(Arc::new(IrReplica::new( membership, - TapirReplica::new(true), + TapirReplica::new(tapirs::ShardNumber(0), true), transport, + Some(TapirReplica::tick), ))), - IdEnum::App(_) => { - KvNodeInner::App(Arc::new(TapirClient::new(membership, transport))) - } + IdEnum::App(_) => KvNodeInner::App(Arc::new(TapirClient::new(transport))), id => panic!("{id}"), }, )); diff --git a/src/ir/client.rs b/src/ir/client.rs index 5665e35..685cc17 100644 --- a/src/ir/client.rs +++ b/src/ir/client.rs @@ -27,7 +27,7 @@ use tokio::select; pub struct Id(pub u64); impl Id { - fn new() -> Self { + pub fn new() -> Self { Self(thread_rng().gen()) } } @@ -95,6 +95,10 @@ impl> Client { self.id } + pub fn set_id(&mut self, id: Id) { + self.id = id; + } + pub fn transport(&self) -> &T { &self.inner.transport } diff --git a/src/ir/replica.rs b/src/ir/replica.rs index eb8ff4c..ff3ad9e 100644 --- a/src/ir/replica.rs +++ b/src/ir/replica.rs @@ -63,10 +63,6 @@ pub trait Upcalls: Sized + Send + Serialize + DeserializeOwned + 'static { d: HashMap, u: Vec<(OpId, Self::CO, Self::CR)>, ) -> HashMap; - fn tick>(&mut self, membership: &Membership, transport: &T) { - let _ = (membership, transport); - // No-op. - } } pub struct Replica> { @@ -89,10 +85,11 @@ impl> Debug for Replica { struct Inner> { transport: T, - sync: Mutex>, + app_tick: Option)>, + sync: Mutex>, } -struct Sync> { +struct SyncInner> { status: Status, view: View, latest_normal_view: View, @@ -113,7 +110,12 @@ struct PersistentViewInfo { impl> Replica { const VIEW_CHANGE_INTERVAL: Duration = Duration::from_secs(4); - pub fn new(membership: Membership, upcalls: U, transport: T) -> Self { + pub fn new( + membership: Membership, + upcalls: U, + transport: T, + app_tick: Option)>, + ) -> Self { let view = View { membership, number: ViewNumber(0), @@ -121,7 +123,8 @@ impl> Replica { let ret = Self { inner: Arc::new(Inner { transport, - sync: Mutex::new(Sync { + app_tick, + sync: Mutex::new(SyncInner { status: Status::Normal, latest_normal_view: view.clone(), view, @@ -173,7 +176,7 @@ impl> Replica { format!("ir_replica_{}", self.inner.transport.address()) } - fn persist_view_info(&self, sync: &Sync) { + fn persist_view_info(&self, sync: &SyncInner) { if sync.view.membership.len() == 1 { return; } @@ -244,12 +247,16 @@ impl> Replica { }; let mut sync = inner.sync.lock().unwrap(); let sync = &mut *sync; - sync.upcalls.tick(&sync.view.membership, &transport); + if let Some(tick) = inner.app_tick.as_ref() { + tick(&sync.upcalls, &transport, &sync.view.membership); + } else { + break; + } } }); } - fn broadcast_do_view_change(transport: &T, sync: &mut Sync) { + fn broadcast_do_view_change(transport: &T, sync: &mut SyncInner) { sync.changed_view_recently = true; let destinations = sync .view @@ -461,7 +468,7 @@ impl> Replica { }) .map(|(_, r)| r.addendum.as_ref().unwrap().record.clone()) .collect::>(); - + eprintln!( "have {} latest ({:?})", latest_records.len(), @@ -608,7 +615,7 @@ impl> Replica { sync.latest_normal_view.number = msg_view_number; sync.latest_normal_view.membership = sync.view.membership.clone(); self.persist_view_info(&*sync); - + for address in destinations { if address == self.inner.transport.address() { continue; diff --git a/src/ir/tests/lock_server.rs b/src/ir/tests/lock_server.rs index 7e7b0e2..6dc1515 100644 --- a/src/ir/tests/lock_server.rs +++ b/src/ir/tests/lock_server.rs @@ -175,7 +175,7 @@ async fn lock_server(num_replicas: usize) { let channel = registry.channel(move |from, message| weak.upgrade()?.receive(from, message)); let upcalls = Upcalls { locked: None }; - IrReplica::new(membership.clone(), upcalls, channel) + IrReplica::new(membership.clone(), upcalls, channel, None) }, ) } diff --git a/src/lib.rs b/src/lib.rs index 1b3cce3..6049657 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,10 @@ pub use occ::{ PrepareResult as OccPrepareResult, Store as OccStore, Timestamp as OccTimestamp, Transaction as OccTransaction, TransactionId as OccTransactionId, }; -pub use tapir::{Client as TapirClient, Replica as TapirReplica, Timestamp as TapirTimestamp}; -pub use transport::{Channel as ChannelTransport, ChannelRegistry}; -pub use transport::{Message as TransportMessage, Transport}; +pub use tapir::{ + Client as TapirClient, Replica as TapirReplica, ShardNumber, Timestamp as TapirTimestamp, +}; +pub use transport::{ + Channel as ChannelTransport, ChannelRegistry, Message as TransportMessage, TapirTransport, + Transport, +}; diff --git a/src/occ/store.rs b/src/occ/store.rs index eb37064..e616234 100644 --- a/src/occ/store.rs +++ b/src/occ/store.rs @@ -1,6 +1,6 @@ use super::{Timestamp, Transaction, TransactionId}; use crate::{ - tapir::{Key, Value}, + tapir::{Key, ShardNumber, Value}, util::{vectorize, vectorize_btree}, MvccStore, }; @@ -15,6 +15,7 @@ use std::{ #[derive(Serialize, Deserialize)] pub struct Store { + shard: ShardNumber, linearizable: bool, #[serde(bound( serialize = "K: Serialize, V: Serialize, TS: Serialize", @@ -99,11 +100,28 @@ impl PrepareResult { pub fn is_fail(&self) -> bool { matches!(self, Self::Fail) } + + pub fn is_abstain(&self) -> bool { + matches!(self, Self::Abstain) + } + + pub fn is_retry(&self) -> bool { + matches!(self, Self::Retry { .. }) + } + + pub fn is_too_late(&self) -> bool { + matches!(self, Self::TooLate) + } + + pub fn is_too_old(&self) -> bool { + matches!(self, Self::TooOld) + } } impl Store { - pub fn new(linearizable: bool) -> Self { + pub fn new(shard: ShardNumber, linearizable: bool) -> Self { Self { + shard, linearizable, inner: Default::default(), prepared: Default::default(), @@ -174,8 +192,8 @@ impl Store { fn occ_check(&self, transaction: &Transaction, commit: TS) -> PrepareResult { // Check for conflicts with the read set. - for (key, read) in &transaction.read_set { - if *read > commit { + for (key, read) in transaction.shard_read_set(self.shard) { + if read > commit { debug_assert!(false, "client picked too low commit timestamp for read"); return PrepareResult::Retry { proposed: read.time(), @@ -183,9 +201,9 @@ impl Store { } // If we don't have this key then no conflicts for read. - let (beginning, end) = self.inner.get_range(key, *read); + let (beginning, end) = self.inner.get_range(key, read); - if beginning == *read { + if beginning == read { if let Some(end) = end && (self.linearizable || commit > end) { // Read value is now invalid (not the latest version), so // the prepare isn't linearizable and may not be serializable. @@ -204,7 +222,7 @@ impl Store { if self.linearizable { Bound::Unbounded } else { - Bound::Excluded(*read) + Bound::Excluded(read) }, Bound::Excluded(commit), )) @@ -218,7 +236,7 @@ impl Store { } // Check for conflicts with the write set. - for key in transaction.write_set.keys() { + for (key, _) in transaction.shard_write_set(self.shard) { { let (_, timestamp) = self.inner.get(key); // If the last commited write is after the write... @@ -262,13 +280,13 @@ impl Store { PrepareResult::Ok } - pub fn commit(&mut self, id: TransactionId, transaction: Transaction, commit: TS) { - for (key, read) in transaction.read_set { + pub fn commit(&mut self, id: TransactionId, transaction: &Transaction, commit: TS) { + for (key, read) in transaction.shard_read_set(self.shard) { self.inner.commit_get(key.clone(), read, commit); } - for (key, value) in transaction.write_set { - self.inner.put(key, value, commit); + for (key, value) in transaction.shard_write_set(self.shard) { + self.inner.put(key.clone(), value.clone(), commit); } // Note: Transaction may not be in the prepared list of this particular replica, and that's okay. @@ -307,13 +325,13 @@ impl Store { } fn add_prepared_inner(&mut self, transaction: Transaction, commit: TS) { - for key in transaction.read_set.keys() { + for (key, _) in transaction.shard_read_set(self.shard) { self.prepared_reads .entry(key.clone()) .or_default() .insert(commit, ()); } - for key in transaction.write_set.keys() { + for (key, _) in transaction.shard_write_set(self.shard) { self.prepared_writes .entry(key.clone()) .or_default() @@ -332,16 +350,16 @@ impl Store { } fn remove_prepared_inner(&mut self, transaction: Transaction, commit: TS) { - for key in transaction.read_set.into_keys() { - if let Entry::Occupied(mut occupied) = self.prepared_reads.entry(key) { + for (key, _) in transaction.shard_read_set(self.shard) { + if let Entry::Occupied(mut occupied) = self.prepared_reads.entry(key.clone()) { occupied.get_mut().remove(&commit); if occupied.get().is_empty() { occupied.remove(); } } } - for key in transaction.write_set.into_keys() { - if let Entry::Occupied(mut occupied) = self.prepared_writes.entry(key) { + for (key, _) in transaction.shard_write_set(self.shard) { + if let Entry::Occupied(mut occupied) = self.prepared_writes.entry(key.clone()) { occupied.get_mut().remove(&commit); if occupied.get().is_empty() { occupied.remove(); diff --git a/src/occ/transaction.rs b/src/occ/transaction.rs index fbd28c0..bcd73e8 100644 --- a/src/occ/transaction.rs +++ b/src/occ/transaction.rs @@ -1,22 +1,34 @@ use crate::{ - tapir::{Key, Value}, + tapir::{Key, ShardNumber, Sharded, Value}, util::vectorize, IrClientId, }; use serde::{Deserialize, Serialize}; use std::{ - collections::{hash_map::Entry, HashMap}, + collections::{hash_map::Entry, HashMap, HashSet}, fmt::Debug, hash::Hash, }; +#[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +pub struct Id { + pub client_id: IrClientId, + pub number: u64, +} + +impl Debug for Id { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Txn({}, {:?})", self.client_id.0, self.number) + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Transaction { #[serde( with = "vectorize", bound(serialize = "TS: Serialize", deserialize = "TS: Deserialize<'de>") )] - pub read_set: HashMap, + pub read_set: HashMap, TS>, #[serde( with = "vectorize", bound( @@ -24,7 +36,7 @@ pub struct Transaction { deserialize = "K: Deserialize<'de> + Eq + Hash, V: Deserialize<'de>" ) )] - pub write_set: HashMap>, + pub write_set: HashMap, Option>, } impl PartialEq for Transaction { @@ -35,15 +47,30 @@ impl PartialEq for Transaction Eq for Transaction {} -#[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] -pub struct Id { - pub client_id: IrClientId, - pub number: u64, -} +impl Transaction { + pub fn participants(&self) -> HashSet { + self.read_set + .iter() + .map(|(k, _)| k.shard) + .chain(self.write_set.iter().map(|(k, _)| k.shard)) + .collect() + } -impl Debug for Id { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Txn({}, {:?})", self.client_id.0, self.number) + pub fn shard_read_set(&self, shard: ShardNumber) -> impl Iterator + '_ { + self.read_set + .iter() + .filter(move |(k, _)| k.shard == shard) + .map(|(k, ts)| (&k.key, *ts)) + } + + pub fn shard_write_set( + &self, + shard: ShardNumber, + ) -> impl Iterator)> + '_ { + self.write_set + .iter() + .filter(move |(k, _)| k.shard == shard) + .map(|(k, v)| (&k.key, v)) } } @@ -56,8 +83,8 @@ impl Default for Transaction { } } -impl Transaction { - pub fn add_read(&mut self, key: K, timestamp: TS) { +impl Transaction { + pub fn add_read(&mut self, key: Sharded, timestamp: TS) { match self.read_set.entry(key) { Entry::Vacant(vacant) => { vacant.insert(timestamp); @@ -68,7 +95,7 @@ impl Transaction { } } - pub fn add_write(&mut self, key: K, value: Option) { + pub fn add_write(&mut self, key: Sharded, value: Option) { self.write_set.insert(key, value); } } diff --git a/src/tapir/client.rs b/src/tapir/client.rs index 9605735..109e582 100644 --- a/src/tapir/client.rs +++ b/src/tapir/client.rs @@ -1,87 +1,228 @@ -use super::{Key, Replica, ShardClient, ShardTransaction, Timestamp, Value}; -use crate::{IrMembership, OccPrepareResult, OccTransactionId, Transport}; +use super::{Key, ShardClient, ShardNumber, Sharded, Timestamp, Value}; +use crate::{ + util::join, IrClientId, OccPrepareResult, OccTransaction, OccTransactionId, TapirTransport, +}; +use futures::future::join_all; use rand::{thread_rng, Rng}; use std::{ + collections::HashMap, future::Future, - sync::atomic::{AtomicU64, Ordering}, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Mutex, + }, + task::Context, time::Duration, }; use tokio::select; -pub struct Client>> { - /// TODO: Add multiple shards. - inner: ShardClient, - #[allow(unused)] - transport: T, +pub struct Client> { + inner: Arc>>, next_transaction_number: AtomicU64, } -pub struct Transaction>> { - #[allow(unused)] +pub struct Inner> { + id: IrClientId, + clients: HashMap>, + transport: T, +} + +impl> Inner { + fn shard_client( + this: &Arc>, + shard: ShardNumber, + ) -> impl Future> + Send + 'static { + let this = Arc::clone(this); + + async move { + let future = { + let lock = this.lock().unwrap(); + if let Some(client) = lock.clients.get(&shard) { + return client.clone(); + } + lock.transport.shard_addresses(shard) + }; + + let membership = future.await; + + let mut lock = this.lock().unwrap(); + let lock = &mut *lock; + lock.clients + .entry(shard) + .or_insert_with(|| { + ShardClient::new(lock.id, shard, membership, lock.transport.clone()) + }) + .clone() + } + } +} + +pub struct Transaction> { id: OccTransactionId, - // TODO: Multiple shards. - inner: ShardTransaction, + client: Arc>>, + inner: Arc>>, +} + +struct TransactionInner { + inner: OccTransaction, + read_cache: HashMap, Option>, } -impl>> Client { - pub fn new(membership: IrMembership, transport: T) -> Self { +impl> Client { + pub fn new(transport: T) -> Self { Self { - inner: ShardClient::new(membership, transport.clone()), - transport, + inner: Arc::new(Mutex::new(Inner { + id: IrClientId::new(), + clients: Default::default(), + transport, + })), next_transaction_number: AtomicU64::new(thread_rng().gen()), } } pub fn begin(&self) -> Transaction { let transaction_id = OccTransactionId { - client_id: self.inner.id(), + client_id: self.inner.lock().unwrap().id, number: self.next_transaction_number.fetch_add(1, Ordering::Relaxed), }; Transaction { id: transaction_id, - inner: self.inner.begin(transaction_id), + client: Arc::clone(&self.inner), + inner: Arc::new(Mutex::new(TransactionInner { + inner: Default::default(), + read_cache: Default::default(), + })), } } } -impl>> Transaction { - pub fn get(&self, key: K) -> impl Future> { - self.inner.get(key) +impl> Transaction { + pub fn get(&self, key: impl Into>) -> impl Future> { + let key = key.into(); + let client = Arc::clone(&self.client); + let inner = Arc::clone(&self.inner); + + async move { + let client = Inner::shard_client(&client, key.shard).await; + + loop { + { + let lock = inner.lock().unwrap(); + + // Read own writes. + if let Some(write) = lock.inner.write_set.get(&key) { + return write.as_ref().cloned(); + } + + // Consistent reads. + if let Some(read) = lock.read_cache.get(&key) { + return read.as_ref().cloned(); + } + } + + let (value, timestamp) = client.get(key.key.clone(), None).await; + + let mut lock = inner.lock().unwrap(); + + // Read own writes. + if let Some(write) = lock.inner.write_set.get(&key) { + return write.as_ref().cloned(); + } + + // Consistent reads. + if let Some(read) = lock.read_cache.get(&key) { + return read.as_ref().cloned(); + } + + lock.read_cache.insert(key.clone(), value.clone()); + lock.inner.add_read(key, timestamp); + return value; + } + } } - pub fn put(&self, key: K, value: Option) { - self.inner.put(key, value); + pub fn put(&self, key: impl Into>, value: Option) { + let key = key.into(); + let mut lock = self.inner.lock().unwrap(); + lock.inner.add_write(key, value); } fn commit_inner(&self, only_prepare: bool) -> impl Future> { + let id = self.id; + let client = self.client.clone(); let inner = self.inner.clone(); - let min_commit_timestamp = inner.max_read_timestamp().saturating_add(1); - let mut timestamp = Timestamp { - time: inner.client.transport().time().max(min_commit_timestamp), - client_id: inner.client.id(), + + let transaction = { + let lock = inner.lock().unwrap(); + lock.inner.clone() + }; + + let min_commit_timestamp = max_read_timestamp(&transaction).saturating_add(1); + let mut timestamp = { + let client = self.client.lock().unwrap(); + Timestamp { + time: client.transport.time().max(min_commit_timestamp), + client_id: client.id, + } }; + let participants = transaction.participants(); async move { + // Writes are buffered; make sure the shard clients exist. + for key in transaction.write_set.keys() { + Inner::shard_client(&client, key.shard).await; + } + let mut remaining_tries = 3u8; loop { - let result = inner.prepare(timestamp).await; + let future = { + let client = client.lock().unwrap(); + join(participants.iter().map(|shard| { + let shard_client = client.clients.get(shard).unwrap(); + let future = shard_client.prepare(id, &transaction, timestamp); + (*shard, future) + })) + }; + + let results = future + .until( + |results: &HashMap>, + _cx: &mut Context<'_>| { + results.values().any(|v| { + v.is_fail() || v.is_abstain() || v.is_too_late() || v.is_too_old() + }) + }, + ) + .await; + + if results.values().any(|v| v.is_too_late() || v.is_too_old()) { + continue; + } - if let OccPrepareResult::Retry { proposed } = &result && let Some(new_remaining_tries) = remaining_tries.checked_sub(1) { + if participants.len() == 1 && let Some(OccPrepareResult::Retry { proposed }) = results.values().next() && let Some(new_remaining_tries) = remaining_tries.checked_sub(1) { remaining_tries = new_remaining_tries; - let new_time = inner.client.transport().time().max(proposed.saturating_add(1)).max(min_commit_timestamp); + let new_time = client.lock().unwrap().transport.time().max(proposed.saturating_add(1)).max(min_commit_timestamp); if new_time != timestamp.time { timestamp.time = new_time; continue; } } - if matches!(result, OccPrepareResult::TooLate | OccPrepareResult::TooOld) { - continue; - } - let ok = matches!(result, OccPrepareResult::Ok); + + // Ok if all participant shards are ok. + let ok = results.len() == participants.len() && results.values().all(|r| r.is_ok()); + if !only_prepare { - inner.end(timestamp, ok).await; + let future = { + let client = client.lock().unwrap(); + join_all(participants.iter().map(|shard| { + let shard_client = client.clients.get(shard).unwrap(); + shard_client.end(id, &transaction, timestamp, ok) + })) + }; + + future.await; } if ok && remaining_tries != 3 { @@ -127,3 +268,12 @@ impl>> Transaction { } } } + +pub fn max_read_timestamp(transaction: &OccTransaction) -> u64 { + transaction + .read_set + .values() + .map(|v| v.time) + .max() + .unwrap_or_default() +} diff --git a/src/tapir/mod.rs b/src/tapir/mod.rs index d869aa3..8d939e0 100644 --- a/src/tapir/mod.rs +++ b/src/tapir/mod.rs @@ -5,6 +5,7 @@ mod replica; mod shard_client; mod timestamp; +mod shard; #[cfg(test)] mod tests; @@ -12,5 +13,6 @@ pub use client::Client; pub use key_value::{Key, Value}; pub use message::{CO, CR, IO, UO, UR}; pub use replica::Replica; -pub use shard_client::{ShardClient, ShardTransaction}; +pub use shard::{Number as ShardNumber, Sharded}; +pub use shard_client::ShardClient; pub use timestamp::Timestamp; diff --git a/src/tapir/replica.rs b/src/tapir/replica.rs index 354f487..d085a08 100644 --- a/src/tapir/replica.rs +++ b/src/tapir/replica.rs @@ -1,10 +1,11 @@ -use super::{Key, Timestamp, Value, CO, CR, IO, UO, UR}; +use super::{Key, ShardNumber, Timestamp, Value, CO, CR, IO, UO, UR}; use crate::ir::ReplyUnlogged; use crate::util::vectorize; use crate::{ IrClient, IrMembership, IrMembershipSize, IrOpId, IrRecord, IrReplicaUpcalls, OccPrepareResult, - OccStore, OccTransaction, OccTransactionId, Transport, + OccStore, OccTransaction, OccTransactionId, TapirTransport, }; +use futures::future::join_all; use serde::{Deserialize, Serialize}; use std::task::Context; use std::time::Duration; @@ -42,9 +43,9 @@ pub struct Replica { } impl Replica { - pub fn new(linearizable: bool) -> Self { + pub fn new(shard: ShardNumber, linearizable: bool) -> Self { Self { - inner: OccStore::new(linearizable), + inner: OccStore::new(shard, linearizable), transaction_log: HashMap::new(), gc_watermark: 0, min_prepare_time: 0, @@ -52,18 +53,25 @@ impl Replica { } } - fn recover_coordination>( + fn recover_coordination>( transaction_id: OccTransactionId, transaction: OccTransaction, commit: Timestamp, - membership: IrMembership, + // TODO: Optimize. + _membership: IrMembership, transport: T, ) -> impl Future { eprintln!("trying to recover {transaction_id:?}"); - let client = IrClient::::new(membership, transport); + async move { - let min_prepare = client - .invoke_consensus( + let mut participants = HashMap::new(); + for shard in transaction.participants() { + let membership = transport.shard_addresses(shard).await; + participants.insert(shard, IrClient::new(membership, transport.clone())); + } + + let min_prepares = join_all(participants.values().map(|client| { + client.invoke_consensus( CO::RaiseMinPrepareTime { time: commit.time + 1, }, @@ -98,15 +106,22 @@ impl Replica { } }, ) - .await; + })) + .await; - let CR::RaiseMinPrepareTime { time: min_prepare_time } = min_prepare else { - debug_assert!(false); - return; - }; + if min_prepares.into_iter().any(|min_prepare| { + let CR::RaiseMinPrepareTime { time: min_prepare_time } = min_prepare else { + debug_assert!(false); + return true; + }; + + if commit.time >= min_prepare_time { + // Not ready. + return true; + } - if commit.time >= min_prepare_time { - // Not ready. + false + }) { return; } @@ -148,48 +163,62 @@ impl Replica { ) } - let (future, membership) = client.invoke_unlogged_joined(UO::CheckPrepare { - transaction_id, - commit, - }); - - let mut timeout = std::pin::pin!(T::sleep(Duration::from_millis(1000))); - let results = future - .until( - |results: &HashMap, T::Address>>, - cx: &mut Context<'_>| { - decide(results, membership).is_some() - || timeout.as_mut().poll(cx).is_ready() - }, - ) - .await; + let results = join_all(participants.values().map(|client| { + let (future, membership) = client.invoke_unlogged_joined(UO::CheckPrepare { + transaction_id, + commit, + }); + + async move { + let mut timeout = std::pin::pin!(T::sleep(Duration::from_millis(1000))); - let Some(result) = decide(&results, membership) else { + let results = future + .until( + |results: &HashMap, T::Address>>, + cx: &mut Context<'_>| { + decide(results, membership).is_some() + || timeout.as_mut().poll(cx).is_ready() + }, + ) + .await; + decide(&results, membership) + } + })) + .await; + + if results.iter().any(|r| r.is_none()) { + // Try again later. return; - }; + } - eprintln!("BACKUP COORD got {result:?} for {transaction_id:?} @ {commit:?}"); + let ok = results + .iter() + .all(|r| matches!(r, Some(OccPrepareResult::Ok))); - match result { - OccPrepareResult::Ok => { - client - .invoke_inconsistent(IO::Commit { - transaction_id, - transaction, - commit, - }) - .await - } - OccPrepareResult::Fail | OccPrepareResult::TooLate => { - client - .invoke_inconsistent(IO::Abort { - transaction_id, - commit: Some(commit), - }) - .await + eprintln!("BACKUP COORD got ok={ok} for {transaction_id:?} @ {commit:?}"); + + join_all(participants.values().map(|client| { + let transaction = transaction.clone(); + async move { + if ok { + client + .invoke_inconsistent(IO::Commit { + transaction_id, + transaction, + commit, + }) + .await + } else { + client + .invoke_inconsistent(IO::Abort { + transaction_id, + commit: Some(commit), + }) + .await + } } - _ => {} - } + })) + .await; } } } @@ -277,8 +306,7 @@ impl IrReplicaUpcalls for Replica { "{transaction_id:?} committed at (different) {ts:?}" ); } - self.inner - .commit(*transaction_id, transaction.clone(), *commit); + self.inner.commit(*transaction_id, transaction, *commit); } IO::Abort { transaction_id, @@ -573,8 +601,14 @@ impl IrReplicaUpcalls for Replica { ret } +} - fn tick>(&mut self, membership: &IrMembership, transport: &T) { +impl Replica { + pub fn tick>( + &self, + transport: &T, + membership: &IrMembership, + ) { eprintln!( "there are {} prepared transactions", self.inner.prepared.len() diff --git a/src/tapir/shard.rs b/src/tapir/shard.rs new file mode 100644 index 0000000..34eae0f --- /dev/null +++ b/src/tapir/shard.rs @@ -0,0 +1,27 @@ +use serde::{Deserialize, Serialize}; +use std::fmt::{self, Debug, Formatter}; + +/// Identifies a shard consisting of a group of replicas. +#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +pub struct Number(pub u32); + +impl Debug for Number { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "S({})", self.0) + } +} + +#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] +pub struct Sharded { + pub shard: Number, + pub key: K, +} + +impl From for Sharded { + fn from(key: K) -> Self { + Self { + shard: Number(0), + key, + } + } +} diff --git a/src/tapir/shard_client.rs b/src/tapir/shard_client.rs index 8b1038f..463a4f6 100644 --- a/src/tapir/shard_client.rs +++ b/src/tapir/shard_client.rs @@ -1,144 +1,71 @@ -use super::{Key, Replica, Timestamp, Value, CO, CR, IO, UO, UR}; +use super::{Key, Replica, ShardNumber, Timestamp, Value, CO, CR, IO, UO, UR}; use crate::{ transport::Transport, IrClient, IrClientId, IrMembership, OccPrepareResult, OccTransaction, OccTransactionId, }; -use std::{ - collections::HashMap, - future::Future, - sync::{Arc, Mutex}, -}; +use std::future::Future; pub struct ShardClient>> { + shard: ShardNumber, inner: IrClient, T>, } -impl>> ShardClient { - pub fn new(membership: IrMembership, transport: T) -> Self { - Self { - inner: IrClient::new(membership, transport), - } - } - - // TODO: Use same id for all shards? - pub fn id(&self) -> IrClientId { - self.inner.id() - } - - pub fn begin(&self, transaction_id: OccTransactionId) -> ShardTransaction { - ShardTransaction::new(self.inner.clone(), transaction_id) - } -} - -pub struct ShardTransaction>> { - pub client: IrClient, T>, - inner: Arc>>, -} - -impl>> Clone for ShardTransaction { +impl>> Clone for ShardClient { fn clone(&self) -> Self { Self { - client: self.client.clone(), - inner: Arc::clone(&self.inner), + shard: self.shard, + inner: self.inner.clone(), } } } -struct Inner { - id: OccTransactionId, - inner: OccTransaction, - read_cache: HashMap>, -} - -impl>> ShardTransaction { - fn new(client: IrClient, T>, id: OccTransactionId) -> Self { - Self { - client, - inner: Arc::new(Mutex::new(Inner { - id, - inner: Default::default(), - read_cache: Default::default(), - })), - } - } - - pub fn max_read_timestamp(&self) -> u64 { - self.inner - .lock() - .unwrap() - .inner - .read_set - .values() - .map(|v| v.time) - .max() - .unwrap_or_default() +impl>> ShardClient { + pub fn new( + id: IrClientId, + shard: ShardNumber, + membership: IrMembership, + transport: T, + ) -> Self { + let mut inner = IrClient::new(membership, transport); + + // Id of all shard clients must match for the timestamps to match during recovery. + inner.set_id(id); + + Self { shard, inner } } - pub fn get(&self, key: K) -> impl Future> + Send { - let client = self.client.clone(); - let inner = Arc::clone(&self.inner); + pub fn get( + &self, + key: K, + timestamp: Option, + ) -> impl Future, Timestamp)> { + let future = self.inner.invoke_unlogged(UO::Get { key, timestamp }); async move { - loop { - { - let lock = inner.lock().unwrap(); - - // Read own writes. - if let Some(write) = lock.inner.write_set.get(&key) { - return write.as_ref().cloned(); - } - - // Consistent reads. - if let Some(read) = lock.read_cache.get(&key) { - return read.as_ref().cloned(); - } - } - - let future = client.invoke_unlogged(UO::Get { - key: key.clone(), - timestamp: None, - }); - - let reply = future.await; - - let UR::Get(value, timestamp) = reply else { - debug_assert!(false); - continue; - }; - - let mut lock = inner.lock().unwrap(); - - // Read own writes. - if let Some(write) = lock.inner.write_set.get(&key) { - return write.as_ref().cloned(); - } + let reply = future.await; - // Consistent reads. - if let Some(read) = lock.read_cache.get(&key) { - return read.as_ref().cloned(); - } + if let UR::Get(value, timestamp) = reply { + (value, timestamp) + } else { + debug_assert!(false); - lock.read_cache.insert(key.clone(), value.clone()); - lock.inner.add_read(key, timestamp); - return value; + // Was valid at the beginning of time (the transaction will + // abort if that's too old). + (None, Default::default()) } } } - pub fn put(&self, key: K, value: Option) { - let mut lock = self.inner.lock().unwrap(); - lock.inner.add_write(key, value); - } - pub fn prepare( &self, + transaction_id: OccTransactionId, + transaction: &OccTransaction, timestamp: Timestamp, ) -> impl Future> + Send { - let lock = self.inner.lock().unwrap(); - let future = self.client.invoke_consensus( + let future = self.inner.invoke_consensus( CO::Prepare { - transaction_id: lock.id, - transaction: lock.inner.clone(), + transaction_id, + transaction: transaction.clone(), commit: timestamp, }, |results, membership_size| { @@ -187,7 +114,6 @@ impl>> ShardTransaction { }) }, ); - drop(lock); async move { let reply = future.await; @@ -202,23 +128,22 @@ impl>> ShardTransaction { pub fn end( &self, + transaction_id: OccTransactionId, + transaction: &OccTransaction, prepared_timestamp: Timestamp, commit: bool, ) -> impl Future + Send { - let lock = self.inner.lock().unwrap(); - let future = self.client.invoke_inconsistent(if commit { + self.inner.invoke_inconsistent(if commit { IO::Commit { - transaction_id: lock.id, - transaction: lock.inner.clone(), + transaction_id, + transaction: transaction.clone(), commit: prepared_timestamp, } } else { IO::Abort { - transaction_id: lock.id, + transaction_id, commit: None, } - }); - drop(lock); - future + }) } } diff --git a/src/tapir/tests/kv.rs b/src/tapir/tests/kv.rs index 2a222c7..f7f1519 100644 --- a/src/tapir/tests/kv.rs +++ b/src/tapir/tests/kv.rs @@ -1,11 +1,9 @@ -use futures::future::join_all; -use rand::{thread_rng, Rng}; -use tokio::time::timeout; - use crate::{ - ChannelRegistry, ChannelTransport, IrMembership, IrReplica, TapirClient, TapirReplica, - TapirTimestamp, Transport as _, + tapir::Sharded, ChannelRegistry, ChannelTransport, IrMembership, IrReplica, ShardNumber, + TapirClient, TapirReplica, TapirTimestamp, Transport as _, }; +use futures::future::join_all; +use rand::{thread_rng, Rng}; use std::{ sync::{ atomic::{AtomicU64, Ordering}, @@ -13,28 +11,28 @@ use std::{ }, time::Duration, }; +use tokio::time::timeout; type K = i64; type V = i64; type Transport = ChannelTransport>; -fn build_kv( +fn build_shard( + shard: ShardNumber, linearizable: bool, num_replicas: usize, - num_clients: usize, -) -> ( - Vec, ChannelTransport>>>>, - Vec>>>>, -) { - println!("---------------------------"); - println!(" linearizable={linearizable} num_replicas={num_replicas}"); - println!("---------------------------"); - - let registry = ChannelRegistry::default(); - let membership = IrMembership::new((0..num_replicas).collect::>()); + registry: &ChannelRegistry>, +) -> Vec, ChannelTransport>>>> { + let initial_address = registry.len(); + let membership = IrMembership::new( + (0..num_replicas) + .map(|n| n + initial_address) + .collect::>(), + ); fn create_replica( registry: &ChannelRegistry>, + shard: ShardNumber, membership: &IrMembership, linearizable: bool, ) -> Arc, ChannelTransport>>> { @@ -45,29 +43,86 @@ fn build_kv( let weak = weak.clone(); let channel = registry.channel(move |from, message| weak.upgrade()?.receive(from, message)); - let upcalls = TapirReplica::new(linearizable); - IrReplica::new(membership.clone(), upcalls, channel) + let upcalls = TapirReplica::new(shard, linearizable); + IrReplica::new( + membership.clone(), + upcalls, + channel, + Some(TapirReplica::tick), + ) }, ) } - let replicas = std::iter::repeat_with(|| create_replica(®istry, &membership, linearizable)) - .take(num_replicas) - .collect::>(); + let replicas = + std::iter::repeat_with(|| create_replica(®istry, shard, &membership, linearizable)) + .take(num_replicas) + .collect::>(); + + registry.put_shard_addresses(shard, membership.clone()); + + replicas +} +fn build_clients( + num_clients: usize, + registry: &ChannelRegistry>, +) -> Vec>>>> { fn create_client( registry: &ChannelRegistry>, - membership: &IrMembership, ) -> Arc>>> { let channel = registry.channel(move |_, _| unreachable!()); - Arc::new(TapirClient::new(membership.clone(), channel)) + Arc::new(TapirClient::new(channel)) } - let clients = std::iter::repeat_with(|| create_client(®istry, &membership)) + let clients = std::iter::repeat_with(|| create_client(®istry)) .take(num_clients) .collect::>(); - (replicas, clients) + clients +} + +fn build_kv( + linearizable: bool, + num_replicas: usize, + num_clients: usize, +) -> ( + Vec, ChannelTransport>>>>, + Vec>>>>, +) { + let (mut shards, clients) = build_sharded_kv(linearizable, 1, num_replicas, num_clients); + (shards.remove(0), clients) +} + +fn build_sharded_kv( + linearizable: bool, + num_shards: usize, + num_replicas: usize, + num_clients: usize, +) -> ( + Vec, ChannelTransport>>>>>, + Vec>>>>, +) { + println!("---------------------------"); + println!(" linearizable={linearizable} num_shards={num_shards} num_replicas={num_replicas}"); + println!("---------------------------"); + + let registry = ChannelRegistry::default(); + + let mut shards = Vec::new(); + for shard in 0..num_shards { + let replicas = build_shard( + ShardNumber(shard as u32), + linearizable, + num_replicas, + ®istry, + ); + shards.push(replicas); + } + + let clients = build_clients(num_clients, ®istry); + + (shards, clients) } #[tokio::test] @@ -132,6 +187,37 @@ async fn rwr(linearizable: bool, num_replicas: usize) { } } +#[tokio::test] +async fn sharded() { + let (_shards, clients) = build_sharded_kv(true, 5, 3, 2); + + let txn = clients[0].begin(); + assert_eq!( + txn.get(Sharded { + shard: ShardNumber(0), + key: 0 + }) + .await, + None + ); + assert_eq!( + txn.get(Sharded { + shard: ShardNumber(1), + key: 0 + }) + .await, + None + ); + txn.put( + Sharded { + shard: ShardNumber(2), + key: 0, + }, + Some(0), + ); + assert!(txn.commit().await.is_some()); +} + #[tokio::test] async fn increment_sequential_3() { increment_sequential_timeout(3).await; diff --git a/src/transport/channel.rs b/src/transport/channel.rs index 42db063..8bba95f 100644 --- a/src/transport/channel.rs +++ b/src/transport/channel.rs @@ -1,5 +1,8 @@ -use super::Transport; -use crate::{IrMessage, IrReplicaUpcalls}; +use super::{TapirTransport, Transport}; +use crate::{ + tapir::{Key, Value}, + IrMembership, IrMessage, IrReplicaUpcalls, ShardNumber, TapirReplica, +}; use rand::{thread_rng, Rng}; use serde::{de::DeserializeOwned, Serialize}; use std::{ @@ -34,12 +37,14 @@ struct Inner { + Sync, >, >, + shards: HashMap>, } impl Default for Inner { fn default() -> Self { Self { callbacks: Vec::new(), + shards: Default::default(), } } } @@ -61,6 +66,15 @@ impl Registry { inner: Arc::clone(&self.inner), } } + + pub fn put_shard_addresses(&self, shard: ShardNumber, membership: IrMembership) { + let mut inner = self.inner.write().unwrap(); + inner.shards.insert(shard, membership); + } + + pub fn len(&self) -> usize { + self.inner.read().unwrap().callbacks.len() + } } pub struct Channel { @@ -208,3 +222,24 @@ impl Transport for Channel { } } } + +impl TapirTransport for Channel> { + fn shard_addresses( + &self, + shard: ShardNumber, + ) -> impl Future> + Send + 'static { + let inner = Arc::clone(&self.inner); + async move { + loop { + { + let inner = inner.read().unwrap(); + if let Some(membership) = inner.shards.get(&shard) { + break membership.clone(); + } + } + + >>::sleep(Duration::from_millis(100)).await; + } + } + } +} diff --git a/src/transport/mod.rs b/src/transport/mod.rs index e00aea5..4270c56 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -1,4 +1,7 @@ -use crate::{IrMessage, IrReplicaUpcalls}; +use crate::{ + tapir::{Key, Value}, + IrMembership, IrMessage, IrReplicaUpcalls, ShardNumber, TapirReplica, +}; pub use channel::{Channel, Registry as ChannelRegistry}; pub use message::Message; use serde::{de::DeserializeOwned, Serialize}; @@ -35,13 +38,13 @@ pub trait Transport: Clone + Send + Sync + 'static { /// Sleep for duration. fn sleep(duration: Duration) -> Self::Sleep; - /// Synchronously persist a key-value pair. Any future calls - /// to `persisted` should return this value unless/until it - /// is overwritten. - // TODO: Allow safe expiration mechanism for checkpoints. + /// Synchronously and atomically persist a key-value pair. Any + /// future calls to `persisted` should return this value + /// unless/until it is overwritten. fn persist(&self, key: &str, value: Option<&T>); - /// Synchronously load a persisted key-value pair. + /// Synchronously load the last key-value pair successfully persisted + /// at the given key. fn persisted(&self, key: &str) -> Option; /// Send/retry, ignoring any errors, until there is a reply. @@ -54,3 +57,12 @@ pub trait Transport: Clone + Send + Sync + 'static { /// Send once and don't wait for a reply. fn do_send(&self, address: Self::Address, message: impl Into> + Debug); } + +pub trait TapirTransport: Transport> { + /// Look up the addresses of replicas in a shard, on a best-effort basis; results + /// may be arbitrarily out of date. + fn shard_addresses( + &self, + shard: ShardNumber, + ) -> impl Future> + Send + 'static; +}