xref: /linux/rust/pin-init/examples/mutex.rs (revision 09808839c7aa6695ceff5cd822c18b0d9550184d)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 #![cfg_attr(USE_RUSTC_FEATURES, feature(lint_reasons))]
6 #![cfg_attr(USE_RUSTC_FEATURES, feature(raw_ref_op))]
7 #![allow(clippy::missing_safety_doc)]
8 
9 use core::{
10     cell::{Cell, UnsafeCell},
11     marker::PhantomPinned,
12     ops::{Deref, DerefMut},
13     pin::Pin,
14     sync::atomic::{AtomicBool, Ordering},
15 };
16 #[cfg(feature = "std")]
17 use std::{
18     sync::Arc,
19     thread::{self, sleep, Builder, Thread},
20     time::Duration,
21 };
22 
23 use pin_init::*;
24 #[allow(unused_attributes)]
25 #[path = "./linked_list.rs"]
26 pub mod linked_list;
27 use linked_list::*;
28 
29 pub struct SpinLock {
30     inner: AtomicBool,
31 }
32 
33 impl SpinLock {
34     #[inline]
35     pub fn acquire(&self) -> SpinLockGuard<'_> {
36         while self
37             .inner
38             .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
39             .is_err()
40         {
41             #[cfg(feature = "std")]
42             while self.inner.load(Ordering::Relaxed) {
43                 thread::yield_now();
44             }
45         }
46         SpinLockGuard(self)
47     }
48 
49     #[inline]
50     #[allow(clippy::new_without_default)]
51     pub const fn new() -> Self {
52         Self {
53             inner: AtomicBool::new(false),
54         }
55     }
56 }
57 
58 pub struct SpinLockGuard<'a>(&'a SpinLock);
59 
60 impl Drop for SpinLockGuard<'_> {
61     #[inline]
62     fn drop(&mut self) {
63         self.0.inner.store(false, Ordering::Release);
64     }
65 }
66 
67 #[pin_data]
68 pub struct CMutex<T> {
69     #[pin]
70     wait_list: ListHead,
71     spin_lock: SpinLock,
72     locked: Cell<bool>,
73     #[pin]
74     data: UnsafeCell<T>,
75 }
76 
77 impl<T> CMutex<T> {
78     #[inline]
79     pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
80         pin_init!(CMutex {
81             wait_list <- ListHead::new(),
82             spin_lock: SpinLock::new(),
83             locked: Cell::new(false),
84             data <- unsafe {
85                 pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
86                     val.__pinned_init(slot.cast::<T>())
87                 })
88             },
89         })
90     }
91 
92     #[inline]
93     pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
94         let mut sguard = self.spin_lock.acquire();
95         if self.locked.get() {
96             stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
97             // println!("wait list length: {}", self.wait_list.size());
98             while self.locked.get() {
99                 drop(sguard);
100                 #[cfg(feature = "std")]
101                 thread::park();
102                 sguard = self.spin_lock.acquire();
103             }
104             // This does have an effect, as the ListHead inside wait_entry implements Drop!
105             #[expect(clippy::drop_non_drop)]
106             drop(wait_entry);
107         }
108         self.locked.set(true);
109         unsafe {
110             Pin::new_unchecked(CMutexGuard {
111                 mtx: self,
112                 _pin: PhantomPinned,
113             })
114         }
115     }
116 
117     #[allow(dead_code)]
118     pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
119         // SAFETY: we have an exclusive reference and thus nobody has access to data.
120         unsafe { &mut *self.data.get() }
121     }
122 }
123 
124 unsafe impl<T: Send> Send for CMutex<T> {}
125 unsafe impl<T: Send> Sync for CMutex<T> {}
126 
127 pub struct CMutexGuard<'a, T> {
128     mtx: &'a CMutex<T>,
129     _pin: PhantomPinned,
130 }
131 
132 impl<T> Drop for CMutexGuard<'_, T> {
133     #[inline]
134     fn drop(&mut self) {
135         let sguard = self.mtx.spin_lock.acquire();
136         self.mtx.locked.set(false);
137         if let Some(list_field) = self.mtx.wait_list.next() {
138             let _wait_entry = list_field.as_ptr().cast::<WaitEntry>();
139             #[cfg(feature = "std")]
140             unsafe {
141                 (*_wait_entry).thread.unpark()
142             };
143         }
144         drop(sguard);
145     }
146 }
147 
148 impl<T> Deref for CMutexGuard<'_, T> {
149     type Target = T;
150 
151     #[inline]
152     fn deref(&self) -> &Self::Target {
153         unsafe { &*self.mtx.data.get() }
154     }
155 }
156 
157 impl<T> DerefMut for CMutexGuard<'_, T> {
158     #[inline]
159     fn deref_mut(&mut self) -> &mut Self::Target {
160         unsafe { &mut *self.mtx.data.get() }
161     }
162 }
163 
164 #[pin_data]
165 #[repr(C)]
166 struct WaitEntry {
167     #[pin]
168     wait_list: ListHead,
169     #[cfg(feature = "std")]
170     thread: Thread,
171 }
172 
173 impl WaitEntry {
174     #[inline]
175     fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
176         #[cfg(feature = "std")]
177         {
178             pin_init!(Self {
179                 thread: thread::current(),
180                 wait_list <- ListHead::insert_prev(list),
181             })
182         }
183         #[cfg(not(feature = "std"))]
184         {
185             pin_init!(Self {
186                 wait_list <- ListHead::insert_prev(list),
187             })
188         }
189     }
190 }
191 
192 #[cfg_attr(test, test)]
193 #[allow(dead_code)]
194 fn main() {
195     #[cfg(feature = "std")]
196     {
197         let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
198         let mut handles = vec![];
199         let thread_count = 20;
200         let workload = if cfg!(miri) { 100 } else { 1_000 };
201         for i in 0..thread_count {
202             let mtx = mtx.clone();
203             handles.push(
204                 Builder::new()
205                     .name(format!("worker #{i}"))
206                     .spawn(move || {
207                         for _ in 0..workload {
208                             *mtx.lock() += 1;
209                         }
210                         println!("{i} halfway");
211                         sleep(Duration::from_millis((i as u64) * 10));
212                         for _ in 0..workload {
213                             *mtx.lock() += 1;
214                         }
215                         println!("{i} finished");
216                     })
217                     .expect("should not fail"),
218             );
219         }
220         for h in handles {
221             h.join().expect("thread panicked");
222         }
223         println!("{:?}", &*mtx.lock());
224         assert_eq!(*mtx.lock(), workload * thread_count * 2);
225     }
226 }
227