184837cf6SBenno Lossin // SPDX-License-Identifier: Apache-2.0 OR MIT
284837cf6SBenno Lossin
384837cf6SBenno Lossin #![allow(clippy::undocumented_unsafe_blocks)]
484837cf6SBenno Lossin #![cfg_attr(feature = "alloc", feature(allocator_api))]
5*5c4167b4SBenno Lossin #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
684837cf6SBenno Lossin #![allow(clippy::missing_safety_doc)]
784837cf6SBenno Lossin
884837cf6SBenno Lossin use core::{
984837cf6SBenno Lossin cell::{Cell, UnsafeCell},
1084837cf6SBenno Lossin marker::PhantomPinned,
1184837cf6SBenno Lossin ops::{Deref, DerefMut},
1284837cf6SBenno Lossin pin::Pin,
1384837cf6SBenno Lossin sync::atomic::{AtomicBool, Ordering},
1484837cf6SBenno Lossin };
1584837cf6SBenno Lossin use std::{
1684837cf6SBenno Lossin sync::Arc,
1784837cf6SBenno Lossin thread::{self, park, sleep, Builder, Thread},
1884837cf6SBenno Lossin time::Duration,
1984837cf6SBenno Lossin };
2084837cf6SBenno Lossin
2184837cf6SBenno Lossin use pin_init::*;
2284837cf6SBenno Lossin #[expect(unused_attributes)]
2384837cf6SBenno Lossin #[path = "./linked_list.rs"]
2484837cf6SBenno Lossin pub mod linked_list;
2584837cf6SBenno Lossin use linked_list::*;
2684837cf6SBenno Lossin
2784837cf6SBenno Lossin pub struct SpinLock {
2884837cf6SBenno Lossin inner: AtomicBool,
2984837cf6SBenno Lossin }
3084837cf6SBenno Lossin
3184837cf6SBenno Lossin impl SpinLock {
3284837cf6SBenno Lossin #[inline]
acquire(&self) -> SpinLockGuard<'_>3384837cf6SBenno Lossin pub fn acquire(&self) -> SpinLockGuard<'_> {
3484837cf6SBenno Lossin while self
3584837cf6SBenno Lossin .inner
3684837cf6SBenno Lossin .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
3784837cf6SBenno Lossin .is_err()
3884837cf6SBenno Lossin {
3984837cf6SBenno Lossin while self.inner.load(Ordering::Relaxed) {
4084837cf6SBenno Lossin thread::yield_now();
4184837cf6SBenno Lossin }
4284837cf6SBenno Lossin }
4384837cf6SBenno Lossin SpinLockGuard(self)
4484837cf6SBenno Lossin }
4584837cf6SBenno Lossin
4684837cf6SBenno Lossin #[inline]
4784837cf6SBenno Lossin #[allow(clippy::new_without_default)]
new() -> Self4884837cf6SBenno Lossin pub const fn new() -> Self {
4984837cf6SBenno Lossin Self {
5084837cf6SBenno Lossin inner: AtomicBool::new(false),
5184837cf6SBenno Lossin }
5284837cf6SBenno Lossin }
5384837cf6SBenno Lossin }
5484837cf6SBenno Lossin
5584837cf6SBenno Lossin pub struct SpinLockGuard<'a>(&'a SpinLock);
5684837cf6SBenno Lossin
5784837cf6SBenno Lossin impl Drop for SpinLockGuard<'_> {
5884837cf6SBenno Lossin #[inline]
drop(&mut self)5984837cf6SBenno Lossin fn drop(&mut self) {
6084837cf6SBenno Lossin self.0.inner.store(false, Ordering::Release);
6184837cf6SBenno Lossin }
6284837cf6SBenno Lossin }
6384837cf6SBenno Lossin
6484837cf6SBenno Lossin #[pin_data]
6584837cf6SBenno Lossin pub struct CMutex<T> {
6684837cf6SBenno Lossin #[pin]
6784837cf6SBenno Lossin wait_list: ListHead,
6884837cf6SBenno Lossin spin_lock: SpinLock,
6984837cf6SBenno Lossin locked: Cell<bool>,
7084837cf6SBenno Lossin #[pin]
7184837cf6SBenno Lossin data: UnsafeCell<T>,
7284837cf6SBenno Lossin }
7384837cf6SBenno Lossin
7484837cf6SBenno Lossin impl<T> CMutex<T> {
7584837cf6SBenno Lossin #[inline]
new(val: impl PinInit<T>) -> impl PinInit<Self>7684837cf6SBenno Lossin pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
7784837cf6SBenno Lossin pin_init!(CMutex {
7884837cf6SBenno Lossin wait_list <- ListHead::new(),
7984837cf6SBenno Lossin spin_lock: SpinLock::new(),
8084837cf6SBenno Lossin locked: Cell::new(false),
8184837cf6SBenno Lossin data <- unsafe {
8284837cf6SBenno Lossin pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
8384837cf6SBenno Lossin val.__pinned_init(slot.cast::<T>())
8484837cf6SBenno Lossin })
8584837cf6SBenno Lossin },
8684837cf6SBenno Lossin })
8784837cf6SBenno Lossin }
8884837cf6SBenno Lossin
8984837cf6SBenno Lossin #[inline]
lock(&self) -> Pin<CMutexGuard<'_, T>>9084837cf6SBenno Lossin pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
9184837cf6SBenno Lossin let mut sguard = self.spin_lock.acquire();
9284837cf6SBenno Lossin if self.locked.get() {
9384837cf6SBenno Lossin stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
9484837cf6SBenno Lossin // println!("wait list length: {}", self.wait_list.size());
9584837cf6SBenno Lossin while self.locked.get() {
9684837cf6SBenno Lossin drop(sguard);
9784837cf6SBenno Lossin park();
9884837cf6SBenno Lossin sguard = self.spin_lock.acquire();
9984837cf6SBenno Lossin }
10084837cf6SBenno Lossin // This does have an effect, as the ListHead inside wait_entry implements Drop!
10184837cf6SBenno Lossin #[expect(clippy::drop_non_drop)]
10284837cf6SBenno Lossin drop(wait_entry);
10384837cf6SBenno Lossin }
10484837cf6SBenno Lossin self.locked.set(true);
10584837cf6SBenno Lossin unsafe {
10684837cf6SBenno Lossin Pin::new_unchecked(CMutexGuard {
10784837cf6SBenno Lossin mtx: self,
10884837cf6SBenno Lossin _pin: PhantomPinned,
10984837cf6SBenno Lossin })
11084837cf6SBenno Lossin }
11184837cf6SBenno Lossin }
11284837cf6SBenno Lossin
11384837cf6SBenno Lossin #[allow(dead_code)]
get_data_mut(self: Pin<&mut Self>) -> &mut T11484837cf6SBenno Lossin pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
11584837cf6SBenno Lossin // SAFETY: we have an exclusive reference and thus nobody has access to data.
11684837cf6SBenno Lossin unsafe { &mut *self.data.get() }
11784837cf6SBenno Lossin }
11884837cf6SBenno Lossin }
11984837cf6SBenno Lossin
12084837cf6SBenno Lossin unsafe impl<T: Send> Send for CMutex<T> {}
12184837cf6SBenno Lossin unsafe impl<T: Send> Sync for CMutex<T> {}
12284837cf6SBenno Lossin
12384837cf6SBenno Lossin pub struct CMutexGuard<'a, T> {
12484837cf6SBenno Lossin mtx: &'a CMutex<T>,
12584837cf6SBenno Lossin _pin: PhantomPinned,
12684837cf6SBenno Lossin }
12784837cf6SBenno Lossin
12884837cf6SBenno Lossin impl<T> Drop for CMutexGuard<'_, T> {
12984837cf6SBenno Lossin #[inline]
drop(&mut self)13084837cf6SBenno Lossin fn drop(&mut self) {
13184837cf6SBenno Lossin let sguard = self.mtx.spin_lock.acquire();
13284837cf6SBenno Lossin self.mtx.locked.set(false);
13384837cf6SBenno Lossin if let Some(list_field) = self.mtx.wait_list.next() {
13484837cf6SBenno Lossin let wait_entry = list_field.as_ptr().cast::<WaitEntry>();
13584837cf6SBenno Lossin unsafe { (*wait_entry).thread.unpark() };
13684837cf6SBenno Lossin }
13784837cf6SBenno Lossin drop(sguard);
13884837cf6SBenno Lossin }
13984837cf6SBenno Lossin }
14084837cf6SBenno Lossin
14184837cf6SBenno Lossin impl<T> Deref for CMutexGuard<'_, T> {
14284837cf6SBenno Lossin type Target = T;
14384837cf6SBenno Lossin
14484837cf6SBenno Lossin #[inline]
deref(&self) -> &Self::Target14584837cf6SBenno Lossin fn deref(&self) -> &Self::Target {
14684837cf6SBenno Lossin unsafe { &*self.mtx.data.get() }
14784837cf6SBenno Lossin }
14884837cf6SBenno Lossin }
14984837cf6SBenno Lossin
15084837cf6SBenno Lossin impl<T> DerefMut for CMutexGuard<'_, T> {
15184837cf6SBenno Lossin #[inline]
deref_mut(&mut self) -> &mut Self::Target15284837cf6SBenno Lossin fn deref_mut(&mut self) -> &mut Self::Target {
15384837cf6SBenno Lossin unsafe { &mut *self.mtx.data.get() }
15484837cf6SBenno Lossin }
15584837cf6SBenno Lossin }
15684837cf6SBenno Lossin
15784837cf6SBenno Lossin #[pin_data]
15884837cf6SBenno Lossin #[repr(C)]
15984837cf6SBenno Lossin struct WaitEntry {
16084837cf6SBenno Lossin #[pin]
16184837cf6SBenno Lossin wait_list: ListHead,
16284837cf6SBenno Lossin thread: Thread,
16384837cf6SBenno Lossin }
16484837cf6SBenno Lossin
16584837cf6SBenno Lossin impl WaitEntry {
16684837cf6SBenno Lossin #[inline]
insert_new(list: &ListHead) -> impl PinInit<Self> + '_16784837cf6SBenno Lossin fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
16884837cf6SBenno Lossin pin_init!(Self {
16984837cf6SBenno Lossin thread: thread::current(),
17084837cf6SBenno Lossin wait_list <- ListHead::insert_prev(list),
17184837cf6SBenno Lossin })
17284837cf6SBenno Lossin }
17384837cf6SBenno Lossin }
17484837cf6SBenno Lossin
17584837cf6SBenno Lossin #[cfg(not(any(feature = "std", feature = "alloc")))]
main()17684837cf6SBenno Lossin fn main() {}
17784837cf6SBenno Lossin
17884837cf6SBenno Lossin #[allow(dead_code)]
17984837cf6SBenno Lossin #[cfg_attr(test, test)]
18084837cf6SBenno Lossin #[cfg(any(feature = "std", feature = "alloc"))]
main()18184837cf6SBenno Lossin fn main() {
18284837cf6SBenno Lossin let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
18384837cf6SBenno Lossin let mut handles = vec![];
18484837cf6SBenno Lossin let thread_count = 20;
18584837cf6SBenno Lossin let workload = if cfg!(miri) { 100 } else { 1_000 };
18684837cf6SBenno Lossin for i in 0..thread_count {
18784837cf6SBenno Lossin let mtx = mtx.clone();
18884837cf6SBenno Lossin handles.push(
18984837cf6SBenno Lossin Builder::new()
19084837cf6SBenno Lossin .name(format!("worker #{i}"))
19184837cf6SBenno Lossin .spawn(move || {
19284837cf6SBenno Lossin for _ in 0..workload {
19384837cf6SBenno Lossin *mtx.lock() += 1;
19484837cf6SBenno Lossin }
19584837cf6SBenno Lossin println!("{i} halfway");
19684837cf6SBenno Lossin sleep(Duration::from_millis((i as u64) * 10));
19784837cf6SBenno Lossin for _ in 0..workload {
19884837cf6SBenno Lossin *mtx.lock() += 1;
19984837cf6SBenno Lossin }
20084837cf6SBenno Lossin println!("{i} finished");
20184837cf6SBenno Lossin })
20284837cf6SBenno Lossin .expect("should not fail"),
20384837cf6SBenno Lossin );
20484837cf6SBenno Lossin }
20584837cf6SBenno Lossin for h in handles {
20684837cf6SBenno Lossin h.join().expect("thread panicked");
20784837cf6SBenno Lossin }
20884837cf6SBenno Lossin println!("{:?}", &*mtx.lock());
20984837cf6SBenno Lossin assert_eq!(*mtx.lock(), workload * thread_count * 2);
21084837cf6SBenno Lossin }
211