xref: /linux/rust/pin-init/examples/mutex.rs (revision ec7714e4947909190ffb3041a03311a975350fe0)
184837cf6SBenno Lossin // SPDX-License-Identifier: Apache-2.0 OR MIT
284837cf6SBenno Lossin 
384837cf6SBenno Lossin #![allow(clippy::undocumented_unsafe_blocks)]
484837cf6SBenno Lossin #![cfg_attr(feature = "alloc", feature(allocator_api))]
5*5c4167b4SBenno Lossin #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
684837cf6SBenno Lossin #![allow(clippy::missing_safety_doc)]
784837cf6SBenno Lossin 
884837cf6SBenno Lossin use core::{
984837cf6SBenno Lossin     cell::{Cell, UnsafeCell},
1084837cf6SBenno Lossin     marker::PhantomPinned,
1184837cf6SBenno Lossin     ops::{Deref, DerefMut},
1284837cf6SBenno Lossin     pin::Pin,
1384837cf6SBenno Lossin     sync::atomic::{AtomicBool, Ordering},
1484837cf6SBenno Lossin };
1584837cf6SBenno Lossin use std::{
1684837cf6SBenno Lossin     sync::Arc,
1784837cf6SBenno Lossin     thread::{self, park, sleep, Builder, Thread},
1884837cf6SBenno Lossin     time::Duration,
1984837cf6SBenno Lossin };
2084837cf6SBenno Lossin 
2184837cf6SBenno Lossin use pin_init::*;
2284837cf6SBenno Lossin #[expect(unused_attributes)]
2384837cf6SBenno Lossin #[path = "./linked_list.rs"]
2484837cf6SBenno Lossin pub mod linked_list;
2584837cf6SBenno Lossin use linked_list::*;
2684837cf6SBenno Lossin 
2784837cf6SBenno Lossin pub struct SpinLock {
2884837cf6SBenno Lossin     inner: AtomicBool,
2984837cf6SBenno Lossin }
3084837cf6SBenno Lossin 
3184837cf6SBenno Lossin impl SpinLock {
3284837cf6SBenno Lossin     #[inline]
acquire(&self) -> SpinLockGuard<'_>3384837cf6SBenno Lossin     pub fn acquire(&self) -> SpinLockGuard<'_> {
3484837cf6SBenno Lossin         while self
3584837cf6SBenno Lossin             .inner
3684837cf6SBenno Lossin             .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
3784837cf6SBenno Lossin             .is_err()
3884837cf6SBenno Lossin         {
3984837cf6SBenno Lossin             while self.inner.load(Ordering::Relaxed) {
4084837cf6SBenno Lossin                 thread::yield_now();
4184837cf6SBenno Lossin             }
4284837cf6SBenno Lossin         }
4384837cf6SBenno Lossin         SpinLockGuard(self)
4484837cf6SBenno Lossin     }
4584837cf6SBenno Lossin 
4684837cf6SBenno Lossin     #[inline]
4784837cf6SBenno Lossin     #[allow(clippy::new_without_default)]
new() -> Self4884837cf6SBenno Lossin     pub const fn new() -> Self {
4984837cf6SBenno Lossin         Self {
5084837cf6SBenno Lossin             inner: AtomicBool::new(false),
5184837cf6SBenno Lossin         }
5284837cf6SBenno Lossin     }
5384837cf6SBenno Lossin }
5484837cf6SBenno Lossin 
5584837cf6SBenno Lossin pub struct SpinLockGuard<'a>(&'a SpinLock);
5684837cf6SBenno Lossin 
5784837cf6SBenno Lossin impl Drop for SpinLockGuard<'_> {
5884837cf6SBenno Lossin     #[inline]
drop(&mut self)5984837cf6SBenno Lossin     fn drop(&mut self) {
6084837cf6SBenno Lossin         self.0.inner.store(false, Ordering::Release);
6184837cf6SBenno Lossin     }
6284837cf6SBenno Lossin }
6384837cf6SBenno Lossin 
6484837cf6SBenno Lossin #[pin_data]
6584837cf6SBenno Lossin pub struct CMutex<T> {
6684837cf6SBenno Lossin     #[pin]
6784837cf6SBenno Lossin     wait_list: ListHead,
6884837cf6SBenno Lossin     spin_lock: SpinLock,
6984837cf6SBenno Lossin     locked: Cell<bool>,
7084837cf6SBenno Lossin     #[pin]
7184837cf6SBenno Lossin     data: UnsafeCell<T>,
7284837cf6SBenno Lossin }
7384837cf6SBenno Lossin 
7484837cf6SBenno Lossin impl<T> CMutex<T> {
7584837cf6SBenno Lossin     #[inline]
new(val: impl PinInit<T>) -> impl PinInit<Self>7684837cf6SBenno Lossin     pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
7784837cf6SBenno Lossin         pin_init!(CMutex {
7884837cf6SBenno Lossin             wait_list <- ListHead::new(),
7984837cf6SBenno Lossin             spin_lock: SpinLock::new(),
8084837cf6SBenno Lossin             locked: Cell::new(false),
8184837cf6SBenno Lossin             data <- unsafe {
8284837cf6SBenno Lossin                 pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
8384837cf6SBenno Lossin                     val.__pinned_init(slot.cast::<T>())
8484837cf6SBenno Lossin                 })
8584837cf6SBenno Lossin             },
8684837cf6SBenno Lossin         })
8784837cf6SBenno Lossin     }
8884837cf6SBenno Lossin 
8984837cf6SBenno Lossin     #[inline]
lock(&self) -> Pin<CMutexGuard<'_, T>>9084837cf6SBenno Lossin     pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
9184837cf6SBenno Lossin         let mut sguard = self.spin_lock.acquire();
9284837cf6SBenno Lossin         if self.locked.get() {
9384837cf6SBenno Lossin             stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
9484837cf6SBenno Lossin             // println!("wait list length: {}", self.wait_list.size());
9584837cf6SBenno Lossin             while self.locked.get() {
9684837cf6SBenno Lossin                 drop(sguard);
9784837cf6SBenno Lossin                 park();
9884837cf6SBenno Lossin                 sguard = self.spin_lock.acquire();
9984837cf6SBenno Lossin             }
10084837cf6SBenno Lossin             // This does have an effect, as the ListHead inside wait_entry implements Drop!
10184837cf6SBenno Lossin             #[expect(clippy::drop_non_drop)]
10284837cf6SBenno Lossin             drop(wait_entry);
10384837cf6SBenno Lossin         }
10484837cf6SBenno Lossin         self.locked.set(true);
10584837cf6SBenno Lossin         unsafe {
10684837cf6SBenno Lossin             Pin::new_unchecked(CMutexGuard {
10784837cf6SBenno Lossin                 mtx: self,
10884837cf6SBenno Lossin                 _pin: PhantomPinned,
10984837cf6SBenno Lossin             })
11084837cf6SBenno Lossin         }
11184837cf6SBenno Lossin     }
11284837cf6SBenno Lossin 
11384837cf6SBenno Lossin     #[allow(dead_code)]
get_data_mut(self: Pin<&mut Self>) -> &mut T11484837cf6SBenno Lossin     pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
11584837cf6SBenno Lossin         // SAFETY: we have an exclusive reference and thus nobody has access to data.
11684837cf6SBenno Lossin         unsafe { &mut *self.data.get() }
11784837cf6SBenno Lossin     }
11884837cf6SBenno Lossin }
11984837cf6SBenno Lossin 
12084837cf6SBenno Lossin unsafe impl<T: Send> Send for CMutex<T> {}
12184837cf6SBenno Lossin unsafe impl<T: Send> Sync for CMutex<T> {}
12284837cf6SBenno Lossin 
12384837cf6SBenno Lossin pub struct CMutexGuard<'a, T> {
12484837cf6SBenno Lossin     mtx: &'a CMutex<T>,
12584837cf6SBenno Lossin     _pin: PhantomPinned,
12684837cf6SBenno Lossin }
12784837cf6SBenno Lossin 
12884837cf6SBenno Lossin impl<T> Drop for CMutexGuard<'_, T> {
12984837cf6SBenno Lossin     #[inline]
drop(&mut self)13084837cf6SBenno Lossin     fn drop(&mut self) {
13184837cf6SBenno Lossin         let sguard = self.mtx.spin_lock.acquire();
13284837cf6SBenno Lossin         self.mtx.locked.set(false);
13384837cf6SBenno Lossin         if let Some(list_field) = self.mtx.wait_list.next() {
13484837cf6SBenno Lossin             let wait_entry = list_field.as_ptr().cast::<WaitEntry>();
13584837cf6SBenno Lossin             unsafe { (*wait_entry).thread.unpark() };
13684837cf6SBenno Lossin         }
13784837cf6SBenno Lossin         drop(sguard);
13884837cf6SBenno Lossin     }
13984837cf6SBenno Lossin }
14084837cf6SBenno Lossin 
14184837cf6SBenno Lossin impl<T> Deref for CMutexGuard<'_, T> {
14284837cf6SBenno Lossin     type Target = T;
14384837cf6SBenno Lossin 
14484837cf6SBenno Lossin     #[inline]
deref(&self) -> &Self::Target14584837cf6SBenno Lossin     fn deref(&self) -> &Self::Target {
14684837cf6SBenno Lossin         unsafe { &*self.mtx.data.get() }
14784837cf6SBenno Lossin     }
14884837cf6SBenno Lossin }
14984837cf6SBenno Lossin 
15084837cf6SBenno Lossin impl<T> DerefMut for CMutexGuard<'_, T> {
15184837cf6SBenno Lossin     #[inline]
deref_mut(&mut self) -> &mut Self::Target15284837cf6SBenno Lossin     fn deref_mut(&mut self) -> &mut Self::Target {
15384837cf6SBenno Lossin         unsafe { &mut *self.mtx.data.get() }
15484837cf6SBenno Lossin     }
15584837cf6SBenno Lossin }
15684837cf6SBenno Lossin 
15784837cf6SBenno Lossin #[pin_data]
15884837cf6SBenno Lossin #[repr(C)]
15984837cf6SBenno Lossin struct WaitEntry {
16084837cf6SBenno Lossin     #[pin]
16184837cf6SBenno Lossin     wait_list: ListHead,
16284837cf6SBenno Lossin     thread: Thread,
16384837cf6SBenno Lossin }
16484837cf6SBenno Lossin 
16584837cf6SBenno Lossin impl WaitEntry {
16684837cf6SBenno Lossin     #[inline]
insert_new(list: &ListHead) -> impl PinInit<Self> + '_16784837cf6SBenno Lossin     fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
16884837cf6SBenno Lossin         pin_init!(Self {
16984837cf6SBenno Lossin             thread: thread::current(),
17084837cf6SBenno Lossin             wait_list <- ListHead::insert_prev(list),
17184837cf6SBenno Lossin         })
17284837cf6SBenno Lossin     }
17384837cf6SBenno Lossin }
17484837cf6SBenno Lossin 
17584837cf6SBenno Lossin #[cfg(not(any(feature = "std", feature = "alloc")))]
main()17684837cf6SBenno Lossin fn main() {}
17784837cf6SBenno Lossin 
17884837cf6SBenno Lossin #[allow(dead_code)]
17984837cf6SBenno Lossin #[cfg_attr(test, test)]
18084837cf6SBenno Lossin #[cfg(any(feature = "std", feature = "alloc"))]
main()18184837cf6SBenno Lossin fn main() {
18284837cf6SBenno Lossin     let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
18384837cf6SBenno Lossin     let mut handles = vec![];
18484837cf6SBenno Lossin     let thread_count = 20;
18584837cf6SBenno Lossin     let workload = if cfg!(miri) { 100 } else { 1_000 };
18684837cf6SBenno Lossin     for i in 0..thread_count {
18784837cf6SBenno Lossin         let mtx = mtx.clone();
18884837cf6SBenno Lossin         handles.push(
18984837cf6SBenno Lossin             Builder::new()
19084837cf6SBenno Lossin                 .name(format!("worker #{i}"))
19184837cf6SBenno Lossin                 .spawn(move || {
19284837cf6SBenno Lossin                     for _ in 0..workload {
19384837cf6SBenno Lossin                         *mtx.lock() += 1;
19484837cf6SBenno Lossin                     }
19584837cf6SBenno Lossin                     println!("{i} halfway");
19684837cf6SBenno Lossin                     sleep(Duration::from_millis((i as u64) * 10));
19784837cf6SBenno Lossin                     for _ in 0..workload {
19884837cf6SBenno Lossin                         *mtx.lock() += 1;
19984837cf6SBenno Lossin                     }
20084837cf6SBenno Lossin                     println!("{i} finished");
20184837cf6SBenno Lossin                 })
20284837cf6SBenno Lossin                 .expect("should not fail"),
20384837cf6SBenno Lossin         );
20484837cf6SBenno Lossin     }
20584837cf6SBenno Lossin     for h in handles {
20684837cf6SBenno Lossin         h.join().expect("thread panicked");
20784837cf6SBenno Lossin     }
20884837cf6SBenno Lossin     println!("{:?}", &*mtx.lock());
20984837cf6SBenno Lossin     assert_eq!(*mtx.lock(), workload * thread_count * 2);
21084837cf6SBenno Lossin }
211