xref: /linux/rust/pin-init/examples/static_init.rs (revision 9e1e9d660255d7216067193d774f338d08d8528d)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 #![cfg_attr(USE_RUSTC_FEATURES, feature(lint_reasons))]
6 #![cfg_attr(USE_RUSTC_FEATURES, feature(raw_ref_op))]
7 #![allow(unused_imports)]
8 
9 use core::{
10     cell::{Cell, UnsafeCell},
11     mem::MaybeUninit,
12     ops,
13     pin::Pin,
14     time::Duration,
15 };
16 use pin_init::*;
17 #[cfg(feature = "std")]
18 use std::{
19     sync::Arc,
20     thread::{sleep, Builder},
21 };
22 
23 #[allow(unused_attributes)]
24 mod mutex;
25 use mutex::*;
26 
27 pub struct StaticInit<T, I> {
28     cell: UnsafeCell<MaybeUninit<T>>,
29     init: Cell<Option<I>>,
30     lock: SpinLock,
31     present: Cell<bool>,
32 }
33 
34 unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}
35 unsafe impl<T: Send, I> Send for StaticInit<T, I> {}
36 
37 impl<T, I: PinInit<T>> StaticInit<T, I> {
38     pub const fn new(init: I) -> Self {
39         Self {
40             cell: UnsafeCell::new(MaybeUninit::uninit()),
41             init: Cell::new(Some(init)),
42             lock: SpinLock::new(),
43             present: Cell::new(false),
44         }
45     }
46 }
47 
48 impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {
49     type Target = T;
50     fn deref(&self) -> &Self::Target {
51         if self.present.get() {
52             unsafe { (*self.cell.get()).assume_init_ref() }
53         } else {
54             println!("acquire spinlock on static init");
55             let _guard = self.lock.acquire();
56             println!("rechecking present...");
57             std::thread::sleep(std::time::Duration::from_millis(200));
58             if self.present.get() {
59                 return unsafe { (*self.cell.get()).assume_init_ref() };
60             }
61             println!("doing init");
62             let ptr = self.cell.get().cast::<T>();
63             match self.init.take() {
64                 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },
65                 None => unsafe { core::hint::unreachable_unchecked() },
66             }
67             self.present.set(true);
68             unsafe { (*self.cell.get()).assume_init_ref() }
69         }
70     }
71 }
72 
73 pub struct CountInit;
74 
75 unsafe impl PinInit<CMutex<usize>> for CountInit {
76     unsafe fn __pinned_init(
77         self,
78         slot: *mut CMutex<usize>,
79     ) -> Result<(), core::convert::Infallible> {
80         let init = CMutex::new(0);
81         std::thread::sleep(std::time::Duration::from_millis(1000));
82         unsafe { init.__pinned_init(slot) }
83     }
84 }
85 
86 pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);
87 
88 fn main() {
89     #[cfg(feature = "std")]
90     {
91         let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
92         let mut handles = vec![];
93         let thread_count = 20;
94         let workload = 1_000;
95         for i in 0..thread_count {
96             let mtx = mtx.clone();
97             handles.push(
98                 Builder::new()
99                     .name(format!("worker #{i}"))
100                     .spawn(move || {
101                         for _ in 0..workload {
102                             *COUNT.lock() += 1;
103                             std::thread::sleep(std::time::Duration::from_millis(10));
104                             *mtx.lock() += 1;
105                             std::thread::sleep(std::time::Duration::from_millis(10));
106                             *COUNT.lock() += 1;
107                         }
108                         println!("{i} halfway");
109                         sleep(Duration::from_millis((i as u64) * 10));
110                         for _ in 0..workload {
111                             std::thread::sleep(std::time::Duration::from_millis(10));
112                             *mtx.lock() += 1;
113                         }
114                         println!("{i} finished");
115                     })
116                     .expect("should not fail"),
117             );
118         }
119         for h in handles {
120             h.join().expect("thread panicked");
121         }
122         println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock());
123         assert_eq!(*mtx.lock(), workload * thread_count * 2);
124     }
125 }
126