From 3b9e7fc5e6715a694b27e059ee7eedd1f6f44735 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Mon, 26 Apr 2021 13:07:51 +0300 Subject: [PATCH 1/2] Use explicit threads. Remove 'async' usage a much as feasible. Async code is harder to debug, and mixing async and non-async code is a recipe for confusion and bugs. There are a couple of exceptions: - The code in walredo.rs, which needs to read and write to the child process simultaneously, still uses async. It's more convenient there. The 'async' usage is carefully limited to just the functions that communicate with the child process. - Code in walreceiver.rs that uses tokio-postgres to do streaming replication. We have to use async there, because tokio-postgres is async. Most rust-postgres functionality has non-async wrappers, but not the new replication client code. The async usage is very limited here, too: we use just block_on to call the tokio-postgres functions. The code in 'page_service.rs' now launches a dedicated thread for each connection. This replaces tokio::sync::watch::channel with std::sync:mpsc in 'seqwait.rs', to make that non-async. It's not a drop-in replacement, though: std::sync::mpsc doesn't support multiple consumers, so we cannot share a channel between multiple waiters. So this removes the code to check if an existing channel can be reused, and creates a new one for each waiter. That created another problem: BTreeMap cannot hold duplicates, so I replaced that with BinaryHeap. Similarly, the tokio::{mpsc, oneshot} channels used between WAL redo manager and PageCache are replaced with std::sync::mpsc. (There is no separate 'oneshot' channel in the standard library.) Fixes github issue #58, and coincidentally also issue #66. --- Cargo.lock | 1 - pageserver/src/page_cache.rs | 32 +-- pageserver/src/page_service.rs | 489 +++++++++++++++------------------ pageserver/src/walreceiver.rs | 75 +++-- pageserver/src/walredo.rs | 66 +++-- zenith_utils/Cargo.toml | 4 - zenith_utils/src/seqwait.rs | 166 ++++++----- 7 files changed, 412 insertions(+), 421 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index da3877927925..072aebc03257 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2761,5 +2761,4 @@ name = "zenith_utils" version = "0.1.0" dependencies = [ "thiserror", - "tokio", ] diff --git a/pageserver/src/page_cache.rs b/pageserver/src/page_cache.rs index 667b16fb9062..aae2fcef7774 100644 --- a/pageserver/src/page_cache.rs +++ b/pageserver/src/page_cache.rs @@ -62,7 +62,7 @@ pub struct PageCache { // WAL redo manager walredo_mgr: WalRedoManager, - // Allows .await on the arrival of a particular LSN. + // Allows waiting for the arrival of a particular LSN. seqwait_lsn: SeqWait, // Counters, for metrics collection. @@ -170,12 +170,7 @@ fn gc_thread_main(conf: &PageServerConf, timelineid: ZTimelineId) { info!("Garbage collection thread started {}", timelineid); let pcache = get_pagecache(conf, timelineid).unwrap(); - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - - runtime.block_on(pcache.do_gc(conf)).unwrap(); + pcache.do_gc(conf).unwrap(); } fn open_rocksdb(_conf: &PageServerConf, timelineid: ZTimelineId) -> rocksdb::DB { @@ -380,10 +375,10 @@ impl PageCache { /// /// Returns an 8k page image /// - pub async fn get_page_at_lsn(&self, tag: BufferTag, req_lsn: Lsn) -> anyhow::Result { + pub fn get_page_at_lsn(&self, tag: BufferTag, req_lsn: Lsn) -> anyhow::Result { self.num_getpage_requests.fetch_add(1, Ordering::Relaxed); - let lsn = self.wait_lsn(req_lsn).await?; + let lsn = self.wait_lsn(req_lsn)?; // Look up cache entry. If it's a page image, return that. If it's a WAL record, // ask the WAL redo service to reconstruct the page image from the WAL records. @@ -409,7 +404,7 @@ impl PageCache { page_img = img.clone(); } else if content.wal_record.is_some() { // Request the WAL redo manager to apply the WAL records for us. - page_img = self.walredo_mgr.request_redo(tag, lsn).await?; + page_img = self.walredo_mgr.request_redo(tag, lsn)?; } else { // No base image, and no WAL record. Huh? bail!("no page image or WAL record for requested page"); @@ -441,16 +436,16 @@ impl PageCache { /// /// Get size of relation at given LSN. /// - pub async fn relsize_get(&self, rel: &RelTag, lsn: Lsn) -> anyhow::Result { - self.wait_lsn(lsn).await?; + pub fn relsize_get(&self, rel: &RelTag, lsn: Lsn) -> anyhow::Result { + self.wait_lsn(lsn)?; return self.relsize_get_nowait(rel, lsn); } /// /// Does relation exist at given LSN? /// - pub async fn relsize_exist(&self, rel: &RelTag, req_lsn: Lsn) -> anyhow::Result { - let lsn = self.wait_lsn(req_lsn).await?; + pub fn relsize_exist(&self, rel: &RelTag, req_lsn: Lsn) -> anyhow::Result { + let lsn = self.wait_lsn(req_lsn)?; let key = CacheKey { tag: BufferTag { @@ -815,7 +810,7 @@ impl PageCache { Ok(0) } - async fn do_gc(&self, conf: &PageServerConf) -> anyhow::Result { + fn do_gc(&self, conf: &PageServerConf) -> anyhow::Result { let mut buf = BytesMut::new(); loop { thread::sleep(conf.gc_period); @@ -867,7 +862,7 @@ impl PageCache { if (v[0] & PAGE_IMAGE_FLAG) == 0 { trace!("Reconstruct most recent page {:?}", key); // force reconstruction of most recent page version - self.walredo_mgr.request_redo(key.tag, key.lsn).await?; + self.walredo_mgr.request_redo(key.tag, key.lsn)?; reconstructed += 1; } @@ -887,7 +882,7 @@ impl PageCache { let v = iter.value().unwrap(); if (v[0] & PAGE_IMAGE_FLAG) == 0 { trace!("Reconstruct horizon page {:?}", key); - self.walredo_mgr.request_redo(key.tag, key.lsn).await?; + self.walredo_mgr.request_redo(key.tag, key.lsn)?; truncated += 1; } } @@ -930,7 +925,7 @@ impl PageCache { // // Wait until WAL has been received up to the given LSN. // - async fn wait_lsn(&self, mut lsn: Lsn) -> anyhow::Result { + fn wait_lsn(&self, mut lsn: Lsn) -> anyhow::Result { // When invalid LSN is requested, it means "don't wait, return latest version of the page" // This is necessary for bootstrap. if lsn == Lsn(0) { @@ -945,7 +940,6 @@ impl PageCache { self.seqwait_lsn .wait_for_timeout(lsn, TIMEOUT) - .await .with_context(|| { format!( "Timed out while waiting for WAL record at LSN {} to arrive", diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 72f97aaaa7c3..2721aa487362 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -10,21 +10,16 @@ // *callmemaybe $url* -- ask pageserver to start walreceiver on $url // -use byteorder::{BigEndian, ByteOrder}; +use byteorder::{ReadBytesExt, WriteBytesExt, BE}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use log::*; use regex::Regex; use std::io; +use std::io::{BufReader, BufWriter, Read, Write}; +use std::net::{TcpListener, TcpStream}; use std::str::FromStr; -use std::sync::Arc; use std::thread; use std::time::Duration; -use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}; -use tokio::net::{TcpListener, TcpStream}; -use tokio::runtime; -use tokio::runtime::Runtime; -use tokio::sync::mpsc; -use tokio::task; use zenith_utils::lsn::Lsn; use crate::basebackup; @@ -116,26 +111,41 @@ enum StartupRequestCode { } impl FeStartupMessage { - pub fn parse(buf: &mut BytesMut) -> Result> { + pub fn read(stream: &mut dyn std::io::Read) -> Result> { const MAX_STARTUP_PACKET_LENGTH: u32 = 10000; const CANCEL_REQUEST_CODE: u32 = (1234 << 16) | 5678; const NEGOTIATE_SSL_CODE: u32 = (1234 << 16) | 5679; const NEGOTIATE_GSS_CODE: u32 = (1234 << 16) | 5680; - if buf.len() < 4 { - return Ok(None); - } - let len = BigEndian::read_u32(&buf[0..4]); - + // Read length. If the connection is closed before reading anything (or before + // reading 4 bytes, to be precise), return None to indicate that the connection + // was closed. This matches the PostgreSQL server's behavior, which avoids noise + // in the log if the client opens connection but closes it immediately. + let len = match stream.read_u32::() { + Ok(len) => len, + Err(err) => { + if err.kind() == std::io::ErrorKind::UnexpectedEof { + return Ok(None); + } else { + return Err(err); + } + } + }; if len < 4 || len as u32 > MAX_STARTUP_PACKET_LENGTH { return Err(io::Error::new( io::ErrorKind::InvalidData, "invalid message length", )); } + let bodylen = len - 4; - let version = BigEndian::read_u32(&buf[4..8]); + // Read the rest of the startup packet + let mut body_buf: Vec = vec![0; bodylen as usize]; + stream.read_exact(&mut body_buf)?; + let mut body = Bytes::from(body_buf); + // Parse the first field, which indicates what kind of a packet it is + let version = body.get_u32(); let kind = match version { CANCEL_REQUEST_CODE => StartupRequestCode::Cancel, NEGOTIATE_SSL_CODE => StartupRequestCode::NegotiateSsl, @@ -143,7 +153,8 @@ impl FeStartupMessage { _ => StartupRequestCode::Normal, }; - buf.advance(len as usize); + // Ignore the rest of the packet + Ok(Some(FeMessage::StartupMessage(FeStartupMessage { version, kind, @@ -328,35 +339,38 @@ impl FeCloseMessage { } impl FeMessage { - pub fn parse(buf: &mut BytesMut) -> Result> { - if buf.len() < 5 { - let to_read = 5 - buf.len(); - buf.reserve(to_read); - return Ok(None); - } - - let tag = buf[0]; - let len = BigEndian::read_u32(&buf[1..5]); + pub fn read(stream: &mut dyn Read) -> Result> { + // Each libpq message begins with a message type byte, followed by message length + // If the client closes the connection, return None. But if the client closes the + // connection in the middle of a message, we will return an error. + let tag = match stream.read_u8() { + Ok(b) => b, + Err(err) => { + if err.kind() == std::io::ErrorKind::UnexpectedEof { + return Ok(None); + } else { + return Err(err); + } + } + }; + let len = stream.read_u32::()?; + // The message length includes itself, so it better be at least 4 if len < 4 { return Err(io::Error::new( io::ErrorKind::InvalidInput, "invalid message length: parsing u32", )); } + let bodylen = len - 4; - let total_len = len as usize + 1; - if buf.len() < total_len { - let to_read = total_len - buf.len(); - buf.reserve(to_read); - return Ok(None); - } - - let mut body = buf.split_to(total_len); - body.advance(5); + // Read message body + let mut body_buf: Vec = vec![0; bodylen as usize]; + stream.read_exact(&mut body_buf)?; - let mut body = body.freeze(); + let mut body = Bytes::from(body_buf); + // Parse it match tag { b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))), b'P' => Ok(Some(FeParseMessage::parse(body)?)), @@ -385,13 +399,13 @@ impl FeMessage { 2 => Ok(Some(FeMessage::ZenithReadRequest(zreq))), _ => Err(io::Error::new( io::ErrorKind::InvalidInput, - format!("unknown smgr message tag: {},'{:?}'", smgr_tag, buf), + format!("unknown smgr message tag: {},'{:?}'", smgr_tag, body), )), } } tag => Err(io::Error::new( io::ErrorKind::InvalidInput, - format!("unknown message tag: {},'{:?}'", tag, buf), + format!("unknown message tag: {},'{:?}'", tag, body), )), } } @@ -399,152 +413,118 @@ impl FeMessage { /////////////////////////////////////////////////////////////////////////////// +/// +/// Main loop of the page service. +/// +/// Listens for connections, and launches a new handler thread for each. +/// pub fn thread_main(conf: &PageServerConf) { - // Create a new thread pool - // - // FIXME: It would be nice to keep this single-threaded for debugging purposes, - // but that currently leads to a deadlock: if a GetPage@LSN request arrives - // for an LSN that hasn't been received yet, the thread gets stuck waiting for - // the WAL to arrive. If the WAL receiver hasn't been launched yet, i.e - // we haven't received a "callmemaybe" request yet to tell us where to get the - // WAL, we will not have a thread available to process the "callmemaybe" - // request when it does arrive. Using a thread pool alleviates the problem so - // that it doesn't happen in the tests anymore, but in principle it could still - // happen if we receive enough GetPage@LSN requests to consume all of the - // available threads. - //let runtime = runtime::Builder::new_current_thread().enable_all().build().unwrap(); - let runtime = runtime::Runtime::new().unwrap(); - info!("Starting page server on {}", conf.listen_addr); - let runtime_ref = Arc::new(runtime); - - runtime_ref.block_on(async { - let listener = TcpListener::bind(conf.listen_addr).await.unwrap(); + let listener = TcpListener::bind(conf.listen_addr).unwrap(); - loop { - let (socket, peer_addr) = listener.accept().await.unwrap(); - debug!("accepted connection from {}", peer_addr); - socket.set_nodelay(true).unwrap(); - let mut conn_handler = Connection::new(conf.clone(), socket, &runtime_ref); - - task::spawn(async move { - if let Err(err) = conn_handler.run().await { - error!("error: {}", err); - } - }); - } - }); + loop { + let (socket, peer_addr) = listener.accept().unwrap(); + debug!("accepted connection from {}", peer_addr); + socket.set_nodelay(true).unwrap(); + let mut conn_handler = Connection::new(conf.clone(), socket); + + thread::spawn(move || { + if let Err(err) = conn_handler.run() { + error!("error: {}", err); + } + }); + } } #[derive(Debug)] struct Connection { + stream_in: BufReader, stream: BufWriter, buffer: BytesMut, init_done: bool, conf: PageServerConf, - runtime: Arc, } impl Connection { - pub fn new(conf: PageServerConf, socket: TcpStream, runtime: &Arc) -> Connection { + pub fn new(conf: PageServerConf, socket: TcpStream) -> Connection { Connection { + stream_in: BufReader::new(socket.try_clone().unwrap()), stream: BufWriter::new(socket), buffer: BytesMut::with_capacity(10 * 1024), init_done: false, conf, - runtime: Arc::clone(runtime), } } // // Read full message or return None if connection is closed // - async fn read_message(&mut self) -> Result> { - loop { - if let Some(message) = self.parse_message()? { - return Ok(Some(message)); - } - - if self.stream.read_buf(&mut self.buffer).await? == 0 { - if self.buffer.is_empty() { - return Ok(None); - } else { - return Err(io::Error::new( - io::ErrorKind::Other, - "connection reset by peer", - )); - } - } - } - } - - fn parse_message(&mut self) -> Result> { + fn read_message(&mut self) -> Result> { if !self.init_done { - FeStartupMessage::parse(&mut self.buffer) + FeStartupMessage::read(&mut self.stream_in) } else { - FeMessage::parse(&mut self.buffer) + FeMessage::read(&mut self.stream_in) } } - async fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<()> { + fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<()> { match message { BeMessage::AuthenticationOk => { - self.stream.write_u8(b'R').await?; - self.stream.write_i32(4 + 4).await?; - self.stream.write_i32(0).await?; + self.stream.write_u8(b'R')?; + self.stream.write_i32::(4 + 4)?; + self.stream.write_i32::(0)?; } BeMessage::ReadyForQuery => { - self.stream.write_u8(b'Z').await?; - self.stream.write_i32(4 + 1).await?; - self.stream.write_u8(b'I').await?; + self.stream.write_u8(b'Z')?; + self.stream.write_i32::(4 + 1)?; + self.stream.write_u8(b'I')?; } BeMessage::ParseComplete => { - self.stream.write_u8(b'1').await?; - self.stream.write_i32(4).await?; + self.stream.write_u8(b'1')?; + self.stream.write_i32::(4)?; } BeMessage::BindComplete => { - self.stream.write_u8(b'2').await?; - self.stream.write_i32(4).await?; + self.stream.write_u8(b'2')?; + self.stream.write_i32::(4)?; } BeMessage::CloseComplete => { - self.stream.write_u8(b'3').await?; - self.stream.write_i32(4).await?; + self.stream.write_u8(b'3')?; + self.stream.write_i32::(4)?; } BeMessage::NoData => { - self.stream.write_u8(b'n').await?; - self.stream.write_i32(4).await?; + self.stream.write_u8(b'n')?; + self.stream.write_i32::(4)?; } BeMessage::ParameterDescription => { - self.stream.write_u8(b't').await?; - self.stream.write_i32(6).await?; + self.stream.write_u8(b't')?; + self.stream.write_i32::(6)?; // we don't support params, so always 0 - self.stream.write_i16(0).await?; + self.stream.write_i16::(0)?; } BeMessage::RowDescription => { // XXX let b = Bytes::from("data\0"); - self.stream.write_u8(b'T').await?; + self.stream.write_u8(b'T')?; self.stream - .write_i32(4 + 2 + b.len() as i32 + 3 * (4 + 2)) - .await?; - - self.stream.write_i16(1).await?; - self.stream.write_all(&b).await?; - self.stream.write_i32(0).await?; /* table oid */ - self.stream.write_i16(0).await?; /* attnum */ - self.stream.write_i32(25).await?; /* TEXTOID */ - self.stream.write_i16(-1).await?; /* typlen */ - self.stream.write_i32(0).await?; /* typmod */ - self.stream.write_i16(0).await?; /* format code */ + .write_i32::(4 + 2 + b.len() as i32 + 3 * (4 + 2))?; + + self.stream.write_i16::(1)?; + self.stream.write_all(&b)?; + self.stream.write_i32::(0)?; /* table oid */ + self.stream.write_i16::(0)?; /* attnum */ + self.stream.write_i32::(25)?; /* TEXTOID */ + self.stream.write_i16::(-1)?; /* typlen */ + self.stream.write_i32::(0)?; /* typmod */ + self.stream.write_i16::(0)?; /* format code */ } // XXX: accept some text data @@ -552,74 +532,73 @@ impl Connection { // XXX let b = Bytes::from("hello world"); - self.stream.write_u8(b'D').await?; - self.stream.write_i32(4 + 2 + 4 + b.len() as i32).await?; + self.stream.write_u8(b'D')?; + self.stream.write_i32::(4 + 2 + 4 + b.len() as i32)?; - self.stream.write_i16(1).await?; - self.stream.write_i32(b.len() as i32).await?; - self.stream.write_all(&b).await?; + self.stream.write_i16::(1)?; + self.stream.write_i32::(b.len() as i32)?; + self.stream.write_all(&b)?; } BeMessage::ControlFile => { // TODO pass checkpoint and xid info in this message let b = Bytes::from("hello pg_control"); - self.stream.write_u8(b'D').await?; - self.stream.write_i32(4 + 2 + 4 + b.len() as i32).await?; + self.stream.write_u8(b'D')?; + self.stream.write_i32::(4 + 2 + 4 + b.len() as i32)?; - self.stream.write_i16(1).await?; - self.stream.write_i32(b.len() as i32).await?; - self.stream.write_all(&b).await?; + self.stream.write_i16::(1)?; + self.stream.write_i32::(b.len() as i32)?; + self.stream.write_all(&b)?; } BeMessage::CommandComplete => { let b = Bytes::from("SELECT 1\0"); - self.stream.write_u8(b'C').await?; - self.stream.write_i32(4 + b.len() as i32).await?; - self.stream.write_all(&b).await?; + self.stream.write_u8(b'C')?; + self.stream.write_i32::(4 + b.len() as i32)?; + self.stream.write_all(&b)?; } BeMessage::ZenithStatusResponse(resp) => { - self.stream.write_u8(b'd').await?; - self.stream.write_u32(4 + 1 + 1 + 4).await?; - self.stream.write_u8(100).await?; /* tag from pagestore_client.h */ - self.stream.write_u8(resp.ok as u8).await?; - self.stream.write_u32(resp.n_blocks).await?; + self.stream.write_u8(b'd')?; + self.stream.write_u32::(4 + 1 + 1 + 4)?; + self.stream.write_u8(100)?; /* tag from pagestore_client.h */ + self.stream.write_u8(resp.ok as u8)?; + self.stream.write_u32::(resp.n_blocks)?; } BeMessage::ZenithNblocksResponse(resp) => { - self.stream.write_u8(b'd').await?; - self.stream.write_u32(4 + 1 + 1 + 4).await?; - self.stream.write_u8(101).await?; /* tag from pagestore_client.h */ - self.stream.write_u8(resp.ok as u8).await?; - self.stream.write_u32(resp.n_blocks).await?; + self.stream.write_u8(b'd')?; + self.stream.write_u32::(4 + 1 + 1 + 4)?; + self.stream.write_u8(101)?; /* tag from pagestore_client.h */ + self.stream.write_u8(resp.ok as u8)?; + self.stream.write_u32::(resp.n_blocks)?; } BeMessage::ZenithReadResponse(resp) => { - self.stream.write_u8(b'd').await?; + self.stream.write_u8(b'd')?; self.stream - .write_u32(4 + 1 + 1 + 4 + resp.page.len() as u32) - .await?; - self.stream.write_u8(102).await?; /* tag from pagestore_client.h */ - self.stream.write_u8(resp.ok as u8).await?; - self.stream.write_u32(resp.n_blocks).await?; - self.stream.write_all(&resp.page.clone()).await?; + .write_u32::(4 + 1 + 1 + 4 + resp.page.len() as u32)?; + self.stream.write_u8(102)?; /* tag from pagestore_client.h */ + self.stream.write_u8(resp.ok as u8)?; + self.stream.write_u32::(resp.n_blocks)?; + self.stream.write_all(&resp.page.clone())?; } } Ok(()) } - async fn write_message(&mut self, message: &BeMessage) -> io::Result<()> { - self.write_message_noflush(message).await?; - self.stream.flush().await + fn write_message(&mut self, message: &BeMessage) -> io::Result<()> { + self.write_message_noflush(message)?; + self.stream.flush() } - async fn run(&mut self) -> Result<()> { + fn run(&mut self) -> Result<()> { let mut unnamed_query_string = Bytes::new(); loop { - let msg = self.read_message().await?; + let msg = self.read_message()?; trace!("got message {:?}", msg); match msg { Some(FeMessage::StartupMessage(m)) => { @@ -628,41 +607,39 @@ impl Connection { match m.kind { StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => { let b = Bytes::from("N"); - self.stream.write_all(&b).await?; - self.stream.flush().await?; + self.stream.write_all(&b)?; + self.stream.flush()?; } StartupRequestCode::Normal => { - self.write_message_noflush(&BeMessage::AuthenticationOk) - .await?; - self.write_message(&BeMessage::ReadyForQuery).await?; + self.write_message_noflush(&BeMessage::AuthenticationOk)?; + self.write_message(&BeMessage::ReadyForQuery)?; self.init_done = true; } StartupRequestCode::Cancel => return Ok(()), } } Some(FeMessage::Query(m)) => { - self.process_query(m.body).await?; + self.process_query(m.body)?; } Some(FeMessage::Parse(m)) => { unnamed_query_string = m.query_string; - self.write_message(&BeMessage::ParseComplete).await?; + self.write_message(&BeMessage::ParseComplete)?; } Some(FeMessage::Describe(_)) => { - self.write_message_noflush(&BeMessage::ParameterDescription) - .await?; - self.write_message(&BeMessage::NoData).await?; + self.write_message_noflush(&BeMessage::ParameterDescription)?; + self.write_message(&BeMessage::NoData)?; } Some(FeMessage::Bind(_)) => { - self.write_message(&BeMessage::BindComplete).await?; + self.write_message(&BeMessage::BindComplete)?; } Some(FeMessage::Close(_)) => { - self.write_message(&BeMessage::CloseComplete).await?; + self.write_message(&BeMessage::CloseComplete)?; } Some(FeMessage::Execute(_)) => { - self.process_query(unnamed_query_string.clone()).await?; + self.process_query(unnamed_query_string.clone())?; } Some(FeMessage::Sync) => { - self.write_message(&BeMessage::ReadyForQuery).await?; + self.write_message(&BeMessage::ReadyForQuery)?; } Some(FeMessage::Terminate) => { break; @@ -681,7 +658,7 @@ impl Connection { Ok(()) } - async fn process_query(&mut self, query_string: Bytes) -> Result<()> { + fn process_query(&mut self, query_string: Bytes) -> Result<()> { debug!("process query {:?}", query_string); // remove null terminator, if any @@ -691,13 +668,13 @@ impl Connection { } if query_string.starts_with(b"controlfile") { - self.handle_controlfile().await + self.handle_controlfile() } else if query_string.starts_with(b"pagestream ") { let (_l, r) = query_string.split_at("pagestream ".len()); let timelineid_str = String::from_utf8(r.to_vec()).unwrap(); let timelineid = ZTimelineId::from_str(&timelineid_str).unwrap(); - self.handle_pagerequests(timelineid).await + self.handle_pagerequests(timelineid) } else if query_string.starts_with(b"basebackup ") { let (_l, r) = query_string.split_at("basebackup ".len()); let r = r.to_vec(); @@ -706,10 +683,9 @@ impl Connection { let timelineid = ZTimelineId::from_str(&timelineid_str).unwrap(); // Check that the timeline exists - self.handle_basebackup_request(timelineid).await?; - self.write_message_noflush(&BeMessage::CommandComplete) - .await?; - self.write_message(&BeMessage::ReadyForQuery).await + self.handle_basebackup_request(timelineid)?; + self.write_message_noflush(&BeMessage::CommandComplete)?; + self.write_message(&BeMessage::ReadyForQuery) } else if query_string.starts_with(b"callmemaybe ") { let query_str = String::from_utf8(query_string.to_vec()) .unwrap() @@ -733,36 +709,29 @@ impl Connection { walreceiver::launch_wal_receiver(&self.conf, timelineid, &connstr); - self.write_message_noflush(&BeMessage::CommandComplete) - .await?; - self.write_message(&BeMessage::ReadyForQuery).await + self.write_message_noflush(&BeMessage::CommandComplete)?; + self.write_message(&BeMessage::ReadyForQuery) } else if query_string.starts_with(b"status") { - self.write_message_noflush(&BeMessage::RowDescription) - .await?; - self.write_message_noflush(&BeMessage::DataRow).await?; - self.write_message_noflush(&BeMessage::CommandComplete) - .await?; - self.write_message(&BeMessage::ReadyForQuery).await + self.write_message_noflush(&BeMessage::RowDescription)?; + self.write_message_noflush(&BeMessage::DataRow)?; + self.write_message_noflush(&BeMessage::CommandComplete)?; + self.write_message(&BeMessage::ReadyForQuery) } else { - self.write_message_noflush(&BeMessage::RowDescription) - .await?; - self.write_message_noflush(&BeMessage::DataRow).await?; - self.write_message_noflush(&BeMessage::CommandComplete) - .await?; - self.write_message(&BeMessage::ReadyForQuery).await + self.write_message_noflush(&BeMessage::RowDescription)?; + self.write_message_noflush(&BeMessage::DataRow)?; + self.write_message_noflush(&BeMessage::CommandComplete)?; + self.write_message(&BeMessage::ReadyForQuery) } } - async fn handle_controlfile(&mut self) -> Result<()> { - self.write_message_noflush(&BeMessage::RowDescription) - .await?; - self.write_message_noflush(&BeMessage::ControlFile).await?; - self.write_message_noflush(&BeMessage::CommandComplete) - .await?; - self.write_message(&BeMessage::ReadyForQuery).await + fn handle_controlfile(&mut self) -> Result<()> { + self.write_message_noflush(&BeMessage::RowDescription)?; + self.write_message_noflush(&BeMessage::ControlFile)?; + self.write_message_noflush(&BeMessage::CommandComplete)?; + self.write_message(&BeMessage::ReadyForQuery) } - async fn handle_pagerequests(&mut self, timelineid: ZTimelineId) -> Result<()> { + fn handle_pagerequests(&mut self, timelineid: ZTimelineId) -> Result<()> { // Check that the timeline exists let pcache = page_cache::get_or_restore_pagecache(&self.conf, timelineid); if pcache.is_err() { @@ -773,14 +742,14 @@ impl Connection { let pcache = pcache.unwrap(); /* switch client to COPYBOTH */ - self.stream.write_u8(b'W').await?; - self.stream.write_i32(4 + 1 + 2).await?; - self.stream.write_u8(0).await?; /* copy_is_binary */ - self.stream.write_i16(0).await?; /* numAttributes */ - self.stream.flush().await?; + self.stream.write_u8(b'W')?; + self.stream.write_i32::(4 + 1 + 2)?; + self.stream.write_u8(0)?; /* copy_is_binary */ + self.stream.write_i16::(0)?; /* numAttributes */ + self.stream.flush()?; loop { - let message = self.read_message().await?; + let message = self.read_message()?; if let Some(m) = &message { trace!("query({:?}): {:?}", timelineid, m); @@ -800,13 +769,12 @@ impl Connection { forknum: req.forknum, }; - let exist = pcache.relsize_exist(&tag, req.lsn).await.unwrap_or(false); + let exist = pcache.relsize_exist(&tag, req.lsn).unwrap_or(false); self.write_message(&BeMessage::ZenithStatusResponse(ZenithStatusResponse { ok: exist, n_blocks: 0, - })) - .await? + }))? } Some(FeMessage::ZenithNblocksRequest(req)) => { let tag = page_cache::RelTag { @@ -816,13 +784,12 @@ impl Connection { forknum: req.forknum, }; - let n_blocks = pcache.relsize_get(&tag, req.lsn).await.unwrap_or(0); + let n_blocks = pcache.relsize_get(&tag, req.lsn).unwrap_or(0); self.write_message(&BeMessage::ZenithNblocksResponse(ZenithStatusResponse { ok: true, n_blocks, - })) - .await? + }))? } Some(FeMessage::ZenithReadRequest(req)) => { let buf_tag = page_cache::BufferTag { @@ -835,7 +802,7 @@ impl Connection { blknum: req.blkno, }; - let msg = match pcache.get_page_at_lsn(buf_tag, req.lsn).await { + let msg = match pcache.get_page_at_lsn(buf_tag, req.lsn) { Ok(p) => BeMessage::ZenithReadResponse(ZenithReadResponse { ok: true, n_blocks: 0, @@ -852,14 +819,14 @@ impl Connection { } }; - self.write_message(&msg).await? + self.write_message(&msg)? } _ => {} } } } - async fn handle_basebackup_request(&mut self, timelineid: ZTimelineId) -> Result<()> { + fn handle_basebackup_request(&mut self, timelineid: ZTimelineId) -> Result<()> { // check that the timeline exists let pcache = page_cache::get_or_restore_pagecache(&self.conf, timelineid); if pcache.is_err() { @@ -870,11 +837,11 @@ impl Connection { /* switch client to COPYOUT */ let stream = &mut self.stream; - stream.write_u8(b'H').await?; - stream.write_i32(4 + 1 + 2).await?; - stream.write_u8(0).await?; /* copy_is_binary */ - stream.write_i16(0).await?; /* numAttributes */ - stream.flush().await?; + stream.write_u8(b'H')?; + stream.write_i32::(4 + 1 + 2)?; + stream.write_u8(0)?; /* copy_is_binary */ + stream.write_i16::(0)?; /* numAttributes */ + stream.flush()?; info!("sent CopyOut"); /* Send a tarball of the latest snapshot on the timeline */ @@ -882,49 +849,16 @@ impl Connection { // find latest snapshot let snapshotlsn = restore_local_repo::find_latest_snapshot(&self.conf, timelineid).unwrap(); - // Stream it - let (s, mut r) = mpsc::channel(5); - - let f_tar = task::spawn_blocking(move || { - basebackup::send_snapshot_tarball(&mut CopyDataSink(s), timelineid, snapshotlsn)?; - Ok(()) - }); - let f_tar2 = async { - let joinres = f_tar.await; - - if let Err(joinreserr) = joinres { - return Err(io::Error::new(io::ErrorKind::InvalidData, joinreserr)); - } - joinres.unwrap() - }; - - let f_pump = async move { - loop { - let buf = r.recv().await; - if buf.is_none() { - break; - } - let buf = buf.unwrap(); - - // CopyData - stream.write_u8(b'd').await?; - stream.write_u32((4 + buf.len()) as u32).await?; - stream.write_all(&buf).await?; - trace!("CopyData sent for {} bytes!", buf.len()); - - // FIXME: flush isn't really required, but makes it easier - // to view in wireshark - stream.flush().await?; - } - Ok(()) - }; - - tokio::try_join!(f_tar2, f_pump)?; + basebackup::send_snapshot_tarball( + &mut CopyDataSink { stream: stream }, + timelineid, + snapshotlsn, + )?; // CopyDone - self.stream.write_u8(b'c').await?; - self.stream.write_u32(4).await?; - self.stream.flush().await?; + self.stream.write_u8(b'c')?; + self.stream.write_u32::(4)?; + self.stream.flush()?; debug!("CopyDone sent!"); // FIXME: I'm getting an error from the tokio copyout driver without this. @@ -936,15 +870,28 @@ impl Connection { } } -struct CopyDataSink(mpsc::Sender); +/// +/// A std::io::Write implementation that wraps all data written to it in CopyData +/// messages. +/// +struct CopyDataSink<'a> { + stream: &'a mut BufWriter, +} -impl std::io::Write for CopyDataSink { +impl<'a> std::io::Write for CopyDataSink<'a> { fn write(&mut self, data: &[u8]) -> std::result::Result { - let buf = Bytes::copy_from_slice(data); - - if let Err(e) = self.0.blocking_send(buf) { - return Err(io::Error::new(io::ErrorKind::Other, e)); - } + // CopyData + // FIXME: if the input is large, we should split it into multiple messages. + // Not sure what the threshold should be, but the ultimate hard limit is that + // the length cannot exceed u32. + self.stream.write_u8(b'd')?; + self.stream.write_u32::((4 + data.len()) as u32)?; + self.stream.write_all(&data)?; + trace!("CopyData sent for {} bytes!", data.len()); + + // FIXME: flush isn't really required, but makes it easier + // to view in wireshark + self.stream.flush()?; Ok(data.len()) } diff --git a/pageserver/src/walreceiver.rs b/pageserver/src/walreceiver.rs index 5ef5f1cf02e1..3ab75ee02cd0 100644 --- a/pageserver/src/walreceiver.rs +++ b/pageserver/src/walreceiver.rs @@ -26,8 +26,9 @@ use std::path::PathBuf; use std::str::FromStr; use std::sync::Mutex; use std::thread; -use tokio::runtime; -use tokio::time::{sleep, Duration}; +use std::thread::sleep; +use std::time::Duration; +use tokio::runtime::Runtime; use tokio_postgres::replication::{PgTimestamp, ReplicationStream}; use tokio_postgres::{NoTls, SimpleQueryMessage, SimpleQueryRow}; use tokio_stream::StreamExt; @@ -95,30 +96,38 @@ fn thread_main(conf: &PageServerConf, timelineid: ZTimelineId) { timelineid ); - let runtime = runtime::Builder::new_current_thread() + // We need a tokio runtime to call the rust-postgres copy_both function. + // Most functions in the rust-postgres driver have a blocking wrapper, + // but copy_both does not (TODO: the copy_both support is still work-in-progress + // as of this writing. Check later if that has changed, or implement the + // wrapper ourselves in rust-postgres) + let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .unwrap(); - runtime.block_on(async { - loop { - // Look up the current WAL producer address - let wal_producer_connstr = get_wal_producer_connstr(timelineid); - - let res = walreceiver_main(conf, timelineid, &wal_producer_connstr).await; - - if let Err(e) = res { - info!( - "WAL streaming connection failed ({}), retrying in 1 second", - e - ); - sleep(Duration::from_secs(1)).await; - } + // + // Make a connection to the WAL safekeeper, or directly to the primary PostgreSQL server, + // and start streaming WAL from it. If the connection is lost, keep retrying. + // + loop { + // Look up the current WAL producer address + let wal_producer_connstr = get_wal_producer_connstr(timelineid); + + let res = walreceiver_main(&runtime, conf, timelineid, &wal_producer_connstr); + + if let Err(e) = res { + info!( + "WAL streaming connection failed ({}), retrying in 1 second", + e + ); + sleep(Duration::from_secs(1)); } - }); + } } -async fn walreceiver_main( +fn walreceiver_main( + runtime: &Runtime, conf: &PageServerConf, timelineid: ZTimelineId, wal_producer_connstr: &str, @@ -126,18 +135,19 @@ async fn walreceiver_main( // Connect to the database in replication mode. info!("connecting to {:?}", wal_producer_connstr); let connect_cfg = format!("{} replication=true", wal_producer_connstr); - let (rclient, connection) = tokio_postgres::connect(&connect_cfg, NoTls).await?; + + let (rclient, connection) = runtime.block_on(tokio_postgres::connect(&connect_cfg, NoTls))?; info!("connected!"); // The connection object performs the actual communication with the database, // so spawn it off to run on its own. - tokio::spawn(async move { + runtime.spawn(async move { if let Err(e) = connection.await { error!("connection error: {}", e); } }); - let identify = identify_system(&rclient).await?; + let identify = identify_system(runtime, &rclient)?; info!("{:?}", identify); let end_of_wal = Lsn::from(u64::from(identify.xlogpos)); let mut caught_up = false; @@ -174,14 +184,15 @@ async fn walreceiver_main( ); let query = format!("START_REPLICATION PHYSICAL {}", startpoint); - let copy_stream = rclient.copy_both_simple::(&query).await?; + + let copy_stream = runtime.block_on(rclient.copy_both_simple::(&query))?; let physical_stream = ReplicationStream::new(copy_stream); tokio::pin!(physical_stream); let mut waldecoder = WalStreamDecoder::new(startpoint); - while let Some(replication_message) = physical_stream.next().await { + while let Some(replication_message) = runtime.block_on(physical_stream.next()) { match replication_message? { ReplicationMessage::XLogData(xlog_data) => { // Pass the WAL data to the decoder, and see if we can decode @@ -309,10 +320,11 @@ async fn walreceiver_main( let ts = PgTimestamp::now()?; const NO_REPLY: u8 = 0u8; - physical_stream - .as_mut() - .standby_status_update(write_lsn, flush_lsn, apply_lsn, ts, NO_REPLY) - .await?; + runtime.block_on( + physical_stream + .as_mut() + .standby_status_update(write_lsn, flush_lsn, apply_lsn, ts, NO_REPLY), + )?; } } _ => (), @@ -341,9 +353,12 @@ pub struct IdentifySystem { pub struct IdentifyError; /// Run the postgres `IDENTIFY_SYSTEM` command -pub async fn identify_system(client: &tokio_postgres::Client) -> Result { +pub fn identify_system( + runtime: &Runtime, + client: &tokio_postgres::Client, +) -> Result { let query_str = "IDENTIFY_SYSTEM"; - let response = client.simple_query(query_str).await?; + let response = runtime.block_on(client.simple_query(query_str))?; // get(N) from row, then parse it as some destination type. fn get_parse(row: &SimpleQueryRow, idx: usize) -> Result diff --git a/pageserver/src/walredo.rs b/pageserver/src/walredo.rs index f170e1800822..abb965f1f6c0 100644 --- a/pageserver/src/walredo.rs +++ b/pageserver/src/walredo.rs @@ -24,13 +24,13 @@ use std::io::prelude::*; use std::io::Error; use std::path::PathBuf; use std::process::Stdio; +use std::sync::mpsc; use std::sync::{Arc, Mutex}; use std::time::Duration; use std::time::Instant; use tokio::io::AsyncBufReadExt; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::process::{Child, ChildStdin, ChildStdout, Command}; -use tokio::sync::{mpsc, oneshot}; use tokio::time::timeout; use zenith_utils::lsn::Lsn; @@ -52,8 +52,8 @@ pub struct WalRedoManager { conf: PageServerConf, timelineid: ZTimelineId, - request_tx: mpsc::UnboundedSender, - request_rx: Mutex>>, + request_tx: Mutex>, + request_rx: Mutex>>, } struct WalRedoManagerInternal { @@ -61,7 +61,7 @@ struct WalRedoManagerInternal { timelineid: ZTimelineId, pcache: Arc, - request_rx: mpsc::UnboundedReceiver, + request_rx: mpsc::Receiver, } #[derive(Debug)] @@ -69,7 +69,7 @@ struct WalRedoRequest { tag: BufferTag, lsn: Lsn, - response_channel: oneshot::Sender>, + response_channel: mpsc::Sender>, } /// An error happened in WAL redo @@ -89,12 +89,12 @@ impl WalRedoManager { /// This only initializes the struct. You need to call WalRedoManager::launch to /// start the thread that processes the requests. pub fn new(conf: &PageServerConf, timelineid: ZTimelineId) -> WalRedoManager { - let (tx, rx) = mpsc::unbounded_channel(); + let (tx, rx) = mpsc::channel(); WalRedoManager { conf: conf.clone(), timelineid, - request_tx: tx, + request_tx: Mutex::new(tx), request_rx: Mutex::new(Some(rx)), } } @@ -114,22 +114,13 @@ impl WalRedoManager { let _walredo_thread = std::thread::Builder::new() .name("WAL redo thread".into()) .spawn(move || { - // We block on waiting for requests on the walredo request channel, but - // use async I/O to communicate with the child process. Initialize the - // runtime for the async part. - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - let mut internal = WalRedoManagerInternal { _conf: conf_copy, timelineid: timelineid, pcache: pcache, request_rx: request_rx, }; - - runtime.block_on(internal.wal_redo_main()); + internal.wal_redo_main(); }) .unwrap(); } @@ -138,9 +129,9 @@ impl WalRedoManager { /// Request the WAL redo manager to apply WAL records, to reconstruct the page image /// of the given page version. /// - pub async fn request_redo(&self, tag: BufferTag, lsn: Lsn) -> Result { + pub fn request_redo(&self, tag: BufferTag, lsn: Lsn) -> Result { // Create a channel where to receive the response - let (tx, rx) = oneshot::channel::>(); + let (tx, rx) = mpsc::channel::>(); let request = WalRedoRequest { tag, @@ -149,10 +140,12 @@ impl WalRedoManager { }; self.request_tx + .lock() + .unwrap() .send(request) .expect("could not send WAL redo request"); - rx.await + rx.recv() .expect("could not receive response to WAL redo request") } } @@ -164,9 +157,17 @@ impl WalRedoManagerInternal { // // Main entry point for the WAL applicator thread. // - async fn wal_redo_main(&mut self) { + fn wal_redo_main(&mut self) { info!("WAL redo thread started {}", self.timelineid); + // We block on waiting for requests on the walredo request channel, but + // use async I/O to communicate with the child process. Initialize the + // runtime for the async part. + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + // Loop forever, handling requests as they come. loop { let mut process: WalRedoProcess; @@ -174,7 +175,7 @@ impl WalRedoManagerInternal { info!("launching WAL redo postgres process {}", self.timelineid); - process = WalRedoProcess::launch(&datadir).await.unwrap(); + process = runtime.block_on(WalRedoProcess::launch(&datadir)).unwrap(); info!("WAL redo postgres started"); // Pretty arbitrarily, reuse the same Postgres process for 100000 requests. @@ -182,9 +183,9 @@ impl WalRedoManagerInternal { // using up all shared buffers in Postgres's shared buffer cache; we don't // want to write any pages to disk in the WAL redo process. for _i in 1..100000 { - let request = self.request_rx.recv().await.unwrap(); + let request = self.request_rx.recv().unwrap(); - let result = self.handle_apply_request(&process, &request).await; + let result = runtime.block_on(self.handle_apply_request(&process, &request)); let result_ok = result.is_ok(); // Send the result to the requester @@ -202,11 +203,13 @@ impl WalRedoManagerInternal { // Time to kill the 'postgres' process. A new one will be launched on next // iteration of the loop. + // + // TODO: SIGKILL if needed info!("killing WAL redo postgres process"); - let _ = process.stdin.get_mut().shutdown().await; + let _ = process.stdin.get_mut().shutdown(); let mut child = process.child; drop(process.stdin); - let _ = child.wait().await; + let _ = child.wait(); } } @@ -441,6 +444,13 @@ impl WalRedoProcess { let mut stdin = self.stdin.borrow_mut(); let mut stdout = self.stdout.borrow_mut(); + // We do three things simultaneously: send the old base image and WAL records to + // the child process's stdin, read the result from child's stdout, and forward any logging + // information that the child writes to its stderr to the page server's log. + // + // 'f_stdin' handles writing the base image and WAL records to the child process. + // 'f_stdout' below reads the result back. And 'f_stderr', which was spawned into the + // tokio runtime in the 'launch' function already, forwards the logging. let f_stdin = async { // Send base image, if any. (If the record initializes the page, previous page // version is not needed.) @@ -487,10 +497,6 @@ impl WalRedoProcess { Ok::<[u8; 8192], Error>(buf) }; - // Kill the process. This closes its stdin, which should signal the process - // to terminate. TODO: SIGKILL if needed - //child.wait(); - let res = futures::try_join!(f_stdout, f_stdin)?; let buf = res.0; diff --git a/zenith_utils/Cargo.toml b/zenith_utils/Cargo.toml index a26a772c9778..ee549ab2f951 100644 --- a/zenith_utils/Cargo.toml +++ b/zenith_utils/Cargo.toml @@ -5,8 +5,4 @@ authors = ["Eric Seppanen "] edition = "2018" [dependencies] -tokio = { version = "1.5", features = ["sync", "time" ] } thiserror = "1" - -[dev-dependencies] -tokio = { version = "1.5", features = ["macros", "rt"] } diff --git a/zenith_utils/src/seqwait.rs b/zenith_utils/src/seqwait.rs index b4f3cdd45424..409090256e4d 100644 --- a/zenith_utils/src/seqwait.rs +++ b/zenith_utils/src/seqwait.rs @@ -1,12 +1,12 @@ #![warn(missing_docs)] -use std::collections::BTreeMap; +use std::cmp::{Eq, Ordering, PartialOrd}; +use std::collections::BinaryHeap; use std::fmt::Debug; use std::mem; +use std::sync::mpsc::{channel, Receiver, Sender}; use std::sync::Mutex; use std::time::Duration; -use tokio::sync::watch::{channel, Receiver, Sender}; -use tokio::time::timeout; /// An error happened while waiting for a number #[derive(Debug, PartialEq, thiserror::Error)] @@ -23,14 +23,44 @@ struct SeqWaitInt where T: Ord, { - waiters: BTreeMap, Receiver<()>)>, + waiters: BinaryHeap>, current: T, shutdown: bool, } +struct Waiter +where + T: Ord, +{ + wake_num: T, // wake me when this number arrives ... + wake_channel: Sender<()>, // ... by sending a message to this channel +} + +// BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here +// to get that. +impl PartialOrd for Waiter { + fn partial_cmp(&self, other: &Self) -> Option { + other.wake_num.partial_cmp(&self.wake_num) + } +} + +impl Ord for Waiter { + fn cmp(&self, other: &Self) -> Ordering { + other.wake_num.cmp(&self.wake_num) + } +} + +impl PartialEq for Waiter { + fn eq(&self, other: &Self) -> bool { + other.wake_num == self.wake_num + } +} + +impl Eq for Waiter {} + /// A tool for waiting on a sequence number /// -/// This provides a way to await the arrival of a number. +/// This provides a way to wait the arrival of a number. /// As soon as the number arrives by another caller calling /// [`advance`], then the waiter will be woken up. /// @@ -56,7 +86,7 @@ where /// Create a new `SeqWait`, initialized to a particular number pub fn new(starting_num: T) -> Self { let internal = SeqWaitInt { - waiters: BTreeMap::new(), + waiters: BinaryHeap::new(), current: starting_num, shutdown: false, }; @@ -92,29 +122,12 @@ where /// /// This call won't complete until someone has called `advance` /// with a number greater than or equal to the one we're waiting for. - pub async fn wait_for(&self, num: T) -> Result<(), SeqWaitError> { - let mut rx = { - let mut internal = self.internal.lock().unwrap(); - if internal.current >= num { - return Ok(()); - } - if internal.shutdown { - return Err(SeqWaitError::Shutdown); - } - - // If we already have a channel for waiting on this number, reuse it. - if let Some((_, rx)) = internal.waiters.get_mut(&num) { - // an Err from changed() means the sender was dropped. - rx.clone() - } else { - // Create a new channel. - let (tx, rx) = channel(()); - internal.waiters.insert(num, (tx, rx.clone())); - rx - } - // Drop the lock as we exit this scope. - }; - rx.changed().await.map_err(|_| SeqWaitError::Shutdown) + pub fn wait_for(&self, num: T) -> Result<(), SeqWaitError> { + match self.queue_for_wait(num) { + Ok(None) => Ok(()), + Ok(Some(rx)) => rx.recv().map_err(|_| SeqWaitError::Shutdown), + Err(e) => Err(e), + } } /// Wait for a number to arrive @@ -124,14 +137,36 @@ where /// /// If that hasn't happened after the specified timeout duration, /// [`SeqWaitError::Timeout`] will be returned. - pub async fn wait_for_timeout( - &self, - num: T, - timeout_duration: Duration, - ) -> Result<(), SeqWaitError> { - timeout(timeout_duration, self.wait_for(num)) - .await - .unwrap_or(Err(SeqWaitError::Timeout)) + pub fn wait_for_timeout(&self, num: T, timeout_duration: Duration) -> Result<(), SeqWaitError> { + match self.queue_for_wait(num) { + Ok(None) => Ok(()), + Ok(Some(rx)) => rx.recv_timeout(timeout_duration).map_err(|e| match e { + std::sync::mpsc::RecvTimeoutError::Timeout => SeqWaitError::Timeout, + std::sync::mpsc::RecvTimeoutError::Disconnected => SeqWaitError::Shutdown, + }), + Err(e) => Err(e), + } + } + + /// Register and return a channel that will be notified when a number arrives, + /// or None, if it has already arrived. + fn queue_for_wait(&self, num: T) -> Result>, SeqWaitError> { + let mut internal = self.internal.lock().unwrap(); + if internal.current >= num { + return Ok(None); + } + if internal.shutdown { + return Err(SeqWaitError::Shutdown); + } + + // Create a new channel. + let (tx, rx) = channel(); + internal.waiters.push(Waiter { + wake_num: num, + wake_channel: tx, + }); + // Drop the lock as we exit this scope. + Ok(Some(rx)) } /// Announce a new number has arrived @@ -152,22 +187,19 @@ where } internal.current = num; - // split_off will give me all the high-numbered waiters, - // so split and then swap. Everything at or above `num` - // stays. - let mut split = internal.waiters.split_off(&num); - std::mem::swap(&mut split, &mut internal.waiters); - - // `split_at` didn't get the value at `num`; if it's - // there take that too. - if let Some(sleeper) = internal.waiters.remove(&num) { - split.insert(num, sleeper); + // Pop all waiters <= num from the heap. Collect them in a vector, and + // wake them up after releasing the lock. + let mut wake_these = Vec::new(); + while let Some(n) = internal.waiters.peek() { + if n.wake_num > num { + break; + } + wake_these.push(internal.waiters.pop().unwrap().wake_channel); } - - split + wake_these }; - for (_wake_num, (tx, _rx)) in wake_these { + for tx in wake_these { // This can fail if there are no receivers. // We don't care; discard the error. let _ = tx.send(()); @@ -179,38 +211,40 @@ where mod tests { use super::*; use std::sync::Arc; - use tokio::time::{sleep, Duration}; + use std::thread::sleep; + use std::thread::spawn; + use std::time::Duration; - #[tokio::test] - async fn seqwait() { + #[test] + fn seqwait() { let seq = Arc::new(SeqWait::new(0)); let seq2 = Arc::clone(&seq); let seq3 = Arc::clone(&seq); - tokio::spawn(async move { - seq2.wait_for(42).await.expect("wait_for 42"); + spawn(move || { + seq2.wait_for(42).expect("wait_for 42"); seq2.advance(100); - seq2.wait_for(999).await.expect_err("no 999"); + seq2.wait_for(999).expect_err("no 999"); }); - tokio::spawn(async move { - seq3.wait_for(42).await.expect("wait_for 42"); - seq3.wait_for(0).await.expect("wait_for 0"); + spawn(move || { + seq3.wait_for(42).expect("wait_for 42"); + seq3.wait_for(0).expect("wait_for 0"); }); - sleep(Duration::from_secs(1)).await; + sleep(Duration::from_secs(1)); seq.advance(99); - seq.wait_for(100).await.expect("wait_for 100"); + seq.wait_for(100).expect("wait_for 100"); seq.shutdown(); } - #[tokio::test] - async fn seqwait_timeout() { + #[test] + fn seqwait_timeout() { let seq = Arc::new(SeqWait::new(0)); let seq2 = Arc::clone(&seq); - tokio::spawn(async move { + spawn(move || { let timeout = Duration::from_millis(1); - let res = seq2.wait_for_timeout(42, timeout).await; + let res = seq2.wait_for_timeout(42, timeout); assert_eq!(res, Err(SeqWaitError::Timeout)); }); - sleep(Duration::from_secs(1)).await; + sleep(Duration::from_secs(1)); // This will attempt to wake, but nothing will happen // because the waiter already dropped its Receiver. seq.advance(99); From bc652e965e8426fbf95e3fe345a72586cea0d4a1 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Mon, 26 Apr 2021 13:30:10 +0300 Subject: [PATCH 2/2] Save old 'async' version of SeqWait, in case we need it later. It is currently unused, and is not built as part of 'cargo build', but seems like a shame to throw it away completely. --- zenith_utils/src/lib.rs | 3 + zenith_utils/src/seqwait_async.rs | 224 ++++++++++++++++++++++++++++++ 2 files changed, 227 insertions(+) create mode 100644 zenith_utils/src/seqwait_async.rs diff --git a/zenith_utils/src/lib.rs b/zenith_utils/src/lib.rs index fb4d415f22b1..21bfc73930cb 100644 --- a/zenith_utils/src/lib.rs +++ b/zenith_utils/src/lib.rs @@ -5,3 +5,6 @@ pub mod lsn; /// SeqWait allows waiting for a future sequence number to arrive pub mod seqwait; + +// Async version of SeqWait. Currently unused. +// pub mod seqwait_async; diff --git a/zenith_utils/src/seqwait_async.rs b/zenith_utils/src/seqwait_async.rs new file mode 100644 index 000000000000..09138e9dd4e6 --- /dev/null +++ b/zenith_utils/src/seqwait_async.rs @@ -0,0 +1,224 @@ +/// +/// Async version of 'seqwait.rs' +/// +/// NOTE: This is currently unused. If you need this, you'll need to uncomment this in lib.rs. +/// + +#![warn(missing_docs)] + +use std::collections::BTreeMap; +use std::fmt::Debug; +use std::mem; +use std::sync::Mutex; +use std::time::Duration; +use tokio::sync::watch::{channel, Receiver, Sender}; +use tokio::time::timeout; + +/// An error happened while waiting for a number +#[derive(Debug, PartialEq, thiserror::Error)] +#[error("SeqWaitError")] +pub enum SeqWaitError { + /// The wait timeout was reached + Timeout, + /// [`SeqWait::shutdown`] was called + Shutdown, +} + +/// Internal components of a `SeqWait` +struct SeqWaitInt +where + T: Ord, +{ + waiters: BTreeMap, Receiver<()>)>, + current: T, + shutdown: bool, +} + +/// A tool for waiting on a sequence number +/// +/// This provides a way to await the arrival of a number. +/// As soon as the number arrives by another caller calling +/// [`advance`], then the waiter will be woken up. +/// +/// This implementation takes a blocking Mutex on both [`wait_for`] +/// and [`advance`], meaning there may be unexpected executor blocking +/// due to thread scheduling unfairness. There are probably better +/// implementations, but we can probably live with this for now. +/// +/// [`wait_for`]: SeqWait::wait_for +/// [`advance`]: SeqWait::advance +/// +pub struct SeqWait +where + T: Ord, +{ + internal: Mutex>, +} + +impl SeqWait +where + T: Ord + Debug + Copy, +{ + /// Create a new `SeqWait`, initialized to a particular number + pub fn new(starting_num: T) -> Self { + let internal = SeqWaitInt { + waiters: BTreeMap::new(), + current: starting_num, + shutdown: false, + }; + SeqWait { + internal: Mutex::new(internal), + } + } + + /// Shut down a `SeqWait`, causing all waiters (present and + /// future) to return an error. + pub fn shutdown(&self) { + let waiters = { + // Prevent new waiters; wake all those that exist. + // Wake everyone with an error. + let mut internal = self.internal.lock().unwrap(); + + // This will steal the entire waiters map. + // When we drop it all waiters will be woken. + mem::take(&mut internal.waiters) + + // Drop the lock as we exit this scope. + }; + + // When we drop the waiters list, each Receiver will + // be woken with an error. + // This drop doesn't need to be explicit; it's done + // here to make it easier to read the code and understand + // the order of events. + drop(waiters); + } + + /// Wait for a number to arrive + /// + /// This call won't complete until someone has called `advance` + /// with a number greater than or equal to the one we're waiting for. + pub async fn wait_for(&self, num: T) -> Result<(), SeqWaitError> { + let mut rx = { + let mut internal = self.internal.lock().unwrap(); + if internal.current >= num { + return Ok(()); + } + if internal.shutdown { + return Err(SeqWaitError::Shutdown); + } + + // If we already have a channel for waiting on this number, reuse it. + if let Some((_, rx)) = internal.waiters.get_mut(&num) { + // an Err from changed() means the sender was dropped. + rx.clone() + } else { + // Create a new channel. + let (tx, rx) = channel(()); + internal.waiters.insert(num, (tx, rx.clone())); + rx + } + // Drop the lock as we exit this scope. + }; + rx.changed().await.map_err(|_| SeqWaitError::Shutdown) + } + + /// Wait for a number to arrive + /// + /// This call won't complete until someone has called `advance` + /// with a number greater than or equal to the one we're waiting for. + /// + /// If that hasn't happened after the specified timeout duration, + /// [`SeqWaitError::Timeout`] will be returned. + pub async fn wait_for_timeout( + &self, + num: T, + timeout_duration: Duration, + ) -> Result<(), SeqWaitError> { + timeout(timeout_duration, self.wait_for(num)) + .await + .unwrap_or(Err(SeqWaitError::Timeout)) + } + + /// Announce a new number has arrived + /// + /// All waiters at this value or below will be woken. + /// + /// `advance` will panic if you send it a lower number than + /// a previous call. + pub fn advance(&self, num: T) { + let wake_these = { + let mut internal = self.internal.lock().unwrap(); + + if internal.current > num { + panic!( + "tried to advance backwards, from {:?} to {:?}", + internal.current, num + ); + } + internal.current = num; + + // split_off will give me all the high-numbered waiters, + // so split and then swap. Everything at or above `num` + // stays. + let mut split = internal.waiters.split_off(&num); + std::mem::swap(&mut split, &mut internal.waiters); + + // `split_at` didn't get the value at `num`; if it's + // there take that too. + if let Some(sleeper) = internal.waiters.remove(&num) { + split.insert(num, sleeper); + } + + split + }; + + for (_wake_num, (tx, _rx)) in wake_these { + // This can fail if there are no receivers. + // We don't care; discard the error. + let _ = tx.send(()); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + use tokio::time::{sleep, Duration}; + + #[tokio::test] + async fn seqwait() { + let seq = Arc::new(SeqWait::new(0)); + let seq2 = Arc::clone(&seq); + let seq3 = Arc::clone(&seq); + tokio::spawn(async move { + seq2.wait_for(42).await.expect("wait_for 42"); + seq2.advance(100); + seq2.wait_for(999).await.expect_err("no 999"); + }); + tokio::spawn(async move { + seq3.wait_for(42).await.expect("wait_for 42"); + seq3.wait_for(0).await.expect("wait_for 0"); + }); + sleep(Duration::from_secs(1)).await; + seq.advance(99); + seq.wait_for(100).await.expect("wait_for 100"); + seq.shutdown(); + } + + #[tokio::test] + async fn seqwait_timeout() { + let seq = Arc::new(SeqWait::new(0)); + let seq2 = Arc::clone(&seq); + tokio::spawn(async move { + let timeout = Duration::from_millis(1); + let res = seq2.wait_for_timeout(42, timeout).await; + assert_eq!(res, Err(SeqWaitError::Timeout)); + }); + sleep(Duration::from_secs(1)).await; + // This will attempt to wake, but nothing will happen + // because the waiter already dropped its Receiver. + seq.advance(99); + } +}