1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 // inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs> 4 #![allow(clippy::undocumented_unsafe_blocks)] 5 #![cfg_attr(feature = "alloc", feature(allocator_api))] 6 #![cfg_attr(USE_RUSTC_FEATURES, feature(lint_reasons))] 7 #![cfg_attr(USE_RUSTC_FEATURES, feature(raw_ref_op))] 8 9 #[cfg(not(windows))] 10 mod pthread_mtx { 11 #[cfg(feature = "alloc")] 12 use core::alloc::AllocError; 13 use core::{ 14 cell::UnsafeCell, 15 marker::PhantomPinned, 16 mem::MaybeUninit, 17 ops::{Deref, DerefMut}, 18 pin::Pin, 19 }; 20 use pin_init::*; 21 use std::convert::Infallible; 22 23 #[pin_data(PinnedDrop)] 24 pub struct PThreadMutex<T> { 25 #[pin] 26 raw: UnsafeCell<libc::pthread_mutex_t>, 27 data: UnsafeCell<T>, 28 #[pin] 29 pin: PhantomPinned, 30 } 31 32 unsafe impl<T: Send> Send for PThreadMutex<T> {} 33 unsafe impl<T: Send> Sync for PThreadMutex<T> {} 34 35 #[pinned_drop] 36 impl<T> PinnedDrop for PThreadMutex<T> { 37 fn drop(self: Pin<&mut Self>) { 38 unsafe { 39 libc::pthread_mutex_destroy(self.raw.get()); 40 } 41 } 42 } 43 44 #[derive(Debug)] 45 pub enum Error { 46 #[allow(dead_code)] 47 IO(std::io::Error), 48 #[allow(dead_code)] 49 Alloc, 50 } 51 52 impl From<Infallible> for Error { 53 fn from(e: Infallible) -> Self { 54 match e {} 55 } 56 } 57 58 #[cfg(feature = "alloc")] 59 impl From<AllocError> for Error { 60 fn from(_: AllocError) -> Self { 61 Self::Alloc 62 } 63 } 64 65 impl<T> PThreadMutex<T> { 66 #[allow(dead_code)] 67 pub fn new(data: T) -> impl PinInit<Self, Error> { 68 fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> { 69 let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| { 70 // we can cast, because `UnsafeCell` has the same layout as T. 71 let slot: *mut libc::pthread_mutex_t = slot.cast(); 72 let mut attr = MaybeUninit::uninit(); 73 let attr = attr.as_mut_ptr(); 74 // SAFETY: ptr is valid 75 let ret = unsafe { libc::pthread_mutexattr_init(attr) }; 76 if ret != 0 { 77 return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); 78 } 79 // SAFETY: attr is initialized 80 let ret = unsafe { 81 libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL) 82 }; 83 if ret != 0 { 84 // SAFETY: attr is initialized 85 unsafe { libc::pthread_mutexattr_destroy(attr) }; 86 return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); 87 } 88 // SAFETY: slot is valid 89 unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) }; 90 // SAFETY: attr and slot are valid ptrs and attr is initialized 91 let ret = unsafe { libc::pthread_mutex_init(slot, attr) }; 92 // SAFETY: attr was initialized 93 unsafe { libc::pthread_mutexattr_destroy(attr) }; 94 if ret != 0 { 95 return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); 96 } 97 Ok(()) 98 }; 99 // SAFETY: mutex has been initialized 100 unsafe { pin_init_from_closure(init) } 101 } 102 pin_init!(Self { 103 data: UnsafeCell::new(data), 104 raw <- init_raw(), 105 pin: PhantomPinned, 106 }? Error) 107 } 108 109 #[allow(dead_code)] 110 pub fn lock(&self) -> PThreadMutexGuard<'_, T> { 111 // SAFETY: raw is always initialized 112 unsafe { libc::pthread_mutex_lock(self.raw.get()) }; 113 PThreadMutexGuard { mtx: self } 114 } 115 } 116 117 pub struct PThreadMutexGuard<'a, T> { 118 mtx: &'a PThreadMutex<T>, 119 } 120 121 impl<T> Drop for PThreadMutexGuard<'_, T> { 122 fn drop(&mut self) { 123 // SAFETY: raw is always initialized 124 unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) }; 125 } 126 } 127 128 impl<T> Deref for PThreadMutexGuard<'_, T> { 129 type Target = T; 130 131 fn deref(&self) -> &Self::Target { 132 unsafe { &*self.mtx.data.get() } 133 } 134 } 135 136 impl<T> DerefMut for PThreadMutexGuard<'_, T> { 137 fn deref_mut(&mut self) -> &mut Self::Target { 138 unsafe { &mut *self.mtx.data.get() } 139 } 140 } 141 } 142 143 #[cfg_attr(test, test)] 144 #[cfg_attr(all(test, miri), ignore)] 145 fn main() { 146 #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))] 147 { 148 use core::pin::Pin; 149 use pin_init::*; 150 use pthread_mtx::*; 151 use std::{ 152 sync::Arc, 153 thread::{sleep, Builder}, 154 time::Duration, 155 }; 156 let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap(); 157 let mut handles = vec![]; 158 let thread_count = 20; 159 let workload = 1_000_000; 160 for i in 0..thread_count { 161 let mtx = mtx.clone(); 162 handles.push( 163 Builder::new() 164 .name(format!("worker #{i}")) 165 .spawn(move || { 166 for _ in 0..workload { 167 *mtx.lock() += 1; 168 } 169 println!("{i} halfway"); 170 sleep(Duration::from_millis((i as u64) * 10)); 171 for _ in 0..workload { 172 *mtx.lock() += 1; 173 } 174 println!("{i} finished"); 175 }) 176 .expect("should not fail"), 177 ); 178 } 179 for h in handles { 180 h.join().expect("thread panicked"); 181 } 182 println!("{:?}", &*mtx.lock()); 183 assert_eq!(*mtx.lock(), workload * thread_count * 2); 184 } 185 } 186