xref: /linux/rust/kernel/alloc/vec_ext.rs (revision 24168c5e6dfbdd5b414f048f47f75d64533296ca)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 //! Extensions to [`Vec`] for fallible allocations.
4 
5 use super::{AllocError, Flags};
6 use alloc::vec::Vec;
7 use core::ptr;
8 
9 /// Extensions to [`Vec`].
10 pub trait VecExt<T>: Sized {
11     /// Creates a new [`Vec`] instance with at least the given capacity.
12     ///
13     /// # Examples
14     ///
15     /// ```
16     /// let v = Vec::<u32>::with_capacity(20, GFP_KERNEL)?;
17     ///
18     /// assert!(v.capacity() >= 20);
19     /// # Ok::<(), Error>(())
20     /// ```
21     fn with_capacity(capacity: usize, flags: Flags) -> Result<Self, AllocError>;
22 
23     /// Appends an element to the back of the [`Vec`] instance.
24     ///
25     /// # Examples
26     ///
27     /// ```
28     /// let mut v = Vec::new();
29     /// v.push(1, GFP_KERNEL)?;
30     /// assert_eq!(&v, &[1]);
31     ///
32     /// v.push(2, GFP_KERNEL)?;
33     /// assert_eq!(&v, &[1, 2]);
34     /// # Ok::<(), Error>(())
35     /// ```
36     fn push(&mut self, v: T, flags: Flags) -> Result<(), AllocError>;
37 
38     /// Pushes clones of the elements of slice into the [`Vec`] instance.
39     ///
40     /// # Examples
41     ///
42     /// ```
43     /// let mut v = Vec::new();
44     /// v.push(1, GFP_KERNEL)?;
45     ///
46     /// v.extend_from_slice(&[20, 30, 40], GFP_KERNEL)?;
47     /// assert_eq!(&v, &[1, 20, 30, 40]);
48     ///
49     /// v.extend_from_slice(&[50, 60], GFP_KERNEL)?;
50     /// assert_eq!(&v, &[1, 20, 30, 40, 50, 60]);
51     /// # Ok::<(), Error>(())
52     /// ```
53     fn extend_from_slice(&mut self, other: &[T], flags: Flags) -> Result<(), AllocError>
54     where
55         T: Clone;
56 
57     /// Ensures that the capacity exceeds the length by at least `additional` elements.
58     ///
59     /// # Examples
60     ///
61     /// ```
62     /// let mut v = Vec::new();
63     /// v.push(1, GFP_KERNEL)?;
64     ///
65     /// v.reserve(10, GFP_KERNEL)?;
66     /// let cap = v.capacity();
67     /// assert!(cap >= 10);
68     ///
69     /// v.reserve(10, GFP_KERNEL)?;
70     /// let new_cap = v.capacity();
71     /// assert_eq!(new_cap, cap);
72     ///
73     /// # Ok::<(), Error>(())
74     /// ```
75     fn reserve(&mut self, additional: usize, flags: Flags) -> Result<(), AllocError>;
76 }
77 
78 impl<T> VecExt<T> for Vec<T> {
79     fn with_capacity(capacity: usize, flags: Flags) -> Result<Self, AllocError> {
80         let mut v = Vec::new();
81         <Self as VecExt<_>>::reserve(&mut v, capacity, flags)?;
82         Ok(v)
83     }
84 
85     fn push(&mut self, v: T, flags: Flags) -> Result<(), AllocError> {
86         <Self as VecExt<_>>::reserve(self, 1, flags)?;
87         let s = self.spare_capacity_mut();
88         s[0].write(v);
89 
90         // SAFETY: We just initialised the first spare entry, so it is safe to increase the length
91         // by 1. We also know that the new length is <= capacity because of the previous call to
92         // `reserve` above.
93         unsafe { self.set_len(self.len() + 1) };
94         Ok(())
95     }
96 
97     fn extend_from_slice(&mut self, other: &[T], flags: Flags) -> Result<(), AllocError>
98     where
99         T: Clone,
100     {
101         <Self as VecExt<_>>::reserve(self, other.len(), flags)?;
102         for (slot, item) in core::iter::zip(self.spare_capacity_mut(), other) {
103             slot.write(item.clone());
104         }
105 
106         // SAFETY: We just initialised the `other.len()` spare entries, so it is safe to increase
107         // the length by the same amount. We also know that the new length is <= capacity because
108         // of the previous call to `reserve` above.
109         unsafe { self.set_len(self.len() + other.len()) };
110         Ok(())
111     }
112 
113     #[cfg(any(test, testlib))]
114     fn reserve(&mut self, additional: usize, _flags: Flags) -> Result<(), AllocError> {
115         Vec::reserve(self, additional);
116         Ok(())
117     }
118 
119     #[cfg(not(any(test, testlib)))]
120     fn reserve(&mut self, additional: usize, flags: Flags) -> Result<(), AllocError> {
121         let len = self.len();
122         let cap = self.capacity();
123 
124         if cap - len >= additional {
125             return Ok(());
126         }
127 
128         if core::mem::size_of::<T>() == 0 {
129             // The capacity is already `usize::MAX` for SZTs, we can't go higher.
130             return Err(AllocError);
131         }
132 
133         // We know cap is <= `isize::MAX` because `Layout::array` fails if the resulting byte size
134         // is greater than `isize::MAX`. So the multiplication by two won't overflow.
135         let new_cap = core::cmp::max(cap * 2, len.checked_add(additional).ok_or(AllocError)?);
136         let layout = core::alloc::Layout::array::<T>(new_cap).map_err(|_| AllocError)?;
137 
138         let (old_ptr, len, cap) = destructure(self);
139 
140         // We need to make sure that `ptr` is either NULL or comes from a previous call to
141         // `krealloc_aligned`. A `Vec<T>`'s `ptr` value is not guaranteed to be NULL and might be
142         // dangling after being created with `Vec::new`. Instead, we can rely on `Vec<T>`'s capacity
143         // to be zero if no memory has been allocated yet.
144         let ptr = if cap == 0 { ptr::null_mut() } else { old_ptr };
145 
146         // SAFETY: `ptr` is valid because it's either NULL or comes from a previous call to
147         // `krealloc_aligned`. We also verified that the type is not a ZST.
148         let new_ptr = unsafe { super::allocator::krealloc_aligned(ptr.cast(), layout, flags) };
149         if new_ptr.is_null() {
150             // SAFETY: We are just rebuilding the existing `Vec` with no changes.
151             unsafe { rebuild(self, old_ptr, len, cap) };
152             Err(AllocError)
153         } else {
154             // SAFETY: `ptr` has been reallocated with the layout for `new_cap` elements. New cap
155             // is greater than `cap`, so it continues to be >= `len`.
156             unsafe { rebuild(self, new_ptr.cast::<T>(), len, new_cap) };
157             Ok(())
158         }
159     }
160 }
161 
162 #[cfg(not(any(test, testlib)))]
163 fn destructure<T>(v: &mut Vec<T>) -> (*mut T, usize, usize) {
164     let mut tmp = Vec::new();
165     core::mem::swap(&mut tmp, v);
166     let mut tmp = core::mem::ManuallyDrop::new(tmp);
167     let len = tmp.len();
168     let cap = tmp.capacity();
169     (tmp.as_mut_ptr(), len, cap)
170 }
171 
172 /// Rebuilds a `Vec` from a pointer, length, and capacity.
173 ///
174 /// # Safety
175 ///
176 /// The same as [`Vec::from_raw_parts`].
177 #[cfg(not(any(test, testlib)))]
178 unsafe fn rebuild<T>(v: &mut Vec<T>, ptr: *mut T, len: usize, cap: usize) {
179     // SAFETY: The safety requirements from this function satisfy those of `from_raw_parts`.
180     let mut tmp = unsafe { Vec::from_raw_parts(ptr, len, cap) };
181     core::mem::swap(&mut tmp, v);
182 }
183