1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 #![allow(clippy::undocumented_unsafe_blocks)] 4 #![cfg_attr(feature = "alloc", feature(allocator_api))] 5 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))] 6 #![allow(clippy::missing_safety_doc)] 7 8 use core::{ 9 cell::{Cell, UnsafeCell}, 10 marker::PhantomPinned, 11 ops::{Deref, DerefMut}, 12 pin::Pin, 13 sync::atomic::{AtomicBool, Ordering}, 14 }; 15 use std::{ 16 sync::Arc, 17 thread::{self, park, sleep, Builder, Thread}, 18 time::Duration, 19 }; 20 21 use pin_init::*; 22 #[expect(unused_attributes)] 23 #[path = "./linked_list.rs"] 24 pub mod linked_list; 25 use linked_list::*; 26 27 pub struct SpinLock { 28 inner: AtomicBool, 29 } 30 31 impl SpinLock { 32 #[inline] 33 pub fn acquire(&self) -> SpinLockGuard<'_> { 34 while self 35 .inner 36 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) 37 .is_err() 38 { 39 while self.inner.load(Ordering::Relaxed) { 40 thread::yield_now(); 41 } 42 } 43 SpinLockGuard(self) 44 } 45 46 #[inline] 47 #[allow(clippy::new_without_default)] 48 pub const fn new() -> Self { 49 Self { 50 inner: AtomicBool::new(false), 51 } 52 } 53 } 54 55 pub struct SpinLockGuard<'a>(&'a SpinLock); 56 57 impl Drop for SpinLockGuard<'_> { 58 #[inline] 59 fn drop(&mut self) { 60 self.0.inner.store(false, Ordering::Release); 61 } 62 } 63 64 #[pin_data] 65 pub struct CMutex<T> { 66 #[pin] 67 wait_list: ListHead, 68 spin_lock: SpinLock, 69 locked: Cell<bool>, 70 #[pin] 71 data: UnsafeCell<T>, 72 } 73 74 impl<T> CMutex<T> { 75 #[inline] 76 pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> { 77 pin_init!(CMutex { 78 wait_list <- ListHead::new(), 79 spin_lock: SpinLock::new(), 80 locked: Cell::new(false), 81 data <- unsafe { 82 pin_init_from_closure(|slot: *mut UnsafeCell<T>| { 83 val.__pinned_init(slot.cast::<T>()) 84 }) 85 }, 86 }) 87 } 88 89 #[inline] 90 pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> { 91 let mut sguard = self.spin_lock.acquire(); 92 if self.locked.get() { 93 stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list)); 94 // println!("wait list length: {}", self.wait_list.size()); 95 while self.locked.get() { 96 drop(sguard); 97 park(); 98 sguard = self.spin_lock.acquire(); 99 } 100 // This does have an effect, as the ListHead inside wait_entry implements Drop! 101 #[expect(clippy::drop_non_drop)] 102 drop(wait_entry); 103 } 104 self.locked.set(true); 105 unsafe { 106 Pin::new_unchecked(CMutexGuard { 107 mtx: self, 108 _pin: PhantomPinned, 109 }) 110 } 111 } 112 113 #[allow(dead_code)] 114 pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T { 115 // SAFETY: we have an exclusive reference and thus nobody has access to data. 116 unsafe { &mut *self.data.get() } 117 } 118 } 119 120 unsafe impl<T: Send> Send for CMutex<T> {} 121 unsafe impl<T: Send> Sync for CMutex<T> {} 122 123 pub struct CMutexGuard<'a, T> { 124 mtx: &'a CMutex<T>, 125 _pin: PhantomPinned, 126 } 127 128 impl<T> Drop for CMutexGuard<'_, T> { 129 #[inline] 130 fn drop(&mut self) { 131 let sguard = self.mtx.spin_lock.acquire(); 132 self.mtx.locked.set(false); 133 if let Some(list_field) = self.mtx.wait_list.next() { 134 let wait_entry = list_field.as_ptr().cast::<WaitEntry>(); 135 unsafe { (*wait_entry).thread.unpark() }; 136 } 137 drop(sguard); 138 } 139 } 140 141 impl<T> Deref for CMutexGuard<'_, T> { 142 type Target = T; 143 144 #[inline] 145 fn deref(&self) -> &Self::Target { 146 unsafe { &*self.mtx.data.get() } 147 } 148 } 149 150 impl<T> DerefMut for CMutexGuard<'_, T> { 151 #[inline] 152 fn deref_mut(&mut self) -> &mut Self::Target { 153 unsafe { &mut *self.mtx.data.get() } 154 } 155 } 156 157 #[pin_data] 158 #[repr(C)] 159 struct WaitEntry { 160 #[pin] 161 wait_list: ListHead, 162 thread: Thread, 163 } 164 165 impl WaitEntry { 166 #[inline] 167 fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ { 168 pin_init!(Self { 169 thread: thread::current(), 170 wait_list <- ListHead::insert_prev(list), 171 }) 172 } 173 } 174 175 #[cfg(not(any(feature = "std", feature = "alloc")))] 176 fn main() {} 177 178 #[allow(dead_code)] 179 #[cfg_attr(test, test)] 180 #[cfg(any(feature = "std", feature = "alloc"))] 181 fn main() { 182 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); 183 let mut handles = vec![]; 184 let thread_count = 20; 185 let workload = if cfg!(miri) { 100 } else { 1_000 }; 186 for i in 0..thread_count { 187 let mtx = mtx.clone(); 188 handles.push( 189 Builder::new() 190 .name(format!("worker #{i}")) 191 .spawn(move || { 192 for _ in 0..workload { 193 *mtx.lock() += 1; 194 } 195 println!("{i} halfway"); 196 sleep(Duration::from_millis((i as u64) * 10)); 197 for _ in 0..workload { 198 *mtx.lock() += 1; 199 } 200 println!("{i} finished"); 201 }) 202 .expect("should not fail"), 203 ); 204 } 205 for h in handles { 206 h.join().expect("thread panicked"); 207 } 208 println!("{:?}", &*mtx.lock()); 209 assert_eq!(*mtx.lock(), workload * thread_count * 2); 210 } 211