diff --git a/src/windows/common/WslClient.cpp b/src/windows/common/WslClient.cpp index bd3f9e410..f2c605b18 100644 --- a/src/windows/common/WslClient.cpp +++ b/src/windows/common/WslClient.cpp @@ -1557,6 +1557,7 @@ int WslaShell(_In_ std::wstring_view commandLine) parser.AddArgument(reinterpret_cast(settings.EnableDnsTunneling), L"--dns-tunneling"); parser.AddArgument(Integer(settings.MemoryMb), L"--memory"); parser.AddArgument(Integer(settings.CpuCount), L"--cpu"); + parser.AddArgument(Integer(reinterpret_cast(settings.NetworkingMode)), L"--networking-mode"); parser.AddArgument(Utf8String(fsType), L"--fstype"); parser.AddArgument(containerRootVhd, L"--container-vhd"); parser.AddArgument(help, L"--help"); @@ -1565,13 +1566,23 @@ int WslaShell(_In_ std::wstring_view commandLine) if (help) { const auto usage = std::format( - LR"({} --wsla [--vhd ] [--shell ] [--memory ] [--cpu ] [--dns-tunneling] [--fstype ] [--container-vhd ] [--help])", + LR"({} --wsla [--vhd ] [--shell ] [--memory ] [--cpu ] [--dns-tunneling] [--networking-mode ] [--fstype ] [--container-vhd ] [--help])", WSL_BINARY_NAME); wprintf(L"%ls\n", usage.c_str()); return 1; } + switch (settings.NetworkingMode) + { + case WSLANetworkingMode::WSLANetworkingModeNone: + case WSLANetworkingMode::WSLANetworkingModeNAT: + case WSLANetworkingMode::WSLANetworkingModeVirtioProxy: + break; + default: + THROW_HR(E_INVALIDARG); + } + if (!containerRootVhd.empty()) { settings.ContainerRootVhd = containerRootVhd.c_str(); diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp index 0c8265cfc..aa56004af 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp @@ -16,9 +16,9 @@ Module Name: #include #include #include "hcs_schema.h" +#include "VirtioNetworking.h" #include "NatNetworking.h" #include "WSLAUserSession.h" -#include "DnsResolver.h" #include "ServiceProcessLauncher.h" using namespace wsl::windows::common; @@ -94,6 +94,12 @@ WSLAVirtualMachine::~WSLAVirtualMachine() WSL_LOG("WSLATerminateVm", TraceLoggingValue(forceTerminate, "forced"), TraceLoggingValue(m_running, "running")); + // Shutdown DeviceHostProxy before resetting compute system + if (m_guestDeviceManager) + { + m_guestDeviceManager->Shutdown(); + } + m_computeSystem.reset(); for (const auto& e : m_attachedDisks) @@ -308,6 +314,13 @@ void WSLAVirtualMachine::Start() auto runtimeId = wsl::windows::common::hcs::GetRuntimeId(m_computeSystem.get()); WI_ASSERT(IsEqualGUID(m_vmId, runtimeId)); + // Initialize DeviceHostProxy for virtio device support. + // N.B. This is currently only needed for VirtioProxy networking mode but would also be needed for virtiofs. + if (m_settings.NetworkingMode == WSLANetworkingModeVirtioProxy) + { + m_guestDeviceManager = std::make_shared(m_vmIdString, m_vmId); + } + wsl::windows::common::hcs::RegisterCallback(m_computeSystem.get(), &s_OnExit, this); wsl::windows::common::hcs::StartComputeSystem(m_computeSystem.get(), json.c_str()); @@ -445,59 +458,71 @@ CATCH_LOG(); void WSLAVirtualMachine::ConfigureNetworking() { - if (m_settings.NetworkingMode == WSLANetworkingModeNone) + switch (m_settings.NetworkingMode) { + case WSLANetworkingModeNone: return; + case WSLANetworkingModeNAT: + case WSLANetworkingModeVirtioProxy: + break; + default: + THROW_HR_MSG(E_INVALIDARG, "Invalid networking mode: %lu", m_settings.NetworkingMode); } - else if (m_settings.NetworkingMode == WSLANetworkingModeNAT) - { - // Launch GNS - std::vector fds(1); - fds[0].Fd = -1; - fds[0].Type = WSLAFdType::WSLAFdTypeDefault; - - std::vector cmd{"/gns", LX_INIT_GNS_SOCKET_ARG}; - // If DNS tunnelling is enabled, use an additional for its channel. - if (m_settings.EnableDnsTunneling) - { - fds.emplace_back(WSLA_PROCESS_FD{.Fd = -1, .Type = WSLAFdType::WSLAFdTypeDefault}); - THROW_IF_FAILED(wsl::core::networking::DnsResolver::LoadDnsResolverMethods()); - } + // Launch GNS + std::vector fds(1); + fds[0].Fd = -1; + fds[0].Type = WSLAFdType::WSLAFdTypeDefault; - WSLA_PROCESS_OPTIONS options{}; - options.Executable = "/init"; - options.Fds = fds.data(); - options.FdsCount = static_cast(fds.size()); + std::vector cmd{"/gns", LX_INIT_GNS_SOCKET_ARG}; - // Because the file descriptors numbers aren't known in advance, the command line needs to be generated after the file - // descriptors are allocated. + // If DNS tunnelling is enabled, use an additional for its channel. + if (m_settings.EnableDnsTunneling) + { + THROW_HR_IF_MSG( + E_NOTIMPL, + m_settings.NetworkingMode == WSLANetworkingModeVirtioProxy, + "DNS tunneling not currently supported for VirtioProxy"); - std::string socketFdArg; - std::string dnsFdArg; - int gnsChannelFd = -1; - int dnsChannelFd = -1; - auto prepareCommandLine = [&](const auto& sockets) { - gnsChannelFd = sockets[0].Fd; - socketFdArg = std::to_string(gnsChannelFd); - cmd.emplace_back(socketFdArg.c_str()); + fds.emplace_back(WSLA_PROCESS_FD{.Fd = -1, .Type = WSLAFdType::WSLAFdTypeDefault}); + THROW_IF_FAILED(wsl::core::networking::DnsResolver::LoadDnsResolverMethods()); + } - if (sockets.size() > 1) - { - dnsChannelFd = sockets[1].Fd; - dnsFdArg = std::to_string(dnsChannelFd); - cmd.emplace_back(LX_INIT_GNS_DNS_SOCKET_ARG); - cmd.emplace_back(dnsFdArg.c_str()); - cmd.emplace_back(LX_INIT_GNS_DNS_TUNNELING_IP); - cmd.emplace_back(LX_INIT_DNS_TUNNELING_IP_ADDRESS); - } + WSLA_PROCESS_OPTIONS options{}; + options.Executable = "/init"; + options.Fds = fds.data(); + options.FdsCount = static_cast(fds.size()); + + // Because the file descriptors numbers aren't known in advance, the command line needs to be generated after the file + // descriptors are allocated. + std::string socketFdArg; + std::string dnsFdArg; + int gnsChannelFd = -1; + int dnsChannelFd = -1; + auto prepareCommandLine = [&](const auto& sockets) { + gnsChannelFd = sockets[0].Fd; + socketFdArg = std::to_string(gnsChannelFd); + cmd.emplace_back(socketFdArg.c_str()); + + if (sockets.size() > 1) + { + dnsChannelFd = sockets[1].Fd; + dnsFdArg = std::to_string(dnsChannelFd); + cmd.emplace_back(LX_INIT_GNS_DNS_SOCKET_ARG); + cmd.emplace_back(dnsFdArg.c_str()); + cmd.emplace_back(LX_INIT_GNS_DNS_TUNNELING_IP); + cmd.emplace_back(LX_INIT_DNS_TUNNELING_IP_ADDRESS); + } - options.CommandLine = cmd.data(); - options.CommandLineCount = static_cast(cmd.size()); - }; + options.CommandLine = cmd.data(); + options.CommandLineCount = static_cast(cmd.size()); + }; - auto process = CreateLinuxProcess(options, nullptr, prepareCommandLine); + auto process = CreateLinuxProcess(options, nullptr, prepareCommandLine); + auto gnsChannel = wsl::core::GnsChannel(wil::unique_socket{(SOCKET)process->GetStdHandle(gnsChannelFd).release()}); + if (m_settings.NetworkingMode == WSLANetworkingModeNAT) + { // TODO: refactor this to avoid using wsl config static wsl::core::Config config(nullptr); @@ -510,18 +535,18 @@ void WSLAVirtualMachine::ConfigureNetworking() m_networkEngine = std::make_unique( m_computeSystem.get(), wsl::core::NatNetworking::CreateNetwork(config), - wil::unique_socket{(SOCKET)process->GetStdHandle(gnsChannelFd).release()}, + std::move(gnsChannel), config, dnsChannelFd != -1 ? wil::unique_socket{(SOCKET)process->GetStdHandle(dnsChannelFd).release()} : wil::unique_socket{}); - - m_networkEngine->Initialize(); - - LaunchPortRelay(); } else { - THROW_HR_MSG(E_INVALIDARG, "Invalid networking mode: %lu", m_settings.NetworkingMode); + m_networkEngine = std::make_unique(std::move(gnsChannel), true, m_guestDeviceManager, m_userToken); } + + m_networkEngine->Initialize(); + + LaunchPortRelay(); } void CALLBACK WSLAVirtualMachine::s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context) diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.h b/src/windows/wslaservice/exe/WSLAVirtualMachine.h index a2ed05323..5620b0c92 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.h +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.h @@ -16,6 +16,8 @@ Module Name: #include "INetworkingEngine.h" #include "hcs.hpp" #include "Dmesg.h" +#include "DnsResolver.h" +#include "GuestDeviceManager.h" #include "WSLAApi.h" #include "WSLAProcess.h" @@ -120,7 +122,7 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine int m_coldDiscardShiftSize{}; bool m_running = false; PSID m_userSid{}; - wil::unique_handle m_userToken; + wil::shared_handle m_userToken; std::wstring m_debugShellPipe; std::mutex m_trackedProcessesLock; @@ -133,6 +135,7 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine bool m_vmSavedStateCaptured = false; bool m_crashLogCaptured = false; + std::shared_ptr m_guestDeviceManager; std::shared_ptr m_dmesgCollector; wil::unique_event m_vmExitEvent{wil::EventOptions::ManualReset}; wil::unique_event m_vmTerminatingEvent{wil::EventOptions::ManualReset}; diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl index 33c055303..4a9049081 100644 --- a/src/windows/wslaservice/inc/wslaservice.idl +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -209,7 +209,8 @@ interface IWSLAVirtualMachine : IUnknown typedef enum _WSLANetworkingMode { WSLANetworkingModeNone, - WSLANetworkingModeNAT + WSLANetworkingModeNAT, + WSLANetworkingModeVirtioProxy } WSLANetworkingMode; typedef diff --git a/test/windows/WSLATests.cpp b/test/windows/WSLATests.cpp index f7cb14635..d07c84862 100644 --- a/test/windows/WSLATests.cpp +++ b/test/windows/WSLATests.cpp @@ -380,6 +380,31 @@ class WSLATests VERIFY_ARE_EQUAL(result.Output[1], std::format("nameserver {}\n", LX_INIT_DNS_TUNNELING_IP_ADDRESS)); } + TEST_METHOD(VirtioProxyNetworking) + { + WSL2_TEST_ONLY(); + + VIRTUAL_MACHINE_SETTINGS settings{}; + settings.CpuCount = 4; + settings.DisplayName = L"WSLA"; + settings.MemoryMb = 2048; + settings.BootTimeoutMs = 30 * 1000; + settings.NetworkingMode = WSLANetworkingModeVirtioProxy; + settings.RootVhd = testVhd.c_str(); + + auto session = CreateSession(settings); + + // Validate that eth0 has an ip address + ExpectCommandResult( + session.get(), + {"/bin/bash", + "-c", + "ip a show dev eth0 | grep -iF 'inet ' | grep -E '[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}'"}, + 0); + + ExpectCommandResult(session.get(), {"/bin/grep", "-iF", "nameserver", "/etc/resolv.conf"}, 0); + } + TEST_METHOD(OpenFiles) { WSL2_TEST_ONLY();