1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 #![allow(clippy::undocumented_unsafe_blocks)] 4 #![cfg_attr(feature = "alloc", feature(allocator_api))] 5 #![allow(clippy::missing_safety_doc)] 6 7 use core::{ 8 cell::{Cell, UnsafeCell}, 9 marker::PhantomPinned, 10 ops::{Deref, DerefMut}, 11 pin::Pin, 12 sync::atomic::{AtomicBool, Ordering}, 13 }; 14 use std::{ 15 sync::Arc, 16 thread::{self, park, sleep, Builder, Thread}, 17 time::Duration, 18 }; 19 20 use pin_init::*; 21 #[expect(unused_attributes)] 22 #[path = "./linked_list.rs"] 23 pub mod linked_list; 24 use linked_list::*; 25 26 pub struct SpinLock { 27 inner: AtomicBool, 28 } 29 30 impl SpinLock { 31 #[inline] 32 pub fn acquire(&self) -> SpinLockGuard<'_> { 33 while self 34 .inner 35 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) 36 .is_err() 37 { 38 while self.inner.load(Ordering::Relaxed) { 39 thread::yield_now(); 40 } 41 } 42 SpinLockGuard(self) 43 } 44 45 #[inline] 46 #[allow(clippy::new_without_default)] 47 pub const fn new() -> Self { 48 Self { 49 inner: AtomicBool::new(false), 50 } 51 } 52 } 53 54 pub struct SpinLockGuard<'a>(&'a SpinLock); 55 56 impl Drop for SpinLockGuard<'_> { 57 #[inline] 58 fn drop(&mut self) { 59 self.0.inner.store(false, Ordering::Release); 60 } 61 } 62 63 #[pin_data] 64 pub struct CMutex<T> { 65 #[pin] 66 wait_list: ListHead, 67 spin_lock: SpinLock, 68 locked: Cell<bool>, 69 #[pin] 70 data: UnsafeCell<T>, 71 } 72 73 impl<T> CMutex<T> { 74 #[inline] 75 pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> { 76 pin_init!(CMutex { 77 wait_list <- ListHead::new(), 78 spin_lock: SpinLock::new(), 79 locked: Cell::new(false), 80 data <- unsafe { 81 pin_init_from_closure(|slot: *mut UnsafeCell<T>| { 82 val.__pinned_init(slot.cast::<T>()) 83 }) 84 }, 85 }) 86 } 87 88 #[inline] 89 pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> { 90 let mut sguard = self.spin_lock.acquire(); 91 if self.locked.get() { 92 stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list)); 93 // println!("wait list length: {}", self.wait_list.size()); 94 while self.locked.get() { 95 drop(sguard); 96 park(); 97 sguard = self.spin_lock.acquire(); 98 } 99 // This does have an effect, as the ListHead inside wait_entry implements Drop! 100 #[expect(clippy::drop_non_drop)] 101 drop(wait_entry); 102 } 103 self.locked.set(true); 104 unsafe { 105 Pin::new_unchecked(CMutexGuard { 106 mtx: self, 107 _pin: PhantomPinned, 108 }) 109 } 110 } 111 112 #[allow(dead_code)] 113 pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T { 114 // SAFETY: we have an exclusive reference and thus nobody has access to data. 115 unsafe { &mut *self.data.get() } 116 } 117 } 118 119 unsafe impl<T: Send> Send for CMutex<T> {} 120 unsafe impl<T: Send> Sync for CMutex<T> {} 121 122 pub struct CMutexGuard<'a, T> { 123 mtx: &'a CMutex<T>, 124 _pin: PhantomPinned, 125 } 126 127 impl<T> Drop for CMutexGuard<'_, T> { 128 #[inline] 129 fn drop(&mut self) { 130 let sguard = self.mtx.spin_lock.acquire(); 131 self.mtx.locked.set(false); 132 if let Some(list_field) = self.mtx.wait_list.next() { 133 let wait_entry = list_field.as_ptr().cast::<WaitEntry>(); 134 unsafe { (*wait_entry).thread.unpark() }; 135 } 136 drop(sguard); 137 } 138 } 139 140 impl<T> Deref for CMutexGuard<'_, T> { 141 type Target = T; 142 143 #[inline] 144 fn deref(&self) -> &Self::Target { 145 unsafe { &*self.mtx.data.get() } 146 } 147 } 148 149 impl<T> DerefMut for CMutexGuard<'_, T> { 150 #[inline] 151 fn deref_mut(&mut self) -> &mut Self::Target { 152 unsafe { &mut *self.mtx.data.get() } 153 } 154 } 155 156 #[pin_data] 157 #[repr(C)] 158 struct WaitEntry { 159 #[pin] 160 wait_list: ListHead, 161 thread: Thread, 162 } 163 164 impl WaitEntry { 165 #[inline] 166 fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ { 167 pin_init!(Self { 168 thread: thread::current(), 169 wait_list <- ListHead::insert_prev(list), 170 }) 171 } 172 } 173 174 #[cfg(not(any(feature = "std", feature = "alloc")))] 175 fn main() {} 176 177 #[allow(dead_code)] 178 #[cfg_attr(test, test)] 179 #[cfg(any(feature = "std", feature = "alloc"))] 180 fn main() { 181 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); 182 let mut handles = vec![]; 183 let thread_count = 20; 184 let workload = if cfg!(miri) { 100 } else { 1_000 }; 185 for i in 0..thread_count { 186 let mtx = mtx.clone(); 187 handles.push( 188 Builder::new() 189 .name(format!("worker #{i}")) 190 .spawn(move || { 191 for _ in 0..workload { 192 *mtx.lock() += 1; 193 } 194 println!("{i} halfway"); 195 sleep(Duration::from_millis((i as u64) * 10)); 196 for _ in 0..workload { 197 *mtx.lock() += 1; 198 } 199 println!("{i} finished"); 200 }) 201 .expect("should not fail"), 202 ); 203 } 204 for h in handles { 205 h.join().expect("thread panicked"); 206 } 207 println!("{:?}", &*mtx.lock()); 208 assert_eq!(*mtx.lock(), workload * thread_count * 2); 209 } 210