1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 #![allow(clippy::undocumented_unsafe_blocks)] 4 #![cfg_attr(feature = "alloc", feature(allocator_api))] 5 6 use core::{ 7 cell::{Cell, UnsafeCell}, 8 mem::MaybeUninit, 9 ops, 10 pin::Pin, 11 time::Duration, 12 }; 13 use pin_init::*; 14 use std::{ 15 sync::Arc, 16 thread::{sleep, Builder}, 17 }; 18 19 #[expect(unused_attributes)] 20 mod mutex; 21 use mutex::*; 22 23 pub struct StaticInit<T, I> { 24 cell: UnsafeCell<MaybeUninit<T>>, 25 init: Cell<Option<I>>, 26 lock: SpinLock, 27 present: Cell<bool>, 28 } 29 30 unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {} 31 unsafe impl<T: Send, I> Send for StaticInit<T, I> {} 32 33 impl<T, I: PinInit<T>> StaticInit<T, I> { 34 pub const fn new(init: I) -> Self { 35 Self { 36 cell: UnsafeCell::new(MaybeUninit::uninit()), 37 init: Cell::new(Some(init)), 38 lock: SpinLock::new(), 39 present: Cell::new(false), 40 } 41 } 42 } 43 44 impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> { 45 type Target = T; 46 fn deref(&self) -> &Self::Target { 47 if self.present.get() { 48 unsafe { (*self.cell.get()).assume_init_ref() } 49 } else { 50 println!("acquire spinlock on static init"); 51 let _guard = self.lock.acquire(); 52 println!("rechecking present..."); 53 std::thread::sleep(std::time::Duration::from_millis(200)); 54 if self.present.get() { 55 return unsafe { (*self.cell.get()).assume_init_ref() }; 56 } 57 println!("doing init"); 58 let ptr = self.cell.get().cast::<T>(); 59 match self.init.take() { 60 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() }, 61 None => unsafe { core::hint::unreachable_unchecked() }, 62 } 63 self.present.set(true); 64 unsafe { (*self.cell.get()).assume_init_ref() } 65 } 66 } 67 } 68 69 pub struct CountInit; 70 71 unsafe impl PinInit<CMutex<usize>> for CountInit { 72 unsafe fn __pinned_init( 73 self, 74 slot: *mut CMutex<usize>, 75 ) -> Result<(), core::convert::Infallible> { 76 let init = CMutex::new(0); 77 std::thread::sleep(std::time::Duration::from_millis(1000)); 78 unsafe { init.__pinned_init(slot) } 79 } 80 } 81 82 pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit); 83 84 #[cfg(not(any(feature = "std", feature = "alloc")))] 85 fn main() {} 86 87 #[cfg(any(feature = "std", feature = "alloc"))] 88 fn main() { 89 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); 90 let mut handles = vec![]; 91 let thread_count = 20; 92 let workload = 1_000; 93 for i in 0..thread_count { 94 let mtx = mtx.clone(); 95 handles.push( 96 Builder::new() 97 .name(format!("worker #{i}")) 98 .spawn(move || { 99 for _ in 0..workload { 100 *COUNT.lock() += 1; 101 std::thread::sleep(std::time::Duration::from_millis(10)); 102 *mtx.lock() += 1; 103 std::thread::sleep(std::time::Duration::from_millis(10)); 104 *COUNT.lock() += 1; 105 } 106 println!("{i} halfway"); 107 sleep(Duration::from_millis((i as u64) * 10)); 108 for _ in 0..workload { 109 std::thread::sleep(std::time::Duration::from_millis(10)); 110 *mtx.lock() += 1; 111 } 112 println!("{i} finished"); 113 }) 114 .expect("should not fail"), 115 ); 116 } 117 for h in handles { 118 h.join().expect("thread panicked"); 119 } 120 println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock()); 121 assert_eq!(*mtx.lock(), workload * thread_count * 2); 122 } 123