xref: /linux/rust/pin-init/examples/pthread_mutex.rs (revision 6fa6b5cb60490db2591bb93872b95f72315e5f53)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 // inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs>
4 #![allow(clippy::undocumented_unsafe_blocks)]
5 #![cfg_attr(feature = "alloc", feature(allocator_api))]
6 #![cfg_attr(USE_RUSTC_FEATURES, feature(lint_reasons))]
7 #![cfg_attr(USE_RUSTC_FEATURES, feature(raw_ref_op))]
8 
9 #[cfg(not(windows))]
10 mod pthread_mtx {
11     #[cfg(feature = "alloc")]
12     use core::alloc::AllocError;
13     use core::{
14         cell::UnsafeCell,
15         marker::PhantomPinned,
16         mem::MaybeUninit,
17         ops::{Deref, DerefMut},
18         pin::Pin,
19     };
20     use pin_init::*;
21     use std::convert::Infallible;
22 
23     #[pin_data(PinnedDrop)]
24     pub struct PThreadMutex<T> {
25         #[pin]
26         raw: UnsafeCell<libc::pthread_mutex_t>,
27         data: UnsafeCell<T>,
28         #[pin]
29         pin: PhantomPinned,
30     }
31 
32     unsafe impl<T: Send> Send for PThreadMutex<T> {}
33     unsafe impl<T: Send> Sync for PThreadMutex<T> {}
34 
35     #[pinned_drop]
36     impl<T> PinnedDrop for PThreadMutex<T> {
37         fn drop(self: Pin<&mut Self>) {
38             unsafe {
39                 libc::pthread_mutex_destroy(self.raw.get());
40             }
41         }
42     }
43 
44     #[derive(Debug)]
45     pub enum Error {
46         #[allow(dead_code)]
47         IO(std::io::Error),
48         #[allow(dead_code)]
49         Alloc,
50     }
51 
52     impl From<Infallible> for Error {
53         fn from(e: Infallible) -> Self {
54             match e {}
55         }
56     }
57 
58     #[cfg(feature = "alloc")]
59     impl From<AllocError> for Error {
60         fn from(_: AllocError) -> Self {
61             Self::Alloc
62         }
63     }
64 
65     impl<T> PThreadMutex<T> {
66         #[allow(dead_code)]
67         pub fn new(data: T) -> impl PinInit<Self, Error> {
68             fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> {
69                 let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| {
70                     // we can cast, because `UnsafeCell` has the same layout as T.
71                     let slot: *mut libc::pthread_mutex_t = slot.cast();
72                     let mut attr = MaybeUninit::uninit();
73                     let attr = attr.as_mut_ptr();
74                     // SAFETY: ptr is valid
75                     let ret = unsafe { libc::pthread_mutexattr_init(attr) };
76                     if ret != 0 {
77                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
78                     }
79                     // SAFETY: attr is initialized
80                     let ret = unsafe {
81                         libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL)
82                     };
83                     if ret != 0 {
84                         // SAFETY: attr is initialized
85                         unsafe { libc::pthread_mutexattr_destroy(attr) };
86                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
87                     }
88                     // SAFETY: slot is valid
89                     unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) };
90                     // SAFETY: attr and slot are valid ptrs and attr is initialized
91                     let ret = unsafe { libc::pthread_mutex_init(slot, attr) };
92                     // SAFETY: attr was initialized
93                     unsafe { libc::pthread_mutexattr_destroy(attr) };
94                     if ret != 0 {
95                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
96                     }
97                     Ok(())
98                 };
99                 // SAFETY: mutex has been initialized
100                 unsafe { pin_init_from_closure(init) }
101             }
102             pin_init!(Self {
103                 data: UnsafeCell::new(data),
104                 raw <- init_raw(),
105                 pin: PhantomPinned,
106             }? Error)
107         }
108 
109         #[allow(dead_code)]
110         pub fn lock(&self) -> PThreadMutexGuard<'_, T> {
111             // SAFETY: raw is always initialized
112             unsafe { libc::pthread_mutex_lock(self.raw.get()) };
113             PThreadMutexGuard { mtx: self }
114         }
115     }
116 
117     pub struct PThreadMutexGuard<'a, T> {
118         mtx: &'a PThreadMutex<T>,
119     }
120 
121     impl<T> Drop for PThreadMutexGuard<'_, T> {
122         fn drop(&mut self) {
123             // SAFETY: raw is always initialized
124             unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) };
125         }
126     }
127 
128     impl<T> Deref for PThreadMutexGuard<'_, T> {
129         type Target = T;
130 
131         fn deref(&self) -> &Self::Target {
132             unsafe { &*self.mtx.data.get() }
133         }
134     }
135 
136     impl<T> DerefMut for PThreadMutexGuard<'_, T> {
137         fn deref_mut(&mut self) -> &mut Self::Target {
138             unsafe { &mut *self.mtx.data.get() }
139         }
140     }
141 }
142 
143 #[cfg_attr(test, test)]
144 #[cfg_attr(all(test, miri), ignore)]
145 fn main() {
146     #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))]
147     {
148         use core::pin::Pin;
149         use pin_init::*;
150         use pthread_mtx::*;
151         use std::{
152             sync::Arc,
153             thread::{sleep, Builder},
154             time::Duration,
155         };
156         let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap();
157         let mut handles = vec![];
158         let thread_count = 20;
159         let workload = 1_000_000;
160         for i in 0..thread_count {
161             let mtx = mtx.clone();
162             handles.push(
163                 Builder::new()
164                     .name(format!("worker #{i}"))
165                     .spawn(move || {
166                         for _ in 0..workload {
167                             *mtx.lock() += 1;
168                         }
169                         println!("{i} halfway");
170                         sleep(Duration::from_millis((i as u64) * 10));
171                         for _ in 0..workload {
172                             *mtx.lock() += 1;
173                         }
174                         println!("{i} finished");
175                     })
176                     .expect("should not fail"),
177             );
178         }
179         for h in handles {
180             h.join().expect("thread panicked");
181         }
182         println!("{:?}", &*mtx.lock());
183         assert_eq!(*mtx.lock(), workload * thread_count * 2);
184     }
185 }
186