1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 #![allow(clippy::undocumented_unsafe_blocks)] 4 #![cfg_attr(feature = "alloc", feature(allocator_api))] 5 #![cfg_attr(USE_RUSTC_FEATURES, feature(lint_reasons))] 6 #![cfg_attr(USE_RUSTC_FEATURES, feature(raw_ref_op))] 7 #![allow(unused_imports)] 8 9 use core::{ 10 cell::{Cell, UnsafeCell}, 11 mem::MaybeUninit, 12 ops, 13 pin::Pin, 14 time::Duration, 15 }; 16 use pin_init::*; 17 #[cfg(feature = "std")] 18 use std::{ 19 sync::Arc, 20 thread::{sleep, Builder}, 21 }; 22 23 #[allow(unused_attributes)] 24 mod mutex; 25 use mutex::*; 26 27 pub struct StaticInit<T, I> { 28 cell: UnsafeCell<MaybeUninit<T>>, 29 init: Cell<Option<I>>, 30 lock: SpinLock, 31 present: Cell<bool>, 32 } 33 34 unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {} 35 unsafe impl<T: Send, I> Send for StaticInit<T, I> {} 36 37 impl<T, I: PinInit<T>> StaticInit<T, I> { 38 pub const fn new(init: I) -> Self { 39 Self { 40 cell: UnsafeCell::new(MaybeUninit::uninit()), 41 init: Cell::new(Some(init)), 42 lock: SpinLock::new(), 43 present: Cell::new(false), 44 } 45 } 46 } 47 48 impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> { 49 type Target = T; 50 fn deref(&self) -> &Self::Target { 51 if self.present.get() { 52 unsafe { (*self.cell.get()).assume_init_ref() } 53 } else { 54 println!("acquire spinlock on static init"); 55 let _guard = self.lock.acquire(); 56 println!("rechecking present..."); 57 std::thread::sleep(std::time::Duration::from_millis(200)); 58 if self.present.get() { 59 return unsafe { (*self.cell.get()).assume_init_ref() }; 60 } 61 println!("doing init"); 62 let ptr = self.cell.get().cast::<T>(); 63 match self.init.take() { 64 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() }, 65 None => unsafe { core::hint::unreachable_unchecked() }, 66 } 67 self.present.set(true); 68 unsafe { (*self.cell.get()).assume_init_ref() } 69 } 70 } 71 } 72 73 pub struct CountInit; 74 75 unsafe impl PinInit<CMutex<usize>> for CountInit { 76 unsafe fn __pinned_init( 77 self, 78 slot: *mut CMutex<usize>, 79 ) -> Result<(), core::convert::Infallible> { 80 let init = CMutex::new(0); 81 std::thread::sleep(std::time::Duration::from_millis(1000)); 82 unsafe { init.__pinned_init(slot) } 83 } 84 } 85 86 pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit); 87 88 fn main() { 89 #[cfg(feature = "std")] 90 { 91 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); 92 let mut handles = vec![]; 93 let thread_count = 20; 94 let workload = 1_000; 95 for i in 0..thread_count { 96 let mtx = mtx.clone(); 97 handles.push( 98 Builder::new() 99 .name(format!("worker #{i}")) 100 .spawn(move || { 101 for _ in 0..workload { 102 *COUNT.lock() += 1; 103 std::thread::sleep(std::time::Duration::from_millis(10)); 104 *mtx.lock() += 1; 105 std::thread::sleep(std::time::Duration::from_millis(10)); 106 *COUNT.lock() += 1; 107 } 108 println!("{i} halfway"); 109 sleep(Duration::from_millis((i as u64) * 10)); 110 for _ in 0..workload { 111 std::thread::sleep(std::time::Duration::from_millis(10)); 112 *mtx.lock() += 1; 113 } 114 println!("{i} finished"); 115 }) 116 .expect("should not fail"), 117 ); 118 } 119 for h in handles { 120 h.join().expect("thread panicked"); 121 } 122 println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock()); 123 assert_eq!(*mtx.lock(), workload * thread_count * 2); 124 } 125 } 126