1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 #![allow(clippy::undocumented_unsafe_blocks)] 4 #![cfg_attr(feature = "alloc", feature(allocator_api))] 5 #![allow(unused_imports)] 6 7 use core::{ 8 cell::{Cell, UnsafeCell}, 9 mem::MaybeUninit, 10 ops, 11 pin::Pin, 12 time::Duration, 13 }; 14 use pin_init::*; 15 #[cfg(feature = "std")] 16 use std::{ 17 sync::Arc, 18 thread::{sleep, Builder}, 19 }; 20 21 #[allow(unused_attributes)] 22 mod mutex; 23 use mutex::*; 24 25 pub struct StaticInit<T, I> { 26 cell: UnsafeCell<MaybeUninit<T>>, 27 init: Cell<Option<I>>, 28 lock: SpinLock, 29 present: Cell<bool>, 30 } 31 32 unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {} 33 unsafe impl<T: Send, I> Send for StaticInit<T, I> {} 34 35 impl<T, I: PinInit<T>> StaticInit<T, I> { 36 pub const fn new(init: I) -> Self { 37 Self { 38 cell: UnsafeCell::new(MaybeUninit::uninit()), 39 init: Cell::new(Some(init)), 40 lock: SpinLock::new(), 41 present: Cell::new(false), 42 } 43 } 44 } 45 46 impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> { 47 type Target = T; 48 fn deref(&self) -> &Self::Target { 49 if self.present.get() { 50 unsafe { (*self.cell.get()).assume_init_ref() } 51 } else { 52 println!("acquire spinlock on static init"); 53 let _guard = self.lock.acquire(); 54 println!("rechecking present..."); 55 std::thread::sleep(std::time::Duration::from_millis(200)); 56 if self.present.get() { 57 return unsafe { (*self.cell.get()).assume_init_ref() }; 58 } 59 println!("doing init"); 60 let ptr = self.cell.get().cast::<T>(); 61 match self.init.take() { 62 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() }, 63 None => unsafe { core::hint::unreachable_unchecked() }, 64 } 65 self.present.set(true); 66 unsafe { (*self.cell.get()).assume_init_ref() } 67 } 68 } 69 } 70 71 pub struct CountInit; 72 73 unsafe impl PinInit<CMutex<usize>> for CountInit { 74 unsafe fn __pinned_init( 75 self, 76 slot: *mut CMutex<usize>, 77 ) -> Result<(), core::convert::Infallible> { 78 let init = CMutex::new(0); 79 std::thread::sleep(std::time::Duration::from_millis(1000)); 80 unsafe { init.__pinned_init(slot) } 81 } 82 } 83 84 pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit); 85 86 fn main() { 87 #[cfg(feature = "std")] 88 { 89 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); 90 let mut handles = vec![]; 91 let thread_count = 20; 92 let workload = 1_000; 93 for i in 0..thread_count { 94 let mtx = mtx.clone(); 95 handles.push( 96 Builder::new() 97 .name(format!("worker #{i}")) 98 .spawn(move || { 99 for _ in 0..workload { 100 *COUNT.lock() += 1; 101 std::thread::sleep(std::time::Duration::from_millis(10)); 102 *mtx.lock() += 1; 103 std::thread::sleep(std::time::Duration::from_millis(10)); 104 *COUNT.lock() += 1; 105 } 106 println!("{i} halfway"); 107 sleep(Duration::from_millis((i as u64) * 10)); 108 for _ in 0..workload { 109 std::thread::sleep(std::time::Duration::from_millis(10)); 110 *mtx.lock() += 1; 111 } 112 println!("{i} finished"); 113 }) 114 .expect("should not fail"), 115 ); 116 } 117 for h in handles { 118 h.join().expect("thread panicked"); 119 } 120 println!("{:?}, {:?}", *mtx.lock(), *COUNT.lock()); 121 assert_eq!(*mtx.lock(), workload * thread_count * 2); 122 } 123 } 124