1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
6 #![allow(clippy::missing_safety_doc)]
7
8 use core::{
9 cell::{Cell, UnsafeCell},
10 marker::PhantomPinned,
11 ops::{Deref, DerefMut},
12 pin::Pin,
13 sync::atomic::{AtomicBool, Ordering},
14 };
15 use std::{
16 sync::Arc,
17 thread::{self, park, sleep, Builder, Thread},
18 time::Duration,
19 };
20
21 use pin_init::*;
22 #[expect(unused_attributes)]
23 #[path = "./linked_list.rs"]
24 pub mod linked_list;
25 use linked_list::*;
26
27 pub struct SpinLock {
28 inner: AtomicBool,
29 }
30
31 impl SpinLock {
32 #[inline]
acquire(&self) -> SpinLockGuard<'_>33 pub fn acquire(&self) -> SpinLockGuard<'_> {
34 while self
35 .inner
36 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
37 .is_err()
38 {
39 while self.inner.load(Ordering::Relaxed) {
40 thread::yield_now();
41 }
42 }
43 SpinLockGuard(self)
44 }
45
46 #[inline]
47 #[allow(clippy::new_without_default)]
new() -> Self48 pub const fn new() -> Self {
49 Self {
50 inner: AtomicBool::new(false),
51 }
52 }
53 }
54
55 pub struct SpinLockGuard<'a>(&'a SpinLock);
56
57 impl Drop for SpinLockGuard<'_> {
58 #[inline]
drop(&mut self)59 fn drop(&mut self) {
60 self.0.inner.store(false, Ordering::Release);
61 }
62 }
63
64 #[pin_data]
65 pub struct CMutex<T> {
66 #[pin]
67 wait_list: ListHead,
68 spin_lock: SpinLock,
69 locked: Cell<bool>,
70 #[pin]
71 data: UnsafeCell<T>,
72 }
73
74 impl<T> CMutex<T> {
75 #[inline]
new(val: impl PinInit<T>) -> impl PinInit<Self>76 pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
77 pin_init!(CMutex {
78 wait_list <- ListHead::new(),
79 spin_lock: SpinLock::new(),
80 locked: Cell::new(false),
81 data <- unsafe {
82 pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
83 val.__pinned_init(slot.cast::<T>())
84 })
85 },
86 })
87 }
88
89 #[inline]
lock(&self) -> Pin<CMutexGuard<'_, T>>90 pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
91 let mut sguard = self.spin_lock.acquire();
92 if self.locked.get() {
93 stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
94 // println!("wait list length: {}", self.wait_list.size());
95 while self.locked.get() {
96 drop(sguard);
97 park();
98 sguard = self.spin_lock.acquire();
99 }
100 // This does have an effect, as the ListHead inside wait_entry implements Drop!
101 #[expect(clippy::drop_non_drop)]
102 drop(wait_entry);
103 }
104 self.locked.set(true);
105 unsafe {
106 Pin::new_unchecked(CMutexGuard {
107 mtx: self,
108 _pin: PhantomPinned,
109 })
110 }
111 }
112
113 #[allow(dead_code)]
get_data_mut(self: Pin<&mut Self>) -> &mut T114 pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
115 // SAFETY: we have an exclusive reference and thus nobody has access to data.
116 unsafe { &mut *self.data.get() }
117 }
118 }
119
120 unsafe impl<T: Send> Send for CMutex<T> {}
121 unsafe impl<T: Send> Sync for CMutex<T> {}
122
123 pub struct CMutexGuard<'a, T> {
124 mtx: &'a CMutex<T>,
125 _pin: PhantomPinned,
126 }
127
128 impl<T> Drop for CMutexGuard<'_, T> {
129 #[inline]
drop(&mut self)130 fn drop(&mut self) {
131 let sguard = self.mtx.spin_lock.acquire();
132 self.mtx.locked.set(false);
133 if let Some(list_field) = self.mtx.wait_list.next() {
134 let wait_entry = list_field.as_ptr().cast::<WaitEntry>();
135 unsafe { (*wait_entry).thread.unpark() };
136 }
137 drop(sguard);
138 }
139 }
140
141 impl<T> Deref for CMutexGuard<'_, T> {
142 type Target = T;
143
144 #[inline]
deref(&self) -> &Self::Target145 fn deref(&self) -> &Self::Target {
146 unsafe { &*self.mtx.data.get() }
147 }
148 }
149
150 impl<T> DerefMut for CMutexGuard<'_, T> {
151 #[inline]
deref_mut(&mut self) -> &mut Self::Target152 fn deref_mut(&mut self) -> &mut Self::Target {
153 unsafe { &mut *self.mtx.data.get() }
154 }
155 }
156
157 #[pin_data]
158 #[repr(C)]
159 struct WaitEntry {
160 #[pin]
161 wait_list: ListHead,
162 thread: Thread,
163 }
164
165 impl WaitEntry {
166 #[inline]
insert_new(list: &ListHead) -> impl PinInit<Self> + '_167 fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
168 pin_init!(Self {
169 thread: thread::current(),
170 wait_list <- ListHead::insert_prev(list),
171 })
172 }
173 }
174
175 #[cfg(not(any(feature = "std", feature = "alloc")))]
main()176 fn main() {}
177
178 #[allow(dead_code)]
179 #[cfg_attr(test, test)]
180 #[cfg(any(feature = "std", feature = "alloc"))]
main()181 fn main() {
182 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
183 let mut handles = vec![];
184 let thread_count = 20;
185 let workload = if cfg!(miri) { 100 } else { 1_000 };
186 for i in 0..thread_count {
187 let mtx = mtx.clone();
188 handles.push(
189 Builder::new()
190 .name(format!("worker #{i}"))
191 .spawn(move || {
192 for _ in 0..workload {
193 *mtx.lock() += 1;
194 }
195 println!("{i} halfway");
196 sleep(Duration::from_millis((i as u64) * 10));
197 for _ in 0..workload {
198 *mtx.lock() += 1;
199 }
200 println!("{i} finished");
201 })
202 .expect("should not fail"),
203 );
204 }
205 for h in handles {
206 h.join().expect("thread panicked");
207 }
208 println!("{:?}", &*mtx.lock());
209 assert_eq!(*mtx.lock(), workload * thread_count * 2);
210 }
211