184837cf6SBenno Lossin // SPDX-License-Identifier: Apache-2.0 OR MIT
284837cf6SBenno Lossin
384837cf6SBenno Lossin #![allow(clippy::undocumented_unsafe_blocks)]
484837cf6SBenno Lossin #![cfg_attr(feature = "alloc", feature(allocator_api))]
5*5c4167b4SBenno Lossin #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
684837cf6SBenno Lossin
784837cf6SBenno Lossin use core::{
884837cf6SBenno Lossin cell::{Cell, UnsafeCell},
984837cf6SBenno Lossin mem::MaybeUninit,
1084837cf6SBenno Lossin ops,
1184837cf6SBenno Lossin pin::Pin,
1284837cf6SBenno Lossin time::Duration,
1384837cf6SBenno Lossin };
1484837cf6SBenno Lossin use pin_init::*;
1584837cf6SBenno Lossin use std::{
1684837cf6SBenno Lossin sync::Arc,
1784837cf6SBenno Lossin thread::{sleep, Builder},
1884837cf6SBenno Lossin };
1984837cf6SBenno Lossin
2084837cf6SBenno Lossin #[expect(unused_attributes)]
2184837cf6SBenno Lossin mod mutex;
2284837cf6SBenno Lossin use mutex::*;
2384837cf6SBenno Lossin
2484837cf6SBenno Lossin pub struct StaticInit<T, I> {
2584837cf6SBenno Lossin cell: UnsafeCell<MaybeUninit<T>>,
2684837cf6SBenno Lossin init: Cell<Option<I>>,
2784837cf6SBenno Lossin lock: SpinLock,
2884837cf6SBenno Lossin present: Cell<bool>,
2984837cf6SBenno Lossin }
3084837cf6SBenno Lossin
3184837cf6SBenno Lossin unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}
3284837cf6SBenno Lossin unsafe impl<T: Send, I> Send for StaticInit<T, I> {}
3384837cf6SBenno Lossin
3484837cf6SBenno Lossin impl<T, I: PinInit<T>> StaticInit<T, I> {
new(init: I) -> Self3584837cf6SBenno Lossin pub const fn new(init: I) -> Self {
3684837cf6SBenno Lossin Self {
3784837cf6SBenno Lossin cell: UnsafeCell::new(MaybeUninit::uninit()),
3884837cf6SBenno Lossin init: Cell::new(Some(init)),
3984837cf6SBenno Lossin lock: SpinLock::new(),
4084837cf6SBenno Lossin present: Cell::new(false),
4184837cf6SBenno Lossin }
4284837cf6SBenno Lossin }
4384837cf6SBenno Lossin }
4484837cf6SBenno Lossin
4584837cf6SBenno Lossin impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {
4684837cf6SBenno Lossin type Target = T;
deref(&self) -> &Self::Target4784837cf6SBenno Lossin fn deref(&self) -> &Self::Target {
4884837cf6SBenno Lossin if self.present.get() {
4984837cf6SBenno Lossin unsafe { (*self.cell.get()).assume_init_ref() }
5084837cf6SBenno Lossin } else {
5184837cf6SBenno Lossin println!("acquire spinlock on static init");
5284837cf6SBenno Lossin let _guard = self.lock.acquire();
5384837cf6SBenno Lossin println!("rechecking present...");
5484837cf6SBenno Lossin std::thread::sleep(std::time::Duration::from_millis(200));
5584837cf6SBenno Lossin if self.present.get() {
5684837cf6SBenno Lossin return unsafe { (*self.cell.get()).assume_init_ref() };
5784837cf6SBenno Lossin }
5884837cf6SBenno Lossin println!("doing init");
5984837cf6SBenno Lossin let ptr = self.cell.get().cast::<T>();
6084837cf6SBenno Lossin match self.init.take() {
6184837cf6SBenno Lossin Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },
6284837cf6SBenno Lossin None => unsafe { core::hint::unreachable_unchecked() },
6384837cf6SBenno Lossin }
6484837cf6SBenno Lossin self.present.set(true);
6584837cf6SBenno Lossin unsafe { (*self.cell.get()).assume_init_ref() }
6684837cf6SBenno Lossin }
6784837cf6SBenno Lossin }
6884837cf6SBenno Lossin }
6984837cf6SBenno Lossin
7084837cf6SBenno Lossin pub struct CountInit;
7184837cf6SBenno Lossin
7284837cf6SBenno Lossin unsafe impl PinInit<CMutex<usize>> for CountInit {
__pinned_init( self, slot: *mut CMutex<usize>, ) -> Result<(), core::convert::Infallible>7384837cf6SBenno Lossin unsafe fn __pinned_init(
7484837cf6SBenno Lossin self,
7584837cf6SBenno Lossin slot: *mut CMutex<usize>,
7684837cf6SBenno Lossin ) -> Result<(), core::convert::Infallible> {
7784837cf6SBenno Lossin let init = CMutex::new(0);
7884837cf6SBenno Lossin std::thread::sleep(std::time::Duration::from_millis(1000));
7984837cf6SBenno Lossin unsafe { init.__pinned_init(slot) }
8084837cf6SBenno Lossin }
8184837cf6SBenno Lossin }
8284837cf6SBenno Lossin
8384837cf6SBenno Lossin pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);
8484837cf6SBenno Lossin
8584837cf6SBenno Lossin #[cfg(not(any(feature = "std", feature = "alloc")))]
main()8684837cf6SBenno Lossin fn main() {}
8784837cf6SBenno Lossin
8884837cf6SBenno Lossin #[cfg(any(feature = "std", feature = "alloc"))]
main()8984837cf6SBenno Lossin fn main() {
9084837cf6SBenno Lossin let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
9184837cf6SBenno Lossin let mut handles = vec![];
9284837cf6SBenno Lossin let thread_count = 20;
9384837cf6SBenno Lossin let workload = 1_000;
9484837cf6SBenno Lossin for i in 0..thread_count {
9584837cf6SBenno Lossin let mtx = mtx.clone();
9684837cf6SBenno Lossin handles.push(
9784837cf6SBenno Lossin Builder::new()
9884837cf6SBenno Lossin .name(format!("worker #{i}"))
9984837cf6SBenno Lossin .spawn(move || {
10084837cf6SBenno Lossin for _ in 0..workload {
10184837cf6SBenno Lossin *COUNT.lock() += 1;
10284837cf6SBenno Lossin std::thread::sleep(std::time::Duration::from_millis(10));
10384837cf6SBenno Lossin *mtx.lock() += 1;
10484837cf6SBenno Lossin std::thread::sleep(std::time::Duration::from_millis(10));
10584837cf6SBenno Lossin *COUNT.lock() += 1;
10684837cf6SBenno Lossin }
10784837cf6SBenno Lossin println!("{i} halfway");
10884837cf6SBenno Lossin sleep(Duration::from_millis((i as u64) * 10));
10984837cf6SBenno Lossin for _ in 0..workload {
11084837cf6SBenno Lossin std::thread::sleep(std::time::Duration::from_millis(10));
11184837cf6SBenno Lossin *mtx.lock() += 1;
11284837cf6SBenno Lossin }
11384837cf6SBenno Lossin println!("{i} finished");
11484837cf6SBenno Lossin })
11584837cf6SBenno Lossin .expect("should not fail"),
11684837cf6SBenno Lossin );
11784837cf6SBenno Lossin }
11884837cf6SBenno Lossin for h in handles {
11984837cf6SBenno Lossin h.join().expect("thread panicked");
12084837cf6SBenno Lossin }
12184837cf6SBenno Lossin println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock());
12284837cf6SBenno Lossin assert_eq!(*mtx.lock(), workload * thread_count * 2);
12384837cf6SBenno Lossin }
124