xref: /linux/kernel/bpf/range_tree.c (revision 84318277d6334c6981ab326d4acc87c6a6ddc9b8)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (c) 2024 Meta Platforms, Inc. and affiliates. */
3 #include <linux/interval_tree_generic.h>
4 #include <linux/slab.h>
5 #include <linux/bpf.h>
6 #include "range_tree.h"
7 
8 /*
9  * struct range_tree is a data structure used to allocate contiguous memory
10  * ranges in bpf arena. It's a large bitmap. The contiguous sequence of bits is
11  * represented by struct range_node or 'rn' for short.
12  * rn->rn_rbnode links it into an interval tree while
13  * rn->rb_range_size links it into a second rbtree sorted by size of the range.
14  * __find_range() performs binary search and best fit algorithm to find the
15  * range less or equal requested size.
16  * range_tree_clear/set() clears or sets a range of bits in this bitmap. The
17  * adjacent ranges are merged or split at the same time.
18  *
19  * The split/merge logic is based/borrowed from XFS's xbitmap32 added
20  * in commit 6772fcc8890a ("xfs: convert xbitmap to interval tree").
21  *
22  * The implementation relies on external lock to protect rbtree-s.
23  * The alloc/free of range_node-s is done via kmalloc_nolock().
24  *
25  * bpf arena is using range_tree to represent unallocated slots.
26  * At init time:
27  *   range_tree_set(rt, 0, max);
28  * Then:
29  *   start = range_tree_find(rt, len);
30  *   if (start >= 0)
31  *     range_tree_clear(rt, start, len);
32  * to find free range and mark slots as allocated and later:
33  *   range_tree_set(rt, start, len);
34  * to mark as unallocated after use.
35  */
36 struct range_node {
37 	struct rb_node rn_rbnode;
38 	struct rb_node rb_range_size;
39 	u32 rn_start;
40 	u32 rn_last; /* inclusive */
41 	u32 __rn_subtree_last;
42 };
43 
44 static struct range_node *rb_to_range_node(struct rb_node *rb)
45 {
46 	return rb_entry(rb, struct range_node, rb_range_size);
47 }
48 
49 static u32 rn_size(struct range_node *rn)
50 {
51 	return rn->rn_last - rn->rn_start + 1;
52 }
53 
54 /* Find range that fits best to requested size */
55 static inline struct range_node *__find_range(struct range_tree *rt, u32 len)
56 {
57 	struct rb_node *rb = rt->range_size_root.rb_root.rb_node;
58 	struct range_node *best = NULL;
59 
60 	while (rb) {
61 		struct range_node *rn = rb_to_range_node(rb);
62 
63 		if (len <= rn_size(rn)) {
64 			best = rn;
65 			rb = rb->rb_right;
66 		} else {
67 			rb = rb->rb_left;
68 		}
69 	}
70 
71 	return best;
72 }
73 
74 s64 range_tree_find(struct range_tree *rt, u32 len)
75 {
76 	struct range_node *rn;
77 
78 	rn = __find_range(rt, len);
79 	if (!rn)
80 		return -ENOENT;
81 	return rn->rn_start;
82 }
83 
84 /* Insert the range into rbtree sorted by the range size */
85 static inline void __range_size_insert(struct range_node *rn,
86 				       struct rb_root_cached *root)
87 {
88 	struct rb_node **link = &root->rb_root.rb_node, *rb = NULL;
89 	u64 size = rn_size(rn);
90 	bool leftmost = true;
91 
92 	while (*link) {
93 		rb = *link;
94 		if (size > rn_size(rb_to_range_node(rb))) {
95 			link = &rb->rb_left;
96 		} else {
97 			link = &rb->rb_right;
98 			leftmost = false;
99 		}
100 	}
101 
102 	rb_link_node(&rn->rb_range_size, rb, link);
103 	rb_insert_color_cached(&rn->rb_range_size, root, leftmost);
104 }
105 
106 #define START(node) ((node)->rn_start)
107 #define LAST(node)  ((node)->rn_last)
108 
109 INTERVAL_TREE_DEFINE(struct range_node, rn_rbnode, u32,
110 		     __rn_subtree_last, START, LAST,
111 		     static inline __maybe_unused,
112 		     __range_it)
113 
114 static inline __maybe_unused void
115 range_it_insert(struct range_node *rn, struct range_tree *rt)
116 {
117 	__range_size_insert(rn, &rt->range_size_root);
118 	__range_it_insert(rn, &rt->it_root);
119 }
120 
121 static inline __maybe_unused void
122 range_it_remove(struct range_node *rn, struct range_tree *rt)
123 {
124 	rb_erase_cached(&rn->rb_range_size, &rt->range_size_root);
125 	RB_CLEAR_NODE(&rn->rb_range_size);
126 	__range_it_remove(rn, &rt->it_root);
127 }
128 
129 static inline __maybe_unused struct range_node *
130 range_it_iter_first(struct range_tree *rt, u32 start, u32 last)
131 {
132 	return __range_it_iter_first(&rt->it_root, start, last);
133 }
134 
135 /* Clear the range in this range tree */
136 int range_tree_clear(struct range_tree *rt, u32 start, u32 len)
137 {
138 	u32 last = start + len - 1;
139 	struct range_node *new_rn;
140 	struct range_node *rn;
141 
142 	while ((rn = range_it_iter_first(rt, start, last))) {
143 		if (rn->rn_start < start && rn->rn_last > last) {
144 			u32 old_last = rn->rn_last;
145 
146 			/* Overlaps with the entire clearing range */
147 			range_it_remove(rn, rt);
148 			rn->rn_last = start - 1;
149 			range_it_insert(rn, rt);
150 
151 			/* Add a range */
152 			new_rn = kmalloc_nolock(sizeof(struct range_node), 0, NUMA_NO_NODE);
153 			if (!new_rn)
154 				return -ENOMEM;
155 			new_rn->rn_start = last + 1;
156 			new_rn->rn_last = old_last;
157 			range_it_insert(new_rn, rt);
158 		} else if (rn->rn_start < start) {
159 			/* Overlaps with the left side of the clearing range */
160 			range_it_remove(rn, rt);
161 			rn->rn_last = start - 1;
162 			range_it_insert(rn, rt);
163 		} else if (rn->rn_last > last) {
164 			/* Overlaps with the right side of the clearing range */
165 			range_it_remove(rn, rt);
166 			rn->rn_start = last + 1;
167 			range_it_insert(rn, rt);
168 			break;
169 		} else {
170 			/* in the middle of the clearing range */
171 			range_it_remove(rn, rt);
172 			kfree_nolock(rn);
173 		}
174 	}
175 	return 0;
176 }
177 
178 /* Is the whole range set ? */
179 int is_range_tree_set(struct range_tree *rt, u32 start, u32 len)
180 {
181 	u32 last = start + len - 1;
182 	struct range_node *left;
183 
184 	/* Is this whole range set ? */
185 	left = range_it_iter_first(rt, start, last);
186 	if (left && left->rn_start <= start && left->rn_last >= last)
187 		return 0;
188 	return -ESRCH;
189 }
190 
191 /* Set the range in this range tree */
192 int range_tree_set(struct range_tree *rt, u32 start, u32 len)
193 {
194 	u32 last = start + len - 1;
195 	struct range_node *right;
196 	struct range_node *left;
197 	int err;
198 
199 	/* Is this whole range already set ? */
200 	left = range_it_iter_first(rt, start, last);
201 	if (left && left->rn_start <= start && left->rn_last >= last)
202 		return 0;
203 
204 	/* Clear out everything in the range we want to set. */
205 	err = range_tree_clear(rt, start, len);
206 	if (err)
207 		return err;
208 
209 	/* Do we have a left-adjacent range ? */
210 	left = range_it_iter_first(rt, start - 1, start - 1);
211 	if (left && left->rn_last + 1 != start)
212 		return -EFAULT;
213 
214 	/* Do we have a right-adjacent range ? */
215 	right = range_it_iter_first(rt, last + 1, last + 1);
216 	if (right && right->rn_start != last + 1)
217 		return -EFAULT;
218 
219 	if (left && right) {
220 		/* Combine left and right adjacent ranges */
221 		range_it_remove(left, rt);
222 		range_it_remove(right, rt);
223 		left->rn_last = right->rn_last;
224 		range_it_insert(left, rt);
225 		kfree_nolock(right);
226 	} else if (left) {
227 		/* Combine with the left range */
228 		range_it_remove(left, rt);
229 		left->rn_last = last;
230 		range_it_insert(left, rt);
231 	} else if (right) {
232 		/* Combine with the right range */
233 		range_it_remove(right, rt);
234 		right->rn_start = start;
235 		range_it_insert(right, rt);
236 	} else {
237 		left = kmalloc_nolock(sizeof(struct range_node), 0, NUMA_NO_NODE);
238 		if (!left)
239 			return -ENOMEM;
240 		left->rn_start = start;
241 		left->rn_last = last;
242 		range_it_insert(left, rt);
243 	}
244 	return 0;
245 }
246 
247 void range_tree_destroy(struct range_tree *rt)
248 {
249 	struct range_node *rn;
250 
251 	while ((rn = range_it_iter_first(rt, 0, -1U))) {
252 		range_it_remove(rn, rt);
253 		kfree_nolock(rn);
254 	}
255 }
256 
257 void range_tree_init(struct range_tree *rt)
258 {
259 	rt->it_root = RB_ROOT_CACHED;
260 	rt->range_size_root = RB_ROOT_CACHED;
261 }
262