Skip to content

Commit b83f5ac

Browse files
committed
Selecting TCS used when new thread launches
1 parent f7323b1 commit b83f5ac

File tree

7 files changed

+156
-27
lines changed

7 files changed

+156
-27
lines changed

Diff for: Cargo.lock

+2-3
Original file line numberDiff line numberDiff line change
@@ -1190,9 +1190,8 @@ dependencies = [
11901190

11911191
[[package]]
11921192
name = "fortanix-sgx-abi"
1193-
version = "0.3.3"
1194-
source = "registry+https://github.com/rust-lang/crates.io-index"
1195-
checksum = "c56c422ef86062869b2d57ae87270608dc5929969dd130a6e248979cf4fb6ca6"
1193+
version = "0.5.0"
1194+
source = "git+https://github.com/fortanix/rust-sgx.git?branch=raoul/tcs_control#495f16cb6cd0d8f6a7b25ade5196a96a0d22fe35"
11961195
dependencies = [
11971196
"compiler_builtins",
11981197
"rustc-std-workspace-core",

Diff for: Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ cargo-util = { path = "src/tools/cargo/crates/cargo-util" }
101101
rustfmt-nightly = { path = "src/tools/rustfmt" }
102102

103103
[patch.crates-io]
104+
fortanix-sgx-abi = { git = "https://github.com/fortanix/rust-sgx.git", branch = "raoul/tcs_control" }
105+
104106
# See comments in `src/tools/rustc-workspace-hack/README.md` for what's going on
105107
# here
106108
rustc-workspace-hack = { path = 'src/tools/rustc-workspace-hack' }

Diff for: library/std/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ rand = "0.7"
4040
dlmalloc = { version = "0.2.1", features = ['rustc-dep-of-std'] }
4141

4242
[target.x86_64-fortanix-unknown-sgx.dependencies]
43-
fortanix-sgx-abi = { version = "0.3.2", features = ['rustc-dep-of-std'] }
43+
fortanix-sgx-abi = { version = "0.5.0", features = ['rustc-dep-of-std'] }
4444

4545
[target.'cfg(all(any(target_arch = "x86_64", target_arch = "aarch64"), target_os = "hermit"))'.dependencies]
4646
hermit-abi = { version = "0.1.17", features = ['rustc-dep-of-std'] }

Diff for: library/std/src/sys/sgx/abi/entry.S

+33-10
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ IMAGE_BASE:
1111
.long 1 /* type = NT_VERSION */
1212
0: .asciz "toolchain-version" /* name */
1313
1: .align 4
14-
2: .long 1 /* desc - toolchain version number, 32-bit LE */
14+
2: .long 2 /* desc - toolchain version number, 32-bit LE */
1515
3: .align 4
1616

1717
.section .rodata
@@ -90,15 +90,17 @@ IMAGE_BASE:
9090
.equ tcsls_last_rsp, 0x10 /* initialized by loader to 0 */
9191
.equ tcsls_panic_last_rsp, 0x18 /* initialized by loader to 0 */
9292
.equ tcsls_debug_panic_buf_ptr, 0x20 /* initialized by loader to 0 */
93-
.equ tcsls_user_rsp, 0x28
94-
.equ tcsls_user_retip, 0x30
95-
.equ tcsls_user_rbp, 0x38
96-
.equ tcsls_user_r12, 0x40
97-
.equ tcsls_user_r13, 0x48
98-
.equ tcsls_user_r14, 0x50
99-
.equ tcsls_user_r15, 0x58
100-
.equ tcsls_tls_ptr, 0x60
101-
.equ tcsls_tcs_addr, 0x68
93+
.equ tcsls_static_tcs_addr, 0x28 /* initialized by loader to *offset* from image base to static TCS */
94+
.equ tcsls_clist_next, 0x30 /* initialized by loader to *offset* from image base to next TCLS, circular linked list */
95+
.equ tcsls_user_rsp, 0x38
96+
.equ tcsls_user_retip, 0x40
97+
.equ tcsls_user_rbp, 0x48
98+
.equ tcsls_user_r12, 0x50
99+
.equ tcsls_user_r13, 0x58
100+
.equ tcsls_user_r14, 0x60
101+
.equ tcsls_user_r15, 0x68
102+
.equ tcsls_tls_ptr, 0x70
103+
.equ tcsls_tcs_addr, 0x78
102104

103105
.macro load_tcsls_flag_secondary_bool reg:req comments:vararg
104106
.ifne tcsls_flag_secondary /* to convert to a bool, must be the first bit */
@@ -370,3 +372,24 @@ take_debug_panic_buf_ptr:
370372
pop %r11
371373
lfence
372374
jmp *%r11
375+
376+
.global next_tcsls
377+
next_tcsls:
378+
mov %gs:tcsls_clist_next,%rax
379+
pop %r11
380+
lfence
381+
jmp *%r11
382+
383+
.global static_tcs_offset
384+
static_tcs_offset:
385+
mov $tcsls_static_tcs_addr, %rax
386+
pop %r11
387+
lfence
388+
jmp *%r11
389+
390+
.global clist_next_offset
391+
clist_next_offset:
392+
mov $tcsls_clist_next, %rax
393+
pop %r11
394+
lfence
395+
jmp *%r11

Diff for: library/std/src/sys/sgx/abi/mem.rs

+34
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,37 @@ pub fn is_user_range(p: *const u8, len: usize) -> bool {
8989
let base = image_base() as usize;
9090
end < base || start > base + (unsafe { ENCLAVE_SIZE } - 1) // unsafe ok: link-time constant
9191
}
92+
93+
#[repr(C, packed)]
94+
#[derive(Default)]
95+
struct TcslsTcsListItem {
96+
tcs_offset: u64,
97+
next_offset: u64,
98+
}
99+
100+
extern "C" {
101+
fn next_tcsls() -> *const u8;
102+
fn static_tcs_offset() -> u64;
103+
fn clist_next_offset() -> u64;
104+
}
105+
106+
/// Returns the location of all TCSes available at compile time in the enclave
107+
#[unstable(feature = "sgx_platform", issue = "56975")]
108+
pub fn static_tcses() -> Vec<*const u8> {
109+
unsafe {
110+
let mut tcsls = next_tcsls();
111+
let mut tcses = Vec::new();
112+
113+
loop {
114+
let tcs_addr = rel_ptr(*rel_ptr::<u64>(tcsls as u64 + static_tcs_offset()));
115+
tcsls = *(rel_ptr::<*const u8>(tcsls as u64 + clist_next_offset()));
116+
117+
if tcses.first() != Some(&tcs_addr) {
118+
tcses.push(tcs_addr);
119+
} else {
120+
break;
121+
}
122+
}
123+
tcses
124+
}
125+
}

Diff for: library/std/src/sys/sgx/abi/usercalls/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ pub fn connect_stream(addr: &str) -> IoResult<(Fd, String, String)> {
139139

140140
/// Usercall `launch_thread`. See the ABI documentation for more information.
141141
#[unstable(feature = "sgx_platform", issue = "56975")]
142-
pub unsafe fn launch_thread() -> IoResult<()> {
142+
pub unsafe fn launch_thread(tcs: Option<Tcs>) -> IoResult<()> {
143143
// SAFETY: The caller must uphold the safety contract for `launch_thread`.
144-
unsafe { raw::launch_thread().from_sgx_result() }
144+
unsafe { raw::launch_thread(tcs).from_sgx_result() }
145145
}
146146

147147
/// Usercall `exit`. See the ABI documentation for more information.

Diff for: library/std/src/sys/sgx/thread.rs

+82-11
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::io;
55
use crate::num::NonZeroUsize;
66
use crate::time::Duration;
77

8+
use super::abi::thread;
89
use super::abi::usercalls;
910

1011
pub struct Thread(task_queue::JoinHandle);
@@ -13,8 +14,65 @@ pub const DEFAULT_MIN_STACK_SIZE: usize = 4096;
1314

1415
pub use self::task_queue::JoinNotifier;
1516

17+
mod tcs_queue {
18+
use super::super::abi::mem as sgx_mem;
19+
use super::super::abi::thread;
20+
use crate::ptr::NonNull;
21+
use crate::sync::{Mutex, MutexGuard, Once};
22+
23+
#[derive(Clone, PartialEq, Eq, Debug)]
24+
pub(super) struct Tcs {
25+
address: NonNull<u8>,
26+
}
27+
28+
impl Tcs {
29+
fn new(address: NonNull<u8>) -> Tcs {
30+
Tcs { address }
31+
}
32+
33+
pub(super) fn address(&self) -> &NonNull<u8> {
34+
&self.address
35+
}
36+
}
37+
38+
/// A queue of not running TCS structs
39+
static mut TCS_QUEUE: Option<Mutex<Vec<Tcs>>> = None;
40+
static TCS_QUEUE_INIT: Once = Once::new();
41+
42+
fn init_tcs_queue() -> Vec<Tcs> {
43+
sgx_mem::static_tcses()
44+
.iter()
45+
.filter_map(|addr| if NonNull::new(*addr as _) != Some(thread::current()) {
46+
Some(Tcs::new(NonNull::new(*addr as _).expect("Compile-time value unexpected NULL")))
47+
} else {
48+
None
49+
})
50+
.collect()
51+
}
52+
53+
fn lock() -> MutexGuard<'static, Vec<Tcs>> {
54+
unsafe {
55+
TCS_QUEUE_INIT.call_once(|| TCS_QUEUE = Some(Mutex::new(init_tcs_queue())));
56+
TCS_QUEUE.as_ref().unwrap().lock().unwrap()
57+
}
58+
}
59+
60+
pub(super) fn take_tcs() -> Option<Tcs> {
61+
let mut tcs_queue = lock();
62+
if let Some(tcs) = tcs_queue.pop() { Some(tcs) } else { None }
63+
}
64+
65+
pub(super) fn add_tcs(tcs: Tcs) {
66+
let mut tcs_queue = lock();
67+
tcs_queue.insert(0, tcs);
68+
}
69+
}
70+
1671
mod task_queue {
72+
use super::tcs_queue::{self, Tcs};
1773
use super::wait_notify;
74+
use crate::ptr::NonNull;
75+
use crate::sync::mpsc;
1876
use crate::sync::{Mutex, MutexGuard, Once};
1977

2078
pub type JoinHandle = wait_notify::Waiter;
@@ -30,18 +88,24 @@ mod task_queue {
3088
pub(super) struct Task {
3189
p: Box<dyn FnOnce()>,
3290
done: JoinNotifier,
91+
tcs: Tcs,
3392
}
3493

3594
impl Task {
36-
pub(super) fn new(p: Box<dyn FnOnce()>) -> (Task, JoinHandle) {
95+
pub(super) fn new(tcs: Tcs, p: Box<dyn FnOnce()>) -> (Task, JoinHandle) {
3796
let (done, recv) = wait_notify::new();
3897
let done = JoinNotifier(Some(done));
39-
(Task { p, done }, recv)
98+
let task = Task { p, done, tcs };
99+
(task, recv)
40100
}
41101

42102
pub(super) fn run(self) -> JoinNotifier {
43-
(self.p)();
44-
self.done
103+
let Task { tcs, p, done } = self;
104+
105+
p();
106+
107+
tcs_queue::add_tcs(tcs);
108+
done
45109
}
46110
}
47111

@@ -58,6 +122,13 @@ mod task_queue {
58122
TASK_QUEUE.as_ref().unwrap().lock().unwrap()
59123
}
60124
}
125+
126+
pub(super) fn take_task(tcs: NonNull<u8>) -> Option<Task> {
127+
let mut tasks = lock();
128+
let (i, _) = tasks.iter().enumerate().find(|(_i, task)| *task.tcs.address() == tcs)?;
129+
let task = tasks.remove(i);
130+
Some(task)
131+
}
61132
}
62133

63134
/// This module provides a synchronization primitive that does not use thread
@@ -105,17 +176,17 @@ pub mod wait_notify {
105176
impl Thread {
106177
// unsafe: see thread::Builder::spawn_unchecked for safety requirements
107178
pub unsafe fn new(_stack: usize, p: Box<dyn FnOnce()>) -> io::Result<Thread> {
108-
let mut queue_lock = task_queue::lock();
109-
unsafe { usercalls::launch_thread()? };
110-
let (task, handle) = task_queue::Task::new(p);
111-
queue_lock.push(task);
179+
let tcs = tcs_queue::take_tcs().ok_or(io::Error::from(io::ErrorKind::WouldBlock))?;
180+
let mut tasks = task_queue::lock();
181+
unsafe { usercalls::launch_thread(Some(*tcs.address()))? };
182+
let (task, handle) = task_queue::Task::new(tcs, p);
183+
tasks.push(task);
112184
Ok(Thread(handle))
113185
}
114186

115187
pub(super) fn entry() -> JoinNotifier {
116-
let mut pending_tasks = task_queue::lock();
117-
let task = rtunwrap!(Some, pending_tasks.pop());
118-
drop(pending_tasks); // make sure to not hold the task queue lock longer than necessary
188+
let task = task_queue::take_task(thread::current())
189+
.expect("enclave entered through TCS unexpectedly");
119190
task.run()
120191
}
121192

0 commit comments

Comments
 (0)