xref: /linux/rust/pin-init/examples/pthread_mutex.rs (revision ec7714e4947909190ffb3041a03311a975350fe0)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 // inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs>
4 #![allow(clippy::undocumented_unsafe_blocks)]
5 #![cfg_attr(feature = "alloc", feature(allocator_api))]
6 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
7 
8 #[cfg(not(windows))]
9 mod pthread_mtx {
10     #[cfg(feature = "alloc")]
11     use core::alloc::AllocError;
12     use core::{
13         cell::UnsafeCell,
14         marker::PhantomPinned,
15         mem::MaybeUninit,
16         ops::{Deref, DerefMut},
17         pin::Pin,
18     };
19     use pin_init::*;
20     use std::convert::Infallible;
21 
22     #[pin_data(PinnedDrop)]
23     pub struct PThreadMutex<T> {
24         #[pin]
25         raw: UnsafeCell<libc::pthread_mutex_t>,
26         data: UnsafeCell<T>,
27         #[pin]
28         pin: PhantomPinned,
29     }
30 
31     unsafe impl<T: Send> Send for PThreadMutex<T> {}
32     unsafe impl<T: Send> Sync for PThreadMutex<T> {}
33 
34     #[pinned_drop]
35     impl<T> PinnedDrop for PThreadMutex<T> {
drop(self: Pin<&mut Self>)36         fn drop(self: Pin<&mut Self>) {
37             unsafe {
38                 libc::pthread_mutex_destroy(self.raw.get());
39             }
40         }
41     }
42 
43     #[derive(Debug)]
44     pub enum Error {
45         #[allow(dead_code)]
46         IO(std::io::Error),
47         Alloc,
48     }
49 
50     impl From<Infallible> for Error {
from(e: Infallible) -> Self51         fn from(e: Infallible) -> Self {
52             match e {}
53         }
54     }
55 
56     #[cfg(feature = "alloc")]
57     impl From<AllocError> for Error {
from(_: AllocError) -> Self58         fn from(_: AllocError) -> Self {
59             Self::Alloc
60         }
61     }
62 
63     impl<T> PThreadMutex<T> {
new(data: T) -> impl PinInit<Self, Error>64         pub fn new(data: T) -> impl PinInit<Self, Error> {
65             fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> {
66                 let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| {
67                     // we can cast, because `UnsafeCell` has the same layout as T.
68                     let slot: *mut libc::pthread_mutex_t = slot.cast();
69                     let mut attr = MaybeUninit::uninit();
70                     let attr = attr.as_mut_ptr();
71                     // SAFETY: ptr is valid
72                     let ret = unsafe { libc::pthread_mutexattr_init(attr) };
73                     if ret != 0 {
74                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
75                     }
76                     // SAFETY: attr is initialized
77                     let ret = unsafe {
78                         libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL)
79                     };
80                     if ret != 0 {
81                         // SAFETY: attr is initialized
82                         unsafe { libc::pthread_mutexattr_destroy(attr) };
83                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
84                     }
85                     // SAFETY: slot is valid
86                     unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) };
87                     // SAFETY: attr and slot are valid ptrs and attr is initialized
88                     let ret = unsafe { libc::pthread_mutex_init(slot, attr) };
89                     // SAFETY: attr was initialized
90                     unsafe { libc::pthread_mutexattr_destroy(attr) };
91                     if ret != 0 {
92                         return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
93                     }
94                     Ok(())
95                 };
96                 // SAFETY: mutex has been initialized
97                 unsafe { pin_init_from_closure(init) }
98             }
99             try_pin_init!(Self {
100             data: UnsafeCell::new(data),
101             raw <- init_raw(),
102             pin: PhantomPinned,
103         }? Error)
104         }
105 
lock(&self) -> PThreadMutexGuard<'_, T>106         pub fn lock(&self) -> PThreadMutexGuard<'_, T> {
107             // SAFETY: raw is always initialized
108             unsafe { libc::pthread_mutex_lock(self.raw.get()) };
109             PThreadMutexGuard { mtx: self }
110         }
111     }
112 
113     pub struct PThreadMutexGuard<'a, T> {
114         mtx: &'a PThreadMutex<T>,
115     }
116 
117     impl<T> Drop for PThreadMutexGuard<'_, T> {
drop(&mut self)118         fn drop(&mut self) {
119             // SAFETY: raw is always initialized
120             unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) };
121         }
122     }
123 
124     impl<T> Deref for PThreadMutexGuard<'_, T> {
125         type Target = T;
126 
deref(&self) -> &Self::Target127         fn deref(&self) -> &Self::Target {
128             unsafe { &*self.mtx.data.get() }
129         }
130     }
131 
132     impl<T> DerefMut for PThreadMutexGuard<'_, T> {
deref_mut(&mut self) -> &mut Self::Target133         fn deref_mut(&mut self) -> &mut Self::Target {
134             unsafe { &mut *self.mtx.data.get() }
135         }
136     }
137 }
138 
139 #[cfg_attr(test, test)]
main()140 fn main() {
141     #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))]
142     {
143         use core::pin::Pin;
144         use pin_init::*;
145         use pthread_mtx::*;
146         use std::{
147             sync::Arc,
148             thread::{sleep, Builder},
149             time::Duration,
150         };
151         let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap();
152         let mut handles = vec![];
153         let thread_count = 20;
154         let workload = 1_000_000;
155         for i in 0..thread_count {
156             let mtx = mtx.clone();
157             handles.push(
158                 Builder::new()
159                     .name(format!("worker #{i}"))
160                     .spawn(move || {
161                         for _ in 0..workload {
162                             *mtx.lock() += 1;
163                         }
164                         println!("{i} halfway");
165                         sleep(Duration::from_millis((i as u64) * 10));
166                         for _ in 0..workload {
167                             *mtx.lock() += 1;
168                         }
169                         println!("{i} finished");
170                     })
171                     .expect("should not fail"),
172             );
173         }
174         for h in handles {
175             h.join().expect("thread panicked");
176         }
177         println!("{:?}", &*mtx.lock());
178         assert_eq!(*mtx.lock(), workload * thread_count * 2);
179     }
180 }
181