1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 #![allow(clippy::undocumented_unsafe_blocks)] 4 #![cfg_attr(feature = "alloc", feature(allocator_api))] 5 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))] 6 #![allow(clippy::missing_safety_doc)] 7 8 use core::{ 9 cell::{Cell, UnsafeCell}, 10 marker::PhantomPinned, 11 ops::{Deref, DerefMut}, 12 pin::Pin, 13 sync::atomic::{AtomicBool, Ordering}, 14 }; 15 #[cfg(feature = "std")] 16 use std::{ 17 sync::Arc, 18 thread::{self, sleep, Builder, Thread}, 19 time::Duration, 20 }; 21 22 use pin_init::*; 23 #[allow(unused_attributes)] 24 #[path = "./linked_list.rs"] 25 pub mod linked_list; 26 use linked_list::*; 27 28 pub struct SpinLock { 29 inner: AtomicBool, 30 } 31 32 impl SpinLock { 33 #[inline] 34 pub fn acquire(&self) -> SpinLockGuard<'_> { 35 while self 36 .inner 37 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) 38 .is_err() 39 { 40 #[cfg(feature = "std")] 41 while self.inner.load(Ordering::Relaxed) { 42 thread::yield_now(); 43 } 44 } 45 SpinLockGuard(self) 46 } 47 48 #[inline] 49 #[allow(clippy::new_without_default)] 50 pub const fn new() -> Self { 51 Self { 52 inner: AtomicBool::new(false), 53 } 54 } 55 } 56 57 pub struct SpinLockGuard<'a>(&'a SpinLock); 58 59 impl Drop for SpinLockGuard<'_> { 60 #[inline] 61 fn drop(&mut self) { 62 self.0.inner.store(false, Ordering::Release); 63 } 64 } 65 66 #[pin_data] 67 pub struct CMutex<T> { 68 #[pin] 69 wait_list: ListHead, 70 spin_lock: SpinLock, 71 locked: Cell<bool>, 72 #[pin] 73 data: UnsafeCell<T>, 74 } 75 76 impl<T> CMutex<T> { 77 #[inline] 78 pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> { 79 pin_init!(CMutex { 80 wait_list <- ListHead::new(), 81 spin_lock: SpinLock::new(), 82 locked: Cell::new(false), 83 data <- unsafe { 84 pin_init_from_closure(|slot: *mut UnsafeCell<T>| { 85 val.__pinned_init(slot.cast::<T>()) 86 }) 87 }, 88 }) 89 } 90 91 #[inline] 92 pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> { 93 let mut sguard = self.spin_lock.acquire(); 94 if self.locked.get() { 95 stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list)); 96 // println!("wait list length: {}", self.wait_list.size()); 97 while self.locked.get() { 98 drop(sguard); 99 #[cfg(feature = "std")] 100 thread::park(); 101 sguard = self.spin_lock.acquire(); 102 } 103 // This does have an effect, as the ListHead inside wait_entry implements Drop! 104 #[expect(clippy::drop_non_drop)] 105 drop(wait_entry); 106 } 107 self.locked.set(true); 108 unsafe { 109 Pin::new_unchecked(CMutexGuard { 110 mtx: self, 111 _pin: PhantomPinned, 112 }) 113 } 114 } 115 116 #[allow(dead_code)] 117 pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T { 118 // SAFETY: we have an exclusive reference and thus nobody has access to data. 119 unsafe { &mut *self.data.get() } 120 } 121 } 122 123 unsafe impl<T: Send> Send for CMutex<T> {} 124 unsafe impl<T: Send> Sync for CMutex<T> {} 125 126 pub struct CMutexGuard<'a, T> { 127 mtx: &'a CMutex<T>, 128 _pin: PhantomPinned, 129 } 130 131 impl<T> Drop for CMutexGuard<'_, T> { 132 #[inline] 133 fn drop(&mut self) { 134 let sguard = self.mtx.spin_lock.acquire(); 135 self.mtx.locked.set(false); 136 if let Some(list_field) = self.mtx.wait_list.next() { 137 let _wait_entry = list_field.as_ptr().cast::<WaitEntry>(); 138 #[cfg(feature = "std")] 139 unsafe { 140 (*_wait_entry).thread.unpark() 141 }; 142 } 143 drop(sguard); 144 } 145 } 146 147 impl<T> Deref for CMutexGuard<'_, T> { 148 type Target = T; 149 150 #[inline] 151 fn deref(&self) -> &Self::Target { 152 unsafe { &*self.mtx.data.get() } 153 } 154 } 155 156 impl<T> DerefMut for CMutexGuard<'_, T> { 157 #[inline] 158 fn deref_mut(&mut self) -> &mut Self::Target { 159 unsafe { &mut *self.mtx.data.get() } 160 } 161 } 162 163 #[pin_data] 164 #[repr(C)] 165 struct WaitEntry { 166 #[pin] 167 wait_list: ListHead, 168 #[cfg(feature = "std")] 169 thread: Thread, 170 } 171 172 impl WaitEntry { 173 #[inline] 174 fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ { 175 #[cfg(feature = "std")] 176 { 177 pin_init!(Self { 178 thread: thread::current(), 179 wait_list <- ListHead::insert_prev(list), 180 }) 181 } 182 #[cfg(not(feature = "std"))] 183 { 184 pin_init!(Self { 185 wait_list <- ListHead::insert_prev(list), 186 }) 187 } 188 } 189 } 190 191 #[cfg_attr(test, test)] 192 #[allow(dead_code)] 193 fn main() { 194 #[cfg(feature = "std")] 195 { 196 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); 197 let mut handles = vec![]; 198 let thread_count = 20; 199 let workload = if cfg!(miri) { 100 } else { 1_000 }; 200 for i in 0..thread_count { 201 let mtx = mtx.clone(); 202 handles.push( 203 Builder::new() 204 .name(format!("worker #{i}")) 205 .spawn(move || { 206 for _ in 0..workload { 207 *mtx.lock() += 1; 208 } 209 println!("{i} halfway"); 210 sleep(Duration::from_millis((i as u64) * 10)); 211 for _ in 0..workload { 212 *mtx.lock() += 1; 213 } 214 println!("{i} finished"); 215 }) 216 .expect("should not fail"), 217 ); 218 } 219 for h in handles { 220 h.join().expect("thread panicked"); 221 } 222 println!("{:?}", &*mtx.lock()); 223 assert_eq!(*mtx.lock(), workload * thread_count * 2); 224 } 225 } 226