xref: /linux/rust/pin-init/examples/static_init.rs (revision ec7714e4947909190ffb3041a03311a975350fe0)
184837cf6SBenno Lossin // SPDX-License-Identifier: Apache-2.0 OR MIT
284837cf6SBenno Lossin 
384837cf6SBenno Lossin #![allow(clippy::undocumented_unsafe_blocks)]
484837cf6SBenno Lossin #![cfg_attr(feature = "alloc", feature(allocator_api))]
5*5c4167b4SBenno Lossin #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
684837cf6SBenno Lossin 
784837cf6SBenno Lossin use core::{
884837cf6SBenno Lossin     cell::{Cell, UnsafeCell},
984837cf6SBenno Lossin     mem::MaybeUninit,
1084837cf6SBenno Lossin     ops,
1184837cf6SBenno Lossin     pin::Pin,
1284837cf6SBenno Lossin     time::Duration,
1384837cf6SBenno Lossin };
1484837cf6SBenno Lossin use pin_init::*;
1584837cf6SBenno Lossin use std::{
1684837cf6SBenno Lossin     sync::Arc,
1784837cf6SBenno Lossin     thread::{sleep, Builder},
1884837cf6SBenno Lossin };
1984837cf6SBenno Lossin 
2084837cf6SBenno Lossin #[expect(unused_attributes)]
2184837cf6SBenno Lossin mod mutex;
2284837cf6SBenno Lossin use mutex::*;
2384837cf6SBenno Lossin 
2484837cf6SBenno Lossin pub struct StaticInit<T, I> {
2584837cf6SBenno Lossin     cell: UnsafeCell<MaybeUninit<T>>,
2684837cf6SBenno Lossin     init: Cell<Option<I>>,
2784837cf6SBenno Lossin     lock: SpinLock,
2884837cf6SBenno Lossin     present: Cell<bool>,
2984837cf6SBenno Lossin }
3084837cf6SBenno Lossin 
3184837cf6SBenno Lossin unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}
3284837cf6SBenno Lossin unsafe impl<T: Send, I> Send for StaticInit<T, I> {}
3384837cf6SBenno Lossin 
3484837cf6SBenno Lossin impl<T, I: PinInit<T>> StaticInit<T, I> {
new(init: I) -> Self3584837cf6SBenno Lossin     pub const fn new(init: I) -> Self {
3684837cf6SBenno Lossin         Self {
3784837cf6SBenno Lossin             cell: UnsafeCell::new(MaybeUninit::uninit()),
3884837cf6SBenno Lossin             init: Cell::new(Some(init)),
3984837cf6SBenno Lossin             lock: SpinLock::new(),
4084837cf6SBenno Lossin             present: Cell::new(false),
4184837cf6SBenno Lossin         }
4284837cf6SBenno Lossin     }
4384837cf6SBenno Lossin }
4484837cf6SBenno Lossin 
4584837cf6SBenno Lossin impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {
4684837cf6SBenno Lossin     type Target = T;
deref(&self) -> &Self::Target4784837cf6SBenno Lossin     fn deref(&self) -> &Self::Target {
4884837cf6SBenno Lossin         if self.present.get() {
4984837cf6SBenno Lossin             unsafe { (*self.cell.get()).assume_init_ref() }
5084837cf6SBenno Lossin         } else {
5184837cf6SBenno Lossin             println!("acquire spinlock on static init");
5284837cf6SBenno Lossin             let _guard = self.lock.acquire();
5384837cf6SBenno Lossin             println!("rechecking present...");
5484837cf6SBenno Lossin             std::thread::sleep(std::time::Duration::from_millis(200));
5584837cf6SBenno Lossin             if self.present.get() {
5684837cf6SBenno Lossin                 return unsafe { (*self.cell.get()).assume_init_ref() };
5784837cf6SBenno Lossin             }
5884837cf6SBenno Lossin             println!("doing init");
5984837cf6SBenno Lossin             let ptr = self.cell.get().cast::<T>();
6084837cf6SBenno Lossin             match self.init.take() {
6184837cf6SBenno Lossin                 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },
6284837cf6SBenno Lossin                 None => unsafe { core::hint::unreachable_unchecked() },
6384837cf6SBenno Lossin             }
6484837cf6SBenno Lossin             self.present.set(true);
6584837cf6SBenno Lossin             unsafe { (*self.cell.get()).assume_init_ref() }
6684837cf6SBenno Lossin         }
6784837cf6SBenno Lossin     }
6884837cf6SBenno Lossin }
6984837cf6SBenno Lossin 
7084837cf6SBenno Lossin pub struct CountInit;
7184837cf6SBenno Lossin 
7284837cf6SBenno Lossin unsafe impl PinInit<CMutex<usize>> for CountInit {
__pinned_init( self, slot: *mut CMutex<usize>, ) -> Result<(), core::convert::Infallible>7384837cf6SBenno Lossin     unsafe fn __pinned_init(
7484837cf6SBenno Lossin         self,
7584837cf6SBenno Lossin         slot: *mut CMutex<usize>,
7684837cf6SBenno Lossin     ) -> Result<(), core::convert::Infallible> {
7784837cf6SBenno Lossin         let init = CMutex::new(0);
7884837cf6SBenno Lossin         std::thread::sleep(std::time::Duration::from_millis(1000));
7984837cf6SBenno Lossin         unsafe { init.__pinned_init(slot) }
8084837cf6SBenno Lossin     }
8184837cf6SBenno Lossin }
8284837cf6SBenno Lossin 
8384837cf6SBenno Lossin pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);
8484837cf6SBenno Lossin 
8584837cf6SBenno Lossin #[cfg(not(any(feature = "std", feature = "alloc")))]
main()8684837cf6SBenno Lossin fn main() {}
8784837cf6SBenno Lossin 
8884837cf6SBenno Lossin #[cfg(any(feature = "std", feature = "alloc"))]
main()8984837cf6SBenno Lossin fn main() {
9084837cf6SBenno Lossin     let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
9184837cf6SBenno Lossin     let mut handles = vec![];
9284837cf6SBenno Lossin     let thread_count = 20;
9384837cf6SBenno Lossin     let workload = 1_000;
9484837cf6SBenno Lossin     for i in 0..thread_count {
9584837cf6SBenno Lossin         let mtx = mtx.clone();
9684837cf6SBenno Lossin         handles.push(
9784837cf6SBenno Lossin             Builder::new()
9884837cf6SBenno Lossin                 .name(format!("worker #{i}"))
9984837cf6SBenno Lossin                 .spawn(move || {
10084837cf6SBenno Lossin                     for _ in 0..workload {
10184837cf6SBenno Lossin                         *COUNT.lock() += 1;
10284837cf6SBenno Lossin                         std::thread::sleep(std::time::Duration::from_millis(10));
10384837cf6SBenno Lossin                         *mtx.lock() += 1;
10484837cf6SBenno Lossin                         std::thread::sleep(std::time::Duration::from_millis(10));
10584837cf6SBenno Lossin                         *COUNT.lock() += 1;
10684837cf6SBenno Lossin                     }
10784837cf6SBenno Lossin                     println!("{i} halfway");
10884837cf6SBenno Lossin                     sleep(Duration::from_millis((i as u64) * 10));
10984837cf6SBenno Lossin                     for _ in 0..workload {
11084837cf6SBenno Lossin                         std::thread::sleep(std::time::Duration::from_millis(10));
11184837cf6SBenno Lossin                         *mtx.lock() += 1;
11284837cf6SBenno Lossin                     }
11384837cf6SBenno Lossin                     println!("{i} finished");
11484837cf6SBenno Lossin                 })
11584837cf6SBenno Lossin                 .expect("should not fail"),
11684837cf6SBenno Lossin         );
11784837cf6SBenno Lossin     }
11884837cf6SBenno Lossin     for h in handles {
11984837cf6SBenno Lossin         h.join().expect("thread panicked");
12084837cf6SBenno Lossin     }
12184837cf6SBenno Lossin     println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock());
12284837cf6SBenno Lossin     assert_eq!(*mtx.lock(), workload * thread_count * 2);
12384837cf6SBenno Lossin }
124