]> git.lizzy.rs Git - rust.git/commitdiff
std: optimize TLS on Windows
authorjoboet <jonasboettiger@icloud.com>
Sat, 8 Oct 2022 18:19:21 +0000 (20:19 +0200)
committerjoboet <jonasboettiger@icloud.com>
Sat, 8 Oct 2022 18:19:21 +0000 (20:19 +0200)
library/std/src/sys/sgx/thread_local_key.rs
library/std/src/sys/solid/thread_local_key.rs
library/std/src/sys/unix/thread_local_key.rs
library/std/src/sys/unsupported/thread_local_key.rs
library/std/src/sys/windows/c.rs
library/std/src/sys/windows/thread_local_key.rs
library/std/src/sys/windows/thread_local_key/tests.rs [new file with mode: 0644]
library/std/src/sys_common/mod.rs
library/std/src/sys_common/thread_local_key.rs

index b21784475f0d2d020aa4f0eb14cff31d0b1f558c..c7a57d3a3d47e3fe290944f6fe9c23277a2ee84e 100644 (file)
@@ -21,8 +21,3 @@ pub unsafe fn get(key: Key) -> *mut u8 {
 pub unsafe fn destroy(key: Key) {
     Tls::destroy(AbiKey::from_usize(key))
 }
-
-#[inline]
-pub fn requires_synchronized_create() -> bool {
-    false
-}
index b17521f701daf99ad021118dc4cbb5ec1a0ef1bc..b37bf999698873fe1b159c8e0ecc7985b198fa46 100644 (file)
@@ -19,8 +19,3 @@ pub unsafe fn get(_key: Key) -> *mut u8 {
 pub unsafe fn destroy(_key: Key) {
     panic!("should not be used on the solid target");
 }
-
-#[inline]
-pub fn requires_synchronized_create() -> bool {
-    panic!("should not be used on the solid target");
-}
index 2c5b94b1e61e5710717363d92396c6b5e1027305..2b2d079ee4d012730b8324561c78beab7191fda7 100644 (file)
@@ -27,8 +27,3 @@ pub unsafe fn destroy(key: Key) {
     let r = libc::pthread_key_delete(key);
     debug_assert_eq!(r, 0);
 }
-
-#[inline]
-pub fn requires_synchronized_create() -> bool {
-    false
-}
index c31b61cbf56d386785f2f43e3ce292eb70004232..b6e5e4cd2e197872f23b17e2101e129021029ce8 100644 (file)
@@ -19,8 +19,3 @@ pub unsafe fn get(_key: Key) -> *mut u8 {
 pub unsafe fn destroy(_key: Key) {
     panic!("should not be used on this target");
 }
-
-#[inline]
-pub fn requires_synchronized_create() -> bool {
-    panic!("should not be used on this target");
-}
index 89d0ab59be89f3edcc15eba93163fd3a1485bcc4..e43229588f888304c2772afbc5a2bc5f650aa9fc 100644 (file)
@@ -71,6 +71,7 @@
 pub type PCONDITION_VARIABLE = *mut CONDITION_VARIABLE;
 pub type PLARGE_INTEGER = *mut c_longlong;
 pub type PSRWLOCK = *mut SRWLOCK;
+pub type LPINIT_ONCE = *mut INIT_ONCE;
 
 pub type SOCKET = crate::os::windows::raw::SOCKET;
 pub type socklen_t = c_int;
@@ -194,6 +195,9 @@ fn clone(&self) -> Self {
 
 pub const CONDITION_VARIABLE_INIT: CONDITION_VARIABLE = CONDITION_VARIABLE { ptr: ptr::null_mut() };
 pub const SRWLOCK_INIT: SRWLOCK = SRWLOCK { ptr: ptr::null_mut() };
+pub const INIT_ONCE_STATIC_INIT: INIT_ONCE = INIT_ONCE { ptr: ptr::null_mut() };
+
+pub const INIT_ONCE_INIT_FAILED: DWORD = 0x00000004;
 
 pub const DETACHED_PROCESS: DWORD = 0x00000008;
 pub const CREATE_NEW_PROCESS_GROUP: DWORD = 0x00000200;
@@ -565,6 +569,10 @@ pub struct CONDITION_VARIABLE {
 pub struct SRWLOCK {
     pub ptr: LPVOID,
 }
+#[repr(C)]
+pub struct INIT_ONCE {
+    pub ptr: LPVOID,
+}
 
 #[repr(C)]
 pub struct REPARSE_MOUNTPOINT_DATA_BUFFER {
@@ -959,6 +967,7 @@ pub fn FormatMessageW(
     pub fn TlsAlloc() -> DWORD;
     pub fn TlsGetValue(dwTlsIndex: DWORD) -> LPVOID;
     pub fn TlsSetValue(dwTlsIndex: DWORD, lpTlsvalue: LPVOID) -> BOOL;
+    pub fn TlsFree(dwTlsIndex: DWORD) -> BOOL;
     pub fn GetLastError() -> DWORD;
     pub fn QueryPerformanceFrequency(lpFrequency: *mut LARGE_INTEGER) -> BOOL;
     pub fn QueryPerformanceCounter(lpPerformanceCount: *mut LARGE_INTEGER) -> BOOL;
@@ -1118,6 +1127,14 @@ pub fn SleepConditionVariableSRW(
     pub fn TryAcquireSRWLockExclusive(SRWLock: PSRWLOCK) -> BOOLEAN;
     pub fn TryAcquireSRWLockShared(SRWLock: PSRWLOCK) -> BOOLEAN;
 
+    pub fn InitOnceBeginInitialize(
+        lpInitOnce: LPINIT_ONCE,
+        dwFlags: DWORD,
+        fPending: LPBOOL,
+        lpContext: *mut LPVOID,
+    ) -> BOOL;
+    pub fn InitOnceComplete(lpInitOnce: LPINIT_ONCE, dwFlags: DWORD, lpContext: LPVOID) -> BOOL;
+
     pub fn CompareStringOrdinal(
         lpString1: LPCWSTR,
         cchCount1: c_int,
index ec670238e6f0eaa9ba65ec60687f173d0ebbab6d..17628b7579b8db3ac2a8eaa89e4fc20b1ce81af7 100644 (file)
@@ -1,11 +1,16 @@
-use crate::mem::ManuallyDrop;
+use crate::cell::UnsafeCell;
 use crate::ptr;
-use crate::sync::atomic::AtomicPtr;
-use crate::sync::atomic::Ordering::SeqCst;
+use crate::sync::atomic::{
+    AtomicPtr, AtomicU32,
+    Ordering::{AcqRel, Acquire, Relaxed, Release},
+};
 use crate::sys::c;
 
-pub type Key = c::DWORD;
-pub type Dtor = unsafe extern "C" fn(*mut u8);
+#[cfg(test)]
+mod tests;
+
+type Key = c::DWORD;
+type Dtor = unsafe extern "C" fn(*mut u8);
 
 // Turns out, like pretty much everything, Windows is pretty close the
 // functionality that Unix provides, but slightly different! In the case of
 // To accomplish this feat, we perform a number of threads, all contained
 // within this module:
 //
-// * All TLS destructors are tracked by *us*, not the windows runtime. This
+// * All TLS destructors are tracked by *us*, not the Windows runtime. This
 //   means that we have a global list of destructors for each TLS key that
 //   we know about.
 // * When a thread exits, we run over the entire list and run dtors for all
 //   non-null keys. This attempts to match Unix semantics in this regard.
 //
-// This ends up having the overhead of using a global list, having some
-// locks here and there, and in general just adding some more code bloat. We
-// attempt to optimize runtime by forgetting keys that don't have
-// destructors, but this only gets us so far.
-//
 // For more details and nitty-gritty, see the code sections below!
 //
 // [1]: https://www.codeproject.com/Articles/8113/Thread-Local-Storage-The-C-Way
-// [2]: https://github.com/ChromiumWebApps/chromium/blob/master/base
-//                        /threading/thread_local_storage_win.cc#L42
+// [2]: https://github.com/ChromiumWebApps/chromium/blob/master/base/threading/thread_local_storage_win.cc#L42
 
-// -------------------------------------------------------------------------
-// Native bindings
-//
-// This section is just raw bindings to the native functions that Windows
-// provides, There's a few extra calls to deal with destructors.
+pub struct StaticKey {
+    /// The key value shifted up by one. Since TLS_OUT_OF_INDEXES == DWORD::MAX
+    /// is not a valid key value, this allows us to use zero as sentinel value
+    /// without risking overflow.
+    key: AtomicU32,
+    dtor: Option<Dtor>,
+    next: AtomicPtr<StaticKey>,
+    /// Currently, destructors cannot be unregistered, so we cannot use racy
+    /// initialization for keys. Instead, we need synchronize initialization.
+    /// Use the Windows-provided `Once` since it does not require TLS.
+    once: UnsafeCell<c::INIT_ONCE>,
+}
 
-#[inline]
-pub unsafe fn create(dtor: Option<Dtor>) -> Key {
-    let key = c::TlsAlloc();
-    assert!(key != c::TLS_OUT_OF_INDEXES);
-    if let Some(f) = dtor {
-        register_dtor(key, f);
+impl StaticKey {
+    #[inline]
+    pub const fn new(dtor: Option<Dtor>) -> StaticKey {
+        StaticKey {
+            key: AtomicU32::new(0),
+            dtor,
+            next: AtomicPtr::new(ptr::null_mut()),
+            once: UnsafeCell::new(c::INIT_ONCE_STATIC_INIT),
+        }
     }
-    key
-}
 
-#[inline]
-pub unsafe fn set(key: Key, value: *mut u8) {
-    let r = c::TlsSetValue(key, value as c::LPVOID);
-    debug_assert!(r != 0);
-}
+    #[inline]
+    pub unsafe fn set(&'static self, val: *mut u8) {
+        let r = c::TlsSetValue(self.key(), val.cast());
+        debug_assert_eq!(r, c::TRUE);
+    }
 
-#[inline]
-pub unsafe fn get(key: Key) -> *mut u8 {
-    c::TlsGetValue(key) as *mut u8
-}
+    #[inline]
+    pub unsafe fn get(&'static self) -> *mut u8 {
+        c::TlsGetValue(self.key()).cast()
+    }
 
-#[inline]
-pub unsafe fn destroy(_key: Key) {
-    rtabort!("can't destroy tls keys on windows")
-}
+    #[inline]
+    unsafe fn key(&'static self) -> Key {
+        match self.key.load(Acquire) {
+            0 => self.init(),
+            key => key - 1,
+        }
+    }
+
+    #[cold]
+    unsafe fn init(&'static self) -> Key {
+        if self.dtor.is_some() {
+            let mut pending = c::FALSE;
+            let r = c::InitOnceBeginInitialize(self.once.get(), 0, &mut pending, ptr::null_mut());
+            assert_eq!(r, c::TRUE);
 
-#[inline]
-pub fn requires_synchronized_create() -> bool {
-    true
+            if pending == c::FALSE {
+                // Some other thread initialized the key, load it.
+                self.key.load(Relaxed) - 1
+            } else {
+                let key = c::TlsAlloc();
+                if key == c::TLS_OUT_OF_INDEXES {
+                    // Wakeup the waiting threads before panicking to avoid deadlock.
+                    c::InitOnceComplete(self.once.get(), c::INIT_ONCE_INIT_FAILED, ptr::null_mut());
+                    panic!("out of TLS indexes");
+                }
+
+                self.key.store(key + 1, Release);
+                register_dtor(self);
+
+                let r = c::InitOnceComplete(self.once.get(), 0, ptr::null_mut());
+                debug_assert_eq!(r, c::TRUE);
+
+                key
+            }
+        } else {
+            // If there is no destructor to clean up, we can use racy initialization.
+
+            let key = c::TlsAlloc();
+            assert_ne!(key, c::TLS_OUT_OF_INDEXES, "out of TLS indexes");
+
+            match self.key.compare_exchange(0, key + 1, AcqRel, Acquire) {
+                Ok(_) => key,
+                Err(new) => {
+                    // Some other thread completed initialization first, so destroy
+                    // our key and use theirs.
+                    let r = c::TlsFree(key);
+                    debug_assert_eq!(r, c::TRUE);
+                    new - 1
+                }
+            }
+        }
+    }
 }
 
+unsafe impl Send for StaticKey {}
+unsafe impl Sync for StaticKey {}
+
 // -------------------------------------------------------------------------
 // Dtor registration
 //
@@ -96,29 +150,21 @@ pub fn requires_synchronized_create() -> bool {
 // Typically processes have a statically known set of TLS keys which is pretty
 // small, and we'd want to keep this memory alive for the whole process anyway
 // really.
-//
-// Perhaps one day we can fold the `Box` here into a static allocation,
-// expanding the `StaticKey` structure to contain not only a slot for the TLS
-// key but also a slot for the destructor queue on windows. An optimization for
-// another day!
-
-static DTORS: AtomicPtr<Node> = AtomicPtr::new(ptr::null_mut());
-
-struct Node {
-    dtor: Dtor,
-    key: Key,
-    next: *mut Node,
-}
 
-unsafe fn register_dtor(key: Key, dtor: Dtor) {
-    let mut node = ManuallyDrop::new(Box::new(Node { key, dtor, next: ptr::null_mut() }));
+static DTORS: AtomicPtr<StaticKey> = AtomicPtr::new(ptr::null_mut());
 
-    let mut head = DTORS.load(SeqCst);
+/// Should only be called once per key, otherwise loops or breaks may occur in
+/// the linked list.
+unsafe fn register_dtor(key: &'static StaticKey) {
+    let this = <*const StaticKey>::cast_mut(key);
+    // Use acquire ordering to pass along the changes done by the previously
+    // registered keys when we store the new head with release ordering.
+    let mut head = DTORS.load(Acquire);
     loop {
-        node.next = head;
-        match DTORS.compare_exchange(head, &mut **node, SeqCst, SeqCst) {
-            Ok(_) => return, // nothing to drop, we successfully added the node to the list
-            Err(cur) => head = cur,
+        key.next.store(head, Relaxed);
+        match DTORS.compare_exchange_weak(head, this, Release, Acquire) {
+            Ok(_) => break,
+            Err(new) => head = new,
         }
     }
 }
@@ -214,25 +260,29 @@ unsafe fn reference_tls_used() {
     unsafe fn reference_tls_used() {}
 }
 
-#[allow(dead_code)] // actually called above
+#[allow(dead_code)] // actually called below
 unsafe fn run_dtors() {
-    let mut any_run = true;
     for _ in 0..5 {
-        if !any_run {
-            break;
-        }
-        any_run = false;
-        let mut cur = DTORS.load(SeqCst);
+        let mut any_run = false;
+
+        // Use acquire ordering to observe key initialization.
+        let mut cur = DTORS.load(Acquire);
         while !cur.is_null() {
-            let ptr = c::TlsGetValue((*cur).key);
+            let key = (*cur).key.load(Relaxed) - 1;
+            let dtor = (*cur).dtor.unwrap();
 
+            let ptr = c::TlsGetValue(key);
             if !ptr.is_null() {
-                c::TlsSetValue((*cur).key, ptr::null_mut());
-                ((*cur).dtor)(ptr as *mut _);
+                c::TlsSetValue(key, ptr::null_mut());
+                dtor(ptr as *mut _);
                 any_run = true;
             }
 
-            cur = (*cur).next;
+            cur = (*cur).next.load(Relaxed);
+        }
+
+        if !any_run {
+            break;
         }
     }
 }
diff --git a/library/std/src/sys/windows/thread_local_key/tests.rs b/library/std/src/sys/windows/thread_local_key/tests.rs
new file mode 100644 (file)
index 0000000..c95f383
--- /dev/null
@@ -0,0 +1,53 @@
+use super::StaticKey;
+use crate::ptr;
+
+#[test]
+fn smoke() {
+    static K1: StaticKey = StaticKey::new(None);
+    static K2: StaticKey = StaticKey::new(None);
+
+    unsafe {
+        assert!(K1.get().is_null());
+        assert!(K2.get().is_null());
+        K1.set(ptr::invalid_mut(1));
+        K2.set(ptr::invalid_mut(2));
+        assert_eq!(K1.get() as usize, 1);
+        assert_eq!(K2.get() as usize, 2);
+    }
+}
+
+#[test]
+fn destructors() {
+    use crate::mem::ManuallyDrop;
+    use crate::sync::Arc;
+    use crate::thread;
+
+    unsafe extern "C" fn destruct(ptr: *mut u8) {
+        drop(Arc::from_raw(ptr as *const ()));
+    }
+
+    static KEY: StaticKey = StaticKey::new(Some(destruct));
+
+    let shared1 = Arc::new(());
+    let shared2 = Arc::clone(&shared1);
+
+    unsafe {
+        assert!(KEY.get().is_null());
+        KEY.set(Arc::into_raw(shared1) as *mut u8);
+    }
+
+    thread::spawn(move || unsafe {
+        assert!(KEY.get().is_null());
+        KEY.set(Arc::into_raw(shared2) as *mut u8);
+    })
+    .join()
+    .unwrap();
+
+    // Leak the Arc, let the TLS destructor clean it up.
+    let shared1 = unsafe { ManuallyDrop::new(Arc::from_raw(KEY.get() as *const ())) };
+    assert_eq!(
+        Arc::strong_count(&shared1),
+        1,
+        "destructor should have dropped the other reference on thread exit"
+    );
+}
index 80f56bf7522b67edbf547f77cf8707fded8c9b21..9ea3c52fa6d713580d34053caafbed6d79bf34d1 100644 (file)
 pub mod thread;
 pub mod thread_info;
 pub mod thread_local_dtor;
-pub mod thread_local_key;
 pub mod thread_parker;
 pub mod wtf8;
 
+cfg_if::cfg_if! {
+    if #[cfg(target_os = "windows")] {
+        pub use crate::sys::thread_local_key;
+    } else {
+        pub mod thread_local_key;
+    }
+}
+
 cfg_if::cfg_if! {
     if #[cfg(any(target_os = "l4re",
                  target_os = "hermit",
index 032bf604d73889de6aac6fa8418f63c37fb40c2d..747579f178127c2897268a6df906f49dc7cd2601 100644 (file)
@@ -53,7 +53,6 @@
 
 use crate::sync::atomic::{self, AtomicUsize, Ordering};
 use crate::sys::thread_local_key as imp;
-use crate::sys_common::mutex::StaticMutex;
 
 /// A type for TLS keys that are statically allocated.
 ///
@@ -151,25 +150,6 @@ unsafe fn key(&self) -> imp::Key {
     }
 
     unsafe fn lazy_init(&self) -> usize {
-        // Currently the Windows implementation of TLS is pretty hairy, and
-        // it greatly simplifies creation if we just synchronize everything.
-        //
-        // Additionally a 0-index of a tls key hasn't been seen on windows, so
-        // we just simplify the whole branch.
-        if imp::requires_synchronized_create() {
-            // We never call `INIT_LOCK.init()`, so it is UB to attempt to
-            // acquire this mutex reentrantly!
-            static INIT_LOCK: StaticMutex = StaticMutex::new();
-            let _guard = INIT_LOCK.lock();
-            let mut key = self.key.load(Ordering::SeqCst);
-            if key == 0 {
-                key = imp::create(self.dtor) as usize;
-                self.key.store(key, Ordering::SeqCst);
-            }
-            rtassert!(key != 0);
-            return key;
-        }
-
         // POSIX allows the key created here to be 0, but the compare_exchange
         // below relies on using 0 as a sentinel value to check who won the
         // race to set the shared TLS key. As far as I know, there is no
@@ -232,8 +212,6 @@ pub fn set(&self, val: *mut u8) {
 
 impl Drop for Key {
     fn drop(&mut self) {
-        // Right now Windows doesn't support TLS key destruction, but this also
-        // isn't used anywhere other than tests, so just leak the TLS key.
-        // unsafe { imp::destroy(self.key) }
+        unsafe { imp::destroy(self.key) }
     }
 }