184837cf6SBenno Lossin // SPDX-License-Identifier: Apache-2.0 OR MIT
284837cf6SBenno Lossin
3193b5a75SMiguel Ojeda // inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs>
484837cf6SBenno Lossin #![allow(clippy::undocumented_unsafe_blocks)]
584837cf6SBenno Lossin #![cfg_attr(feature = "alloc", feature(allocator_api))]
65c4167b4SBenno Lossin #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
75c4167b4SBenno Lossin
884837cf6SBenno Lossin #[cfg(not(windows))]
984837cf6SBenno Lossin mod pthread_mtx {
1084837cf6SBenno Lossin #[cfg(feature = "alloc")]
1184837cf6SBenno Lossin use core::alloc::AllocError;
1284837cf6SBenno Lossin use core::{
1384837cf6SBenno Lossin cell::UnsafeCell,
1484837cf6SBenno Lossin marker::PhantomPinned,
1584837cf6SBenno Lossin mem::MaybeUninit,
1684837cf6SBenno Lossin ops::{Deref, DerefMut},
1784837cf6SBenno Lossin pin::Pin,
1884837cf6SBenno Lossin };
1984837cf6SBenno Lossin use pin_init::*;
2084837cf6SBenno Lossin use std::convert::Infallible;
2184837cf6SBenno Lossin
2284837cf6SBenno Lossin #[pin_data(PinnedDrop)]
2384837cf6SBenno Lossin pub struct PThreadMutex<T> {
2484837cf6SBenno Lossin #[pin]
2584837cf6SBenno Lossin raw: UnsafeCell<libc::pthread_mutex_t>,
2684837cf6SBenno Lossin data: UnsafeCell<T>,
2784837cf6SBenno Lossin #[pin]
2884837cf6SBenno Lossin pin: PhantomPinned,
2984837cf6SBenno Lossin }
3084837cf6SBenno Lossin
3184837cf6SBenno Lossin unsafe impl<T: Send> Send for PThreadMutex<T> {}
3284837cf6SBenno Lossin unsafe impl<T: Send> Sync for PThreadMutex<T> {}
3384837cf6SBenno Lossin
3484837cf6SBenno Lossin #[pinned_drop]
3584837cf6SBenno Lossin impl<T> PinnedDrop for PThreadMutex<T> {
drop(self: Pin<&mut Self>)3684837cf6SBenno Lossin fn drop(self: Pin<&mut Self>) {
3784837cf6SBenno Lossin unsafe {
3884837cf6SBenno Lossin libc::pthread_mutex_destroy(self.raw.get());
3984837cf6SBenno Lossin }
4084837cf6SBenno Lossin }
4184837cf6SBenno Lossin }
4284837cf6SBenno Lossin
4384837cf6SBenno Lossin #[derive(Debug)]
4484837cf6SBenno Lossin pub enum Error {
45*39051adbSBenno Lossin #[allow(dead_code)]
4684837cf6SBenno Lossin IO(std::io::Error),
4784837cf6SBenno Lossin Alloc,
4884837cf6SBenno Lossin }
4984837cf6SBenno Lossin
5084837cf6SBenno Lossin impl From<Infallible> for Error {
from(e: Infallible) -> Self5184837cf6SBenno Lossin fn from(e: Infallible) -> Self {
5284837cf6SBenno Lossin match e {}
5384837cf6SBenno Lossin }
5484837cf6SBenno Lossin }
5584837cf6SBenno Lossin
5684837cf6SBenno Lossin #[cfg(feature = "alloc")]
5784837cf6SBenno Lossin impl From<AllocError> for Error {
from(_: AllocError) -> Self5884837cf6SBenno Lossin fn from(_: AllocError) -> Self {
5984837cf6SBenno Lossin Self::Alloc
6084837cf6SBenno Lossin }
6184837cf6SBenno Lossin }
6284837cf6SBenno Lossin
6384837cf6SBenno Lossin impl<T> PThreadMutex<T> {
new(data: T) -> impl PinInit<Self, Error>6484837cf6SBenno Lossin pub fn new(data: T) -> impl PinInit<Self, Error> {
6584837cf6SBenno Lossin fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> {
6684837cf6SBenno Lossin let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| {
6784837cf6SBenno Lossin // we can cast, because `UnsafeCell` has the same layout as T.
6884837cf6SBenno Lossin let slot: *mut libc::pthread_mutex_t = slot.cast();
6984837cf6SBenno Lossin let mut attr = MaybeUninit::uninit();
7084837cf6SBenno Lossin let attr = attr.as_mut_ptr();
7184837cf6SBenno Lossin // SAFETY: ptr is valid
7284837cf6SBenno Lossin let ret = unsafe { libc::pthread_mutexattr_init(attr) };
7384837cf6SBenno Lossin if ret != 0 {
7484837cf6SBenno Lossin return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
7584837cf6SBenno Lossin }
7684837cf6SBenno Lossin // SAFETY: attr is initialized
7784837cf6SBenno Lossin let ret = unsafe {
7884837cf6SBenno Lossin libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL)
7984837cf6SBenno Lossin };
8084837cf6SBenno Lossin if ret != 0 {
8184837cf6SBenno Lossin // SAFETY: attr is initialized
8284837cf6SBenno Lossin unsafe { libc::pthread_mutexattr_destroy(attr) };
8384837cf6SBenno Lossin return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
8484837cf6SBenno Lossin }
8584837cf6SBenno Lossin // SAFETY: slot is valid
8684837cf6SBenno Lossin unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) };
8784837cf6SBenno Lossin // SAFETY: attr and slot are valid ptrs and attr is initialized
8884837cf6SBenno Lossin let ret = unsafe { libc::pthread_mutex_init(slot, attr) };
8984837cf6SBenno Lossin // SAFETY: attr was initialized
9084837cf6SBenno Lossin unsafe { libc::pthread_mutexattr_destroy(attr) };
9184837cf6SBenno Lossin if ret != 0 {
9284837cf6SBenno Lossin return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
9384837cf6SBenno Lossin }
9484837cf6SBenno Lossin Ok(())
9584837cf6SBenno Lossin };
9684837cf6SBenno Lossin // SAFETY: mutex has been initialized
9784837cf6SBenno Lossin unsafe { pin_init_from_closure(init) }
9884837cf6SBenno Lossin }
9984837cf6SBenno Lossin try_pin_init!(Self {
10084837cf6SBenno Lossin data: UnsafeCell::new(data),
10184837cf6SBenno Lossin raw <- init_raw(),
10284837cf6SBenno Lossin pin: PhantomPinned,
10384837cf6SBenno Lossin }? Error)
10484837cf6SBenno Lossin }
10584837cf6SBenno Lossin
lock(&self) -> PThreadMutexGuard<'_, T>10684837cf6SBenno Lossin pub fn lock(&self) -> PThreadMutexGuard<'_, T> {
10784837cf6SBenno Lossin // SAFETY: raw is always initialized
10884837cf6SBenno Lossin unsafe { libc::pthread_mutex_lock(self.raw.get()) };
10984837cf6SBenno Lossin PThreadMutexGuard { mtx: self }
11084837cf6SBenno Lossin }
11184837cf6SBenno Lossin }
11284837cf6SBenno Lossin
11384837cf6SBenno Lossin pub struct PThreadMutexGuard<'a, T> {
11484837cf6SBenno Lossin mtx: &'a PThreadMutex<T>,
11584837cf6SBenno Lossin }
11684837cf6SBenno Lossin
11784837cf6SBenno Lossin impl<T> Drop for PThreadMutexGuard<'_, T> {
drop(&mut self)11884837cf6SBenno Lossin fn drop(&mut self) {
11984837cf6SBenno Lossin // SAFETY: raw is always initialized
12084837cf6SBenno Lossin unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) };
12184837cf6SBenno Lossin }
12284837cf6SBenno Lossin }
12384837cf6SBenno Lossin
12484837cf6SBenno Lossin impl<T> Deref for PThreadMutexGuard<'_, T> {
12584837cf6SBenno Lossin type Target = T;
12684837cf6SBenno Lossin
deref(&self) -> &Self::Target12784837cf6SBenno Lossin fn deref(&self) -> &Self::Target {
12884837cf6SBenno Lossin unsafe { &*self.mtx.data.get() }
12984837cf6SBenno Lossin }
13084837cf6SBenno Lossin }
13184837cf6SBenno Lossin
13284837cf6SBenno Lossin impl<T> DerefMut for PThreadMutexGuard<'_, T> {
deref_mut(&mut self) -> &mut Self::Target13384837cf6SBenno Lossin fn deref_mut(&mut self) -> &mut Self::Target {
13484837cf6SBenno Lossin unsafe { &mut *self.mtx.data.get() }
13584837cf6SBenno Lossin }
13684837cf6SBenno Lossin }
13784837cf6SBenno Lossin }
13884837cf6SBenno Lossin
13984837cf6SBenno Lossin #[cfg_attr(test, test)]
main()14084837cf6SBenno Lossin fn main() {
14184837cf6SBenno Lossin #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))]
14284837cf6SBenno Lossin {
14384837cf6SBenno Lossin use core::pin::Pin;
14484837cf6SBenno Lossin use pin_init::*;
14584837cf6SBenno Lossin use pthread_mtx::*;
14684837cf6SBenno Lossin use std::{
14784837cf6SBenno Lossin sync::Arc,
14884837cf6SBenno Lossin thread::{sleep, Builder},
14984837cf6SBenno Lossin time::Duration,
15084837cf6SBenno Lossin };
15184837cf6SBenno Lossin let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap();
15284837cf6SBenno Lossin let mut handles = vec![];
15384837cf6SBenno Lossin let thread_count = 20;
15484837cf6SBenno Lossin let workload = 1_000_000;
15584837cf6SBenno Lossin for i in 0..thread_count {
15684837cf6SBenno Lossin let mtx = mtx.clone();
15784837cf6SBenno Lossin handles.push(
15884837cf6SBenno Lossin Builder::new()
15984837cf6SBenno Lossin .name(format!("worker #{i}"))
16084837cf6SBenno Lossin .spawn(move || {
16184837cf6SBenno Lossin for _ in 0..workload {
16284837cf6SBenno Lossin *mtx.lock() += 1;
16384837cf6SBenno Lossin }
16484837cf6SBenno Lossin println!("{i} halfway");
16584837cf6SBenno Lossin sleep(Duration::from_millis((i as u64) * 10));
16684837cf6SBenno Lossin for _ in 0..workload {
16784837cf6SBenno Lossin *mtx.lock() += 1;
16884837cf6SBenno Lossin }
16984837cf6SBenno Lossin println!("{i} finished");
17084837cf6SBenno Lossin })
17184837cf6SBenno Lossin .expect("should not fail"),
17284837cf6SBenno Lossin );
17384837cf6SBenno Lossin }
17484837cf6SBenno Lossin for h in handles {
17584837cf6SBenno Lossin h.join().expect("thread panicked");
17684837cf6SBenno Lossin }
17784837cf6SBenno Lossin println!("{:?}", &*mtx.lock());
17884837cf6SBenno Lossin assert_eq!(*mtx.lock(), workload * thread_count * 2);
17984837cf6SBenno Lossin }
18084837cf6SBenno Lossin }
181