Skip to content

Switch to a safer technique for obtaining the working directory on Windows #1277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -489,22 +489,13 @@ extension _FileManagerImpl {

var currentDirectoryPath: String? {
#if os(Windows)
var dwLength: DWORD = GetCurrentDirectoryW(0, nil)
guard dwLength > 0 else { return nil }

for _ in 0 ... 8 {
if let szCurrentDirectory = withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(dwLength), {
let dwResult: DWORD = GetCurrentDirectoryW(dwLength, $0.baseAddress)
if dwResult == dwLength - 1 {
return String(decodingCString: $0.baseAddress!, as: UTF16.self)
}
dwLength = dwResult
return nil
}) {
return szCurrentDirectory
}
// Make an initial call to GetCurrentDirectoryW to get a buffer size estimate.
// This is solely to minimize the number of allocations and number of bytes allocated versus starting with a hardcoded value like MAX_PATH.
// We should NOT early-return if this returns 0, in order to avoid TOCTOU issues.
let dwSize = GetCurrentDirectoryW(0, nil)
return try? FillNullTerminatedWideStringBuffer(initialSize: dwSize >= 0 ? dwSize : DWORD(MAX_PATH), maxSize: DWORD(Int16.max)) {
GetCurrentDirectoryW(DWORD($0.count), $0.baseAddress)
}
return nil
#else
withUnsafeTemporaryAllocation(of: CChar.self, capacity: FileManager.MAX_PATH_SIZE) { buffer in
guard getcwd(buffer.baseAddress!, FileManager.MAX_PATH_SIZE) != nil else {
Expand Down
34 changes: 34 additions & 0 deletions Sources/FoundationEssentials/WinSDK+Extensions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ package var ERROR_FILENAME_EXCED_RANGE: DWORD {
DWORD(WinSDK.ERROR_FILENAME_EXCED_RANGE)
}

package var ERROR_INSUFFICIENT_BUFFER: DWORD {
DWORD(WinSDK.ERROR_INSUFFICIENT_BUFFER)
}

package var ERROR_INVALID_ACCESS: DWORD {
DWORD(WinSDK.ERROR_INVALID_ACCESS)
}
Expand Down Expand Up @@ -288,4 +292,34 @@ internal func WIN32_FROM_HRESULT(_ hr: HRESULT) -> DWORD {
return DWORD(hr)
}

/// Calls a Win32 API function that fills a (potentially long path) null-terminated string buffer by continually attempting to allocate more memory up until the true max path is reached.
/// This is especially useful for protecting against race conditions like with GetCurrentDirectoryW where the measured length may no longer be valid on subsequent calls.
/// - parameter initialSize: Initial size of the buffer (including the null terminator) to allocate to hold the returned string.
/// - parameter maxSize: Maximum size of the buffer (including the null terminator) to allocate to hold the returned string.
/// - parameter body: Closure to call the Win32 API function to populate the provided buffer.
/// Should return the number of UTF-16 code units (not including the null terminator) copied, 0 to indicate an error.
/// If the buffer is not of sufficient size, should return a value greater than or equal to the size of the buffer.
internal func FillNullTerminatedWideStringBuffer(initialSize: DWORD, maxSize: DWORD, _ body: (UnsafeMutableBufferPointer<WCHAR>) throws -> DWORD) throws -> String {
var bufferCount = max(1, min(initialSize, maxSize))
while bufferCount <= maxSize {
if let result = try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(bufferCount), { buffer in
let count = try body(buffer)
switch count {
case 0:
throw Win32Error(GetLastError())
case 1..<DWORD(buffer.count):
let result = String(decodingCString: buffer.baseAddress!, as: UTF16.self)
assert(result.utf16.count == count, "Parsed UTF-16 count \(result.utf16.count) != reported UTF-16 count \(count)")
return result
default:
bufferCount *= 2
return nil
}
}) {
return result
}
}
throw Win32Error(ERROR_INSUFFICIENT_BUFFER)
}

#endif