Rollup merge of #124881 - Sp00ph:reentrant_lock_tid, r=joboet · model-checking/verify-rust-std@b0c85ba (original) (raw)
1
1
`#[cfg(all(test, not(target_os = "emscripten")))]
`
2
2
`mod tests;
`
3
3
``
``
4
`+
use cfg_if::cfg_if;
`
``
5
+
4
6
`use crate::cell::UnsafeCell;
`
5
7
`use crate::fmt;
`
6
8
`use crate::ops::Deref;
`
7
9
`use crate::panic::{RefUnwindSafe, UnwindSafe};
`
8
``
`-
use crate::sync::atomic::{AtomicUsize, Ordering::Relaxed};
`
9
10
`use crate::sys::sync as sys;
`
``
11
`+
use crate::thread::{current_id, ThreadId};
`
10
12
``
11
13
`/// A re-entrant mutual exclusion lock
`
12
14
`///
`
`@@ -53,8 +55,8 @@ use crate::sys::sync as sys;
`
53
55
`//
`
54
56
`// The 'owner' field tracks which thread has locked the mutex.
`
55
57
`//
`
56
``
`-
// We use current_thread_unique_ptr() as the thread identifier,
`
57
``
`-
// which is just the address of a thread local variable.
`
``
58
`+
// We use thread::current_id() as the thread identifier, which is just the
`
``
59
`+
// current thread's ThreadId, so it's unique across the process lifetime.
`
58
60
`//
`
59
61
`` // If owner
is set to the identifier of the current thread,
``
60
62
`// we assume the mutex is already locked and instead of locking it again,
`
`@@ -72,14 +74,109 @@ use crate::sys::sync as sys;
`
72
74
`// since we're not dealing with multiple threads. If it's not equal,
`
73
75
`// synchronization is left to the mutex, making relaxed memory ordering for
`
74
76
`` // the owner
field fine in all cases.
``
``
77
`+
//
`
``
78
`+
// On systems without 64 bit atomics we also store the address of a TLS variable
`
``
79
`+
// along the 64-bit TID. We then first check that address against the address
`
``
80
`+
// of that variable on the current thread, and only if they compare equal do we
`
``
81
`+
// compare the actual TIDs. Because we only ever read the TID on the same thread
`
``
82
`+
// that it was written on (or a thread sharing the TLS block with that writer thread),
`
``
83
`+
// we don't need to further synchronize the TID accesses, so they can be regular 64-bit
`
``
84
`+
// non-atomic accesses.
`
75
85
`#[unstable(feature = "reentrant_lock", issue = "121440")]
`
76
86
`pub struct ReentrantLock<T: ?Sized> {
`
77
87
`mutex: sys::Mutex,
`
78
``
`-
owner: AtomicUsize,
`
``
88
`+
owner: Tid,
`
79
89
`lock_count: UnsafeCell,
`
80
90
`data: T,
`
81
91
`}
`
82
92
``
``
93
`+
cfg_if!(
`
``
94
`+
if #[cfg(target_has_atomic = "64")] {
`
``
95
`+
use crate::sync::atomic::{AtomicU64, Ordering::Relaxed};
`
``
96
+
``
97
`+
struct Tid(AtomicU64);
`
``
98
+
``
99
`+
impl Tid {
`
``
100
`+
const fn new() -> Self {
`
``
101
`+
Self(AtomicU64::new(0))
`
``
102
`+
}
`
``
103
+
``
104
`+
#[inline]
`
``
105
`+
fn contains(&self, owner: ThreadId) -> bool {
`
``
106
`+
owner.as_u64().get() == self.0.load(Relaxed)
`
``
107
`+
}
`
``
108
+
``
109
`+
#[inline]
`
``
110
`+
// This is just unsafe to match the API of the Tid type below.
`
``
111
`+
unsafe fn set(&self, tid: Option) {
`
``
112
`+
let value = tid.map_or(0, |tid| tid.as_u64().get());
`
``
113
`+
self.0.store(value, Relaxed);
`
``
114
`+
}
`
``
115
`+
}
`
``
116
`+
} else {
`
``
117
`+
/// Returns the address of a TLS variable. This is guaranteed to
`
``
118
`+
/// be unique across all currently alive threads.
`
``
119
`+
fn tls_addr() -> usize {
`
``
120
`+
thread_local! { static X: u8 = const { 0u8 } };
`
``
121
+
``
122
`+
X.with(|p| <*const u8>::addr(p))
`
``
123
`+
}
`
``
124
+
``
125
`+
use crate::sync::atomic::{
`
``
126
`+
AtomicUsize,
`
``
127
`+
Ordering,
`
``
128
`+
};
`
``
129
+
``
130
`+
struct Tid {
`
``
131
`` +
// When a thread calls set()
, this value gets updated to
``
``
132
`+
// the address of a thread local on that thread. This is
`
``
133
`` +
// used as a first check in contains()
; if the tls_addr
``
``
134
`+
// doesn't match the TLS address of the current thread, then
`
``
135
`+
// the ThreadId also can't match. Only if the TLS addresses do
`
``
136
`+
// match do we read out the actual TID.
`
``
137
`+
// Note also that we can use relaxed atomic operations here, because
`
``
138
`` +
// we only ever read from the tid if tls_addr
matches the current
``
``
139
`+
// TLS address. In that case, either the the tid has been set by
`
``
140
`+
// the current thread, or by a thread that has terminated before
`
``
141
`+
// the current thread was created. In either case, no further
`
``
142
`+
// synchronization is needed (as per https://github.com/rust-lang/miri/issues/3450)
`
``
143
`+
tls_addr: AtomicUsize,
`
``
144
`+
tid: UnsafeCell,
`
``
145
`+
}
`
``
146
+
``
147
`+
unsafe impl Send for Tid {}
`
``
148
`+
unsafe impl Sync for Tid {}
`
``
149
+
``
150
`+
impl Tid {
`
``
151
`+
const fn new() -> Self {
`
``
152
`+
Self { tls_addr: AtomicUsize::new(0), tid: UnsafeCell::new(0) }
`
``
153
`+
}
`
``
154
+
``
155
`+
#[inline]
`
``
156
`` +
// NOTE: This assumes that owner
is the ID of the current
``
``
157
`` +
// thread, and may spuriously return false
if that's not the case.
``
``
158
`+
fn contains(&self, owner: ThreadId) -> bool {
`
``
159
`+
// SAFETY: See the comments in the struct definition.
`
``
160
`+
self.tls_addr.load(Ordering::Relaxed) == tls_addr()
`
``
161
`+
&& unsafe { *self.tid.get() } == owner.as_u64().get()
`
``
162
`+
}
`
``
163
+
``
164
`+
#[inline]
`
``
165
`+
// This may only be called by one thread at a time, and can lead to
`
``
166
`+
// race conditions otherwise.
`
``
167
`+
unsafe fn set(&self, tid: Option) {
`
``
168
`` +
// It's important that we set self.tls_addr
to 0 if the tid is
``
``
169
`+
// cleared. Otherwise, there might be race conditions between
`
``
170
`` +
// set()
and get()
.
``
``
171
`+
let tls_addr = if tid.is_some() { tls_addr() } else { 0 };
`
``
172
`+
let value = tid.map_or(0, |tid| tid.as_u64().get());
`
``
173
`+
self.tls_addr.store(tls_addr, Ordering::Relaxed);
`
``
174
`+
unsafe { *self.tid.get() = value };
`
``
175
`+
}
`
``
176
`+
}
`
``
177
`+
}
`
``
178
`+
);
`
``
179
+
83
180
`#[unstable(feature = "reentrant_lock", issue = "121440")]
`
84
181
`unsafe impl<T: Send + ?Sized> Send for ReentrantLock {}
`
85
182
`#[unstable(feature = "reentrant_lock", issue = "121440")]
`
`@@ -134,7 +231,7 @@ impl ReentrantLock {
`
134
231
`pub const fn new(t: T) -> ReentrantLock {
`
135
232
`ReentrantLock {
`
136
233
`mutex: sys::Mutex::new(),
`
137
``
`-
owner: AtomicUsize::new(0),
`
``
234
`+
owner: Tid::new(),
`
138
235
`lock_count: UnsafeCell::new(0),
`
139
236
`data: t,
`
140
237
`}
`
`@@ -184,14 +281,16 @@ impl<T: ?Sized> ReentrantLock {
`
184
281
`/// assert_eq!(lock.lock().get(), 10);
`
185
282
```` /// ```
````
186
283
`pub fn lock(&self) -> ReentrantLockGuard<'_, T> {
`
187
``
`-
let this_thread = current_thread_unique_ptr();
`
188
``
`-
// Safety: We only touch lock_count when we own the lock.
`
``
284
`+
let this_thread = current_id();
`
``
285
`+
// Safety: We only touch lock_count when we own the inner mutex.
`
``
286
`` +
// Additionally, we only call self.owner.set()
while holding
``
``
287
`+
// the inner mutex, so no two threads can call it concurrently.
`
189
288
`unsafe {
`
190
``
`-
if self.owner.load(Relaxed) == this_thread {
`
``
289
`+
if self.owner.contains(this_thread) {
`
191
290
`self.increment_lock_count().expect("lock count overflow in reentrant mutex");
`
192
291
`} else {
`
193
292
`self.mutex.lock();
`
194
``
`-
self.owner.store(this_thread, Relaxed);
`
``
293
`+
self.owner.set(Some(this_thread));
`
195
294
`debug_assert_eq!(*self.lock_count.get(), 0);
`
196
295
`*self.lock_count.get() = 1;
`
197
296
`}
`
`@@ -226,14 +325,16 @@ impl<T: ?Sized> ReentrantLock {
`
226
325
`///
`
227
326
`/// This function does not block.
`
228
327
`pub(crate) fn try_lock(&self) -> Option<ReentrantLockGuard<'_, T>> {
`
229
``
`-
let this_thread = current_thread_unique_ptr();
`
230
``
`-
// Safety: We only touch lock_count when we own the lock.
`
``
328
`+
let this_thread = current_id();
`
``
329
`+
// Safety: We only touch lock_count when we own the inner mutex.
`
``
330
`` +
// Additionally, we only call self.owner.set()
while holding
``
``
331
`+
// the inner mutex, so no two threads can call it concurrently.
`
231
332
`unsafe {
`
232
``
`-
if self.owner.load(Relaxed) == this_thread {
`
``
333
`+
if self.owner.contains(this_thread) {
`
233
334
`self.increment_lock_count()?;
`
234
335
`Some(ReentrantLockGuard { lock: self })
`
235
336
`} else if self.mutex.try_lock() {
`
236
``
`-
self.owner.store(this_thread, Relaxed);
`
``
337
`+
self.owner.set(Some(this_thread));
`
237
338
`debug_assert_eq!(*self.lock_count.get(), 0);
`
238
339
`*self.lock_count.get() = 1;
`
239
340
`Some(ReentrantLockGuard { lock: self })
`
`@@ -308,18 +409,9 @@ impl<T: ?Sized> Drop for ReentrantLockGuard<'_, T> {
`
308
409
`unsafe {
`
309
410
`*self.lock.lock_count.get() -= 1;
`
310
411
`if *self.lock.lock_count.get() == 0 {
`
311
``
`-
self.lock.owner.store(0, Relaxed);
`
``
412
`+
self.lock.owner.set(None);
`
312
413
`self.lock.mutex.unlock();
`
313
414
`}
`
314
415
`}
`
315
416
`}
`
316
417
`}
`
317
``
-
318
``
`-
/// Get an address that is unique per running thread.
`
319
``
`-
///
`
320
``
`-
/// This can be used as a non-null usize-sized ID.
`
321
``
`-
pub(crate) fn current_thread_unique_ptr() -> usize {
`
322
``
`-
// Use a non-drop type to make sure it's still available during thread destruction.
`
323
``
`-
thread_local! { static X: u8 = const { 0 } }
`
324
``
`-
X.with(|x| <*const _>::addr(x))
`
325
``
`-
}
`