diff --git a/CMakeLists.txt b/CMakeLists.txt index d5f7651e8..c66cd5c7e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -431,6 +431,7 @@ add_subdirectory(src/windows/wslhost) add_subdirectory(src/windows/wslrelay) add_subdirectory(src/windows/wslinstall) add_subdirectory(src/windows/wslaclient) +add_subdirectory(src/windows/wsladiag) if (WSL_BUILD_WSL_SETTINGS) add_subdirectory(src/windows/libwsl) diff --git a/msipackage/CMakeLists.txt b/msipackage/CMakeLists.txt index 7f7bcb1ed..699638700 100644 --- a/msipackage/CMakeLists.txt +++ b/msipackage/CMakeLists.txt @@ -3,7 +3,7 @@ set(OUTPUT_PACKAGE ${BIN}/wsl.msi) set(PACKAGE_WIX_IN ${CMAKE_CURRENT_LIST_DIR}/package.wix.in) set(PACKAGE_WIX ${BIN}/package.wix) set(CAB_CACHE ${BIN}/cab) -set(BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;init;initrd.img;wslinstall.dll;wslaserviceproxystub.dll;wslaservice.exe) +set(BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;init;initrd.img;wslinstall.dll;wslaserviceproxystub.dll;wslaservice.exe;wsladiag.exe) if (WSL_BUILD_WSL_SETTINGS) list(APPEND BINARIES_DEPENDENCIES "wslsettings/wslsettings.dll;wslsettings/wslsettings.exe;libwsl.dll") @@ -39,7 +39,7 @@ add_custom_command( add_custom_target(msipackage DEPENDS ${OUTPUT_PACKAGE}) set_target_properties(msipackage PROPERTIES EXCLUDE_FROM_ALL FALSE SOURCES ${PACKAGE_WIX_IN}) -add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage wslaservice wslaserviceproxystub) +add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage wslaservice wslaserviceproxystub wsladiag) if (WSL_BUILD_WSL_SETTINGS) add_dependencies(msipackage wslsettings libwsl) diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in index f51bc8f3a..ea3a47f9f 100644 --- a/msipackage/package.wix.in +++ b/msipackage/package.wix.in @@ -27,6 +27,7 @@ + @@ -378,7 +379,6 @@ - diff --git a/src/windows/common/WslTelemetry.cpp b/src/windows/common/WslTelemetry.cpp index 766513d2c..cac31bb69 100644 --- a/src/windows/common/WslTelemetry.cpp +++ b/src/windows/common/WslTelemetry.cpp @@ -33,7 +33,7 @@ TRACELOGGING_DEFINE_PROVIDER( TraceLoggingOptionMicrosoftTelemetry()); TRACELOGGING_DEFINE_PROVIDER( - WslaServiceTelemetryProvider, + WslaTelemetryProvider, "Microsoft.Windows.Wsla", // {0383CE62-8F86-4766-AFB2-9D66A7FB1E90} (0x383ce62, 0x8f86, 0x4766, 0xaf, 0xb2, 0x9d, 0x66, 0xa7, 0xfb, 0x1e, 0x90), diff --git a/src/windows/common/WslTelemetry.h b/src/windows/common/WslTelemetry.h index 209dfd6ff..96ee3dead 100644 --- a/src/windows/common/WslTelemetry.h +++ b/src/windows/common/WslTelemetry.h @@ -28,7 +28,7 @@ extern "C" { #endif TRACELOGGING_DECLARE_PROVIDER(LxssTelemetryProvider); TRACELOGGING_DECLARE_PROVIDER(WslServiceTelemetryProvider); -TRACELOGGING_DECLARE_PROVIDER(WslaServiceTelemetryProvider); +TRACELOGGING_DECLARE_PROVIDER(WslaTelemetryProvider); #ifdef __cplusplus } #endif diff --git a/src/windows/wsladiag/CMakeLists.txt b/src/windows/wsladiag/CMakeLists.txt new file mode 100644 index 000000000..041124538 --- /dev/null +++ b/src/windows/wsladiag/CMakeLists.txt @@ -0,0 +1,16 @@ + + +set(SOURCES + main.cpp +) + +add_executable(wsladiag ${SOURCES}) + +target_link_libraries(wsladiag + ${COMMON_LINK_LIBRARIES} + common +) + +target_precompile_headers(wsladiag REUSE_FROM common) + +set_target_properties(wsladiag PROPERTIES FOLDER windows) diff --git a/src/windows/wsladiag/main.cpp b/src/windows/wsladiag/main.cpp new file mode 100644 index 000000000..a8e10fbb3 --- /dev/null +++ b/src/windows/wsladiag/main.cpp @@ -0,0 +1,139 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + main.cpp + +Abstract: + + Entry point for the wsladiag tool, performs WSL runtime initialization and parses --list/--help. + +--*/ + +#include "precomp.h" +#include "CommandLine.h" +#include "wslutil.h" +#include "wslaservice.h" +#include "WslSecurity.h" + +using namespace wsl::shared; +namespace wslutil = wsl::windows::common::wslutil; + +int wsladiag_main(std::wstring_view commandLine) +{ + wslutil::ConfigureCrt(); + wslutil::InitializeWil(); + + WslTraceLoggingInitialize(WslaTelemetryProvider, !wsl::shared::OfficialBuild); + auto cleanupTelemetry = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, []() { WslTraceLoggingUninitialize(); }); + + wslutil::SetCrtEncoding(_O_U8TEXT); + + auto coInit = wil::CoInitializeEx(COINIT_MULTITHREADED); + wslutil::CoInitializeSecurity(); + + WSADATA data{}; + THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &data)); + auto wsaCleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, []() { WSACleanup(); }); + + // Command-line parsing using ArgumentParser. + ArgumentParser parser(std::wstring{commandLine}, L"wsladiag"); + + bool help = false; + bool list = false; + + parser.AddArgument(list, L"--list"); + parser.AddArgument(help, L"--help", L'h'); // short option is a single wide char + parser.Parse(); + + auto printUsage = []() { + wslutil::PrintMessage( + L"wsladiag - WSLA diagnostics tool\n" + L"Usage:\n" + L" wsladiag --list List WSLA sessions\n" + L" wsladiag --help Show this help", + stderr); + }; + + // If '--help' was requested, print usage and exit. + if (help) + { + printUsage(); + return 0; + } + + if (!list) + { + // No recognized command → show usage + printUsage(); + return 0; + } + + // --list: Call WSLA service COM interface to retrieve and display sessions. + + try + { + wil::com_ptr userSession; + THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession))); + + wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); + + wil::unique_cotaskmem_array_ptr sessions; + + THROW_IF_FAILED(userSession->ListSessions(&sessions, sessions.size_address())); + + if (sessions.size() == 0) + { + wslutil::PrintMessage(L"No WSLA sessions found.\n", stdout); + } + else + { + wslutil::PrintMessage(std::format(L"Found {} WSLA session{}:\n", sessions.size(), sessions.size() > 1 ? L"s" : L""), stdout); + + wslutil::PrintMessage(L"ID\tCreator PID\tDisplay Name\n", stdout); + wslutil::PrintMessage(L"--\t-----------\t------------\n", stdout); + + for (const auto& session : sessions) + { + const auto * displayName = session.DisplayName; + if (displayName[0] == L'\0') + { + displayName = L""; + } + + wslutil::PrintMessage( + std::format(L"{}\t{}\t\t{}\n", session.SessionId, session.CreatorPid, displayName), stdout); + } + } + + return 0; + } + catch (...) + { + const auto hr = wil::ResultFromCaughtException(); + const std::wstring hrMessage = wslutil::ErrorCodeToString(hr); + + if (!hrMessage.empty()) + { + wslutil::PrintMessage(std::format(L"Error listing WSLA sessions: 0x{:08x} - {}\n", static_cast(hr), hrMessage), stderr); + } + else + { + wslutil::PrintMessage(std::format(L"Error listing WSLA sessions: 0x{:08x}\n", static_cast(hr)), stderr); + } + + return 1; + } +} + +int wmain(int /*argc*/, wchar_t** /*argv*/) +{ + try + { + // Use raw Unicode command line so ArgumentParser gets original input. + return wsladiag_main(GetCommandLineW()); + } + CATCH_RETURN(); +} diff --git a/src/windows/wslaservice/exe/ServiceMain.cpp b/src/windows/wslaservice/exe/ServiceMain.cpp index 3eeb06c96..872ed438a 100644 --- a/src/windows/wslaservice/exe/ServiceMain.cpp +++ b/src/windows/wslaservice/exe/ServiceMain.cpp @@ -67,7 +67,7 @@ try // Initialize telemetry. // TODO-WSLA: Create a dedicated WSLA provider - WslTraceLoggingInitialize(WslaServiceTelemetryProvider, !wsl::shared::OfficialBuild); + WslTraceLoggingInitialize(WslaTelemetryProvider, !wsl::shared::OfficialBuild); WSL_LOG("Service starting", TraceLoggingLevel(WINEVENT_LEVEL_INFO)); diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp index 525eb3f4d..a2a0ae9a1 100644 --- a/src/windows/wslaservice/exe/WSLASession.cpp +++ b/src/windows/wslaservice/exe/WSLASession.cpp @@ -20,7 +20,10 @@ Module Name: using wsl::windows::service::wsla::WSLASession; -WSLASession::WSLASession(const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl, const VIRTUAL_MACHINE_SETTINGS& VmSettings) : +WSLASession::WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, + WSLAUserSessionImpl& userSessionImpl, + const VIRTUAL_MACHINE_SETTINGS& VmSettings) : + m_id(id), m_sessionSettings(Settings), m_userSession(&userSessionImpl), m_virtualMachine(wil::MakeOrThrow(VmSettings, userSessionImpl.GetUserSid(), &userSessionImpl)), diff --git a/src/windows/wslaservice/exe/WSLASession.h b/src/windows/wslaservice/exe/WSLASession.h index 843b88d50..f3f2d8238 100644 --- a/src/windows/wslaservice/exe/WSLASession.h +++ b/src/windows/wslaservice/exe/WSLASession.h @@ -23,8 +23,12 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession : public Microsoft::WRL::RuntimeClass, IWSLASession, IFastRundown> { public: - WSLASession(const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl, const VIRTUAL_MACHINE_SETTINGS& VmSettings); - ~WSLASession(); + WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl, const VIRTUAL_MACHINE_SETTINGS& VmSettings); + ~WSLASession(); + ULONG GetId() const noexcept + { + return m_id; + } IFACEMETHOD(GetDisplayName)(LPWSTR* DisplayName) override; @@ -51,6 +55,7 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession void OnUserSessionTerminating(); private: + ULONG m_id = 0; WSLA_SESSION_SETTINGS m_sessionSettings; // TODO: Revisit to see if we should have session settings as a member or not WSLAUserSessionImpl* m_userSession = nullptr; Microsoft::WRL::ComPtr m_virtualMachine; diff --git a/src/windows/wslaservice/exe/WSLAUserSession.cpp b/src/windows/wslaservice/exe/WSLAUserSession.cpp index f7d59418d..9ebe2f0b7 100644 --- a/src/windows/wslaservice/exe/WSLAUserSession.cpp +++ b/src/windows/wslaservice/exe/WSLAUserSession.cpp @@ -50,19 +50,51 @@ PSID WSLAUserSessionImpl::GetUserSid() const HRESULT wsl::windows::service::wsla::WSLAUserSessionImpl::CreateSession( const WSLA_SESSION_SETTINGS* Settings, const VIRTUAL_MACHINE_SETTINGS* VmSettings, IWSLASession** WslaSession) { - auto session = wil::MakeOrThrow(*Settings, *this, *VmSettings); - - std::lock_guard lock(m_wslaSessionsLock); - auto it = m_sessions.emplace(session.Get()); + ULONG id = m_nextSessionId++; + auto session = wil::MakeOrThrow(id, *Settings, *this, *VmSettings); + { + std::lock_guard lock(m_wslaSessionsLock); + auto it = m_sessions.emplace(session.Get()); + m_wslaSessions.emplace_back(session); // Client now owns the session. // TODO: Add a flag for the client to specify that the session should outlive its process. - + } + THROW_IF_FAILED(session.CopyTo(__uuidof(IWSLASession), (void**)WslaSession)); return S_OK; } +HRESULT wsl::windows::service::wsla::WSLAUserSessionImpl::ListSessions(_Out_ WSLA_SESSION_INFORMATION** Sessions, _Out_ ULONG* SessionsCount) +{ + const auto count = m_wslaSessions.size(); + auto output = wil::make_unique_cotaskmem(m_wslaSessions.size()); + std::lock_guard lock(m_wslaSessionsLock); + for (size_t i = 0; i < m_wslaSessions.size(); ++i) + { + output[i].SessionId = m_wslaSessions[i]->GetId(); + output[i].CreatorPid = 0; // placeholder until we populate this later + PWSTR tempName = nullptr; + + RETURN_IF_FAILED(m_wslaSessions[i]->GetDisplayName(&tempName)); + + if (tempName) + { + wcscpy_s(output[i].DisplayName, tempName); + CoTaskMemFree(tempName); + } + else + { + output[i].DisplayName[0] = L'\0'; + } + + } + *Sessions = output.release(); + *SessionsCount = static_cast(count); + return S_OK; +} + wsl::windows::service::wsla::WSLAUserSession::WSLAUserSession(std::weak_ptr&& Session) : m_session(std::move(Session)) { @@ -89,9 +121,21 @@ try CATCH_RETURN(); HRESULT wsl::windows::service::wsla::WSLAUserSession::ListSessions(WSLA_SESSION_INFORMATION** Sessions, ULONG* SessionsCount) +try { - return E_NOTIMPL; + if (!Sessions || !SessionsCount) + { + return E_INVALIDARG; + } + + auto session = m_session.lock(); + RETURN_HR_IF(RPC_E_DISCONNECTED, !session); + + RETURN_IF_FAILED(session->ListSessions(Sessions, SessionsCount)); + return S_OK; } +CATCH_RETURN(); + HRESULT wsl::windows::service::wsla::WSLAUserSession::OpenSession(ULONG Id, IWSLASession** Session) { return E_NOTIMPL; diff --git a/src/windows/wslaservice/exe/WSLAUserSession.h b/src/windows/wslaservice/exe/WSLAUserSession.h index a7b6635ce..8657b77d4 100644 --- a/src/windows/wslaservice/exe/WSLAUserSession.h +++ b/src/windows/wslaservice/exe/WSLAUserSession.h @@ -15,6 +15,8 @@ Module Name: #pragma once #include "WSLAVirtualMachine.h" #include "WSLASession.h" +#include +#include namespace wsl::windows::service::wsla { @@ -30,14 +32,17 @@ class WSLAUserSessionImpl PSID GetUserSid() const; HRESULT CreateSession(const WSLA_SESSION_SETTINGS* Settings, const VIRTUAL_MACHINE_SETTINGS* VmSettings, IWSLASession** WslaSession); - + HRESULT ListSessions(_Out_ WSLA_SESSION_INFORMATION** Sessions, _Out_ ULONG* SessionsCount); + void OnVmTerminated(WSLAVirtualMachine* machine); void OnSessionTerminated(WSLASession* Session); private: wil::unique_tokeninfo_ptr m_tokenInfo; - + std::atomic m_nextSessionId{1}; std::recursive_mutex m_wslaSessionsLock; std::recursive_mutex m_lock; + // Track active sessions for diagnostics / ListSessions. + std::vector> m_wslaSessions; // TODO-WSLA: Consider using a weak_ptr to easily destroy when the last client reference is released. std::unordered_set m_sessions; diff --git a/src/windows/wslaservice/exe/WSLAUserSessionFactory.cpp b/src/windows/wslaservice/exe/WSLAUserSessionFactory.cpp index e13effec7..0a28ede8e 100644 --- a/src/windows/wslaservice/exe/WSLAUserSessionFactory.cpp +++ b/src/windows/wslaservice/exe/WSLAUserSessionFactory.cpp @@ -50,11 +50,15 @@ HRESULT WSLAUserSessionFactory::CreateInstance(_In_ IUnknown* pUnkOuter, _In_ RE THROW_HR_IF(CO_E_SERVER_STOPPING, !g_sessions.has_value()); auto session = std::find_if(g_sessions->begin(), g_sessions->end(), [&tokenInfo](auto it) { - return EqualSid(it->GetUserSid(), &tokenInfo->User.Sid); + return EqualSid(it->GetUserSid(), tokenInfo->User.Sid); }); if (session == g_sessions->end()) { + wil::unique_hlocal_string sid; + THROW_IF_WIN32_BOOL_FALSE(ConvertSidToStringSid(tokenInfo->User.Sid, &sid)); + WSL_LOG("WSLAUserSession created", TraceLoggingValue(sid.get(), "sid")); + session = g_sessions->insert(g_sessions->end(), std::make_shared(userToken.get(), std::move(tokenInfo))); } diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.h b/src/windows/wslaservice/exe/WSLAVirtualMachine.h index a2ed05323..b0f01a957 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.h +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.h @@ -48,29 +48,29 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine ~WSLAVirtualMachine(); void Start(); - void OnSessionTerminated(); + void OnSessionTerminating(); IFACEMETHOD(CreateLinuxProcess(_In_ const WSLA_PROCESS_OPTIONS* Options, _Out_ IWSLAProcess** Process, _Out_ int* Errno)) override; IFACEMETHOD(WaitPid(_In_ LONG Pid, _In_ ULONGLONG TimeoutMs, _Out_ ULONG* State, _Out_ int* Code)) override; IFACEMETHOD(Signal(_In_ LONG Pid, _In_ int Signal)) override; IFACEMETHOD(Shutdown(ULONGLONG _In_ TimeoutMs)) override; + IFACEMETHOD(RegisterCallback(_In_ ITerminationCallback* callback)) override; IFACEMETHOD(GetDebugShellPipe(_Out_ LPWSTR* pipePath)) override; IFACEMETHOD(MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove)) override; IFACEMETHOD(Unmount(_In_ const char* Path)) override; + IFACEMETHOD(DetachDisk(_In_ ULONG Lun)) override; IFACEMETHOD(MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly)) override; IFACEMETHOD(UnmountWindowsFolder(_In_ LPCSTR LinuxPath)) override; void MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint, _In_ DWORD Flags); void OnProcessReleased(int Pid); - void RegisterCallback(_In_ ITerminationCallback* callback); Microsoft::WRL::ComPtr CreateLinuxProcess( _In_ const WSLA_PROCESS_OPTIONS& Options, int* Errno = nullptr, const TPrepareCommandLine& PrepareCommandLine = [](const auto&) {}); +private: std::pair AttachDisk(_In_ PCWSTR Path, _In_ BOOL ReadOnly); - void DetachDisk(_In_ ULONG Lun); -private: static void Mount(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags); static void CALLBACK s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context); static bool ParseTtyInformation( @@ -82,8 +82,7 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine void OnCrash(_In_ const HCS_EVENT* Event); std::tuple Fork(enum WSLA_FORK::ForkType Type); - std::tuple Fork( - wsl::shared::SocketChannel& Channel, enum WSLA_FORK::ForkType Type, ULONG TtyRows = 0, ULONG TtyColumns = 0); + std::tuple Fork(wsl::shared::SocketChannel& Channel, enum WSLA_FORK::ForkType Type); int32_t ExpectClosedChannelOrError(wsl::shared::SocketChannel& Channel); ConnectedSocket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd); @@ -122,8 +121,6 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine PSID m_userSid{}; wil::unique_handle m_userToken; std::wstring m_debugShellPipe; - - std::mutex m_trackedProcessesLock; std::vector m_trackedProcesses; wsl::windows::common::hcs::unique_hcs_system m_computeSystem; @@ -147,5 +144,6 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine std::map m_plan9Mounts; std::recursive_mutex m_lock; std::mutex m_portRelaylock; + WSLAUserSessionImpl* m_userSession; }; } // namespace wsl::windows::service::wsla \ No newline at end of file diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl index 33c055303..a25317e93 100644 --- a/src/windows/wslaservice/inc/wslaservice.idl +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -296,9 +296,9 @@ interface IWSLASession : IUnknown struct WSLA_SESSION_INFORMATION { - ULONG Id; + ULONG SessionId; DWORD CreatorPid; - LPSTR DisplayName; + wchar_t DisplayName[256]; }; [