xref: /linux/net/core/bpf_sk_storage.c (revision 24ce659dcc02c21f8d6c0a7589c3320a4dfa8152)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook  */
3 #include <linux/rculist.h>
4 #include <linux/list.h>
5 #include <linux/hash.h>
6 #include <linux/types.h>
7 #include <linux/spinlock.h>
8 #include <linux/bpf.h>
9 #include <net/bpf_sk_storage.h>
10 #include <net/sock.h>
11 #include <uapi/linux/sock_diag.h>
12 #include <uapi/linux/btf.h>
13 
14 static atomic_t cache_idx;
15 
16 #define SK_STORAGE_CREATE_FLAG_MASK					\
17 	(BPF_F_NO_PREALLOC | BPF_F_CLONE)
18 
19 struct bucket {
20 	struct hlist_head list;
21 	raw_spinlock_t lock;
22 };
23 
24 /* Thp map is not the primary owner of a bpf_sk_storage_elem.
25  * Instead, the sk->sk_bpf_storage is.
26  *
27  * The map (bpf_sk_storage_map) is for two purposes
28  * 1. Define the size of the "sk local storage".  It is
29  *    the map's value_size.
30  *
31  * 2. Maintain a list to keep track of all elems such
32  *    that they can be cleaned up during the map destruction.
33  *
34  * When a bpf local storage is being looked up for a
35  * particular sk,  the "bpf_map" pointer is actually used
36  * as the "key" to search in the list of elem in
37  * sk->sk_bpf_storage.
38  *
39  * Hence, consider sk->sk_bpf_storage is the mini-map
40  * with the "bpf_map" pointer as the searching key.
41  */
42 struct bpf_sk_storage_map {
43 	struct bpf_map map;
44 	/* Lookup elem does not require accessing the map.
45 	 *
46 	 * Updating/Deleting requires a bucket lock to
47 	 * link/unlink the elem from the map.  Having
48 	 * multiple buckets to improve contention.
49 	 */
50 	struct bucket *buckets;
51 	u32 bucket_log;
52 	u16 elem_size;
53 	u16 cache_idx;
54 };
55 
56 struct bpf_sk_storage_data {
57 	/* smap is used as the searching key when looking up
58 	 * from sk->sk_bpf_storage.
59 	 *
60 	 * Put it in the same cacheline as the data to minimize
61 	 * the number of cachelines access during the cache hit case.
62 	 */
63 	struct bpf_sk_storage_map __rcu *smap;
64 	u8 data[] __aligned(8);
65 };
66 
67 /* Linked to bpf_sk_storage and bpf_sk_storage_map */
68 struct bpf_sk_storage_elem {
69 	struct hlist_node map_node;	/* Linked to bpf_sk_storage_map */
70 	struct hlist_node snode;	/* Linked to bpf_sk_storage */
71 	struct bpf_sk_storage __rcu *sk_storage;
72 	struct rcu_head rcu;
73 	/* 8 bytes hole */
74 	/* The data is stored in aother cacheline to minimize
75 	 * the number of cachelines access during a cache hit.
76 	 */
77 	struct bpf_sk_storage_data sdata ____cacheline_aligned;
78 };
79 
80 #define SELEM(_SDATA) container_of((_SDATA), struct bpf_sk_storage_elem, sdata)
81 #define SDATA(_SELEM) (&(_SELEM)->sdata)
82 #define BPF_SK_STORAGE_CACHE_SIZE	16
83 
84 struct bpf_sk_storage {
85 	struct bpf_sk_storage_data __rcu *cache[BPF_SK_STORAGE_CACHE_SIZE];
86 	struct hlist_head list;	/* List of bpf_sk_storage_elem */
87 	struct sock *sk;	/* The sk that owns the the above "list" of
88 				 * bpf_sk_storage_elem.
89 				 */
90 	struct rcu_head rcu;
91 	raw_spinlock_t lock;	/* Protect adding/removing from the "list" */
92 };
93 
94 static struct bucket *select_bucket(struct bpf_sk_storage_map *smap,
95 				    struct bpf_sk_storage_elem *selem)
96 {
97 	return &smap->buckets[hash_ptr(selem, smap->bucket_log)];
98 }
99 
100 static int omem_charge(struct sock *sk, unsigned int size)
101 {
102 	/* same check as in sock_kmalloc() */
103 	if (size <= sysctl_optmem_max &&
104 	    atomic_read(&sk->sk_omem_alloc) + size < sysctl_optmem_max) {
105 		atomic_add(size, &sk->sk_omem_alloc);
106 		return 0;
107 	}
108 
109 	return -ENOMEM;
110 }
111 
112 static bool selem_linked_to_sk(const struct bpf_sk_storage_elem *selem)
113 {
114 	return !hlist_unhashed(&selem->snode);
115 }
116 
117 static bool selem_linked_to_map(const struct bpf_sk_storage_elem *selem)
118 {
119 	return !hlist_unhashed(&selem->map_node);
120 }
121 
122 static struct bpf_sk_storage_elem *selem_alloc(struct bpf_sk_storage_map *smap,
123 					       struct sock *sk, void *value,
124 					       bool charge_omem)
125 {
126 	struct bpf_sk_storage_elem *selem;
127 
128 	if (charge_omem && omem_charge(sk, smap->elem_size))
129 		return NULL;
130 
131 	selem = kzalloc(smap->elem_size, GFP_ATOMIC | __GFP_NOWARN);
132 	if (selem) {
133 		if (value)
134 			memcpy(SDATA(selem)->data, value, smap->map.value_size);
135 		return selem;
136 	}
137 
138 	if (charge_omem)
139 		atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
140 
141 	return NULL;
142 }
143 
144 /* sk_storage->lock must be held and selem->sk_storage == sk_storage.
145  * The caller must ensure selem->smap is still valid to be
146  * dereferenced for its smap->elem_size and smap->cache_idx.
147  */
148 static bool __selem_unlink_sk(struct bpf_sk_storage *sk_storage,
149 			      struct bpf_sk_storage_elem *selem,
150 			      bool uncharge_omem)
151 {
152 	struct bpf_sk_storage_map *smap;
153 	bool free_sk_storage;
154 	struct sock *sk;
155 
156 	smap = rcu_dereference(SDATA(selem)->smap);
157 	sk = sk_storage->sk;
158 
159 	/* All uncharging on sk->sk_omem_alloc must be done first.
160 	 * sk may be freed once the last selem is unlinked from sk_storage.
161 	 */
162 	if (uncharge_omem)
163 		atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
164 
165 	free_sk_storage = hlist_is_singular_node(&selem->snode,
166 						 &sk_storage->list);
167 	if (free_sk_storage) {
168 		atomic_sub(sizeof(struct bpf_sk_storage), &sk->sk_omem_alloc);
169 		sk_storage->sk = NULL;
170 		/* After this RCU_INIT, sk may be freed and cannot be used */
171 		RCU_INIT_POINTER(sk->sk_bpf_storage, NULL);
172 
173 		/* sk_storage is not freed now.  sk_storage->lock is
174 		 * still held and raw_spin_unlock_bh(&sk_storage->lock)
175 		 * will be done by the caller.
176 		 *
177 		 * Although the unlock will be done under
178 		 * rcu_read_lock(),  it is more intutivie to
179 		 * read if kfree_rcu(sk_storage, rcu) is done
180 		 * after the raw_spin_unlock_bh(&sk_storage->lock).
181 		 *
182 		 * Hence, a "bool free_sk_storage" is returned
183 		 * to the caller which then calls the kfree_rcu()
184 		 * after unlock.
185 		 */
186 	}
187 	hlist_del_init_rcu(&selem->snode);
188 	if (rcu_access_pointer(sk_storage->cache[smap->cache_idx]) ==
189 	    SDATA(selem))
190 		RCU_INIT_POINTER(sk_storage->cache[smap->cache_idx], NULL);
191 
192 	kfree_rcu(selem, rcu);
193 
194 	return free_sk_storage;
195 }
196 
197 static void selem_unlink_sk(struct bpf_sk_storage_elem *selem)
198 {
199 	struct bpf_sk_storage *sk_storage;
200 	bool free_sk_storage = false;
201 
202 	if (unlikely(!selem_linked_to_sk(selem)))
203 		/* selem has already been unlinked from sk */
204 		return;
205 
206 	sk_storage = rcu_dereference(selem->sk_storage);
207 	raw_spin_lock_bh(&sk_storage->lock);
208 	if (likely(selem_linked_to_sk(selem)))
209 		free_sk_storage = __selem_unlink_sk(sk_storage, selem, true);
210 	raw_spin_unlock_bh(&sk_storage->lock);
211 
212 	if (free_sk_storage)
213 		kfree_rcu(sk_storage, rcu);
214 }
215 
216 static void __selem_link_sk(struct bpf_sk_storage *sk_storage,
217 			    struct bpf_sk_storage_elem *selem)
218 {
219 	RCU_INIT_POINTER(selem->sk_storage, sk_storage);
220 	hlist_add_head(&selem->snode, &sk_storage->list);
221 }
222 
223 static void selem_unlink_map(struct bpf_sk_storage_elem *selem)
224 {
225 	struct bpf_sk_storage_map *smap;
226 	struct bucket *b;
227 
228 	if (unlikely(!selem_linked_to_map(selem)))
229 		/* selem has already be unlinked from smap */
230 		return;
231 
232 	smap = rcu_dereference(SDATA(selem)->smap);
233 	b = select_bucket(smap, selem);
234 	raw_spin_lock_bh(&b->lock);
235 	if (likely(selem_linked_to_map(selem)))
236 		hlist_del_init_rcu(&selem->map_node);
237 	raw_spin_unlock_bh(&b->lock);
238 }
239 
240 static void selem_link_map(struct bpf_sk_storage_map *smap,
241 			   struct bpf_sk_storage_elem *selem)
242 {
243 	struct bucket *b = select_bucket(smap, selem);
244 
245 	raw_spin_lock_bh(&b->lock);
246 	RCU_INIT_POINTER(SDATA(selem)->smap, smap);
247 	hlist_add_head_rcu(&selem->map_node, &b->list);
248 	raw_spin_unlock_bh(&b->lock);
249 }
250 
251 static void selem_unlink(struct bpf_sk_storage_elem *selem)
252 {
253 	/* Always unlink from map before unlinking from sk_storage
254 	 * because selem will be freed after successfully unlinked from
255 	 * the sk_storage.
256 	 */
257 	selem_unlink_map(selem);
258 	selem_unlink_sk(selem);
259 }
260 
261 static struct bpf_sk_storage_data *
262 __sk_storage_lookup(struct bpf_sk_storage *sk_storage,
263 		    struct bpf_sk_storage_map *smap,
264 		    bool cacheit_lockit)
265 {
266 	struct bpf_sk_storage_data *sdata;
267 	struct bpf_sk_storage_elem *selem;
268 
269 	/* Fast path (cache hit) */
270 	sdata = rcu_dereference(sk_storage->cache[smap->cache_idx]);
271 	if (sdata && rcu_access_pointer(sdata->smap) == smap)
272 		return sdata;
273 
274 	/* Slow path (cache miss) */
275 	hlist_for_each_entry_rcu(selem, &sk_storage->list, snode)
276 		if (rcu_access_pointer(SDATA(selem)->smap) == smap)
277 			break;
278 
279 	if (!selem)
280 		return NULL;
281 
282 	sdata = SDATA(selem);
283 	if (cacheit_lockit) {
284 		/* spinlock is needed to avoid racing with the
285 		 * parallel delete.  Otherwise, publishing an already
286 		 * deleted sdata to the cache will become a use-after-free
287 		 * problem in the next __sk_storage_lookup().
288 		 */
289 		raw_spin_lock_bh(&sk_storage->lock);
290 		if (selem_linked_to_sk(selem))
291 			rcu_assign_pointer(sk_storage->cache[smap->cache_idx],
292 					   sdata);
293 		raw_spin_unlock_bh(&sk_storage->lock);
294 	}
295 
296 	return sdata;
297 }
298 
299 static struct bpf_sk_storage_data *
300 sk_storage_lookup(struct sock *sk, struct bpf_map *map, bool cacheit_lockit)
301 {
302 	struct bpf_sk_storage *sk_storage;
303 	struct bpf_sk_storage_map *smap;
304 
305 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
306 	if (!sk_storage)
307 		return NULL;
308 
309 	smap = (struct bpf_sk_storage_map *)map;
310 	return __sk_storage_lookup(sk_storage, smap, cacheit_lockit);
311 }
312 
313 static int check_flags(const struct bpf_sk_storage_data *old_sdata,
314 		       u64 map_flags)
315 {
316 	if (old_sdata && (map_flags & ~BPF_F_LOCK) == BPF_NOEXIST)
317 		/* elem already exists */
318 		return -EEXIST;
319 
320 	if (!old_sdata && (map_flags & ~BPF_F_LOCK) == BPF_EXIST)
321 		/* elem doesn't exist, cannot update it */
322 		return -ENOENT;
323 
324 	return 0;
325 }
326 
327 static int sk_storage_alloc(struct sock *sk,
328 			    struct bpf_sk_storage_map *smap,
329 			    struct bpf_sk_storage_elem *first_selem)
330 {
331 	struct bpf_sk_storage *prev_sk_storage, *sk_storage;
332 	int err;
333 
334 	err = omem_charge(sk, sizeof(*sk_storage));
335 	if (err)
336 		return err;
337 
338 	sk_storage = kzalloc(sizeof(*sk_storage), GFP_ATOMIC | __GFP_NOWARN);
339 	if (!sk_storage) {
340 		err = -ENOMEM;
341 		goto uncharge;
342 	}
343 	INIT_HLIST_HEAD(&sk_storage->list);
344 	raw_spin_lock_init(&sk_storage->lock);
345 	sk_storage->sk = sk;
346 
347 	__selem_link_sk(sk_storage, first_selem);
348 	selem_link_map(smap, first_selem);
349 	/* Publish sk_storage to sk.  sk->sk_lock cannot be acquired.
350 	 * Hence, atomic ops is used to set sk->sk_bpf_storage
351 	 * from NULL to the newly allocated sk_storage ptr.
352 	 *
353 	 * From now on, the sk->sk_bpf_storage pointer is protected
354 	 * by the sk_storage->lock.  Hence,  when freeing
355 	 * the sk->sk_bpf_storage, the sk_storage->lock must
356 	 * be held before setting sk->sk_bpf_storage to NULL.
357 	 */
358 	prev_sk_storage = cmpxchg((struct bpf_sk_storage **)&sk->sk_bpf_storage,
359 				  NULL, sk_storage);
360 	if (unlikely(prev_sk_storage)) {
361 		selem_unlink_map(first_selem);
362 		err = -EAGAIN;
363 		goto uncharge;
364 
365 		/* Note that even first_selem was linked to smap's
366 		 * bucket->list, first_selem can be freed immediately
367 		 * (instead of kfree_rcu) because
368 		 * bpf_sk_storage_map_free() does a
369 		 * synchronize_rcu() before walking the bucket->list.
370 		 * Hence, no one is accessing selem from the
371 		 * bucket->list under rcu_read_lock().
372 		 */
373 	}
374 
375 	return 0;
376 
377 uncharge:
378 	kfree(sk_storage);
379 	atomic_sub(sizeof(*sk_storage), &sk->sk_omem_alloc);
380 	return err;
381 }
382 
383 /* sk cannot be going away because it is linking new elem
384  * to sk->sk_bpf_storage. (i.e. sk->sk_refcnt cannot be 0).
385  * Otherwise, it will become a leak (and other memory issues
386  * during map destruction).
387  */
388 static struct bpf_sk_storage_data *sk_storage_update(struct sock *sk,
389 						     struct bpf_map *map,
390 						     void *value,
391 						     u64 map_flags)
392 {
393 	struct bpf_sk_storage_data *old_sdata = NULL;
394 	struct bpf_sk_storage_elem *selem;
395 	struct bpf_sk_storage *sk_storage;
396 	struct bpf_sk_storage_map *smap;
397 	int err;
398 
399 	/* BPF_EXIST and BPF_NOEXIST cannot be both set */
400 	if (unlikely((map_flags & ~BPF_F_LOCK) > BPF_EXIST) ||
401 	    /* BPF_F_LOCK can only be used in a value with spin_lock */
402 	    unlikely((map_flags & BPF_F_LOCK) && !map_value_has_spin_lock(map)))
403 		return ERR_PTR(-EINVAL);
404 
405 	smap = (struct bpf_sk_storage_map *)map;
406 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
407 	if (!sk_storage || hlist_empty(&sk_storage->list)) {
408 		/* Very first elem for this sk */
409 		err = check_flags(NULL, map_flags);
410 		if (err)
411 			return ERR_PTR(err);
412 
413 		selem = selem_alloc(smap, sk, value, true);
414 		if (!selem)
415 			return ERR_PTR(-ENOMEM);
416 
417 		err = sk_storage_alloc(sk, smap, selem);
418 		if (err) {
419 			kfree(selem);
420 			atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
421 			return ERR_PTR(err);
422 		}
423 
424 		return SDATA(selem);
425 	}
426 
427 	if ((map_flags & BPF_F_LOCK) && !(map_flags & BPF_NOEXIST)) {
428 		/* Hoping to find an old_sdata to do inline update
429 		 * such that it can avoid taking the sk_storage->lock
430 		 * and changing the lists.
431 		 */
432 		old_sdata = __sk_storage_lookup(sk_storage, smap, false);
433 		err = check_flags(old_sdata, map_flags);
434 		if (err)
435 			return ERR_PTR(err);
436 		if (old_sdata && selem_linked_to_sk(SELEM(old_sdata))) {
437 			copy_map_value_locked(map, old_sdata->data,
438 					      value, false);
439 			return old_sdata;
440 		}
441 	}
442 
443 	raw_spin_lock_bh(&sk_storage->lock);
444 
445 	/* Recheck sk_storage->list under sk_storage->lock */
446 	if (unlikely(hlist_empty(&sk_storage->list))) {
447 		/* A parallel del is happening and sk_storage is going
448 		 * away.  It has just been checked before, so very
449 		 * unlikely.  Return instead of retry to keep things
450 		 * simple.
451 		 */
452 		err = -EAGAIN;
453 		goto unlock_err;
454 	}
455 
456 	old_sdata = __sk_storage_lookup(sk_storage, smap, false);
457 	err = check_flags(old_sdata, map_flags);
458 	if (err)
459 		goto unlock_err;
460 
461 	if (old_sdata && (map_flags & BPF_F_LOCK)) {
462 		copy_map_value_locked(map, old_sdata->data, value, false);
463 		selem = SELEM(old_sdata);
464 		goto unlock;
465 	}
466 
467 	/* sk_storage->lock is held.  Hence, we are sure
468 	 * we can unlink and uncharge the old_sdata successfully
469 	 * later.  Hence, instead of charging the new selem now
470 	 * and then uncharge the old selem later (which may cause
471 	 * a potential but unnecessary charge failure),  avoid taking
472 	 * a charge at all here (the "!old_sdata" check) and the
473 	 * old_sdata will not be uncharged later during __selem_unlink_sk().
474 	 */
475 	selem = selem_alloc(smap, sk, value, !old_sdata);
476 	if (!selem) {
477 		err = -ENOMEM;
478 		goto unlock_err;
479 	}
480 
481 	/* First, link the new selem to the map */
482 	selem_link_map(smap, selem);
483 
484 	/* Second, link (and publish) the new selem to sk_storage */
485 	__selem_link_sk(sk_storage, selem);
486 
487 	/* Third, remove old selem, SELEM(old_sdata) */
488 	if (old_sdata) {
489 		selem_unlink_map(SELEM(old_sdata));
490 		__selem_unlink_sk(sk_storage, SELEM(old_sdata), false);
491 	}
492 
493 unlock:
494 	raw_spin_unlock_bh(&sk_storage->lock);
495 	return SDATA(selem);
496 
497 unlock_err:
498 	raw_spin_unlock_bh(&sk_storage->lock);
499 	return ERR_PTR(err);
500 }
501 
502 static int sk_storage_delete(struct sock *sk, struct bpf_map *map)
503 {
504 	struct bpf_sk_storage_data *sdata;
505 
506 	sdata = sk_storage_lookup(sk, map, false);
507 	if (!sdata)
508 		return -ENOENT;
509 
510 	selem_unlink(SELEM(sdata));
511 
512 	return 0;
513 }
514 
515 /* Called by __sk_destruct() & bpf_sk_storage_clone() */
516 void bpf_sk_storage_free(struct sock *sk)
517 {
518 	struct bpf_sk_storage_elem *selem;
519 	struct bpf_sk_storage *sk_storage;
520 	bool free_sk_storage = false;
521 	struct hlist_node *n;
522 
523 	rcu_read_lock();
524 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
525 	if (!sk_storage) {
526 		rcu_read_unlock();
527 		return;
528 	}
529 
530 	/* Netiher the bpf_prog nor the bpf-map's syscall
531 	 * could be modifying the sk_storage->list now.
532 	 * Thus, no elem can be added-to or deleted-from the
533 	 * sk_storage->list by the bpf_prog or by the bpf-map's syscall.
534 	 *
535 	 * It is racing with bpf_sk_storage_map_free() alone
536 	 * when unlinking elem from the sk_storage->list and
537 	 * the map's bucket->list.
538 	 */
539 	raw_spin_lock_bh(&sk_storage->lock);
540 	hlist_for_each_entry_safe(selem, n, &sk_storage->list, snode) {
541 		/* Always unlink from map before unlinking from
542 		 * sk_storage.
543 		 */
544 		selem_unlink_map(selem);
545 		free_sk_storage = __selem_unlink_sk(sk_storage, selem, true);
546 	}
547 	raw_spin_unlock_bh(&sk_storage->lock);
548 	rcu_read_unlock();
549 
550 	if (free_sk_storage)
551 		kfree_rcu(sk_storage, rcu);
552 }
553 
554 static void bpf_sk_storage_map_free(struct bpf_map *map)
555 {
556 	struct bpf_sk_storage_elem *selem;
557 	struct bpf_sk_storage_map *smap;
558 	struct bucket *b;
559 	unsigned int i;
560 
561 	smap = (struct bpf_sk_storage_map *)map;
562 
563 	/* Note that this map might be concurrently cloned from
564 	 * bpf_sk_storage_clone. Wait for any existing bpf_sk_storage_clone
565 	 * RCU read section to finish before proceeding. New RCU
566 	 * read sections should be prevented via bpf_map_inc_not_zero.
567 	 */
568 	synchronize_rcu();
569 
570 	/* bpf prog and the userspace can no longer access this map
571 	 * now.  No new selem (of this map) can be added
572 	 * to the sk->sk_bpf_storage or to the map bucket's list.
573 	 *
574 	 * The elem of this map can be cleaned up here
575 	 * or
576 	 * by bpf_sk_storage_free() during __sk_destruct().
577 	 */
578 	for (i = 0; i < (1U << smap->bucket_log); i++) {
579 		b = &smap->buckets[i];
580 
581 		rcu_read_lock();
582 		/* No one is adding to b->list now */
583 		while ((selem = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(&b->list)),
584 						 struct bpf_sk_storage_elem,
585 						 map_node))) {
586 			selem_unlink(selem);
587 			cond_resched_rcu();
588 		}
589 		rcu_read_unlock();
590 	}
591 
592 	/* bpf_sk_storage_free() may still need to access the map.
593 	 * e.g. bpf_sk_storage_free() has unlinked selem from the map
594 	 * which then made the above while((selem = ...)) loop
595 	 * exited immediately.
596 	 *
597 	 * However, the bpf_sk_storage_free() still needs to access
598 	 * the smap->elem_size to do the uncharging in
599 	 * __selem_unlink_sk().
600 	 *
601 	 * Hence, wait another rcu grace period for the
602 	 * bpf_sk_storage_free() to finish.
603 	 */
604 	synchronize_rcu();
605 
606 	kvfree(smap->buckets);
607 	kfree(map);
608 }
609 
610 /* U16_MAX is much more than enough for sk local storage
611  * considering a tcp_sock is ~2k.
612  */
613 #define MAX_VALUE_SIZE							\
614 	min_t(u32,							\
615 	      (KMALLOC_MAX_SIZE - MAX_BPF_STACK - sizeof(struct bpf_sk_storage_elem)), \
616 	      (U16_MAX - sizeof(struct bpf_sk_storage_elem)))
617 
618 static int bpf_sk_storage_map_alloc_check(union bpf_attr *attr)
619 {
620 	if (attr->map_flags & ~SK_STORAGE_CREATE_FLAG_MASK ||
621 	    !(attr->map_flags & BPF_F_NO_PREALLOC) ||
622 	    attr->max_entries ||
623 	    attr->key_size != sizeof(int) || !attr->value_size ||
624 	    /* Enforce BTF for userspace sk dumping */
625 	    !attr->btf_key_type_id || !attr->btf_value_type_id)
626 		return -EINVAL;
627 
628 	if (!capable(CAP_SYS_ADMIN))
629 		return -EPERM;
630 
631 	if (attr->value_size > MAX_VALUE_SIZE)
632 		return -E2BIG;
633 
634 	return 0;
635 }
636 
637 static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr)
638 {
639 	struct bpf_sk_storage_map *smap;
640 	unsigned int i;
641 	u32 nbuckets;
642 	u64 cost;
643 	int ret;
644 
645 	smap = kzalloc(sizeof(*smap), GFP_USER | __GFP_NOWARN);
646 	if (!smap)
647 		return ERR_PTR(-ENOMEM);
648 	bpf_map_init_from_attr(&smap->map, attr);
649 
650 	nbuckets = roundup_pow_of_two(num_possible_cpus());
651 	/* Use at least 2 buckets, select_bucket() is undefined behavior with 1 bucket */
652 	nbuckets = max_t(u32, 2, nbuckets);
653 	smap->bucket_log = ilog2(nbuckets);
654 	cost = sizeof(*smap->buckets) * nbuckets + sizeof(*smap);
655 
656 	ret = bpf_map_charge_init(&smap->map.memory, cost);
657 	if (ret < 0) {
658 		kfree(smap);
659 		return ERR_PTR(ret);
660 	}
661 
662 	smap->buckets = kvcalloc(sizeof(*smap->buckets), nbuckets,
663 				 GFP_USER | __GFP_NOWARN);
664 	if (!smap->buckets) {
665 		bpf_map_charge_finish(&smap->map.memory);
666 		kfree(smap);
667 		return ERR_PTR(-ENOMEM);
668 	}
669 
670 	for (i = 0; i < nbuckets; i++) {
671 		INIT_HLIST_HEAD(&smap->buckets[i].list);
672 		raw_spin_lock_init(&smap->buckets[i].lock);
673 	}
674 
675 	smap->elem_size = sizeof(struct bpf_sk_storage_elem) + attr->value_size;
676 	smap->cache_idx = (unsigned int)atomic_inc_return(&cache_idx) %
677 		BPF_SK_STORAGE_CACHE_SIZE;
678 
679 	return &smap->map;
680 }
681 
682 static int notsupp_get_next_key(struct bpf_map *map, void *key,
683 				void *next_key)
684 {
685 	return -ENOTSUPP;
686 }
687 
688 static int bpf_sk_storage_map_check_btf(const struct bpf_map *map,
689 					const struct btf *btf,
690 					const struct btf_type *key_type,
691 					const struct btf_type *value_type)
692 {
693 	u32 int_data;
694 
695 	if (BTF_INFO_KIND(key_type->info) != BTF_KIND_INT)
696 		return -EINVAL;
697 
698 	int_data = *(u32 *)(key_type + 1);
699 	if (BTF_INT_BITS(int_data) != 32 || BTF_INT_OFFSET(int_data))
700 		return -EINVAL;
701 
702 	return 0;
703 }
704 
705 static void *bpf_fd_sk_storage_lookup_elem(struct bpf_map *map, void *key)
706 {
707 	struct bpf_sk_storage_data *sdata;
708 	struct socket *sock;
709 	int fd, err;
710 
711 	fd = *(int *)key;
712 	sock = sockfd_lookup(fd, &err);
713 	if (sock) {
714 		sdata = sk_storage_lookup(sock->sk, map, true);
715 		sockfd_put(sock);
716 		return sdata ? sdata->data : NULL;
717 	}
718 
719 	return ERR_PTR(err);
720 }
721 
722 static int bpf_fd_sk_storage_update_elem(struct bpf_map *map, void *key,
723 					 void *value, u64 map_flags)
724 {
725 	struct bpf_sk_storage_data *sdata;
726 	struct socket *sock;
727 	int fd, err;
728 
729 	fd = *(int *)key;
730 	sock = sockfd_lookup(fd, &err);
731 	if (sock) {
732 		sdata = sk_storage_update(sock->sk, map, value, map_flags);
733 		sockfd_put(sock);
734 		return PTR_ERR_OR_ZERO(sdata);
735 	}
736 
737 	return err;
738 }
739 
740 static int bpf_fd_sk_storage_delete_elem(struct bpf_map *map, void *key)
741 {
742 	struct socket *sock;
743 	int fd, err;
744 
745 	fd = *(int *)key;
746 	sock = sockfd_lookup(fd, &err);
747 	if (sock) {
748 		err = sk_storage_delete(sock->sk, map);
749 		sockfd_put(sock);
750 		return err;
751 	}
752 
753 	return err;
754 }
755 
756 static struct bpf_sk_storage_elem *
757 bpf_sk_storage_clone_elem(struct sock *newsk,
758 			  struct bpf_sk_storage_map *smap,
759 			  struct bpf_sk_storage_elem *selem)
760 {
761 	struct bpf_sk_storage_elem *copy_selem;
762 
763 	copy_selem = selem_alloc(smap, newsk, NULL, true);
764 	if (!copy_selem)
765 		return NULL;
766 
767 	if (map_value_has_spin_lock(&smap->map))
768 		copy_map_value_locked(&smap->map, SDATA(copy_selem)->data,
769 				      SDATA(selem)->data, true);
770 	else
771 		copy_map_value(&smap->map, SDATA(copy_selem)->data,
772 			       SDATA(selem)->data);
773 
774 	return copy_selem;
775 }
776 
777 int bpf_sk_storage_clone(const struct sock *sk, struct sock *newsk)
778 {
779 	struct bpf_sk_storage *new_sk_storage = NULL;
780 	struct bpf_sk_storage *sk_storage;
781 	struct bpf_sk_storage_elem *selem;
782 	int ret = 0;
783 
784 	RCU_INIT_POINTER(newsk->sk_bpf_storage, NULL);
785 
786 	rcu_read_lock();
787 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
788 
789 	if (!sk_storage || hlist_empty(&sk_storage->list))
790 		goto out;
791 
792 	hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
793 		struct bpf_sk_storage_elem *copy_selem;
794 		struct bpf_sk_storage_map *smap;
795 		struct bpf_map *map;
796 
797 		smap = rcu_dereference(SDATA(selem)->smap);
798 		if (!(smap->map.map_flags & BPF_F_CLONE))
799 			continue;
800 
801 		/* Note that for lockless listeners adding new element
802 		 * here can race with cleanup in bpf_sk_storage_map_free.
803 		 * Try to grab map refcnt to make sure that it's still
804 		 * alive and prevent concurrent removal.
805 		 */
806 		map = bpf_map_inc_not_zero(&smap->map);
807 		if (IS_ERR(map))
808 			continue;
809 
810 		copy_selem = bpf_sk_storage_clone_elem(newsk, smap, selem);
811 		if (!copy_selem) {
812 			ret = -ENOMEM;
813 			bpf_map_put(map);
814 			goto out;
815 		}
816 
817 		if (new_sk_storage) {
818 			selem_link_map(smap, copy_selem);
819 			__selem_link_sk(new_sk_storage, copy_selem);
820 		} else {
821 			ret = sk_storage_alloc(newsk, smap, copy_selem);
822 			if (ret) {
823 				kfree(copy_selem);
824 				atomic_sub(smap->elem_size,
825 					   &newsk->sk_omem_alloc);
826 				bpf_map_put(map);
827 				goto out;
828 			}
829 
830 			new_sk_storage = rcu_dereference(copy_selem->sk_storage);
831 		}
832 		bpf_map_put(map);
833 	}
834 
835 out:
836 	rcu_read_unlock();
837 
838 	/* In case of an error, don't free anything explicitly here, the
839 	 * caller is responsible to call bpf_sk_storage_free.
840 	 */
841 
842 	return ret;
843 }
844 
845 BPF_CALL_4(bpf_sk_storage_get, struct bpf_map *, map, struct sock *, sk,
846 	   void *, value, u64, flags)
847 {
848 	struct bpf_sk_storage_data *sdata;
849 
850 	if (flags > BPF_SK_STORAGE_GET_F_CREATE)
851 		return (unsigned long)NULL;
852 
853 	sdata = sk_storage_lookup(sk, map, true);
854 	if (sdata)
855 		return (unsigned long)sdata->data;
856 
857 	if (flags == BPF_SK_STORAGE_GET_F_CREATE &&
858 	    /* Cannot add new elem to a going away sk.
859 	     * Otherwise, the new elem may become a leak
860 	     * (and also other memory issues during map
861 	     *  destruction).
862 	     */
863 	    refcount_inc_not_zero(&sk->sk_refcnt)) {
864 		sdata = sk_storage_update(sk, map, value, BPF_NOEXIST);
865 		/* sk must be a fullsock (guaranteed by verifier),
866 		 * so sock_gen_put() is unnecessary.
867 		 */
868 		sock_put(sk);
869 		return IS_ERR(sdata) ?
870 			(unsigned long)NULL : (unsigned long)sdata->data;
871 	}
872 
873 	return (unsigned long)NULL;
874 }
875 
876 BPF_CALL_2(bpf_sk_storage_delete, struct bpf_map *, map, struct sock *, sk)
877 {
878 	if (refcount_inc_not_zero(&sk->sk_refcnt)) {
879 		int err;
880 
881 		err = sk_storage_delete(sk, map);
882 		sock_put(sk);
883 		return err;
884 	}
885 
886 	return -ENOENT;
887 }
888 
889 const struct bpf_map_ops sk_storage_map_ops = {
890 	.map_alloc_check = bpf_sk_storage_map_alloc_check,
891 	.map_alloc = bpf_sk_storage_map_alloc,
892 	.map_free = bpf_sk_storage_map_free,
893 	.map_get_next_key = notsupp_get_next_key,
894 	.map_lookup_elem = bpf_fd_sk_storage_lookup_elem,
895 	.map_update_elem = bpf_fd_sk_storage_update_elem,
896 	.map_delete_elem = bpf_fd_sk_storage_delete_elem,
897 	.map_check_btf = bpf_sk_storage_map_check_btf,
898 };
899 
900 const struct bpf_func_proto bpf_sk_storage_get_proto = {
901 	.func		= bpf_sk_storage_get,
902 	.gpl_only	= false,
903 	.ret_type	= RET_PTR_TO_MAP_VALUE_OR_NULL,
904 	.arg1_type	= ARG_CONST_MAP_PTR,
905 	.arg2_type	= ARG_PTR_TO_SOCKET,
906 	.arg3_type	= ARG_PTR_TO_MAP_VALUE_OR_NULL,
907 	.arg4_type	= ARG_ANYTHING,
908 };
909 
910 const struct bpf_func_proto bpf_sk_storage_delete_proto = {
911 	.func		= bpf_sk_storage_delete,
912 	.gpl_only	= false,
913 	.ret_type	= RET_INTEGER,
914 	.arg1_type	= ARG_CONST_MAP_PTR,
915 	.arg2_type	= ARG_PTR_TO_SOCKET,
916 };
917 
918 struct bpf_sk_storage_diag {
919 	u32 nr_maps;
920 	struct bpf_map *maps[];
921 };
922 
923 /* The reply will be like:
924  * INET_DIAG_BPF_SK_STORAGES (nla_nest)
925  *	SK_DIAG_BPF_STORAGE (nla_nest)
926  *		SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
927  *		SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
928  *	SK_DIAG_BPF_STORAGE (nla_nest)
929  *		SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
930  *		SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
931  *	....
932  */
933 static int nla_value_size(u32 value_size)
934 {
935 	/* SK_DIAG_BPF_STORAGE (nla_nest)
936 	 *	SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
937 	 *	SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
938 	 */
939 	return nla_total_size(0) + nla_total_size(sizeof(u32)) +
940 		nla_total_size_64bit(value_size);
941 }
942 
943 void bpf_sk_storage_diag_free(struct bpf_sk_storage_diag *diag)
944 {
945 	u32 i;
946 
947 	if (!diag)
948 		return;
949 
950 	for (i = 0; i < diag->nr_maps; i++)
951 		bpf_map_put(diag->maps[i]);
952 
953 	kfree(diag);
954 }
955 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_free);
956 
957 static bool diag_check_dup(const struct bpf_sk_storage_diag *diag,
958 			   const struct bpf_map *map)
959 {
960 	u32 i;
961 
962 	for (i = 0; i < diag->nr_maps; i++) {
963 		if (diag->maps[i] == map)
964 			return true;
965 	}
966 
967 	return false;
968 }
969 
970 struct bpf_sk_storage_diag *
971 bpf_sk_storage_diag_alloc(const struct nlattr *nla_stgs)
972 {
973 	struct bpf_sk_storage_diag *diag;
974 	struct nlattr *nla;
975 	u32 nr_maps = 0;
976 	int rem, err;
977 
978 	/* bpf_sk_storage_map is currently limited to CAP_SYS_ADMIN as
979 	 * the map_alloc_check() side also does.
980 	 */
981 	if (!capable(CAP_SYS_ADMIN))
982 		return ERR_PTR(-EPERM);
983 
984 	nla_for_each_nested(nla, nla_stgs, rem) {
985 		if (nla_type(nla) == SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
986 			nr_maps++;
987 	}
988 
989 	diag = kzalloc(sizeof(*diag) + sizeof(diag->maps[0]) * nr_maps,
990 		       GFP_KERNEL);
991 	if (!diag)
992 		return ERR_PTR(-ENOMEM);
993 
994 	nla_for_each_nested(nla, nla_stgs, rem) {
995 		struct bpf_map *map;
996 		int map_fd;
997 
998 		if (nla_type(nla) != SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
999 			continue;
1000 
1001 		map_fd = nla_get_u32(nla);
1002 		map = bpf_map_get(map_fd);
1003 		if (IS_ERR(map)) {
1004 			err = PTR_ERR(map);
1005 			goto err_free;
1006 		}
1007 		if (map->map_type != BPF_MAP_TYPE_SK_STORAGE) {
1008 			bpf_map_put(map);
1009 			err = -EINVAL;
1010 			goto err_free;
1011 		}
1012 		if (diag_check_dup(diag, map)) {
1013 			bpf_map_put(map);
1014 			err = -EEXIST;
1015 			goto err_free;
1016 		}
1017 		diag->maps[diag->nr_maps++] = map;
1018 	}
1019 
1020 	return diag;
1021 
1022 err_free:
1023 	bpf_sk_storage_diag_free(diag);
1024 	return ERR_PTR(err);
1025 }
1026 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_alloc);
1027 
1028 static int diag_get(struct bpf_sk_storage_data *sdata, struct sk_buff *skb)
1029 {
1030 	struct nlattr *nla_stg, *nla_value;
1031 	struct bpf_sk_storage_map *smap;
1032 
1033 	/* It cannot exceed max nlattr's payload */
1034 	BUILD_BUG_ON(U16_MAX - NLA_HDRLEN < MAX_VALUE_SIZE);
1035 
1036 	nla_stg = nla_nest_start(skb, SK_DIAG_BPF_STORAGE);
1037 	if (!nla_stg)
1038 		return -EMSGSIZE;
1039 
1040 	smap = rcu_dereference(sdata->smap);
1041 	if (nla_put_u32(skb, SK_DIAG_BPF_STORAGE_MAP_ID, smap->map.id))
1042 		goto errout;
1043 
1044 	nla_value = nla_reserve_64bit(skb, SK_DIAG_BPF_STORAGE_MAP_VALUE,
1045 				      smap->map.value_size,
1046 				      SK_DIAG_BPF_STORAGE_PAD);
1047 	if (!nla_value)
1048 		goto errout;
1049 
1050 	if (map_value_has_spin_lock(&smap->map))
1051 		copy_map_value_locked(&smap->map, nla_data(nla_value),
1052 				      sdata->data, true);
1053 	else
1054 		copy_map_value(&smap->map, nla_data(nla_value), sdata->data);
1055 
1056 	nla_nest_end(skb, nla_stg);
1057 	return 0;
1058 
1059 errout:
1060 	nla_nest_cancel(skb, nla_stg);
1061 	return -EMSGSIZE;
1062 }
1063 
1064 static int bpf_sk_storage_diag_put_all(struct sock *sk, struct sk_buff *skb,
1065 				       int stg_array_type,
1066 				       unsigned int *res_diag_size)
1067 {
1068 	/* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
1069 	unsigned int diag_size = nla_total_size(0);
1070 	struct bpf_sk_storage *sk_storage;
1071 	struct bpf_sk_storage_elem *selem;
1072 	struct bpf_sk_storage_map *smap;
1073 	struct nlattr *nla_stgs;
1074 	unsigned int saved_len;
1075 	int err = 0;
1076 
1077 	rcu_read_lock();
1078 
1079 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
1080 	if (!sk_storage || hlist_empty(&sk_storage->list)) {
1081 		rcu_read_unlock();
1082 		return 0;
1083 	}
1084 
1085 	nla_stgs = nla_nest_start(skb, stg_array_type);
1086 	if (!nla_stgs)
1087 		/* Continue to learn diag_size */
1088 		err = -EMSGSIZE;
1089 
1090 	saved_len = skb->len;
1091 	hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
1092 		smap = rcu_dereference(SDATA(selem)->smap);
1093 		diag_size += nla_value_size(smap->map.value_size);
1094 
1095 		if (nla_stgs && diag_get(SDATA(selem), skb))
1096 			/* Continue to learn diag_size */
1097 			err = -EMSGSIZE;
1098 	}
1099 
1100 	rcu_read_unlock();
1101 
1102 	if (nla_stgs) {
1103 		if (saved_len == skb->len)
1104 			nla_nest_cancel(skb, nla_stgs);
1105 		else
1106 			nla_nest_end(skb, nla_stgs);
1107 	}
1108 
1109 	if (diag_size == nla_total_size(0)) {
1110 		*res_diag_size = 0;
1111 		return 0;
1112 	}
1113 
1114 	*res_diag_size = diag_size;
1115 	return err;
1116 }
1117 
1118 int bpf_sk_storage_diag_put(struct bpf_sk_storage_diag *diag,
1119 			    struct sock *sk, struct sk_buff *skb,
1120 			    int stg_array_type,
1121 			    unsigned int *res_diag_size)
1122 {
1123 	/* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
1124 	unsigned int diag_size = nla_total_size(0);
1125 	struct bpf_sk_storage *sk_storage;
1126 	struct bpf_sk_storage_data *sdata;
1127 	struct nlattr *nla_stgs;
1128 	unsigned int saved_len;
1129 	int err = 0;
1130 	u32 i;
1131 
1132 	*res_diag_size = 0;
1133 
1134 	/* No map has been specified.  Dump all. */
1135 	if (!diag->nr_maps)
1136 		return bpf_sk_storage_diag_put_all(sk, skb, stg_array_type,
1137 						   res_diag_size);
1138 
1139 	rcu_read_lock();
1140 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
1141 	if (!sk_storage || hlist_empty(&sk_storage->list)) {
1142 		rcu_read_unlock();
1143 		return 0;
1144 	}
1145 
1146 	nla_stgs = nla_nest_start(skb, stg_array_type);
1147 	if (!nla_stgs)
1148 		/* Continue to learn diag_size */
1149 		err = -EMSGSIZE;
1150 
1151 	saved_len = skb->len;
1152 	for (i = 0; i < diag->nr_maps; i++) {
1153 		sdata = __sk_storage_lookup(sk_storage,
1154 				(struct bpf_sk_storage_map *)diag->maps[i],
1155 				false);
1156 
1157 		if (!sdata)
1158 			continue;
1159 
1160 		diag_size += nla_value_size(diag->maps[i]->value_size);
1161 
1162 		if (nla_stgs && diag_get(sdata, skb))
1163 			/* Continue to learn diag_size */
1164 			err = -EMSGSIZE;
1165 	}
1166 	rcu_read_unlock();
1167 
1168 	if (nla_stgs) {
1169 		if (saved_len == skb->len)
1170 			nla_nest_cancel(skb, nla_stgs);
1171 		else
1172 			nla_nest_end(skb, nla_stgs);
1173 	}
1174 
1175 	if (diag_size == nla_total_size(0)) {
1176 		*res_diag_size = 0;
1177 		return 0;
1178 	}
1179 
1180 	*res_diag_size = diag_size;
1181 	return err;
1182 }
1183 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_put);
1184