xref: /linux/rust/kernel/alloc/vec_ext.rs (revision 906fd46a65383cd639e5eec72a047efc33045d86)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 //! Extensions to [`Vec`] for fallible allocations.
4 
5 use super::{AllocError, Flags};
6 use alloc::vec::Vec;
7 
8 /// Extensions to [`Vec`].
9 pub trait VecExt<T>: Sized {
10     /// Creates a new [`Vec`] instance with at least the given capacity.
11     ///
12     /// # Examples
13     ///
14     /// ```
15     /// let v = Vec::<u32>::with_capacity(20, GFP_KERNEL)?;
16     ///
17     /// assert!(v.capacity() >= 20);
18     /// # Ok::<(), Error>(())
19     /// ```
20     fn with_capacity(capacity: usize, flags: Flags) -> Result<Self, AllocError>;
21 
22     /// Appends an element to the back of the [`Vec`] instance.
23     ///
24     /// # Examples
25     ///
26     /// ```
27     /// let mut v = Vec::new();
28     /// v.push(1, GFP_KERNEL)?;
29     /// assert_eq!(&v, &[1]);
30     ///
31     /// v.push(2, GFP_KERNEL)?;
32     /// assert_eq!(&v, &[1, 2]);
33     /// # Ok::<(), Error>(())
34     /// ```
35     fn push(&mut self, v: T, flags: Flags) -> Result<(), AllocError>;
36 
37     /// Pushes clones of the elements of slice into the [`Vec`] instance.
38     ///
39     /// # Examples
40     ///
41     /// ```
42     /// let mut v = Vec::new();
43     /// v.push(1, GFP_KERNEL)?;
44     ///
45     /// v.extend_from_slice(&[20, 30, 40], GFP_KERNEL)?;
46     /// assert_eq!(&v, &[1, 20, 30, 40]);
47     ///
48     /// v.extend_from_slice(&[50, 60], GFP_KERNEL)?;
49     /// assert_eq!(&v, &[1, 20, 30, 40, 50, 60]);
50     /// # Ok::<(), Error>(())
51     /// ```
52     fn extend_from_slice(&mut self, other: &[T], flags: Flags) -> Result<(), AllocError>
53     where
54         T: Clone;
55 
56     /// Ensures that the capacity exceeds the length by at least `additional` elements.
57     ///
58     /// # Examples
59     ///
60     /// ```
61     /// let mut v = Vec::new();
62     /// v.push(1, GFP_KERNEL)?;
63     ///
64     /// v.reserve(10, GFP_KERNEL)?;
65     /// let cap = v.capacity();
66     /// assert!(cap >= 10);
67     ///
68     /// v.reserve(10, GFP_KERNEL)?;
69     /// let new_cap = v.capacity();
70     /// assert_eq!(new_cap, cap);
71     ///
72     /// # Ok::<(), Error>(())
73     /// ```
74     fn reserve(&mut self, additional: usize, flags: Flags) -> Result<(), AllocError>;
75 }
76 
77 impl<T> VecExt<T> for Vec<T> {
78     fn with_capacity(capacity: usize, flags: Flags) -> Result<Self, AllocError> {
79         let mut v = Vec::new();
80         <Self as VecExt<_>>::reserve(&mut v, capacity, flags)?;
81         Ok(v)
82     }
83 
84     fn push(&mut self, v: T, flags: Flags) -> Result<(), AllocError> {
85         <Self as VecExt<_>>::reserve(self, 1, flags)?;
86         let s = self.spare_capacity_mut();
87         s[0].write(v);
88 
89         // SAFETY: We just initialised the first spare entry, so it is safe to increase the length
90         // by 1. We also know that the new length is <= capacity because of the previous call to
91         // `reserve` above.
92         unsafe { self.set_len(self.len() + 1) };
93         Ok(())
94     }
95 
96     fn extend_from_slice(&mut self, other: &[T], flags: Flags) -> Result<(), AllocError>
97     where
98         T: Clone,
99     {
100         <Self as VecExt<_>>::reserve(self, other.len(), flags)?;
101         for (slot, item) in core::iter::zip(self.spare_capacity_mut(), other) {
102             slot.write(item.clone());
103         }
104 
105         // SAFETY: We just initialised the `other.len()` spare entries, so it is safe to increase
106         // the length by the same amount. We also know that the new length is <= capacity because
107         // of the previous call to `reserve` above.
108         unsafe { self.set_len(self.len() + other.len()) };
109         Ok(())
110     }
111 
112     #[cfg(any(test, testlib))]
113     fn reserve(&mut self, additional: usize, _flags: Flags) -> Result<(), AllocError> {
114         Vec::reserve(self, additional);
115         Ok(())
116     }
117 
118     #[cfg(not(any(test, testlib)))]
119     fn reserve(&mut self, additional: usize, flags: Flags) -> Result<(), AllocError> {
120         let len = self.len();
121         let cap = self.capacity();
122 
123         if cap - len >= additional {
124             return Ok(());
125         }
126 
127         if core::mem::size_of::<T>() == 0 {
128             // The capacity is already `usize::MAX` for SZTs, we can't go higher.
129             return Err(AllocError);
130         }
131 
132         // We know cap is <= `isize::MAX` because `Layout::array` fails if the resulting byte size
133         // is greater than `isize::MAX`. So the multiplication by two won't overflow.
134         let new_cap = core::cmp::max(cap * 2, len.checked_add(additional).ok_or(AllocError)?);
135         let layout = core::alloc::Layout::array::<T>(new_cap).map_err(|_| AllocError)?;
136 
137         let (old_ptr, len, cap) = destructure(self);
138 
139         // We need to make sure that `ptr` is either NULL or comes from a previous call to
140         // `krealloc_aligned`. A `Vec<T>`'s `ptr` value is not guaranteed to be NULL and might be
141         // dangling after being created with `Vec::new`. Instead, we can rely on `Vec<T>`'s capacity
142         // to be zero if no memory has been allocated yet.
143         let ptr = if cap == 0 {
144             core::ptr::null_mut()
145         } else {
146             old_ptr
147         };
148 
149         // SAFETY: `ptr` is valid because it's either NULL or comes from a previous call to
150         // `krealloc_aligned`. We also verified that the type is not a ZST.
151         let new_ptr = unsafe { super::allocator::krealloc_aligned(ptr.cast(), layout, flags) };
152         if new_ptr.is_null() {
153             // SAFETY: We are just rebuilding the existing `Vec` with no changes.
154             unsafe { rebuild(self, old_ptr, len, cap) };
155             Err(AllocError)
156         } else {
157             // SAFETY: `ptr` has been reallocated with the layout for `new_cap` elements. New cap
158             // is greater than `cap`, so it continues to be >= `len`.
159             unsafe { rebuild(self, new_ptr.cast::<T>(), len, new_cap) };
160             Ok(())
161         }
162     }
163 }
164 
165 #[cfg(not(any(test, testlib)))]
166 fn destructure<T>(v: &mut Vec<T>) -> (*mut T, usize, usize) {
167     let mut tmp = Vec::new();
168     core::mem::swap(&mut tmp, v);
169     let mut tmp = core::mem::ManuallyDrop::new(tmp);
170     let len = tmp.len();
171     let cap = tmp.capacity();
172     (tmp.as_mut_ptr(), len, cap)
173 }
174 
175 /// Rebuilds a `Vec` from a pointer, length, and capacity.
176 ///
177 /// # Safety
178 ///
179 /// The same as [`Vec::from_raw_parts`].
180 #[cfg(not(any(test, testlib)))]
181 unsafe fn rebuild<T>(v: &mut Vec<T>, ptr: *mut T, len: usize, cap: usize) {
182     // SAFETY: The safety requirements from this function satisfy those of `from_raw_parts`.
183     let mut tmp = unsafe { Vec::from_raw_parts(ptr, len, cap) };
184     core::mem::swap(&mut tmp, v);
185 }
186