Rollup merge of #124881 - Sp00ph:reentrant_lock_tid, r=joboet · model-checking/verify-rust-std@b0c85ba (original) (raw)

1

1

`#[cfg(all(test, not(target_os = "emscripten")))]

`

2

2

`mod tests;

`

3

3

``

``

4

`+

use cfg_if::cfg_if;

`

``

5

+

4

6

`use crate::cell::UnsafeCell;

`

5

7

`use crate::fmt;

`

6

8

`use crate::ops::Deref;

`

7

9

`use crate::panic::{RefUnwindSafe, UnwindSafe};

`

8

``

`-

use crate::sync::atomic::{AtomicUsize, Ordering::Relaxed};

`

9

10

`use crate::sys::sync as sys;

`

``

11

`+

use crate::thread::{current_id, ThreadId};

`

10

12

``

11

13

`/// A re-entrant mutual exclusion lock

`

12

14

`///

`

`@@ -53,8 +55,8 @@ use crate::sys::sync as sys;

`

53

55

`//

`

54

56

`// The 'owner' field tracks which thread has locked the mutex.

`

55

57

`//

`

56

``

`-

// We use current_thread_unique_ptr() as the thread identifier,

`

57

``

`-

// which is just the address of a thread local variable.

`

``

58

`+

// We use thread::current_id() as the thread identifier, which is just the

`

``

59

`+

// current thread's ThreadId, so it's unique across the process lifetime.

`

58

60

`//

`

59

61

`` // If owner is set to the identifier of the current thread,

``

60

62

`// we assume the mutex is already locked and instead of locking it again,

`

`@@ -72,14 +74,109 @@ use crate::sys::sync as sys;

`

72

74

`// since we're not dealing with multiple threads. If it's not equal,

`

73

75

`// synchronization is left to the mutex, making relaxed memory ordering for

`

74

76

`` // the owner field fine in all cases.

``

``

77

`+

//

`

``

78

`+

// On systems without 64 bit atomics we also store the address of a TLS variable

`

``

79

`+

// along the 64-bit TID. We then first check that address against the address

`

``

80

`+

// of that variable on the current thread, and only if they compare equal do we

`

``

81

`+

// compare the actual TIDs. Because we only ever read the TID on the same thread

`

``

82

`+

// that it was written on (or a thread sharing the TLS block with that writer thread),

`

``

83

`+

// we don't need to further synchronize the TID accesses, so they can be regular 64-bit

`

``

84

`+

// non-atomic accesses.

`

75

85

`#[unstable(feature = "reentrant_lock", issue = "121440")]

`

76

86

`pub struct ReentrantLock<T: ?Sized> {

`

77

87

`mutex: sys::Mutex,

`

78

``

`-

owner: AtomicUsize,

`

``

88

`+

owner: Tid,

`

79

89

`lock_count: UnsafeCell,

`

80

90

`data: T,

`

81

91

`}

`

82

92

``

``

93

`+

cfg_if!(

`

``

94

`+

if #[cfg(target_has_atomic = "64")] {

`

``

95

`+

use crate::sync::atomic::{AtomicU64, Ordering::Relaxed};

`

``

96

+

``

97

`+

struct Tid(AtomicU64);

`

``

98

+

``

99

`+

impl Tid {

`

``

100

`+

const fn new() -> Self {

`

``

101

`+

Self(AtomicU64::new(0))

`

``

102

`+

}

`

``

103

+

``

104

`+

#[inline]

`

``

105

`+

fn contains(&self, owner: ThreadId) -> bool {

`

``

106

`+

owner.as_u64().get() == self.0.load(Relaxed)

`

``

107

`+

}

`

``

108

+

``

109

`+

#[inline]

`

``

110

`+

// This is just unsafe to match the API of the Tid type below.

`

``

111

`+

unsafe fn set(&self, tid: Option) {

`

``

112

`+

let value = tid.map_or(0, |tid| tid.as_u64().get());

`

``

113

`+

self.0.store(value, Relaxed);

`

``

114

`+

}

`

``

115

`+

}

`

``

116

