xref: /linux/rust/pin-init/examples/static_init.rs (revision ec7714e4947909190ffb3041a03311a975350fe0)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
6 
7 use core::{
8     cell::{Cell, UnsafeCell},
9     mem::MaybeUninit,
10     ops,
11     pin::Pin,
12     time::Duration,
13 };
14 use pin_init::*;
15 use std::{
16     sync::Arc,
17     thread::{sleep, Builder},
18 };
19 
20 #[expect(unused_attributes)]
21 mod mutex;
22 use mutex::*;
23 
24 pub struct StaticInit<T, I> {
25     cell: UnsafeCell<MaybeUninit<T>>,
26     init: Cell<Option<I>>,
27     lock: SpinLock,
28     present: Cell<bool>,
29 }
30 
31 unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}
32 unsafe impl<T: Send, I> Send for StaticInit<T, I> {}
33 
34 impl<T, I: PinInit<T>> StaticInit<T, I> {
new(init: I) -> Self35     pub const fn new(init: I) -> Self {
36         Self {
37             cell: UnsafeCell::new(MaybeUninit::uninit()),
38             init: Cell::new(Some(init)),
39             lock: SpinLock::new(),
40             present: Cell::new(false),
41         }
42     }
43 }
44 
45 impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {
46     type Target = T;
deref(&self) -> &Self::Target47     fn deref(&self) -> &Self::Target {
48         if self.present.get() {
49             unsafe { (*self.cell.get()).assume_init_ref() }
50         } else {
51             println!("acquire spinlock on static init");
52             let _guard = self.lock.acquire();
53             println!("rechecking present...");
54             std::thread::sleep(std::time::Duration::from_millis(200));
55             if self.present.get() {
56                 return unsafe { (*self.cell.get()).assume_init_ref() };
57             }
58             println!("doing init");
59             let ptr = self.cell.get().cast::<T>();
60             match self.init.take() {
61                 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },
62                 None => unsafe { core::hint::unreachable_unchecked() },
63             }
64             self.present.set(true);
65             unsafe { (*self.cell.get()).assume_init_ref() }
66         }
67     }
68 }
69 
70 pub struct CountInit;
71 
72 unsafe impl PinInit<CMutex<usize>> for CountInit {
__pinned_init( self, slot: *mut CMutex<usize>, ) -> Result<(), core::convert::Infallible>73     unsafe fn __pinned_init(
74         self,
75         slot: *mut CMutex<usize>,
76     ) -> Result<(), core::convert::Infallible> {
77         let init = CMutex::new(0);
78         std::thread::sleep(std::time::Duration::from_millis(1000));
79         unsafe { init.__pinned_init(slot) }
80     }
81 }
82 
83 pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);
84 
85 #[cfg(not(any(feature = "std", feature = "alloc")))]
main()86 fn main() {}
87 
88 #[cfg(any(feature = "std", feature = "alloc"))]
main()89 fn main() {
90     let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
91     let mut handles = vec![];
92     let thread_count = 20;
93     let workload = 1_000;
94     for i in 0..thread_count {
95         let mtx = mtx.clone();
96         handles.push(
97             Builder::new()
98                 .name(format!("worker #{i}"))
99                 .spawn(move || {
100                     for _ in 0..workload {
101                         *COUNT.lock() += 1;
102                         std::thread::sleep(std::time::Duration::from_millis(10));
103                         *mtx.lock() += 1;
104                         std::thread::sleep(std::time::Duration::from_millis(10));
105                         *COUNT.lock() += 1;
106                     }
107                     println!("{i} halfway");
108                     sleep(Duration::from_millis((i as u64) * 10));
109                     for _ in 0..workload {
110                         std::thread::sleep(std::time::Duration::from_millis(10));
111                         *mtx.lock() += 1;
112                     }
113                     println!("{i} finished");
114                 })
115                 .expect("should not fail"),
116         );
117     }
118     for h in handles {
119         h.join().expect("thread panicked");
120     }
121     println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock());
122     assert_eq!(*mtx.lock(), workload * thread_count * 2);
123 }
124