xref: /linux/rust/pin-init/examples/mutex.rs (revision ec7714e4947909190ffb3041a03311a975350fe0)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
6 #![allow(clippy::missing_safety_doc)]
7 
8 use core::{
9     cell::{Cell, UnsafeCell},
10     marker::PhantomPinned,
11     ops::{Deref, DerefMut},
12     pin::Pin,
13     sync::atomic::{AtomicBool, Ordering},
14 };
15 use std::{
16     sync::Arc,
17     thread::{self, park, sleep, Builder, Thread},
18     time::Duration,
19 };
20 
21 use pin_init::*;
22 #[expect(unused_attributes)]
23 #[path = "./linked_list.rs"]
24 pub mod linked_list;
25 use linked_list::*;
26 
27 pub struct SpinLock {
28     inner: AtomicBool,
29 }
30 
31 impl SpinLock {
32     #[inline]
acquire(&self) -> SpinLockGuard<'_>33     pub fn acquire(&self) -> SpinLockGuard<'_> {
34         while self
35             .inner
36             .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
37             .is_err()
38         {
39             while self.inner.load(Ordering::Relaxed) {
40                 thread::yield_now();
41             }
42         }
43         SpinLockGuard(self)
44     }
45 
46     #[inline]
47     #[allow(clippy::new_without_default)]
new() -> Self48     pub const fn new() -> Self {
49         Self {
50             inner: AtomicBool::new(false),
51         }
52     }
53 }
54 
55 pub struct SpinLockGuard<'a>(&'a SpinLock);
56 
57 impl Drop for SpinLockGuard<'_> {
58     #[inline]
drop(&mut self)59     fn drop(&mut self) {
60         self.0.inner.store(false, Ordering::Release);
61     }
62 }
63 
64 #[pin_data]
65 pub struct CMutex<T> {
66     #[pin]
67     wait_list: ListHead,
68     spin_lock: SpinLock,
69     locked: Cell<bool>,
70     #[pin]
71     data: UnsafeCell<T>,
72 }
73 
74 impl<T> CMutex<T> {
75     #[inline]
new(val: impl PinInit<T>) -> impl PinInit<Self>76     pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
77         pin_init!(CMutex {
78             wait_list <- ListHead::new(),
79             spin_lock: SpinLock::new(),
80             locked: Cell::new(false),
81             data <- unsafe {
82                 pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
83                     val.__pinned_init(slot.cast::<T>())
84                 })
85             },
86         })
87     }
88 
89     #[inline]
lock(&self) -> Pin<CMutexGuard<'_, T>>90     pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
91         let mut sguard = self.spin_lock.acquire();
92         if self.locked.get() {
93             stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
94             // println!("wait list length: {}", self.wait_list.size());
95             while self.locked.get() {
96                 drop(sguard);
97                 park();
98                 sguard = self.spin_lock.acquire();
99             }
100             // This does have an effect, as the ListHead inside wait_entry implements Drop!
101             #[expect(clippy::drop_non_drop)]
102             drop(wait_entry);
103         }
104         self.locked.set(true);
105         unsafe {
106             Pin::new_unchecked(CMutexGuard {
107                 mtx: self,
108                 _pin: PhantomPinned,
109             })
110         }
111     }
112 
113     #[allow(dead_code)]
get_data_mut(self: Pin<&mut Self>) -> &mut T114     pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
115         // SAFETY: we have an exclusive reference and thus nobody has access to data.
116         unsafe { &mut *self.data.get() }
117     }
118 }
119 
120 unsafe impl<T: Send> Send for CMutex<T> {}
121 unsafe impl<T: Send> Sync for CMutex<T> {}
122 
123 pub struct CMutexGuard<'a, T> {
124     mtx: &'a CMutex<T>,
125     _pin: PhantomPinned,
126 }
127 
128 impl<T> Drop for CMutexGuard<'_, T> {
129     #[inline]
drop(&mut self)130     fn drop(&mut self) {
131         let sguard = self.mtx.spin_lock.acquire();
132         self.mtx.locked.set(false);
133         if let Some(list_field) = self.mtx.wait_list.next() {
134             let wait_entry = list_field.as_ptr().cast::<WaitEntry>();
135             unsafe { (*wait_entry).thread.unpark() };
136         }
137         drop(sguard);
138     }
139 }
140 
141 impl<T> Deref for CMutexGuard<'_, T> {
142     type Target = T;
143 
144     #[inline]
deref(&self) -> &Self::Target145     fn deref(&self) -> &Self::Target {
146         unsafe { &*self.mtx.data.get() }
147     }
148 }
149 
150 impl<T> DerefMut for CMutexGuard<'_, T> {
151     #[inline]
deref_mut(&mut self) -> &mut Self::Target152     fn deref_mut(&mut self) -> &mut Self::Target {
153         unsafe { &mut *self.mtx.data.get() }
154     }
155 }
156 
157 #[pin_data]
158 #[repr(C)]
159 struct WaitEntry {
160     #[pin]
161     wait_list: ListHead,
162     thread: Thread,
163 }
164 
165 impl WaitEntry {
166     #[inline]
insert_new(list: &ListHead) -> impl PinInit<Self> + '_167     fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
168         pin_init!(Self {
169             thread: thread::current(),
170             wait_list <- ListHead::insert_prev(list),
171         })
172     }
173 }
174 
175 #[cfg(not(any(feature = "std", feature = "alloc")))]
main()176 fn main() {}
177 
178 #[allow(dead_code)]
179 #[cfg_attr(test, test)]
180 #[cfg(any(feature = "std", feature = "alloc"))]
main()181 fn main() {
182     let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
183     let mut handles = vec![];
184     let thread_count = 20;
185     let workload = if cfg!(miri) { 100 } else { 1_000 };
186     for i in 0..thread_count {
187         let mtx = mtx.clone();
188         handles.push(
189             Builder::new()
190                 .name(format!("worker #{i}"))
191                 .spawn(move || {
192                     for _ in 0..workload {
193                         *mtx.lock() += 1;
194                     }
195                     println!("{i} halfway");
196                     sleep(Duration::from_millis((i as u64) * 10));
197                     for _ in 0..workload {
198                         *mtx.lock() += 1;
199                     }
200                     println!("{i} finished");
201                 })
202                 .expect("should not fail"),
203         );
204     }
205     for h in handles {
206         h.join().expect("thread panicked");
207     }
208     println!("{:?}", &*mtx.lock());
209     assert_eq!(*mtx.lock(), workload * thread_count * 2);
210 }
211