`+

} else {

`

``

117

`+

/// Returns the address of a TLS variable. This is guaranteed to

`

``

118

`+

/// be unique across all currently alive threads.

`

``

119

`+

fn tls_addr() -> usize {

`

``

120

`+

thread_local! { static X: u8 = const { 0u8 } };

`

``

121

+

``

122

`+

X.with(|p| <*const u8>::addr(p))

`

``

123

`+

}

`

``

124

+

``

125

`+

use crate::sync::atomic::{

`

``

126

`+

AtomicUsize,

`

``

127

`+

Ordering,

`

``

128

`+

};

`

``

129

+

``

130

`+

struct Tid {

`

``

131

`` +

// When a thread calls set(), this value gets updated to

``

``

132

`+

// the address of a thread local on that thread. This is

`

``

133

`` +

// used as a first check in contains(); if the tls_addr

``

``

134

`+

// doesn't match the TLS address of the current thread, then

`

``

135

`+

// the ThreadId also can't match. Only if the TLS addresses do

`

``

136

`+

// match do we read out the actual TID.

`

``

137

`+

// Note also that we can use relaxed atomic operations here, because

`

``

138

`` +

// we only ever read from the tid if tls_addr matches the current

``

``

139

`+

// TLS address. In that case, either the the tid has been set by

`

``

140

`+

// the current thread, or by a thread that has terminated before

`

``

141

`+

// the current thread was created. In either case, no further

`

``

142

`+

// synchronization is needed (as per https://github.com/rust-lang/miri/issues/3450)

`

``

143

`+

tls_addr: AtomicUsize,

`

``

144

`+

tid: UnsafeCell,

`

``

145

`+

}

`

``

146

+

``

147

`+

unsafe impl Send for Tid {}

`

``

148

`+

unsafe impl Sync for Tid {}

`

``

149

+

``

150

`+

impl Tid {

`

``

151

`+

const fn new() -> Self {

`

``

152

`+

Self { tls_addr: AtomicUsize::new(0), tid: UnsafeCell::new(0) }

`

``

153

`+

}

`

``

154

+

``

155

`+

#[inline]

`

``

156

`` +

// NOTE: This assumes that owner is the ID of the current

``

``

157

`` +

// thread, and may spuriously return false if that's not the case.

``

``

158

`+

fn contains(&self, owner: ThreadId) -> bool {

`

``

159

`+

// SAFETY: See the comments in the struct definition.

`

``

160

`+

self.tls_addr.load(Ordering::Relaxed) == tls_addr()

`

``

161

`+

&& unsafe { *self.tid.get() } == owner.as_u64().get()

`

``

162

`+

}

`

``

163

+

``

164

`+

#[inline]

`

``

165

`+

// This may only be called by one thread at a time, and can lead to

`

``

166

`+

// race conditions otherwise.

`

``

167

`+

unsafe fn set(&self, tid: Option) {

`

``

168

`` +

// It's important that we set self.tls_addr to 0 if the tid is

``

``

169

`+

// cleared. Otherwise, there might be race conditions between

`

``

170

`` +

// set() and get().

``

``

171

`+

let tls_addr = if tid.is_some() { tls_addr() } else { 0 };

`

``

172

`+

let value = tid.map_or(0, |tid| tid.as_u64().get());

`

``

173

`+

self.tls_addr.store(tls_addr, Ordering::Relaxed);

`

``

174

`+

unsafe { *self.tid.get() = value };

`

``

175

`+

}

`

``

176

`+

}

`

``

177

`+

}

`

``

178

`+

);

`

``

179

+

83

180

`#[unstable(feature = "reentrant_lock", issue = "121440")]

`

84

181

`unsafe impl<T: Send + ?Sized> Send for ReentrantLock {}

`

85

182

`#[unstable(feature = "reentrant_lock", issue = "121440")]

`

`@@ -134,7 +231,7 @@ impl ReentrantLock {

`

134

231

`pub const fn new(t: T) -> ReentrantLock {

`

135

232

`ReentrantLock {

`

136

233

`mutex: sys::Mutex::new(),

`

137

