1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 // inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs> 4 #![allow(clippy::undocumented_unsafe_blocks)] 5 #![cfg_attr(feature = "alloc", feature(allocator_api))] 6 7 #[cfg(not(windows))] 8 mod pthread_mtx { 9 #[cfg(feature = "alloc")] 10 use core::alloc::AllocError; 11 use core::{ 12 cell::UnsafeCell, 13 marker::PhantomPinned, 14 mem::MaybeUninit, 15 ops::{Deref, DerefMut}, 16 pin::Pin, 17 }; 18 use pin_init::*; 19 use std::convert::Infallible; 20 21 #[pin_data(PinnedDrop)] 22 pub struct PThreadMutex<T> { 23 #[pin] 24 raw: UnsafeCell<libc::pthread_mutex_t>, 25 data: UnsafeCell<T>, 26 #[pin] 27 pin: PhantomPinned, 28 } 29 30 unsafe impl<T: Send> Send for PThreadMutex<T> {} 31 unsafe impl<T: Send> Sync for PThreadMutex<T> {} 32 33 #[pinned_drop] 34 impl<T> PinnedDrop for PThreadMutex<T> { 35 fn drop(self: Pin<&mut Self>) { 36 unsafe { 37 libc::pthread_mutex_destroy(self.raw.get()); 38 } 39 } 40 } 41 42 #[derive(Debug)] 43 pub enum Error { 44 #[allow(dead_code)] 45 IO(std::io::Error), 46 #[allow(dead_code)] 47 Alloc, 48 } 49 50 impl From<Infallible> for Error { 51 fn from(e: Infallible) -> Self { 52 match e {} 53 } 54 } 55 56 #[cfg(feature = "alloc")] 57 impl From<AllocError> for Error { 58 fn from(_: AllocError) -> Self { 59 Self::Alloc 60 } 61 } 62 63 impl<T> PThreadMutex<T> { 64 #[allow(dead_code)] 65 pub fn new(data: T) -> impl PinInit<Self, Error> { 66 fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> { 67 let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| { 68 // we can cast, because `UnsafeCell` has the same layout as T. 69 let slot: *mut libc::pthread_mutex_t = slot.cast(); 70 let mut attr = MaybeUninit::uninit(); 71 let attr = attr.as_mut_ptr(); 72 // SAFETY: ptr is valid 73 let ret = unsafe { libc::pthread_mutexattr_init(attr) }; 74 if ret != 0 { 75 return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); 76 } 77 // SAFETY: attr is initialized 78 let ret = unsafe { 79 libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL) 80 }; 81 if ret != 0 { 82 // SAFETY: attr is initialized 83 unsafe { libc::pthread_mutexattr_destroy(attr) }; 84 return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); 85 } 86 // SAFETY: slot is valid 87 unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) }; 88 // SAFETY: attr and slot are valid ptrs and attr is initialized 89 let ret = unsafe { libc::pthread_mutex_init(slot, attr) }; 90 // SAFETY: attr was initialized 91 unsafe { libc::pthread_mutexattr_destroy(attr) }; 92 if ret != 0 { 93 return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); 94 } 95 Ok(()) 96 }; 97 // SAFETY: mutex has been initialized 98 unsafe { pin_init_from_closure(init) } 99 } 100 pin_init!(Self { 101 data: UnsafeCell::new(data), 102 raw <- init_raw(), 103 pin: PhantomPinned, 104 }? Error) 105 } 106 107 #[allow(dead_code)] 108 pub fn lock(&self) -> PThreadMutexGuard<'_, T> { 109 // SAFETY: raw is always initialized 110 unsafe { libc::pthread_mutex_lock(self.raw.get()) }; 111 PThreadMutexGuard { mtx: self } 112 } 113 } 114 115 pub struct PThreadMutexGuard<'a, T> { 116 mtx: &'a PThreadMutex<T>, 117 } 118 119 impl<T> Drop for PThreadMutexGuard<'_, T> { 120 fn drop(&mut self) { 121 // SAFETY: raw is always initialized 122 unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) }; 123 } 124 } 125 126 impl<T> Deref for PThreadMutexGuard<'_, T> { 127 type Target = T; 128 129 fn deref(&self) -> &Self::Target { 130 unsafe { &*self.mtx.data.get() } 131 } 132 } 133 134 impl<T> DerefMut for PThreadMutexGuard<'_, T> { 135 fn deref_mut(&mut self) -> &mut Self::Target { 136 unsafe { &mut *self.mtx.data.get() } 137 } 138 } 139 } 140 141 #[cfg_attr(test, test)] 142 #[cfg_attr(all(test, miri), ignore)] 143 fn main() { 144 #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))] 145 { 146 use core::pin::Pin; 147 use pin_init::*; 148 use pthread_mtx::*; 149 use std::{ 150 sync::Arc, 151 thread::{sleep, Builder}, 152 time::Duration, 153 }; 154 let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap(); 155 let mut handles = vec![]; 156 let thread_count = 20; 157 let workload = 1_000_000; 158 for i in 0..thread_count { 159 let mtx = mtx.clone(); 160 handles.push( 161 Builder::new() 162 .name(format!("worker #{i}")) 163 .spawn(move || { 164 for _ in 0..workload { 165 *mtx.lock() += 1; 166 } 167 println!("{i} halfway"); 168 sleep(Duration::from_millis((i as u64) * 10)); 169 for _ in 0..workload { 170 *mtx.lock() += 1; 171 } 172 println!("{i} finished"); 173 }) 174 .expect("should not fail"), 175 ); 176 } 177 for h in handles { 178 h.join().expect("thread panicked"); 179 } 180 println!("{:?}", *mtx.lock()); 181 assert_eq!(*mtx.lock(), workload * thread_count * 2); 182 } 183 } 184