From 74bd001a52609ab7359f32a2cd8439e997f5ff66 Mon Sep 17 00:00:00 2001 From: John Starks Date: Tue, 25 Mar 2025 21:26:15 +0000 Subject: [PATCH 1/6] add tests --- vm/devices/vmbus/vmbus_client/src/lib.rs | 97 ++++++++++++++++++++++-- 1 file changed, 92 insertions(+), 5 deletions(-) diff --git a/vm/devices/vmbus/vmbus_client/src/lib.rs b/vm/devices/vmbus/vmbus_client/src/lib.rs index dc6ae0d732..6054c8469b 100644 --- a/vm/devices/vmbus/vmbus_client/src/lib.rs +++ b/vm/devices/vmbus/vmbus_client/src/lib.rs @@ -2517,7 +2517,35 @@ mod tests { }, )); c0.revoke_recv.await.unwrap(); - server.stop_client(&mut client).await; + let rpc = c0.request_send.call( + ChannelRequest::Modify, + ModifyRequest::TargetVp { target_vp: 1 }, + ); + + check_message( + server.next().await.unwrap(), + protocol::ModifyChannel { + channel_id: ChannelId(0), + target_vp: 1, + }, + ); + + let client_stop = client.stop(); + let server_stop = async { + server.send(in_msg( + MessageType::MODIFY_CHANNEL_RESPONSE, + protocol::ModifyChannelResponse { + channel_id: ChannelId(0), + status: protocol::STATUS_SUCCESS, + }, + )); + check_message(server.next().await.unwrap(), protocol::Pause); + server.send(in_msg(MessageType::PAUSE_RESPONSE, protocol::PauseResponse)); + }; + (client_stop, server_stop).join().await; + + rpc.await.unwrap(); + let s0 = client.save().await; let builder = client.sever().await; let mut client = builder.build(&driver); @@ -2684,7 +2712,7 @@ mod tests { let (mut server, mut client) = test_init(&driver); let channel = server.get_channel(&mut client).await; let channel_id = ChannelId(0); - let gpadl_id = GpadlId(1); + for gpadl_id in [1, 2, 3].map(GpadlId) { let recv = channel.request_send.call( ChannelRequest::Gpadl, GpadlRequest { @@ -2715,16 +2743,17 @@ mod tests { )); recv.await.unwrap().unwrap(); + } let rpc = channel .request_send - .call(ChannelRequest::TeardownGpadl, gpadl_id); + .call(ChannelRequest::TeardownGpadl, GpadlId(1)); check_message( server.next().await.unwrap(), protocol::GpadlTeardown { channel_id, - gpadl_id, + gpadl_id: GpadlId(1), }, ); @@ -2733,11 +2762,69 @@ mod tests { protocol::RescindChannelOffer { channel_id }, )); + let recv = channel.request_send.call_failable( + ChannelRequest::Gpadl, + GpadlRequest { + id: GpadlId(4), + count: 1, + buf: vec![3], + }, + ); + + check_message_with_data( + server.next().await.unwrap(), + protocol::GpadlHeader { + channel_id, + gpadl_id: GpadlId(4), + len: 8, + count: 1, + }, + 0x3u64.as_bytes(), + ); + + server.send(in_msg( + MessageType::GPADL_CREATED, + protocol::GpadlCreated { + channel_id, + gpadl_id: GpadlId(4), + status: protocol::STATUS_UNSUCCESSFUL, + }, + )); + + server.send(in_msg( + MessageType::GPADL_TORNDOWN, + protocol::GpadlTorndown { + gpadl_id: GpadlId(1), + }, + )); + rpc.await.unwrap(); + recv.await.unwrap_err(); channel.revoke_recv.await.unwrap(); + + let rpc = channel + .request_send + .call(ChannelRequest::TeardownGpadl, GpadlId(2)); drop(channel.request_send); + check_message( + server.next().await.unwrap(), + protocol::GpadlTeardown { + channel_id, + gpadl_id: GpadlId(2), + }, + ); + + server.send(in_msg( + MessageType::GPADL_TORNDOWN, + protocol::GpadlTorndown { + gpadl_id: GpadlId(2), + }, + )); + + rpc.await.unwrap(); + check_message( server.next().await.unwrap(), protocol::RelIdReleased { channel_id }, @@ -2905,7 +2992,7 @@ mod tests { channel .request_send - .call( + .call_failable( ChannelRequest::Open, OpenRequest { open_data: OpenData { From 59e9c22633ab6f2ec7d098dec4798af059b95860 Mon Sep 17 00:00:00 2001 From: John Starks Date: Tue, 25 Mar 2025 21:26:20 +0000 Subject: [PATCH 2/6] fix tests --- vm/devices/vmbus/vmbus_client/src/lib.rs | 530 +++++++++--------- .../vmbus/vmbus_client/src/saved_state.rs | 36 +- 2 files changed, 280 insertions(+), 286 deletions(-) diff --git a/vm/devices/vmbus/vmbus_client/src/lib.rs b/vm/devices/vmbus/vmbus_client/src/lib.rs index 6054c8469b..74007c9654 100644 --- a/vm/devices/vmbus/vmbus_client/src/lib.rs +++ b/vm/devices/vmbus/vmbus_client/src/lib.rs @@ -28,6 +28,8 @@ use std::collections::hash_map; use std::convert::TryInto; use std::future::Future; use std::future::poll_fn; +use std::ops::Deref; +use std::ops::DerefMut; use std::pin::pin; use std::sync::Arc; use std::task::Context; @@ -151,7 +153,6 @@ impl VmbusClientBuilder { poster: self.msg_client, queued: VecDeque::new(), }, - channels: ChannelList::default(), teardown_gpadls: HashMap::new(), channel_requests: SelectAll::new(), synic: SynicState { @@ -162,6 +163,7 @@ impl VmbusClientBuilder { let mut task = ClientTask { inner, + channels: ChannelList::default(), task_recv, running: false, msg_source: self.msg_source, @@ -334,14 +336,15 @@ pub struct OpenOutput { impl std::fmt::Display for ChannelRequest { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ChannelRequest::Open(_) => write!(fmt, "Open"), - ChannelRequest::Close(_) => write!(fmt, "Close"), - ChannelRequest::Restore(_) => write!(fmt, "Restore"), - ChannelRequest::Gpadl(_) => write!(fmt, "Gpadl"), - ChannelRequest::TeardownGpadl(_) => write!(fmt, "TeardownGpadl"), - ChannelRequest::Modify(_) => write!(fmt, "Modify"), - } + let s = match self { + ChannelRequest::Open(_) => "Open", + ChannelRequest::Close(_) => "Close", + ChannelRequest::Restore(_) => "Restore", + ChannelRequest::Gpadl(_) => "Gpadl", + ChannelRequest::TeardownGpadl(_) => "TeardownGpadl", + ChannelRequest::Modify(_) => "Modify", + }; + fmt.pad(s) } } @@ -390,12 +393,13 @@ enum ClientRequest { impl std::fmt::Display for ClientRequest { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ClientRequest::Connect(..) => write!(fmt, "Connect"), - ClientRequest::Unload { .. } => write!(fmt, "Unload"), - ClientRequest::Modify(..) => write!(fmt, "Modify"), - ClientRequest::HvsockConnect(..) => write!(fmt, "HvsockConnect"), - } + let s = match self { + ClientRequest::Connect(..) => "Connect", + ClientRequest::Unload { .. } => "Unload", + ClientRequest::Modify(..) => "Modify", + ClientRequest::HvsockConnect(..) => "HvsockConnect", + }; + fmt.pad(s) } } @@ -457,13 +461,14 @@ impl ClientState { impl std::fmt::Display for ClientState { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ClientState::Disconnected => write!(fmt, "Disconnected"), - ClientState::Connecting { .. } => write!(fmt, "Connecting"), - ClientState::Connected { .. } => write!(fmt, "Connected"), - ClientState::RequestingOffers { .. } => write!(fmt, "RequestingOffers"), - ClientState::Disconnecting { .. } => write!(fmt, "Disconnecting"), - } + let s = match self { + ClientState::Disconnected => "Disconnected", + ClientState::Connecting { .. } => "Connecting", + ClientState::Connected { .. } => "Connected", + ClientState::RequestingOffers { .. } => "RequestingOffers", + ClientState::Disconnecting { .. } => "Disconnecting", + }; + fmt.pad(s) } } @@ -516,39 +521,47 @@ enum ChannelState { #[inspect(skip)] redirected_event: Option, }, + /// The channel has been revoked by the server. + Revoked, } impl std::fmt::Display for ChannelState { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ChannelState::Opening { .. } => write!(fmt, "Opening"), - ChannelState::Offered => write!(fmt, "Offered"), - ChannelState::Opened { .. } => write!(fmt, "Opened"), - ChannelState::Restored => write!(fmt, "Restored"), - } + let s = match self { + ChannelState::Opening { .. } => "Opening", + ChannelState::Offered => "Offered", + ChannelState::Opened { .. } => "Opened", + ChannelState::Restored => "Restored", + ChannelState::Revoked => "Revoked", + }; + fmt.pad(s) } } -#[derive(Inspect)] +#[derive(Debug, Inspect)] struct Channel { offer: protocol::OfferChannel, // When dropped, notifies the caller the channel has been revoked. #[inspect(skip)] - revoke_send: mesh::OneshotSender<()>, + revoke_send: Option>, state: ChannelState, #[inspect(with = "|x| x.is_some()")] modify_response_send: Option>, #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|x| x.0)")] gpadls: HashMap, - released: bool, + is_client_released: bool, } -impl std::fmt::Debug for Channel { - fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - fmt.debug_struct("Channel") - .field("offer", &self.offer) - .field("state", &self.state) - .finish() +impl Channel { + fn pending_request(&self) -> Option<&'static str> { + if self.modify_response_send.is_some() { + return Some("modify"); + } + self.gpadls.iter().find_map(|(_, gpadl)| match gpadl { + GpadlState::Offered(_) => Some("creating gpadl"), + GpadlState::Created => None, + GpadlState::TearingDown { .. } => Some("tearing down gpadl"), + }) } } @@ -556,6 +569,7 @@ impl std::fmt::Debug for Channel { struct ClientTask { #[inspect(flatten)] inner: ClientTaskInner, + channels: ChannelList, state: ClientState, hvsock_tracker: hvsock::HvsockRequestTracker, running: bool, @@ -730,22 +744,22 @@ impl ClientTask { offer: protocol::OfferChannel, state: ChannelState, ) -> Result { - if self.inner.channels.0.contains_key(&offer.channel_id) { + if self.channels.0.contains_key(&offer.channel_id) { anyhow::bail!("channel {:?} exists", offer.channel_id); } let (request_send, request_recv) = mesh::channel(); let (revoke_send, revoke_recv) = mesh::oneshot(); - self.inner.channels.0.insert( + self.channels.0.insert( offer.channel_id, - Some(Channel { - revoke_send, + Channel { + revoke_send: Some(revoke_send), offer, state, modify_response_send: None, gpadls: HashMap::new(), - released: false, - }), + is_client_released: false, + }, ); self.inner @@ -787,20 +801,11 @@ impl ClientTask { } } - fn handle_rescind(&mut self, rescind: protocol::RescindChannelOffer) { + fn handle_rescind(&mut self, rescind: protocol::RescindChannelOffer) -> TriedRelease { tracing::info!(state = %self.state, channel_id = rescind.channel_id.0, "received rescind"); - let hash_map::Entry::Occupied(mut entry) = self.inner.channels.0.entry(rescind.channel_id) - else { - panic!("rescind for unknown channel id {:?}", rescind.channel_id) - }; - - let channel = entry - .get_mut() - .take() - .unwrap_or_else(|| panic!("channel id {:?} already revoked", rescind.channel_id)); - - let event_flag = match channel.state { + let mut channel = self.channels.get_mut(rescind.channel_id); + let event_flag = match std::mem::replace(&mut channel.state, ChannelState::Revoked) { ChannelState::Offered => None, ChannelState::Opening { connection_id: _, @@ -817,40 +822,18 @@ impl ClientTask { redirected_event_flag, redirected_event: _, } => redirected_event_flag, + ChannelState::Revoked => { + panic!("channel id {:?} already revoked", rescind.channel_id); + } }; if let Some(event_flag) = event_flag { self.inner.synic.free_event_flag(event_flag); } - // Teardown all remaining gpadls for this channel. We don't care about GpadlTorndown - // responses at this point. - for (gpadl_id, gpadl_state) in channel.gpadls { - match gpadl_state { - GpadlState::Offered(rpc) => { - rpc.fail(anyhow::anyhow!("channel revoked")); - } - GpadlState::Created => {} - GpadlState::TearingDown { rpcs } => { - self.inner.teardown_gpadls.remove(&gpadl_id).unwrap(); - for rpc in rpcs { - rpc.complete(()); - } - } - } - } - // Drop the channel and send the revoked message to the client. - channel.revoke_send.send(()); - - // Tell the host we're not referencing the client ID anymore, if we are - // not. Otherwise, we will send the released message to the host when - // the client is done with the channel. - if channel.released { - self.inner.messages.send(&protocol::RelIdReleased { - channel_id: rescind.channel_id, - }); - entry.remove(); - } + channel.revoke_send.take().unwrap().send(()); + + channel.try_release(&mut self.inner.messages) } fn handle_offers_delivered(&mut self) { @@ -879,46 +862,33 @@ impl ClientTask { } } - fn handle_gpadl_created(&mut self, request: protocol::GpadlCreated) { - let channel = self - .inner - .channels - .get_for_channel_message(request.channel_id); + fn handle_gpadl_created(&mut self, request: protocol::GpadlCreated) -> TriedRelease { + let mut channel = self.channels.get_mut(request.channel_id); let Some(gpadl_state) = channel.gpadls.get_mut(&request.gpadl_id) else { - tracing::warn!( - gpadl_id = request.gpadl_id.0, - "GpadlCreated for unknown gpadl" - ); - - return; + panic!("GpadlCreated for unknown gpadl {:#x}", request.gpadl_id.0); }; let rpc = match std::mem::replace(gpadl_state, GpadlState::Created) { GpadlState::Offered(rpc) => rpc, old_state => { - *gpadl_state = old_state; - tracing::warn!( - gpadl_id = request.gpadl_id.0, - channel_id = request.channel_id.0, - ?gpadl_state, - "Invalid state for GpadlCreated" + panic!( + "invalid state {old_state:?} for gpadl {:#x}:{:#x}", + request.channel_id.0, request.gpadl_id.0 ); - - return; } }; let gpadl_created = request.status == protocol::STATUS_SUCCESS; - if !gpadl_created { + if gpadl_created { + rpc.complete(Ok(())); + } else { channel.gpadls.remove(&request.gpadl_id).unwrap(); rpc.fail(anyhow::anyhow!( "gpadl creation failed: {:#x}", request.status )); - return; }; - - rpc.complete(Ok(())); + channel.try_release(&mut self.inner.messages) } fn handle_open_result(&mut self, result: protocol::OpenResult) { @@ -928,10 +898,7 @@ impl ClientTask { "received open result" ); - let channel = self - .inner - .channels - .get_for_channel_message(result.channel_id); + let mut channel = self.channels.get_mut(result.channel_id); let channel_opened = result.status == protocol::STATUS_SUCCESS as u32; let old_state = std::mem::replace(&mut channel.state, ChannelState::Offered); @@ -971,13 +938,9 @@ impl ClientTask { })); } - fn handle_gpadl_torndown(&mut self, request: protocol::GpadlTorndown) { + fn handle_gpadl_torndown(&mut self, request: protocol::GpadlTorndown) -> TriedRelease { let Some(channel_id) = self.inner.teardown_gpadls.remove(&request.gpadl_id) else { - tracing::warn!( - gpadl_id = request.gpadl_id.0, - "Unknown ID or invalid state for GpadlTorndown" - ); - return; + panic!("gpadl {:#x} not in teardown list", request.gpadl_id.0); }; tracing::debug!( @@ -986,7 +949,7 @@ impl ClientTask { "Received GpadlTorndown" ); - let channel = self.inner.channels.get_for_channel_message(channel_id); + let mut channel = self.channels.get_mut(channel_id); let gpadl_state = channel .gpadls .remove(&request.gpadl_id) @@ -999,6 +962,7 @@ impl ClientTask { for rpc in rpcs { rpc.complete(()); } + channel.try_release(&mut self.inner.messages) } fn handle_unload_complete(&mut self) { @@ -1021,22 +985,20 @@ impl ClientTask { } } - fn handle_modify_channel_response(&mut self, response: protocol::ModifyChannelResponse) { - let Some(sender) = self - .inner - .channels - .get_for_channel_message(response.channel_id) - .modify_response_send - .take() - else { - tracing::warn!( - channel_id = response.channel_id.0, - "unexpected modify channel response" + fn handle_modify_channel_response( + &mut self, + response: protocol::ModifyChannelResponse, + ) -> TriedRelease { + let mut channel = self.channels.get_mut(response.channel_id); + let Some(sender) = channel.modify_response_send.take() else { + panic!( + "unexpected modify channel response for channel {:#x}", + response.channel_id.0 ); - return; }; sender.complete(response.status); + channel.try_release(&mut self.inner.messages) } fn handle_tl_connect_result(&mut self, response: protocol::TlConnectResult) { @@ -1082,7 +1044,7 @@ impl ClientTask { self.handle_rescind(rescind); } Message::ModifyChannelResponse(response, ..) => { - self.handle_modify_channel_response(response) + self.handle_modify_channel_response(response); } Message::TlConnectResult(response, ..) => self.handle_tl_connect_result(response), // Unsupported messages. @@ -1123,10 +1085,17 @@ impl ClientTask { channel_id: ChannelId, rpc: FailableRpc, ) { - let channel = self.inner.channels.get_for_caller_request(channel_id); - if !matches!(channel.state, ChannelState::Offered) { - rpc.fail(anyhow::anyhow!("invalid channel state: {}", channel.state)); - return; + let mut channel = self.channels.get_mut(channel_id); + match &channel.state { + ChannelState::Offered => {} + ChannelState::Revoked => { + rpc.fail(anyhow::anyhow!("channel revoked")); + return; + } + state => { + rpc.fail(anyhow::anyhow!("invalid channel state: {}", state)); + return; + } } tracing::info!(channel_id = channel_id.0, "opening channel on host"); @@ -1216,7 +1185,7 @@ impl ClientTask { channel_id: ChannelId, request: RestoreRequest, ) -> Result { - let channel = self.inner.channels.get_for_caller_request(channel_id); + let mut channel = self.channels.get_mut(channel_id); if !matches!(channel.state, ChannelState::Restored) { anyhow::bail!("invalid channel state: {}", channel.state); } @@ -1248,7 +1217,7 @@ impl ClientTask { fn handle_gpadl(&mut self, channel_id: ChannelId, rpc: FailableRpc) { let (request, rpc) = rpc.split(); - let channel = self.inner.channels.get_for_caller_request(channel_id); + let mut channel = self.channels.get_mut(channel_id); if channel .gpadls .insert(request.id, GpadlState::Offered(rpc)) @@ -1302,7 +1271,7 @@ impl ClientTask { fn handle_gpadl_teardown(&mut self, channel_id: ChannelId, rpc: Rpc) { let (gpadl_id, rpc) = rpc.split(); - let channel = self.inner.channels.get_for_caller_request(channel_id); + let mut channel = self.channels.get_mut(channel_id); let Some(gpadl_state) = channel.gpadls.get_mut(&gpadl_id) else { tracing::warn!( gpadl_id = gpadl_id.0, @@ -1345,27 +1314,8 @@ impl ClientTask { } fn handle_close_channel(&mut self, channel_id: ChannelId) { - let channel = self.inner.channels.get_for_caller_request(channel_id); - if let ChannelState::Opened { - redirected_event_flag, - .. - } = channel.state - { - if let Some(flag) = redirected_event_flag { - self.inner.synic.free_event_flag(flag); - } - tracing::info!(channel_id = channel_id.0, "closing channel on host"); - self.inner - .messages - .send(&protocol::CloseChannel { channel_id }); - channel.state = ChannelState::Offered; - } else { - tracing::warn!( - id = %channel_id.0, - channel_state = %channel.state, - "invalid channel state for close channel" - ); - } + let mut channel = self.channels.get_mut(channel_id); + self.inner.close_channel(channel_id, &mut channel); } fn handle_modify_channel(&mut self, channel_id: ChannelId, rpc: Rpc) { @@ -1373,7 +1323,7 @@ impl ClientTask { // ModifyChannelResponse. This means we don't need to worry about sending a ChannelResponse // if that weren't supported. assert!(self.check_version(Version::Iron)); - let channel = self.inner.channels.get_for_channel_message(channel_id); + let mut channel = self.channels.get_mut(channel_id); if channel.modify_response_send.is_some() { panic!("duplicate channel modify request {channel_id:?}"); } @@ -1391,24 +1341,6 @@ impl ClientTask { } fn handle_channel_request(&mut self, channel_id: ChannelId, request: ChannelRequest) { - match self.inner.channels.0.get(&channel_id) { - Some(Some(channel)) => { - tracing::trace!( - id = %channel_id.0, - %request, - state = %channel.state, - "received client request" - ); - } - Some(None) => { - tracelimit::info_ratelimited!(id = %channel_id.0, %request, "request for revoked channel"); - return; - } - None => { - panic!("request {} for missing channel {:?}", request, channel_id); - } - }; - match request { ChannelRequest::Open(rpc) => self.handle_open_channel(channel_id, rpc), ChannelRequest::Restore(rpc) => { @@ -1439,35 +1371,18 @@ impl ClientTask { } /// Makes sure a channel is closed if the channel request stream was dropped. - fn handle_device_removal(&mut self, channel_id: ChannelId) { - let hash_map::Entry::Occupied(mut entry) = self.inner.channels.0.entry(channel_id) else { - panic!("channel {:?} does not exist", channel_id); - }; - - match entry.get_mut() { - Some(channel) => { - // The channel is still offered. Remember that the user is gone so - // that we can release the channel ID immediately on revoke. - channel.released = true; - // Close the channel if it is still open. - if let ChannelState::Opened { .. } = channel.state { - tracing::warn!( - channel_id = channel_id.0, - "Channel dropped without closing first" - ); - self.handle_close_channel(channel_id); - } - } - None => { - // The channel has already been revoked. Tell the host we're not - // referencing the client ID anymore. - self.inner - .messages - .send(&protocol::RelIdReleased { channel_id }); - - entry.remove(); - } + fn handle_device_removal(&mut self, channel_id: ChannelId) -> TriedRelease { + let mut channel = self.channels.get_mut(channel_id); + channel.is_client_released = true; + // Close the channel if it is still open. + if let ChannelState::Opened { .. } = channel.state { + tracing::warn!( + channel_id = channel_id.0, + "Channel dropped without closing first" + ); + self.inner.close_channel(channel_id, &mut channel); } + channel.try_release(&mut self.inner.messages) } /// Determines if the client is connected with at least the specified version. @@ -1485,6 +1400,20 @@ impl ClientTask { assert!(self.running); loop { + // Wait until there are no more channels waiting for responses. This + // is necessary to ensure that the saved state does not have to + // support encoding revoked channels for which we are waiting for + // GPADL or modify responses. + if let Some((id, request)) = self.channels.revoked_channel_with_pending_request() { + tracing::info!( + channel_id = id.0, + request, + "waiting for responses for channel" + ); + assert!(self.process_next_message().await); + continue; + } + if self.can_pause_resume() { // Send a pause and flush any queued messages to ensure the host // sees it. @@ -1505,25 +1434,8 @@ impl ClientTask { self.msg_source.pause_message_stream(); } - // Process messages until we hit EOF. - tracing::debug!("draining messages"); - let mut buf = [0; protocol::MAX_MESSAGE_SIZE]; - loop { - let size = self - .msg_source - .recv(&mut buf) - .await - .expect("Fatal error reading messages from synic"); - - if size == 0 { - break; - } - - if !self.handle_synic_message(&buf[..size]) { - // Received a pause response message. We won't receive - // any more messages until we send a resume message. - break; - } + while self.process_next_message().await { + // Continue processing messages until we hit EOF. } // Flush any pending outgoing messages. This needs to be done with @@ -1551,6 +1463,23 @@ impl ClientTask { self.running = false; } + async fn process_next_message(&mut self) -> bool { + // Process messages until we hit EOF. + tracing::debug!("draining messages"); + let mut buf = [0; protocol::MAX_MESSAGE_SIZE]; + let size = self + .msg_source + .recv(&mut buf) + .await + .expect("Fatal error reading messages from synic"); + + if size == 0 { + return false; + } + + self.handle_synic_message(&buf[..size]) + } + /// Returns whether the server supports in-band messages to pause/resume the /// message stream. /// @@ -1616,7 +1545,9 @@ impl ClientTask { r = channel_requests => { match r.unwrap() { (id, Some(request)) => self.handle_channel_request(id, request), - (id, _) => self.handle_device_removal(id), + (id, _) => { + self.handle_device_removal(id); + } } } r = message_recv => { @@ -1639,6 +1570,29 @@ impl ClientTask { } } +impl ClientTaskInner { + fn close_channel(&mut self, channel_id: ChannelId, channel: &mut Channel) { + if let ChannelState::Opened { + redirected_event_flag, + .. + } = channel.state + { + if let Some(flag) = redirected_event_flag { + self.synic.free_event_flag(flag); + } + tracing::info!(channel_id = channel_id.0, "closing channel on host"); + self.messages.send(&protocol::CloseChannel { channel_id }); + channel.state = ChannelState::Offered; + } else { + tracing::warn!( + id = %channel_id.0, + channel_state = %channel.state, + "invalid channel state for close channel" + ); + } + } +} + #[derive(Debug, Inspect)] #[inspect(external_tag)] enum GpadlState { @@ -1718,7 +1672,6 @@ impl OutgoingMessages { #[derive(Inspect)] struct ClientTaskInner { messages: OutgoingMessages, - channels: ChannelList, #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")] teardown_gpadls: HashMap, #[inspect(skip)] @@ -1737,27 +1690,68 @@ struct SynicState { #[derive(Inspect, Default)] #[inspect(transparent)] struct ChannelList( - #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")] - HashMap>, + #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|id| id.0)")] HashMap, ); +/// A reference to a channel that can be used to remove the channel from the map +/// as well. +struct ChannelRef<'a>(hash_map::OccupiedEntry<'a, ChannelId, Channel>); + +/// A tag value used to prove that [`ChannelRef::try_release`] has been called. +/// This is useful to put as the return value for methods that could possibly +/// transition a channel into a fully released state. +struct TriedRelease(()); + +impl ChannelRef<'_> { + /// If the channel has been fully released (revoked, released by the client, + /// no pending requests), notifes the server and removes this channel from + /// the map. + fn try_release(self, messages: &mut OutgoingMessages) -> TriedRelease { + if self.is_client_released + && matches!(self.state, ChannelState::Revoked) + && self.pending_request().is_none() + { + let channel_id = *self.0.key(); + tracelimit::info_ratelimited!(channel_id = channel_id.0, "releasing channel"); + messages.send(&protocol::RelIdReleased { channel_id }); + self.0.remove(); + } + TriedRelease(()) + } +} + +impl Deref for ChannelRef<'_> { + type Target = Channel; + + fn deref(&self) -> &Self::Target { + self.0.get() + } +} + +impl DerefMut for ChannelRef<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.get_mut() + } +} + impl ChannelList { - #[track_caller] - fn get_for_channel_message(&mut self, channel_id: ChannelId) -> &mut Channel { - self.0 - .get_mut(&channel_id) - .unwrap_or_else(|| panic!("channel {channel_id:?} not found")) - .as_mut() - .unwrap_or_else(|| panic!("channel {channel_id:?} was revoked")) + fn revoked_channel_with_pending_request(&self) -> Option<(ChannelId, &'static str)> { + self.0.iter().find_map(|(&id, channel)| { + if !matches!(channel.state, ChannelState::Revoked) { + return None; + } + Some((id, channel.pending_request()?)) + }) } #[track_caller] - fn get_for_caller_request(&mut self, channel_id: ChannelId) -> &mut Channel { - self.0 - .get_mut(&channel_id) - .unwrap_or_else(|| panic!("channel {channel_id:?} not found")) - .as_mut() - .expect("should have been validated already") + fn get_mut(&mut self, channel_id: ChannelId) -> ChannelRef<'_> { + match self.0.entry(channel_id) { + hash_map::Entry::Occupied(entry) => ChannelRef(entry), + hash_map::Entry::Vacant(_) => { + panic!("channel {:?} not found", channel_id); + } + } } } @@ -2713,36 +2707,36 @@ mod tests { let channel = server.get_channel(&mut client).await; let channel_id = ChannelId(0); for gpadl_id in [1, 2, 3].map(GpadlId) { - let recv = channel.request_send.call( - ChannelRequest::Gpadl, - GpadlRequest { - id: gpadl_id, - count: 1, - buf: vec![3], - }, - ); + let recv = channel.request_send.call( + ChannelRequest::Gpadl, + GpadlRequest { + id: gpadl_id, + count: 1, + buf: vec![3], + }, + ); - check_message_with_data( - server.next().await.unwrap(), - protocol::GpadlHeader { - channel_id, - gpadl_id, - len: 8, - count: 1, - }, - 0x3u64.as_bytes(), - ); + check_message_with_data( + server.next().await.unwrap(), + protocol::GpadlHeader { + channel_id, + gpadl_id, + len: 8, + count: 1, + }, + 0x3u64.as_bytes(), + ); - server.send(in_msg( - MessageType::GPADL_CREATED, - protocol::GpadlCreated { - channel_id, - gpadl_id, - status: protocol::STATUS_SUCCESS, - }, - )); + server.send(in_msg( + MessageType::GPADL_CREATED, + protocol::GpadlCreated { + channel_id, + gpadl_id, + status: protocol::STATUS_SUCCESS, + }, + )); - recv.await.unwrap().unwrap(); + recv.await.unwrap().unwrap(); } let rpc = channel diff --git a/vm/devices/vmbus/vmbus_client/src/saved_state.rs b/vm/devices/vmbus/vmbus_client/src/saved_state.rs index bfc4c8a6ad..e71b3e41bb 100644 --- a/vm/devices/vmbus/vmbus_client/src/saved_state.rs +++ b/vm/devices/vmbus/vmbus_client/src/saved_state.rs @@ -49,12 +49,14 @@ impl super::ClientTask { } }, channels: self - .inner .channels .0 .iter() .filter_map(|(&id, v)| { - let Some(v) = v else { + let Some(state) = ChannelState::save(&v.state) else { + if let Some(request) = v.pending_request() { + panic!("revoked channel {id} has pending request '{request}' that should be drained", id = id.0); + } // The channel has been revoked, but the user is not // done with it. The channel won't be available for use // when we restore, so don't save it, but do save a @@ -75,23 +77,20 @@ impl super::ClientTask { tracing::info!(%key, %v.state, "channel saved"); Some(Channel { id: id.0, - state: ChannelState::save(&v.state), + state, offer: v.offer.into(), }) }) .collect(), gpadls: self - .inner .channels .0 .iter() .flat_map(|(channel_id, channel)| { - channel.iter().flat_map(|c| { - c.gpadls.iter().map(|(gpadl_id, gpadl_state)| Gpadl { - gpadl_id: gpadl_id.0, - channel_id: channel_id.0, - state: GpadlState::save(gpadl_state), - }) + channel.gpadls.iter().map(|(gpadl_id, gpadl_state)| Gpadl { + gpadl_id: gpadl_id.0, + channel_id: channel_id.0, + state: GpadlState::save(gpadl_state), }) }) .collect(), @@ -157,11 +156,9 @@ impl super::ClientTask { let tearing_down = matches!(gpadl_state, super::GpadlState::TearingDown { .. }); let channel = self - .inner .channels .0 .get_mut(&channel_id) - .and_then(|v| v.as_mut()) .ok_or(RestoreError::GpadlForUnknownChannelId(channel_id.0))?; if channel.gpadls.insert(gpadl_id, gpadl_state).is_some() { @@ -197,8 +194,7 @@ impl super::ClientTask { assert!(!self.running); // Close restored channels that have not been claimed. - for (&channel_id, channel) in &mut self.inner.channels.0 { - let Some(channel) = channel else { continue }; + for (&channel_id, channel) in &mut self.channels.0 { if let super::ChannelState::Restored = channel.state { tracing::info!( channel_id = channel_id.0, @@ -293,14 +289,18 @@ pub enum ChannelState { } impl ChannelState { - fn save(state: &super::ChannelState) -> Self { - match state { + fn save(state: &super::ChannelState) -> Option { + let s = match state { super::ChannelState::Offered => Self::Offered, super::ChannelState::Opening { .. } => { unreachable!("Cannot save channel in opening state.") } - super::ChannelState::Restored | super::ChannelState::Opened { .. } => Self::Opened, - } + super::ChannelState::Restored | super::ChannelState::Opened { .. } => { + Self::Opened + } + super::ChannelState::Revoked => return None, + }; + Some(s) } fn restore(self) -> super::ChannelState { From 39a59c7ef5a79f106d72c9f244e7abb6d631d5ec Mon Sep 17 00:00:00 2001 From: John Starks Date: Wed, 26 Mar 2025 18:18:16 +0000 Subject: [PATCH 3/6] fix pause --- vm/devices/vmbus/vmbus_client/src/lib.rs | 146 ++++++++++++++--------- 1 file changed, 88 insertions(+), 58 deletions(-) diff --git a/vm/devices/vmbus/vmbus_client/src/lib.rs b/vm/devices/vmbus/vmbus_client/src/lib.rs index 74007c9654..b11926e3db 100644 --- a/vm/devices/vmbus/vmbus_client/src/lib.rs +++ b/vm/devices/vmbus/vmbus_client/src/lib.rs @@ -14,6 +14,7 @@ use futures::FutureExt; use futures::StreamExt; use futures::future::OptionFuture; use futures::stream::SelectAll; +use futures_concurrency::future::Race; use guid::Guid; use inspect::Inspect; use mesh::rpc::FailableRpc; @@ -34,7 +35,6 @@ use std::pin::pin; use std::sync::Arc; use std::task::Context; use std::task::Poll; -use std::task::ready; use thiserror::Error; use vmbus_async::async_dgram::AsyncRecv; use vmbus_async::async_dgram::AsyncRecvExt; @@ -152,6 +152,7 @@ impl VmbusClientBuilder { messages: OutgoingMessages { poster: self.msg_client, queued: VecDeque::new(), + state: OutgoingMessageState::Paused, }, teardown_gpadls: HashMap::new(), channel_requests: SelectAll::new(), @@ -1393,6 +1394,7 @@ impl ClientTask { fn handle_start(&mut self) { assert!(!self.running); self.msg_source.resume_message_stream(); + self.inner.messages.resume(); self.running = true; } @@ -1400,62 +1402,45 @@ impl ClientTask { assert!(self.running); loop { - // Wait until there are no more channels waiting for responses. This - // is necessary to ensure that the saved state does not have to - // support encoding revoked channels for which we are waiting for - // GPADL or modify responses. - if let Some((id, request)) = self.channels.revoked_channel_with_pending_request() { - tracing::info!( + // Process messages until there are no more channels waiting for + // responses. This is necessary to ensure that the saved state does + // not have to support encoding revoked channels for which we are + // waiting for GPADL or modify responses. + while let Some((id, request)) = self.channels.revoked_channel_with_pending_request() { + tracelimit::info_ratelimited!( channel_id = id.0, request, "waiting for responses for channel" ); assert!(self.process_next_message().await); - continue; } if self.can_pause_resume() { - // Send a pause and flush any queued messages to ensure the host - // sees it. - self.inner.messages.send(&protocol::Pause {}); - self.inner.messages.flush_messages().await; - // Push the resume message onto the queue now. This ensures the - // resume message is sent before any other messages, that new - // messages sent during processing below will be queued rather - // than sent immediately, and it means we don't need to save the - // paused state in the saved state. - self.inner - .messages - .queued - .push_back(OutgoingMessage::new(&protocol::Resume)); + self.inner.messages.pause(); } else { // Mask the sint to pause the message stream. The host will // retry any queued messages after the sint is unmasked. self.msg_source.pause_message_stream(); + self.inner.messages.force_pause(); } - while self.process_next_message().await { - // Continue processing messages until we hit EOF. - } - - // Flush any pending outgoing messages. This needs to be done with - // the incoming message stream active; otherwise, the host may stop - // reading our sent messages. - // - // FUTURE: We can save these pending messages instead, but older - // versions of OpenHCL cannot restore them. Remove this code once - // those older versions are no longer supported (e.g. late 2025). - // - // When pause/resume is supported, we assume that we can save - // pending messages safely, though, since a rollback to a version - // that doesn't support pause/resume will not be able to restore the - // paused state anyway. - if self.inner.messages.is_empty() || self.can_pause_resume() { + // Continue processing messages until we hit EOF or get a pause + // response. + while self.process_next_message().await {} + + // Ensure there are still no pending requests. If there are, resume + // and go around again. + if self + .channels + .revoked_channel_with_pending_request() + .is_none() + { break; } - tracing::info!("flushing outgoing messages"); - self.msg_source.resume_message_stream(); - self.inner.messages.flush_messages().await; + if !self.can_pause_resume() { + self.msg_source.resume_message_stream(); + } + self.inner.messages.resume(); } tracing::debug!("messages drained"); @@ -1464,19 +1449,21 @@ impl ClientTask { } async fn process_next_message(&mut self) -> bool { - // Process messages until we hit EOF. - tracing::debug!("draining messages"); let mut buf = [0; protocol::MAX_MESSAGE_SIZE]; - let size = self - .msg_source - .recv(&mut buf) + let recv = self.msg_source.recv(&mut buf); + // Concurrently flush until there is no more work to do, since pending + // messages may be blocking responses from the host. + let flush = async { + self.inner.messages.flush_messages().await; + std::future::pending().await + }; + let size = (recv, flush) + .race() .await .expect("Fatal error reading messages from synic"); - if size == 0 { return false; } - self.handle_synic_message(&buf[..size]) } @@ -1613,6 +1600,14 @@ struct OutgoingMessages { poster: Box, #[inspect(with = "|x| x.len()")] queued: VecDeque, + state: OutgoingMessageState, +} + +#[derive(Inspect, PartialEq, Eq, Debug)] +enum OutgoingMessageState { + Running, + SendingPauseMessage, + Paused, } impl OutgoingMessages { @@ -1632,7 +1627,7 @@ impl OutgoingMessages { ) { tracing::trace!(typ = ?T::MESSAGE_TYPE, "Sending message to host"); let msg = OutgoingMessage::with_data(msg, data); - if self.queued.is_empty() { + if self.queued.is_empty() && self.state == OutgoingMessageState::Running { let r = self.poster.poll_post_message( &mut Context::from_waker(std::task::Waker::noop()), protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID, @@ -1648,20 +1643,55 @@ impl OutgoingMessages { } async fn flush_messages(&mut self) { - poll_fn(|cx| { - while let Some(msg) = self.queued.front() { - ready!(self.poster.poll_post_message( + let mut send = async |msg: &OutgoingMessage| { + poll_fn(|cx| { + self.poster.poll_post_message( cx, protocol::VMBUS_MESSAGE_REDIRECT_CONNECTION_ID, 1, msg.data(), - )); - tracing::trace!("sent queued message"); - self.queued.pop_front(); + ) + }) + .await + }; + match self.state { + OutgoingMessageState::Running => { + while let Some(msg) = self.queued.front() { + send(msg).await; + tracing::trace!("sent queued message"); + self.queued.pop_front(); + } } - Poll::Ready(()) - }) - .await + OutgoingMessageState::SendingPauseMessage => { + send(&OutgoingMessage::new(&protocol::Pause)).await; + tracing::trace!("sent pause message"); + self.state = OutgoingMessageState::Paused; + } + OutgoingMessageState::Paused => {} + } + } + + /// Pause by sending a pause message to the host. This will cause the host + /// to stop sending messages after sending a pause response. + fn pause(&mut self) { + assert_eq!(self.state, OutgoingMessageState::Running); + self.state = OutgoingMessageState::SendingPauseMessage; + // Queue a resume message to be sent later. + self.queued + .push_front(OutgoingMessage::new(&protocol::Resume)); + } + + /// Force a pause by setting the state to Paused. This is used when the + /// host does not support in-band pause/resume messages, in which case + /// the SINT is masked to force the host to stop sending messages. + fn force_pause(&mut self) { + assert_eq!(self.state, OutgoingMessageState::Running); + self.state = OutgoingMessageState::Paused; + } + + fn resume(&mut self) { + assert_eq!(self.state, OutgoingMessageState::Paused); + self.state = OutgoingMessageState::Running; } fn is_empty(&self) -> bool { From 0e6c7a1150023e87229a83d80923351c984b4b34 Mon Sep 17 00:00:00 2001 From: John Starks Date: Wed, 26 Mar 2025 18:31:08 +0000 Subject: [PATCH 4/6] remove out of date comment --- vm/devices/vmbus/vmbus_client/src/saved_state.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/vm/devices/vmbus/vmbus_client/src/saved_state.rs b/vm/devices/vmbus/vmbus_client/src/saved_state.rs index e71b3e41bb..e0be2e6139 100644 --- a/vm/devices/vmbus/vmbus_client/src/saved_state.rs +++ b/vm/devices/vmbus/vmbus_client/src/saved_state.rs @@ -240,9 +240,6 @@ pub struct SavedState { pub channels: Vec, #[mesh(3)] pub gpadls: Vec, - /// Added in Feb 2025, but not yet used in practice (we flush pending - /// messages during stop) since we need to support restoring on older - /// versions. #[mesh(4)] pub pending_messages: Vec, } From 722e876f2a08bc952eb60a8dc981a512666f6258 Mon Sep 17 00:00:00 2001 From: John Starks Date: Wed, 9 Apr 2025 01:07:33 +0000 Subject: [PATCH 5/6] fmt --- vm/devices/vmbus/vmbus_client/src/saved_state.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vm/devices/vmbus/vmbus_client/src/saved_state.rs b/vm/devices/vmbus/vmbus_client/src/saved_state.rs index e0be2e6139..29dbba83ff 100644 --- a/vm/devices/vmbus/vmbus_client/src/saved_state.rs +++ b/vm/devices/vmbus/vmbus_client/src/saved_state.rs @@ -292,9 +292,7 @@ impl ChannelState { super::ChannelState::Opening { .. } => { unreachable!("Cannot save channel in opening state.") } - super::ChannelState::Restored | super::ChannelState::Opened { .. } => { - Self::Opened - } + super::ChannelState::Restored | super::ChannelState::Opened { .. } => Self::Opened, super::ChannelState::Revoked => return None, }; Some(s) From c4d74f29553af9a6033d0d95f2ad5c25fd9cfb4b Mon Sep 17 00:00:00 2001 From: John Starks Date: Thu, 10 Apr 2025 20:03:55 +0000 Subject: [PATCH 6/6] comment --- vm/devices/vmbus/vmbus_client/src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vm/devices/vmbus/vmbus_client/src/lib.rs b/vm/devices/vmbus/vmbus_client/src/lib.rs index b11926e3db..49fafaeb19 100644 --- a/vm/devices/vmbus/vmbus_client/src/lib.rs +++ b/vm/devices/vmbus/vmbus_client/src/lib.rs @@ -1727,9 +1727,9 @@ struct ChannelList( /// as well. struct ChannelRef<'a>(hash_map::OccupiedEntry<'a, ChannelId, Channel>); -/// A tag value used to prove that [`ChannelRef::try_release`] has been called. -/// This is useful to put as the return value for methods that could possibly -/// transition a channel into a fully released state. +/// A tag value used to indicate that [`ChannelRef::try_release`] has been called. +/// This is useful as a return value for methods that might transition a channel +/// into a fully released state. struct TriedRelease(()); impl ChannelRef<'_> {