1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
6
7 use core::{
8 cell::{Cell, UnsafeCell},
9 mem::MaybeUninit,
10 ops,
11 pin::Pin,
12 time::Duration,
13 };
14 use pin_init::*;
15 use std::{
16 sync::Arc,
17 thread::{sleep, Builder},
18 };
19
20 #[expect(unused_attributes)]
21 mod mutex;
22 use mutex::*;
23
24 pub struct StaticInit<T, I> {
25 cell: UnsafeCell<MaybeUninit<T>>,
26 init: Cell<Option<I>>,
27 lock: SpinLock,
28 present: Cell<bool>,
29 }
30
31 unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}
32 unsafe impl<T: Send, I> Send for StaticInit<T, I> {}
33
34 impl<T, I: PinInit<T>> StaticInit<T, I> {
new(init: I) -> Self35 pub const fn new(init: I) -> Self {
36 Self {
37 cell: UnsafeCell::new(MaybeUninit::uninit()),
38 init: Cell::new(Some(init)),
39 lock: SpinLock::new(),
40 present: Cell::new(false),
41 }
42 }
43 }
44
45 impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {
46 type Target = T;
deref(&self) -> &Self::Target47 fn deref(&self) -> &Self::Target {
48 if self.present.get() {
49 unsafe { (*self.cell.get()).assume_init_ref() }
50 } else {
51 println!("acquire spinlock on static init");
52 let _guard = self.lock.acquire();
53 println!("rechecking present...");
54 std::thread::sleep(std::time::Duration::from_millis(200));
55 if self.present.get() {
56 return unsafe { (*self.cell.get()).assume_init_ref() };
57 }
58 println!("doing init");
59 let ptr = self.cell.get().cast::<T>();
60 match self.init.take() {
61 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },
62 None => unsafe { core::hint::unreachable_unchecked() },
63 }
64 self.present.set(true);
65 unsafe { (*self.cell.get()).assume_init_ref() }
66 }
67 }
68 }
69
70 pub struct CountInit;
71
72 unsafe impl PinInit<CMutex<usize>> for CountInit {
__pinned_init( self, slot: *mut CMutex<usize>, ) -> Result<(), core::convert::Infallible>73 unsafe fn __pinned_init(
74 self,
75 slot: *mut CMutex<usize>,
76 ) -> Result<(), core::convert::Infallible> {
77 let init = CMutex::new(0);
78 std::thread::sleep(std::time::Duration::from_millis(1000));
79 unsafe { init.__pinned_init(slot) }
80 }
81 }
82
83 pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);
84
85 #[cfg(not(any(feature = "std", feature = "alloc")))]
main()86 fn main() {}
87
88 #[cfg(any(feature = "std", feature = "alloc"))]
main()89 fn main() {
90 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
91 let mut handles = vec![];
92 let thread_count = 20;
93 let workload = 1_000;
94 for i in 0..thread_count {
95 let mtx = mtx.clone();
96 handles.push(
97 Builder::new()
98 .name(format!("worker #{i}"))
99 .spawn(move || {
100 for _ in 0..workload {
101 *COUNT.lock() += 1;
102 std::thread::sleep(std::time::Duration::from_millis(10));
103 *mtx.lock() += 1;
104 std::thread::sleep(std::time::Duration::from_millis(10));
105 *COUNT.lock() += 1;
106 }
107 println!("{i} halfway");
108 sleep(Duration::from_millis((i as u64) * 10));
109 for _ in 0..workload {
110 std::thread::sleep(std::time::Duration::from_millis(10));
111 *mtx.lock() += 1;
112 }
113 println!("{i} finished");
114 })
115 .expect("should not fail"),
116 );
117 }
118 for h in handles {
119 h.join().expect("thread panicked");
120 }
121 println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock());
122 assert_eq!(*mtx.lock(), workload * thread_count * 2);
123 }
124