xref: /linux/rust/pin-init/examples/pthread_mutex.rs (revision 9fd2da71c301184d98fe37674ca8d017d1ce6600)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 // inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs>
4 #![allow(clippy::undocumented_unsafe_blocks)]
5 #![cfg_attr(feature = "alloc", feature(allocator_api))]
6 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
7 
8 #[cfg(not(windows))]
9 mod pthread_mtx {
10     #[cfg(feature = "alloc")]
11     use core::alloc::AllocError;
12     use core::{
13         cell::UnsafeCell,
14         marker::PhantomPinned,
15         mem::MaybeUninit,
16         ops::{Deref, DerefMut},
17         pin::Pin,
18     };
19     use pin_init::*;
20     use std::convert::Infallible;
21 
22     #[pin_data(PinnedDrop)]
23     pub struct PThreadMutex<T> {
24         #[pin]
25         raw: UnsafeCell<libc::pthread_mutex_t>,
26         data: UnsafeCell<T>,
27         #[pin]
28         pin: PhantomPinned,
29     }
30 
31     unsafe impl<T: Send> Send for PThreadMutex<T> {}
32     unsafe impl<T: Send> Sync for PThreadMutex<T> {}
33 
34     #[pinned_drop]
35     impl<T> PinnedDrop for PThreadMutex<T> {
36         fn drop(self: Pin<&mut Self>) {
37             unsafe {
38                 libc::pthread_mutex_destroy(self.raw.get());
39             }
40         }
41     }
42 
43     #[derive(Debug)]
44     pub enum Error {
45         #[allow(dead_code)]
46         IO(std::io::Error),
47         #[allow(dead_code)]
48         Alloc,
49     }
50 
51     impl From<Infallible> for Error {
52         fn from(e: Infallible) -> Self {
53             match e {}
54         }
55     }
56 
57     #[cfg(feature = "alloc")]
58     impl From<AllocError> for Error {
59         fn from(_: AllocError) -> Self {
60             Self::Alloc
61         }
62     }
63 
64     impl<T> PThreadMutex<T> {
65         #[allow(dead_code)]
66         pub fn new(data: T) -> impl PinInit<Self, Error> {
67             fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> {
68                 let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| {
69                     // we can cast, because `UnsafeCell` has the same layout as T.
70                     let slot: *mut libc::pthread_mutex_t = slot.cast();
71                     let mut attr = MaybeUninit::uninit();
72                     let attr = attr.as_mut_ptr();
73                     // SAFETY: ptr is valid
74                     let ret = unsafe { libc::pthread_mutexattr_init(attr) };
75                     if ret != 0 {
76                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
77                     }
78                     // SAFETY: attr is initialized
79                     let ret = unsafe {
80                         libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL)
81                     };
82                     if ret != 0 {
83                         // SAFETY: attr is initialized
84                         unsafe { libc::pthread_mutexattr_destroy(attr) };
85                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
86                     }
87                     // SAFETY: slot is valid
88                     unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) };
89                     // SAFETY: attr and slot are valid ptrs and attr is initialized
90                     let ret = unsafe { libc::pthread_mutex_init(slot, attr) };
91                     // SAFETY: attr was initialized
92                     unsafe { libc::pthread_mutexattr_destroy(attr) };
93                     if ret != 0 {
94                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
95                     }
96                     Ok(())
97                 };
98                 // SAFETY: mutex has been initialized
99                 unsafe { pin_init_from_closure(init) }
100             }
101             try_pin_init!(Self {
102             data: UnsafeCell::new(data),
103             raw <- init_raw(),
104             pin: PhantomPinned,
105         }? Error)
106         }
107 
108         #[allow(dead_code)]
109         pub fn lock(&self) -> PThreadMutexGuard<'_, T> {
110             // SAFETY: raw is always initialized
111             unsafe { libc::pthread_mutex_lock(self.raw.get()) };
112             PThreadMutexGuard { mtx: self }
113         }
114     }
115 
116     pub struct PThreadMutexGuard<'a, T> {
117         mtx: &'a PThreadMutex<T>,
118     }
119 
120     impl<T> Drop for PThreadMutexGuard<'_, T> {
121         fn drop(&mut self) {
122             // SAFETY: raw is always initialized
123             unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) };
124         }
125     }
126 
127     impl<T> Deref for PThreadMutexGuard<'_, T> {
128         type Target = T;
129 
130         fn deref(&self) -> &Self::Target {
131             unsafe { &*self.mtx.data.get() }
132         }
133     }
134 
135     impl<T> DerefMut for PThreadMutexGuard<'_, T> {
136         fn deref_mut(&mut self) -> &mut Self::Target {
137             unsafe { &mut *self.mtx.data.get() }
138         }
139     }
140 }
141 
142 #[cfg_attr(test, test)]
143 #[cfg_attr(all(test, miri), ignore)]
144 fn main() {
145     #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))]
146     {
147         use core::pin::Pin;
148         use pin_init::*;
149         use pthread_mtx::*;
150         use std::{
151             sync::Arc,
152             thread::{sleep, Builder},
153             time::Duration,
154         };
155         let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap();
156         let mut handles = vec![];
157         let thread_count = 20;
158         let workload = 1_000_000;
159         for i in 0..thread_count {
160             let mtx = mtx.clone();
161             handles.push(
162                 Builder::new()
163                     .name(format!("worker #{i}"))
164                     .spawn(move || {
165                         for _ in 0..workload {
166                             *mtx.lock() += 1;
167                         }
168                         println!("{i} halfway");
169                         sleep(Duration::from_millis((i as u64) * 10));
170                         for _ in 0..workload {
171                             *mtx.lock() += 1;
172                         }
173                         println!("{i} finished");
174                     })
175                     .expect("should not fail"),
176             );
177         }
178         for h in handles {
179             h.join().expect("thread panicked");
180         }
181         println!("{:?}", &*mtx.lock());
182         assert_eq!(*mtx.lock(), workload * thread_count * 2);
183     }
184 }
185