xref: /linux/net/core/bpf_sk_storage.c (revision b8265621f4888af9494e1d685620871ec81bc33d)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook  */
3 #include <linux/rculist.h>
4 #include <linux/list.h>
5 #include <linux/hash.h>
6 #include <linux/types.h>
7 #include <linux/spinlock.h>
8 #include <linux/bpf.h>
9 #include <net/bpf_sk_storage.h>
10 #include <net/sock.h>
11 #include <uapi/linux/sock_diag.h>
12 #include <uapi/linux/btf.h>
13 
14 #define SK_STORAGE_CREATE_FLAG_MASK					\
15 	(BPF_F_NO_PREALLOC | BPF_F_CLONE)
16 
17 struct bucket {
18 	struct hlist_head list;
19 	raw_spinlock_t lock;
20 };
21 
22 /* Thp map is not the primary owner of a bpf_sk_storage_elem.
23  * Instead, the sk->sk_bpf_storage is.
24  *
25  * The map (bpf_sk_storage_map) is for two purposes
26  * 1. Define the size of the "sk local storage".  It is
27  *    the map's value_size.
28  *
29  * 2. Maintain a list to keep track of all elems such
30  *    that they can be cleaned up during the map destruction.
31  *
32  * When a bpf local storage is being looked up for a
33  * particular sk,  the "bpf_map" pointer is actually used
34  * as the "key" to search in the list of elem in
35  * sk->sk_bpf_storage.
36  *
37  * Hence, consider sk->sk_bpf_storage is the mini-map
38  * with the "bpf_map" pointer as the searching key.
39  */
40 struct bpf_sk_storage_map {
41 	struct bpf_map map;
42 	/* Lookup elem does not require accessing the map.
43 	 *
44 	 * Updating/Deleting requires a bucket lock to
45 	 * link/unlink the elem from the map.  Having
46 	 * multiple buckets to improve contention.
47 	 */
48 	struct bucket *buckets;
49 	u32 bucket_log;
50 	u16 elem_size;
51 	u16 cache_idx;
52 };
53 
54 struct bpf_sk_storage_data {
55 	/* smap is used as the searching key when looking up
56 	 * from sk->sk_bpf_storage.
57 	 *
58 	 * Put it in the same cacheline as the data to minimize
59 	 * the number of cachelines access during the cache hit case.
60 	 */
61 	struct bpf_sk_storage_map __rcu *smap;
62 	u8 data[] __aligned(8);
63 };
64 
65 /* Linked to bpf_sk_storage and bpf_sk_storage_map */
66 struct bpf_sk_storage_elem {
67 	struct hlist_node map_node;	/* Linked to bpf_sk_storage_map */
68 	struct hlist_node snode;	/* Linked to bpf_sk_storage */
69 	struct bpf_sk_storage __rcu *sk_storage;
70 	struct rcu_head rcu;
71 	/* 8 bytes hole */
72 	/* The data is stored in aother cacheline to minimize
73 	 * the number of cachelines access during a cache hit.
74 	 */
75 	struct bpf_sk_storage_data sdata ____cacheline_aligned;
76 };
77 
78 #define SELEM(_SDATA) container_of((_SDATA), struct bpf_sk_storage_elem, sdata)
79 #define SDATA(_SELEM) (&(_SELEM)->sdata)
80 #define BPF_SK_STORAGE_CACHE_SIZE	16
81 
82 static DEFINE_SPINLOCK(cache_idx_lock);
83 static u64 cache_idx_usage_counts[BPF_SK_STORAGE_CACHE_SIZE];
84 
85 struct bpf_sk_storage {
86 	struct bpf_sk_storage_data __rcu *cache[BPF_SK_STORAGE_CACHE_SIZE];
87 	struct hlist_head list;	/* List of bpf_sk_storage_elem */
88 	struct sock *sk;	/* The sk that owns the the above "list" of
89 				 * bpf_sk_storage_elem.
90 				 */
91 	struct rcu_head rcu;
92 	raw_spinlock_t lock;	/* Protect adding/removing from the "list" */
93 };
94 
95 static struct bucket *select_bucket(struct bpf_sk_storage_map *smap,
96 				    struct bpf_sk_storage_elem *selem)
97 {
98 	return &smap->buckets[hash_ptr(selem, smap->bucket_log)];
99 }
100 
101 static int omem_charge(struct sock *sk, unsigned int size)
102 {
103 	/* same check as in sock_kmalloc() */
104 	if (size <= sysctl_optmem_max &&
105 	    atomic_read(&sk->sk_omem_alloc) + size < sysctl_optmem_max) {
106 		atomic_add(size, &sk->sk_omem_alloc);
107 		return 0;
108 	}
109 
110 	return -ENOMEM;
111 }
112 
113 static bool selem_linked_to_sk(const struct bpf_sk_storage_elem *selem)
114 {
115 	return !hlist_unhashed(&selem->snode);
116 }
117 
118 static bool selem_linked_to_map(const struct bpf_sk_storage_elem *selem)
119 {
120 	return !hlist_unhashed(&selem->map_node);
121 }
122 
123 static struct bpf_sk_storage_elem *selem_alloc(struct bpf_sk_storage_map *smap,
124 					       struct sock *sk, void *value,
125 					       bool charge_omem)
126 {
127 	struct bpf_sk_storage_elem *selem;
128 
129 	if (charge_omem && omem_charge(sk, smap->elem_size))
130 		return NULL;
131 
132 	selem = kzalloc(smap->elem_size, GFP_ATOMIC | __GFP_NOWARN);
133 	if (selem) {
134 		if (value)
135 			memcpy(SDATA(selem)->data, value, smap->map.value_size);
136 		return selem;
137 	}
138 
139 	if (charge_omem)
140 		atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
141 
142 	return NULL;
143 }
144 
145 /* sk_storage->lock must be held and selem->sk_storage == sk_storage.
146  * The caller must ensure selem->smap is still valid to be
147  * dereferenced for its smap->elem_size and smap->cache_idx.
148  */
149 static bool __selem_unlink_sk(struct bpf_sk_storage *sk_storage,
150 			      struct bpf_sk_storage_elem *selem,
151 			      bool uncharge_omem)
152 {
153 	struct bpf_sk_storage_map *smap;
154 	bool free_sk_storage;
155 	struct sock *sk;
156 
157 	smap = rcu_dereference(SDATA(selem)->smap);
158 	sk = sk_storage->sk;
159 
160 	/* All uncharging on sk->sk_omem_alloc must be done first.
161 	 * sk may be freed once the last selem is unlinked from sk_storage.
162 	 */
163 	if (uncharge_omem)
164 		atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
165 
166 	free_sk_storage = hlist_is_singular_node(&selem->snode,
167 						 &sk_storage->list);
168 	if (free_sk_storage) {
169 		atomic_sub(sizeof(struct bpf_sk_storage), &sk->sk_omem_alloc);
170 		sk_storage->sk = NULL;
171 		/* After this RCU_INIT, sk may be freed and cannot be used */
172 		RCU_INIT_POINTER(sk->sk_bpf_storage, NULL);
173 
174 		/* sk_storage is not freed now.  sk_storage->lock is
175 		 * still held and raw_spin_unlock_bh(&sk_storage->lock)
176 		 * will be done by the caller.
177 		 *
178 		 * Although the unlock will be done under
179 		 * rcu_read_lock(),  it is more intutivie to
180 		 * read if kfree_rcu(sk_storage, rcu) is done
181 		 * after the raw_spin_unlock_bh(&sk_storage->lock).
182 		 *
183 		 * Hence, a "bool free_sk_storage" is returned
184 		 * to the caller which then calls the kfree_rcu()
185 		 * after unlock.
186 		 */
187 	}
188 	hlist_del_init_rcu(&selem->snode);
189 	if (rcu_access_pointer(sk_storage->cache[smap->cache_idx]) ==
190 	    SDATA(selem))
191 		RCU_INIT_POINTER(sk_storage->cache[smap->cache_idx], NULL);
192 
193 	kfree_rcu(selem, rcu);
194 
195 	return free_sk_storage;
196 }
197 
198 static void selem_unlink_sk(struct bpf_sk_storage_elem *selem)
199 {
200 	struct bpf_sk_storage *sk_storage;
201 	bool free_sk_storage = false;
202 
203 	if (unlikely(!selem_linked_to_sk(selem)))
204 		/* selem has already been unlinked from sk */
205 		return;
206 
207 	sk_storage = rcu_dereference(selem->sk_storage);
208 	raw_spin_lock_bh(&sk_storage->lock);
209 	if (likely(selem_linked_to_sk(selem)))
210 		free_sk_storage = __selem_unlink_sk(sk_storage, selem, true);
211 	raw_spin_unlock_bh(&sk_storage->lock);
212 
213 	if (free_sk_storage)
214 		kfree_rcu(sk_storage, rcu);
215 }
216 
217 static void __selem_link_sk(struct bpf_sk_storage *sk_storage,
218 			    struct bpf_sk_storage_elem *selem)
219 {
220 	RCU_INIT_POINTER(selem->sk_storage, sk_storage);
221 	hlist_add_head(&selem->snode, &sk_storage->list);
222 }
223 
224 static void selem_unlink_map(struct bpf_sk_storage_elem *selem)
225 {
226 	struct bpf_sk_storage_map *smap;
227 	struct bucket *b;
228 
229 	if (unlikely(!selem_linked_to_map(selem)))
230 		/* selem has already be unlinked from smap */
231 		return;
232 
233 	smap = rcu_dereference(SDATA(selem)->smap);
234 	b = select_bucket(smap, selem);
235 	raw_spin_lock_bh(&b->lock);
236 	if (likely(selem_linked_to_map(selem)))
237 		hlist_del_init_rcu(&selem->map_node);
238 	raw_spin_unlock_bh(&b->lock);
239 }
240 
241 static void selem_link_map(struct bpf_sk_storage_map *smap,
242 			   struct bpf_sk_storage_elem *selem)
243 {
244 	struct bucket *b = select_bucket(smap, selem);
245 
246 	raw_spin_lock_bh(&b->lock);
247 	RCU_INIT_POINTER(SDATA(selem)->smap, smap);
248 	hlist_add_head_rcu(&selem->map_node, &b->list);
249 	raw_spin_unlock_bh(&b->lock);
250 }
251 
252 static void selem_unlink(struct bpf_sk_storage_elem *selem)
253 {
254 	/* Always unlink from map before unlinking from sk_storage
255 	 * because selem will be freed after successfully unlinked from
256 	 * the sk_storage.
257 	 */
258 	selem_unlink_map(selem);
259 	selem_unlink_sk(selem);
260 }
261 
262 static struct bpf_sk_storage_data *
263 __sk_storage_lookup(struct bpf_sk_storage *sk_storage,
264 		    struct bpf_sk_storage_map *smap,
265 		    bool cacheit_lockit)
266 {
267 	struct bpf_sk_storage_data *sdata;
268 	struct bpf_sk_storage_elem *selem;
269 
270 	/* Fast path (cache hit) */
271 	sdata = rcu_dereference(sk_storage->cache[smap->cache_idx]);
272 	if (sdata && rcu_access_pointer(sdata->smap) == smap)
273 		return sdata;
274 
275 	/* Slow path (cache miss) */
276 	hlist_for_each_entry_rcu(selem, &sk_storage->list, snode)
277 		if (rcu_access_pointer(SDATA(selem)->smap) == smap)
278 			break;
279 
280 	if (!selem)
281 		return NULL;
282 
283 	sdata = SDATA(selem);
284 	if (cacheit_lockit) {
285 		/* spinlock is needed to avoid racing with the
286 		 * parallel delete.  Otherwise, publishing an already
287 		 * deleted sdata to the cache will become a use-after-free
288 		 * problem in the next __sk_storage_lookup().
289 		 */
290 		raw_spin_lock_bh(&sk_storage->lock);
291 		if (selem_linked_to_sk(selem))
292 			rcu_assign_pointer(sk_storage->cache[smap->cache_idx],
293 					   sdata);
294 		raw_spin_unlock_bh(&sk_storage->lock);
295 	}
296 
297 	return sdata;
298 }
299 
300 static struct bpf_sk_storage_data *
301 sk_storage_lookup(struct sock *sk, struct bpf_map *map, bool cacheit_lockit)
302 {
303 	struct bpf_sk_storage *sk_storage;
304 	struct bpf_sk_storage_map *smap;
305 
306 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
307 	if (!sk_storage)
308 		return NULL;
309 
310 	smap = (struct bpf_sk_storage_map *)map;
311 	return __sk_storage_lookup(sk_storage, smap, cacheit_lockit);
312 }
313 
314 static int check_flags(const struct bpf_sk_storage_data *old_sdata,
315 		       u64 map_flags)
316 {
317 	if (old_sdata && (map_flags & ~BPF_F_LOCK) == BPF_NOEXIST)
318 		/* elem already exists */
319 		return -EEXIST;
320 
321 	if (!old_sdata && (map_flags & ~BPF_F_LOCK) == BPF_EXIST)
322 		/* elem doesn't exist, cannot update it */
323 		return -ENOENT;
324 
325 	return 0;
326 }
327 
328 static int sk_storage_alloc(struct sock *sk,
329 			    struct bpf_sk_storage_map *smap,
330 			    struct bpf_sk_storage_elem *first_selem)
331 {
332 	struct bpf_sk_storage *prev_sk_storage, *sk_storage;
333 	int err;
334 
335 	err = omem_charge(sk, sizeof(*sk_storage));
336 	if (err)
337 		return err;
338 
339 	sk_storage = kzalloc(sizeof(*sk_storage), GFP_ATOMIC | __GFP_NOWARN);
340 	if (!sk_storage) {
341 		err = -ENOMEM;
342 		goto uncharge;
343 	}
344 	INIT_HLIST_HEAD(&sk_storage->list);
345 	raw_spin_lock_init(&sk_storage->lock);
346 	sk_storage->sk = sk;
347 
348 	__selem_link_sk(sk_storage, first_selem);
349 	selem_link_map(smap, first_selem);
350 	/* Publish sk_storage to sk.  sk->sk_lock cannot be acquired.
351 	 * Hence, atomic ops is used to set sk->sk_bpf_storage
352 	 * from NULL to the newly allocated sk_storage ptr.
353 	 *
354 	 * From now on, the sk->sk_bpf_storage pointer is protected
355 	 * by the sk_storage->lock.  Hence,  when freeing
356 	 * the sk->sk_bpf_storage, the sk_storage->lock must
357 	 * be held before setting sk->sk_bpf_storage to NULL.
358 	 */
359 	prev_sk_storage = cmpxchg((struct bpf_sk_storage **)&sk->sk_bpf_storage,
360 				  NULL, sk_storage);
361 	if (unlikely(prev_sk_storage)) {
362 		selem_unlink_map(first_selem);
363 		err = -EAGAIN;
364 		goto uncharge;
365 
366 		/* Note that even first_selem was linked to smap's
367 		 * bucket->list, first_selem can be freed immediately
368 		 * (instead of kfree_rcu) because
369 		 * bpf_sk_storage_map_free() does a
370 		 * synchronize_rcu() before walking the bucket->list.
371 		 * Hence, no one is accessing selem from the
372 		 * bucket->list under rcu_read_lock().
373 		 */
374 	}
375 
376 	return 0;
377 
378 uncharge:
379 	kfree(sk_storage);
380 	atomic_sub(sizeof(*sk_storage), &sk->sk_omem_alloc);
381 	return err;
382 }
383 
384 /* sk cannot be going away because it is linking new elem
385  * to sk->sk_bpf_storage. (i.e. sk->sk_refcnt cannot be 0).
386  * Otherwise, it will become a leak (and other memory issues
387  * during map destruction).
388  */
389 static struct bpf_sk_storage_data *sk_storage_update(struct sock *sk,
390 						     struct bpf_map *map,
391 						     void *value,
392 						     u64 map_flags)
393 {
394 	struct bpf_sk_storage_data *old_sdata = NULL;
395 	struct bpf_sk_storage_elem *selem;
396 	struct bpf_sk_storage *sk_storage;
397 	struct bpf_sk_storage_map *smap;
398 	int err;
399 
400 	/* BPF_EXIST and BPF_NOEXIST cannot be both set */
401 	if (unlikely((map_flags & ~BPF_F_LOCK) > BPF_EXIST) ||
402 	    /* BPF_F_LOCK can only be used in a value with spin_lock */
403 	    unlikely((map_flags & BPF_F_LOCK) && !map_value_has_spin_lock(map)))
404 		return ERR_PTR(-EINVAL);
405 
406 	smap = (struct bpf_sk_storage_map *)map;
407 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
408 	if (!sk_storage || hlist_empty(&sk_storage->list)) {
409 		/* Very first elem for this sk */
410 		err = check_flags(NULL, map_flags);
411 		if (err)
412 			return ERR_PTR(err);
413 
414 		selem = selem_alloc(smap, sk, value, true);
415 		if (!selem)
416 			return ERR_PTR(-ENOMEM);
417 
418 		err = sk_storage_alloc(sk, smap, selem);
419 		if (err) {
420 			kfree(selem);
421 			atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
422 			return ERR_PTR(err);
423 		}
424 
425 		return SDATA(selem);
426 	}
427 
428 	if ((map_flags & BPF_F_LOCK) && !(map_flags & BPF_NOEXIST)) {
429 		/* Hoping to find an old_sdata to do inline update
430 		 * such that it can avoid taking the sk_storage->lock
431 		 * and changing the lists.
432 		 */
433 		old_sdata = __sk_storage_lookup(sk_storage, smap, false);
434 		err = check_flags(old_sdata, map_flags);
435 		if (err)
436 			return ERR_PTR(err);
437 		if (old_sdata && selem_linked_to_sk(SELEM(old_sdata))) {
438 			copy_map_value_locked(map, old_sdata->data,
439 					      value, false);
440 			return old_sdata;
441 		}
442 	}
443 
444 	raw_spin_lock_bh(&sk_storage->lock);
445 
446 	/* Recheck sk_storage->list under sk_storage->lock */
447 	if (unlikely(hlist_empty(&sk_storage->list))) {
448 		/* A parallel del is happening and sk_storage is going
449 		 * away.  It has just been checked before, so very
450 		 * unlikely.  Return instead of retry to keep things
451 		 * simple.
452 		 */
453 		err = -EAGAIN;
454 		goto unlock_err;
455 	}
456 
457 	old_sdata = __sk_storage_lookup(sk_storage, smap, false);
458 	err = check_flags(old_sdata, map_flags);
459 	if (err)
460 		goto unlock_err;
461 
462 	if (old_sdata && (map_flags & BPF_F_LOCK)) {
463 		copy_map_value_locked(map, old_sdata->data, value, false);
464 		selem = SELEM(old_sdata);
465 		goto unlock;
466 	}
467 
468 	/* sk_storage->lock is held.  Hence, we are sure
469 	 * we can unlink and uncharge the old_sdata successfully
470 	 * later.  Hence, instead of charging the new selem now
471 	 * and then uncharge the old selem later (which may cause
472 	 * a potential but unnecessary charge failure),  avoid taking
473 	 * a charge at all here (the "!old_sdata" check) and the
474 	 * old_sdata will not be uncharged later during __selem_unlink_sk().
475 	 */
476 	selem = selem_alloc(smap, sk, value, !old_sdata);
477 	if (!selem) {
478 		err = -ENOMEM;
479 		goto unlock_err;
480 	}
481 
482 	/* First, link the new selem to the map */
483 	selem_link_map(smap, selem);
484 
485 	/* Second, link (and publish) the new selem to sk_storage */
486 	__selem_link_sk(sk_storage, selem);
487 
488 	/* Third, remove old selem, SELEM(old_sdata) */
489 	if (old_sdata) {
490 		selem_unlink_map(SELEM(old_sdata));
491 		__selem_unlink_sk(sk_storage, SELEM(old_sdata), false);
492 	}
493 
494 unlock:
495 	raw_spin_unlock_bh(&sk_storage->lock);
496 	return SDATA(selem);
497 
498 unlock_err:
499 	raw_spin_unlock_bh(&sk_storage->lock);
500 	return ERR_PTR(err);
501 }
502 
503 static int sk_storage_delete(struct sock *sk, struct bpf_map *map)
504 {
505 	struct bpf_sk_storage_data *sdata;
506 
507 	sdata = sk_storage_lookup(sk, map, false);
508 	if (!sdata)
509 		return -ENOENT;
510 
511 	selem_unlink(SELEM(sdata));
512 
513 	return 0;
514 }
515 
516 static u16 cache_idx_get(void)
517 {
518 	u64 min_usage = U64_MAX;
519 	u16 i, res = 0;
520 
521 	spin_lock(&cache_idx_lock);
522 
523 	for (i = 0; i < BPF_SK_STORAGE_CACHE_SIZE; i++) {
524 		if (cache_idx_usage_counts[i] < min_usage) {
525 			min_usage = cache_idx_usage_counts[i];
526 			res = i;
527 
528 			/* Found a free cache_idx */
529 			if (!min_usage)
530 				break;
531 		}
532 	}
533 	cache_idx_usage_counts[res]++;
534 
535 	spin_unlock(&cache_idx_lock);
536 
537 	return res;
538 }
539 
540 static void cache_idx_free(u16 idx)
541 {
542 	spin_lock(&cache_idx_lock);
543 	cache_idx_usage_counts[idx]--;
544 	spin_unlock(&cache_idx_lock);
545 }
546 
547 /* Called by __sk_destruct() & bpf_sk_storage_clone() */
548 void bpf_sk_storage_free(struct sock *sk)
549 {
550 	struct bpf_sk_storage_elem *selem;
551 	struct bpf_sk_storage *sk_storage;
552 	bool free_sk_storage = false;
553 	struct hlist_node *n;
554 
555 	rcu_read_lock();
556 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
557 	if (!sk_storage) {
558 		rcu_read_unlock();
559 		return;
560 	}
561 
562 	/* Netiher the bpf_prog nor the bpf-map's syscall
563 	 * could be modifying the sk_storage->list now.
564 	 * Thus, no elem can be added-to or deleted-from the
565 	 * sk_storage->list by the bpf_prog or by the bpf-map's syscall.
566 	 *
567 	 * It is racing with bpf_sk_storage_map_free() alone
568 	 * when unlinking elem from the sk_storage->list and
569 	 * the map's bucket->list.
570 	 */
571 	raw_spin_lock_bh(&sk_storage->lock);
572 	hlist_for_each_entry_safe(selem, n, &sk_storage->list, snode) {
573 		/* Always unlink from map before unlinking from
574 		 * sk_storage.
575 		 */
576 		selem_unlink_map(selem);
577 		free_sk_storage = __selem_unlink_sk(sk_storage, selem, true);
578 	}
579 	raw_spin_unlock_bh(&sk_storage->lock);
580 	rcu_read_unlock();
581 
582 	if (free_sk_storage)
583 		kfree_rcu(sk_storage, rcu);
584 }
585 
586 static void bpf_sk_storage_map_free(struct bpf_map *map)
587 {
588 	struct bpf_sk_storage_elem *selem;
589 	struct bpf_sk_storage_map *smap;
590 	struct bucket *b;
591 	unsigned int i;
592 
593 	smap = (struct bpf_sk_storage_map *)map;
594 
595 	cache_idx_free(smap->cache_idx);
596 
597 	/* Note that this map might be concurrently cloned from
598 	 * bpf_sk_storage_clone. Wait for any existing bpf_sk_storage_clone
599 	 * RCU read section to finish before proceeding. New RCU
600 	 * read sections should be prevented via bpf_map_inc_not_zero.
601 	 */
602 	synchronize_rcu();
603 
604 	/* bpf prog and the userspace can no longer access this map
605 	 * now.  No new selem (of this map) can be added
606 	 * to the sk->sk_bpf_storage or to the map bucket's list.
607 	 *
608 	 * The elem of this map can be cleaned up here
609 	 * or
610 	 * by bpf_sk_storage_free() during __sk_destruct().
611 	 */
612 	for (i = 0; i < (1U << smap->bucket_log); i++) {
613 		b = &smap->buckets[i];
614 
615 		rcu_read_lock();
616 		/* No one is adding to b->list now */
617 		while ((selem = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(&b->list)),
618 						 struct bpf_sk_storage_elem,
619 						 map_node))) {
620 			selem_unlink(selem);
621 			cond_resched_rcu();
622 		}
623 		rcu_read_unlock();
624 	}
625 
626 	/* bpf_sk_storage_free() may still need to access the map.
627 	 * e.g. bpf_sk_storage_free() has unlinked selem from the map
628 	 * which then made the above while((selem = ...)) loop
629 	 * exited immediately.
630 	 *
631 	 * However, the bpf_sk_storage_free() still needs to access
632 	 * the smap->elem_size to do the uncharging in
633 	 * __selem_unlink_sk().
634 	 *
635 	 * Hence, wait another rcu grace period for the
636 	 * bpf_sk_storage_free() to finish.
637 	 */
638 	synchronize_rcu();
639 
640 	kvfree(smap->buckets);
641 	kfree(map);
642 }
643 
644 /* U16_MAX is much more than enough for sk local storage
645  * considering a tcp_sock is ~2k.
646  */
647 #define MAX_VALUE_SIZE							\
648 	min_t(u32,							\
649 	      (KMALLOC_MAX_SIZE - MAX_BPF_STACK - sizeof(struct bpf_sk_storage_elem)), \
650 	      (U16_MAX - sizeof(struct bpf_sk_storage_elem)))
651 
652 static int bpf_sk_storage_map_alloc_check(union bpf_attr *attr)
653 {
654 	if (attr->map_flags & ~SK_STORAGE_CREATE_FLAG_MASK ||
655 	    !(attr->map_flags & BPF_F_NO_PREALLOC) ||
656 	    attr->max_entries ||
657 	    attr->key_size != sizeof(int) || !attr->value_size ||
658 	    /* Enforce BTF for userspace sk dumping */
659 	    !attr->btf_key_type_id || !attr->btf_value_type_id)
660 		return -EINVAL;
661 
662 	if (!bpf_capable())
663 		return -EPERM;
664 
665 	if (attr->value_size > MAX_VALUE_SIZE)
666 		return -E2BIG;
667 
668 	return 0;
669 }
670 
671 static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr)
672 {
673 	struct bpf_sk_storage_map *smap;
674 	unsigned int i;
675 	u32 nbuckets;
676 	u64 cost;
677 	int ret;
678 
679 	smap = kzalloc(sizeof(*smap), GFP_USER | __GFP_NOWARN);
680 	if (!smap)
681 		return ERR_PTR(-ENOMEM);
682 	bpf_map_init_from_attr(&smap->map, attr);
683 
684 	nbuckets = roundup_pow_of_two(num_possible_cpus());
685 	/* Use at least 2 buckets, select_bucket() is undefined behavior with 1 bucket */
686 	nbuckets = max_t(u32, 2, nbuckets);
687 	smap->bucket_log = ilog2(nbuckets);
688 	cost = sizeof(*smap->buckets) * nbuckets + sizeof(*smap);
689 
690 	ret = bpf_map_charge_init(&smap->map.memory, cost);
691 	if (ret < 0) {
692 		kfree(smap);
693 		return ERR_PTR(ret);
694 	}
695 
696 	smap->buckets = kvcalloc(sizeof(*smap->buckets), nbuckets,
697 				 GFP_USER | __GFP_NOWARN);
698 	if (!smap->buckets) {
699 		bpf_map_charge_finish(&smap->map.memory);
700 		kfree(smap);
701 		return ERR_PTR(-ENOMEM);
702 	}
703 
704 	for (i = 0; i < nbuckets; i++) {
705 		INIT_HLIST_HEAD(&smap->buckets[i].list);
706 		raw_spin_lock_init(&smap->buckets[i].lock);
707 	}
708 
709 	smap->elem_size = sizeof(struct bpf_sk_storage_elem) + attr->value_size;
710 	smap->cache_idx = cache_idx_get();
711 
712 	return &smap->map;
713 }
714 
715 static int notsupp_get_next_key(struct bpf_map *map, void *key,
716 				void *next_key)
717 {
718 	return -ENOTSUPP;
719 }
720 
721 static int bpf_sk_storage_map_check_btf(const struct bpf_map *map,
722 					const struct btf *btf,
723 					const struct btf_type *key_type,
724 					const struct btf_type *value_type)
725 {
726 	u32 int_data;
727 
728 	if (BTF_INFO_KIND(key_type->info) != BTF_KIND_INT)
729 		return -EINVAL;
730 
731 	int_data = *(u32 *)(key_type + 1);
732 	if (BTF_INT_BITS(int_data) != 32 || BTF_INT_OFFSET(int_data))
733 		return -EINVAL;
734 
735 	return 0;
736 }
737 
738 static void *bpf_fd_sk_storage_lookup_elem(struct bpf_map *map, void *key)
739 {
740 	struct bpf_sk_storage_data *sdata;
741 	struct socket *sock;
742 	int fd, err;
743 
744 	fd = *(int *)key;
745 	sock = sockfd_lookup(fd, &err);
746 	if (sock) {
747 		sdata = sk_storage_lookup(sock->sk, map, true);
748 		sockfd_put(sock);
749 		return sdata ? sdata->data : NULL;
750 	}
751 
752 	return ERR_PTR(err);
753 }
754 
755 static int bpf_fd_sk_storage_update_elem(struct bpf_map *map, void *key,
756 					 void *value, u64 map_flags)
757 {
758 	struct bpf_sk_storage_data *sdata;
759 	struct socket *sock;
760 	int fd, err;
761 
762 	fd = *(int *)key;
763 	sock = sockfd_lookup(fd, &err);
764 	if (sock) {
765 		sdata = sk_storage_update(sock->sk, map, value, map_flags);
766 		sockfd_put(sock);
767 		return PTR_ERR_OR_ZERO(sdata);
768 	}
769 
770 	return err;
771 }
772 
773 static int bpf_fd_sk_storage_delete_elem(struct bpf_map *map, void *key)
774 {
775 	struct socket *sock;
776 	int fd, err;
777 
778 	fd = *(int *)key;
779 	sock = sockfd_lookup(fd, &err);
780 	if (sock) {
781 		err = sk_storage_delete(sock->sk, map);
782 		sockfd_put(sock);
783 		return err;
784 	}
785 
786 	return err;
787 }
788 
789 static struct bpf_sk_storage_elem *
790 bpf_sk_storage_clone_elem(struct sock *newsk,
791 			  struct bpf_sk_storage_map *smap,
792 			  struct bpf_sk_storage_elem *selem)
793 {
794 	struct bpf_sk_storage_elem *copy_selem;
795 
796 	copy_selem = selem_alloc(smap, newsk, NULL, true);
797 	if (!copy_selem)
798 		return NULL;
799 
800 	if (map_value_has_spin_lock(&smap->map))
801 		copy_map_value_locked(&smap->map, SDATA(copy_selem)->data,
802 				      SDATA(selem)->data, true);
803 	else
804 		copy_map_value(&smap->map, SDATA(copy_selem)->data,
805 			       SDATA(selem)->data);
806 
807 	return copy_selem;
808 }
809 
810 int bpf_sk_storage_clone(const struct sock *sk, struct sock *newsk)
811 {
812 	struct bpf_sk_storage *new_sk_storage = NULL;
813 	struct bpf_sk_storage *sk_storage;
814 	struct bpf_sk_storage_elem *selem;
815 	int ret = 0;
816 
817 	RCU_INIT_POINTER(newsk->sk_bpf_storage, NULL);
818 
819 	rcu_read_lock();
820 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
821 
822 	if (!sk_storage || hlist_empty(&sk_storage->list))
823 		goto out;
824 
825 	hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
826 		struct bpf_sk_storage_elem *copy_selem;
827 		struct bpf_sk_storage_map *smap;
828 		struct bpf_map *map;
829 
830 		smap = rcu_dereference(SDATA(selem)->smap);
831 		if (!(smap->map.map_flags & BPF_F_CLONE))
832 			continue;
833 
834 		/* Note that for lockless listeners adding new element
835 		 * here can race with cleanup in bpf_sk_storage_map_free.
836 		 * Try to grab map refcnt to make sure that it's still
837 		 * alive and prevent concurrent removal.
838 		 */
839 		map = bpf_map_inc_not_zero(&smap->map);
840 		if (IS_ERR(map))
841 			continue;
842 
843 		copy_selem = bpf_sk_storage_clone_elem(newsk, smap, selem);
844 		if (!copy_selem) {
845 			ret = -ENOMEM;
846 			bpf_map_put(map);
847 			goto out;
848 		}
849 
850 		if (new_sk_storage) {
851 			selem_link_map(smap, copy_selem);
852 			__selem_link_sk(new_sk_storage, copy_selem);
853 		} else {
854 			ret = sk_storage_alloc(newsk, smap, copy_selem);
855 			if (ret) {
856 				kfree(copy_selem);
857 				atomic_sub(smap->elem_size,
858 					   &newsk->sk_omem_alloc);
859 				bpf_map_put(map);
860 				goto out;
861 			}
862 
863 			new_sk_storage = rcu_dereference(copy_selem->sk_storage);
864 		}
865 		bpf_map_put(map);
866 	}
867 
868 out:
869 	rcu_read_unlock();
870 
871 	/* In case of an error, don't free anything explicitly here, the
872 	 * caller is responsible to call bpf_sk_storage_free.
873 	 */
874 
875 	return ret;
876 }
877 
878 BPF_CALL_4(bpf_sk_storage_get, struct bpf_map *, map, struct sock *, sk,
879 	   void *, value, u64, flags)
880 {
881 	struct bpf_sk_storage_data *sdata;
882 
883 	if (flags > BPF_SK_STORAGE_GET_F_CREATE)
884 		return (unsigned long)NULL;
885 
886 	sdata = sk_storage_lookup(sk, map, true);
887 	if (sdata)
888 		return (unsigned long)sdata->data;
889 
890 	if (flags == BPF_SK_STORAGE_GET_F_CREATE &&
891 	    /* Cannot add new elem to a going away sk.
892 	     * Otherwise, the new elem may become a leak
893 	     * (and also other memory issues during map
894 	     *  destruction).
895 	     */
896 	    refcount_inc_not_zero(&sk->sk_refcnt)) {
897 		sdata = sk_storage_update(sk, map, value, BPF_NOEXIST);
898 		/* sk must be a fullsock (guaranteed by verifier),
899 		 * so sock_gen_put() is unnecessary.
900 		 */
901 		sock_put(sk);
902 		return IS_ERR(sdata) ?
903 			(unsigned long)NULL : (unsigned long)sdata->data;
904 	}
905 
906 	return (unsigned long)NULL;
907 }
908 
909 BPF_CALL_2(bpf_sk_storage_delete, struct bpf_map *, map, struct sock *, sk)
910 {
911 	if (refcount_inc_not_zero(&sk->sk_refcnt)) {
912 		int err;
913 
914 		err = sk_storage_delete(sk, map);
915 		sock_put(sk);
916 		return err;
917 	}
918 
919 	return -ENOENT;
920 }
921 
922 static int sk_storage_map_btf_id;
923 const struct bpf_map_ops sk_storage_map_ops = {
924 	.map_alloc_check = bpf_sk_storage_map_alloc_check,
925 	.map_alloc = bpf_sk_storage_map_alloc,
926 	.map_free = bpf_sk_storage_map_free,
927 	.map_get_next_key = notsupp_get_next_key,
928 	.map_lookup_elem = bpf_fd_sk_storage_lookup_elem,
929 	.map_update_elem = bpf_fd_sk_storage_update_elem,
930 	.map_delete_elem = bpf_fd_sk_storage_delete_elem,
931 	.map_check_btf = bpf_sk_storage_map_check_btf,
932 	.map_btf_name = "bpf_sk_storage_map",
933 	.map_btf_id = &sk_storage_map_btf_id,
934 };
935 
936 const struct bpf_func_proto bpf_sk_storage_get_proto = {
937 	.func		= bpf_sk_storage_get,
938 	.gpl_only	= false,
939 	.ret_type	= RET_PTR_TO_MAP_VALUE_OR_NULL,
940 	.arg1_type	= ARG_CONST_MAP_PTR,
941 	.arg2_type	= ARG_PTR_TO_SOCKET,
942 	.arg3_type	= ARG_PTR_TO_MAP_VALUE_OR_NULL,
943 	.arg4_type	= ARG_ANYTHING,
944 };
945 
946 const struct bpf_func_proto bpf_sk_storage_delete_proto = {
947 	.func		= bpf_sk_storage_delete,
948 	.gpl_only	= false,
949 	.ret_type	= RET_INTEGER,
950 	.arg1_type	= ARG_CONST_MAP_PTR,
951 	.arg2_type	= ARG_PTR_TO_SOCKET,
952 };
953 
954 struct bpf_sk_storage_diag {
955 	u32 nr_maps;
956 	struct bpf_map *maps[];
957 };
958 
959 /* The reply will be like:
960  * INET_DIAG_BPF_SK_STORAGES (nla_nest)
961  *	SK_DIAG_BPF_STORAGE (nla_nest)
962  *		SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
963  *		SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
964  *	SK_DIAG_BPF_STORAGE (nla_nest)
965  *		SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
966  *		SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
967  *	....
968  */
969 static int nla_value_size(u32 value_size)
970 {
971 	/* SK_DIAG_BPF_STORAGE (nla_nest)
972 	 *	SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
973 	 *	SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
974 	 */
975 	return nla_total_size(0) + nla_total_size(sizeof(u32)) +
976 		nla_total_size_64bit(value_size);
977 }
978 
979 void bpf_sk_storage_diag_free(struct bpf_sk_storage_diag *diag)
980 {
981 	u32 i;
982 
983 	if (!diag)
984 		return;
985 
986 	for (i = 0; i < diag->nr_maps; i++)
987 		bpf_map_put(diag->maps[i]);
988 
989 	kfree(diag);
990 }
991 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_free);
992 
993 static bool diag_check_dup(const struct bpf_sk_storage_diag *diag,
994 			   const struct bpf_map *map)
995 {
996 	u32 i;
997 
998 	for (i = 0; i < diag->nr_maps; i++) {
999 		if (diag->maps[i] == map)
1000 			return true;
1001 	}
1002 
1003 	return false;
1004 }
1005 
1006 struct bpf_sk_storage_diag *
1007 bpf_sk_storage_diag_alloc(const struct nlattr *nla_stgs)
1008 {
1009 	struct bpf_sk_storage_diag *diag;
1010 	struct nlattr *nla;
1011 	u32 nr_maps = 0;
1012 	int rem, err;
1013 
1014 	/* bpf_sk_storage_map is currently limited to CAP_SYS_ADMIN as
1015 	 * the map_alloc_check() side also does.
1016 	 */
1017 	if (!bpf_capable())
1018 		return ERR_PTR(-EPERM);
1019 
1020 	nla_for_each_nested(nla, nla_stgs, rem) {
1021 		if (nla_type(nla) == SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
1022 			nr_maps++;
1023 	}
1024 
1025 	diag = kzalloc(sizeof(*diag) + sizeof(diag->maps[0]) * nr_maps,
1026 		       GFP_KERNEL);
1027 	if (!diag)
1028 		return ERR_PTR(-ENOMEM);
1029 
1030 	nla_for_each_nested(nla, nla_stgs, rem) {
1031 		struct bpf_map *map;
1032 		int map_fd;
1033 
1034 		if (nla_type(nla) != SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
1035 			continue;
1036 
1037 		map_fd = nla_get_u32(nla);
1038 		map = bpf_map_get(map_fd);
1039 		if (IS_ERR(map)) {
1040 			err = PTR_ERR(map);
1041 			goto err_free;
1042 		}
1043 		if (map->map_type != BPF_MAP_TYPE_SK_STORAGE) {
1044 			bpf_map_put(map);
1045 			err = -EINVAL;
1046 			goto err_free;
1047 		}
1048 		if (diag_check_dup(diag, map)) {
1049 			bpf_map_put(map);
1050 			err = -EEXIST;
1051 			goto err_free;
1052 		}
1053 		diag->maps[diag->nr_maps++] = map;
1054 	}
1055 
1056 	return diag;
1057 
1058 err_free:
1059 	bpf_sk_storage_diag_free(diag);
1060 	return ERR_PTR(err);
1061 }
1062 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_alloc);
1063 
1064 static int diag_get(struct bpf_sk_storage_data *sdata, struct sk_buff *skb)
1065 {
1066 	struct nlattr *nla_stg, *nla_value;
1067 	struct bpf_sk_storage_map *smap;
1068 
1069 	/* It cannot exceed max nlattr's payload */
1070 	BUILD_BUG_ON(U16_MAX - NLA_HDRLEN < MAX_VALUE_SIZE);
1071 
1072 	nla_stg = nla_nest_start(skb, SK_DIAG_BPF_STORAGE);
1073 	if (!nla_stg)
1074 		return -EMSGSIZE;
1075 
1076 	smap = rcu_dereference(sdata->smap);
1077 	if (nla_put_u32(skb, SK_DIAG_BPF_STORAGE_MAP_ID, smap->map.id))
1078 		goto errout;
1079 
1080 	nla_value = nla_reserve_64bit(skb, SK_DIAG_BPF_STORAGE_MAP_VALUE,
1081 				      smap->map.value_size,
1082 				      SK_DIAG_BPF_STORAGE_PAD);
1083 	if (!nla_value)
1084 		goto errout;
1085 
1086 	if (map_value_has_spin_lock(&smap->map))
1087 		copy_map_value_locked(&smap->map, nla_data(nla_value),
1088 				      sdata->data, true);
1089 	else
1090 		copy_map_value(&smap->map, nla_data(nla_value), sdata->data);
1091 
1092 	nla_nest_end(skb, nla_stg);
1093 	return 0;
1094 
1095 errout:
1096 	nla_nest_cancel(skb, nla_stg);
1097 	return -EMSGSIZE;
1098 }
1099 
1100 static int bpf_sk_storage_diag_put_all(struct sock *sk, struct sk_buff *skb,
1101 				       int stg_array_type,
1102 				       unsigned int *res_diag_size)
1103 {
1104 	/* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
1105 	unsigned int diag_size = nla_total_size(0);
1106 	struct bpf_sk_storage *sk_storage;
1107 	struct bpf_sk_storage_elem *selem;
1108 	struct bpf_sk_storage_map *smap;
1109 	struct nlattr *nla_stgs;
1110 	unsigned int saved_len;
1111 	int err = 0;
1112 
1113 	rcu_read_lock();
1114 
1115 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
1116 	if (!sk_storage || hlist_empty(&sk_storage->list)) {
1117 		rcu_read_unlock();
1118 		return 0;
1119 	}
1120 
1121 	nla_stgs = nla_nest_start(skb, stg_array_type);
1122 	if (!nla_stgs)
1123 		/* Continue to learn diag_size */
1124 		err = -EMSGSIZE;
1125 
1126 	saved_len = skb->len;
1127 	hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
1128 		smap = rcu_dereference(SDATA(selem)->smap);
1129 		diag_size += nla_value_size(smap->map.value_size);
1130 
1131 		if (nla_stgs && diag_get(SDATA(selem), skb))
1132 			/* Continue to learn diag_size */
1133 			err = -EMSGSIZE;
1134 	}
1135 
1136 	rcu_read_unlock();
1137 
1138 	if (nla_stgs) {
1139 		if (saved_len == skb->len)
1140 			nla_nest_cancel(skb, nla_stgs);
1141 		else
1142 			nla_nest_end(skb, nla_stgs);
1143 	}
1144 
1145 	if (diag_size == nla_total_size(0)) {
1146 		*res_diag_size = 0;
1147 		return 0;
1148 	}
1149 
1150 	*res_diag_size = diag_size;
1151 	return err;
1152 }
1153 
1154 int bpf_sk_storage_diag_put(struct bpf_sk_storage_diag *diag,
1155 			    struct sock *sk, struct sk_buff *skb,
1156 			    int stg_array_type,
1157 			    unsigned int *res_diag_size)
1158 {
1159 	/* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
1160 	unsigned int diag_size = nla_total_size(0);
1161 	struct bpf_sk_storage *sk_storage;
1162 	struct bpf_sk_storage_data *sdata;
1163 	struct nlattr *nla_stgs;
1164 	unsigned int saved_len;
1165 	int err = 0;
1166 	u32 i;
1167 
1168 	*res_diag_size = 0;
1169 
1170 	/* No map has been specified.  Dump all. */
1171 	if (!diag->nr_maps)
1172 		return bpf_sk_storage_diag_put_all(sk, skb, stg_array_type,
1173 						   res_diag_size);
1174 
1175 	rcu_read_lock();
1176 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
1177 	if (!sk_storage || hlist_empty(&sk_storage->list)) {
1178 		rcu_read_unlock();
1179 		return 0;
1180 	}
1181 
1182 	nla_stgs = nla_nest_start(skb, stg_array_type);
1183 	if (!nla_stgs)
1184 		/* Continue to learn diag_size */
1185 		err = -EMSGSIZE;
1186 
1187 	saved_len = skb->len;
1188 	for (i = 0; i < diag->nr_maps; i++) {
1189 		sdata = __sk_storage_lookup(sk_storage,
1190 				(struct bpf_sk_storage_map *)diag->maps[i],
1191 				false);
1192 
1193 		if (!sdata)
1194 			continue;
1195 
1196 		diag_size += nla_value_size(diag->maps[i]->value_size);
1197 
1198 		if (nla_stgs && diag_get(sdata, skb))
1199 			/* Continue to learn diag_size */
1200 			err = -EMSGSIZE;
1201 	}
1202 	rcu_read_unlock();
1203 
1204 	if (nla_stgs) {
1205 		if (saved_len == skb->len)
1206 			nla_nest_cancel(skb, nla_stgs);
1207 		else
1208 			nla_nest_end(skb, nla_stgs);
1209 	}
1210 
1211 	if (diag_size == nla_total_size(0)) {
1212 		*res_diag_size = 0;
1213 		return 0;
1214 	}
1215 
1216 	*res_diag_size = diag_size;
1217 	return err;
1218 }
1219 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_put);
1220