xref: /linux/rust/pin-init/examples/static_init.rs (revision 430654211d566f86e8ee533ff1b01a42be6b602c)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 #![allow(clippy::undocumented_unsafe_blocks)]
4 #![cfg_attr(feature = "alloc", feature(allocator_api))]
5 #![allow(unused_imports)]
6 
7 use core::{
8     cell::{Cell, UnsafeCell},
9     mem::MaybeUninit,
10     ops,
11     pin::Pin,
12     time::Duration,
13 };
14 use pin_init::*;
15 #[cfg(feature = "std")]
16 use std::{
17     sync::Arc,
18     thread::{sleep, Builder},
19 };
20 
21 #[allow(unused_attributes)]
22 mod mutex;
23 use mutex::*;
24 
25 pub struct StaticInit<T, I> {
26     cell: UnsafeCell<MaybeUninit<T>>,
27     init: Cell<Option<I>>,
28     lock: SpinLock,
29     present: Cell<bool>,
30 }
31 
32 unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}
33 unsafe impl<T: Send, I> Send for StaticInit<T, I> {}
34 
35 impl<T, I: PinInit<T>> StaticInit<T, I> {
36     pub const fn new(init: I) -> Self {
37         Self {
38             cell: UnsafeCell::new(MaybeUninit::uninit()),
39             init: Cell::new(Some(init)),
40             lock: SpinLock::new(),
41             present: Cell::new(false),
42         }
43     }
44 }
45 
46 impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {
47     type Target = T;
48     fn deref(&self) -> &Self::Target {
49         if self.present.get() {
50             unsafe { (*self.cell.get()).assume_init_ref() }
51         } else {
52             println!("acquire spinlock on static init");
53             let _guard = self.lock.acquire();
54             println!("rechecking present...");
55             std::thread::sleep(std::time::Duration::from_millis(200));
56             if self.present.get() {
57                 return unsafe { (*self.cell.get()).assume_init_ref() };
58             }
59             println!("doing init");
60             let ptr = self.cell.get().cast::<T>();
61             match self.init.take() {
62                 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },
63                 None => unsafe { core::hint::unreachable_unchecked() },
64             }
65             self.present.set(true);
66             unsafe { (*self.cell.get()).assume_init_ref() }
67         }
68     }
69 }
70 
71 pub struct CountInit;
72 
73 unsafe impl PinInit<CMutex<usize>> for CountInit {
74     unsafe fn __pinned_init(
75         self,
76         slot: *mut CMutex<usize>,
77     ) -> Result<(), core::convert::Infallible> {
78         let init = CMutex::new(0);
79         std::thread::sleep(std::time::Duration::from_millis(1000));
80         unsafe { init.__pinned_init(slot) }
81     }
82 }
83 
84 pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);
85 
86 fn main() {
87     #[cfg(feature = "std")]
88     {
89         let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
90         let mut handles = vec![];
91         let thread_count = 20;
92         let workload = 1_000;
93         for i in 0..thread_count {
94             let mtx = mtx.clone();
95             handles.push(
96                 Builder::new()
97                     .name(format!("worker #{i}"))
98                     .spawn(move || {
99                         for _ in 0..workload {
100                             *COUNT.lock() += 1;
101                             std::thread::sleep(std::time::Duration::from_millis(10));
102                             *mtx.lock() += 1;
103                             std::thread::sleep(std::time::Duration::from_millis(10));
104                             *COUNT.lock() += 1;
105                         }
106                         println!("{i} halfway");
107                         sleep(Duration::from_millis((i as u64) * 10));
108                         for _ in 0..workload {
109                             std::thread::sleep(std::time::Duration::from_millis(10));
110                             *mtx.lock() += 1;
111                         }
112                         println!("{i} finished");
113                     })
114                     .expect("should not fail"),
115             );
116         }
117         for h in handles {
118             h.join().expect("thread panicked");
119         }
120         println!("{:?}, {:?}", *mtx.lock(), *COUNT.lock());
121         assert_eq!(*mtx.lock(), workload * thread_count * 2);
122     }
123 }
124