xref: /linux/tools/testing/selftests/bpf/progs/rbtree_search.c (revision da4fd657730c9510b848ef7a9cc7247bbb6a44c9)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2025 Meta Platforms, Inc. and affiliates. */
3 
4 #include <vmlinux.h>
5 #include <bpf/bpf_helpers.h>
6 #include "bpf_misc.h"
7 #include "bpf_experimental.h"
8 
9 struct node_data {
10 	struct bpf_refcount ref;
11 	struct bpf_rb_node r0;
12 	struct bpf_rb_node r1;
13 	int key0;
14 	int key1;
15 };
16 
17 #define private(name) SEC(".data." #name) __hidden __attribute__((aligned(8)))
18 private(A) struct bpf_spin_lock glock0;
19 private(A) struct bpf_rb_root groot0 __contains(node_data, r0);
20 
21 private(B) struct bpf_spin_lock glock1;
22 private(B) struct bpf_rb_root groot1 __contains(node_data, r1);
23 
24 #define rb_entry(ptr, type, member) container_of(ptr, type, member)
25 #define NR_NODES 16
26 
27 int zero = 0;
28 
29 static bool less0(struct bpf_rb_node *a, const struct bpf_rb_node *b)
30 {
31 	struct node_data *node_a;
32 	struct node_data *node_b;
33 
34 	node_a = rb_entry(a, struct node_data, r0);
35 	node_b = rb_entry(b, struct node_data, r0);
36 
37 	return node_a->key0 < node_b->key0;
38 }
39 
40 static bool less1(struct bpf_rb_node *a, const struct bpf_rb_node *b)
41 {
42 	struct node_data *node_a;
43 	struct node_data *node_b;
44 
45 	node_a = rb_entry(a, struct node_data, r1);
46 	node_b = rb_entry(b, struct node_data, r1);
47 
48 	return node_a->key1 < node_b->key1;
49 }
50 
51 SEC("syscall")
52 __retval(0)
53 long rbtree_search(void *ctx)
54 {
55 	struct bpf_rb_node *rb_n, *rb_m, *gc_ns[NR_NODES];
56 	long lookup_key = NR_NODES / 2;
57 	struct node_data *n, *m;
58 	int i, nr_gc = 0;
59 
60 	for (i = zero; i < NR_NODES && can_loop; i++) {
61 		n = bpf_obj_new(typeof(*n));
62 		if (!n)
63 			return __LINE__;
64 
65 		m = bpf_refcount_acquire(n);
66 
67 		n->key0 = i;
68 		m->key1 = i;
69 
70 		bpf_spin_lock(&glock0);
71 		bpf_rbtree_add(&groot0, &n->r0, less0);
72 		bpf_spin_unlock(&glock0);
73 
74 		bpf_spin_lock(&glock1);
75 		bpf_rbtree_add(&groot1, &m->r1, less1);
76 		bpf_spin_unlock(&glock1);
77 	}
78 
79 	n = NULL;
80 	bpf_spin_lock(&glock0);
81 	rb_n = bpf_rbtree_root(&groot0);
82 	while (can_loop) {
83 		if (!rb_n) {
84 			bpf_spin_unlock(&glock0);
85 			return __LINE__;
86 		}
87 
88 		n = rb_entry(rb_n, struct node_data, r0);
89 		if (lookup_key == n->key0)
90 			break;
91 		if (nr_gc < NR_NODES)
92 			gc_ns[nr_gc++] = rb_n;
93 		if (lookup_key < n->key0)
94 			rb_n = bpf_rbtree_left(&groot0, rb_n);
95 		else
96 			rb_n = bpf_rbtree_right(&groot0, rb_n);
97 	}
98 
99 	if (!n || lookup_key != n->key0) {
100 		bpf_spin_unlock(&glock0);
101 		return __LINE__;
102 	}
103 
104 	for (i = 0; i < nr_gc; i++) {
105 		rb_n = gc_ns[i];
106 		gc_ns[i] = bpf_rbtree_remove(&groot0, rb_n);
107 	}
108 
109 	m = bpf_refcount_acquire(n);
110 	bpf_spin_unlock(&glock0);
111 
112 	for (i = 0; i < nr_gc; i++) {
113 		rb_n = gc_ns[i];
114 		if (rb_n) {
115 			n = rb_entry(rb_n, struct node_data, r0);
116 			bpf_obj_drop(n);
117 		}
118 	}
119 
120 	if (!m)
121 		return __LINE__;
122 
123 	bpf_spin_lock(&glock1);
124 	rb_m = bpf_rbtree_remove(&groot1, &m->r1);
125 	bpf_spin_unlock(&glock1);
126 	bpf_obj_drop(m);
127 	if (!rb_m)
128 		return __LINE__;
129 	bpf_obj_drop(rb_entry(rb_m, struct node_data, r1));
130 
131 	return 0;
132 }
133 
134 #define TEST_ROOT(dolock)				\
135 SEC("syscall")						\
136 __failure __msg(MSG)					\
137 long test_root_spinlock_##dolock(void *ctx)		\
138 {							\
139 	struct bpf_rb_node *rb_n;			\
140 	__u64 jiffies = 0;				\
141 							\
142 	if (dolock)					\
143 		bpf_spin_lock(&glock0);			\
144 	rb_n = bpf_rbtree_root(&groot0);		\
145 	if (rb_n)					\
146 		jiffies = bpf_jiffies64();		\
147 	if (dolock)					\
148 		bpf_spin_unlock(&glock0);		\
149 							\
150 	return !!jiffies;				\
151 }
152 
153 #define TEST_LR(op, dolock)				\
154 SEC("syscall")						\
155 __failure __msg(MSG)					\
156 long test_##op##_spinlock_##dolock(void *ctx)		\
157 {							\
158 	struct bpf_rb_node *rb_n;			\
159 	struct node_data *n;				\
160 	__u64 jiffies = 0;				\
161 							\
162 	bpf_spin_lock(&glock0);				\
163 	rb_n = bpf_rbtree_root(&groot0);		\
164 	if (!rb_n) {					\
165 		bpf_spin_unlock(&glock0);		\
166 		return 1;				\
167 	}						\
168 	n = rb_entry(rb_n, struct node_data, r0);	\
169 	n = bpf_refcount_acquire(n);			\
170 	bpf_spin_unlock(&glock0);			\
171 	if (!n)						\
172 		return 1;				\
173 							\
174 	if (dolock)					\
175 		bpf_spin_lock(&glock0);			\
176 	rb_n = bpf_rbtree_##op(&groot0, &n->r0);	\
177 	if (rb_n)					\
178 		jiffies = bpf_jiffies64();		\
179 	if (dolock)					\
180 		bpf_spin_unlock(&glock0);		\
181 							\
182 	return !!jiffies;				\
183 }
184 
185 /*
186  * Use a spearate MSG macro instead of passing to TEST_XXX(..., MSG)
187  * to ensure the message itself is not in the bpf prog lineinfo
188  * which the verifier includes in its log.
189  * Otherwise, the test_loader will incorrectly match the prog lineinfo
190  * instead of the log generated by the verifier.
191  */
192 #define MSG "call bpf_rbtree_root{{.+}}; R0{{(_w)?}}=rcu_ptr_or_null_node_data(id={{[0-9]+}},non_own_ref"
193 TEST_ROOT(true)
194 #undef MSG
195 #define MSG "call bpf_rbtree_{{(left|right).+}}; R0{{(_w)?}}=rcu_ptr_or_null_node_data(id={{[0-9]+}},non_own_ref"
196 TEST_LR(left,  true)
197 TEST_LR(right, true)
198 #undef MSG
199 
200 #define MSG "bpf_spin_lock at off=0 must be held for bpf_rb_root"
201 TEST_ROOT(false)
202 TEST_LR(left, false)
203 TEST_LR(right, false)
204 #undef MSG
205 
206 char _license[] SEC("license") = "GPL";
207