xref: /linux/kernel/bpf/bpf_struct_ops.c (revision 2b0cfa6e49566c8fa6759734cf821aa6e8271a9e)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (c) 2019 Facebook */
3 
4 #include <linux/bpf.h>
5 #include <linux/bpf_verifier.h>
6 #include <linux/btf.h>
7 #include <linux/filter.h>
8 #include <linux/slab.h>
9 #include <linux/numa.h>
10 #include <linux/seq_file.h>
11 #include <linux/refcount.h>
12 #include <linux/mutex.h>
13 #include <linux/btf_ids.h>
14 #include <linux/rcupdate_wait.h>
15 
16 struct bpf_struct_ops_value {
17 	struct bpf_struct_ops_common_value common;
18 	char data[] ____cacheline_aligned_in_smp;
19 };
20 
21 struct bpf_struct_ops_map {
22 	struct bpf_map map;
23 	struct rcu_head rcu;
24 	const struct bpf_struct_ops_desc *st_ops_desc;
25 	/* protect map_update */
26 	struct mutex lock;
27 	/* link has all the bpf_links that is populated
28 	 * to the func ptr of the kernel's struct
29 	 * (in kvalue.data).
30 	 */
31 	struct bpf_link **links;
32 	u32 links_cnt;
33 	/* image is a page that has all the trampolines
34 	 * that stores the func args before calling the bpf_prog.
35 	 * A PAGE_SIZE "image" is enough to store all trampoline for
36 	 * "links[]".
37 	 */
38 	void *image;
39 	/* The owner moduler's btf. */
40 	struct btf *btf;
41 	/* uvalue->data stores the kernel struct
42 	 * (e.g. tcp_congestion_ops) that is more useful
43 	 * to userspace than the kvalue.  For example,
44 	 * the bpf_prog's id is stored instead of the kernel
45 	 * address of a func ptr.
46 	 */
47 	struct bpf_struct_ops_value *uvalue;
48 	/* kvalue.data stores the actual kernel's struct
49 	 * (e.g. tcp_congestion_ops) that will be
50 	 * registered to the kernel subsystem.
51 	 */
52 	struct bpf_struct_ops_value kvalue;
53 };
54 
55 struct bpf_struct_ops_link {
56 	struct bpf_link link;
57 	struct bpf_map __rcu *map;
58 };
59 
60 static DEFINE_MUTEX(update_mutex);
61 
62 #define VALUE_PREFIX "bpf_struct_ops_"
63 #define VALUE_PREFIX_LEN (sizeof(VALUE_PREFIX) - 1)
64 
65 const struct bpf_verifier_ops bpf_struct_ops_verifier_ops = {
66 };
67 
68 const struct bpf_prog_ops bpf_struct_ops_prog_ops = {
69 #ifdef CONFIG_NET
70 	.test_run = bpf_struct_ops_test_run,
71 #endif
72 };
73 
74 BTF_ID_LIST(st_ops_ids)
75 BTF_ID(struct, module)
76 BTF_ID(struct, bpf_struct_ops_common_value)
77 
78 enum {
79 	IDX_MODULE_ID,
80 	IDX_ST_OPS_COMMON_VALUE_ID,
81 };
82 
83 extern struct btf *btf_vmlinux;
84 
85 static bool is_valid_value_type(struct btf *btf, s32 value_id,
86 				const struct btf_type *type,
87 				const char *value_name)
88 {
89 	const struct btf_type *common_value_type;
90 	const struct btf_member *member;
91 	const struct btf_type *vt, *mt;
92 
93 	vt = btf_type_by_id(btf, value_id);
94 	if (btf_vlen(vt) != 2) {
95 		pr_warn("The number of %s's members should be 2, but we get %d\n",
96 			value_name, btf_vlen(vt));
97 		return false;
98 	}
99 	member = btf_type_member(vt);
100 	mt = btf_type_by_id(btf, member->type);
101 	common_value_type = btf_type_by_id(btf_vmlinux,
102 					   st_ops_ids[IDX_ST_OPS_COMMON_VALUE_ID]);
103 	if (mt != common_value_type) {
104 		pr_warn("The first member of %s should be bpf_struct_ops_common_value\n",
105 			value_name);
106 		return false;
107 	}
108 	member++;
109 	mt = btf_type_by_id(btf, member->type);
110 	if (mt != type) {
111 		pr_warn("The second member of %s should be %s\n",
112 			value_name, btf_name_by_offset(btf, type->name_off));
113 		return false;
114 	}
115 
116 	return true;
117 }
118 
119 int bpf_struct_ops_desc_init(struct bpf_struct_ops_desc *st_ops_desc,
120 			     struct btf *btf,
121 			     struct bpf_verifier_log *log)
122 {
123 	struct bpf_struct_ops *st_ops = st_ops_desc->st_ops;
124 	const struct btf_member *member;
125 	const struct btf_type *t;
126 	s32 type_id, value_id;
127 	char value_name[128];
128 	const char *mname;
129 	int i;
130 
131 	if (strlen(st_ops->name) + VALUE_PREFIX_LEN >=
132 	    sizeof(value_name)) {
133 		pr_warn("struct_ops name %s is too long\n",
134 			st_ops->name);
135 		return -EINVAL;
136 	}
137 	sprintf(value_name, "%s%s", VALUE_PREFIX, st_ops->name);
138 
139 	type_id = btf_find_by_name_kind(btf, st_ops->name,
140 					BTF_KIND_STRUCT);
141 	if (type_id < 0) {
142 		pr_warn("Cannot find struct %s in %s\n",
143 			st_ops->name, btf_get_name(btf));
144 		return -EINVAL;
145 	}
146 	t = btf_type_by_id(btf, type_id);
147 	if (btf_type_vlen(t) > BPF_STRUCT_OPS_MAX_NR_MEMBERS) {
148 		pr_warn("Cannot support #%u members in struct %s\n",
149 			btf_type_vlen(t), st_ops->name);
150 		return -EINVAL;
151 	}
152 
153 	value_id = btf_find_by_name_kind(btf, value_name,
154 					 BTF_KIND_STRUCT);
155 	if (value_id < 0) {
156 		pr_warn("Cannot find struct %s in %s\n",
157 			value_name, btf_get_name(btf));
158 		return -EINVAL;
159 	}
160 	if (!is_valid_value_type(btf, value_id, t, value_name))
161 		return -EINVAL;
162 
163 	for_each_member(i, t, member) {
164 		const struct btf_type *func_proto;
165 
166 		mname = btf_name_by_offset(btf, member->name_off);
167 		if (!*mname) {
168 			pr_warn("anon member in struct %s is not supported\n",
169 				st_ops->name);
170 			return -EOPNOTSUPP;
171 		}
172 
173 		if (__btf_member_bitfield_size(t, member)) {
174 			pr_warn("bit field member %s in struct %s is not supported\n",
175 				mname, st_ops->name);
176 			return -EOPNOTSUPP;
177 		}
178 
179 		func_proto = btf_type_resolve_func_ptr(btf,
180 						       member->type,
181 						       NULL);
182 		if (func_proto &&
183 		    btf_distill_func_proto(log, btf,
184 					   func_proto, mname,
185 					   &st_ops->func_models[i])) {
186 			pr_warn("Error in parsing func ptr %s in struct %s\n",
187 				mname, st_ops->name);
188 			return -EINVAL;
189 		}
190 	}
191 
192 	if (i == btf_type_vlen(t)) {
193 		if (st_ops->init(btf)) {
194 			pr_warn("Error in init bpf_struct_ops %s\n",
195 				st_ops->name);
196 			return -EINVAL;
197 		} else {
198 			st_ops_desc->type_id = type_id;
199 			st_ops_desc->type = t;
200 			st_ops_desc->value_id = value_id;
201 			st_ops_desc->value_type = btf_type_by_id(btf,
202 								 value_id);
203 		}
204 	}
205 
206 	return 0;
207 }
208 
209 static int bpf_struct_ops_map_get_next_key(struct bpf_map *map, void *key,
210 					   void *next_key)
211 {
212 	if (key && *(u32 *)key == 0)
213 		return -ENOENT;
214 
215 	*(u32 *)next_key = 0;
216 	return 0;
217 }
218 
219 int bpf_struct_ops_map_sys_lookup_elem(struct bpf_map *map, void *key,
220 				       void *value)
221 {
222 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
223 	struct bpf_struct_ops_value *uvalue, *kvalue;
224 	enum bpf_struct_ops_state state;
225 	s64 refcnt;
226 
227 	if (unlikely(*(u32 *)key != 0))
228 		return -ENOENT;
229 
230 	kvalue = &st_map->kvalue;
231 	/* Pair with smp_store_release() during map_update */
232 	state = smp_load_acquire(&kvalue->common.state);
233 	if (state == BPF_STRUCT_OPS_STATE_INIT) {
234 		memset(value, 0, map->value_size);
235 		return 0;
236 	}
237 
238 	/* No lock is needed.  state and refcnt do not need
239 	 * to be updated together under atomic context.
240 	 */
241 	uvalue = value;
242 	memcpy(uvalue, st_map->uvalue, map->value_size);
243 	uvalue->common.state = state;
244 
245 	/* This value offers the user space a general estimate of how
246 	 * many sockets are still utilizing this struct_ops for TCP
247 	 * congestion control. The number might not be exact, but it
248 	 * should sufficiently meet our present goals.
249 	 */
250 	refcnt = atomic64_read(&map->refcnt) - atomic64_read(&map->usercnt);
251 	refcount_set(&uvalue->common.refcnt, max_t(s64, refcnt, 0));
252 
253 	return 0;
254 }
255 
256 static void *bpf_struct_ops_map_lookup_elem(struct bpf_map *map, void *key)
257 {
258 	return ERR_PTR(-EINVAL);
259 }
260 
261 static void bpf_struct_ops_map_put_progs(struct bpf_struct_ops_map *st_map)
262 {
263 	u32 i;
264 
265 	for (i = 0; i < st_map->links_cnt; i++) {
266 		if (st_map->links[i]) {
267 			bpf_link_put(st_map->links[i]);
268 			st_map->links[i] = NULL;
269 		}
270 	}
271 }
272 
273 static int check_zero_holes(const struct btf *btf, const struct btf_type *t, void *data)
274 {
275 	const struct btf_member *member;
276 	u32 i, moff, msize, prev_mend = 0;
277 	const struct btf_type *mtype;
278 
279 	for_each_member(i, t, member) {
280 		moff = __btf_member_bit_offset(t, member) / 8;
281 		if (moff > prev_mend &&
282 		    memchr_inv(data + prev_mend, 0, moff - prev_mend))
283 			return -EINVAL;
284 
285 		mtype = btf_type_by_id(btf, member->type);
286 		mtype = btf_resolve_size(btf, mtype, &msize);
287 		if (IS_ERR(mtype))
288 			return PTR_ERR(mtype);
289 		prev_mend = moff + msize;
290 	}
291 
292 	if (t->size > prev_mend &&
293 	    memchr_inv(data + prev_mend, 0, t->size - prev_mend))
294 		return -EINVAL;
295 
296 	return 0;
297 }
298 
299 static void bpf_struct_ops_link_release(struct bpf_link *link)
300 {
301 }
302 
303 static void bpf_struct_ops_link_dealloc(struct bpf_link *link)
304 {
305 	struct bpf_tramp_link *tlink = container_of(link, struct bpf_tramp_link, link);
306 
307 	kfree(tlink);
308 }
309 
310 const struct bpf_link_ops bpf_struct_ops_link_lops = {
311 	.release = bpf_struct_ops_link_release,
312 	.dealloc = bpf_struct_ops_link_dealloc,
313 };
314 
315 int bpf_struct_ops_prepare_trampoline(struct bpf_tramp_links *tlinks,
316 				      struct bpf_tramp_link *link,
317 				      const struct btf_func_model *model,
318 				      void *stub_func, void *image, void *image_end)
319 {
320 	u32 flags = BPF_TRAMP_F_INDIRECT;
321 	int size;
322 
323 	tlinks[BPF_TRAMP_FENTRY].links[0] = link;
324 	tlinks[BPF_TRAMP_FENTRY].nr_links = 1;
325 
326 	if (model->ret_size > 0)
327 		flags |= BPF_TRAMP_F_RET_FENTRY_RET;
328 
329 	size = arch_bpf_trampoline_size(model, flags, tlinks, NULL);
330 	if (size < 0)
331 		return size;
332 	if (size > (unsigned long)image_end - (unsigned long)image)
333 		return -E2BIG;
334 	return arch_prepare_bpf_trampoline(NULL, image, image_end,
335 					   model, flags, tlinks, stub_func);
336 }
337 
338 static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key,
339 					   void *value, u64 flags)
340 {
341 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
342 	const struct bpf_struct_ops_desc *st_ops_desc = st_map->st_ops_desc;
343 	const struct bpf_struct_ops *st_ops = st_ops_desc->st_ops;
344 	struct bpf_struct_ops_value *uvalue, *kvalue;
345 	const struct btf_type *module_type;
346 	const struct btf_member *member;
347 	const struct btf_type *t = st_ops_desc->type;
348 	struct bpf_tramp_links *tlinks;
349 	void *udata, *kdata;
350 	int prog_fd, err;
351 	void *image, *image_end;
352 	u32 i;
353 
354 	if (flags)
355 		return -EINVAL;
356 
357 	if (*(u32 *)key != 0)
358 		return -E2BIG;
359 
360 	err = check_zero_holes(st_map->btf, st_ops_desc->value_type, value);
361 	if (err)
362 		return err;
363 
364 	uvalue = value;
365 	err = check_zero_holes(st_map->btf, t, uvalue->data);
366 	if (err)
367 		return err;
368 
369 	if (uvalue->common.state || refcount_read(&uvalue->common.refcnt))
370 		return -EINVAL;
371 
372 	tlinks = kcalloc(BPF_TRAMP_MAX, sizeof(*tlinks), GFP_KERNEL);
373 	if (!tlinks)
374 		return -ENOMEM;
375 
376 	uvalue = (struct bpf_struct_ops_value *)st_map->uvalue;
377 	kvalue = (struct bpf_struct_ops_value *)&st_map->kvalue;
378 
379 	mutex_lock(&st_map->lock);
380 
381 	if (kvalue->common.state != BPF_STRUCT_OPS_STATE_INIT) {
382 		err = -EBUSY;
383 		goto unlock;
384 	}
385 
386 	memcpy(uvalue, value, map->value_size);
387 
388 	udata = &uvalue->data;
389 	kdata = &kvalue->data;
390 	image = st_map->image;
391 	image_end = st_map->image + PAGE_SIZE;
392 
393 	module_type = btf_type_by_id(btf_vmlinux, st_ops_ids[IDX_MODULE_ID]);
394 	for_each_member(i, t, member) {
395 		const struct btf_type *mtype, *ptype;
396 		struct bpf_prog *prog;
397 		struct bpf_tramp_link *link;
398 		u32 moff;
399 
400 		moff = __btf_member_bit_offset(t, member) / 8;
401 		ptype = btf_type_resolve_ptr(st_map->btf, member->type, NULL);
402 		if (ptype == module_type) {
403 			if (*(void **)(udata + moff))
404 				goto reset_unlock;
405 			*(void **)(kdata + moff) = BPF_MODULE_OWNER;
406 			continue;
407 		}
408 
409 		err = st_ops->init_member(t, member, kdata, udata);
410 		if (err < 0)
411 			goto reset_unlock;
412 
413 		/* The ->init_member() has handled this member */
414 		if (err > 0)
415 			continue;
416 
417 		/* If st_ops->init_member does not handle it,
418 		 * we will only handle func ptrs and zero-ed members
419 		 * here.  Reject everything else.
420 		 */
421 
422 		/* All non func ptr member must be 0 */
423 		if (!ptype || !btf_type_is_func_proto(ptype)) {
424 			u32 msize;
425 
426 			mtype = btf_type_by_id(st_map->btf, member->type);
427 			mtype = btf_resolve_size(st_map->btf, mtype, &msize);
428 			if (IS_ERR(mtype)) {
429 				err = PTR_ERR(mtype);
430 				goto reset_unlock;
431 			}
432 
433 			if (memchr_inv(udata + moff, 0, msize)) {
434 				err = -EINVAL;
435 				goto reset_unlock;
436 			}
437 
438 			continue;
439 		}
440 
441 		prog_fd = (int)(*(unsigned long *)(udata + moff));
442 		/* Similar check as the attr->attach_prog_fd */
443 		if (!prog_fd)
444 			continue;
445 
446 		prog = bpf_prog_get(prog_fd);
447 		if (IS_ERR(prog)) {
448 			err = PTR_ERR(prog);
449 			goto reset_unlock;
450 		}
451 
452 		if (prog->type != BPF_PROG_TYPE_STRUCT_OPS ||
453 		    prog->aux->attach_btf_id != st_ops_desc->type_id ||
454 		    prog->expected_attach_type != i) {
455 			bpf_prog_put(prog);
456 			err = -EINVAL;
457 			goto reset_unlock;
458 		}
459 
460 		link = kzalloc(sizeof(*link), GFP_USER);
461 		if (!link) {
462 			bpf_prog_put(prog);
463 			err = -ENOMEM;
464 			goto reset_unlock;
465 		}
466 		bpf_link_init(&link->link, BPF_LINK_TYPE_STRUCT_OPS,
467 			      &bpf_struct_ops_link_lops, prog);
468 		st_map->links[i] = &link->link;
469 
470 		err = bpf_struct_ops_prepare_trampoline(tlinks, link,
471 							&st_ops->func_models[i],
472 							*(void **)(st_ops->cfi_stubs + moff),
473 							image, image_end);
474 		if (err < 0)
475 			goto reset_unlock;
476 
477 		*(void **)(kdata + moff) = image + cfi_get_offset();
478 		image += err;
479 
480 		/* put prog_id to udata */
481 		*(unsigned long *)(udata + moff) = prog->aux->id;
482 	}
483 
484 	if (st_map->map.map_flags & BPF_F_LINK) {
485 		err = 0;
486 		if (st_ops->validate) {
487 			err = st_ops->validate(kdata);
488 			if (err)
489 				goto reset_unlock;
490 		}
491 		arch_protect_bpf_trampoline(st_map->image, PAGE_SIZE);
492 		/* Let bpf_link handle registration & unregistration.
493 		 *
494 		 * Pair with smp_load_acquire() during lookup_elem().
495 		 */
496 		smp_store_release(&kvalue->common.state, BPF_STRUCT_OPS_STATE_READY);
497 		goto unlock;
498 	}
499 
500 	arch_protect_bpf_trampoline(st_map->image, PAGE_SIZE);
501 	err = st_ops->reg(kdata);
502 	if (likely(!err)) {
503 		/* This refcnt increment on the map here after
504 		 * 'st_ops->reg()' is secure since the state of the
505 		 * map must be set to INIT at this moment, and thus
506 		 * bpf_struct_ops_map_delete_elem() can't unregister
507 		 * or transition it to TOBEFREE concurrently.
508 		 */
509 		bpf_map_inc(map);
510 		/* Pair with smp_load_acquire() during lookup_elem().
511 		 * It ensures the above udata updates (e.g. prog->aux->id)
512 		 * can be seen once BPF_STRUCT_OPS_STATE_INUSE is set.
513 		 */
514 		smp_store_release(&kvalue->common.state, BPF_STRUCT_OPS_STATE_INUSE);
515 		goto unlock;
516 	}
517 
518 	/* Error during st_ops->reg(). Can happen if this struct_ops needs to be
519 	 * verified as a whole, after all init_member() calls. Can also happen if
520 	 * there was a race in registering the struct_ops (under the same name) to
521 	 * a sub-system through different struct_ops's maps.
522 	 */
523 	arch_unprotect_bpf_trampoline(st_map->image, PAGE_SIZE);
524 
525 reset_unlock:
526 	bpf_struct_ops_map_put_progs(st_map);
527 	memset(uvalue, 0, map->value_size);
528 	memset(kvalue, 0, map->value_size);
529 unlock:
530 	kfree(tlinks);
531 	mutex_unlock(&st_map->lock);
532 	return err;
533 }
534 
535 static long bpf_struct_ops_map_delete_elem(struct bpf_map *map, void *key)
536 {
537 	enum bpf_struct_ops_state prev_state;
538 	struct bpf_struct_ops_map *st_map;
539 
540 	st_map = (struct bpf_struct_ops_map *)map;
541 	if (st_map->map.map_flags & BPF_F_LINK)
542 		return -EOPNOTSUPP;
543 
544 	prev_state = cmpxchg(&st_map->kvalue.common.state,
545 			     BPF_STRUCT_OPS_STATE_INUSE,
546 			     BPF_STRUCT_OPS_STATE_TOBEFREE);
547 	switch (prev_state) {
548 	case BPF_STRUCT_OPS_STATE_INUSE:
549 		st_map->st_ops_desc->st_ops->unreg(&st_map->kvalue.data);
550 		bpf_map_put(map);
551 		return 0;
552 	case BPF_STRUCT_OPS_STATE_TOBEFREE:
553 		return -EINPROGRESS;
554 	case BPF_STRUCT_OPS_STATE_INIT:
555 		return -ENOENT;
556 	default:
557 		WARN_ON_ONCE(1);
558 		/* Should never happen.  Treat it as not found. */
559 		return -ENOENT;
560 	}
561 }
562 
563 static void bpf_struct_ops_map_seq_show_elem(struct bpf_map *map, void *key,
564 					     struct seq_file *m)
565 {
566 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
567 	void *value;
568 	int err;
569 
570 	value = kmalloc(map->value_size, GFP_USER | __GFP_NOWARN);
571 	if (!value)
572 		return;
573 
574 	err = bpf_struct_ops_map_sys_lookup_elem(map, key, value);
575 	if (!err) {
576 		btf_type_seq_show(st_map->btf,
577 				  map->btf_vmlinux_value_type_id,
578 				  value, m);
579 		seq_puts(m, "\n");
580 	}
581 
582 	kfree(value);
583 }
584 
585 static void __bpf_struct_ops_map_free(struct bpf_map *map)
586 {
587 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
588 
589 	if (st_map->links)
590 		bpf_struct_ops_map_put_progs(st_map);
591 	bpf_map_area_free(st_map->links);
592 	if (st_map->image) {
593 		arch_free_bpf_trampoline(st_map->image, PAGE_SIZE);
594 		bpf_jit_uncharge_modmem(PAGE_SIZE);
595 	}
596 	bpf_map_area_free(st_map->uvalue);
597 	bpf_map_area_free(st_map);
598 }
599 
600 static void bpf_struct_ops_map_free(struct bpf_map *map)
601 {
602 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
603 
604 	/* st_ops->owner was acquired during map_alloc to implicitly holds
605 	 * the btf's refcnt. The acquire was only done when btf_is_module()
606 	 * st_map->btf cannot be NULL here.
607 	 */
608 	if (btf_is_module(st_map->btf))
609 		module_put(st_map->st_ops_desc->st_ops->owner);
610 
611 	/* The struct_ops's function may switch to another struct_ops.
612 	 *
613 	 * For example, bpf_tcp_cc_x->init() may switch to
614 	 * another tcp_cc_y by calling
615 	 * setsockopt(TCP_CONGESTION, "tcp_cc_y").
616 	 * During the switch,  bpf_struct_ops_put(tcp_cc_x) is called
617 	 * and its refcount may reach 0 which then free its
618 	 * trampoline image while tcp_cc_x is still running.
619 	 *
620 	 * A vanilla rcu gp is to wait for all bpf-tcp-cc prog
621 	 * to finish. bpf-tcp-cc prog is non sleepable.
622 	 * A rcu_tasks gp is to wait for the last few insn
623 	 * in the tramopline image to finish before releasing
624 	 * the trampoline image.
625 	 */
626 	synchronize_rcu_mult(call_rcu, call_rcu_tasks);
627 
628 	__bpf_struct_ops_map_free(map);
629 }
630 
631 static int bpf_struct_ops_map_alloc_check(union bpf_attr *attr)
632 {
633 	if (attr->key_size != sizeof(unsigned int) || attr->max_entries != 1 ||
634 	    (attr->map_flags & ~(BPF_F_LINK | BPF_F_VTYPE_BTF_OBJ_FD)) ||
635 	    !attr->btf_vmlinux_value_type_id)
636 		return -EINVAL;
637 	return 0;
638 }
639 
640 static struct bpf_map *bpf_struct_ops_map_alloc(union bpf_attr *attr)
641 {
642 	const struct bpf_struct_ops_desc *st_ops_desc;
643 	size_t st_map_size;
644 	struct bpf_struct_ops_map *st_map;
645 	const struct btf_type *t, *vt;
646 	struct module *mod = NULL;
647 	struct bpf_map *map;
648 	struct btf *btf;
649 	int ret;
650 
651 	if (attr->map_flags & BPF_F_VTYPE_BTF_OBJ_FD) {
652 		/* The map holds btf for its whole life time. */
653 		btf = btf_get_by_fd(attr->value_type_btf_obj_fd);
654 		if (IS_ERR(btf))
655 			return ERR_CAST(btf);
656 		if (!btf_is_module(btf)) {
657 			btf_put(btf);
658 			return ERR_PTR(-EINVAL);
659 		}
660 
661 		mod = btf_try_get_module(btf);
662 		/* mod holds a refcnt to btf. We don't need an extra refcnt
663 		 * here.
664 		 */
665 		btf_put(btf);
666 		if (!mod)
667 			return ERR_PTR(-EINVAL);
668 	} else {
669 		btf = bpf_get_btf_vmlinux();
670 		if (IS_ERR(btf))
671 			return ERR_CAST(btf);
672 		if (!btf)
673 			return ERR_PTR(-ENOTSUPP);
674 	}
675 
676 	st_ops_desc = bpf_struct_ops_find_value(btf, attr->btf_vmlinux_value_type_id);
677 	if (!st_ops_desc) {
678 		ret = -ENOTSUPP;
679 		goto errout;
680 	}
681 
682 	vt = st_ops_desc->value_type;
683 	if (attr->value_size != vt->size) {
684 		ret = -EINVAL;
685 		goto errout;
686 	}
687 
688 	t = st_ops_desc->type;
689 
690 	st_map_size = sizeof(*st_map) +
691 		/* kvalue stores the
692 		 * struct bpf_struct_ops_tcp_congestions_ops
693 		 */
694 		(vt->size - sizeof(struct bpf_struct_ops_value));
695 
696 	st_map = bpf_map_area_alloc(st_map_size, NUMA_NO_NODE);
697 	if (!st_map) {
698 		ret = -ENOMEM;
699 		goto errout;
700 	}
701 
702 	st_map->st_ops_desc = st_ops_desc;
703 	map = &st_map->map;
704 
705 	ret = bpf_jit_charge_modmem(PAGE_SIZE);
706 	if (ret)
707 		goto errout_free;
708 
709 	st_map->image = arch_alloc_bpf_trampoline(PAGE_SIZE);
710 	if (!st_map->image) {
711 		/* __bpf_struct_ops_map_free() uses st_map->image as flag
712 		 * for "charged or not". In this case, we need to unchange
713 		 * here.
714 		 */
715 		bpf_jit_uncharge_modmem(PAGE_SIZE);
716 		ret = -ENOMEM;
717 		goto errout_free;
718 	}
719 	st_map->uvalue = bpf_map_area_alloc(vt->size, NUMA_NO_NODE);
720 	st_map->links_cnt = btf_type_vlen(t);
721 	st_map->links =
722 		bpf_map_area_alloc(st_map->links_cnt * sizeof(struct bpf_links *),
723 				   NUMA_NO_NODE);
724 	if (!st_map->uvalue || !st_map->links) {
725 		ret = -ENOMEM;
726 		goto errout_free;
727 	}
728 	st_map->btf = btf;
729 
730 	mutex_init(&st_map->lock);
731 	bpf_map_init_from_attr(map, attr);
732 
733 	return map;
734 
735 errout_free:
736 	__bpf_struct_ops_map_free(map);
737 errout:
738 	module_put(mod);
739 
740 	return ERR_PTR(ret);
741 }
742 
743 static u64 bpf_struct_ops_map_mem_usage(const struct bpf_map *map)
744 {
745 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
746 	const struct bpf_struct_ops_desc *st_ops_desc = st_map->st_ops_desc;
747 	const struct btf_type *vt = st_ops_desc->value_type;
748 	u64 usage;
749 
750 	usage = sizeof(*st_map) +
751 			vt->size - sizeof(struct bpf_struct_ops_value);
752 	usage += vt->size;
753 	usage += btf_type_vlen(vt) * sizeof(struct bpf_links *);
754 	usage += PAGE_SIZE;
755 	return usage;
756 }
757 
758 BTF_ID_LIST_SINGLE(bpf_struct_ops_map_btf_ids, struct, bpf_struct_ops_map)
759 const struct bpf_map_ops bpf_struct_ops_map_ops = {
760 	.map_alloc_check = bpf_struct_ops_map_alloc_check,
761 	.map_alloc = bpf_struct_ops_map_alloc,
762 	.map_free = bpf_struct_ops_map_free,
763 	.map_get_next_key = bpf_struct_ops_map_get_next_key,
764 	.map_lookup_elem = bpf_struct_ops_map_lookup_elem,
765 	.map_delete_elem = bpf_struct_ops_map_delete_elem,
766 	.map_update_elem = bpf_struct_ops_map_update_elem,
767 	.map_seq_show_elem = bpf_struct_ops_map_seq_show_elem,
768 	.map_mem_usage = bpf_struct_ops_map_mem_usage,
769 	.map_btf_id = &bpf_struct_ops_map_btf_ids[0],
770 };
771 
772 /* "const void *" because some subsystem is
773  * passing a const (e.g. const struct tcp_congestion_ops *)
774  */
775 bool bpf_struct_ops_get(const void *kdata)
776 {
777 	struct bpf_struct_ops_value *kvalue;
778 	struct bpf_struct_ops_map *st_map;
779 	struct bpf_map *map;
780 
781 	kvalue = container_of(kdata, struct bpf_struct_ops_value, data);
782 	st_map = container_of(kvalue, struct bpf_struct_ops_map, kvalue);
783 
784 	map = __bpf_map_inc_not_zero(&st_map->map, false);
785 	return !IS_ERR(map);
786 }
787 
788 void bpf_struct_ops_put(const void *kdata)
789 {
790 	struct bpf_struct_ops_value *kvalue;
791 	struct bpf_struct_ops_map *st_map;
792 
793 	kvalue = container_of(kdata, struct bpf_struct_ops_value, data);
794 	st_map = container_of(kvalue, struct bpf_struct_ops_map, kvalue);
795 
796 	bpf_map_put(&st_map->map);
797 }
798 
799 static bool bpf_struct_ops_valid_to_reg(struct bpf_map *map)
800 {
801 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
802 
803 	return map->map_type == BPF_MAP_TYPE_STRUCT_OPS &&
804 		map->map_flags & BPF_F_LINK &&
805 		/* Pair with smp_store_release() during map_update */
806 		smp_load_acquire(&st_map->kvalue.common.state) == BPF_STRUCT_OPS_STATE_READY;
807 }
808 
809 static void bpf_struct_ops_map_link_dealloc(struct bpf_link *link)
810 {
811 	struct bpf_struct_ops_link *st_link;
812 	struct bpf_struct_ops_map *st_map;
813 
814 	st_link = container_of(link, struct bpf_struct_ops_link, link);
815 	st_map = (struct bpf_struct_ops_map *)
816 		rcu_dereference_protected(st_link->map, true);
817 	if (st_map) {
818 		/* st_link->map can be NULL if
819 		 * bpf_struct_ops_link_create() fails to register.
820 		 */
821 		st_map->st_ops_desc->st_ops->unreg(&st_map->kvalue.data);
822 		bpf_map_put(&st_map->map);
823 	}
824 	kfree(st_link);
825 }
826 
827 static void bpf_struct_ops_map_link_show_fdinfo(const struct bpf_link *link,
828 					    struct seq_file *seq)
829 {
830 	struct bpf_struct_ops_link *st_link;
831 	struct bpf_map *map;
832 
833 	st_link = container_of(link, struct bpf_struct_ops_link, link);
834 	rcu_read_lock();
835 	map = rcu_dereference(st_link->map);
836 	seq_printf(seq, "map_id:\t%d\n", map->id);
837 	rcu_read_unlock();
838 }
839 
840 static int bpf_struct_ops_map_link_fill_link_info(const struct bpf_link *link,
841 					       struct bpf_link_info *info)
842 {
843 	struct bpf_struct_ops_link *st_link;
844 	struct bpf_map *map;
845 
846 	st_link = container_of(link, struct bpf_struct_ops_link, link);
847 	rcu_read_lock();
848 	map = rcu_dereference(st_link->map);
849 	info->struct_ops.map_id = map->id;
850 	rcu_read_unlock();
851 	return 0;
852 }
853 
854 static int bpf_struct_ops_map_link_update(struct bpf_link *link, struct bpf_map *new_map,
855 					  struct bpf_map *expected_old_map)
856 {
857 	struct bpf_struct_ops_map *st_map, *old_st_map;
858 	struct bpf_map *old_map;
859 	struct bpf_struct_ops_link *st_link;
860 	int err;
861 
862 	st_link = container_of(link, struct bpf_struct_ops_link, link);
863 	st_map = container_of(new_map, struct bpf_struct_ops_map, map);
864 
865 	if (!bpf_struct_ops_valid_to_reg(new_map))
866 		return -EINVAL;
867 
868 	if (!st_map->st_ops_desc->st_ops->update)
869 		return -EOPNOTSUPP;
870 
871 	mutex_lock(&update_mutex);
872 
873 	old_map = rcu_dereference_protected(st_link->map, lockdep_is_held(&update_mutex));
874 	if (expected_old_map && old_map != expected_old_map) {
875 		err = -EPERM;
876 		goto err_out;
877 	}
878 
879 	old_st_map = container_of(old_map, struct bpf_struct_ops_map, map);
880 	/* The new and old struct_ops must be the same type. */
881 	if (st_map->st_ops_desc != old_st_map->st_ops_desc) {
882 		err = -EINVAL;
883 		goto err_out;
884 	}
885 
886 	err = st_map->st_ops_desc->st_ops->update(st_map->kvalue.data, old_st_map->kvalue.data);
887 	if (err)
888 		goto err_out;
889 
890 	bpf_map_inc(new_map);
891 	rcu_assign_pointer(st_link->map, new_map);
892 	bpf_map_put(old_map);
893 
894 err_out:
895 	mutex_unlock(&update_mutex);
896 
897 	return err;
898 }
899 
900 static const struct bpf_link_ops bpf_struct_ops_map_lops = {
901 	.dealloc = bpf_struct_ops_map_link_dealloc,
902 	.show_fdinfo = bpf_struct_ops_map_link_show_fdinfo,
903 	.fill_link_info = bpf_struct_ops_map_link_fill_link_info,
904 	.update_map = bpf_struct_ops_map_link_update,
905 };
906 
907 int bpf_struct_ops_link_create(union bpf_attr *attr)
908 {
909 	struct bpf_struct_ops_link *link = NULL;
910 	struct bpf_link_primer link_primer;
911 	struct bpf_struct_ops_map *st_map;
912 	struct bpf_map *map;
913 	int err;
914 
915 	map = bpf_map_get(attr->link_create.map_fd);
916 	if (IS_ERR(map))
917 		return PTR_ERR(map);
918 
919 	st_map = (struct bpf_struct_ops_map *)map;
920 
921 	if (!bpf_struct_ops_valid_to_reg(map)) {
922 		err = -EINVAL;
923 		goto err_out;
924 	}
925 
926 	link = kzalloc(sizeof(*link), GFP_USER);
927 	if (!link) {
928 		err = -ENOMEM;
929 		goto err_out;
930 	}
931 	bpf_link_init(&link->link, BPF_LINK_TYPE_STRUCT_OPS, &bpf_struct_ops_map_lops, NULL);
932 
933 	err = bpf_link_prime(&link->link, &link_primer);
934 	if (err)
935 		goto err_out;
936 
937 	err = st_map->st_ops_desc->st_ops->reg(st_map->kvalue.data);
938 	if (err) {
939 		bpf_link_cleanup(&link_primer);
940 		link = NULL;
941 		goto err_out;
942 	}
943 	RCU_INIT_POINTER(link->map, map);
944 
945 	return bpf_link_settle(&link_primer);
946 
947 err_out:
948 	bpf_map_put(map);
949 	kfree(link);
950 	return err;
951 }
952 
953 void bpf_map_struct_ops_info_fill(struct bpf_map_info *info, struct bpf_map *map)
954 {
955 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
956 
957 	info->btf_vmlinux_id = btf_obj_id(st_map->btf);
958 }
959