``

`-

owner: AtomicUsize::new(0),

`

``

234

`+

owner: Tid::new(),

`

138

235

`lock_count: UnsafeCell::new(0),

`

139

236

`data: t,

`

140

237

`}

`

`@@ -184,14 +281,16 @@ impl<T: ?Sized> ReentrantLock {

`

184

281

`/// assert_eq!(lock.lock().get(), 10);

`

185

282

```` /// ```

````

186

283

`pub fn lock(&self) -> ReentrantLockGuard<'_, T> {

`

187

``

`-

let this_thread = current_thread_unique_ptr();

`

188

``

`-

// Safety: We only touch lock_count when we own the lock.

`

``

284

`+

let this_thread = current_id();

`

``

285

`+

// Safety: We only touch lock_count when we own the inner mutex.

`

``

286

`` +

// Additionally, we only call self.owner.set() while holding

``

``

287

`+

// the inner mutex, so no two threads can call it concurrently.

`

189

288

`unsafe {

`

190

``

`-

if self.owner.load(Relaxed) == this_thread {

`

``

289

`+

if self.owner.contains(this_thread) {

`

191

290

`self.increment_lock_count().expect("lock count overflow in reentrant mutex");

`

192

291

`} else {

`

193

292

`self.mutex.lock();

`

194

``

`-

self.owner.store(this_thread, Relaxed);

`

``

293

`+

self.owner.set(Some(this_thread));

`

195

294

`debug_assert_eq!(*self.lock_count.get(), 0);

`

196

295

`*self.lock_count.get() = 1;

`

197

296

`}

`

`@@ -226,14 +325,16 @@ impl<T: ?Sized> ReentrantLock {

`

226

325

`///

`

227

326

`/// This function does not block.

`

228

327

`pub(crate) fn try_lock(&self) -> Option<ReentrantLockGuard<'_, T>> {

`

229

``

`-

let this_thread = current_thread_unique_ptr();

`

230

``

`-

// Safety: We only touch lock_count when we own the lock.

`

``

328

`+

let this_thread = current_id();

`

``

329

`+

// Safety: We only touch lock_count when we own the inner mutex.

`

``

330

`` +

// Additionally, we only call self.owner.set() while holding

``

``

331

`+

// the inner mutex, so no two threads can call it concurrently.

`

231

332

`unsafe {

`

232

``

`-

if self.owner.load(Relaxed) == this_thread {

`

``

333

`+

if self.owner.contains(this_thread) {

`

233

334

`self.increment_lock_count()?;

`

234

335

`Some(ReentrantLockGuard { lock: self })

`

235

336

`} else if self.mutex.try_lock() {

`

236

``

`-

self.owner.store(this_thread, Relaxed);

`

``

337

`+

self.owner.set(Some(this_thread));

`

237

338

`debug_assert_eq!(*self.lock_count.get(), 0);

`

238

339

`*self.lock_count.get() = 1;

`

239

340

`Some(ReentrantLockGuard { lock: self })

`

`@@ -308,18 +409,9 @@ impl<T: ?Sized> Drop for ReentrantLockGuard<'_, T> {

`

308

409

`unsafe {

`

309

410

`*self.lock.lock_count.get() -= 1;

`

310

411

`if *self.lock.lock_count.get() == 0 {

`

311

``

`-

self.lock.owner.store(0, Relaxed);

`

``

412

`+

self.lock.owner.set(None);

`

312

413

`self.lock.mutex.unlock();

`

313

414

`}

`

314

415

`}

`

315

416

`}

`

316

417

`}

`

317

``

-

318

``

`-

/// Get an address that is unique per running thread.

`

319

``

`-

///

`

320

``

`-

/// This can be used as a non-null usize-sized ID.

`

321

``

`-

pub(crate) fn current_thread_unique_ptr() -> usize {

`

322

``

`-

// Use a non-drop type to make sure it's still available during thread destruction.

`

323

``

`-

thread_local! { static X: u8 = const { 0 } }

`

324

``

`-

X.with(|x| <*const _>::addr(x))

`

325

``

`-

}

`