1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 // inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs> 4 #![allow(clippy::undocumented_unsafe_blocks)] 5 #![cfg_attr(feature = "alloc", feature(allocator_api))] 6 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))] 7 8 #[cfg(not(windows))] 9 mod pthread_mtx { 10 #[cfg(feature = "alloc")] 11 use core::alloc::AllocError; 12 use core::{ 13 cell::UnsafeCell, 14 marker::PhantomPinned, 15 mem::MaybeUninit, 16 ops::{Deref, DerefMut}, 17 pin::Pin, 18 }; 19 use pin_init::*; 20 use std::convert::Infallible; 21 22 #[pin_data(PinnedDrop)] 23 pub struct PThreadMutex<T> { 24 #[pin] 25 raw: UnsafeCell<libc::pthread_mutex_t>, 26 data: UnsafeCell<T>, 27 #[pin] 28 pin: PhantomPinned, 29 } 30 31 unsafe impl<T: Send> Send for PThreadMutex<T> {} 32 unsafe impl<T: Send> Sync for PThreadMutex<T> {} 33 34 #[pinned_drop] 35 impl<T> PinnedDrop for PThreadMutex<T> { 36 fn drop(self: Pin<&mut Self>) { 37 unsafe { 38 libc::pthread_mutex_destroy(self.raw.get()); 39 } 40 } 41 } 42 43 #[derive(Debug)] 44 pub enum Error { 45 #[allow(dead_code)] 46 IO(std::io::Error), 47 #[allow(dead_code)] 48 Alloc, 49 } 50 51 impl From<Infallible> for Error { 52 fn from(e: Infallible) -> Self { 53 match e {} 54 } 55 } 56 57 #[cfg(feature = "alloc")] 58 impl From<AllocError> for Error { 59 fn from(_: AllocError) -> Self { 60 Self::Alloc 61 } 62 } 63 64 impl<T> PThreadMutex<T> { 65 #[allow(dead_code)] 66 pub fn new(data: T) -> impl PinInit<Self, Error> { 67 fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> { 68 let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| { 69 // we can cast, because `UnsafeCell` has the same layout as T. 70 let slot: *mut libc::pthread_mutex_t = slot.cast(); 71 let mut attr = MaybeUninit::uninit(); 72 let attr = attr.as_mut_ptr(); 73 // SAFETY: ptr is valid 74 let ret = unsafe { libc::pthread_mutexattr_init(attr) }; 75 if ret != 0 { 76 return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); 77 } 78 // SAFETY: attr is initialized 79 let ret = unsafe { 80 libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL) 81 }; 82 if ret != 0 { 83 // SAFETY: attr is initialized 84 unsafe { libc::pthread_mutexattr_destroy(attr) }; 85 return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); 86 } 87 // SAFETY: slot is valid 88 unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) }; 89 // SAFETY: attr and slot are valid ptrs and attr is initialized 90 let ret = unsafe { libc::pthread_mutex_init(slot, attr) }; 91 // SAFETY: attr was initialized 92 unsafe { libc::pthread_mutexattr_destroy(attr) }; 93 if ret != 0 { 94 return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); 95 } 96 Ok(()) 97 }; 98 // SAFETY: mutex has been initialized 99 unsafe { pin_init_from_closure(init) } 100 } 101 try_pin_init!(Self { 102 data: UnsafeCell::new(data), 103 raw <- init_raw(), 104 pin: PhantomPinned, 105 }? Error) 106 } 107 108 #[allow(dead_code)] 109 pub fn lock(&self) -> PThreadMutexGuard<'_, T> { 110 // SAFETY: raw is always initialized 111 unsafe { libc::pthread_mutex_lock(self.raw.get()) }; 112 PThreadMutexGuard { mtx: self } 113 } 114 } 115 116 pub struct PThreadMutexGuard<'a, T> { 117 mtx: &'a PThreadMutex<T>, 118 } 119 120 impl<T> Drop for PThreadMutexGuard<'_, T> { 121 fn drop(&mut self) { 122 // SAFETY: raw is always initialized 123 unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) }; 124 } 125 } 126 127 impl<T> Deref for PThreadMutexGuard<'_, T> { 128 type Target = T; 129 130 fn deref(&self) -> &Self::Target { 131 unsafe { &*self.mtx.data.get() } 132 } 133 } 134 135 impl<T> DerefMut for PThreadMutexGuard<'_, T> { 136 fn deref_mut(&mut self) -> &mut Self::Target { 137 unsafe { &mut *self.mtx.data.get() } 138 } 139 } 140 } 141 142 #[cfg_attr(test, test)] 143 #[cfg_attr(all(test, miri), ignore)] 144 fn main() { 145 #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))] 146 { 147 use core::pin::Pin; 148 use pin_init::*; 149 use pthread_mtx::*; 150 use std::{ 151 sync::Arc, 152 thread::{sleep, Builder}, 153 time::Duration, 154 }; 155 let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap(); 156 let mut handles = vec![]; 157 let thread_count = 20; 158 let workload = 1_000_000; 159 for i in 0..thread_count { 160 let mtx = mtx.clone(); 161 handles.push( 162 Builder::new() 163 .name(format!("worker #{i}")) 164 .spawn(move || { 165 for _ in 0..workload { 166 *mtx.lock() += 1; 167 } 168 println!("{i} halfway"); 169 sleep(Duration::from_millis((i as u64) * 10)); 170 for _ in 0..workload { 171 *mtx.lock() += 1; 172 } 173 println!("{i} finished"); 174 }) 175 .expect("should not fail"), 176 ); 177 } 178 for h in handles { 179 h.join().expect("thread panicked"); 180 } 181 println!("{:?}", &*mtx.lock()); 182 assert_eq!(*mtx.lock(), workload * thread_count * 2); 183 } 184 } 185