xref: /linux/rust/pin-init/examples/pthread_mutex.rs (revision ec7714e4947909190ffb3041a03311a975350fe0)
184837cf6SBenno Lossin // SPDX-License-Identifier: Apache-2.0 OR MIT
284837cf6SBenno Lossin 
3193b5a75SMiguel Ojeda // inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs>
484837cf6SBenno Lossin #![allow(clippy::undocumented_unsafe_blocks)]
584837cf6SBenno Lossin #![cfg_attr(feature = "alloc", feature(allocator_api))]
65c4167b4SBenno Lossin #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
75c4167b4SBenno Lossin 
884837cf6SBenno Lossin #[cfg(not(windows))]
984837cf6SBenno Lossin mod pthread_mtx {
1084837cf6SBenno Lossin     #[cfg(feature = "alloc")]
1184837cf6SBenno Lossin     use core::alloc::AllocError;
1284837cf6SBenno Lossin     use core::{
1384837cf6SBenno Lossin         cell::UnsafeCell,
1484837cf6SBenno Lossin         marker::PhantomPinned,
1584837cf6SBenno Lossin         mem::MaybeUninit,
1684837cf6SBenno Lossin         ops::{Deref, DerefMut},
1784837cf6SBenno Lossin         pin::Pin,
1884837cf6SBenno Lossin     };
1984837cf6SBenno Lossin     use pin_init::*;
2084837cf6SBenno Lossin     use std::convert::Infallible;
2184837cf6SBenno Lossin 
2284837cf6SBenno Lossin     #[pin_data(PinnedDrop)]
2384837cf6SBenno Lossin     pub struct PThreadMutex<T> {
2484837cf6SBenno Lossin         #[pin]
2584837cf6SBenno Lossin         raw: UnsafeCell<libc::pthread_mutex_t>,
2684837cf6SBenno Lossin         data: UnsafeCell<T>,
2784837cf6SBenno Lossin         #[pin]
2884837cf6SBenno Lossin         pin: PhantomPinned,
2984837cf6SBenno Lossin     }
3084837cf6SBenno Lossin 
3184837cf6SBenno Lossin     unsafe impl<T: Send> Send for PThreadMutex<T> {}
3284837cf6SBenno Lossin     unsafe impl<T: Send> Sync for PThreadMutex<T> {}
3384837cf6SBenno Lossin 
3484837cf6SBenno Lossin     #[pinned_drop]
3584837cf6SBenno Lossin     impl<T> PinnedDrop for PThreadMutex<T> {
drop(self: Pin<&mut Self>)3684837cf6SBenno Lossin         fn drop(self: Pin<&mut Self>) {
3784837cf6SBenno Lossin             unsafe {
3884837cf6SBenno Lossin                 libc::pthread_mutex_destroy(self.raw.get());
3984837cf6SBenno Lossin             }
4084837cf6SBenno Lossin         }
4184837cf6SBenno Lossin     }
4284837cf6SBenno Lossin 
4384837cf6SBenno Lossin     #[derive(Debug)]
4484837cf6SBenno Lossin     pub enum Error {
45*39051adbSBenno Lossin         #[allow(dead_code)]
4684837cf6SBenno Lossin         IO(std::io::Error),
4784837cf6SBenno Lossin         Alloc,
4884837cf6SBenno Lossin     }
4984837cf6SBenno Lossin 
5084837cf6SBenno Lossin     impl From<Infallible> for Error {
from(e: Infallible) -> Self5184837cf6SBenno Lossin         fn from(e: Infallible) -> Self {
5284837cf6SBenno Lossin             match e {}
5384837cf6SBenno Lossin         }
5484837cf6SBenno Lossin     }
5584837cf6SBenno Lossin 
5684837cf6SBenno Lossin     #[cfg(feature = "alloc")]
5784837cf6SBenno Lossin     impl From<AllocError> for Error {
from(_: AllocError) -> Self5884837cf6SBenno Lossin         fn from(_: AllocError) -> Self {
5984837cf6SBenno Lossin             Self::Alloc
6084837cf6SBenno Lossin         }
6184837cf6SBenno Lossin     }
6284837cf6SBenno Lossin 
6384837cf6SBenno Lossin     impl<T> PThreadMutex<T> {
new(data: T) -> impl PinInit<Self, Error>6484837cf6SBenno Lossin         pub fn new(data: T) -> impl PinInit<Self, Error> {
6584837cf6SBenno Lossin             fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> {
6684837cf6SBenno Lossin                 let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| {
6784837cf6SBenno Lossin                     // we can cast, because `UnsafeCell` has the same layout as T.
6884837cf6SBenno Lossin                     let slot: *mut libc::pthread_mutex_t = slot.cast();
6984837cf6SBenno Lossin                     let mut attr = MaybeUninit::uninit();
7084837cf6SBenno Lossin                     let attr = attr.as_mut_ptr();
7184837cf6SBenno Lossin                     // SAFETY: ptr is valid
7284837cf6SBenno Lossin                     let ret = unsafe { libc::pthread_mutexattr_init(attr) };
7384837cf6SBenno Lossin                     if ret != 0 {
7484837cf6SBenno Lossin                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
7584837cf6SBenno Lossin                     }
7684837cf6SBenno Lossin                     // SAFETY: attr is initialized
7784837cf6SBenno Lossin                     let ret = unsafe {
7884837cf6SBenno Lossin                         libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL)
7984837cf6SBenno Lossin                     };
8084837cf6SBenno Lossin                     if ret != 0 {
8184837cf6SBenno Lossin                         // SAFETY: attr is initialized
8284837cf6SBenno Lossin                         unsafe { libc::pthread_mutexattr_destroy(attr) };
8384837cf6SBenno Lossin                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
8484837cf6SBenno Lossin                     }
8584837cf6SBenno Lossin                     // SAFETY: slot is valid
8684837cf6SBenno Lossin                     unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) };
8784837cf6SBenno Lossin                     // SAFETY: attr and slot are valid ptrs and attr is initialized
8884837cf6SBenno Lossin                     let ret = unsafe { libc::pthread_mutex_init(slot, attr) };
8984837cf6SBenno Lossin                     // SAFETY: attr was initialized
9084837cf6SBenno Lossin                     unsafe { libc::pthread_mutexattr_destroy(attr) };
9184837cf6SBenno Lossin                     if ret != 0 {
9284837cf6SBenno Lossin                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
9384837cf6SBenno Lossin                     }
9484837cf6SBenno Lossin                     Ok(())
9584837cf6SBenno Lossin                 };
9684837cf6SBenno Lossin                 // SAFETY: mutex has been initialized
9784837cf6SBenno Lossin                 unsafe { pin_init_from_closure(init) }
9884837cf6SBenno Lossin             }
9984837cf6SBenno Lossin             try_pin_init!(Self {
10084837cf6SBenno Lossin             data: UnsafeCell::new(data),
10184837cf6SBenno Lossin             raw <- init_raw(),
10284837cf6SBenno Lossin             pin: PhantomPinned,
10384837cf6SBenno Lossin         }? Error)
10484837cf6SBenno Lossin         }
10584837cf6SBenno Lossin 
lock(&self) -> PThreadMutexGuard<'_, T>10684837cf6SBenno Lossin         pub fn lock(&self) -> PThreadMutexGuard<'_, T> {
10784837cf6SBenno Lossin             // SAFETY: raw is always initialized
10884837cf6SBenno Lossin             unsafe { libc::pthread_mutex_lock(self.raw.get()) };
10984837cf6SBenno Lossin             PThreadMutexGuard { mtx: self }
11084837cf6SBenno Lossin         }
11184837cf6SBenno Lossin     }
11284837cf6SBenno Lossin 
11384837cf6SBenno Lossin     pub struct PThreadMutexGuard<'a, T> {
11484837cf6SBenno Lossin         mtx: &'a PThreadMutex<T>,
11584837cf6SBenno Lossin     }
11684837cf6SBenno Lossin 
11784837cf6SBenno Lossin     impl<T> Drop for PThreadMutexGuard<'_, T> {
drop(&mut self)11884837cf6SBenno Lossin         fn drop(&mut self) {
11984837cf6SBenno Lossin             // SAFETY: raw is always initialized
12084837cf6SBenno Lossin             unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) };
12184837cf6SBenno Lossin         }
12284837cf6SBenno Lossin     }
12384837cf6SBenno Lossin 
12484837cf6SBenno Lossin     impl<T> Deref for PThreadMutexGuard<'_, T> {
12584837cf6SBenno Lossin         type Target = T;
12684837cf6SBenno Lossin 
deref(&self) -> &Self::Target12784837cf6SBenno Lossin         fn deref(&self) -> &Self::Target {
12884837cf6SBenno Lossin             unsafe { &*self.mtx.data.get() }
12984837cf6SBenno Lossin         }
13084837cf6SBenno Lossin     }
13184837cf6SBenno Lossin 
13284837cf6SBenno Lossin     impl<T> DerefMut for PThreadMutexGuard<'_, T> {
deref_mut(&mut self) -> &mut Self::Target13384837cf6SBenno Lossin         fn deref_mut(&mut self) -> &mut Self::Target {
13484837cf6SBenno Lossin             unsafe { &mut *self.mtx.data.get() }
13584837cf6SBenno Lossin         }
13684837cf6SBenno Lossin     }
13784837cf6SBenno Lossin }
13884837cf6SBenno Lossin 
13984837cf6SBenno Lossin #[cfg_attr(test, test)]
main()14084837cf6SBenno Lossin fn main() {
14184837cf6SBenno Lossin     #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))]
14284837cf6SBenno Lossin     {
14384837cf6SBenno Lossin         use core::pin::Pin;
14484837cf6SBenno Lossin         use pin_init::*;
14584837cf6SBenno Lossin         use pthread_mtx::*;
14684837cf6SBenno Lossin         use std::{
14784837cf6SBenno Lossin             sync::Arc,
14884837cf6SBenno Lossin             thread::{sleep, Builder},
14984837cf6SBenno Lossin             time::Duration,
15084837cf6SBenno Lossin         };
15184837cf6SBenno Lossin         let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap();
15284837cf6SBenno Lossin         let mut handles = vec![];
15384837cf6SBenno Lossin         let thread_count = 20;
15484837cf6SBenno Lossin         let workload = 1_000_000;
15584837cf6SBenno Lossin         for i in 0..thread_count {
15684837cf6SBenno Lossin             let mtx = mtx.clone();
15784837cf6SBenno Lossin             handles.push(
15884837cf6SBenno Lossin                 Builder::new()
15984837cf6SBenno Lossin                     .name(format!("worker #{i}"))
16084837cf6SBenno Lossin                     .spawn(move || {
16184837cf6SBenno Lossin                         for _ in 0..workload {
16284837cf6SBenno Lossin                             *mtx.lock() += 1;
16384837cf6SBenno Lossin                         }
16484837cf6SBenno Lossin                         println!("{i} halfway");
16584837cf6SBenno Lossin                         sleep(Duration::from_millis((i as u64) * 10));
16684837cf6SBenno Lossin                         for _ in 0..workload {
16784837cf6SBenno Lossin                             *mtx.lock() += 1;
16884837cf6SBenno Lossin                         }
16984837cf6SBenno Lossin                         println!("{i} finished");
17084837cf6SBenno Lossin                     })
17184837cf6SBenno Lossin                     .expect("should not fail"),
17284837cf6SBenno Lossin             );
17384837cf6SBenno Lossin         }
17484837cf6SBenno Lossin         for h in handles {
17584837cf6SBenno Lossin             h.join().expect("thread panicked");
17684837cf6SBenno Lossin         }
17784837cf6SBenno Lossin         println!("{:?}", &*mtx.lock());
17884837cf6SBenno Lossin         assert_eq!(*mtx.lock(), workload * thread_count * 2);
17984837cf6SBenno Lossin     }
18084837cf6SBenno Lossin }
181