xref: /linux/kernel/bpf/reuseport_array.c (revision 92d33063c081a82d25dd08a9cce03947c8ed9164)
1  // SPDX-License-Identifier: GPL-2.0
2  /*
3   * Copyright (c) 2018 Facebook
4   */
5  #include <linux/bpf.h>
6  #include <linux/err.h>
7  #include <linux/sock_diag.h>
8  #include <net/sock_reuseport.h>
9  #include <linux/btf_ids.h>
10  
11  struct reuseport_array {
12  	struct bpf_map map;
13  	struct sock __rcu *ptrs[];
14  };
15  
16  static struct reuseport_array *reuseport_array(struct bpf_map *map)
17  {
18  	return (struct reuseport_array *)map;
19  }
20  
21  /* The caller must hold the reuseport_lock */
22  void bpf_sk_reuseport_detach(struct sock *sk)
23  {
24  	struct sock __rcu **socks;
25  
26  	write_lock_bh(&sk->sk_callback_lock);
27  	socks = __locked_read_sk_user_data_with_flags(sk, SK_USER_DATA_BPF);
28  	if (socks) {
29  		WRITE_ONCE(sk->sk_user_data, NULL);
30  		/*
31  		 * Do not move this NULL assignment outside of
32  		 * sk->sk_callback_lock because there is
33  		 * a race with reuseport_array_free()
34  		 * which does not hold the reuseport_lock.
35  		 */
36  		RCU_INIT_POINTER(*socks, NULL);
37  	}
38  	write_unlock_bh(&sk->sk_callback_lock);
39  }
40  
41  static int reuseport_array_alloc_check(union bpf_attr *attr)
42  {
43  	if (attr->value_size != sizeof(u32) &&
44  	    attr->value_size != sizeof(u64))
45  		return -EINVAL;
46  
47  	return array_map_alloc_check(attr);
48  }
49  
50  static void *reuseport_array_lookup_elem(struct bpf_map *map, void *key)
51  {
52  	struct reuseport_array *array = reuseport_array(map);
53  	u32 index = *(u32 *)key;
54  
55  	if (unlikely(index >= array->map.max_entries))
56  		return NULL;
57  
58  	return rcu_dereference(array->ptrs[index]);
59  }
60  
61  /* Called from syscall only */
62  static int reuseport_array_delete_elem(struct bpf_map *map, void *key)
63  {
64  	struct reuseport_array *array = reuseport_array(map);
65  	u32 index = *(u32 *)key;
66  	struct sock *sk;
67  	int err;
68  
69  	if (index >= map->max_entries)
70  		return -E2BIG;
71  
72  	if (!rcu_access_pointer(array->ptrs[index]))
73  		return -ENOENT;
74  
75  	spin_lock_bh(&reuseport_lock);
76  
77  	sk = rcu_dereference_protected(array->ptrs[index],
78  				       lockdep_is_held(&reuseport_lock));
79  	if (sk) {
80  		write_lock_bh(&sk->sk_callback_lock);
81  		WRITE_ONCE(sk->sk_user_data, NULL);
82  		RCU_INIT_POINTER(array->ptrs[index], NULL);
83  		write_unlock_bh(&sk->sk_callback_lock);
84  		err = 0;
85  	} else {
86  		err = -ENOENT;
87  	}
88  
89  	spin_unlock_bh(&reuseport_lock);
90  
91  	return err;
92  }
93  
94  static void reuseport_array_free(struct bpf_map *map)
95  {
96  	struct reuseport_array *array = reuseport_array(map);
97  	struct sock *sk;
98  	u32 i;
99  
100  	/*
101  	 * ops->map_*_elem() will not be able to access this
102  	 * array now. Hence, this function only races with
103  	 * bpf_sk_reuseport_detach() which was triggered by
104  	 * close() or disconnect().
105  	 *
106  	 * This function and bpf_sk_reuseport_detach() are
107  	 * both removing sk from "array".  Who removes it
108  	 * first does not matter.
109  	 *
110  	 * The only concern here is bpf_sk_reuseport_detach()
111  	 * may access "array" which is being freed here.
112  	 * bpf_sk_reuseport_detach() access this "array"
113  	 * through sk->sk_user_data _and_ with sk->sk_callback_lock
114  	 * held which is enough because this "array" is not freed
115  	 * until all sk->sk_user_data has stopped referencing this "array".
116  	 *
117  	 * Hence, due to the above, taking "reuseport_lock" is not
118  	 * needed here.
119  	 */
120  
121  	/*
122  	 * Since reuseport_lock is not taken, sk is accessed under
123  	 * rcu_read_lock()
124  	 */
125  	rcu_read_lock();
126  	for (i = 0; i < map->max_entries; i++) {
127  		sk = rcu_dereference(array->ptrs[i]);
128  		if (sk) {
129  			write_lock_bh(&sk->sk_callback_lock);
130  			/*
131  			 * No need for WRITE_ONCE(). At this point,
132  			 * no one is reading it without taking the
133  			 * sk->sk_callback_lock.
134  			 */
135  			sk->sk_user_data = NULL;
136  			write_unlock_bh(&sk->sk_callback_lock);
137  			RCU_INIT_POINTER(array->ptrs[i], NULL);
138  		}
139  	}
140  	rcu_read_unlock();
141  
142  	/*
143  	 * Once reaching here, all sk->sk_user_data is not
144  	 * referencing this "array". "array" can be freed now.
145  	 */
146  	bpf_map_area_free(array);
147  }
148  
149  static struct bpf_map *reuseport_array_alloc(union bpf_attr *attr)
150  {
151  	int numa_node = bpf_map_attr_numa_node(attr);
152  	struct reuseport_array *array;
153  
154  	if (!bpf_capable())
155  		return ERR_PTR(-EPERM);
156  
157  	/* allocate all map elements and zero-initialize them */
158  	array = bpf_map_area_alloc(struct_size(array, ptrs, attr->max_entries), numa_node);
159  	if (!array)
160  		return ERR_PTR(-ENOMEM);
161  
162  	/* copy mandatory map attributes */
163  	bpf_map_init_from_attr(&array->map, attr);
164  
165  	return &array->map;
166  }
167  
168  int bpf_fd_reuseport_array_lookup_elem(struct bpf_map *map, void *key,
169  				       void *value)
170  {
171  	struct sock *sk;
172  	int err;
173  
174  	if (map->value_size != sizeof(u64))
175  		return -ENOSPC;
176  
177  	rcu_read_lock();
178  	sk = reuseport_array_lookup_elem(map, key);
179  	if (sk) {
180  		*(u64 *)value = __sock_gen_cookie(sk);
181  		err = 0;
182  	} else {
183  		err = -ENOENT;
184  	}
185  	rcu_read_unlock();
186  
187  	return err;
188  }
189  
190  static int
191  reuseport_array_update_check(const struct reuseport_array *array,
192  			     const struct sock *nsk,
193  			     const struct sock *osk,
194  			     const struct sock_reuseport *nsk_reuse,
195  			     u32 map_flags)
196  {
197  	if (osk && map_flags == BPF_NOEXIST)
198  		return -EEXIST;
199  
200  	if (!osk && map_flags == BPF_EXIST)
201  		return -ENOENT;
202  
203  	if (nsk->sk_protocol != IPPROTO_UDP && nsk->sk_protocol != IPPROTO_TCP)
204  		return -ENOTSUPP;
205  
206  	if (nsk->sk_family != AF_INET && nsk->sk_family != AF_INET6)
207  		return -ENOTSUPP;
208  
209  	if (nsk->sk_type != SOCK_STREAM && nsk->sk_type != SOCK_DGRAM)
210  		return -ENOTSUPP;
211  
212  	/*
213  	 * sk must be hashed (i.e. listening in the TCP case or binded
214  	 * in the UDP case) and
215  	 * it must also be a SO_REUSEPORT sk (i.e. reuse cannot be NULL).
216  	 *
217  	 * Also, sk will be used in bpf helper that is protected by
218  	 * rcu_read_lock().
219  	 */
220  	if (!sock_flag(nsk, SOCK_RCU_FREE) || !sk_hashed(nsk) || !nsk_reuse)
221  		return -EINVAL;
222  
223  	/* READ_ONCE because the sk->sk_callback_lock may not be held here */
224  	if (READ_ONCE(nsk->sk_user_data))
225  		return -EBUSY;
226  
227  	return 0;
228  }
229  
230  /*
231   * Called from syscall only.
232   * The "nsk" in the fd refcnt.
233   * The "osk" and "reuse" are protected by reuseport_lock.
234   */
235  int bpf_fd_reuseport_array_update_elem(struct bpf_map *map, void *key,
236  				       void *value, u64 map_flags)
237  {
238  	struct reuseport_array *array = reuseport_array(map);
239  	struct sock *free_osk = NULL, *osk, *nsk;
240  	struct sock_reuseport *reuse;
241  	u32 index = *(u32 *)key;
242  	uintptr_t sk_user_data;
243  	struct socket *socket;
244  	int err, fd;
245  
246  	if (map_flags > BPF_EXIST)
247  		return -EINVAL;
248  
249  	if (index >= map->max_entries)
250  		return -E2BIG;
251  
252  	if (map->value_size == sizeof(u64)) {
253  		u64 fd64 = *(u64 *)value;
254  
255  		if (fd64 > S32_MAX)
256  			return -EINVAL;
257  		fd = fd64;
258  	} else {
259  		fd = *(int *)value;
260  	}
261  
262  	socket = sockfd_lookup(fd, &err);
263  	if (!socket)
264  		return err;
265  
266  	nsk = socket->sk;
267  	if (!nsk) {
268  		err = -EINVAL;
269  		goto put_file;
270  	}
271  
272  	/* Quick checks before taking reuseport_lock */
273  	err = reuseport_array_update_check(array, nsk,
274  					   rcu_access_pointer(array->ptrs[index]),
275  					   rcu_access_pointer(nsk->sk_reuseport_cb),
276  					   map_flags);
277  	if (err)
278  		goto put_file;
279  
280  	spin_lock_bh(&reuseport_lock);
281  	/*
282  	 * Some of the checks only need reuseport_lock
283  	 * but it is done under sk_callback_lock also
284  	 * for simplicity reason.
285  	 */
286  	write_lock_bh(&nsk->sk_callback_lock);
287  
288  	osk = rcu_dereference_protected(array->ptrs[index],
289  					lockdep_is_held(&reuseport_lock));
290  	reuse = rcu_dereference_protected(nsk->sk_reuseport_cb,
291  					  lockdep_is_held(&reuseport_lock));
292  	err = reuseport_array_update_check(array, nsk, osk, reuse, map_flags);
293  	if (err)
294  		goto put_file_unlock;
295  
296  	sk_user_data = (uintptr_t)&array->ptrs[index] | SK_USER_DATA_NOCOPY |
297  		SK_USER_DATA_BPF;
298  	WRITE_ONCE(nsk->sk_user_data, (void *)sk_user_data);
299  	rcu_assign_pointer(array->ptrs[index], nsk);
300  	free_osk = osk;
301  	err = 0;
302  
303  put_file_unlock:
304  	write_unlock_bh(&nsk->sk_callback_lock);
305  
306  	if (free_osk) {
307  		write_lock_bh(&free_osk->sk_callback_lock);
308  		WRITE_ONCE(free_osk->sk_user_data, NULL);
309  		write_unlock_bh(&free_osk->sk_callback_lock);
310  	}
311  
312  	spin_unlock_bh(&reuseport_lock);
313  put_file:
314  	fput(socket->file);
315  	return err;
316  }
317  
318  /* Called from syscall */
319  static int reuseport_array_get_next_key(struct bpf_map *map, void *key,
320  					void *next_key)
321  {
322  	struct reuseport_array *array = reuseport_array(map);
323  	u32 index = key ? *(u32 *)key : U32_MAX;
324  	u32 *next = (u32 *)next_key;
325  
326  	if (index >= array->map.max_entries) {
327  		*next = 0;
328  		return 0;
329  	}
330  
331  	if (index == array->map.max_entries - 1)
332  		return -ENOENT;
333  
334  	*next = index + 1;
335  	return 0;
336  }
337  
338  BTF_ID_LIST_SINGLE(reuseport_array_map_btf_ids, struct, reuseport_array)
339  const struct bpf_map_ops reuseport_array_ops = {
340  	.map_meta_equal = bpf_map_meta_equal,
341  	.map_alloc_check = reuseport_array_alloc_check,
342  	.map_alloc = reuseport_array_alloc,
343  	.map_free = reuseport_array_free,
344  	.map_lookup_elem = reuseport_array_lookup_elem,
345  	.map_get_next_key = reuseport_array_get_next_key,
346  	.map_delete_elem = reuseport_array_delete_elem,
347  	.map_btf_id = &reuseport_array_map_btf_ids[0],
348  };
349