diff --git a/vm/devices/vmbus/vmbus_server/src/channels.rs b/vm/devices/vmbus/vmbus_server/src/channels.rs index a85ddb503f..a1b254d917 100644 --- a/vm/devices/vmbus/vmbus_server/src/channels.rs +++ b/vm/devices/vmbus/vmbus_server/src/channels.rs @@ -326,7 +326,6 @@ pub struct ModifyConnectionRequest { pub monitor_page: Update, pub interrupt_page: Update, pub target_message_vp: Option, - pub force: bool, pub notify_relay: bool, } @@ -338,7 +337,6 @@ impl Default for ModifyConnectionRequest { monitor_page: Update::Unchanged, interrupt_page: Update::Unchanged, target_message_vp: None, - force: false, notify_relay: true, } } @@ -398,8 +396,8 @@ enum RestoreState { /// The channel has been offered newly this session. New, /// The channel was in the saved state and has been re-offered this session, - /// but restore_channel has not yet been called on it, and post_restore has - /// not yet been called. + /// but restore_channel has not yet been called on it, and revoke_unclaimed_channels + /// has not yet been called. Restoring, /// The channel was in the saved state but has not yet been re-offered this /// session. @@ -1593,7 +1591,9 @@ impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> { Ok(()) } - pub fn post_restore(&mut self) -> Result<(), RestoreError> { + /// Revoke and reoffer channels to the guest, depending on their `RestoreState.` + /// This function should be called after [`ServerWithNotifier::restore`]. + pub fn revoke_unclaimed_channels(&mut self) { for (offer_id, channel) in self.inner.channels.iter_mut() { match channel.restore_state { RestoreState::Restored => { @@ -1603,7 +1603,7 @@ impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> { // This is a fresh channel offer, not in the saved state. // Send the offer to the guest if it has not already been // sent (which could have happened if the channel was - // offered after restore() but before post_restore()). + // offered after restore() but before revoke_unclaimed_channels()). if let ConnectionState::Connected(info) = &self.inner.state { if matches!(channel.state, ChannelState::ClientReleased) { channel.prepare_channel( @@ -1673,42 +1673,7 @@ impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> { } } - // Restore server state, and resend server notifications if needed. If these notifications - // were processed before the save, it's harmless as the values will be the same. - let request = match self.inner.state { - ConnectionState::Connecting { - info, - next_action: _, - } => Some(ModifyConnectionRequest { - version: Some(info.version.version as u32), - interrupt_page: info.interrupt_page.into(), - monitor_page: info.monitor_page.into(), - target_message_vp: Some(info.target_message_vp), - force: true, - notify_relay: true, - }), - ConnectionState::Connected(info) => Some(ModifyConnectionRequest { - version: None, - monitor_page: info.monitor_page.into(), - interrupt_page: info.interrupt_page.into(), - target_message_vp: Some(info.target_message_vp), - force: true, - // If the save didn't happen while modifying, the relay doesn't need to be notified - // of this info as it doesn't constitute a change, we're just restoring existing - // connection state. - notify_relay: info.modifying, - }), - // No action needed for these states; if disconnecting, check_disconnected will resend - // the reset request if needed. - ConnectionState::Disconnected | ConnectionState::Disconnecting { .. } => None, - }; - - if let Some(request) = request { - self.notifier.modify_connection(request)?; - } - self.check_disconnected(); - Ok(()) } /// Initiates a state reset and a closing of all channels. @@ -1755,7 +1720,7 @@ impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> { assert!(!matches!(channel.state, ChannelState::Revoked)); // This channel was previously offered to the guest in the saved // state. Match this back up to handle future calls to - // restore_channel and post_restore. + // restore_channel and revoke_unclaimed_channels. channel.restore_state = RestoreState::Restoring; // The relay can specify a host-determined monitor ID, which needs to match what's @@ -2252,7 +2217,6 @@ impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> { monitor_page: monitor_page.into(), interrupt_page: request.interrupt_page.into(), target_message_vp: Some(request.target_message_vp), - force: false, notify_relay: true, }) { tracelimit::error_ratelimited!(?err, "server failed to change state"); @@ -4924,9 +4888,8 @@ mod tests { let state = env.server.save(); env.c().reset(); assert!(env.notifier.is_reset()); - env.server.restore(state).unwrap(); + env.c().restore(state).unwrap(); env.c().restore_channel(offer_id1, false).unwrap(); - env.c().post_restore().unwrap(); } #[test] @@ -5012,7 +4975,7 @@ mod tests { env.c().revoke_channel(offer_id5); env.c().revoke_channel(offer_id6); - env.server.restore(state.clone()).unwrap(); + env.c().restore(state.clone()).unwrap(); env.c().revoke_channel(offer_id1); env.c().revoke_channel(offer_id4); @@ -5028,7 +4991,7 @@ mod tests { ChannelState::Reoffered )); - env.c().post_restore().unwrap(); + env.c().revoke_unclaimed_channels(); assert_eq!(env.notifier.monitor_page, Some(expected_monitor)); assert_eq!(env.notifier.target_message_vp, Some(0)); @@ -5055,9 +5018,8 @@ mod tests { env.complete_reset(); env.notifier.check_reset(); - env.server.restore(state).unwrap(); + env.c().restore(state).unwrap(); env.c().restore_channel(offer_id3, false).unwrap(); - env.c().post_restore().unwrap(); assert_eq!(env.notifier.monitor_page, Some(expected_monitor)); assert_eq!(env.notifier.target_message_vp, Some(0)); } @@ -5085,9 +5047,8 @@ mod tests { env.complete_connect(); env.notifier.check_reset(); - env.server.restore(state).unwrap(); + env.c().restore(state).unwrap(); env.c().restore_channel(offer_id1, false).unwrap(); - env.c().post_restore().unwrap(); assert_eq!( env.notifier.monitor_page, Some(MonitorPageGpas { @@ -5108,7 +5069,6 @@ mod tests { }), interrupt_page: Update::Reset, target_message_vp: Some(0), - force: true, ..Default::default() } ); @@ -5148,8 +5108,8 @@ mod tests { let state = env.server.save(); env.c().reset(); env.notifier.check_reset(); - env.server.restore(state).unwrap(); - env.c().post_restore().unwrap(); + + env.c().restore(state).unwrap(); // Restore should have resent the request. let request = env.next_action(); @@ -5162,7 +5122,6 @@ mod tests { }), interrupt_page: Update::Reset, target_message_vp: Some(0), - force: true, ..Default::default() } ); @@ -5204,13 +5163,13 @@ mod tests { let offer_id1 = env.offer(1); let offer_id2 = env.offer(2); let offer_id3 = env.offer(3); - env.server.restore(state).unwrap(); + + env.c().restore(state).unwrap(); // This will panic if the reserved channel was not restored. env.c().restore_channel(offer_id1, true).unwrap(); env.c().restore_channel(offer_id2, false).unwrap(); env.c().restore_channel(offer_id3, false).unwrap(); - env.c().post_restore().unwrap(); // Make sure the gpadl was restored as well. assert!(env.server.gpadls.contains_key(&(GpadlId(1), offer_id1))); @@ -5258,11 +5217,11 @@ mod tests { let offer_id1 = env.offer(1); let offer_id2 = env.offer(2); let offer_id3 = env.offer(3); - env.server.restore(state).unwrap(); + + env.c().restore(state).unwrap(); env.c().restore_channel(offer_id1, false).unwrap(); env.c().restore_channel(offer_id2, true).unwrap(); env.c().restore_channel(offer_id3, true).unwrap(); - env.c().post_restore().unwrap(); // The messages should be pending again. assert!(env.server.has_pending_messages()); diff --git a/vm/devices/vmbus/vmbus_server/src/channels/saved_state.rs b/vm/devices/vmbus/vmbus_server/src/channels/saved_state.rs index 1fed5c8e2e..d84fd169ff 100644 --- a/vm/devices/vmbus/vmbus_server/src/channels/saved_state.rs +++ b/vm/devices/vmbus/vmbus_server/src/channels/saved_state.rs @@ -2,6 +2,7 @@ // Licensed under the MIT License. use super::MnfUsage; +use super::Notifier; use super::OfferError; use super::OfferParamsInternal; use super::OfferedInfo; @@ -21,52 +22,6 @@ use vmbus_ring::gparange::MultiPagedRangeBuf; use vmcore::monitor::MonitorId; impl super::Server { - /// Restores state. - /// - /// This may be called before or after channels have been offered. After - /// calling this routine, [`ServerWithNotifier::restore_channel`] should be - /// called for each channel to be restored, possibly interleaved with - /// additional calls to offer or revoke channels. - /// - /// Once all channels are in the appropriate state, - /// [`ServerWithNotifier::post_restore`] should be called. This will revoke - /// any channels that were in the saved state but were not restored via - /// `restore_channel`. - pub fn restore(&mut self, saved: SavedState) -> Result<(), RestoreError> { - tracing::trace!(?saved, "restoring channel state"); - - if let Some(saved) = saved.state { - self.state = saved.connection.restore()?; - - for saved_channel in saved.channels { - self.restore_one_channel(saved_channel)?; - } - - for saved_gpadl in saved.gpadls { - self.restore_one_gpadl(saved_gpadl)?; - } - } else if let Some(saved) = saved.disconnected_state { - self.state = super::ConnectionState::Disconnected; - for saved_channel in saved.reserved_channels { - self.restore_one_channel(saved_channel)?; - } - - for saved_gpadl in saved.reserved_gpadls { - self.restore_one_gpadl(saved_gpadl)?; - } - } - - self.pending_messages - .0 - .reserve(saved.pending_messages.len()); - - for message in saved.pending_messages { - self.pending_messages.0.push_back(message.restore()?); - } - - Ok(()) - } - fn restore_one_channel(&mut self, saved_channel: Channel) -> Result<(), RestoreError> { let (info, stub_offer, state) = saved_channel.restore()?; if let Some((offer_id, channel)) = self.channels.get_by_key_mut(&saved_channel.key) { @@ -211,6 +166,88 @@ impl super::Server { } } +impl<'a, N: 'a + Notifier> super::ServerWithNotifier<'a, N> { + /// Restores state. + /// + /// This may be called before or after channels have been offered. After + /// calling this routine, [`super::ServerWithNotifier::restore_channel`] should be + /// called for each channel to be restored, possibly interleaved with + /// additional calls to offer or revoke channels. + /// + /// Once all channels are in the appropriate state, + /// [`super::ServerWithNotifier::revoke_unclaimed_channels`] should be called. This will revoke + /// any channels that were in the saved state but were not restored via + /// `restore_channel`. + pub fn restore(&mut self, saved: SavedState) -> Result<(), RestoreError> { + tracing::trace!(?saved, "restoring channel state"); + + if let Some(saved) = saved.state { + self.inner.state = saved.connection.restore()?; + + // Restore server state, and resend server notifications if needed. If these notifications + // were processed before the save, it's harmless as the values will be the same. + let request = match self.inner.state { + super::ConnectionState::Connecting { + info, + next_action: _, + } => Some(super::ModifyConnectionRequest { + version: Some(info.version.version as u32), + interrupt_page: info.interrupt_page.into(), + monitor_page: info.monitor_page.into(), + target_message_vp: Some(info.target_message_vp), + notify_relay: true, + }), + super::ConnectionState::Connected(info) => Some(super::ModifyConnectionRequest { + version: None, + monitor_page: info.monitor_page.into(), + interrupt_page: info.interrupt_page.into(), + target_message_vp: Some(info.target_message_vp), + // If the save didn't happen while modifying, the relay doesn't need to be notified + // of this info as it doesn't constitute a change, we're just restoring existing + // connection state. + notify_relay: info.modifying, + }), + // No action needed for these states; if disconnecting, check_disconnected will resend + // the reset request if needed. + super::ConnectionState::Disconnected + | super::ConnectionState::Disconnecting { .. } => None, + }; + + if let Some(request) = request { + self.notifier.modify_connection(request)?; + } + + for saved_channel in saved.channels { + self.inner.restore_one_channel(saved_channel)?; + } + + for saved_gpadl in saved.gpadls { + self.inner.restore_one_gpadl(saved_gpadl)?; + } + } else if let Some(saved) = saved.disconnected_state { + self.inner.state = super::ConnectionState::Disconnected; + for saved_channel in saved.reserved_channels { + self.inner.restore_one_channel(saved_channel)?; + } + + for saved_gpadl in saved.reserved_gpadls { + self.inner.restore_one_gpadl(saved_gpadl)?; + } + } + + self.inner + .pending_messages + .0 + .reserve(saved.pending_messages.len()); + + for message in saved.pending_messages { + self.inner.pending_messages.0.push_back(message.restore()?); + } + + Ok(()) + } +} + #[derive(Debug, Error)] pub enum RestoreError { #[error(transparent)] diff --git a/vm/devices/vmbus/vmbus_server/src/lib.rs b/vm/devices/vmbus/vmbus_server/src/lib.rs index 10aa8ac96d..700eeb2b67 100644 --- a/vm/devices/vmbus/vmbus_server/src/lib.rs +++ b/vm/devices/vmbus/vmbus_server/src/lib.rs @@ -207,7 +207,6 @@ enum VmbusRequest { Inspect(inspect::Deferred), Save(Rpc<(), SavedState>), Restore(Rpc>), - PostRestore(Rpc<(), Result<(), RestoreError>>), Start, Stop(Rpc<(), ()>), } @@ -559,13 +558,6 @@ impl VmbusServer { .unwrap() } - pub async fn post_restore(&self) -> Result<(), RestoreError> { - self.task_send - .call(VmbusRequest::PostRestore, ()) - .await - .unwrap() - } - /// Stop the control plane. pub async fn stop(&self) { self.task_send.call(VmbusRequest::Stop, ()).await.unwrap() @@ -958,11 +950,10 @@ impl ServerTask { }), VmbusRequest::Restore(rpc) => rpc.handle_sync(|state| { self.unstick_on_start = !state.lost_synic_bug_fixed; - self.server.restore(state.server) + self.server + .with_notifier(&mut self.inner) + .restore(state.server) }), - VmbusRequest::PostRestore(rpc) => { - rpc.handle_sync(|()| self.server.with_notifier(&mut self.inner).post_restore()) - } VmbusRequest::Stop(rpc) => rpc.handle_sync(|()| { if self.inner.running { self.inner.running = false; @@ -971,6 +962,9 @@ impl ServerTask { VmbusRequest::Start => { if !self.inner.running { self.inner.running = true; + self.server + .with_notifier(&mut self.inner) + .revoke_unclaimed_channels(); if self.unstick_on_start { tracing::info!( "lost synic bug fix is not in yet, call unstick_channels to mitigate the issue." @@ -1397,7 +1391,7 @@ impl Notifier for ServerTaskInner { self.map_interrupt_page(request.interrupt_page) .context("Failed to map interrupt page.")?; - self.set_monitor_page(request.monitor_page, request.force) + self.set_monitor_page(request.monitor_page) .context("Failed to map monitor page.")?; if let Some(vp) = request.target_message_vp { @@ -1710,33 +1704,26 @@ impl ServerTaskInner { Ok(()) } - fn set_monitor_page( - &mut self, - monitor_page: Update, - force: bool, - ) -> anyhow::Result<()> { + fn set_monitor_page(&mut self, monitor_page: Update) -> anyhow::Result<()> { let monitor_page = match monitor_page { Update::Unchanged => return Ok(()), Update::Reset => None, Update::Set(value) => Some(value), }; - // Force is used by restore because there may be restored channels in the open state. // TODO: can this check be moved into channels.rs? - if !force - && self.channels.iter().any(|(_, c)| { - matches!( - &c.state, - ChannelState::Open { - open_params, - .. - } | ChannelState::Opening { - open_params, - .. - } if open_params.monitor_info.is_some() - ) - }) - { + if self.channels.iter().any(|(_, c)| { + matches!( + &c.state, + ChannelState::Open { + open_params, + .. + } | ChannelState::Opening { + open_params, + .. + } if open_params.monitor_info.is_some() + ) + }) { anyhow::bail!("attempt to change monitor page while open channels using mnf"); } @@ -2407,7 +2394,6 @@ mod tests { // will be repeated. This must not panic. env.vmbus.restore(saved_state).await.unwrap(); channel.restore().await; - env.vmbus.post_restore().await.unwrap(); env.vmbus.start(); // Handle the teardown after restore. diff --git a/vmm_core/src/vmbus_unit.rs b/vmm_core/src/vmbus_unit.rs index 54ec84ecaa..bcb6b0c2cb 100644 --- a/vmm_core/src/vmbus_unit.rs +++ b/vmm_core/src/vmbus_unit.rs @@ -100,11 +100,6 @@ impl StateUnit for &'_ VmbusServerUnit { .await .map_err(|err| RestoreError::Other(err.into())) } - - async fn post_restore(&mut self) -> anyhow::Result<()> { - self.0.post_restore().await?; - Ok(()) - } } /// A type wrapping a [`ChannelHandle`] and implementing [`StateUnit`].