1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 #![allow(clippy::undocumented_unsafe_blocks)] 4 #![cfg_attr(feature = "alloc", feature(allocator_api))] 5 #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))] 6 7 use core::{ 8 cell::{Cell, UnsafeCell}, 9 mem::MaybeUninit, 10 ops, 11 pin::Pin, 12 time::Duration, 13 }; 14 use pin_init::*; 15 use std::{ 16 sync::Arc, 17 thread::{sleep, Builder}, 18 }; 19 20 #[expect(unused_attributes)] 21 mod mutex; 22 use mutex::*; 23 24 pub struct StaticInit<T, I> { 25 cell: UnsafeCell<MaybeUninit<T>>, 26 init: Cell<Option<I>>, 27 lock: SpinLock, 28 present: Cell<bool>, 29 } 30 31 unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {} 32 unsafe impl<T: Send, I> Send for StaticInit<T, I> {} 33 34 impl<T, I: PinInit<T>> StaticInit<T, I> { 35 pub const fn new(init: I) -> Self { 36 Self { 37 cell: UnsafeCell::new(MaybeUninit::uninit()), 38 init: Cell::new(Some(init)), 39 lock: SpinLock::new(), 40 present: Cell::new(false), 41 } 42 } 43 } 44 45 impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> { 46 type Target = T; 47 fn deref(&self) -> &Self::Target { 48 if self.present.get() { 49 unsafe { (*self.cell.get()).assume_init_ref() } 50 } else { 51 println!("acquire spinlock on static init"); 52 let _guard = self.lock.acquire(); 53 println!("rechecking present..."); 54 std::thread::sleep(std::time::Duration::from_millis(200)); 55 if self.present.get() { 56 return unsafe { (*self.cell.get()).assume_init_ref() }; 57 } 58 println!("doing init"); 59 let ptr = self.cell.get().cast::<T>(); 60 match self.init.take() { 61 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() }, 62 None => unsafe { core::hint::unreachable_unchecked() }, 63 } 64 self.present.set(true); 65 unsafe { (*self.cell.get()).assume_init_ref() } 66 } 67 } 68 } 69 70 pub struct CountInit; 71 72 unsafe impl PinInit<CMutex<usize>> for CountInit { 73 unsafe fn __pinned_init( 74 self, 75 slot: *mut CMutex<usize>, 76 ) -> Result<(), core::convert::Infallible> { 77 let init = CMutex::new(0); 78 std::thread::sleep(std::time::Duration::from_millis(1000)); 79 unsafe { init.__pinned_init(slot) } 80 } 81 } 82 83 pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit); 84 85 #[cfg(not(any(feature = "std", feature = "alloc")))] 86 fn main() {} 87 88 #[cfg(any(feature = "std", feature = "alloc"))] 89 fn main() { 90 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); 91 let mut handles = vec![]; 92 let thread_count = 20; 93 let workload = 1_000; 94 for i in 0..thread_count { 95 let mtx = mtx.clone(); 96 handles.push( 97 Builder::new() 98 .name(format!("worker #{i}")) 99 .spawn(move || { 100 for _ in 0..workload { 101 *COUNT.lock() += 1; 102 std::thread::sleep(std::time::Duration::from_millis(10)); 103 *mtx.lock() += 1; 104 std::thread::sleep(std::time::Duration::from_millis(10)); 105 *COUNT.lock() += 1; 106 } 107 println!("{i} halfway"); 108 sleep(Duration::from_millis((i as u64) * 10)); 109 for _ in 0..workload { 110 std::thread::sleep(std::time::Duration::from_millis(10)); 111 *mtx.lock() += 1; 112 } 113 println!("{i} finished"); 114 }) 115 .expect("should not fail"), 116 ); 117 } 118 for h in handles { 119 h.join().expect("thread panicked"); 120 } 121 println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock()); 122 assert_eq!(*mtx.lock(), workload * thread_count * 2); 123 } 124