xref: /linux/kernel/bpf/bpf_struct_ops.c (revision e28c5efc31397af17bc5a7d55b963f59bcde0166)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (c) 2019 Facebook */
3 
4 #include <linux/bpf.h>
5 #include <linux/bpf_verifier.h>
6 #include <linux/btf.h>
7 #include <linux/filter.h>
8 #include <linux/slab.h>
9 #include <linux/numa.h>
10 #include <linux/seq_file.h>
11 #include <linux/refcount.h>
12 #include <linux/mutex.h>
13 #include <linux/btf_ids.h>
14 #include <linux/rcupdate_wait.h>
15 
16 struct bpf_struct_ops_value {
17 	struct bpf_struct_ops_common_value common;
18 	char data[] ____cacheline_aligned_in_smp;
19 };
20 
21 struct bpf_struct_ops_map {
22 	struct bpf_map map;
23 	struct rcu_head rcu;
24 	const struct bpf_struct_ops_desc *st_ops_desc;
25 	/* protect map_update */
26 	struct mutex lock;
27 	/* link has all the bpf_links that is populated
28 	 * to the func ptr of the kernel's struct
29 	 * (in kvalue.data).
30 	 */
31 	struct bpf_link **links;
32 	u32 links_cnt;
33 	/* image is a page that has all the trampolines
34 	 * that stores the func args before calling the bpf_prog.
35 	 * A PAGE_SIZE "image" is enough to store all trampoline for
36 	 * "links[]".
37 	 */
38 	void *image;
39 	/* The owner moduler's btf. */
40 	struct btf *btf;
41 	/* uvalue->data stores the kernel struct
42 	 * (e.g. tcp_congestion_ops) that is more useful
43 	 * to userspace than the kvalue.  For example,
44 	 * the bpf_prog's id is stored instead of the kernel
45 	 * address of a func ptr.
46 	 */
47 	struct bpf_struct_ops_value *uvalue;
48 	/* kvalue.data stores the actual kernel's struct
49 	 * (e.g. tcp_congestion_ops) that will be
50 	 * registered to the kernel subsystem.
51 	 */
52 	struct bpf_struct_ops_value kvalue;
53 };
54 
55 struct bpf_struct_ops_link {
56 	struct bpf_link link;
57 	struct bpf_map __rcu *map;
58 };
59 
60 static DEFINE_MUTEX(update_mutex);
61 
62 #define VALUE_PREFIX "bpf_struct_ops_"
63 #define VALUE_PREFIX_LEN (sizeof(VALUE_PREFIX) - 1)
64 
65 const struct bpf_verifier_ops bpf_struct_ops_verifier_ops = {
66 };
67 
68 const struct bpf_prog_ops bpf_struct_ops_prog_ops = {
69 #ifdef CONFIG_NET
70 	.test_run = bpf_struct_ops_test_run,
71 #endif
72 };
73 
74 BTF_ID_LIST(st_ops_ids)
75 BTF_ID(struct, module)
76 BTF_ID(struct, bpf_struct_ops_common_value)
77 
78 enum {
79 	IDX_MODULE_ID,
80 	IDX_ST_OPS_COMMON_VALUE_ID,
81 };
82 
83 extern struct btf *btf_vmlinux;
84 
85 static bool is_valid_value_type(struct btf *btf, s32 value_id,
86 				const struct btf_type *type,
87 				const char *value_name)
88 {
89 	const struct btf_type *common_value_type;
90 	const struct btf_member *member;
91 	const struct btf_type *vt, *mt;
92 
93 	vt = btf_type_by_id(btf, value_id);
94 	if (btf_vlen(vt) != 2) {
95 		pr_warn("The number of %s's members should be 2, but we get %d\n",
96 			value_name, btf_vlen(vt));
97 		return false;
98 	}
99 	member = btf_type_member(vt);
100 	mt = btf_type_by_id(btf, member->type);
101 	common_value_type = btf_type_by_id(btf_vmlinux,
102 					   st_ops_ids[IDX_ST_OPS_COMMON_VALUE_ID]);
103 	if (mt != common_value_type) {
104 		pr_warn("The first member of %s should be bpf_struct_ops_common_value\n",
105 			value_name);
106 		return false;
107 	}
108 	member++;
109 	mt = btf_type_by_id(btf, member->type);
110 	if (mt != type) {
111 		pr_warn("The second member of %s should be %s\n",
112 			value_name, btf_name_by_offset(btf, type->name_off));
113 		return false;
114 	}
115 
116 	return true;
117 }
118 
119 #define MAYBE_NULL_SUFFIX "__nullable"
120 #define MAX_STUB_NAME 128
121 
122 /* Return the type info of a stub function, if it exists.
123  *
124  * The name of a stub function is made up of the name of the struct_ops and
125  * the name of the function pointer member, separated by "__". For example,
126  * if the struct_ops type is named "foo_ops" and the function pointer
127  * member is named "bar", the stub function name would be "foo_ops__bar".
128  */
129 static const struct btf_type *
130 find_stub_func_proto(const struct btf *btf, const char *st_op_name,
131 		     const char *member_name)
132 {
133 	char stub_func_name[MAX_STUB_NAME];
134 	const struct btf_type *func_type;
135 	s32 btf_id;
136 	int cp;
137 
138 	cp = snprintf(stub_func_name, MAX_STUB_NAME, "%s__%s",
139 		      st_op_name, member_name);
140 	if (cp >= MAX_STUB_NAME) {
141 		pr_warn("Stub function name too long\n");
142 		return NULL;
143 	}
144 	btf_id = btf_find_by_name_kind(btf, stub_func_name, BTF_KIND_FUNC);
145 	if (btf_id < 0)
146 		return NULL;
147 	func_type = btf_type_by_id(btf, btf_id);
148 	if (!func_type)
149 		return NULL;
150 
151 	return btf_type_by_id(btf, func_type->type); /* FUNC_PROTO */
152 }
153 
154 /* Prepare argument info for every nullable argument of a member of a
155  * struct_ops type.
156  *
157  * Initialize a struct bpf_struct_ops_arg_info according to type info of
158  * the arguments of a stub function. (Check kCFI for more information about
159  * stub functions.)
160  *
161  * Each member in the struct_ops type has a struct bpf_struct_ops_arg_info
162  * to provide an array of struct bpf_ctx_arg_aux, which in turn provides
163  * the information that used by the verifier to check the arguments of the
164  * BPF struct_ops program assigned to the member. Here, we only care about
165  * the arguments that are marked as __nullable.
166  *
167  * The array of struct bpf_ctx_arg_aux is eventually assigned to
168  * prog->aux->ctx_arg_info of BPF struct_ops programs and passed to the
169  * verifier. (See check_struct_ops_btf_id())
170  *
171  * arg_info->info will be the list of struct bpf_ctx_arg_aux if success. If
172  * fails, it will be kept untouched.
173  */
174 static int prepare_arg_info(struct btf *btf,
175 			    const char *st_ops_name,
176 			    const char *member_name,
177 			    const struct btf_type *func_proto,
178 			    struct bpf_struct_ops_arg_info *arg_info)
179 {
180 	const struct btf_type *stub_func_proto, *pointed_type;
181 	const struct btf_param *stub_args, *args;
182 	struct bpf_ctx_arg_aux *info, *info_buf;
183 	u32 nargs, arg_no, info_cnt = 0;
184 	u32 arg_btf_id;
185 	int offset;
186 
187 	stub_func_proto = find_stub_func_proto(btf, st_ops_name, member_name);
188 	if (!stub_func_proto)
189 		return 0;
190 
191 	/* Check if the number of arguments of the stub function is the same
192 	 * as the number of arguments of the function pointer.
193 	 */
194 	nargs = btf_type_vlen(func_proto);
195 	if (nargs != btf_type_vlen(stub_func_proto)) {
196 		pr_warn("the number of arguments of the stub function %s__%s does not match the number of arguments of the member %s of struct %s\n",
197 			st_ops_name, member_name, member_name, st_ops_name);
198 		return -EINVAL;
199 	}
200 
201 	if (!nargs)
202 		return 0;
203 
204 	args = btf_params(func_proto);
205 	stub_args = btf_params(stub_func_proto);
206 
207 	info_buf = kcalloc(nargs, sizeof(*info_buf), GFP_KERNEL);
208 	if (!info_buf)
209 		return -ENOMEM;
210 
211 	/* Prepare info for every nullable argument */
212 	info = info_buf;
213 	for (arg_no = 0; arg_no < nargs; arg_no++) {
214 		/* Skip arguments that is not suffixed with
215 		 * "__nullable".
216 		 */
217 		if (!btf_param_match_suffix(btf, &stub_args[arg_no],
218 					    MAYBE_NULL_SUFFIX))
219 			continue;
220 
221 		/* Should be a pointer to struct */
222 		pointed_type = btf_type_resolve_ptr(btf,
223 						    args[arg_no].type,
224 						    &arg_btf_id);
225 		if (!pointed_type ||
226 		    !btf_type_is_struct(pointed_type)) {
227 			pr_warn("stub function %s__%s has %s tagging to an unsupported type\n",
228 				st_ops_name, member_name, MAYBE_NULL_SUFFIX);
229 			goto err_out;
230 		}
231 
232 		offset = btf_ctx_arg_offset(btf, func_proto, arg_no);
233 		if (offset < 0) {
234 			pr_warn("stub function %s__%s has an invalid trampoline ctx offset for arg#%u\n",
235 				st_ops_name, member_name, arg_no);
236 			goto err_out;
237 		}
238 
239 		if (args[arg_no].type != stub_args[arg_no].type) {
240 			pr_warn("arg#%u type in stub function %s__%s does not match with its original func_proto\n",
241 				arg_no, st_ops_name, member_name);
242 			goto err_out;
243 		}
244 
245 		/* Fill the information of the new argument */
246 		info->reg_type =
247 			PTR_TRUSTED | PTR_TO_BTF_ID | PTR_MAYBE_NULL;
248 		info->btf_id = arg_btf_id;
249 		info->btf = btf;
250 		info->offset = offset;
251 
252 		info++;
253 		info_cnt++;
254 	}
255 
256 	if (info_cnt) {
257 		arg_info->info = info_buf;
258 		arg_info->cnt = info_cnt;
259 	} else {
260 		kfree(info_buf);
261 	}
262 
263 	return 0;
264 
265 err_out:
266 	kfree(info_buf);
267 
268 	return -EINVAL;
269 }
270 
271 /* Clean up the arg_info in a struct bpf_struct_ops_desc. */
272 void bpf_struct_ops_desc_release(struct bpf_struct_ops_desc *st_ops_desc)
273 {
274 	struct bpf_struct_ops_arg_info *arg_info;
275 	int i;
276 
277 	arg_info = st_ops_desc->arg_info;
278 	for (i = 0; i < btf_type_vlen(st_ops_desc->type); i++)
279 		kfree(arg_info[i].info);
280 
281 	kfree(arg_info);
282 }
283 
284 int bpf_struct_ops_desc_init(struct bpf_struct_ops_desc *st_ops_desc,
285 			     struct btf *btf,
286 			     struct bpf_verifier_log *log)
287 {
288 	struct bpf_struct_ops *st_ops = st_ops_desc->st_ops;
289 	struct bpf_struct_ops_arg_info *arg_info;
290 	const struct btf_member *member;
291 	const struct btf_type *t;
292 	s32 type_id, value_id;
293 	char value_name[128];
294 	const char *mname;
295 	int i, err;
296 
297 	if (strlen(st_ops->name) + VALUE_PREFIX_LEN >=
298 	    sizeof(value_name)) {
299 		pr_warn("struct_ops name %s is too long\n",
300 			st_ops->name);
301 		return -EINVAL;
302 	}
303 	sprintf(value_name, "%s%s", VALUE_PREFIX, st_ops->name);
304 
305 	if (!st_ops->cfi_stubs) {
306 		pr_warn("struct_ops for %s has no cfi_stubs\n", st_ops->name);
307 		return -EINVAL;
308 	}
309 
310 	type_id = btf_find_by_name_kind(btf, st_ops->name,
311 					BTF_KIND_STRUCT);
312 	if (type_id < 0) {
313 		pr_warn("Cannot find struct %s in %s\n",
314 			st_ops->name, btf_get_name(btf));
315 		return -EINVAL;
316 	}
317 	t = btf_type_by_id(btf, type_id);
318 	if (btf_type_vlen(t) > BPF_STRUCT_OPS_MAX_NR_MEMBERS) {
319 		pr_warn("Cannot support #%u members in struct %s\n",
320 			btf_type_vlen(t), st_ops->name);
321 		return -EINVAL;
322 	}
323 
324 	value_id = btf_find_by_name_kind(btf, value_name,
325 					 BTF_KIND_STRUCT);
326 	if (value_id < 0) {
327 		pr_warn("Cannot find struct %s in %s\n",
328 			value_name, btf_get_name(btf));
329 		return -EINVAL;
330 	}
331 	if (!is_valid_value_type(btf, value_id, t, value_name))
332 		return -EINVAL;
333 
334 	arg_info = kcalloc(btf_type_vlen(t), sizeof(*arg_info),
335 			   GFP_KERNEL);
336 	if (!arg_info)
337 		return -ENOMEM;
338 
339 	st_ops_desc->arg_info = arg_info;
340 	st_ops_desc->type = t;
341 	st_ops_desc->type_id = type_id;
342 	st_ops_desc->value_id = value_id;
343 	st_ops_desc->value_type = btf_type_by_id(btf, value_id);
344 
345 	for_each_member(i, t, member) {
346 		const struct btf_type *func_proto;
347 
348 		mname = btf_name_by_offset(btf, member->name_off);
349 		if (!*mname) {
350 			pr_warn("anon member in struct %s is not supported\n",
351 				st_ops->name);
352 			err = -EOPNOTSUPP;
353 			goto errout;
354 		}
355 
356 		if (__btf_member_bitfield_size(t, member)) {
357 			pr_warn("bit field member %s in struct %s is not supported\n",
358 				mname, st_ops->name);
359 			err = -EOPNOTSUPP;
360 			goto errout;
361 		}
362 
363 		func_proto = btf_type_resolve_func_ptr(btf,
364 						       member->type,
365 						       NULL);
366 		if (!func_proto)
367 			continue;
368 
369 		if (btf_distill_func_proto(log, btf,
370 					   func_proto, mname,
371 					   &st_ops->func_models[i])) {
372 			pr_warn("Error in parsing func ptr %s in struct %s\n",
373 				mname, st_ops->name);
374 			err = -EINVAL;
375 			goto errout;
376 		}
377 
378 		err = prepare_arg_info(btf, st_ops->name, mname,
379 				       func_proto,
380 				       arg_info + i);
381 		if (err)
382 			goto errout;
383 	}
384 
385 	if (st_ops->init(btf)) {
386 		pr_warn("Error in init bpf_struct_ops %s\n",
387 			st_ops->name);
388 		err = -EINVAL;
389 		goto errout;
390 	}
391 
392 	return 0;
393 
394 errout:
395 	bpf_struct_ops_desc_release(st_ops_desc);
396 
397 	return err;
398 }
399 
400 static int bpf_struct_ops_map_get_next_key(struct bpf_map *map, void *key,
401 					   void *next_key)
402 {
403 	if (key && *(u32 *)key == 0)
404 		return -ENOENT;
405 
406 	*(u32 *)next_key = 0;
407 	return 0;
408 }
409 
410 int bpf_struct_ops_map_sys_lookup_elem(struct bpf_map *map, void *key,
411 				       void *value)
412 {
413 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
414 	struct bpf_struct_ops_value *uvalue, *kvalue;
415 	enum bpf_struct_ops_state state;
416 	s64 refcnt;
417 
418 	if (unlikely(*(u32 *)key != 0))
419 		return -ENOENT;
420 
421 	kvalue = &st_map->kvalue;
422 	/* Pair with smp_store_release() during map_update */
423 	state = smp_load_acquire(&kvalue->common.state);
424 	if (state == BPF_STRUCT_OPS_STATE_INIT) {
425 		memset(value, 0, map->value_size);
426 		return 0;
427 	}
428 
429 	/* No lock is needed.  state and refcnt do not need
430 	 * to be updated together under atomic context.
431 	 */
432 	uvalue = value;
433 	memcpy(uvalue, st_map->uvalue, map->value_size);
434 	uvalue->common.state = state;
435 
436 	/* This value offers the user space a general estimate of how
437 	 * many sockets are still utilizing this struct_ops for TCP
438 	 * congestion control. The number might not be exact, but it
439 	 * should sufficiently meet our present goals.
440 	 */
441 	refcnt = atomic64_read(&map->refcnt) - atomic64_read(&map->usercnt);
442 	refcount_set(&uvalue->common.refcnt, max_t(s64, refcnt, 0));
443 
444 	return 0;
445 }
446 
447 static void *bpf_struct_ops_map_lookup_elem(struct bpf_map *map, void *key)
448 {
449 	return ERR_PTR(-EINVAL);
450 }
451 
452 static void bpf_struct_ops_map_put_progs(struct bpf_struct_ops_map *st_map)
453 {
454 	u32 i;
455 
456 	for (i = 0; i < st_map->links_cnt; i++) {
457 		if (st_map->links[i]) {
458 			bpf_link_put(st_map->links[i]);
459 			st_map->links[i] = NULL;
460 		}
461 	}
462 }
463 
464 static int check_zero_holes(const struct btf *btf, const struct btf_type *t, void *data)
465 {
466 	const struct btf_member *member;
467 	u32 i, moff, msize, prev_mend = 0;
468 	const struct btf_type *mtype;
469 
470 	for_each_member(i, t, member) {
471 		moff = __btf_member_bit_offset(t, member) / 8;
472 		if (moff > prev_mend &&
473 		    memchr_inv(data + prev_mend, 0, moff - prev_mend))
474 			return -EINVAL;
475 
476 		mtype = btf_type_by_id(btf, member->type);
477 		mtype = btf_resolve_size(btf, mtype, &msize);
478 		if (IS_ERR(mtype))
479 			return PTR_ERR(mtype);
480 		prev_mend = moff + msize;
481 	}
482 
483 	if (t->size > prev_mend &&
484 	    memchr_inv(data + prev_mend, 0, t->size - prev_mend))
485 		return -EINVAL;
486 
487 	return 0;
488 }
489 
490 static void bpf_struct_ops_link_release(struct bpf_link *link)
491 {
492 }
493 
494 static void bpf_struct_ops_link_dealloc(struct bpf_link *link)
495 {
496 	struct bpf_tramp_link *tlink = container_of(link, struct bpf_tramp_link, link);
497 
498 	kfree(tlink);
499 }
500 
501 const struct bpf_link_ops bpf_struct_ops_link_lops = {
502 	.release = bpf_struct_ops_link_release,
503 	.dealloc = bpf_struct_ops_link_dealloc,
504 };
505 
506 int bpf_struct_ops_prepare_trampoline(struct bpf_tramp_links *tlinks,
507 				      struct bpf_tramp_link *link,
508 				      const struct btf_func_model *model,
509 				      void *stub_func, void *image, void *image_end)
510 {
511 	u32 flags = BPF_TRAMP_F_INDIRECT;
512 	int size;
513 
514 	tlinks[BPF_TRAMP_FENTRY].links[0] = link;
515 	tlinks[BPF_TRAMP_FENTRY].nr_links = 1;
516 
517 	if (model->ret_size > 0)
518 		flags |= BPF_TRAMP_F_RET_FENTRY_RET;
519 
520 	size = arch_bpf_trampoline_size(model, flags, tlinks, NULL);
521 	if (size < 0)
522 		return size;
523 	if (size > (unsigned long)image_end - (unsigned long)image)
524 		return -E2BIG;
525 	return arch_prepare_bpf_trampoline(NULL, image, image_end,
526 					   model, flags, tlinks, stub_func);
527 }
528 
529 static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key,
530 					   void *value, u64 flags)
531 {
532 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
533 	const struct bpf_struct_ops_desc *st_ops_desc = st_map->st_ops_desc;
534 	const struct bpf_struct_ops *st_ops = st_ops_desc->st_ops;
535 	struct bpf_struct_ops_value *uvalue, *kvalue;
536 	const struct btf_type *module_type;
537 	const struct btf_member *member;
538 	const struct btf_type *t = st_ops_desc->type;
539 	struct bpf_tramp_links *tlinks;
540 	void *udata, *kdata;
541 	int prog_fd, err;
542 	void *image, *image_end;
543 	u32 i;
544 
545 	if (flags)
546 		return -EINVAL;
547 
548 	if (*(u32 *)key != 0)
549 		return -E2BIG;
550 
551 	err = check_zero_holes(st_map->btf, st_ops_desc->value_type, value);
552 	if (err)
553 		return err;
554 
555 	uvalue = value;
556 	err = check_zero_holes(st_map->btf, t, uvalue->data);
557 	if (err)
558 		return err;
559 
560 	if (uvalue->common.state || refcount_read(&uvalue->common.refcnt))
561 		return -EINVAL;
562 
563 	tlinks = kcalloc(BPF_TRAMP_MAX, sizeof(*tlinks), GFP_KERNEL);
564 	if (!tlinks)
565 		return -ENOMEM;
566 
567 	uvalue = (struct bpf_struct_ops_value *)st_map->uvalue;
568 	kvalue = (struct bpf_struct_ops_value *)&st_map->kvalue;
569 
570 	mutex_lock(&st_map->lock);
571 
572 	if (kvalue->common.state != BPF_STRUCT_OPS_STATE_INIT) {
573 		err = -EBUSY;
574 		goto unlock;
575 	}
576 
577 	memcpy(uvalue, value, map->value_size);
578 
579 	udata = &uvalue->data;
580 	kdata = &kvalue->data;
581 	image = st_map->image;
582 	image_end = st_map->image + PAGE_SIZE;
583 
584 	module_type = btf_type_by_id(btf_vmlinux, st_ops_ids[IDX_MODULE_ID]);
585 	for_each_member(i, t, member) {
586 		const struct btf_type *mtype, *ptype;
587 		struct bpf_prog *prog;
588 		struct bpf_tramp_link *link;
589 		u32 moff;
590 
591 		moff = __btf_member_bit_offset(t, member) / 8;
592 		ptype = btf_type_resolve_ptr(st_map->btf, member->type, NULL);
593 		if (ptype == module_type) {
594 			if (*(void **)(udata + moff))
595 				goto reset_unlock;
596 			*(void **)(kdata + moff) = BPF_MODULE_OWNER;
597 			continue;
598 		}
599 
600 		err = st_ops->init_member(t, member, kdata, udata);
601 		if (err < 0)
602 			goto reset_unlock;
603 
604 		/* The ->init_member() has handled this member */
605 		if (err > 0)
606 			continue;
607 
608 		/* If st_ops->init_member does not handle it,
609 		 * we will only handle func ptrs and zero-ed members
610 		 * here.  Reject everything else.
611 		 */
612 
613 		/* All non func ptr member must be 0 */
614 		if (!ptype || !btf_type_is_func_proto(ptype)) {
615 			u32 msize;
616 
617 			mtype = btf_type_by_id(st_map->btf, member->type);
618 			mtype = btf_resolve_size(st_map->btf, mtype, &msize);
619 			if (IS_ERR(mtype)) {
620 				err = PTR_ERR(mtype);
621 				goto reset_unlock;
622 			}
623 
624 			if (memchr_inv(udata + moff, 0, msize)) {
625 				err = -EINVAL;
626 				goto reset_unlock;
627 			}
628 
629 			continue;
630 		}
631 
632 		prog_fd = (int)(*(unsigned long *)(udata + moff));
633 		/* Similar check as the attr->attach_prog_fd */
634 		if (!prog_fd)
635 			continue;
636 
637 		prog = bpf_prog_get(prog_fd);
638 		if (IS_ERR(prog)) {
639 			err = PTR_ERR(prog);
640 			goto reset_unlock;
641 		}
642 
643 		if (prog->type != BPF_PROG_TYPE_STRUCT_OPS ||
644 		    prog->aux->attach_btf_id != st_ops_desc->type_id ||
645 		    prog->expected_attach_type != i) {
646 			bpf_prog_put(prog);
647 			err = -EINVAL;
648 			goto reset_unlock;
649 		}
650 
651 		link = kzalloc(sizeof(*link), GFP_USER);
652 		if (!link) {
653 			bpf_prog_put(prog);
654 			err = -ENOMEM;
655 			goto reset_unlock;
656 		}
657 		bpf_link_init(&link->link, BPF_LINK_TYPE_STRUCT_OPS,
658 			      &bpf_struct_ops_link_lops, prog);
659 		st_map->links[i] = &link->link;
660 
661 		err = bpf_struct_ops_prepare_trampoline(tlinks, link,
662 							&st_ops->func_models[i],
663 							*(void **)(st_ops->cfi_stubs + moff),
664 							image, image_end);
665 		if (err < 0)
666 			goto reset_unlock;
667 
668 		*(void **)(kdata + moff) = image + cfi_get_offset();
669 		image += err;
670 
671 		/* put prog_id to udata */
672 		*(unsigned long *)(udata + moff) = prog->aux->id;
673 	}
674 
675 	if (st_map->map.map_flags & BPF_F_LINK) {
676 		err = 0;
677 		if (st_ops->validate) {
678 			err = st_ops->validate(kdata);
679 			if (err)
680 				goto reset_unlock;
681 		}
682 		arch_protect_bpf_trampoline(st_map->image, PAGE_SIZE);
683 		/* Let bpf_link handle registration & unregistration.
684 		 *
685 		 * Pair with smp_load_acquire() during lookup_elem().
686 		 */
687 		smp_store_release(&kvalue->common.state, BPF_STRUCT_OPS_STATE_READY);
688 		goto unlock;
689 	}
690 
691 	arch_protect_bpf_trampoline(st_map->image, PAGE_SIZE);
692 	err = st_ops->reg(kdata);
693 	if (likely(!err)) {
694 		/* This refcnt increment on the map here after
695 		 * 'st_ops->reg()' is secure since the state of the
696 		 * map must be set to INIT at this moment, and thus
697 		 * bpf_struct_ops_map_delete_elem() can't unregister
698 		 * or transition it to TOBEFREE concurrently.
699 		 */
700 		bpf_map_inc(map);
701 		/* Pair with smp_load_acquire() during lookup_elem().
702 		 * It ensures the above udata updates (e.g. prog->aux->id)
703 		 * can be seen once BPF_STRUCT_OPS_STATE_INUSE is set.
704 		 */
705 		smp_store_release(&kvalue->common.state, BPF_STRUCT_OPS_STATE_INUSE);
706 		goto unlock;
707 	}
708 
709 	/* Error during st_ops->reg(). Can happen if this struct_ops needs to be
710 	 * verified as a whole, after all init_member() calls. Can also happen if
711 	 * there was a race in registering the struct_ops (under the same name) to
712 	 * a sub-system through different struct_ops's maps.
713 	 */
714 	arch_unprotect_bpf_trampoline(st_map->image, PAGE_SIZE);
715 
716 reset_unlock:
717 	bpf_struct_ops_map_put_progs(st_map);
718 	memset(uvalue, 0, map->value_size);
719 	memset(kvalue, 0, map->value_size);
720 unlock:
721 	kfree(tlinks);
722 	mutex_unlock(&st_map->lock);
723 	return err;
724 }
725 
726 static long bpf_struct_ops_map_delete_elem(struct bpf_map *map, void *key)
727 {
728 	enum bpf_struct_ops_state prev_state;
729 	struct bpf_struct_ops_map *st_map;
730 
731 	st_map = (struct bpf_struct_ops_map *)map;
732 	if (st_map->map.map_flags & BPF_F_LINK)
733 		return -EOPNOTSUPP;
734 
735 	prev_state = cmpxchg(&st_map->kvalue.common.state,
736 			     BPF_STRUCT_OPS_STATE_INUSE,
737 			     BPF_STRUCT_OPS_STATE_TOBEFREE);
738 	switch (prev_state) {
739 	case BPF_STRUCT_OPS_STATE_INUSE:
740 		st_map->st_ops_desc->st_ops->unreg(&st_map->kvalue.data);
741 		bpf_map_put(map);
742 		return 0;
743 	case BPF_STRUCT_OPS_STATE_TOBEFREE:
744 		return -EINPROGRESS;
745 	case BPF_STRUCT_OPS_STATE_INIT:
746 		return -ENOENT;
747 	default:
748 		WARN_ON_ONCE(1);
749 		/* Should never happen.  Treat it as not found. */
750 		return -ENOENT;
751 	}
752 }
753 
754 static void bpf_struct_ops_map_seq_show_elem(struct bpf_map *map, void *key,
755 					     struct seq_file *m)
756 {
757 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
758 	void *value;
759 	int err;
760 
761 	value = kmalloc(map->value_size, GFP_USER | __GFP_NOWARN);
762 	if (!value)
763 		return;
764 
765 	err = bpf_struct_ops_map_sys_lookup_elem(map, key, value);
766 	if (!err) {
767 		btf_type_seq_show(st_map->btf,
768 				  map->btf_vmlinux_value_type_id,
769 				  value, m);
770 		seq_puts(m, "\n");
771 	}
772 
773 	kfree(value);
774 }
775 
776 static void __bpf_struct_ops_map_free(struct bpf_map *map)
777 {
778 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
779 
780 	if (st_map->links)
781 		bpf_struct_ops_map_put_progs(st_map);
782 	bpf_map_area_free(st_map->links);
783 	if (st_map->image) {
784 		arch_free_bpf_trampoline(st_map->image, PAGE_SIZE);
785 		bpf_jit_uncharge_modmem(PAGE_SIZE);
786 	}
787 	bpf_map_area_free(st_map->uvalue);
788 	bpf_map_area_free(st_map);
789 }
790 
791 static void bpf_struct_ops_map_free(struct bpf_map *map)
792 {
793 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
794 
795 	/* st_ops->owner was acquired during map_alloc to implicitly holds
796 	 * the btf's refcnt. The acquire was only done when btf_is_module()
797 	 * st_map->btf cannot be NULL here.
798 	 */
799 	if (btf_is_module(st_map->btf))
800 		module_put(st_map->st_ops_desc->st_ops->owner);
801 
802 	/* The struct_ops's function may switch to another struct_ops.
803 	 *
804 	 * For example, bpf_tcp_cc_x->init() may switch to
805 	 * another tcp_cc_y by calling
806 	 * setsockopt(TCP_CONGESTION, "tcp_cc_y").
807 	 * During the switch,  bpf_struct_ops_put(tcp_cc_x) is called
808 	 * and its refcount may reach 0 which then free its
809 	 * trampoline image while tcp_cc_x is still running.
810 	 *
811 	 * A vanilla rcu gp is to wait for all bpf-tcp-cc prog
812 	 * to finish. bpf-tcp-cc prog is non sleepable.
813 	 * A rcu_tasks gp is to wait for the last few insn
814 	 * in the tramopline image to finish before releasing
815 	 * the trampoline image.
816 	 */
817 	synchronize_rcu_mult(call_rcu, call_rcu_tasks);
818 
819 	__bpf_struct_ops_map_free(map);
820 }
821 
822 static int bpf_struct_ops_map_alloc_check(union bpf_attr *attr)
823 {
824 	if (attr->key_size != sizeof(unsigned int) || attr->max_entries != 1 ||
825 	    (attr->map_flags & ~(BPF_F_LINK | BPF_F_VTYPE_BTF_OBJ_FD)) ||
826 	    !attr->btf_vmlinux_value_type_id)
827 		return -EINVAL;
828 	return 0;
829 }
830 
831 static struct bpf_map *bpf_struct_ops_map_alloc(union bpf_attr *attr)
832 {
833 	const struct bpf_struct_ops_desc *st_ops_desc;
834 	size_t st_map_size;
835 	struct bpf_struct_ops_map *st_map;
836 	const struct btf_type *t, *vt;
837 	struct module *mod = NULL;
838 	struct bpf_map *map;
839 	struct btf *btf;
840 	int ret;
841 
842 	if (attr->map_flags & BPF_F_VTYPE_BTF_OBJ_FD) {
843 		/* The map holds btf for its whole life time. */
844 		btf = btf_get_by_fd(attr->value_type_btf_obj_fd);
845 		if (IS_ERR(btf))
846 			return ERR_CAST(btf);
847 		if (!btf_is_module(btf)) {
848 			btf_put(btf);
849 			return ERR_PTR(-EINVAL);
850 		}
851 
852 		mod = btf_try_get_module(btf);
853 		/* mod holds a refcnt to btf. We don't need an extra refcnt
854 		 * here.
855 		 */
856 		btf_put(btf);
857 		if (!mod)
858 			return ERR_PTR(-EINVAL);
859 	} else {
860 		btf = bpf_get_btf_vmlinux();
861 		if (IS_ERR(btf))
862 			return ERR_CAST(btf);
863 		if (!btf)
864 			return ERR_PTR(-ENOTSUPP);
865 	}
866 
867 	st_ops_desc = bpf_struct_ops_find_value(btf, attr->btf_vmlinux_value_type_id);
868 	if (!st_ops_desc) {
869 		ret = -ENOTSUPP;
870 		goto errout;
871 	}
872 
873 	vt = st_ops_desc->value_type;
874 	if (attr->value_size != vt->size) {
875 		ret = -EINVAL;
876 		goto errout;
877 	}
878 
879 	t = st_ops_desc->type;
880 
881 	st_map_size = sizeof(*st_map) +
882 		/* kvalue stores the
883 		 * struct bpf_struct_ops_tcp_congestions_ops
884 		 */
885 		(vt->size - sizeof(struct bpf_struct_ops_value));
886 
887 	st_map = bpf_map_area_alloc(st_map_size, NUMA_NO_NODE);
888 	if (!st_map) {
889 		ret = -ENOMEM;
890 		goto errout;
891 	}
892 
893 	st_map->st_ops_desc = st_ops_desc;
894 	map = &st_map->map;
895 
896 	ret = bpf_jit_charge_modmem(PAGE_SIZE);
897 	if (ret)
898 		goto errout_free;
899 
900 	st_map->image = arch_alloc_bpf_trampoline(PAGE_SIZE);
901 	if (!st_map->image) {
902 		/* __bpf_struct_ops_map_free() uses st_map->image as flag
903 		 * for "charged or not". In this case, we need to unchange
904 		 * here.
905 		 */
906 		bpf_jit_uncharge_modmem(PAGE_SIZE);
907 		ret = -ENOMEM;
908 		goto errout_free;
909 	}
910 	st_map->uvalue = bpf_map_area_alloc(vt->size, NUMA_NO_NODE);
911 	st_map->links_cnt = btf_type_vlen(t);
912 	st_map->links =
913 		bpf_map_area_alloc(st_map->links_cnt * sizeof(struct bpf_links *),
914 				   NUMA_NO_NODE);
915 	if (!st_map->uvalue || !st_map->links) {
916 		ret = -ENOMEM;
917 		goto errout_free;
918 	}
919 	st_map->btf = btf;
920 
921 	mutex_init(&st_map->lock);
922 	bpf_map_init_from_attr(map, attr);
923 
924 	return map;
925 
926 errout_free:
927 	__bpf_struct_ops_map_free(map);
928 errout:
929 	module_put(mod);
930 
931 	return ERR_PTR(ret);
932 }
933 
934 static u64 bpf_struct_ops_map_mem_usage(const struct bpf_map *map)
935 {
936 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
937 	const struct bpf_struct_ops_desc *st_ops_desc = st_map->st_ops_desc;
938 	const struct btf_type *vt = st_ops_desc->value_type;
939 	u64 usage;
940 
941 	usage = sizeof(*st_map) +
942 			vt->size - sizeof(struct bpf_struct_ops_value);
943 	usage += vt->size;
944 	usage += btf_type_vlen(vt) * sizeof(struct bpf_links *);
945 	usage += PAGE_SIZE;
946 	return usage;
947 }
948 
949 BTF_ID_LIST_SINGLE(bpf_struct_ops_map_btf_ids, struct, bpf_struct_ops_map)
950 const struct bpf_map_ops bpf_struct_ops_map_ops = {
951 	.map_alloc_check = bpf_struct_ops_map_alloc_check,
952 	.map_alloc = bpf_struct_ops_map_alloc,
953 	.map_free = bpf_struct_ops_map_free,
954 	.map_get_next_key = bpf_struct_ops_map_get_next_key,
955 	.map_lookup_elem = bpf_struct_ops_map_lookup_elem,
956 	.map_delete_elem = bpf_struct_ops_map_delete_elem,
957 	.map_update_elem = bpf_struct_ops_map_update_elem,
958 	.map_seq_show_elem = bpf_struct_ops_map_seq_show_elem,
959 	.map_mem_usage = bpf_struct_ops_map_mem_usage,
960 	.map_btf_id = &bpf_struct_ops_map_btf_ids[0],
961 };
962 
963 /* "const void *" because some subsystem is
964  * passing a const (e.g. const struct tcp_congestion_ops *)
965  */
966 bool bpf_struct_ops_get(const void *kdata)
967 {
968 	struct bpf_struct_ops_value *kvalue;
969 	struct bpf_struct_ops_map *st_map;
970 	struct bpf_map *map;
971 
972 	kvalue = container_of(kdata, struct bpf_struct_ops_value, data);
973 	st_map = container_of(kvalue, struct bpf_struct_ops_map, kvalue);
974 
975 	map = __bpf_map_inc_not_zero(&st_map->map, false);
976 	return !IS_ERR(map);
977 }
978 
979 void bpf_struct_ops_put(const void *kdata)
980 {
981 	struct bpf_struct_ops_value *kvalue;
982 	struct bpf_struct_ops_map *st_map;
983 
984 	kvalue = container_of(kdata, struct bpf_struct_ops_value, data);
985 	st_map = container_of(kvalue, struct bpf_struct_ops_map, kvalue);
986 
987 	bpf_map_put(&st_map->map);
988 }
989 
990 static bool bpf_struct_ops_valid_to_reg(struct bpf_map *map)
991 {
992 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
993 
994 	return map->map_type == BPF_MAP_TYPE_STRUCT_OPS &&
995 		map->map_flags & BPF_F_LINK &&
996 		/* Pair with smp_store_release() during map_update */
997 		smp_load_acquire(&st_map->kvalue.common.state) == BPF_STRUCT_OPS_STATE_READY;
998 }
999 
1000 static void bpf_struct_ops_map_link_dealloc(struct bpf_link *link)
1001 {
1002 	struct bpf_struct_ops_link *st_link;
1003 	struct bpf_struct_ops_map *st_map;
1004 
1005 	st_link = container_of(link, struct bpf_struct_ops_link, link);
1006 	st_map = (struct bpf_struct_ops_map *)
1007 		rcu_dereference_protected(st_link->map, true);
1008 	if (st_map) {
1009 		/* st_link->map can be NULL if
1010 		 * bpf_struct_ops_link_create() fails to register.
1011 		 */
1012 		st_map->st_ops_desc->st_ops->unreg(&st_map->kvalue.data);
1013 		bpf_map_put(&st_map->map);
1014 	}
1015 	kfree(st_link);
1016 }
1017 
1018 static void bpf_struct_ops_map_link_show_fdinfo(const struct bpf_link *link,
1019 					    struct seq_file *seq)
1020 {
1021 	struct bpf_struct_ops_link *st_link;
1022 	struct bpf_map *map;
1023 
1024 	st_link = container_of(link, struct bpf_struct_ops_link, link);
1025 	rcu_read_lock();
1026 	map = rcu_dereference(st_link->map);
1027 	seq_printf(seq, "map_id:\t%d\n", map->id);
1028 	rcu_read_unlock();
1029 }
1030 
1031 static int bpf_struct_ops_map_link_fill_link_info(const struct bpf_link *link,
1032 					       struct bpf_link_info *info)
1033 {
1034 	struct bpf_struct_ops_link *st_link;
1035 	struct bpf_map *map;
1036 
1037 	st_link = container_of(link, struct bpf_struct_ops_link, link);
1038 	rcu_read_lock();
1039 	map = rcu_dereference(st_link->map);
1040 	info->struct_ops.map_id = map->id;
1041 	rcu_read_unlock();
1042 	return 0;
1043 }
1044 
1045 static int bpf_struct_ops_map_link_update(struct bpf_link *link, struct bpf_map *new_map,
1046 					  struct bpf_map *expected_old_map)
1047 {
1048 	struct bpf_struct_ops_map *st_map, *old_st_map;
1049 	struct bpf_map *old_map;
1050 	struct bpf_struct_ops_link *st_link;
1051 	int err;
1052 
1053 	st_link = container_of(link, struct bpf_struct_ops_link, link);
1054 	st_map = container_of(new_map, struct bpf_struct_ops_map, map);
1055 
1056 	if (!bpf_struct_ops_valid_to_reg(new_map))
1057 		return -EINVAL;
1058 
1059 	if (!st_map->st_ops_desc->st_ops->update)
1060 		return -EOPNOTSUPP;
1061 
1062 	mutex_lock(&update_mutex);
1063 
1064 	old_map = rcu_dereference_protected(st_link->map, lockdep_is_held(&update_mutex));
1065 	if (expected_old_map && old_map != expected_old_map) {
1066 		err = -EPERM;
1067 		goto err_out;
1068 	}
1069 
1070 	old_st_map = container_of(old_map, struct bpf_struct_ops_map, map);
1071 	/* The new and old struct_ops must be the same type. */
1072 	if (st_map->st_ops_desc != old_st_map->st_ops_desc) {
1073 		err = -EINVAL;
1074 		goto err_out;
1075 	}
1076 
1077 	err = st_map->st_ops_desc->st_ops->update(st_map->kvalue.data, old_st_map->kvalue.data);
1078 	if (err)
1079 		goto err_out;
1080 
1081 	bpf_map_inc(new_map);
1082 	rcu_assign_pointer(st_link->map, new_map);
1083 	bpf_map_put(old_map);
1084 
1085 err_out:
1086 	mutex_unlock(&update_mutex);
1087 
1088 	return err;
1089 }
1090 
1091 static const struct bpf_link_ops bpf_struct_ops_map_lops = {
1092 	.dealloc = bpf_struct_ops_map_link_dealloc,
1093 	.show_fdinfo = bpf_struct_ops_map_link_show_fdinfo,
1094 	.fill_link_info = bpf_struct_ops_map_link_fill_link_info,
1095 	.update_map = bpf_struct_ops_map_link_update,
1096 };
1097 
1098 int bpf_struct_ops_link_create(union bpf_attr *attr)
1099 {
1100 	struct bpf_struct_ops_link *link = NULL;
1101 	struct bpf_link_primer link_primer;
1102 	struct bpf_struct_ops_map *st_map;
1103 	struct bpf_map *map;
1104 	int err;
1105 
1106 	map = bpf_map_get(attr->link_create.map_fd);
1107 	if (IS_ERR(map))
1108 		return PTR_ERR(map);
1109 
1110 	st_map = (struct bpf_struct_ops_map *)map;
1111 
1112 	if (!bpf_struct_ops_valid_to_reg(map)) {
1113 		err = -EINVAL;
1114 		goto err_out;
1115 	}
1116 
1117 	link = kzalloc(sizeof(*link), GFP_USER);
1118 	if (!link) {
1119 		err = -ENOMEM;
1120 		goto err_out;
1121 	}
1122 	bpf_link_init(&link->link, BPF_LINK_TYPE_STRUCT_OPS, &bpf_struct_ops_map_lops, NULL);
1123 
1124 	err = bpf_link_prime(&link->link, &link_primer);
1125 	if (err)
1126 		goto err_out;
1127 
1128 	err = st_map->st_ops_desc->st_ops->reg(st_map->kvalue.data);
1129 	if (err) {
1130 		bpf_link_cleanup(&link_primer);
1131 		link = NULL;
1132 		goto err_out;
1133 	}
1134 	RCU_INIT_POINTER(link->map, map);
1135 
1136 	return bpf_link_settle(&link_primer);
1137 
1138 err_out:
1139 	bpf_map_put(map);
1140 	kfree(link);
1141 	return err;
1142 }
1143 
1144 void bpf_map_struct_ops_info_fill(struct bpf_map_info *info, struct bpf_map *map)
1145 {
1146 	struct bpf_struct_ops_map *st_map = (struct bpf_struct_ops_map *)map;
1147 
1148 	info->btf_vmlinux_id = btf_obj_id(st_map->btf);
1149 }
1150