Skip to content

Commit 2c681e3

Browse files
authored
fix(swift-sdk): make SPV C callbacks Swift 6–safe; eliminate races and TOCTOU (#2814)
1 parent 146643b commit 2c681e3

File tree

7 files changed

+114
-86
lines changed

7 files changed

+114
-86
lines changed

packages/swift-sdk/Sources/SwiftDashSDK/SPV/SPVClient.swift

Lines changed: 108 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ extension SPVClient {
2626
}
2727

2828
// MARK: - C Callback Functions
29-
// These must be global functions to be used as C function pointers
29+
// Use top-level C-compatible functions to avoid actor-isolation init issues
3030

3131
private func spvProgressCallback(
3232
progressPtr: UnsafePointer<FFIDetailedSyncProgress>?,
@@ -35,8 +35,10 @@ private func spvProgressCallback(
3535
guard let progressPtr = progressPtr,
3636
let userData = userData else { return }
3737
let snapshot = progressPtr.pointee
38-
let context = Unmanaged<CallbackContext>.fromOpaque(userData).takeUnretainedValue()
38+
let ptrVal = UInt(bitPattern: userData)
3939
DispatchQueue.main.async {
40+
guard let userData = UnsafeMutableRawPointer(bitPattern: ptrVal) else { return }
41+
let context = Unmanaged<CallbackContext>.fromOpaque(userData).takeUnretainedValue()
4042
context.handleProgressUpdate(snapshot)
4143
}
4244
}
@@ -48,12 +50,80 @@ private func spvCompletionCallback(
4850
) {
4951
guard let userData = userData else { return }
5052
let errorString: String? = errorMsg.map { String(cString: $0) }
51-
let context = Unmanaged<CallbackContext>.fromOpaque(userData).takeUnretainedValue()
53+
let ptrVal = UInt(bitPattern: userData)
5254
DispatchQueue.main.async {
55+
guard let userData = UnsafeMutableRawPointer(bitPattern: ptrVal) else { return }
56+
let context = Unmanaged<CallbackContext>.fromOpaque(userData).takeUnretainedValue()
5357
context.handleSyncCompletion(success: success, error: errorString)
5458
}
5559
}
5660

61+
// Global C-compatible event callbacks that use userData context
62+
private typealias Byte32 = (
63+
UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8,
64+
UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8,
65+
UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8,
66+
UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8
67+
)
68+
69+
private func onBlockCallbackC(
70+
_ height: UInt32,
71+
_ hashPtr: UnsafePointer<Byte32>?,
72+
_ userData: UnsafeMutableRawPointer?
73+
) {
74+
guard let userData = userData else { return }
75+
// Synchronously copy 32-byte hash into Swift-owned buffer to avoid TOCTOU
76+
var hashBytes: [UInt8] = []
77+
if let hashPtr = hashPtr {
78+
let raw = UnsafeRawPointer(hashPtr).assumingMemoryBound(to: UInt8.self)
79+
let buf = UnsafeBufferPointer(start: raw, count: 32)
80+
hashBytes = Array(buf)
81+
}
82+
let ctxAddr = UInt(bitPattern: userData)
83+
Task { @MainActor in
84+
guard let userData = UnsafeMutableRawPointer(bitPattern: ctxAddr) else { return }
85+
let context = Unmanaged<CallbackContext>.fromOpaque(userData).takeUnretainedValue()
86+
let hashData = Data(hashBytes)
87+
context.client?.handleBlockEvent(height: height, hash: hashData)
88+
}
89+
}
90+
91+
private func onTransactionCallbackC(
92+
_ txidPtr: UnsafePointer<Byte32>?,
93+
_ confirmed: Bool,
94+
_ amount: Int64,
95+
_ addressesPtr: UnsafePointer<CChar>?,
96+
_ blockHeight: UInt32,
97+
_ userData: UnsafeMutableRawPointer?
98+
) {
99+
guard let userData = userData else { return }
100+
// Synchronously copy 32-byte txid and address string to Swift-owned values
101+
var txidBytes: [UInt8] = []
102+
if let txidPtr = txidPtr {
103+
let raw = UnsafeRawPointer(txidPtr).assumingMemoryBound(to: UInt8.self)
104+
let buf = UnsafeBufferPointer(start: raw, count: 32)
105+
txidBytes = Array(buf)
106+
}
107+
var addresses: [String] = []
108+
if let addressesPtr = addressesPtr {
109+
let addressesStr = String(cString: addressesPtr)
110+
addresses = addressesStr.components(separatedBy: ",")
111+
}
112+
let ctxAddr = UInt(bitPattern: userData)
113+
Task { @MainActor in
114+
guard let userData = UnsafeMutableRawPointer(bitPattern: ctxAddr) else { return }
115+
let context = Unmanaged<CallbackContext>.fromOpaque(userData).takeUnretainedValue()
116+
let txid = Data(txidBytes)
117+
context.client?.handleTransactionEvent(
118+
txid: txid,
119+
confirmed: confirmed,
120+
amount: amount,
121+
addresses: addresses,
122+
blockHeight: blockHeight > 0 ? blockHeight : nil
123+
)
124+
}
125+
}
126+
57127
// MARK: - SPV Sync Progress
58128

59129
public struct SPVSyncProgress {
@@ -218,12 +288,12 @@ public class SPVClient: ObservableObject {
218288

219289
// Create configuration based on network raw value
220290
let configPtr: UnsafeMutablePointer<FFIClientConfig>? = {
221-
switch network {
222-
case DashSDKNetwork(rawValue: 0):
291+
switch network.rawValue {
292+
case 0:
223293
return dash_spv_ffi_config_mainnet()
224-
case DashSDKNetwork(rawValue: 1):
294+
case 1:
225295
return dash_spv_ffi_config_testnet()
226-
case DashSDKNetwork(rawValue: 2):
296+
case 3:
227297
// Map devnet to custom FFINetwork value 3
228298
return dash_spv_ffi_config_new(FFINetwork(rawValue: 3))
229299
default:
@@ -487,7 +557,12 @@ public class SPVClient: ObservableObject {
487557
}
488558

489559
// Start sync in the background to avoid blocking the main thread
560+
// Copy pointer addresses to avoid capturing non-Sendable pointers inside the GCD closure
561+
let clientAddr = UInt(bitPattern: clientPtr)
562+
let ctxAddr = UInt(bitPattern: contextPtr)
490563
DispatchQueue.global(qos: .userInitiated).async { [weak self] in
564+
guard let clientPtr = UnsafeMutablePointer<FFIDashSpvClient>(bitPattern: clientAddr),
565+
let contextPtr = UnsafeMutableRawPointer(bitPattern: ctxAddr) else { return }
491566
let result = dash_spv_ffi_client_sync_to_tip_with_progress(
492567
clientPtr,
493568
spvProgressCallback,
@@ -560,57 +635,18 @@ public class SPVClient: ObservableObject {
560635
let contextPtr = Unmanaged.passUnretained(context).toOpaque()
561636

562637
var callbacks = FFIEventCallbacks()
563-
564-
callbacks.on_block = { height, hashPtr, userData in
565-
guard let userData = userData else { return }
566-
567-
let context = Unmanaged<CallbackContext>.fromOpaque(userData).takeUnretainedValue()
568-
569-
var hash = Data()
570-
if let hashPtr = hashPtr {
571-
hash = Data(bytes: hashPtr, count: 32)
572-
}
573-
574-
let clientRef = context.client
575-
Task { @MainActor [weak clientRef] in
576-
clientRef?.handleBlockEvent(height: height, hash: hash)
577-
}
578-
}
579-
580-
callbacks.on_transaction = { txidPtr, confirmed, amount, addressesPtr, blockHeight, userData in
581-
guard let userData = userData else { return }
582-
583-
let context = Unmanaged<CallbackContext>.fromOpaque(userData).takeUnretainedValue()
584-
585-
var txid = Data()
586-
if let txidPtr = txidPtr {
587-
txid = Data(bytes: txidPtr, count: 32)
588-
}
589-
590-
var addresses: [String] = []
591-
if let addressesPtr = addressesPtr {
592-
let addressesStr = String(cString: addressesPtr)
593-
addresses = addressesStr.components(separatedBy: ",")
594-
}
595-
596-
let clientRef = context.client
597-
Task { @MainActor [weak clientRef] in
598-
clientRef?.handleTransactionEvent(
599-
txid: txid,
600-
confirmed: confirmed,
601-
amount: amount,
602-
addresses: addresses,
603-
blockHeight: blockHeight > 0 ? blockHeight : nil
604-
)
605-
}
606-
}
638+
639+
// Assign C-compatible top-level functions which match the imported C signatures
640+
callbacks.on_block = onBlockCallbackC
641+
callbacks.on_transaction = onTransactionCallbackC
607642

608643
callbacks.on_compact_filter_matched = { _blockHashPtr, _scripts, _wallet, userData in
609644
guard let userData = userData else { return }
610-
let context = Unmanaged<CallbackContext>.fromOpaque(userData).takeUnretainedValue()
611-
let clientRef = context.client
612-
Task { @MainActor [weak clientRef] in
613-
guard let client = clientRef else { return }
645+
let ptrVal = UInt(bitPattern: userData)
646+
Task { @MainActor in
647+
guard let userData = UnsafeMutableRawPointer(bitPattern: ptrVal) else { return }
648+
let context = Unmanaged<CallbackContext>.fromOpaque(userData).takeUnretainedValue()
649+
guard let client = context.client else { return }
614650
client.blocksHit &+= 1
615651
client.delegate?.spvClient(client, didUpdateBlocksHit: client.blocksHit)
616652
}
@@ -624,7 +660,7 @@ public class SPVClient: ObservableObject {
624660
// MARK: - Filter progress event handler
625661
// MARK: - Event Handlers
626662

627-
private func handleBlockEvent(height: UInt32, hash: Data) {
663+
fileprivate func handleBlockEvent(height: UInt32, hash: Data) {
628664
let block = SPVBlockEvent(
629665
height: height,
630666
hash: hash,
@@ -706,7 +742,7 @@ public class SPVClient: ObservableObject {
706742
}
707743
}
708744

709-
private func handleTransactionEvent(txid: Data, confirmed: Bool, amount: Int64, addresses: [String], blockHeight: UInt32?) {
745+
fileprivate func handleTransactionEvent(txid: Data, confirmed: Bool, amount: Int64, addresses: [String], blockHeight: UInt32?) {
710746
let transaction = SPVTransactionEvent(
711747
txid: txid,
712748
confirmed: confirmed,
@@ -808,15 +844,11 @@ public class SPVClient: ObservableObject {
808844
public func getLatestCheckpointHeight() -> UInt32? {
809845
// Derive FFINetwork matching how we built config
810846
let ffiNet: FFINetwork
811-
switch network {
812-
case DashSDKNetwork(rawValue: 0): // mainnet
813-
ffiNet = FFINetwork(rawValue: 0)
814-
case DashSDKNetwork(rawValue: 1): // testnet
815-
ffiNet = FFINetwork(rawValue: 1)
816-
case DashSDKNetwork(rawValue: 2): // devnet
817-
ffiNet = FFINetwork(rawValue: 3)
818-
default:
819-
ffiNet = FFINetwork(rawValue: 1)
847+
switch network.rawValue {
848+
case 0: ffiNet = FFINetwork(rawValue: 0)
849+
case 1: ffiNet = FFINetwork(rawValue: 1)
850+
case 3: ffiNet = FFINetwork(rawValue: 3)
851+
default: ffiNet = FFINetwork(rawValue: 1)
820852
}
821853

822854
var outHeight: UInt32 = 0
@@ -832,15 +864,11 @@ public class SPVClient: ObservableObject {
832864
/// without depending on the client's configured network.
833865
public static func latestCheckpointHeight(forNetwork net: DashSDKNetwork) -> UInt32? {
834866
let ffiNet: FFINetwork
835-
switch net {
836-
case DashSDKNetwork(rawValue: 0): // mainnet
837-
ffiNet = FFINetwork(rawValue: 0)
838-
case DashSDKNetwork(rawValue: 1): // testnet
839-
ffiNet = FFINetwork(rawValue: 1)
840-
case DashSDKNetwork(rawValue: 2): // devnet
841-
ffiNet = FFINetwork(rawValue: 3)
842-
default:
843-
ffiNet = FFINetwork(rawValue: 1)
867+
switch net.rawValue {
868+
case 0: ffiNet = FFINetwork(rawValue: 0)
869+
case 1: ffiNet = FFINetwork(rawValue: 1)
870+
case 3: ffiNet = FFINetwork(rawValue: 3)
871+
default: ffiNet = FFINetwork(rawValue: 1)
844872
}
845873

846874
var outHeight: UInt32 = 0
@@ -855,11 +883,11 @@ public class SPVClient: ObservableObject {
855883
/// Returns the checkpoint height at or before a given UNIX timestamp (seconds) for this network
856884
public func getCheckpointHeight(beforeTimestamp timestamp: UInt32) -> UInt32? {
857885
let ffiNet: FFINetwork
858-
switch network {
859-
case DashSDKNetwork(rawValue: 0): ffiNet = FFINetwork(rawValue: 0)
860-
case DashSDKNetwork(rawValue: 1): ffiNet = FFINetwork(rawValue: 1)
861-
case DashSDKNetwork(rawValue: 2): ffiNet = FFINetwork(rawValue: 3)
862-
default: ffiNet = FFINetwork(rawValue: 1)
886+
switch network.rawValue {
887+
case 0: ffiNet = FFINetwork(rawValue: 0)
888+
case 1: ffiNet = FFINetwork(rawValue: 1)
889+
case 3: ffiNet = FFINetwork(rawValue: 3)
890+
default: ffiNet = FFINetwork(rawValue: 1)
863891
}
864892
var outHeight: UInt32 = 0
865893
var outHash = [UInt8](repeating: 0, count: 32)

packages/swift-sdk/SwiftExampleApp/SwiftExampleApp/Models/Network.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ enum Network: String, CaseIterable, Codable {
2424
case .testnet:
2525
return DashSDKNetwork(rawValue: 1)
2626
case .devnet:
27-
return DashSDKNetwork(rawValue: 2)
27+
return DashSDKNetwork(rawValue: 3)
2828
}
2929
}
3030

packages/swift-sdk/SwiftExampleApp/SwiftExampleAppTests/CrashDebugTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,4 @@ final class CrashDebugTests: XCTestCase {
116116

117117
print("=== Method existence test completed ===")
118118
}
119-
}
119+
}

packages/swift-sdk/SwiftExampleApp/SwiftExampleAppTests/DebugTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,4 @@ final class DebugTests: XCTestCase {
117117

118118
XCTAssertTrue(true)
119119
}
120-
}
120+
}

packages/swift-sdk/SwiftExampleApp/SwiftExampleAppTests/SDKMethodTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,4 @@ final class SDKMethodTests: XCTestCase {
108108
throw error
109109
}
110110
}
111-
}
111+
}

packages/swift-sdk/SwiftExampleApp/SwiftExampleAppTests/SimpleTransitionTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,4 @@ final class SimpleTransitionTests: XCTestCase {
108108

109109
print(">>> SimpleTransitionTests.testIdentityCreditTransfer completed")
110110
}
111-
}
111+
}

packages/swift-sdk/SwiftExampleApp/SwiftExampleAppTests/StateTransitionTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ final class StateTransitionTests: XCTestCase {
612612
SDK.initialize()
613613

614614
// Create SDK instance for testnet
615-
let testnetNetwork = DashSDKNetwork(rawValue: 1) // Testnet
615+
let testnetNetwork = DashSDKNetwork(rawValue: 1)
616616
return try SDK(network: testnetNetwork)
617617
}
618618

0 commit comments

Comments
 (0)