xref: /linux/rust/pin-init/examples/pthread_mutex.rs (revision 430654211d566f86e8ee533ff1b01a42be6b602c)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 // inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs>
4 #![allow(clippy::undocumented_unsafe_blocks)]
5 #![cfg_attr(feature = "alloc", feature(allocator_api))]
6 
7 #[cfg(not(windows))]
8 mod pthread_mtx {
9     #[cfg(feature = "alloc")]
10     use core::alloc::AllocError;
11     use core::{
12         cell::UnsafeCell,
13         marker::PhantomPinned,
14         mem::MaybeUninit,
15         ops::{Deref, DerefMut},
16         pin::Pin,
17     };
18     use pin_init::*;
19     use std::convert::Infallible;
20 
21     #[pin_data(PinnedDrop)]
22     pub struct PThreadMutex<T> {
23         #[pin]
24         raw: UnsafeCell<libc::pthread_mutex_t>,
25         data: UnsafeCell<T>,
26         #[pin]
27         pin: PhantomPinned,
28     }
29 
30     unsafe impl<T: Send> Send for PThreadMutex<T> {}
31     unsafe impl<T: Send> Sync for PThreadMutex<T> {}
32 
33     #[pinned_drop]
34     impl<T> PinnedDrop for PThreadMutex<T> {
35         fn drop(self: Pin<&mut Self>) {
36             unsafe {
37                 libc::pthread_mutex_destroy(self.raw.get());
38             }
39         }
40     }
41 
42     #[derive(Debug)]
43     pub enum Error {
44         #[allow(dead_code)]
45         IO(std::io::Error),
46         #[allow(dead_code)]
47         Alloc,
48     }
49 
50     impl From<Infallible> for Error {
51         fn from(e: Infallible) -> Self {
52             match e {}
53         }
54     }
55 
56     #[cfg(feature = "alloc")]
57     impl From<AllocError> for Error {
58         fn from(_: AllocError) -> Self {
59             Self::Alloc
60         }
61     }
62 
63     impl<T> PThreadMutex<T> {
64         #[allow(dead_code)]
65         pub fn new(data: T) -> impl PinInit<Self, Error> {
66             fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> {
67                 let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| {
68                     // we can cast, because `UnsafeCell` has the same layout as T.
69                     let slot: *mut libc::pthread_mutex_t = slot.cast();
70                     let mut attr = MaybeUninit::uninit();
71                     let attr = attr.as_mut_ptr();
72                     // SAFETY: ptr is valid
73                     let ret = unsafe { libc::pthread_mutexattr_init(attr) };
74                     if ret != 0 {
75                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
76                     }
77                     // SAFETY: attr is initialized
78                     let ret = unsafe {
79                         libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL)
80                     };
81                     if ret != 0 {
82                         // SAFETY: attr is initialized
83                         unsafe { libc::pthread_mutexattr_destroy(attr) };
84                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
85                     }
86                     // SAFETY: slot is valid
87                     unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) };
88                     // SAFETY: attr and slot are valid ptrs and attr is initialized
89                     let ret = unsafe { libc::pthread_mutex_init(slot, attr) };
90                     // SAFETY: attr was initialized
91                     unsafe { libc::pthread_mutexattr_destroy(attr) };
92                     if ret != 0 {
93                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
94                     }
95                     Ok(())
96                 };
97                 // SAFETY: mutex has been initialized
98                 unsafe { pin_init_from_closure(init) }
99             }
100             pin_init!(Self {
101                 data: UnsafeCell::new(data),
102                 raw <- init_raw(),
103                 pin: PhantomPinned,
104             }? Error)
105         }
106 
107         #[allow(dead_code)]
108         pub fn lock(&self) -> PThreadMutexGuard<'_, T> {
109             // SAFETY: raw is always initialized
110             unsafe { libc::pthread_mutex_lock(self.raw.get()) };
111             PThreadMutexGuard { mtx: self }
112         }
113     }
114 
115     pub struct PThreadMutexGuard<'a, T> {
116         mtx: &'a PThreadMutex<T>,
117     }
118 
119     impl<T> Drop for PThreadMutexGuard<'_, T> {
120         fn drop(&mut self) {
121             // SAFETY: raw is always initialized
122             unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) };
123         }
124     }
125 
126     impl<T> Deref for PThreadMutexGuard<'_, T> {
127         type Target = T;
128 
129         fn deref(&self) -> &Self::Target {
130             unsafe { &*self.mtx.data.get() }
131         }
132     }
133 
134     impl<T> DerefMut for PThreadMutexGuard<'_, T> {
135         fn deref_mut(&mut self) -> &mut Self::Target {
136             unsafe { &mut *self.mtx.data.get() }
137         }
138     }
139 }
140 
141 #[cfg_attr(test, test)]
142 #[cfg_attr(all(test, miri), ignore)]
143 fn main() {
144     #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))]
145     {
146         use core::pin::Pin;
147         use pin_init::*;
148         use pthread_mtx::*;
149         use std::{
150             sync::Arc,
151             thread::{sleep, Builder},
152             time::Duration,
153         };
154         let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap();
155         let mut handles = vec![];
156         let thread_count = 20;
157         let workload = 1_000_000;
158         for i in 0..thread_count {
159             let mtx = mtx.clone();
160             handles.push(
161                 Builder::new()
162                     .name(format!("worker #{i}"))
163                     .spawn(move || {
164                         for _ in 0..workload {
165                             *mtx.lock() += 1;
166                         }
167                         println!("{i} halfway");
168                         sleep(Duration::from_millis((i as u64) * 10));
169                         for _ in 0..workload {
170                             *mtx.lock() += 1;
171                         }
172                         println!("{i} finished");
173                     })
174                     .expect("should not fail"),
175             );
176         }
177         for h in handles {
178             h.join().expect("thread panicked");
179         }
180         println!("{:?}", *mtx.lock());
181         assert_eq!(*mtx.lock(), workload * thread_count * 2);
182     }
183 }
184