xref: /linux/tools/testing/selftests/bpf/prog_tests/ctx_rewrite.c (revision c83b49383b595be50647f0c764a48c78b5f3c4f8)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <limits.h>
4 #include <stdio.h>
5 #include <string.h>
6 #include <ctype.h>
7 #include <regex.h>
8 #include <test_progs.h>
9 
10 #include "bpf/btf.h"
11 #include "bpf_util.h"
12 #include "linux/filter.h"
13 #include "disasm.h"
14 
15 #define MAX_PROG_TEXT_SZ (32 * 1024)
16 
17 /* The code in this file serves the sole purpose of executing test cases
18  * specified in the test_cases array. Each test case specifies a program
19  * type, context field offset, and disassembly patterns that correspond
20  * to read and write instructions generated by
21  * verifier.c:convert_ctx_access() for accessing that field.
22  *
23  * For each test case, up to three programs are created:
24  * - One that uses BPF_LDX_MEM to read the context field.
25  * - One that uses BPF_STX_MEM to write to the context field.
26  * - One that uses BPF_ST_MEM to write to the context field.
27  *
28  * The disassembly of each program is then compared with the pattern
29  * specified in the test case.
30  */
31 struct test_case {
32 	char *name;
33 	enum bpf_prog_type prog_type;
34 	enum bpf_attach_type expected_attach_type;
35 	int field_offset;
36 	int field_sz;
37 	/* Program generated for BPF_ST_MEM uses value 42 by default,
38 	 * this field allows to specify custom value.
39 	 */
40 	struct {
41 		bool use;
42 		int value;
43 	} st_value;
44 	/* Pattern for BPF_LDX_MEM(field_sz, dst, ctx, field_offset) */
45 	char *read;
46 	/* Pattern for BPF_STX_MEM(field_sz, ctx, src, field_offset) and
47 	 *             BPF_ST_MEM (field_sz, ctx, src, field_offset)
48 	 */
49 	char *write;
50 	/* Pattern for BPF_ST_MEM(field_sz, ctx, src, field_offset),
51 	 * takes priority over `write`.
52 	 */
53 	char *write_st;
54 	/* Pattern for BPF_STX_MEM (field_sz, ctx, src, field_offset),
55 	 * takes priority over `write`.
56 	 */
57 	char *write_stx;
58 };
59 
60 #define N(_prog_type, type, field, name_extra...)	\
61 	.name = #_prog_type "." #field name_extra,	\
62 	.prog_type = BPF_PROG_TYPE_##_prog_type,	\
63 	.field_offset = offsetof(type, field),		\
64 	.field_sz = sizeof(typeof(((type *)NULL)->field))
65 
66 static struct test_case test_cases[] = {
67 /* Sign extension on s390 changes the pattern */
68 #if defined(__x86_64__) || defined(__aarch64__)
69 	{
70 		N(SCHED_CLS, struct __sk_buff, tstamp),
71 		.read  = "r11 = *(u8 *)($ctx + sk_buff::__mono_tc_offset);"
72 			 "w11 &= 3;"
73 			 "if w11 != 0x3 goto pc+2;"
74 			 "$dst = 0;"
75 			 "goto pc+1;"
76 			 "$dst = *(u64 *)($ctx + sk_buff::tstamp);",
77 		.write = "r11 = *(u8 *)($ctx + sk_buff::__mono_tc_offset);"
78 			 "if w11 & 0x2 goto pc+1;"
79 			 "goto pc+2;"
80 			 "w11 &= -2;"
81 			 "*(u8 *)($ctx + sk_buff::__mono_tc_offset) = r11;"
82 			 "*(u64 *)($ctx + sk_buff::tstamp) = $src;",
83 	},
84 #endif
85 	{
86 		N(SCHED_CLS, struct __sk_buff, priority),
87 		.read  = "$dst = *(u32 *)($ctx + sk_buff::priority);",
88 		.write = "*(u32 *)($ctx + sk_buff::priority) = $src;",
89 	},
90 	{
91 		N(SCHED_CLS, struct __sk_buff, mark),
92 		.read  = "$dst = *(u32 *)($ctx + sk_buff::mark);",
93 		.write = "*(u32 *)($ctx + sk_buff::mark) = $src;",
94 	},
95 	{
96 		N(SCHED_CLS, struct __sk_buff, cb[0]),
97 		.read  = "$dst = *(u32 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::data));",
98 		.write = "*(u32 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::data)) = $src;",
99 	},
100 	{
101 		N(SCHED_CLS, struct __sk_buff, tc_classid),
102 		.read  = "$dst = *(u16 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::tc_classid));",
103 		.write = "*(u16 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::tc_classid)) = $src;",
104 	},
105 	{
106 		N(SCHED_CLS, struct __sk_buff, tc_index),
107 		.read  = "$dst = *(u16 *)($ctx + sk_buff::tc_index);",
108 		.write = "*(u16 *)($ctx + sk_buff::tc_index) = $src;",
109 	},
110 	{
111 		N(SCHED_CLS, struct __sk_buff, queue_mapping),
112 		.read      = "$dst = *(u16 *)($ctx + sk_buff::queue_mapping);",
113 		.write_stx = "if $src >= 0xffff goto pc+1;"
114 			     "*(u16 *)($ctx + sk_buff::queue_mapping) = $src;",
115 		.write_st  = "*(u16 *)($ctx + sk_buff::queue_mapping) = $src;",
116 	},
117 	{
118 		/* This is a corner case in filter.c:bpf_convert_ctx_access() */
119 		N(SCHED_CLS, struct __sk_buff, queue_mapping, ".ushrt_max"),
120 		.st_value = { true, USHRT_MAX },
121 		.write_st = "goto pc+0;",
122 	},
123 	{
124 		N(CGROUP_SOCK, struct bpf_sock, bound_dev_if),
125 		.read  = "$dst = *(u32 *)($ctx + sock_common::skc_bound_dev_if);",
126 		.write = "*(u32 *)($ctx + sock_common::skc_bound_dev_if) = $src;",
127 	},
128 	{
129 		N(CGROUP_SOCK, struct bpf_sock, mark),
130 		.read  = "$dst = *(u32 *)($ctx + sock::sk_mark);",
131 		.write = "*(u32 *)($ctx + sock::sk_mark) = $src;",
132 	},
133 	{
134 		N(CGROUP_SOCK, struct bpf_sock, priority),
135 		.read  = "$dst = *(u32 *)($ctx + sock::sk_priority);",
136 		.write = "*(u32 *)($ctx + sock::sk_priority) = $src;",
137 	},
138 	{
139 		N(SOCK_OPS, struct bpf_sock_ops, replylong[0]),
140 		.read  = "$dst = *(u32 *)($ctx + bpf_sock_ops_kern::replylong);",
141 		.write = "*(u32 *)($ctx + bpf_sock_ops_kern::replylong) = $src;",
142 	},
143 	{
144 		N(CGROUP_SYSCTL, struct bpf_sysctl, file_pos),
145 #if __BYTE_ORDER == __LITTLE_ENDIAN
146 		.read  = "$dst = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
147 			 "$dst = *(u32 *)($dst +0);",
148 		.write = "*(u64 *)($ctx + bpf_sysctl_kern::tmp_reg) = r9;"
149 			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
150 			 "*(u32 *)(r9 +0) = $src;"
151 			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::tmp_reg);",
152 #else
153 		.read  = "$dst = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
154 			 "$dst = *(u32 *)($dst +4);",
155 		.write = "*(u64 *)($ctx + bpf_sysctl_kern::tmp_reg) = r9;"
156 			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
157 			 "*(u32 *)(r9 +4) = $src;"
158 			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::tmp_reg);",
159 #endif
160 	},
161 	{
162 		N(CGROUP_SOCKOPT, struct bpf_sockopt, sk),
163 		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::sk);",
164 		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
165 	},
166 	{
167 		N(CGROUP_SOCKOPT, struct bpf_sockopt, level),
168 		.read  = "$dst = *(u32 *)($ctx + bpf_sockopt_kern::level);",
169 		.write = "*(u32 *)($ctx + bpf_sockopt_kern::level) = $src;",
170 		.expected_attach_type = BPF_CGROUP_SETSOCKOPT,
171 	},
172 	{
173 		N(CGROUP_SOCKOPT, struct bpf_sockopt, optname),
174 		.read  = "$dst = *(u32 *)($ctx + bpf_sockopt_kern::optname);",
175 		.write = "*(u32 *)($ctx + bpf_sockopt_kern::optname) = $src;",
176 		.expected_attach_type = BPF_CGROUP_SETSOCKOPT,
177 	},
178 	{
179 		N(CGROUP_SOCKOPT, struct bpf_sockopt, optlen),
180 		.read  = "$dst = *(u32 *)($ctx + bpf_sockopt_kern::optlen);",
181 		.write = "*(u32 *)($ctx + bpf_sockopt_kern::optlen) = $src;",
182 		.expected_attach_type = BPF_CGROUP_SETSOCKOPT,
183 	},
184 	{
185 		N(CGROUP_SOCKOPT, struct bpf_sockopt, retval),
186 		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::current_task);"
187 			 "$dst = *(u64 *)($dst + task_struct::bpf_ctx);"
188 			 "$dst = *(u32 *)($dst + bpf_cg_run_ctx::retval);",
189 		.write = "*(u64 *)($ctx + bpf_sockopt_kern::tmp_reg) = r9;"
190 			 "r9 = *(u64 *)($ctx + bpf_sockopt_kern::current_task);"
191 			 "r9 = *(u64 *)(r9 + task_struct::bpf_ctx);"
192 			 "*(u32 *)(r9 + bpf_cg_run_ctx::retval) = $src;"
193 			 "r9 = *(u64 *)($ctx + bpf_sockopt_kern::tmp_reg);",
194 		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
195 	},
196 	{
197 		N(CGROUP_SOCKOPT, struct bpf_sockopt, optval),
198 		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::optval);",
199 		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
200 	},
201 	{
202 		N(CGROUP_SOCKOPT, struct bpf_sockopt, optval_end),
203 		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::optval_end);",
204 		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
205 	},
206 };
207 
208 #undef N
209 
210 static regex_t *ident_regex;
211 static regex_t *field_regex;
212 
213 static char *skip_space(char *str)
214 {
215 	while (*str && isspace(*str))
216 		++str;
217 	return str;
218 }
219 
220 static char *skip_space_and_semi(char *str)
221 {
222 	while (*str && (isspace(*str) || *str == ';'))
223 		++str;
224 	return str;
225 }
226 
227 static char *match_str(char *str, char *prefix)
228 {
229 	while (*str && *prefix && *str == *prefix) {
230 		++str;
231 		++prefix;
232 	}
233 	if (*prefix)
234 		return NULL;
235 	return str;
236 }
237 
238 static char *match_number(char *str, int num)
239 {
240 	char *next;
241 	int snum = strtol(str, &next, 10);
242 
243 	if (next - str == 0 || num != snum)
244 		return NULL;
245 
246 	return next;
247 }
248 
249 static int find_field_offset_aux(struct btf *btf, int btf_id, char *field_name, int off)
250 {
251 	const struct btf_type *type = btf__type_by_id(btf, btf_id);
252 	const struct btf_member *m;
253 	__u16 mnum;
254 	int i;
255 
256 	if (!type) {
257 		PRINT_FAIL("Can't find btf_type for id %d\n", btf_id);
258 		return -1;
259 	}
260 
261 	if (!btf_is_struct(type) && !btf_is_union(type)) {
262 		PRINT_FAIL("BTF id %d is not struct or union\n", btf_id);
263 		return -1;
264 	}
265 
266 	m = btf_members(type);
267 	mnum = btf_vlen(type);
268 
269 	for (i = 0; i < mnum; ++i, ++m) {
270 		const char *mname = btf__name_by_offset(btf, m->name_off);
271 
272 		if (strcmp(mname, "") == 0) {
273 			int msize = find_field_offset_aux(btf, m->type, field_name,
274 							  off + m->offset);
275 			if (msize >= 0)
276 				return msize;
277 		}
278 
279 		if (strcmp(mname, field_name))
280 			continue;
281 
282 		return (off + m->offset) / 8;
283 	}
284 
285 	return -1;
286 }
287 
288 static int find_field_offset(struct btf *btf, char *pattern, regmatch_t *matches)
289 {
290 	int type_sz  = matches[1].rm_eo - matches[1].rm_so;
291 	int field_sz = matches[2].rm_eo - matches[2].rm_so;
292 	char *type   = pattern + matches[1].rm_so;
293 	char *field  = pattern + matches[2].rm_so;
294 	char field_str[128] = {};
295 	char type_str[128] = {};
296 	int btf_id, field_offset;
297 
298 	if (type_sz >= sizeof(type_str)) {
299 		PRINT_FAIL("Malformed pattern: type ident is too long: %d\n", type_sz);
300 		return -1;
301 	}
302 
303 	if (field_sz >= sizeof(field_str)) {
304 		PRINT_FAIL("Malformed pattern: field ident is too long: %d\n", field_sz);
305 		return -1;
306 	}
307 
308 	strncpy(type_str, type, type_sz);
309 	strncpy(field_str, field, field_sz);
310 	btf_id = btf__find_by_name(btf, type_str);
311 	if (btf_id < 0) {
312 		PRINT_FAIL("No BTF info for type %s\n", type_str);
313 		return -1;
314 	}
315 
316 	field_offset = find_field_offset_aux(btf, btf_id, field_str, 0);
317 	if (field_offset < 0) {
318 		PRINT_FAIL("No BTF info for field %s::%s\n", type_str, field_str);
319 		return -1;
320 	}
321 
322 	return field_offset;
323 }
324 
325 static regex_t *compile_regex(char *pat)
326 {
327 	regex_t *re;
328 	int err;
329 
330 	re = malloc(sizeof(regex_t));
331 	if (!re) {
332 		PRINT_FAIL("Can't alloc regex\n");
333 		return NULL;
334 	}
335 
336 	err = regcomp(re, pat, REG_EXTENDED);
337 	if (err) {
338 		char errbuf[512];
339 
340 		regerror(err, re, errbuf, sizeof(errbuf));
341 		PRINT_FAIL("Can't compile regex: %s\n", errbuf);
342 		free(re);
343 		return NULL;
344 	}
345 
346 	return re;
347 }
348 
349 static void free_regex(regex_t *re)
350 {
351 	if (!re)
352 		return;
353 
354 	regfree(re);
355 	free(re);
356 }
357 
358 static u32 max_line_len(char *str)
359 {
360 	u32 max_line = 0;
361 	char *next = str;
362 
363 	while (next) {
364 		next = strchr(str, '\n');
365 		if (next) {
366 			max_line = max_t(u32, max_line, (next - str));
367 			str = next + 1;
368 		} else {
369 			max_line = max_t(u32, max_line, strlen(str));
370 		}
371 	}
372 
373 	return min(max_line, 60u);
374 }
375 
376 /* Print strings `pattern_origin` and `text_origin` side by side,
377  * assume `pattern_pos` and `text_pos` designate location within
378  * corresponding origin string where match diverges.
379  * The output should look like:
380  *
381  *   Can't match disassembly(left) with pattern(right):
382  *   r2 = *(u64 *)(r1 +0)  ;  $dst = *(u64 *)($ctx + bpf_sockopt_kern::sk1)
383  *                     ^                             ^
384  *   r0 = 0                ;
385  *   exit                  ;
386  */
387 static void print_match_error(FILE *out,
388 			      char *pattern_origin, char *text_origin,
389 			      char *pattern_pos, char *text_pos)
390 {
391 	char *pattern = pattern_origin;
392 	char *text = text_origin;
393 	int middle = max_line_len(text) + 2;
394 
395 	fprintf(out, "Can't match disassembly(left) with pattern(right):\n");
396 	while (*pattern || *text) {
397 		int column = 0;
398 		int mark1 = -1;
399 		int mark2 = -1;
400 
401 		/* Print one line from text */
402 		while (*text && *text != '\n') {
403 			if (text == text_pos)
404 				mark1 = column;
405 			fputc(*text, out);
406 			++text;
407 			++column;
408 		}
409 		if (text == text_pos)
410 			mark1 = column;
411 
412 		/* Pad to the middle */
413 		while (column < middle) {
414 			fputc(' ', out);
415 			++column;
416 		}
417 		fputs(";  ", out);
418 		column += 3;
419 
420 		/* Print one line from pattern, pattern lines are terminated by ';' */
421 		while (*pattern && *pattern != ';') {
422 			if (pattern == pattern_pos)
423 				mark2 = column;
424 			fputc(*pattern, out);
425 			++pattern;
426 			++column;
427 		}
428 		if (pattern == pattern_pos)
429 			mark2 = column;
430 
431 		fputc('\n', out);
432 		if (*pattern)
433 			++pattern;
434 		if (*text)
435 			++text;
436 
437 		/* If pattern and text diverge at this line, print an
438 		 * additional line with '^' marks, highlighting
439 		 * positions where match fails.
440 		 */
441 		if (mark1 > 0 || mark2 > 0) {
442 			for (column = 0; column <= max(mark1, mark2); ++column) {
443 				if (column == mark1 || column == mark2)
444 					fputc('^', out);
445 				else
446 					fputc(' ', out);
447 			}
448 			fputc('\n', out);
449 		}
450 	}
451 }
452 
453 /* Test if `text` matches `pattern`. Pattern consists of the following elements:
454  *
455  * - Field offset references:
456  *
457  *     <type>::<field>
458  *
459  *   When such reference is encountered BTF is used to compute numerical
460  *   value for the offset of <field> in <type>. The `text` is expected to
461  *   contain matching numerical value.
462  *
463  * - Field groups:
464  *
465  *     $(<type>::<field> [+ <type>::<field>]*)
466  *
467  *   Allows to specify an offset that is a sum of multiple field offsets.
468  *   The `text` is expected to contain matching numerical value.
469  *
470  * - Variable references, e.g. `$src`, `$dst`, `$ctx`.
471  *   These are substitutions specified in `reg_map` array.
472  *   If a substring of pattern is equal to `reg_map[i][0]` the `text` is
473  *   expected to contain `reg_map[i][1]` in the matching position.
474  *
475  * - Whitespace is ignored, ';' counts as whitespace for `pattern`.
476  *
477  * - Any other characters, `pattern` and `text` should match one-to-one.
478  *
479  * Example of a pattern:
480  *
481  *                    __________ fields group ________________
482  *                   '                                        '
483  *   *(u16 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::tc_classid)) = $src;
484  *            ^^^^                   '______________________'
485  *     variable reference             field offset reference
486  */
487 static bool match_pattern(struct btf *btf, char *pattern, char *text, char *reg_map[][2])
488 {
489 	char *pattern_origin = pattern;
490 	char *text_origin = text;
491 	regmatch_t matches[3];
492 
493 _continue:
494 	while (*pattern) {
495 		if (!*text)
496 			goto err;
497 
498 		/* Skip whitespace */
499 		if (isspace(*pattern) || *pattern == ';') {
500 			if (!isspace(*text) && text != text_origin && isalnum(text[-1]))
501 				goto err;
502 			pattern = skip_space_and_semi(pattern);
503 			text = skip_space(text);
504 			continue;
505 		}
506 
507 		/* Check for variable references */
508 		for (int i = 0; reg_map[i][0]; ++i) {
509 			char *pattern_next, *text_next;
510 
511 			pattern_next = match_str(pattern, reg_map[i][0]);
512 			if (!pattern_next)
513 				continue;
514 
515 			text_next = match_str(text, reg_map[i][1]);
516 			if (!text_next)
517 				goto err;
518 
519 			pattern = pattern_next;
520 			text = text_next;
521 			goto _continue;
522 		}
523 
524 		/* Match field group:
525 		 *   $(sk_buff::cb + qdisc_skb_cb::tc_classid)
526 		 */
527 		if (strncmp(pattern, "$(", 2) == 0) {
528 			char *group_start = pattern, *text_next;
529 			int acc_offset = 0;
530 
531 			pattern += 2;
532 
533 			for (;;) {
534 				int field_offset;
535 
536 				pattern = skip_space(pattern);
537 				if (!*pattern) {
538 					PRINT_FAIL("Unexpected end of pattern\n");
539 					goto err;
540 				}
541 
542 				if (*pattern == ')') {
543 					++pattern;
544 					break;
545 				}
546 
547 				if (*pattern == '+') {
548 					++pattern;
549 					continue;
550 				}
551 
552 				printf("pattern: %s\n", pattern);
553 				if (regexec(field_regex, pattern, 3, matches, 0) != 0) {
554 					PRINT_FAIL("Field reference expected\n");
555 					goto err;
556 				}
557 
558 				field_offset = find_field_offset(btf, pattern, matches);
559 				if (field_offset < 0)
560 					goto err;
561 
562 				pattern += matches[0].rm_eo;
563 				acc_offset += field_offset;
564 			}
565 
566 			text_next = match_number(text, acc_offset);
567 			if (!text_next) {
568 				PRINT_FAIL("No match for group offset %.*s (%d)\n",
569 					   (int)(pattern - group_start),
570 					   group_start,
571 					   acc_offset);
572 				goto err;
573 			}
574 			text = text_next;
575 		}
576 
577 		/* Match field reference:
578 		 *   sk_buff::cb
579 		 */
580 		if (regexec(field_regex, pattern, 3, matches, 0) == 0) {
581 			int field_offset;
582 			char *text_next;
583 
584 			field_offset = find_field_offset(btf, pattern, matches);
585 			if (field_offset < 0)
586 				goto err;
587 
588 			text_next = match_number(text, field_offset);
589 			if (!text_next) {
590 				PRINT_FAIL("No match for field offset %.*s (%d)\n",
591 					   (int)matches[0].rm_eo, pattern, field_offset);
592 				goto err;
593 			}
594 
595 			pattern += matches[0].rm_eo;
596 			text = text_next;
597 			continue;
598 		}
599 
600 		/* If pattern points to identifier not followed by '::'
601 		 * skip the identifier to avoid n^2 application of the
602 		 * field reference rule.
603 		 */
604 		if (regexec(ident_regex, pattern, 1, matches, 0) == 0) {
605 			if (strncmp(pattern, text, matches[0].rm_eo) != 0)
606 				goto err;
607 
608 			pattern += matches[0].rm_eo;
609 			text += matches[0].rm_eo;
610 			continue;
611 		}
612 
613 		/* Match literally */
614 		if (*pattern != *text)
615 			goto err;
616 
617 		++pattern;
618 		++text;
619 	}
620 
621 	return true;
622 
623 err:
624 	test__fail();
625 	print_match_error(stdout, pattern_origin, text_origin, pattern, text);
626 	return false;
627 }
628 
629 /* Request BPF program instructions after all rewrites are applied,
630  * e.g. verifier.c:convert_ctx_access() is done.
631  */
632 static int get_xlated_program(int fd_prog, struct bpf_insn **buf, __u32 *cnt)
633 {
634 	struct bpf_prog_info info = {};
635 	__u32 info_len = sizeof(info);
636 	__u32 xlated_prog_len;
637 	__u32 buf_element_size = sizeof(struct bpf_insn);
638 
639 	if (bpf_prog_get_info_by_fd(fd_prog, &info, &info_len)) {
640 		perror("bpf_prog_get_info_by_fd failed");
641 		return -1;
642 	}
643 
644 	xlated_prog_len = info.xlated_prog_len;
645 	if (xlated_prog_len % buf_element_size) {
646 		printf("Program length %d is not multiple of %d\n",
647 		       xlated_prog_len, buf_element_size);
648 		return -1;
649 	}
650 
651 	*cnt = xlated_prog_len / buf_element_size;
652 	*buf = calloc(*cnt, buf_element_size);
653 	if (!buf) {
654 		perror("can't allocate xlated program buffer");
655 		return -ENOMEM;
656 	}
657 
658 	bzero(&info, sizeof(info));
659 	info.xlated_prog_len = xlated_prog_len;
660 	info.xlated_prog_insns = (__u64)(unsigned long)*buf;
661 	if (bpf_prog_get_info_by_fd(fd_prog, &info, &info_len)) {
662 		perror("second bpf_prog_get_info_by_fd failed");
663 		goto out_free_buf;
664 	}
665 
666 	return 0;
667 
668 out_free_buf:
669 	free(*buf);
670 	return -1;
671 }
672 
673 static void print_insn(void *private_data, const char *fmt, ...)
674 {
675 	va_list args;
676 
677 	va_start(args, fmt);
678 	vfprintf((FILE *)private_data, fmt, args);
679 	va_end(args);
680 }
681 
682 /* Disassemble instructions to a stream */
683 static void print_xlated(FILE *out, struct bpf_insn *insn, __u32 len)
684 {
685 	const struct bpf_insn_cbs cbs = {
686 		.cb_print	= print_insn,
687 		.cb_call	= NULL,
688 		.cb_imm		= NULL,
689 		.private_data	= out,
690 	};
691 	bool double_insn = false;
692 	int i;
693 
694 	for (i = 0; i < len; i++) {
695 		if (double_insn) {
696 			double_insn = false;
697 			continue;
698 		}
699 
700 		double_insn = insn[i].code == (BPF_LD | BPF_IMM | BPF_DW);
701 		print_bpf_insn(&cbs, insn + i, true);
702 	}
703 }
704 
705 /* We share code with kernel BPF disassembler, it adds '(FF) ' prefix
706  * for each instruction (FF stands for instruction `code` byte).
707  * This function removes the prefix inplace for each line in `str`.
708  */
709 static void remove_insn_prefix(char *str, int size)
710 {
711 	const int prefix_size = 5;
712 
713 	int write_pos = 0, read_pos = prefix_size;
714 	int len = strlen(str);
715 	char c;
716 
717 	size = min(size, len);
718 
719 	while (read_pos < size) {
720 		c = str[read_pos++];
721 		if (c == 0)
722 			break;
723 		str[write_pos++] = c;
724 		if (c == '\n')
725 			read_pos += prefix_size;
726 	}
727 	str[write_pos] = 0;
728 }
729 
730 struct prog_info {
731 	char *prog_kind;
732 	enum bpf_prog_type prog_type;
733 	enum bpf_attach_type expected_attach_type;
734 	struct bpf_insn *prog;
735 	u32 prog_len;
736 };
737 
738 static void match_program(struct btf *btf,
739 			  struct prog_info *pinfo,
740 			  char *pattern,
741 			  char *reg_map[][2],
742 			  bool skip_first_insn)
743 {
744 	struct bpf_insn *buf = NULL;
745 	int err = 0, prog_fd = 0;
746 	FILE *prog_out = NULL;
747 	char *text = NULL;
748 	__u32 cnt = 0;
749 
750 	text = calloc(MAX_PROG_TEXT_SZ, 1);
751 	if (!text) {
752 		PRINT_FAIL("Can't allocate %d bytes\n", MAX_PROG_TEXT_SZ);
753 		goto out;
754 	}
755 
756 	// TODO: log level
757 	LIBBPF_OPTS(bpf_prog_load_opts, opts);
758 	opts.log_buf = text;
759 	opts.log_size = MAX_PROG_TEXT_SZ;
760 	opts.log_level = 1 | 2 | 4;
761 	opts.expected_attach_type = pinfo->expected_attach_type;
762 
763 	prog_fd = bpf_prog_load(pinfo->prog_type, NULL, "GPL",
764 				pinfo->prog, pinfo->prog_len, &opts);
765 	if (prog_fd < 0) {
766 		PRINT_FAIL("Can't load program, errno %d (%s), verifier log:\n%s\n",
767 			   errno, strerror(errno), text);
768 		goto out;
769 	}
770 
771 	memset(text, 0, MAX_PROG_TEXT_SZ);
772 
773 	err = get_xlated_program(prog_fd, &buf, &cnt);
774 	if (err) {
775 		PRINT_FAIL("Can't load back BPF program\n");
776 		goto out;
777 	}
778 
779 	prog_out = fmemopen(text, MAX_PROG_TEXT_SZ - 1, "w");
780 	if (!prog_out) {
781 		PRINT_FAIL("Can't open memory stream\n");
782 		goto out;
783 	}
784 	if (skip_first_insn)
785 		print_xlated(prog_out, buf + 1, cnt - 1);
786 	else
787 		print_xlated(prog_out, buf, cnt);
788 	fclose(prog_out);
789 	remove_insn_prefix(text, MAX_PROG_TEXT_SZ);
790 
791 	ASSERT_TRUE(match_pattern(btf, pattern, text, reg_map),
792 		    pinfo->prog_kind);
793 
794 out:
795 	if (prog_fd)
796 		close(prog_fd);
797 	free(buf);
798 	free(text);
799 }
800 
801 static void run_one_testcase(struct btf *btf, struct test_case *test)
802 {
803 	struct prog_info pinfo = {};
804 	int bpf_sz;
805 
806 	if (!test__start_subtest(test->name))
807 		return;
808 
809 	switch (test->field_sz) {
810 	case 8:
811 		bpf_sz = BPF_DW;
812 		break;
813 	case 4:
814 		bpf_sz = BPF_W;
815 		break;
816 	case 2:
817 		bpf_sz = BPF_H;
818 		break;
819 	case 1:
820 		bpf_sz = BPF_B;
821 		break;
822 	default:
823 		PRINT_FAIL("Unexpected field size: %d, want 8,4,2 or 1\n", test->field_sz);
824 		return;
825 	}
826 
827 	pinfo.prog_type = test->prog_type;
828 	pinfo.expected_attach_type = test->expected_attach_type;
829 
830 	if (test->read) {
831 		struct bpf_insn ldx_prog[] = {
832 			BPF_LDX_MEM(bpf_sz, BPF_REG_2, BPF_REG_1, test->field_offset),
833 			BPF_MOV64_IMM(BPF_REG_0, 0),
834 			BPF_EXIT_INSN(),
835 		};
836 		char *reg_map[][2] = {
837 			{ "$ctx", "r1" },
838 			{ "$dst", "r2" },
839 			{}
840 		};
841 
842 		pinfo.prog_kind = "LDX";
843 		pinfo.prog = ldx_prog;
844 		pinfo.prog_len = ARRAY_SIZE(ldx_prog);
845 		match_program(btf, &pinfo, test->read, reg_map, false);
846 	}
847 
848 	if (test->write || test->write_st || test->write_stx) {
849 		struct bpf_insn stx_prog[] = {
850 			BPF_MOV64_IMM(BPF_REG_2, 0),
851 			BPF_STX_MEM(bpf_sz, BPF_REG_1, BPF_REG_2, test->field_offset),
852 			BPF_MOV64_IMM(BPF_REG_0, 0),
853 			BPF_EXIT_INSN(),
854 		};
855 		char *stx_reg_map[][2] = {
856 			{ "$ctx", "r1" },
857 			{ "$src", "r2" },
858 			{}
859 		};
860 		struct bpf_insn st_prog[] = {
861 			BPF_ST_MEM(bpf_sz, BPF_REG_1, test->field_offset,
862 				   test->st_value.use ? test->st_value.value : 42),
863 			BPF_MOV64_IMM(BPF_REG_0, 0),
864 			BPF_EXIT_INSN(),
865 		};
866 		char *st_reg_map[][2] = {
867 			{ "$ctx", "r1" },
868 			{ "$src", "42" },
869 			{}
870 		};
871 
872 		if (test->write || test->write_stx) {
873 			char *pattern = test->write_stx ? test->write_stx : test->write;
874 
875 			pinfo.prog_kind = "STX";
876 			pinfo.prog = stx_prog;
877 			pinfo.prog_len = ARRAY_SIZE(stx_prog);
878 			match_program(btf, &pinfo, pattern, stx_reg_map, true);
879 		}
880 
881 		if (test->write || test->write_st) {
882 			char *pattern = test->write_st ? test->write_st : test->write;
883 
884 			pinfo.prog_kind = "ST";
885 			pinfo.prog = st_prog;
886 			pinfo.prog_len = ARRAY_SIZE(st_prog);
887 			match_program(btf, &pinfo, pattern, st_reg_map, false);
888 		}
889 	}
890 
891 	test__end_subtest();
892 }
893 
894 void test_ctx_rewrite(void)
895 {
896 	struct btf *btf;
897 	int i;
898 
899 	field_regex = compile_regex("^([[:alpha:]_][[:alnum:]_]+)::([[:alpha:]_][[:alnum:]_]+)");
900 	ident_regex = compile_regex("^[[:alpha:]_][[:alnum:]_]+");
901 	if (!field_regex || !ident_regex)
902 		return;
903 
904 	btf = btf__load_vmlinux_btf();
905 	if (!btf) {
906 		PRINT_FAIL("Can't load vmlinux BTF, errno %d (%s)\n", errno, strerror(errno));
907 		goto out;
908 	}
909 
910 	for (i = 0; i < ARRAY_SIZE(test_cases); ++i)
911 		run_one_testcase(btf, &test_cases[i]);
912 
913 out:
914 	btf__free(btf);
915 	free_regex(field_regex);
916 	free_regex(ident_regex);
917 }
918