Skip to content

Commit f77a636

Browse files
committed
Switch to a safer technique for obtaining the working directory on Windows
Instead of looping 8 times to work around the TOCTOU issue with sizing the current directory buffer, instead keep doubling the buffer up until the 32767 character limit until the result fits. This ensures we always get a working directory if GetWorkingDirectoryW didn't return some other error, rather than returning nil in the case of a race condition.
1 parent 7ae9160 commit f77a636

File tree

2 files changed

+40
-15
lines changed

2 files changed

+40
-15
lines changed

Sources/FoundationEssentials/FileManager/FileManager+Directories.swift

+6-15
Original file line numberDiff line numberDiff line change
@@ -489,22 +489,13 @@ extension _FileManagerImpl {
489489

490490
var currentDirectoryPath: String? {
491491
#if os(Windows)
492-
var dwLength: DWORD = GetCurrentDirectoryW(0, nil)
493-
guard dwLength > 0 else { return nil }
494-
495-
for _ in 0 ... 8 {
496-
if let szCurrentDirectory = withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(dwLength), {
497-
let dwResult: DWORD = GetCurrentDirectoryW(dwLength, $0.baseAddress)
498-
if dwResult == dwLength - 1 {
499-
return String(decodingCString: $0.baseAddress!, as: UTF16.self)
500-
}
501-
dwLength = dwResult
502-
return nil
503-
}) {
504-
return szCurrentDirectory
505-
}
492+
// Make an initial call to GetCurrentDirectoryW to get a buffer size estimate.
493+
// This is solely to minimize the number of allocations and number of bytes allocated versus starting with a hardcoded value like MAX_PATH.
494+
// We should NOT early-return if this returns 0, in order to avoid TOCTOU issues.
495+
let dwSize = GetCurrentDirectoryW(0, nil)
496+
return try? FillNullTerminatedWideStringBuffer(initialSize: dwSize >= 0 ? dwSize : DWORD(MAX_PATH), maxSize: DWORD(Int16.max)) {
497+
GetCurrentDirectoryW(DWORD($0.count), $0.baseAddress)
506498
}
507-
return nil
508499
#else
509500
withUnsafeTemporaryAllocation(of: CChar.self, capacity: FileManager.MAX_PATH_SIZE) { buffer in
510501
guard getcwd(buffer.baseAddress!, FileManager.MAX_PATH_SIZE) != nil else {

Sources/FoundationEssentials/WinSDK+Extensions.swift

+34
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ package var ERROR_FILENAME_EXCED_RANGE: DWORD {
8181
DWORD(WinSDK.ERROR_FILENAME_EXCED_RANGE)
8282
}
8383

84+
package var ERROR_INSUFFICIENT_BUFFER: DWORD {
85+
DWORD(WinSDK.ERROR_INSUFFICIENT_BUFFER)
86+
}
87+
8488
package var ERROR_INVALID_ACCESS: DWORD {
8589
DWORD(WinSDK.ERROR_INVALID_ACCESS)
8690
}
@@ -288,4 +292,34 @@ internal func WIN32_FROM_HRESULT(_ hr: HRESULT) -> DWORD {
288292
return DWORD(hr)
289293
}
290294

295+
/// Calls a Win32 API function that fills a (potentially long path) null-terminated string buffer by continually attempting to allocate more memory up until the true max path is reached.
296+
/// This is especially useful for protecting against race conditions like with GetCurrentDirectoryW where the measured length may no longer be valid on subsequent calls.
297+
/// - parameter initialSize: Initial size of the buffer (including the null terminator) to allocate to hold the returned string.
298+
/// - parameter maxSize: Maximum size of the buffer (including the null terminator) to allocate to hold the returned string.
299+
/// - parameter body: Closure to call the Win32 API function to populate the provided buffer.
300+
/// Should return the number of UTF-16 code units (not including the null terminator) copied, 0 to indicate an error.
301+
/// If the buffer is not of sufficient size, should return a value greater than or equal to the size of the buffer.
302+
internal func FillNullTerminatedWideStringBuffer(initialSize: DWORD, maxSize: DWORD, _ body: (UnsafeMutableBufferPointer<WCHAR>) throws -> DWORD) throws -> String {
303+
var bufferCount = max(1, min(initialSize, maxSize))
304+
while bufferCount <= maxSize {
305+
if let result = try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(bufferCount), { buffer in
306+
let count = try body(buffer)
307+
switch count {
308+
case 0:
309+
throw Win32Error(GetLastError())
310+
case 1..<DWORD(buffer.count):
311+
let result = String(decodingCString: buffer.baseAddress!, as: UTF16.self)
312+
assert(result.utf16.count == count, "Parsed UTF-16 count \(result.utf16.count) != reported UTF-16 count \(count)")
313+
return result
314+
default:
315+
bufferCount *= 2
316+
return nil
317+
}
318+
}) {
319+
return result
320+
}
321+
}
322+
throw Win32Error(ERROR_INSUFFICIENT_BUFFER)
323+
}
324+
291325
#endif

0 commit comments

Comments
 (0)