diff --git a/Cargo.toml b/Cargo.toml index 14c18c2a..1232de25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,4 +3,5 @@ members = [ "crates/vhost", "crates/vhost-user-backend", + "crates/vhost-user-frontend", ] diff --git a/crates/vhost-user-frontend/Cargo.toml b/crates/vhost-user-frontend/Cargo.toml new file mode 100644 index 00000000..233ee88a --- /dev/null +++ b/crates/vhost-user-frontend/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "vhost-user-frontend" +version = "0.1.0" +authors = ["Viresh Kumar "] +keywords = ["vhost-user", "virtio", "frontend"] +description = "vhost user frontend" +license = "Apache-2.0 OR BSD-3-Clause" +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +epoll = "4.3.1" +libc = "0.2.118" +log = "0.4.14" +seccompiler = "0.2.0" +thiserror = "1.0" +vhost = { path = "../vhost/", version = "0.5", features = ["vhost-user-master", "vhost-kern", "vhost-user-slave"] } +virtio-bindings = { version = ">=0.1.0", features = ["virtio-v5_0_0"] } +virtio-queue = "0.6" +vm-memory = { version = "0.9", features = ["backend-mmap", "backend-atomic", "backend-bitmap"] } +vmm-sys-util = "0.10.0" diff --git a/crates/vhost-user-frontend/src/device.rs b/crates/vhost-user-frontend/src/device.rs new file mode 100644 index 00000000..349ad0f2 --- /dev/null +++ b/crates/vhost-user-frontend/src/device.rs @@ -0,0 +1,308 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE-BSD-3-Clause file. +// +// Copyright © 2019 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause + +use crate::{ + AccessPlatform, ActivateError, ActivateResult, Error, GuestMemoryMmap, GuestRegionMmap, + VirtioDeviceType, VIRTIO_F_RING_INDIRECT_DESC, +}; +use libc::EFD_NONBLOCK; +use std::collections::HashMap; +use std::io::Write; +use std::num::Wrapping; +use std::sync::{atomic::AtomicBool, Arc, Barrier}; +use std::thread; +use virtio_queue::Queue; +use vm_memory::{GuestAddress, GuestMemoryAtomic, GuestUsize}; +use vmm_sys_util::eventfd::EventFd; + +pub enum VirtioInterruptType { + Config, + Queue(u16), +} + +pub trait VirtioInterrupt: Send + Sync { + fn trigger(&self, int_type: VirtioInterruptType) -> std::result::Result<(), std::io::Error>; + fn notifier(&self, _int_type: VirtioInterruptType) -> Option { + None + } +} + +#[derive(Clone)] +pub struct UserspaceMapping { + pub host_addr: u64, + pub mem_slot: u32, + pub addr: GuestAddress, + pub len: GuestUsize, + pub mergeable: bool, +} + +#[derive(Clone)] +pub struct VirtioSharedMemory { + pub offset: u64, + pub len: u64, +} + +#[derive(Clone)] +pub struct VirtioSharedMemoryList { + pub host_addr: u64, + pub mem_slot: u32, + pub addr: GuestAddress, + pub len: GuestUsize, + pub region_list: Vec, +} + +/// Trait for virtio devices to be driven by a virtio transport. +/// +/// The lifecycle of a virtio device is to be moved to a virtio transport, which will then query the +/// device. Once the guest driver has configured the device, `VirtioDevice::activate` will be called +/// and all the events, memory, and queues for device operation will be moved into the device. +/// Optionally, a virtio device can implement device reset in which it returns said resources and +/// resets its internal. +pub trait VirtioDevice: Send { + /// The virtio device type. + fn device_type(&self) -> u32; + + /// The maximum size of each queue that this device supports. + fn queue_max_sizes(&self) -> &[u16]; + + /// The set of feature bits that this device supports. + fn features(&self) -> u64 { + 0 + } + + /// Acknowledges that this set of features should be enabled. + fn ack_features(&mut self, value: u64) { + let _ = value; + } + + /// Reads this device configuration space at `offset`. + fn read_config(&self, _offset: u64, _data: &mut [u8]) { + warn!( + "No readable configuration fields for {}", + VirtioDeviceType::from(self.device_type()) + ); + } + + /// Writes to this device configuration space at `offset`. + fn write_config(&mut self, _offset: u64, _data: &[u8]) { + warn!( + "No writable configuration fields for {}", + VirtioDeviceType::from(self.device_type()) + ); + } + + /// Activates this device for real usage. + fn activate( + &mut self, + mem: GuestMemoryAtomic, + interrupt_evt: Arc, + queues: Vec<(usize, Queue, EventFd)>, + ) -> ActivateResult; + + /// Optionally deactivates this device and returns ownership of the guest memory map, interrupt + /// event, and queue events. + fn reset(&mut self) -> Option> { + None + } + + /// Returns the list of shared memory regions required by the device. + fn get_shm_regions(&self) -> Option { + None + } + + /// Updates the list of shared memory regions required by the device. + fn set_shm_regions( + &mut self, + _shm_regions: VirtioSharedMemoryList, + ) -> std::result::Result<(), Error> { + std::unimplemented!() + } + + /// Some devices may need to do some explicit shutdown work. This method + /// may be implemented to do this. The VMM should call shutdown() on + /// every device as part of shutting down the VM. Acting on the device + /// after a shutdown() can lead to unpredictable results. + fn shutdown(&mut self) {} + + fn add_memory_region( + &mut self, + _region: &Arc, + ) -> std::result::Result<(), Error> { + Ok(()) + } + + /// Returns the list of userspace mappings associated with this device. + fn userspace_mappings(&self) -> Vec { + Vec::new() + } + + /// Return the counters that this device exposes + fn counters(&self) -> Option>> { + None + } + + /// Helper to allow common implementation of read_config + fn read_config_from_slice(&self, config: &[u8], offset: u64, mut data: &mut [u8]) { + let config_len = config.len() as u64; + let data_len = data.len() as u64; + if offset + data_len > config_len { + error!( + "Out-of-bound access to configuration: config_len = {} offset = {:x} length = {} for {}", + config_len, + offset, + data_len, + self.device_type() + ); + return; + } + if let Some(end) = offset.checked_add(data.len() as u64) { + data.write_all(&config[offset as usize..std::cmp::min(end, config_len) as usize]) + .unwrap(); + } + } + + /// Helper to allow common implementation of write_config + fn write_config_helper(&self, config: &mut [u8], offset: u64, data: &[u8]) { + let config_len = config.len() as u64; + let data_len = data.len() as u64; + if offset + data_len > config_len { + error!( + "Out-of-bound access to configuration: config_len = {} offset = {:x} length = {} for {}", + config_len, + offset, + data_len, + self.device_type() + ); + return; + } + + if let Some(end) = offset.checked_add(config.len() as u64) { + let mut offset_config = + &mut config[offset as usize..std::cmp::min(end, config_len) as usize]; + offset_config.write_all(data).unwrap(); + } + } + + /// Set the access platform trait to let the device perform address + /// translations if needed. + fn set_access_platform(&mut self, _access_platform: Arc) {} +} + +/// Trait providing address translation the same way a physical DMA remapping +/// table would provide translation between an IOVA and a physical address. +/// The goal of this trait is to be used by virtio devices to perform the +/// address translation before they try to read from the guest physical address. +/// On the other side, the implementation itself should be provided by the code +/// emulating the IOMMU for the guest. +pub trait DmaRemapping { + /// Provide a way to translate GVA address ranges into GPAs. + fn translate_gva(&self, id: u32, addr: u64) -> std::result::Result; + /// Provide a way to translate GPA address ranges into GVAs. + fn translate_gpa(&self, id: u32, addr: u64) -> std::result::Result; +} + +/// Structure to handle device state common to all devices +#[derive(Default)] +pub struct VirtioCommon { + pub avail_features: u64, + pub acked_features: u64, + pub kill_evt: Option, + pub interrupt_cb: Option>, + pub pause_evt: Option, + pub paused: Arc, + pub paused_sync: Option>, + pub epoll_threads: Option>>, + pub queue_sizes: Vec, + pub device_type: u32, + pub min_queues: u16, + pub access_platform: Option>, +} + +impl VirtioCommon { + pub fn feature_acked(&self, feature: u64) -> bool { + self.acked_features & 1 << feature == 1 << feature + } + + pub fn ack_features(&mut self, value: u64) { + let mut v = value; + // Check if the guest is ACK'ing a feature that we didn't claim to have. + let unrequested_features = v & !self.avail_features; + if unrequested_features != 0 { + warn!("Received acknowledge request for unknown feature."); + + // Don't count these features as acked. + v &= !unrequested_features; + } + self.acked_features |= v; + } + + pub fn activate( + &mut self, + queues: &[(usize, Queue, EventFd)], + interrupt_cb: &Arc, + ) -> ActivateResult { + if queues.len() < self.min_queues.into() { + error!( + "Number of enabled queues lower than min: {} vs {}", + queues.len(), + self.min_queues + ); + return Err(ActivateError::BadActivate); + } + + let kill_evt = EventFd::new(EFD_NONBLOCK).map_err(|e| { + error!("failed creating kill EventFd: {}", e); + ActivateError::BadActivate + })?; + self.kill_evt = Some(kill_evt); + + let pause_evt = EventFd::new(EFD_NONBLOCK).map_err(|e| { + error!("failed creating pause EventFd: {}", e); + ActivateError::BadActivate + })?; + self.pause_evt = Some(pause_evt); + + // Save the interrupt EventFD as we need to return it on reset + // but clone it to pass into the thread. + self.interrupt_cb = Some(interrupt_cb.clone()); + + Ok(()) + } + + pub fn reset(&mut self) -> Option> { + if let Some(kill_evt) = self.kill_evt.take() { + // Ignore the result because there is nothing we can do about it. + let _ = kill_evt.write(1); + } + + if let Some(mut threads) = self.epoll_threads.take() { + for t in threads.drain(..) { + if let Err(e) = t.join() { + error!("Error joining thread: {:?}", e); + } + } + } + + // Return the interrupt + Some(self.interrupt_cb.take().unwrap()) + } + + pub fn dup_eventfds(&self) -> (EventFd, EventFd) { + ( + self.kill_evt.as_ref().unwrap().try_clone().unwrap(), + self.pause_evt.as_ref().unwrap().try_clone().unwrap(), + ) + } + + pub fn set_access_platform(&mut self, access_platform: Arc) { + self.access_platform = Some(access_platform); + // Indirect descriptors feature is not supported when the device + // requires the addresses held by the descriptors to be translated. + self.avail_features &= !(1 << VIRTIO_F_RING_INDIRECT_DESC); + } +} diff --git a/crates/vhost-user-frontend/src/epoll_helper.rs b/crates/vhost-user-frontend/src/epoll_helper.rs new file mode 100644 index 00000000..2588ef0a --- /dev/null +++ b/crates/vhost-user-frontend/src/epoll_helper.rs @@ -0,0 +1,172 @@ +// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Portions Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE-BSD-3-Clause file. +// +// Copyright © 2020 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause + +use std::fs::File; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Barrier}; +use std::thread; +use vmm_sys_util::eventfd::EventFd; + +pub struct EpollHelper { + pause_evt: EventFd, + epoll_file: File, +} + +#[derive(Debug)] +pub enum EpollHelperError { + CreateFd(std::io::Error), + Ctl(std::io::Error), + IoError(std::io::Error), + Wait(std::io::Error), + QueueRingIndex(virtio_queue::Error), +} + +pub const EPOLL_HELPER_EVENT_PAUSE: u16 = 0; +pub const EPOLL_HELPER_EVENT_KILL: u16 = 1; +pub const EPOLL_HELPER_EVENT_LAST: u16 = 15; + +pub trait EpollHelperHandler { + // Return true if the loop execution should be stopped + fn handle_event(&mut self, helper: &mut EpollHelper, event: &epoll::Event) -> bool; +} + +impl EpollHelper { + pub fn new( + kill_evt: &EventFd, + pause_evt: &EventFd, + ) -> std::result::Result { + // Create the epoll file descriptor + let epoll_fd = epoll::create(true).map_err(EpollHelperError::CreateFd)?; + // Use 'File' to enforce closing on 'epoll_fd' + let epoll_file = unsafe { File::from_raw_fd(epoll_fd) }; + + let mut helper = Self { + pause_evt: pause_evt.try_clone().unwrap(), + epoll_file, + }; + + helper.add_event(kill_evt.as_raw_fd(), EPOLL_HELPER_EVENT_KILL)?; + helper.add_event(pause_evt.as_raw_fd(), EPOLL_HELPER_EVENT_PAUSE)?; + Ok(helper) + } + + pub fn add_event(&mut self, fd: RawFd, id: u16) -> std::result::Result<(), EpollHelperError> { + self.add_event_custom(fd, id, epoll::Events::EPOLLIN) + } + + pub fn add_event_custom( + &mut self, + fd: RawFd, + id: u16, + evts: epoll::Events, + ) -> std::result::Result<(), EpollHelperError> { + epoll::ctl( + self.epoll_file.as_raw_fd(), + epoll::ControlOptions::EPOLL_CTL_ADD, + fd, + epoll::Event::new(evts, id.into()), + ) + .map_err(EpollHelperError::Ctl) + } + + pub fn del_event_custom( + &mut self, + fd: RawFd, + id: u16, + evts: epoll::Events, + ) -> std::result::Result<(), EpollHelperError> { + epoll::ctl( + self.epoll_file.as_raw_fd(), + epoll::ControlOptions::EPOLL_CTL_DEL, + fd, + epoll::Event::new(evts, id.into()), + ) + .map_err(EpollHelperError::Ctl) + } + + pub fn run( + &mut self, + paused: Arc, + paused_sync: Arc, + handler: &mut dyn EpollHelperHandler, + ) -> std::result::Result<(), EpollHelperError> { + const EPOLL_EVENTS_LEN: usize = 100; + let mut events = vec![epoll::Event::new(epoll::Events::empty(), 0); EPOLL_EVENTS_LEN]; + + // Before jumping into the epoll loop, check if the device is expected + // to be in a paused state. This is helpful for the restore code path + // as the device thread should not start processing anything before the + // device has been resumed. + while paused.load(Ordering::SeqCst) { + thread::park(); + } + + loop { + let num_events = match epoll::wait(self.epoll_file.as_raw_fd(), -1, &mut events[..]) { + Ok(res) => res, + Err(e) => { + if e.kind() == std::io::ErrorKind::Interrupted { + // It's well defined from the epoll_wait() syscall + // documentation that the epoll loop can be interrupted + // before any of the requested events occurred or the + // timeout expired. In both those cases, epoll_wait() + // returns an error of type EINTR, but this should not + // be considered as a regular error. Instead it is more + // appropriate to retry, by calling into epoll_wait(). + continue; + } + return Err(EpollHelperError::Wait(e)); + } + }; + + for event in events.iter().take(num_events) { + let ev_type = event.data as u16; + + match ev_type { + EPOLL_HELPER_EVENT_KILL => { + info!("KILL_EVENT received, stopping epoll loop"); + return Ok(()); + } + EPOLL_HELPER_EVENT_PAUSE => { + info!("PAUSE_EVENT received, pausing epoll loop"); + + // Acknowledge the pause is effective by using the + // paused_sync barrier. + paused_sync.wait(); + + // We loop here to handle spurious park() returns. + // Until we have not resumed, the paused boolean will + // be true. + while paused.load(Ordering::SeqCst) { + thread::park(); + } + + // Drain pause event after the device has been resumed. + // This ensures the pause event has been seen by each + // thread related to this virtio device. + let _ = self.pause_evt.read(); + } + _ => { + if handler.handle_event(self, event) { + return Ok(()); + } + } + } + } + } + } +} + +impl AsRawFd for EpollHelper { + fn as_raw_fd(&self) -> RawFd { + self.epoll_file.as_raw_fd() + } +} diff --git a/crates/vhost-user-frontend/src/generic.rs b/crates/vhost-user-frontend/src/generic.rs new file mode 100644 index 00000000..89804d2d --- /dev/null +++ b/crates/vhost-user-frontend/src/generic.rs @@ -0,0 +1,293 @@ +// Copyright 2022 Linaro Ltd. All Rights Reserved. +// Viresh Kumar +// +// SPDX-License-Identifier: Apache-2.0 + +use seccompiler::SeccompAction; +use std::sync::{Arc, Barrier, Mutex}; +use std::thread; +use std::vec::Vec; + +use vhost::vhost_user::message::{ + VhostUserConfigFlags, VhostUserProtocolFeatures, VhostUserVirtioFeatures, +}; +use vhost::vhost_user::{MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler}; +use virtio_queue::Queue; +use vm_memory::GuestMemoryAtomic; +use vmm_sys_util::eventfd::EventFd; + +use crate::{ + spawn_virtio_thread, ActivateResult, Error, GuestMemoryMmap, GuestRegionMmap, Result, Thread, + VhostUserCommon, VhostUserConfig, VhostUserHandle, VirtioCommon, VirtioDevice, + VirtioDeviceType, VirtioInterrupt, +}; + +const MIN_NUM_QUEUES: usize = 1; + +pub struct State { + pub avail_features: u64, + pub acked_features: u64, + pub acked_protocol_features: u64, + pub vu_num_queues: usize, +} + +struct SlaveReqHandler {} +impl VhostUserMasterReqHandler for SlaveReqHandler {} + +pub struct Generic { + common: VirtioCommon, + vu_common: VhostUserCommon, + id: String, + guest_memory: Option>, + epoll_thread: Option>, + seccomp_action: SeccompAction, + exit_evt: EventFd, + device_features: u64, + num_queues: u32, + name: String, +} + +impl Generic { + /// Create a new vhost-user-blk device + pub fn new( + vu_cfg: VhostUserConfig, + seccomp_action: SeccompAction, + exit_evt: EventFd, + device_type: VirtioDeviceType, + ) -> Result { + let num_queues = vu_cfg.num_queues; + + let vu = + VhostUserHandle::connect_vhost_user(false, &vu_cfg.socket, num_queues as u64, false)?; + let device_features = vu.device_features()?; + + Ok(Generic { + common: VirtioCommon { + device_type: device_type as u32, + queue_sizes: vec![vu_cfg.queue_size; num_queues], + avail_features: 0, + acked_features: 0, + paused_sync: Some(Arc::new(Barrier::new(2))), + min_queues: MIN_NUM_QUEUES as u16, + ..Default::default() + }, + vu_common: VhostUserCommon { + vu: Some(Arc::new(Mutex::new(vu))), + acked_protocol_features: 0, + socket_path: vu_cfg.socket, + vu_num_queues: num_queues, + ..Default::default() + }, + id: "generic_device".to_string(), + guest_memory: None, + epoll_thread: None, + seccomp_action, + exit_evt, + device_features, + num_queues: 0, + name: String::from(device_type), + }) + } + + pub fn device_features(&self) -> u64 { + self.device_features + } + + pub fn name(&self) -> String { + self.name.clone() + } + + pub fn negotiate_features(&mut self, avail_features: u64) -> Result<(u64, u64)> { + let mut vu = self.vu_common.vu.as_ref().unwrap().lock().unwrap(); + let avail_protocol_features = VhostUserProtocolFeatures::MQ + | VhostUserProtocolFeatures::CONFIG + | VhostUserProtocolFeatures::REPLY_ACK; + + // Virtio spec says following for VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(): + // + // Bit 30 is used by qemu’s implementation to check for experimental early versions of + // virtio which did not perform correct feature negotiation, and SHOULD NOT be negotiated. + // + // And so Linux clears it in available features. Lets set it forcefully here to make things + // work. + + let avail_features = avail_features | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + + let (acked_features, acked_protocol_features) = + vu.negotiate_features_vhost_user(avail_features, avail_protocol_features)?; + + let backend_num_queues = + if acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0 { + vu.socket_handle() + .get_queue_num() + .map_err(Error::VhostUserGetQueueMaxNum)? as usize + } else { + MIN_NUM_QUEUES + }; + + if self.vu_common.vu_num_queues > backend_num_queues { + error!("vhost-user-device requested too many queues ({}) since the backend only supports {}\n", + self.vu_common.vu_num_queues, backend_num_queues); + return Err(Error::BadQueueNum); + } + + self.common.acked_features = acked_features; + self.vu_common.acked_protocol_features = acked_protocol_features; + self.num_queues = backend_num_queues as u32; + + Ok((acked_features, acked_protocol_features)) + } + + pub fn state(&self) -> State { + State { + avail_features: self.common.avail_features, + acked_features: self.common.acked_features, + acked_protocol_features: self.vu_common.acked_protocol_features, + vu_num_queues: self.vu_common.vu_num_queues, + } + } + + pub fn set_state(&mut self, state: &State) { + self.common.avail_features = state.avail_features; + self.common.acked_features = state.acked_features; + self.vu_common.acked_protocol_features = state.acked_protocol_features; + self.vu_common.vu_num_queues = state.vu_num_queues; + + if let Err(e) = self + .vu_common + .restore_backend_connection(self.common.acked_features) + { + error!( + "Failed restoring connection with vhost-user backend: {:?}", + e + ); + } + } +} + +impl Drop for Generic { + fn drop(&mut self) { + if let Some(kill_evt) = self.common.kill_evt.take() { + if let Err(e) = kill_evt.write(1) { + error!("failed to kill vhost-user-blk: {:?}", e); + } + } + } +} + +impl VirtioDevice for Generic { + fn device_type(&self) -> u32 { + self.common.device_type as u32 + } + + fn queue_max_sizes(&self) -> &[u16] { + &self.common.queue_sizes + } + + fn features(&self) -> u64 { + self.common.avail_features + } + + fn ack_features(&mut self, value: u64) { + self.common.ack_features(value) + } + + fn read_config(&self, offset: u64, data: &mut [u8]) { + let mut vu = self.vu_common.vu.as_ref().unwrap().lock().unwrap(); + let len = data.len(); + let config_space: Vec = vec![0u8; len]; + let (_, config_space) = vu + .socket_handle() + .get_config( + offset as u32, + len as u32, + VhostUserConfigFlags::WRITABLE, + config_space.as_slice(), + ) + .unwrap(); + + data.copy_from_slice(config_space.as_slice()); + } + + fn write_config(&mut self, offset: u64, data: &[u8]) { + let mut vu = self.vu_common.vu.as_ref().unwrap().lock().unwrap(); + vu.socket_handle() + .set_config(offset as u32, VhostUserConfigFlags::WRITABLE, data) + .unwrap(); + } + + fn activate( + &mut self, + mem: GuestMemoryAtomic, + interrupt: Arc, + queues: Vec<(usize, Queue, EventFd)>, + ) -> ActivateResult { + self.common.activate(&queues, &interrupt)?; + self.guest_memory = Some(mem.clone()); + + let slave_req_handler: Option> = None; + + // Run a dedicated thread for handling potential reconnections with + // the backend. + let (kill_evt, pause_evt) = self.common.dup_eventfds(); + + let mut handler = self.vu_common.activate( + mem, + queues, + interrupt, + self.common.acked_features, + slave_req_handler, + kill_evt, + pause_evt, + )?; + + let paused = self.common.paused.clone(); + let paused_sync = self.common.paused_sync.clone(); + + let mut epoll_threads = Vec::new(); + + spawn_virtio_thread( + &self.id, + &self.seccomp_action, + Thread::VirtioVhostBlock, + &mut epoll_threads, + &self.exit_evt, + move || { + if let Err(e) = handler.run(paused, paused_sync.unwrap()) { + error!("Error running worker: {:?}", e); + } + }, + )?; + self.epoll_thread = Some(epoll_threads.remove(0)); + + Ok(()) + } + + fn reset(&mut self) -> Option> { + if let Some(vu) = &self.vu_common.vu { + if let Err(e) = vu.lock().unwrap().reset_vhost_user() { + error!("Failed to reset vhost-user daemon: {:?}", e); + return None; + } + } + + if let Some(kill_evt) = self.common.kill_evt.take() { + // Ignore the result because there is nothing we can do about it. + let _ = kill_evt.write(1); + } + + // Return the interrupt + Some(self.common.interrupt_cb.take().unwrap()) + } + + fn shutdown(&mut self) { + self.vu_common.shutdown() + } + + fn add_memory_region( + &mut self, + region: &Arc, + ) -> std::result::Result<(), Error> { + self.vu_common.add_memory_region(&self.guest_memory, region) + } +} diff --git a/crates/vhost-user-frontend/src/lib.rs b/crates/vhost-user-frontend/src/lib.rs new file mode 100644 index 00000000..9cdd5be8 --- /dev/null +++ b/crates/vhost-user-frontend/src/lib.rs @@ -0,0 +1,250 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE-BSD-3-Clause file. +// +// Copyright © 2019 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause + +#[macro_use] +extern crate log; + +mod device; +mod epoll_helper; +mod generic; +mod seccomp_filters; +mod thread_helper; +mod vhost_user; + +pub use crate::device::*; +pub use crate::epoll_helper::*; +pub use crate::generic::*; +pub use crate::seccomp_filters::*; +pub(crate) use crate::thread_helper::*; +pub use crate::vhost_user::*; + +use std::fmt::{self, Debug}; + +use virtio_queue::{Queue, QueueT}; +use vm_memory::{bitmap::AtomicBitmap, GuestAddress, GuestMemory}; + +pub type GuestMemoryMmap = vm_memory::GuestMemoryMmap; +pub type GuestRegionMmap = vm_memory::GuestRegionMmap; +pub type MmapRegion = vm_memory::MmapRegion; + +const VIRTIO_F_RING_INDIRECT_DESC: u32 = 28; +const VIRTIO_F_RING_EVENT_IDX: u32 = 29; +const VIRTIO_F_VERSION_1: u32 = 32; +#[allow(dead_code)] +const VIRTIO_F_IOMMU_PLATFORM: u32 = 33; +const VIRTIO_F_IN_ORDER: u32 = 35; +const VIRTIO_F_ORDER_PLATFORM: u32 = 36; +#[allow(dead_code)] +const VIRTIO_F_SR_IOV: u32 = 37; +const VIRTIO_F_NOTIFICATION_DATA: u32 = 38; + +#[derive(Debug)] +pub enum ActivateError { + EpollCtl(std::io::Error), + BadActivate, + /// Queue number is not correct + BadQueueNum, + /// Failed to clone Kill event fd + CloneKillEventFd, + /// Failed to clone exit event fd + CloneExitEventFd(std::io::Error), + // Failed to spawn thread + ThreadSpawn(std::io::Error), + /// Failed to create Vhost-user interrupt eventfd + VhostIrqCreate, + /// Failed to setup vhost-user-fs daemon. + VhostUserFsSetup(vhost_user::Error), + /// Failed to setup vhost-user-net daemon. + VhostUserNetSetup(vhost_user::Error), + /// Failed to setup vhost-user-blk daemon. + VhostUserBlkSetup(vhost_user::Error), + /// Failed to reset vhost-user daemon. + VhostUserReset(vhost_user::Error), + /// Cannot create seccomp filter + CreateSeccompFilter(seccompiler::Error), + /// Cannot create rate limiter + CreateRateLimiter(std::io::Error), +} + +pub type ActivateResult = std::result::Result<(), ActivateError>; + +// Types taken from linux/virtio_ids.h +#[derive(Copy, Clone, Debug)] +#[allow(dead_code)] +#[allow(non_camel_case_types)] +#[repr(C)] +pub enum VirtioDeviceType { + Net = 1, + Block = 2, + Console = 3, + Rng = 4, + Balloon = 5, + Fs9P = 9, + Gpu = 16, + Input = 18, + Vsock = 19, + Iommu = 23, + Mem = 24, + Fs = 26, + Pmem = 27, + I2c = 34, + Watchdog = 35, // Temporary until official number allocated + Gpio = 41, + Unknown = 0xFF, +} + +impl From for VirtioDeviceType { + fn from(t: u32) -> Self { + match t { + 1 => VirtioDeviceType::Net, + 2 => VirtioDeviceType::Block, + 3 => VirtioDeviceType::Console, + 4 => VirtioDeviceType::Rng, + 5 => VirtioDeviceType::Balloon, + 9 => VirtioDeviceType::Fs9P, + 16 => VirtioDeviceType::Gpu, + 18 => VirtioDeviceType::Input, + 19 => VirtioDeviceType::Vsock, + 23 => VirtioDeviceType::Iommu, + 24 => VirtioDeviceType::Mem, + 26 => VirtioDeviceType::Fs, + 27 => VirtioDeviceType::Pmem, + 34 => VirtioDeviceType::I2c, + 35 => VirtioDeviceType::Watchdog, + 41 => VirtioDeviceType::Gpio, + _ => VirtioDeviceType::Unknown, + } + } +} + +impl From<&str> for VirtioDeviceType { + fn from(t: &str) -> Self { + match t { + "net" => VirtioDeviceType::Net, + "block" => VirtioDeviceType::Block, + "console" => VirtioDeviceType::Console, + "rng" => VirtioDeviceType::Rng, + "balloon" => VirtioDeviceType::Balloon, + "fs9p" => VirtioDeviceType::Fs9P, + "gpu" => VirtioDeviceType::Gpu, + "input" => VirtioDeviceType::Input, + "vsock" => VirtioDeviceType::Vsock, + "iommu" => VirtioDeviceType::Iommu, + "mem" => VirtioDeviceType::Mem, + "fs" => VirtioDeviceType::Fs, + "pmem" => VirtioDeviceType::Pmem, + "i2c" => VirtioDeviceType::I2c, + "watchdog" => VirtioDeviceType::Watchdog, + "gpio" => VirtioDeviceType::Gpio, + _ => VirtioDeviceType::Unknown, + } + } +} + +impl From for String { + fn from(t: VirtioDeviceType) -> String { + match t { + VirtioDeviceType::Net => "net", + VirtioDeviceType::Block => "block", + VirtioDeviceType::Console => "console", + VirtioDeviceType::Rng => "rng", + VirtioDeviceType::Balloon => "balloon", + VirtioDeviceType::Gpu => "gpu", + VirtioDeviceType::Fs9P => "9p", + VirtioDeviceType::Input => "input", + VirtioDeviceType::Vsock => "vsock", + VirtioDeviceType::Iommu => "iommu", + VirtioDeviceType::Mem => "mem", + VirtioDeviceType::Fs => "fs", + VirtioDeviceType::Pmem => "pmem", + VirtioDeviceType::I2c => "i2c", + VirtioDeviceType::Watchdog => "watchdog", + VirtioDeviceType::Gpio => "gpio", + VirtioDeviceType::Unknown => "UNKNOWN", + } + .to_string() + } +} + +// In order to use the `{}` marker, the trait `fmt::Display` must be implemented +// manually for the type VirtioDeviceType. +impl fmt::Display for VirtioDeviceType { + // This trait requires `fmt` with this exact signature. + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", String::from(*self)) + } +} + +impl VirtioDeviceType { + // Returns (number, size) of all queues + pub fn queue_num_and_size(&self) -> (usize, usize) { + match *self { + VirtioDeviceType::Net => (0, 0), + VirtioDeviceType::Block => (0, 0), + VirtioDeviceType::Console => (0, 0), + VirtioDeviceType::Rng => (0, 0), + VirtioDeviceType::Balloon => (0, 0), + VirtioDeviceType::Gpu => (0, 0), + VirtioDeviceType::Fs9P => (0, 0), + VirtioDeviceType::Input => (0, 0), + VirtioDeviceType::Vsock => (0, 0), + VirtioDeviceType::Iommu => (0, 0), + VirtioDeviceType::Mem => (0, 0), + VirtioDeviceType::Fs => (0, 0), + VirtioDeviceType::Pmem => (0, 0), + VirtioDeviceType::I2c => (1, 1024), + VirtioDeviceType::Watchdog => (0, 0), + VirtioDeviceType::Gpio => (2, 256), + _ => (0, 0), + } + } +} + +/// Trait for devices with access to data in memory being limited and/or +/// translated. +pub trait AccessPlatform: Send + Sync + Debug { + /// Provide a way to translate GVA address ranges into GPAs. + fn translate_gva(&self, base: u64, size: u64) -> std::result::Result; + /// Provide a way to translate GPA address ranges into GVAs. + fn translate_gpa(&self, base: u64, size: u64) -> std::result::Result; +} + +/// Helper for cloning a Queue since QueueState doesn't derive Clone +pub fn clone_queue(queue: &Queue) -> Queue { + let mut q = Queue::new(queue.max_size()).unwrap(); + + q.set_next_avail(queue.next_avail()); + q.set_next_used(queue.next_used()); + q.set_event_idx(queue.event_idx_enabled()); + q.set_size(queue.size()); + q.set_ready(queue.ready()); + q.try_set_desc_table_address(GuestAddress(queue.desc_table())) + .unwrap(); + q.try_set_avail_ring_address(GuestAddress(queue.avail_ring())) + .unwrap(); + q.try_set_used_ring_address(GuestAddress(queue.used_ring())) + .unwrap(); + + q +} + +/// Convert an absolute address into an address space (GuestMemory) +/// to a host pointer and verify that the provided size define a valid +/// range within a single memory region. +/// Return None if it is out of bounds or if addr+size overlaps a single region. +pub fn get_host_address_range( + mem: &M, + addr: GuestAddress, + size: usize, +) -> Option<*mut u8> { + if mem.check_range(addr, size) { + Some(mem.get_host_address(addr).unwrap()) + } else { + None + } +} diff --git a/crates/vhost-user-frontend/src/seccomp_filters.rs b/crates/vhost-user-frontend/src/seccomp_filters.rs new file mode 100644 index 00000000..6044231b --- /dev/null +++ b/crates/vhost-user-frontend/src/seccomp_filters.rs @@ -0,0 +1,291 @@ +// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Copyright © 2020 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +use seccompiler::{ + BpfProgram, Error, SeccompAction, SeccompCmpArgLen as ArgLen, SeccompCmpOp::Eq, + SeccompCondition as Cond, SeccompFilter, SeccompRule, +}; +use std::convert::TryInto; + +pub enum Thread { + VirtioBalloon, + VirtioBlock, + VirtioConsole, + VirtioIommu, + VirtioMem, + VirtioNet, + VirtioNetCtl, + VirtioPmem, + VirtioRng, + VirtioVhostBlock, + VirtioVhostFs, + VirtioVhostNet, + VirtioVhostNetCtl, + VirtioVsock, + VirtioWatchdog, +} + +/// Shorthand for chaining `SeccompCondition`s with the `and` operator in a `SeccompRule`. +/// The rule will take the `Allow` action if _all_ the conditions are true. +/// +/// [`SeccompCondition`]: struct.SeccompCondition.html +/// [`SeccompRule`]: struct.SeccompRule.html +macro_rules! and { + ($($x:expr),*) => (SeccompRule::new(vec![$($x),*]).unwrap()) +} + +/// Shorthand for chaining `SeccompRule`s with the `or` operator in a `SeccompFilter`. +/// +/// [`SeccompFilter`]: struct.SeccompFilter.html +/// [`SeccompRule`]: struct.SeccompRule.html +macro_rules! or { + ($($x:expr,)*) => (vec![$($x),*]); + ($($x:expr),*) => (vec![$($x),*]) +} + +// See include/uapi/asm-generic/ioctls.h in the kernel code. +const TIOCGWINSZ: u64 = 0x5413; +const FIONBIO: u64 = 0x5421; + +// See include/uapi/linux/vfio.h in the kernel code. +const VFIO_IOMMU_MAP_DMA: u64 = 0x3b71; +const VFIO_IOMMU_UNMAP_DMA: u64 = 0x3b72; + +// See include/uapi/linux/if_tun.h in the kernel code. +const TUNSETOFFLOAD: u64 = 0x4004_54d0; + +fn create_virtio_console_ioctl_seccomp_rule() -> Vec { + or![and![Cond::new(1, ArgLen::Dword, Eq, TIOCGWINSZ).unwrap()]] +} + +fn create_virtio_iommu_ioctl_seccomp_rule() -> Vec { + or![ + and![Cond::new(1, ArgLen::Dword, Eq, VFIO_IOMMU_MAP_DMA).unwrap()], + and![Cond::new(1, ArgLen::Dword, Eq, VFIO_IOMMU_UNMAP_DMA).unwrap()], + ] +} + +fn create_virtio_mem_ioctl_seccomp_rule() -> Vec { + or![ + and![Cond::new(1, ArgLen::Dword, Eq, VFIO_IOMMU_MAP_DMA).unwrap()], + and![Cond::new(1, ArgLen::Dword, Eq, VFIO_IOMMU_UNMAP_DMA).unwrap()], + ] +} + +fn virtio_balloon_thread_rules() -> Vec<(i64, Vec)> { + vec![(libc::SYS_fallocate, vec![])] +} + +fn virtio_block_thread_rules() -> Vec<(i64, Vec)> { + vec![ + (libc::SYS_fallocate, vec![]), + (libc::SYS_fdatasync, vec![]), + (libc::SYS_fsync, vec![]), + (libc::SYS_ftruncate, vec![]), + (libc::SYS_getrandom, vec![]), + (libc::SYS_io_uring_enter, vec![]), + (libc::SYS_lseek, vec![]), + (libc::SYS_mprotect, vec![]), + (libc::SYS_prctl, vec![]), + (libc::SYS_pread64, vec![]), + (libc::SYS_preadv, vec![]), + (libc::SYS_pwritev, vec![]), + (libc::SYS_pwrite64, vec![]), + (libc::SYS_sched_getaffinity, vec![]), + (libc::SYS_set_robust_list, vec![]), + (libc::SYS_timerfd_settime, vec![]), + ] +} + +fn virtio_console_thread_rules() -> Vec<(i64, Vec)> { + vec![ + (libc::SYS_ioctl, create_virtio_console_ioctl_seccomp_rule()), + (libc::SYS_mprotect, vec![]), + (libc::SYS_prctl, vec![]), + (libc::SYS_sched_getaffinity, vec![]), + (libc::SYS_set_robust_list, vec![]), + ] +} + +fn virtio_iommu_thread_rules() -> Vec<(i64, Vec)> { + vec![ + (libc::SYS_ioctl, create_virtio_iommu_ioctl_seccomp_rule()), + (libc::SYS_mprotect, vec![]), + ] +} + +fn virtio_mem_thread_rules() -> Vec<(i64, Vec)> { + vec![ + (libc::SYS_fallocate, vec![]), + (libc::SYS_ioctl, create_virtio_mem_ioctl_seccomp_rule()), + (libc::SYS_recvfrom, vec![]), + (libc::SYS_sendmsg, vec![]), + ] +} + +fn virtio_net_thread_rules() -> Vec<(i64, Vec)> { + vec![ + (libc::SYS_readv, vec![]), + (libc::SYS_timerfd_settime, vec![]), + (libc::SYS_writev, vec![]), + ] +} + +fn create_virtio_net_ctl_ioctl_seccomp_rule() -> Vec { + or![and![Cond::new(1, ArgLen::Dword, Eq, TUNSETOFFLOAD).unwrap()],] +} + +fn virtio_net_ctl_thread_rules() -> Vec<(i64, Vec)> { + vec![(libc::SYS_ioctl, create_virtio_net_ctl_ioctl_seccomp_rule())] +} + +fn virtio_pmem_thread_rules() -> Vec<(i64, Vec)> { + vec![(libc::SYS_fsync, vec![])] +} + +fn virtio_rng_thread_rules() -> Vec<(i64, Vec)> { + vec![ + (libc::SYS_mprotect, vec![]), + (libc::SYS_prctl, vec![]), + (libc::SYS_sched_getaffinity, vec![]), + (libc::SYS_set_robust_list, vec![]), + ] +} + +fn virtio_vhost_fs_thread_rules() -> Vec<(i64, Vec)> { + vec![ + (libc::SYS_connect, vec![]), + (libc::SYS_nanosleep, vec![]), + (libc::SYS_pread64, vec![]), + (libc::SYS_pwrite64, vec![]), + (libc::SYS_recvmsg, vec![]), + (libc::SYS_sendmsg, vec![]), + (libc::SYS_sendto, vec![]), + (libc::SYS_socket, vec![]), + ] +} + +fn virtio_vhost_net_ctl_thread_rules() -> Vec<(i64, Vec)> { + vec![] +} + +fn virtio_vhost_net_thread_rules() -> Vec<(i64, Vec)> { + vec![ + (libc::SYS_accept4, vec![]), + (libc::SYS_bind, vec![]), + (libc::SYS_getcwd, vec![]), + (libc::SYS_listen, vec![]), + (libc::SYS_recvmsg, vec![]), + (libc::SYS_sendmsg, vec![]), + (libc::SYS_sendto, vec![]), + (libc::SYS_socket, vec![]), + #[cfg(target_arch = "x86_64")] + (libc::SYS_unlink, vec![]), + #[cfg(target_arch = "aarch64")] + (libc::SYS_unlinkat, vec![]), + ] +} + +fn virtio_vhost_block_thread_rules() -> Vec<(i64, Vec)> { + vec![] +} + +fn create_vsock_ioctl_seccomp_rule() -> Vec { + or![and![Cond::new(1, ArgLen::Dword, Eq, FIONBIO,).unwrap()],] +} + +fn virtio_vsock_thread_rules() -> Vec<(i64, Vec)> { + vec![ + (libc::SYS_accept4, vec![]), + (libc::SYS_connect, vec![]), + (libc::SYS_ioctl, create_vsock_ioctl_seccomp_rule()), + (libc::SYS_recvfrom, vec![]), + (libc::SYS_socket, vec![]), + ] +} + +fn virtio_watchdog_thread_rules() -> Vec<(i64, Vec)> { + vec![ + (libc::SYS_mprotect, vec![]), + (libc::SYS_prctl, vec![]), + (libc::SYS_sched_getaffinity, vec![]), + (libc::SYS_set_robust_list, vec![]), + (libc::SYS_timerfd_settime, vec![]), + ] +} + +fn get_seccomp_rules(thread_type: Thread) -> Vec<(i64, Vec)> { + let mut rules = match thread_type { + Thread::VirtioBalloon => virtio_balloon_thread_rules(), + Thread::VirtioBlock => virtio_block_thread_rules(), + Thread::VirtioConsole => virtio_console_thread_rules(), + Thread::VirtioIommu => virtio_iommu_thread_rules(), + Thread::VirtioMem => virtio_mem_thread_rules(), + Thread::VirtioNet => virtio_net_thread_rules(), + Thread::VirtioNetCtl => virtio_net_ctl_thread_rules(), + Thread::VirtioPmem => virtio_pmem_thread_rules(), + Thread::VirtioRng => virtio_rng_thread_rules(), + Thread::VirtioVhostBlock => virtio_vhost_block_thread_rules(), + Thread::VirtioVhostFs => virtio_vhost_fs_thread_rules(), + Thread::VirtioVhostNet => virtio_vhost_net_thread_rules(), + Thread::VirtioVhostNetCtl => virtio_vhost_net_ctl_thread_rules(), + Thread::VirtioVsock => virtio_vsock_thread_rules(), + Thread::VirtioWatchdog => virtio_watchdog_thread_rules(), + }; + rules.append(&mut virtio_thread_common()); + rules +} + +fn virtio_thread_common() -> Vec<(i64, Vec)> { + vec![ + (libc::SYS_brk, vec![]), + (libc::SYS_clock_gettime, vec![]), + (libc::SYS_close, vec![]), + (libc::SYS_dup, vec![]), + (libc::SYS_epoll_create1, vec![]), + (libc::SYS_epoll_ctl, vec![]), + (libc::SYS_epoll_pwait, vec![]), + #[cfg(target_arch = "x86_64")] + (libc::SYS_epoll_wait, vec![]), + (libc::SYS_exit, vec![]), + (libc::SYS_futex, vec![]), + (libc::SYS_madvise, vec![]), + (libc::SYS_mmap, vec![]), + (libc::SYS_munmap, vec![]), + (libc::SYS_openat, vec![]), + (libc::SYS_read, vec![]), + (libc::SYS_rt_sigprocmask, vec![]), + (libc::SYS_rt_sigreturn, vec![]), + (libc::SYS_sigaltstack, vec![]), + (libc::SYS_write, vec![]), + ] +} + +/// Generate a BPF program based on the seccomp_action value +pub fn get_seccomp_filter( + seccomp_action: &SeccompAction, + thread_type: Thread, +) -> Result { + match seccomp_action { + SeccompAction::Allow => Ok(vec![]), + SeccompAction::Log => SeccompFilter::new( + get_seccomp_rules(thread_type).into_iter().collect(), + SeccompAction::Log, + SeccompAction::Allow, + std::env::consts::ARCH.try_into().unwrap(), + ) + .and_then(|filter| filter.try_into()) + .map_err(Error::Backend), + _ => SeccompFilter::new( + get_seccomp_rules(thread_type).into_iter().collect(), + SeccompAction::Trap, + SeccompAction::Allow, + std::env::consts::ARCH.try_into().unwrap(), + ) + .and_then(|filter| filter.try_into()) + .map_err(Error::Backend), + } +} diff --git a/crates/vhost-user-frontend/src/thread_helper.rs b/crates/vhost-user-frontend/src/thread_helper.rs new file mode 100644 index 00000000..412bd77a --- /dev/null +++ b/crates/vhost-user-frontend/src/thread_helper.rs @@ -0,0 +1,56 @@ +// Copyright © 2021 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 +// + +use crate::{get_seccomp_filter, ActivateError, Thread}; +use seccompiler::{apply_filter, SeccompAction}; +use std::{ + panic::AssertUnwindSafe, + thread::{self, JoinHandle}, +}; +use vmm_sys_util::eventfd::EventFd; + +pub(crate) fn spawn_virtio_thread( + name: &str, + seccomp_action: &SeccompAction, + thread_type: Thread, + epoll_threads: &mut Vec>, + exit_evt: &EventFd, + f: F, +) -> Result<(), ActivateError> +where + F: FnOnce(), + F: Send + 'static, +{ + let seccomp_filter = get_seccomp_filter(seccomp_action, thread_type) + .map_err(ActivateError::CreateSeccompFilter)?; + + let thread_exit_evt = exit_evt + .try_clone() + .map_err(ActivateError::CloneExitEventFd)?; + let thread_name = name.to_string(); + + thread::Builder::new() + .name(name.to_string()) + .spawn(move || { + if !seccomp_filter.is_empty() { + if let Err(e) = apply_filter(&seccomp_filter) { + error!("Error applying seccomp filter: {:?}", e); + thread_exit_evt.write(1).ok(); + return; + } + } + std::panic::catch_unwind(AssertUnwindSafe(f)) + .or_else(|_| { + error!("{} thread panicked", thread_name); + thread_exit_evt.write(1) + }) + .ok(); + }) + .map(|thread| epoll_threads.push(thread)) + .map_err(|e| { + error!("Failed to spawn thread for {}: {}", name, e); + ActivateError::ThreadSpawn(e) + }) +} diff --git a/crates/vhost-user-frontend/src/vhost_user/mod.rs b/crates/vhost-user-frontend/src/vhost_user/mod.rs new file mode 100644 index 00000000..14e55840 --- /dev/null +++ b/crates/vhost-user-frontend/src/vhost_user/mod.rs @@ -0,0 +1,384 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + clone_queue, ActivateError, EpollHelper, EpollHelperError, EpollHelperHandler, GuestMemoryMmap, + GuestRegionMmap, VirtioInterrupt, EPOLL_HELPER_EVENT_LAST, VIRTIO_F_IN_ORDER, + VIRTIO_F_NOTIFICATION_DATA, VIRTIO_F_ORDER_PLATFORM, VIRTIO_F_RING_EVENT_IDX, + VIRTIO_F_RING_INDIRECT_DESC, VIRTIO_F_VERSION_1, +}; +use std::fmt::Debug; +use std::io; +use std::ops::Deref; +use std::os::unix::io::AsRawFd; +use std::sync::{atomic::AtomicBool, Arc, Barrier, Mutex}; +use vhost::vhost_user::message::{ + VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures, +}; +use vhost::vhost_user::{MasterReqHandler, VhostUserMasterReqHandler}; +use vhost::Error as VhostError; +use virtio_queue::Error as QueueError; +use virtio_queue::Queue; +use vm_memory::{mmap::MmapRegionError, Error as MmapError, GuestAddressSpace, GuestMemoryAtomic}; +use vmm_sys_util::eventfd::EventFd; +pub(crate) use vu_common_ctrl::VhostUserHandle; + +pub mod vu_common_ctrl; +pub use self::vu_common_ctrl::VhostUserConfig; + +#[derive(Debug)] +pub enum Error { + /// Failed accepting connection. + AcceptConnection(io::Error), + /// Invalid available address. + AvailAddress, + /// Queue number is not correct + BadQueueNum, + /// Failed binding vhost-user socket. + BindSocket(io::Error), + /// Creating kill eventfd failed. + CreateKillEventFd(io::Error), + /// Cloning kill eventfd failed. + CloneKillEventFd(io::Error), + /// Invalid descriptor table address. + DescriptorTableAddress, + /// Signal used queue failed. + FailedSignalingUsedQueue(io::Error), + /// Failed to read vhost eventfd. + MemoryRegions(MmapError), + /// Failed removing socket path + RemoveSocketPath(io::Error), + /// Failed to create master. + VhostUserCreateMaster(VhostError), + /// Failed to open vhost device. + VhostUserOpen(VhostError), + /// Connection to socket failed. + VhostUserConnect, + /// Get features failed. + VhostUserGetFeatures(VhostError), + /// Get queue max number failed. + VhostUserGetQueueMaxNum(VhostError), + /// Get protocol features failed. + VhostUserGetProtocolFeatures(VhostError), + /// Get vring base failed. + VhostUserGetVringBase(VhostError), + /// Vhost-user Backend not support vhost-user protocol. + VhostUserProtocolNotSupport, + /// Set owner failed. + VhostUserSetOwner(VhostError), + /// Reset owner failed. + VhostUserResetOwner(VhostError), + /// Set features failed. + VhostUserSetFeatures(VhostError), + /// Set protocol features failed. + VhostUserSetProtocolFeatures(VhostError), + /// Set mem table failed. + VhostUserSetMemTable(VhostError), + /// Set vring num failed. + VhostUserSetVringNum(VhostError), + /// Set vring addr failed. + VhostUserSetVringAddr(VhostError), + /// Set vring base failed. + VhostUserSetVringBase(VhostError), + /// Set vring call failed. + VhostUserSetVringCall(VhostError), + /// Set vring kick failed. + VhostUserSetVringKick(VhostError), + /// Set vring enable failed. + VhostUserSetVringEnable(VhostError), + /// Failed to create vhost eventfd. + VhostIrqCreate(io::Error), + /// Failed to read vhost eventfd. + VhostIrqRead(io::Error), + /// Failed to read vhost eventfd. + VhostUserMemoryRegion(MmapError), + /// Failed to create the master request handler from slave. + MasterReqHandlerCreation(vhost::vhost_user::Error), + /// Set slave request fd failed. + VhostUserSetSlaveRequestFd(vhost::Error), + /// Add memory region failed. + VhostUserAddMemReg(VhostError), + /// Failed getting the configuration. + VhostUserGetConfig(VhostError), + /// Failed setting the configuration. + VhostUserSetConfig(VhostError), + /// Failed getting inflight shm log. + VhostUserGetInflight(VhostError), + /// Failed setting inflight shm log. + VhostUserSetInflight(VhostError), + /// Failed setting the log base. + VhostUserSetLogBase(VhostError), + /// Invalid used address. + UsedAddress, + /// Invalid features provided from vhost-user backend + InvalidFeatures, + /// Missing file descriptor for the region. + MissingRegionFd, + /// Missing IrqFd + MissingIrqFd, + /// Failed getting the available index. + GetAvailableIndex(QueueError), + /// Migration is not supported by this vhost-user device. + MigrationNotSupported, + /// Failed creating memfd. + MemfdCreate(io::Error), + /// Failed truncating the file size to the expected size. + SetFileSize(io::Error), + /// Failed to set the seals on the file. + SetSeals(io::Error), + /// Failed creating new mmap region + NewMmapRegion(MmapRegionError), + /// Could not find the shm log region + MissingShmLogRegion, +} +pub type Result = std::result::Result; + +pub const DEFAULT_VIRTIO_FEATURES: u64 = 1 << VIRTIO_F_RING_INDIRECT_DESC + | 1 << VIRTIO_F_RING_EVENT_IDX + | 1 << VIRTIO_F_VERSION_1 + | 1 << VIRTIO_F_IN_ORDER + | 1 << VIRTIO_F_ORDER_PLATFORM + | 1 << VIRTIO_F_NOTIFICATION_DATA + | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + +const HUP_CONNECTION_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 1; +const SLAVE_REQ_EVENT: u16 = EPOLL_HELPER_EVENT_LAST + 2; + +#[derive(Default)] +pub struct Inflight { + pub info: VhostUserInflight, + pub fd: Option, +} + +pub struct VhostUserEpollHandler { + pub vu: Arc>, + pub mem: GuestMemoryAtomic, + pub kill_evt: EventFd, + pub pause_evt: EventFd, + pub queues: Vec<(usize, Queue, EventFd)>, + pub virtio_interrupt: Arc, + pub acked_features: u64, + pub acked_protocol_features: u64, + pub socket_path: String, + pub server: bool, + pub slave_req_handler: Option>, + pub inflight: Option, +} + +impl VhostUserEpollHandler { + pub fn run( + &mut self, + paused: Arc, + paused_sync: Arc, + ) -> std::result::Result<(), EpollHelperError> { + let mut helper = EpollHelper::new(&self.kill_evt, &self.pause_evt)?; + helper.add_event_custom( + self.vu.lock().unwrap().socket_handle().as_raw_fd(), + HUP_CONNECTION_EVENT, + epoll::Events::EPOLLHUP, + )?; + + if let Some(slave_req_handler) = &self.slave_req_handler { + helper.add_event(slave_req_handler.as_raw_fd(), SLAVE_REQ_EVENT)?; + } + + helper.run(paused, paused_sync, self)?; + + Ok(()) + } + + fn reconnect(&mut self, helper: &mut EpollHelper) -> std::result::Result<(), EpollHelperError> { + helper.del_event_custom( + self.vu.lock().unwrap().socket_handle().as_raw_fd(), + HUP_CONNECTION_EVENT, + epoll::Events::EPOLLHUP, + )?; + + let mut vhost_user = VhostUserHandle::connect_vhost_user( + self.server, + &self.socket_path, + self.queues.len() as u64, + true, + ) + .map_err(|e| { + EpollHelperError::IoError(std::io::Error::new( + std::io::ErrorKind::Other, + format!("failed connecting vhost-user backend{:?}", e), + )) + })?; + + // Initialize the backend + vhost_user + .reinitialize_vhost_user( + self.mem.memory().deref(), + self.queues + .iter() + .map(|(i, q, e)| (*i, clone_queue(q), e.try_clone().unwrap())) + .collect(), + &self.virtio_interrupt, + self.acked_features, + self.acked_protocol_features, + &self.slave_req_handler, + self.inflight.as_mut(), + ) + .map_err(|e| { + EpollHelperError::IoError(std::io::Error::new( + std::io::ErrorKind::Other, + format!("failed reconnecting vhost-user backend{:?}", e), + )) + })?; + + helper.add_event_custom( + vhost_user.socket_handle().as_raw_fd(), + HUP_CONNECTION_EVENT, + epoll::Events::EPOLLHUP, + )?; + + // Update vhost-user reference + let mut vu = self.vu.lock().unwrap(); + *vu = vhost_user; + + Ok(()) + } +} + +impl EpollHelperHandler for VhostUserEpollHandler { + fn handle_event(&mut self, helper: &mut EpollHelper, event: &epoll::Event) -> bool { + let ev_type = event.data as u16; + match ev_type { + HUP_CONNECTION_EVENT => { + if let Err(e) = self.reconnect(helper) { + error!("failed to reconnect vhost-user backend: {:?}", e); + return true; + } + } + SLAVE_REQ_EVENT => { + if let Some(slave_req_handler) = self.slave_req_handler.as_mut() { + if let Err(e) = slave_req_handler.handle_request() { + error!("Failed to handle request from vhost-user backend: {:?}", e); + return true; + } + } + } + _ => { + error!("Unknown event for vhost-user thread"); + return true; + } + } + + false + } +} + +#[derive(Default)] +pub struct VhostUserCommon { + pub vu: Option>>, + pub acked_protocol_features: u64, + pub socket_path: String, + pub vu_num_queues: usize, + pub migration_started: bool, + pub server: bool, +} + +impl VhostUserCommon { + #[allow(clippy::too_many_arguments)] + pub fn activate( + &mut self, + mem: GuestMemoryAtomic, + queues: Vec<(usize, Queue, EventFd)>, + interrupt_cb: Arc, + acked_features: u64, + slave_req_handler: Option>, + kill_evt: EventFd, + pause_evt: EventFd, + ) -> std::result::Result, ActivateError> { + let mut inflight: Option = + if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() != 0 + { + Some(Inflight::default()) + } else { + None + }; + + if self.vu.is_none() { + error!("Missing vhost-user handle"); + return Err(ActivateError::BadActivate); + } + let vu = self.vu.as_ref().unwrap(); + vu.lock() + .unwrap() + .setup_vhost_user( + &mem.memory(), + queues + .iter() + .map(|(i, q, e)| (*i, clone_queue(q), e.try_clone().unwrap())) + .collect(), + &interrupt_cb, + acked_features, + &slave_req_handler, + inflight.as_mut(), + ) + .map_err(ActivateError::VhostUserBlkSetup)?; + + Ok(VhostUserEpollHandler { + vu: vu.clone(), + mem, + kill_evt, + pause_evt, + queues, + virtio_interrupt: interrupt_cb, + acked_features, + acked_protocol_features: self.acked_protocol_features, + socket_path: self.socket_path.clone(), + server: self.server, + slave_req_handler, + inflight, + }) + } + + pub fn restore_backend_connection(&mut self, acked_features: u64) -> Result<()> { + let mut vu = VhostUserHandle::connect_vhost_user( + self.server, + &self.socket_path, + self.vu_num_queues as u64, + false, + )?; + + vu.set_protocol_features_vhost_user(acked_features, self.acked_protocol_features)?; + + self.vu = Some(Arc::new(Mutex::new(vu))); + + Ok(()) + } + + pub fn shutdown(&mut self) { + if let Some(vu) = &self.vu { + let _ = unsafe { libc::close(vu.lock().unwrap().socket_handle().as_raw_fd()) }; + } + + // Remove socket path if needed + if self.server { + let _ = std::fs::remove_file(&self.socket_path); + } + } + + pub fn add_memory_region( + &mut self, + guest_memory: &Option>, + region: &Arc, + ) -> std::result::Result<(), Error> { + if let Some(vu) = &self.vu { + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() + != 0 + { + return vu.lock().unwrap().add_memory_region(region); + } else if let Some(guest_memory) = guest_memory { + return vu + .lock() + .unwrap() + .update_mem_table(guest_memory.memory().deref()); + } + } + Ok(()) + } +} diff --git a/crates/vhost-user-frontend/src/vhost_user/vu_common_ctrl.rs b/crates/vhost-user-frontend/src/vhost_user/vu_common_ctrl.rs new file mode 100644 index 00000000..18536725 --- /dev/null +++ b/crates/vhost-user-frontend/src/vhost_user/vu_common_ctrl.rs @@ -0,0 +1,575 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{Error, Result}; +use crate::vhost_user::Inflight; +use crate::{ + get_host_address_range, GuestMemoryMmap, GuestRegionMmap, MmapRegion, VirtioInterrupt, + VirtioInterruptType, +}; +use std::convert::TryInto; +use std::ffi; +use std::fs::File; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::os::unix::net::UnixListener; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::thread::sleep; +use std::time::{Duration, Instant}; +use std::vec::Vec; +use vhost::vhost_kern::vhost_binding::{VHOST_F_LOG_ALL, VHOST_VRING_F_LOG}; +use vhost::vhost_user::message::{ + VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, VhostUserVirtioFeatures, +}; +use vhost::vhost_user::{Master, MasterReqHandler, VhostUserMaster, VhostUserMasterReqHandler}; +use vhost::{VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData}; +use virtio_queue::{Descriptor, Queue, QueueT}; +use vm_memory::{ + Address, Error as MmapError, FileOffset, GuestAddress, GuestMemory, GuestMemoryRegion, +}; +use vmm_sys_util::eventfd::EventFd; + +// Size of a dirty page for vhost-user. +const VHOST_LOG_PAGE: u64 = 0x1000; + +#[derive(Debug, Clone)] +pub struct VhostUserConfig { + pub socket: String, + pub num_queues: usize, + pub queue_size: u16, +} + +#[derive(Clone)] +struct VringInfo { + config_data: VringConfigData, + used_guest_addr: u64, +} + +#[derive(Clone)] +pub struct VhostUserHandle { + vu: Master, + ready: bool, + supports_migration: bool, + shm_log: Option>, + acked_features: u64, + vrings_info: Option>, + queue_indexes: Vec, +} + +impl VhostUserHandle { + pub fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<()> { + let mut regions: Vec = Vec::new(); + for region in mem.iter() { + let (mmap_handle, mmap_offset) = match region.file_offset() { + Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()), + None => return Err(Error::VhostUserMemoryRegion(MmapError::NoMemoryRegion)), + }; + + let vhost_user_net_reg = VhostUserMemoryRegionInfo { + guest_phys_addr: region.start_addr().raw_value(), + memory_size: region.len() as u64, + userspace_addr: region.as_ptr() as u64, + mmap_offset, + mmap_handle, + }; + + regions.push(vhost_user_net_reg); + } + + self.vu + .set_mem_table(regions.as_slice()) + .map_err(Error::VhostUserSetMemTable)?; + + Ok(()) + } + + pub fn add_memory_region(&mut self, region: &Arc) -> Result<()> { + let (mmap_handle, mmap_offset) = match region.file_offset() { + Some(file_offset) => (file_offset.file().as_raw_fd(), file_offset.start()), + None => return Err(Error::MissingRegionFd), + }; + + let region = VhostUserMemoryRegionInfo { + guest_phys_addr: region.start_addr().raw_value(), + memory_size: region.len() as u64, + userspace_addr: region.as_ptr() as u64, + mmap_offset, + mmap_handle, + }; + + self.vu + .add_mem_region(®ion) + .map_err(Error::VhostUserAddMemReg) + } + + pub fn negotiate_features_vhost_user( + &mut self, + avail_features: u64, + avail_protocol_features: VhostUserProtocolFeatures, + ) -> Result<(u64, u64)> { + // Set vhost-user owner. + self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; + + // Get features from backend, do negotiation to get a feature collection which + // both VMM and backend support. + let backend_features = self + .vu + .get_features() + .map_err(Error::VhostUserGetFeatures)?; + let acked_features = avail_features & backend_features; + + let acked_protocol_features = + if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { + let backend_protocol_features = self + .vu + .get_protocol_features() + .map_err(Error::VhostUserGetProtocolFeatures)?; + + let acked_protocol_features = avail_protocol_features & backend_protocol_features; + + self.vu + .set_protocol_features(acked_protocol_features) + .map_err(Error::VhostUserSetProtocolFeatures)?; + + acked_protocol_features + } else { + VhostUserProtocolFeatures::empty() + }; + + if avail_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) + && acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) + { + self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); + } + + self.update_supports_migration(acked_features, acked_protocol_features.bits()); + + Ok((acked_features, acked_protocol_features.bits())) + } + + pub fn device_features(&self) -> Result { + self.vu.get_features().map_err(Error::VhostUserGetFeatures) + } + + #[allow(clippy::too_many_arguments)] + pub fn setup_vhost_user( + &mut self, + mem: &GuestMemoryMmap, + queues: Vec<(usize, Queue, EventFd)>, + virtio_interrupt: &Arc, + acked_features: u64, + slave_req_handler: &Option>, + inflight: Option<&mut Inflight>, + ) -> Result<()> { + self.vu + .set_features(acked_features) + .map_err(Error::VhostUserSetFeatures)?; + + // Update internal value after it's been sent to the backend. + self.acked_features = acked_features; + + // Let's first provide the memory table to the backend. + self.update_mem_table(mem)?; + + // Send set_vring_num here, since it could tell backends, like SPDK, + // how many virt queues to be handled, which backend required to know + // at early stage. + for (queue_index, queue, _) in queues.iter() { + self.vu + .set_vring_num(*queue_index, queue.size()) + .map_err(Error::VhostUserSetVringNum)?; + } + + // Setup for inflight I/O tracking shared memory. + if let Some(inflight) = inflight { + if inflight.fd.is_none() { + let inflight_req_info = VhostUserInflight { + mmap_size: 0, + mmap_offset: 0, + num_queues: queues.len() as u16, + queue_size: queues[0].1.size(), + }; + let (info, fd) = self + .vu + .get_inflight_fd(&inflight_req_info) + .map_err(Error::VhostUserGetInflight)?; + inflight.info = info; + inflight.fd = Some(fd); + } + // Unwrapping the inflight fd is safe here since we know it can't be None. + self.vu + .set_inflight_fd(&inflight.info, inflight.fd.as_ref().unwrap().as_raw_fd()) + .map_err(Error::VhostUserSetInflight)?; + } + + let mut vrings_info = Vec::new(); + for (queue_index, queue, queue_evt) in queues.iter() { + let actual_size: usize = queue.size().try_into().unwrap(); + + let config_data = VringConfigData { + queue_max_size: queue.max_size(), + queue_size: queue.size(), + flags: 0u32, + desc_table_addr: get_host_address_range( + mem, + GuestAddress(queue.desc_table()), + actual_size * std::mem::size_of::(), + ) + .ok_or(Error::DescriptorTableAddress)? as u64, + // The used ring is {flags: u16; idx: u16; virtq_used_elem [{id: u16, len: u16}; actual_size]}, + // i.e. 4 + (4 + 4) * actual_size. + used_ring_addr: get_host_address_range( + mem, + GuestAddress(queue.used_ring()), + 4 + actual_size * 8, + ) + .ok_or(Error::UsedAddress)? as u64, + // The used ring is {flags: u16; idx: u16; elem [u16; actual_size]}, + // i.e. 4 + (2) * actual_size. + avail_ring_addr: get_host_address_range( + mem, + GuestAddress(queue.avail_ring()), + 4 + actual_size * 2, + ) + .ok_or(Error::AvailAddress)? as u64, + log_addr: None, + }; + + vrings_info.push(VringInfo { + config_data, + used_guest_addr: queue.used_ring(), + }); + + self.vu + .set_vring_addr(*queue_index, &config_data) + .map_err(Error::VhostUserSetVringAddr)?; + self.vu + .set_vring_base( + *queue_index, + queue + .avail_idx(mem, Ordering::Acquire) + .map_err(Error::GetAvailableIndex)? + .0, + ) + .map_err(Error::VhostUserSetVringBase)?; + + if let Some(eventfd) = + virtio_interrupt.notifier(VirtioInterruptType::Queue(*queue_index as u16)) + { + self.vu + .set_vring_call(*queue_index, &eventfd) + .map_err(Error::VhostUserSetVringCall)?; + } + + self.vu + .set_vring_kick(*queue_index, queue_evt) + .map_err(Error::VhostUserSetVringKick)?; + + self.queue_indexes.push(*queue_index); + } + + self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?; + + if let Some(slave_req_handler) = slave_req_handler { + self.vu + .set_slave_request_fd(&slave_req_handler.get_tx_raw_fd()) + .map_err(Error::VhostUserSetSlaveRequestFd)?; + } + + self.vrings_info = Some(vrings_info); + self.ready = true; + + Ok(()) + } + + fn enable_vhost_user_vrings(&mut self, queue_indexes: Vec, enable: bool) -> Result<()> { + for queue_index in queue_indexes { + self.vu + .set_vring_enable(queue_index, enable) + .map_err(Error::VhostUserSetVringEnable)?; + } + + Ok(()) + } + + pub fn reset_vhost_user(&mut self) -> Result<()> { + for queue_index in self.queue_indexes.drain(..) { + self.vu + .set_vring_enable(queue_index, false) + .map_err(Error::VhostUserSetVringEnable)?; + + let _ = self + .vu + .get_vring_base(queue_index) + .map_err(Error::VhostUserGetVringBase)?; + } + + Ok(()) + } + + pub fn set_protocol_features_vhost_user( + &mut self, + acked_features: u64, + acked_protocol_features: u64, + ) -> Result<()> { + self.vu.set_owner().map_err(Error::VhostUserSetOwner)?; + self.vu + .get_features() + .map_err(Error::VhostUserGetFeatures)?; + + if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { + if let Some(acked_protocol_features) = + VhostUserProtocolFeatures::from_bits(acked_protocol_features) + { + self.vu + .set_protocol_features(acked_protocol_features) + .map_err(Error::VhostUserSetProtocolFeatures)?; + + if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) { + self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); + } + } + } + + self.update_supports_migration(acked_features, acked_protocol_features); + + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + pub fn reinitialize_vhost_user( + &mut self, + mem: &GuestMemoryMmap, + queues: Vec<(usize, Queue, EventFd)>, + virtio_interrupt: &Arc, + acked_features: u64, + acked_protocol_features: u64, + slave_req_handler: &Option>, + inflight: Option<&mut Inflight>, + ) -> Result<()> { + self.set_protocol_features_vhost_user(acked_features, acked_protocol_features)?; + + self.setup_vhost_user( + mem, + queues, + virtio_interrupt, + acked_features, + slave_req_handler, + inflight, + ) + } + + pub fn connect_vhost_user( + server: bool, + socket_path: &str, + num_queues: u64, + unlink_socket: bool, + ) -> Result { + if server { + if unlink_socket { + std::fs::remove_file(socket_path).map_err(Error::RemoveSocketPath)?; + } + + info!("Binding vhost-user listener..."); + let listener = UnixListener::bind(socket_path).map_err(Error::BindSocket)?; + info!("Waiting for incoming vhost-user connection..."); + let (stream, _) = listener.accept().map_err(Error::AcceptConnection)?; + + Ok(VhostUserHandle { + vu: Master::from_stream(stream, num_queues), + ready: false, + supports_migration: false, + shm_log: None, + acked_features: 0, + vrings_info: None, + queue_indexes: Vec::new(), + }) + } else { + let now = Instant::now(); + + // Retry connecting for a full minute + let err = loop { + let err = match Master::connect(socket_path, num_queues) { + Ok(m) => { + return Ok(VhostUserHandle { + vu: m, + ready: false, + supports_migration: false, + shm_log: None, + acked_features: 0, + vrings_info: None, + queue_indexes: Vec::new(), + }) + } + Err(e) => e, + }; + sleep(Duration::from_millis(100)); + + if now.elapsed().as_secs() >= 60 { + break err; + } + }; + + error!( + "Failed connecting the backend after trying for 1 minute: {:?}", + err + ); + Err(Error::VhostUserConnect) + } + } + + pub fn socket_handle(&mut self) -> &mut Master { + &mut self.vu + } + + pub fn pause_vhost_user(&mut self) -> Result<()> { + if self.ready { + self.enable_vhost_user_vrings(self.queue_indexes.clone(), false)?; + } + + Ok(()) + } + + pub fn resume_vhost_user(&mut self) -> Result<()> { + if self.ready { + self.enable_vhost_user_vrings(self.queue_indexes.clone(), true)?; + } + + Ok(()) + } + + fn update_supports_migration(&mut self, acked_features: u64, acked_protocol_features: u64) { + if (acked_features & u64::from(vhost::vhost_kern::vhost_binding::VHOST_F_LOG_ALL) != 0) + && (acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0) + { + self.supports_migration = true; + } + } + + fn update_log_base(&mut self, last_ram_addr: u64) -> Result>> { + // Create the memfd + let fd = memfd_create( + &ffi::CString::new("vhost_user_dirty_log").unwrap(), + libc::MFD_CLOEXEC | libc::MFD_ALLOW_SEALING, + ) + .map_err(Error::MemfdCreate)?; + + // Safe because we checked the file descriptor is valid + let file = unsafe { File::from_raw_fd(fd) }; + // The size of the memory mapping corresponds to the size of a bitmap + // covering all guest pages for addresses from 0 to the last physical + // address in guest RAM. + // A page is always 4kiB from a vhost-user perspective, and each bit is + // a page. That's how we can compute mmap_size from the last address. + let mmap_size = (last_ram_addr / (VHOST_LOG_PAGE * 8)) + 1; + let mmap_handle = file.as_raw_fd(); + + // Set shm_log region size + file.set_len(mmap_size).map_err(Error::SetFileSize)?; + + // Set the seals + let res = unsafe { + libc::fcntl( + file.as_raw_fd(), + libc::F_ADD_SEALS, + libc::F_SEAL_GROW | libc::F_SEAL_SHRINK | libc::F_SEAL_SEAL, + ) + }; + if res < 0 { + return Err(Error::SetSeals(std::io::Error::last_os_error())); + } + + // Mmap shm_log region + let region = MmapRegion::build( + Some(FileOffset::new(file, 0)), + mmap_size as usize, + libc::PROT_READ | libc::PROT_WRITE, + libc::MAP_SHARED, + ) + .map_err(Error::NewMmapRegion)?; + + // Make sure we hold onto the region to prevent the mapping from being + // released. + let old_region = self.shm_log.replace(Arc::new(region)); + + // Send the shm_log fd over to the backend + let log = VhostUserDirtyLogRegion { + mmap_size, + mmap_offset: 0, + mmap_handle, + }; + self.vu + .set_log_base(0, Some(log)) + .map_err(Error::VhostUserSetLogBase)?; + + Ok(old_region) + } + + fn set_vring_logging(&mut self, enable: bool) -> Result<()> { + if let Some(vrings_info) = &self.vrings_info { + for (i, vring_info) in vrings_info.iter().enumerate() { + let mut config_data = vring_info.config_data; + config_data.flags = if enable { 1 << VHOST_VRING_F_LOG } else { 0 }; + config_data.log_addr = if enable { + Some(vring_info.used_guest_addr) + } else { + None + }; + + self.vu + .set_vring_addr(i, &config_data) + .map_err(Error::VhostUserSetVringAddr)?; + } + } + + Ok(()) + } + + pub fn start_dirty_log(&mut self, last_ram_addr: u64) -> Result<()> { + if !self.supports_migration { + return Err(Error::MigrationNotSupported); + } + + // Set the shm log region + self.update_log_base(last_ram_addr)?; + + // Enable VHOST_F_LOG_ALL feature + let features = self.acked_features | (1 << VHOST_F_LOG_ALL); + self.vu + .set_features(features) + .map_err(Error::VhostUserSetFeatures)?; + + // Enable dirty page logging of used ring for all queues + self.set_vring_logging(true) + } + + pub fn stop_dirty_log(&mut self) -> Result<()> { + if !self.supports_migration { + return Err(Error::MigrationNotSupported); + } + + // Disable dirty page logging of used ring for all queues + self.set_vring_logging(false)?; + + // Disable VHOST_F_LOG_ALL feature + self.vu + .set_features(self.acked_features) + .map_err(Error::VhostUserSetFeatures)?; + + // This is important here since the log region goes out of scope, + // invoking the Drop trait, hence unmapping the memory. + self.shm_log = None; + + Ok(()) + } +} + +fn memfd_create(name: &ffi::CStr, flags: u32) -> std::result::Result { + let res = unsafe { libc::syscall(libc::SYS_memfd_create, name.as_ptr(), flags) }; + + if res < 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(res as RawFd) + } +}