1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 #![allow(clippy::undocumented_unsafe_blocks)] 4 #![cfg_attr(feature = "alloc", feature(allocator_api))] 5 #![cfg_attr(USE_RUSTC_FEATURES, feature(lint_reasons))] 6 #![cfg_attr(USE_RUSTC_FEATURES, feature(raw_ref_op))] 7 #![allow(clippy::missing_safety_doc)] 8 9 use core::{ 10 cell::{Cell, UnsafeCell}, 11 marker::PhantomPinned, 12 ops::{Deref, DerefMut}, 13 pin::Pin, 14 sync::atomic::{AtomicBool, Ordering}, 15 }; 16 #[cfg(feature = "std")] 17 use std::{ 18 sync::Arc, 19 thread::{self, sleep, Builder, Thread}, 20 time::Duration, 21 }; 22 23 use pin_init::*; 24 #[allow(unused_attributes)] 25 #[path = "./linked_list.rs"] 26 pub mod linked_list; 27 use linked_list::*; 28 29 pub struct SpinLock { 30 inner: AtomicBool, 31 } 32 33 impl SpinLock { 34 #[inline] 35 pub fn acquire(&self) -> SpinLockGuard<'_> { 36 while self 37 .inner 38 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) 39 .is_err() 40 { 41 #[cfg(feature = "std")] 42 while self.inner.load(Ordering::Relaxed) { 43 thread::yield_now(); 44 } 45 } 46 SpinLockGuard(self) 47 } 48 49 #[inline] 50 #[allow(clippy::new_without_default)] 51 pub const fn new() -> Self { 52 Self { 53 inner: AtomicBool::new(false), 54 } 55 } 56 } 57 58 pub struct SpinLockGuard<'a>(&'a SpinLock); 59 60 impl Drop for SpinLockGuard<'_> { 61 #[inline] 62 fn drop(&mut self) { 63 self.0.inner.store(false, Ordering::Release); 64 } 65 } 66 67 #[pin_data] 68 pub struct CMutex<T> { 69 #[pin] 70 wait_list: ListHead, 71 spin_lock: SpinLock, 72 locked: Cell<bool>, 73 #[pin] 74 data: UnsafeCell<T>, 75 } 76 77 impl<T> CMutex<T> { 78 #[inline] 79 pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> { 80 pin_init!(CMutex { 81 wait_list <- ListHead::new(), 82 spin_lock: SpinLock::new(), 83 locked: Cell::new(false), 84 data <- unsafe { 85 pin_init_from_closure(|slot: *mut UnsafeCell<T>| { 86 val.__pinned_init(slot.cast::<T>()) 87 }) 88 }, 89 }) 90 } 91 92 #[inline] 93 pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> { 94 let mut sguard = self.spin_lock.acquire(); 95 if self.locked.get() { 96 stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list)); 97 // println!("wait list length: {}", self.wait_list.size()); 98 while self.locked.get() { 99 drop(sguard); 100 #[cfg(feature = "std")] 101 thread::park(); 102 sguard = self.spin_lock.acquire(); 103 } 104 // This does have an effect, as the ListHead inside wait_entry implements Drop! 105 #[expect(clippy::drop_non_drop)] 106 drop(wait_entry); 107 } 108 self.locked.set(true); 109 unsafe { 110 Pin::new_unchecked(CMutexGuard { 111 mtx: self, 112 _pin: PhantomPinned, 113 }) 114 } 115 } 116 117 #[allow(dead_code)] 118 pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T { 119 // SAFETY: we have an exclusive reference and thus nobody has access to data. 120 unsafe { &mut *self.data.get() } 121 } 122 } 123 124 unsafe impl<T: Send> Send for CMutex<T> {} 125 unsafe impl<T: Send> Sync for CMutex<T> {} 126 127 pub struct CMutexGuard<'a, T> { 128 mtx: &'a CMutex<T>, 129 _pin: PhantomPinned, 130 } 131 132 impl<T> Drop for CMutexGuard<'_, T> { 133 #[inline] 134 fn drop(&mut self) { 135 let sguard = self.mtx.spin_lock.acquire(); 136 self.mtx.locked.set(false); 137 if let Some(list_field) = self.mtx.wait_list.next() { 138 let _wait_entry = list_field.as_ptr().cast::<WaitEntry>(); 139 #[cfg(feature = "std")] 140 unsafe { 141 (*_wait_entry).thread.unpark() 142 }; 143 } 144 drop(sguard); 145 } 146 } 147 148 impl<T> Deref for CMutexGuard<'_, T> { 149 type Target = T; 150 151 #[inline] 152 fn deref(&self) -> &Self::Target { 153 unsafe { &*self.mtx.data.get() } 154 } 155 } 156 157 impl<T> DerefMut for CMutexGuard<'_, T> { 158 #[inline] 159 fn deref_mut(&mut self) -> &mut Self::Target { 160 unsafe { &mut *self.mtx.data.get() } 161 } 162 } 163 164 #[pin_data] 165 #[repr(C)] 166 struct WaitEntry { 167 #[pin] 168 wait_list: ListHead, 169 #[cfg(feature = "std")] 170 thread: Thread, 171 } 172 173 impl WaitEntry { 174 #[inline] 175 fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ { 176 #[cfg(feature = "std")] 177 { 178 pin_init!(Self { 179 thread: thread::current(), 180 wait_list <- ListHead::insert_prev(list), 181 }) 182 } 183 #[cfg(not(feature = "std"))] 184 { 185 pin_init!(Self { 186 wait_list <- ListHead::insert_prev(list), 187 }) 188 } 189 } 190 } 191 192 #[cfg_attr(test, test)] 193 #[allow(dead_code)] 194 fn main() { 195 #[cfg(feature = "std")] 196 { 197 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); 198 let mut handles = vec![]; 199 let thread_count = 20; 200 let workload = if cfg!(miri) { 100 } else { 1_000 }; 201 for i in 0..thread_count { 202 let mtx = mtx.clone(); 203 handles.push( 204 Builder::new() 205 .name(format!("worker #{i}")) 206 .spawn(move || { 207 for _ in 0..workload { 208 *mtx.lock() += 1; 209 } 210 println!("{i} halfway"); 211 sleep(Duration::from_millis((i as u64) * 10)); 212 for _ in 0..workload { 213 *mtx.lock() += 1; 214 } 215 println!("{i} finished"); 216 }) 217 .expect("should not fail"), 218 ); 219 } 220 for h in handles { 221 h.join().expect("thread panicked"); 222 } 223 println!("{:?}", &*mtx.lock()); 224 assert_eq!(*mtx.lock(), workload * thread_count * 2); 225 } 226 } 227