Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/windows/common/WslClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1557,6 +1557,7 @@ int WslaShell(_In_ std::wstring_view commandLine)
parser.AddArgument(reinterpret_cast<bool&>(settings.EnableDnsTunneling), L"--dns-tunneling");
parser.AddArgument(Integer(settings.MemoryMb), L"--memory");
parser.AddArgument(Integer(settings.CpuCount), L"--cpu");
parser.AddArgument(Integer(reinterpret_cast<int&>(settings.NetworkingMode)), L"--networking-mode");
parser.AddArgument(Utf8String(fsType), L"--fstype");
parser.AddArgument(containerRootVhd, L"--container-vhd");
parser.AddArgument(help, L"--help");
Expand All @@ -1565,13 +1566,23 @@ int WslaShell(_In_ std::wstring_view commandLine)
if (help)
{
const auto usage = std::format(
LR"({} --wsla [--vhd </path/to/vhd>] [--shell </path/to/shell>] [--memory <memory-mb>] [--cpu <cpus>] [--dns-tunneling] [--fstype <fstype>] [--container-vhd </path/to/vhd>] [--help])",
LR"({} --wsla [--vhd </path/to/vhd>] [--shell </path/to/shell>] [--memory <memory-mb>] [--cpu <cpus>] [--dns-tunneling] [--networking-mode <mode>] [--fstype <fstype>] [--container-vhd </path/to/vhd>] [--help])",
WSL_BINARY_NAME);

wprintf(L"%ls\n", usage.c_str());
return 1;
}

switch (settings.NetworkingMode)
{
case WSLANetworkingMode::WSLANetworkingModeNone:
case WSLANetworkingMode::WSLANetworkingModeNAT:
case WSLANetworkingMode::WSLANetworkingModeVirtioProxy:
break;
default:
THROW_HR(E_INVALIDARG);
}

