1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 #![allow(clippy::undocumented_unsafe_blocks)] 4 #![cfg_attr(feature = "alloc", feature(allocator_api))] 5 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))] 6 #![allow(unused_imports)] 7 8 use core::{ 9 cell::{Cell, UnsafeCell}, 10 mem::MaybeUninit, 11 ops, 12 pin::Pin, 13 time::Duration, 14 }; 15 use pin_init::*; 16 #[cfg(feature = "std")] 17 use std::{ 18 sync::Arc, 19 thread::{sleep, Builder}, 20 }; 21 22 #[allow(unused_attributes)] 23 mod mutex; 24 use mutex::*; 25 26 pub struct StaticInit<T, I> { 27 cell: UnsafeCell<MaybeUninit<T>>, 28 init: Cell<Option<I>>, 29 lock: SpinLock, 30 present: Cell<bool>, 31 } 32 33 unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {} 34 unsafe impl<T: Send, I> Send for StaticInit<T, I> {} 35 36 impl<T, I: PinInit<T>> StaticInit<T, I> { 37 pub const fn new(init: I) -> Self { 38 Self { 39 cell: UnsafeCell::new(MaybeUninit::uninit()), 40 init: Cell::new(Some(init)), 41 lock: SpinLock::new(), 42 present: Cell::new(false), 43 } 44 } 45 } 46 47 impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> { 48 type Target = T; 49 fn deref(&self) -> &Self::Target { 50 if self.present.get() { 51 unsafe { (*self.cell.get()).assume_init_ref() } 52 } else { 53 println!("acquire spinlock on static init"); 54 let _guard = self.lock.acquire(); 55 println!("rechecking present..."); 56 std::thread::sleep(std::time::Duration::from_millis(200)); 57 if self.present.get() { 58 return unsafe { (*self.cell.get()).assume_init_ref() }; 59 } 60 println!("doing init"); 61 let ptr = self.cell.get().cast::<T>(); 62 match self.init.take() { 63 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() }, 64 None => unsafe { core::hint::unreachable_unchecked() }, 65 } 66 self.present.set(true); 67 unsafe { (*self.cell.get()).assume_init_ref() } 68 } 69 } 70 } 71 72 pub struct CountInit; 73 74 unsafe impl PinInit<CMutex<usize>> for CountInit { 75 unsafe fn __pinned_init( 76 self, 77 slot: *mut CMutex<usize>, 78 ) -> Result<(), core::convert::Infallible> { 79 let init = CMutex::new(0); 80 std::thread::sleep(std::time::Duration::from_millis(1000)); 81 unsafe { init.__pinned_init(slot) } 82 } 83 } 84 85 pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit); 86 87 fn main() { 88 #[cfg(feature = "std")] 89 { 90 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); 91 let mut handles = vec![]; 92 let thread_count = 20; 93 let workload = 1_000; 94 for i in 0..thread_count { 95 let mtx = mtx.clone(); 96 handles.push( 97 Builder::new() 98 .name(format!("worker #{i}")) 99 .spawn(move || { 100 for _ in 0..workload { 101 *COUNT.lock() += 1; 102 std::thread::sleep(std::time::Duration::from_millis(10)); 103 *mtx.lock() += 1; 104 std::thread::sleep(std::time::Duration::from_millis(10)); 105 *COUNT.lock() += 1; 106 } 107 println!("{i} halfway"); 108 sleep(Duration::from_millis((i as u64) * 10)); 109 for _ in 0..workload { 110 std::thread::sleep(std::time::Duration::from_millis(10)); 111 *mtx.lock() += 1; 112 } 113 println!("{i} finished"); 114 }) 115 .expect("should not fail"), 116 ); 117 } 118 for h in handles { 119 h.join().expect("thread panicked"); 120 } 121 println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock()); 122 assert_eq!(*mtx.lock(), workload * thread_count * 2); 123 } 124 } 125