xref: /linux/rust/pin-init/examples/mutex.rs (revision 430654211d566f86e8ee533ff1b01a42be6b602c)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 #![allow(clippy::missing_safety_doc)]
6 
7 use core::{
8     cell::{Cell, UnsafeCell},
9     marker::PhantomPinned,
10     ops::{Deref, DerefMut},
11     pin::Pin,
12     sync::atomic::{AtomicBool, Ordering},
13 };
14 #[cfg(feature = "std")]
15 use std::{
16     sync::Arc,
17     thread::{self, sleep, Builder, Thread},
18     time::Duration,
19 };
20 
21 use pin_init::*;
22 #[allow(unused_attributes)]
23 #[path = "./linked_list.rs"]
24 pub mod linked_list;
25 use linked_list::*;
26 
27 pub struct SpinLock {
28     inner: AtomicBool,
29 }
30 
31 impl SpinLock {
32     #[inline]
33     pub fn acquire(&self) -> SpinLockGuard<'_> {
34         while self
35             .inner
36             .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
37             .is_err()
38         {
39             #[cfg(feature = "std")]
40             while self.inner.load(Ordering::Relaxed) {
41                 thread::yield_now();
42             }
43         }
44         SpinLockGuard(self)
45     }
46 
47     #[inline]
48     #[allow(clippy::new_without_default)]
49     pub const fn new() -> Self {
50         Self {
51             inner: AtomicBool::new(false),
52         }
53     }
54 }
55 
56 pub struct SpinLockGuard<'a>(&'a SpinLock);
57 
58 impl Drop for SpinLockGuard<'_> {
59     #[inline]
60     fn drop(&mut self) {
61         self.0.inner.store(false, Ordering::Release);
62     }
63 }
64 
65 #[pin_data]
66 pub struct CMutex<T> {
67     #[pin]
68     wait_list: ListHead,
69     spin_lock: SpinLock,
70     locked: Cell<bool>,
71     #[pin]
72     data: UnsafeCell<T>,
73 }
74 
75 impl<T> CMutex<T> {
76     #[inline]
77     pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
78         pin_init!(CMutex {
79             wait_list <- ListHead::new(),
80             spin_lock: SpinLock::new(),
81             locked: Cell::new(false),
82             data <- unsafe {
83                 pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
84                     val.__pinned_init(slot.cast::<T>())
85                 })
86             },
87         })
88     }
89 
90     #[inline]
91     pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
92         let mut sguard = self.spin_lock.acquire();
93         if self.locked.get() {
94             stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
95             // println!("wait list length: {}", self.wait_list.size());
96             while self.locked.get() {
97                 drop(sguard);
98                 #[cfg(feature = "std")]
99                 thread::park();
100                 sguard = self.spin_lock.acquire();
101             }
102             // This does have an effect, as the ListHead inside wait_entry implements Drop!
103             #[expect(clippy::drop_non_drop)]
104             drop(wait_entry);
105         }
106         self.locked.set(true);
107         unsafe {
108             Pin::new_unchecked(CMutexGuard {
109                 mtx: self,
110                 _pin: PhantomPinned,
111             })
112         }
113     }
114 
115     #[allow(dead_code)]
116     pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
117         // SAFETY: we have an exclusive reference and thus nobody has access to data.
118         unsafe { &mut *self.data.get() }
119     }
120 }
121 
122 unsafe impl<T: Send> Send for CMutex<T> {}
123 unsafe impl<T: Send> Sync for CMutex<T> {}
124 
125 pub struct CMutexGuard<'a, T> {
126     mtx: &'a CMutex<T>,
127     _pin: PhantomPinned,
128 }
129 
130 impl<T> Drop for CMutexGuard<'_, T> {
131     #[inline]
132     fn drop(&mut self) {
133         let sguard = self.mtx.spin_lock.acquire();
134         self.mtx.locked.set(false);
135         if let Some(list_field) = self.mtx.wait_list.next() {
136             let _wait_entry = list_field.as_ptr().cast::<WaitEntry>();
137             #[cfg(feature = "std")]
138             unsafe {
139                 (*_wait_entry).thread.unpark()
140             };
141         }
142         drop(sguard);
143     }
144 }
145 
146 impl<T> Deref for CMutexGuard<'_, T> {
147     type Target = T;
148 
149     #[inline]
150     fn deref(&self) -> &Self::Target {
151         unsafe { &*self.mtx.data.get() }
152     }
153 }
154 
155 impl<T> DerefMut for CMutexGuard<'_, T> {
156     #[inline]
157     fn deref_mut(&mut self) -> &mut Self::Target {
158         unsafe { &mut *self.mtx.data.get() }
159     }
160 }
161 
162 #[pin_data]
163 #[repr(C)]
164 struct WaitEntry {
165     #[pin]
166     wait_list: ListHead,
167     #[cfg(feature = "std")]
168     thread: Thread,
169 }
170 
171 impl WaitEntry {
172     #[inline]
173     fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
174         #[cfg(feature = "std")]
175         {
176             pin_init!(Self {
177                 thread: thread::current(),
178                 wait_list <- ListHead::insert_prev(list),
179             })
180         }
181         #[cfg(not(feature = "std"))]
182         {
183             pin_init!(Self {
184                 wait_list <- ListHead::insert_prev(list),
185             })
186         }
187     }
188 }
189 
190 #[cfg_attr(test, test)]
191 #[allow(dead_code)]
192 fn main() {
193     #[cfg(feature = "std")]
194     {
195         let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
196         let mut handles = vec![];
197         let thread_count = 20;
198         let workload = if cfg!(miri) { 100 } else { 1_000 };
199         for i in 0..thread_count {
200             let mtx = mtx.clone();
201             handles.push(
202                 Builder::new()
203                     .name(format!("worker #{i}"))
204                     .spawn(move || {
205                         for _ in 0..workload {
206                             *mtx.lock() += 1;
207                         }
208                         println!("{i} halfway");
209                         sleep(Duration::from_millis((i as u64) * 10));
210                         for _ in 0..workload {
211                             *mtx.lock() += 1;
212                         }
213                         println!("{i} finished");
214                     })
215                     .expect("should not fail"),
216             );
217         }
218         for h in handles {
219             h.join().expect("thread panicked");
220         }
221         println!("{:?}", *mtx.lock());
222         assert_eq!(*mtx.lock(), workload * thread_count * 2);
223     }
224 }
225