1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 #![allow(clippy::undocumented_unsafe_blocks)] 4 #![cfg_attr(feature = "alloc", feature(allocator_api))] 5 #![allow(clippy::missing_safety_doc)] 6 7 use core::{ 8 cell::{Cell, UnsafeCell}, 9 marker::PhantomPinned, 10 ops::{Deref, DerefMut}, 11 pin::Pin, 12 sync::atomic::{AtomicBool, Ordering}, 13 }; 14 #[cfg(feature = "std")] 15 use std::{ 16 sync::Arc, 17 thread::{self, sleep, Builder, Thread}, 18 time::Duration, 19 }; 20 21 use pin_init::*; 22 #[allow(unused_attributes)] 23 #[path = "./linked_list.rs"] 24 pub mod linked_list; 25 use linked_list::*; 26 27 pub struct SpinLock { 28 inner: AtomicBool, 29 } 30 31 impl SpinLock { 32 #[inline] 33 pub fn acquire(&self) -> SpinLockGuard<'_> { 34 while self 35 .inner 36 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) 37 .is_err() 38 { 39 #[cfg(feature = "std")] 40 while self.inner.load(Ordering::Relaxed) { 41 thread::yield_now(); 42 } 43 } 44 SpinLockGuard(self) 45 } 46 47 #[inline] 48 #[allow(clippy::new_without_default)] 49 pub const fn new() -> Self { 50 Self { 51 inner: AtomicBool::new(false), 52 } 53 } 54 } 55 56 pub struct SpinLockGuard<'a>(&'a SpinLock); 57 58 impl Drop for SpinLockGuard<'_> { 59 #[inline] 60 fn drop(&mut self) { 61 self.0.inner.store(false, Ordering::Release); 62 } 63 } 64 65 #[pin_data] 66 pub struct CMutex<T> { 67 #[pin] 68 wait_list: ListHead, 69 spin_lock: SpinLock, 70 locked: Cell<bool>, 71 #[pin] 72 data: UnsafeCell<T>, 73 } 74 75 impl<T> CMutex<T> { 76 #[inline] 77 pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> { 78 pin_init!(CMutex { 79 wait_list <- ListHead::new(), 80 spin_lock: SpinLock::new(), 81 locked: Cell::new(false), 82 data <- unsafe { 83 pin_init_from_closure(|slot: *mut UnsafeCell<T>| { 84 val.__pinned_init(slot.cast::<T>()) 85 }) 86 }, 87 }) 88 } 89 90 #[inline] 91 pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> { 92 let mut sguard = self.spin_lock.acquire(); 93 if self.locked.get() { 94 stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list)); 95 // println!("wait list length: {}", self.wait_list.size()); 96 while self.locked.get() { 97 drop(sguard); 98 #[cfg(feature = "std")] 99 thread::park(); 100 sguard = self.spin_lock.acquire(); 101 } 102 // This does have an effect, as the ListHead inside wait_entry implements Drop! 103 #[expect(clippy::drop_non_drop)] 104 drop(wait_entry); 105 } 106 self.locked.set(true); 107 unsafe { 108 Pin::new_unchecked(CMutexGuard { 109 mtx: self, 110 _pin: PhantomPinned, 111 }) 112 } 113 } 114 115 #[allow(dead_code)] 116 pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T { 117 // SAFETY: we have an exclusive reference and thus nobody has access to data. 118 unsafe { &mut *self.data.get() } 119 } 120 } 121 122 unsafe impl<T: Send> Send for CMutex<T> {} 123 unsafe impl<T: Send> Sync for CMutex<T> {} 124 125 pub struct CMutexGuard<'a, T> { 126 mtx: &'a CMutex<T>, 127 _pin: PhantomPinned, 128 } 129 130 impl<T> Drop for CMutexGuard<'_, T> { 131 #[inline] 132 fn drop(&mut self) { 133 let sguard = self.mtx.spin_lock.acquire(); 134 self.mtx.locked.set(false); 135 if let Some(list_field) = self.mtx.wait_list.next() { 136 let _wait_entry = list_field.as_ptr().cast::<WaitEntry>(); 137 #[cfg(feature = "std")] 138 unsafe { 139 (*_wait_entry).thread.unpark() 140 }; 141 } 142 drop(sguard); 143 } 144 } 145 146 impl<T> Deref for CMutexGuard<'_, T> { 147 type Target = T; 148 149 #[inline] 150 fn deref(&self) -> &Self::Target { 151 unsafe { &*self.mtx.data.get() } 152 } 153 } 154 155 impl<T> DerefMut for CMutexGuard<'_, T> { 156 #[inline] 157 fn deref_mut(&mut self) -> &mut Self::Target { 158 unsafe { &mut *self.mtx.data.get() } 159 } 160 } 161 162 #[pin_data] 163 #[repr(C)] 164 struct WaitEntry { 165 #[pin] 166 wait_list: ListHead, 167 #[cfg(feature = "std")] 168 thread: Thread, 169 } 170 171 impl WaitEntry { 172 #[inline] 173 fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ { 174 #[cfg(feature = "std")] 175 { 176 pin_init!(Self { 177 thread: thread::current(), 178 wait_list <- ListHead::insert_prev(list), 179 }) 180 } 181 #[cfg(not(feature = "std"))] 182 { 183 pin_init!(Self { 184 wait_list <- ListHead::insert_prev(list), 185 }) 186 } 187 } 188 } 189 190 #[cfg_attr(test, test)] 191 #[allow(dead_code)] 192 fn main() { 193 #[cfg(feature = "std")] 194 { 195 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); 196 let mut handles = vec![]; 197 let thread_count = 20; 198 let workload = if cfg!(miri) { 100 } else { 1_000 }; 199 for i in 0..thread_count { 200 let mtx = mtx.clone(); 201 handles.push( 202 Builder::new() 203 .name(format!("worker #{i}")) 204 .spawn(move || { 205 for _ in 0..workload { 206 *mtx.lock() += 1; 207 } 208 println!("{i} halfway"); 209 sleep(Duration::from_millis((i as u64) * 10)); 210 for _ in 0..workload { 211 *mtx.lock() += 1; 212 } 213 println!("{i} finished"); 214 }) 215 .expect("should not fail"), 216 ); 217 } 218 for h in handles { 219 h.join().expect("thread panicked"); 220 } 221 println!("{:?}", *mtx.lock()); 222 assert_eq!(*mtx.lock(), workload * thread_count * 2); 223 } 224 } 225