if (!containerRootVhd.empty())
{
settings.ContainerRootVhd = containerRootVhd.c_str();
Expand Down
123 changes: 74 additions & 49 deletions src/windows/wslaservice/exe/WSLAVirtualMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ Module Name:
#include <format>
#include <filesystem>
#include "hcs_schema.h"
#include "VirtioNetworking.h"
#include "NatNetworking.h"
#include "WSLAUserSession.h"
#include "DnsResolver.h"
#include "ServiceProcessLauncher.h"

using namespace wsl::windows::common;
Expand Down Expand Up @@ -94,6 +94,12 @@ WSLAVirtualMachine::~WSLAVirtualMachine()

WSL_LOG("WSLATerminateVm", TraceLoggingValue(forceTerminate, "forced"), TraceLoggingValue(m_running, "running"));

// Shutdown DeviceHostProxy before resetting compute system
if (m_guestDeviceManager)
{
m_guestDeviceManager->Shutdown();
}

m_computeSystem.reset();

for (const auto& e : m_attachedDisks)
Expand Down Expand Up @@ -308,6 +314,13 @@ void WSLAVirtualMachine::Start()
auto runtimeId = wsl::windows::common::hcs::GetRuntimeId(m_computeSystem.get());
WI_ASSERT(IsEqualGUID(m_vmId, runtimeId));

// Initialize DeviceHostProxy for virtio device support.
// N.B. This is currently only needed for VirtioProxy networking mode but would also be needed for virtiofs.
if (m_settings.NetworkingMode == WSLANetworkingModeVirtioProxy)
{
m_guestDeviceManager = std::make_shared<GuestDeviceManager>(m_vmIdString, m_vmId);
}

wsl::windows::common::hcs::RegisterCallback(m_computeSystem.get(), &s_OnExit, this);

wsl::windows::common::hcs::StartComputeSystem(m_computeSystem.get(), json.c_str());
Expand Down Expand Up @@ -445,59 +458,71 @@ CATCH_LOG();

void WSLAVirtualMachine::ConfigureNetworking()
{
if (m_settings.NetworkingMode == WSLANetworkingModeNone)
switch (m_settings.NetworkingMode)
{
case WSLANetworkingModeNone:
return;
case WSLANetworkingModeNAT:
case WSLANetworkingModeVirtioProxy:
break;
default:
THROW_HR_MSG(E_INVALIDARG, "Invalid networking mode: %lu", m_settings.NetworkingMode);
}
else if (m_settings.NetworkingMode == WSLANetworkingModeNAT)
{
// Launch GNS
std::vector<WSLA_PROCESS_FD> fds(1);
fds[0].Fd = -1;
fds[0].Type = WSLAFdType::WSLAFdTypeDefault;

std::vector<const char*> cmd{"/gns", LX_INIT_GNS_SOCKET_ARG};

// If DNS tunnelling is enabled, use an additional for its channel.
if (m_settings.EnableDnsTunneling)
{
fds.emplace_back(WSLA_PROCESS_FD{.Fd = -1, .Type = WSLAFdType::WSLAFdTypeDefault});
THROW_IF_FAILED(wsl::core::networking::DnsResolver::LoadDnsResolverMethods());
}
// Launch GNS
std::vector<WSLA_PROCESS_FD> fds(1);
fds[0].Fd = -1;
fds[0].Type = WSLAFdType::WSLAFdTypeDefault;

WSLA_PROCESS_OPTIONS options{};
options.Executable = "/init";
options.Fds = fds.data();
options.FdsCount = static_cast<DWORD>(fds.size());
std::vector<const char*> cmd{"/gns", LX_INIT_GNS_SOCKET_ARG};

// Because the file descriptors numbers aren't known in advance, the command line needs to be generated after the file
// descriptors are allocated.
// If DNS tunnelling is enabled, use an additional for its channel.
if (m_settings.EnableDnsTunneling)
{
THROW_HR_IF_MSG(
E_NOTIMPL,
m_settings.NetworkingMode == WSLANetworkingModeVirtioProxy,
"DNS tunneling not currently supported for VirtioProxy");

std::string socketFdArg;
std::string dnsFdArg;
int gnsChannelFd = -1;
int dnsChannelFd = -1;
auto prepareCommandLine = [&](const auto& sockets) {
gnsChannelFd = sockets[0].Fd;
socketFdArg = std::to_string(gnsChannelFd);
cmd.emplace_back(socketFdArg.c_str());
fds.emplace_back(WSLA_PROCESS_FD{.Fd = -1, .Type = WSLAFdType::WSLAFdTypeDefault});
THROW_IF_FAILED(wsl::core::networking::DnsResolver::LoadDnsResolverMethods());
}

if (sockets.size() > 1)
{
dnsChannelFd = sockets[1].Fd;
dnsFdArg = std::to_string(dnsChannelFd);
cmd.emplace_back(LX_INIT_GNS_DNS_SOCKET_ARG);
cmd.emplace_back(dnsFdArg.c_str());
cmd.emplace_back(LX_INIT_GNS_DNS_TUNNELING_IP);
cmd.emplace_back(LX_INIT_DNS_TUNNELING_IP_ADDRESS);
}
WSLA_PROCESS_OPTIONS options{};
options.Executable = "/init";
options.Fds = fds.data();
options.FdsCount = static_cast<DWORD>(fds.size());

// Because the file descriptors numbers aren't known in advance, the command line needs to be generated after the file
// descriptors are allocated.
std::string socketFdArg;
std::string dnsFdArg;
int gnsChannelFd = -1;
int dnsChannelFd = -1;
auto prepareCommandLine = [&](const auto& sockets) {
gnsChannelFd = sockets[0].Fd;
socketFdArg = std::to_string(gnsChannelFd);
cmd.emplace_back(socketFdArg.c_str());

if (sockets.size() > 1)
{
dnsChannelFd = sockets[1].Fd;
dnsFdArg = std::to_string(dnsChannelFd);
cmd.emplace_back(LX_INIT_GNS_DNS_SOCKET_ARG);
cmd.emplace_back(dnsFdArg.c_str());
cmd.emplace_back(LX_INIT_GNS_DNS_TUNNELING_IP);
cmd.emplace_back(LX_INIT_DNS_TUNNELING_IP_ADDRESS);
}

options.CommandLine = cmd.data();
options.CommandLineCount = static_cast<DWORD>(cmd.size());
};
options.CommandLine = cmd.data();
options.CommandLineCount = static_cast<DWORD>(cmd.size());
};

auto process = CreateLinuxProcess(options, nullptr, prepareCommandLine);
auto process = CreateLinuxProcess(options, nullptr, prepareCommandLine);
auto gnsChannel = wsl::core::GnsChannel(wil::unique_socket{(SOCKET)process->GetStdHandle(gnsChannelFd).release()});

if (m_settings.NetworkingMode == WSLANetworkingModeNAT)
{
// TODO: refactor this to avoid using wsl config
static wsl::core::Config config(nullptr);

Expand All @@ -510,18 +535,18 @@ void WSLAVirtualMachine::ConfigureNetworking()
m_networkEngine = std::make_unique<wsl::core::NatNetworking>(
m_computeSystem.get(),
wsl::core::NatNetworking::CreateNetwork(config),
wil::unique_socket{(SOCKET)process->GetStdHandle(gnsChannelFd).release()},
std::move(gnsChannel),
config,
dnsChannelFd != -1 ? wil::unique_socket{(SOCKET)process->GetStdHandle(dnsChannelFd).release()} : wil::unique_socket{});

m_networkEngine->Initialize();

LaunchPortRelay();
}
else
{
THROW_HR_MSG(E_INVALIDARG, "Invalid networking mode: %lu", m_settings.NetworkingMode);
m_networkEngine = std::make_unique<wsl::core::VirtioNetworking>(std::move(gnsChannel), true, m_guestDeviceManager, m_userToken);
}

m_networkEngine->Initialize();

LaunchPortRelay();
}

void CALLBACK WSLAVirtualMachine::s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context)
Expand Down
5 changes: 4 additions & 1 deletion src/windows/wslaservice/exe/WSLAVirtualMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Module Name:
#include "INetworkingEngine.h"
#include "hcs.hpp"
#include "Dmesg.h"
#include "DnsResolver.h"
#include "GuestDeviceManager.h"
#include "WSLAApi.h"
#include "WSLAProcess.h"

Expand Down Expand Up @@ -120,7 +122,7 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine
int m_coldDiscardShiftSize{};
bool m_running = false;
PSID m_userSid{};
wil::unique_handle m_userToken;
wil::shared_handle m_userToken;
std::wstring m_debugShellPipe;

std::mutex m_trackedProcessesLock;
Expand All @@ -133,6 +135,7 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine
bool m_vmSavedStateCaptured = false;
bool m_crashLogCaptured = false;

std::shared_ptr<GuestDeviceManager> m_guestDeviceManager;
std::shared_ptr<DmesgCollector> m_dmesgCollector;
wil::unique_event m_vmExitEvent{wil::EventOptions::ManualReset};
wil::unique_event m_vmTerminatingEvent{wil::EventOptions::ManualReset};
Expand Down
3 changes: 2 additions & 1 deletion src/windows/wslaservice/inc/wslaservice.idl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ interface IWSLAVirtualMachine : IUnknown
typedef enum _WSLANetworkingMode
{
WSLANetworkingModeNone,
WSLANetworkingModeNAT
WSLANetworkingModeNAT,
WSLANetworkingModeVirtioProxy
} WSLANetworkingMode;

typedef
Expand Down
25 changes: 25 additions & 0 deletions test/windows/WSLATests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,31 @@ class WSLATests
VERIFY_ARE_EQUAL(result.Output[1], std::format("nameserver {}\n", LX_INIT_DNS_TUNNELING_IP_ADDRESS));
}

TEST_METHOD(VirtioProxyNetworking)
{
WSL2_TEST_ONLY();

VIRTUAL_MACHINE_SETTINGS settings{};
settings.CpuCount = 4;
settings.DisplayName = L"WSLA";
settings.MemoryMb = 2048;
settings.BootTimeoutMs = 30 * 1000;
settings.NetworkingMode = WSLANetworkingModeVirtioProxy;
settings.RootVhd = testVhd.c_str();

auto session = CreateSession(settings);

// Validate that eth0 has an ip address
ExpectCommandResult(
session.get(),
{"/bin/bash",
"-c",
"ip a show dev eth0 | grep -iF 'inet ' | grep -E '[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}'"},
0);

ExpectCommandResult(session.get(), {"/bin/grep", "-iF", "nameserver", "/etc/resolv.conf"}, 0);
}

TEST_METHOD(OpenFiles)
{
WSL2_TEST_ONLY();
Expand Down