xref: /linux/tools/testing/selftests/bpf/prog_tests/ctx_rewrite.c (revision 2a52ca7c98960aafb0eca9ef96b2d0c932171357)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <limits.h>
4 #include <stdio.h>
5 #include <string.h>
6 #include <ctype.h>
7 #include <regex.h>
8 #include <test_progs.h>
9 
10 #include "bpf/btf.h"
11 #include "bpf_util.h"
12 #include "linux/filter.h"
13 #include "disasm.h"
14 
15 #define MAX_PROG_TEXT_SZ (32 * 1024)
16 
17 /* The code in this file serves the sole purpose of executing test cases
18  * specified in the test_cases array. Each test case specifies a program
19  * type, context field offset, and disassembly patterns that correspond
20  * to read and write instructions generated by
21  * verifier.c:convert_ctx_access() for accessing that field.
22  *
23  * For each test case, up to three programs are created:
24  * - One that uses BPF_LDX_MEM to read the context field.
25  * - One that uses BPF_STX_MEM to write to the context field.
26  * - One that uses BPF_ST_MEM to write to the context field.
27  *
28  * The disassembly of each program is then compared with the pattern
29  * specified in the test case.
30  */
31 struct test_case {
32 	char *name;
33 	enum bpf_prog_type prog_type;
34 	enum bpf_attach_type expected_attach_type;
35 	int field_offset;
36 	int field_sz;
37 	/* Program generated for BPF_ST_MEM uses value 42 by default,
38 	 * this field allows to specify custom value.
39 	 */
40 	struct {
41 		bool use;
42 		int value;
43 	} st_value;
44 	/* Pattern for BPF_LDX_MEM(field_sz, dst, ctx, field_offset) */
45 	char *read;
46 	/* Pattern for BPF_STX_MEM(field_sz, ctx, src, field_offset) and
47 	 *             BPF_ST_MEM (field_sz, ctx, src, field_offset)
48 	 */
49 	char *write;
50 	/* Pattern for BPF_ST_MEM(field_sz, ctx, src, field_offset),
51 	 * takes priority over `write`.
52 	 */
53 	char *write_st;
54 	/* Pattern for BPF_STX_MEM (field_sz, ctx, src, field_offset),
55 	 * takes priority over `write`.
56 	 */
57 	char *write_stx;
58 };
59 
60 #define N(_prog_type, type, field, name_extra...)	\
61 	.name = #_prog_type "." #field name_extra,	\
62 	.prog_type = BPF_PROG_TYPE_##_prog_type,	\
63 	.field_offset = offsetof(type, field),		\
64 	.field_sz = sizeof(typeof(((type *)NULL)->field))
65 
66 static struct test_case test_cases[] = {
67 /* Sign extension on s390 changes the pattern */
68 #if defined(__x86_64__) || defined(__aarch64__)
69 	{
70 		N(SCHED_CLS, struct __sk_buff, tstamp),
71 		.read  = "r11 = *(u8 *)($ctx + sk_buff::__mono_tc_offset);"
72 			 "if w11 & 0x4 goto pc+1;"
73 			 "goto pc+4;"
74 			 "if w11 & 0x3 goto pc+1;"
75 			 "goto pc+2;"
76 			 "$dst = 0;"
77 			 "goto pc+1;"
78 			 "$dst = *(u64 *)($ctx + sk_buff::tstamp);",
79 		.write = "r11 = *(u8 *)($ctx + sk_buff::__mono_tc_offset);"
80 			 "if w11 & 0x4 goto pc+1;"
81 			 "goto pc+2;"
82 			 "w11 &= -4;"
83 			 "*(u8 *)($ctx + sk_buff::__mono_tc_offset) = r11;"
84 			 "*(u64 *)($ctx + sk_buff::tstamp) = $src;",
85 	},
86 #endif
87 	{
88 		N(SCHED_CLS, struct __sk_buff, priority),
89 		.read  = "$dst = *(u32 *)($ctx + sk_buff::priority);",
90 		.write = "*(u32 *)($ctx + sk_buff::priority) = $src;",
91 	},
92 	{
93 		N(SCHED_CLS, struct __sk_buff, mark),
94 		.read  = "$dst = *(u32 *)($ctx + sk_buff::mark);",
95 		.write = "*(u32 *)($ctx + sk_buff::mark) = $src;",
96 	},
97 	{
98 		N(SCHED_CLS, struct __sk_buff, cb[0]),
99 		.read  = "$dst = *(u32 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::data));",
100 		.write = "*(u32 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::data)) = $src;",
101 	},
102 	{
103 		N(SCHED_CLS, struct __sk_buff, tc_classid),
104 		.read  = "$dst = *(u16 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::tc_classid));",
105 		.write = "*(u16 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::tc_classid)) = $src;",
106 	},
107 	{
108 		N(SCHED_CLS, struct __sk_buff, tc_index),
109 		.read  = "$dst = *(u16 *)($ctx + sk_buff::tc_index);",
110 		.write = "*(u16 *)($ctx + sk_buff::tc_index) = $src;",
111 	},
112 	{
113 		N(SCHED_CLS, struct __sk_buff, queue_mapping),
114 		.read      = "$dst = *(u16 *)($ctx + sk_buff::queue_mapping);",
115 		.write_stx = "if $src >= 0xffff goto pc+1;"
116 			     "*(u16 *)($ctx + sk_buff::queue_mapping) = $src;",
117 		.write_st  = "*(u16 *)($ctx + sk_buff::queue_mapping) = $src;",
118 	},
119 	{
120 		/* This is a corner case in filter.c:bpf_convert_ctx_access() */
121 		N(SCHED_CLS, struct __sk_buff, queue_mapping, ".ushrt_max"),
122 		.st_value = { true, USHRT_MAX },
123 		.write_st = "goto pc+0;",
124 	},
125 	{
126 		N(CGROUP_SOCK, struct bpf_sock, bound_dev_if),
127 		.read  = "$dst = *(u32 *)($ctx + sock_common::skc_bound_dev_if);",
128 		.write = "*(u32 *)($ctx + sock_common::skc_bound_dev_if) = $src;",
129 	},
130 	{
131 		N(CGROUP_SOCK, struct bpf_sock, mark),
132 		.read  = "$dst = *(u32 *)($ctx + sock::sk_mark);",
133 		.write = "*(u32 *)($ctx + sock::sk_mark) = $src;",
134 	},
135 	{
136 		N(CGROUP_SOCK, struct bpf_sock, priority),
137 		.read  = "$dst = *(u32 *)($ctx + sock::sk_priority);",
138 		.write = "*(u32 *)($ctx + sock::sk_priority) = $src;",
139 	},
140 	{
141 		N(SOCK_OPS, struct bpf_sock_ops, replylong[0]),
142 		.read  = "$dst = *(u32 *)($ctx + bpf_sock_ops_kern::replylong);",
143 		.write = "*(u32 *)($ctx + bpf_sock_ops_kern::replylong) = $src;",
144 	},
145 	{
146 		N(CGROUP_SYSCTL, struct bpf_sysctl, file_pos),
147 #if __BYTE_ORDER == __LITTLE_ENDIAN
148 		.read  = "$dst = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
149 			 "$dst = *(u32 *)($dst +0);",
150 		.write = "*(u64 *)($ctx + bpf_sysctl_kern::tmp_reg) = r9;"
151 			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
152 			 "*(u32 *)(r9 +0) = $src;"
153 			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::tmp_reg);",
154 #else
155 		.read  = "$dst = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
156 			 "$dst = *(u32 *)($dst +4);",
157 		.write = "*(u64 *)($ctx + bpf_sysctl_kern::tmp_reg) = r9;"
158 			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
159 			 "*(u32 *)(r9 +4) = $src;"
160 			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::tmp_reg);",
161 #endif
162 	},
163 	{
164 		N(CGROUP_SOCKOPT, struct bpf_sockopt, sk),
165 		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::sk);",
166 		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
167 	},
168 	{
169 		N(CGROUP_SOCKOPT, struct bpf_sockopt, level),
170 		.read  = "$dst = *(u32 *)($ctx + bpf_sockopt_kern::level);",
171 		.write = "*(u32 *)($ctx + bpf_sockopt_kern::level) = $src;",
172 		.expected_attach_type = BPF_CGROUP_SETSOCKOPT,
173 	},
174 	{
175 		N(CGROUP_SOCKOPT, struct bpf_sockopt, optname),
176 		.read  = "$dst = *(u32 *)($ctx + bpf_sockopt_kern::optname);",
177 		.write = "*(u32 *)($ctx + bpf_sockopt_kern::optname) = $src;",
178 		.expected_attach_type = BPF_CGROUP_SETSOCKOPT,
179 	},
180 	{
181 		N(CGROUP_SOCKOPT, struct bpf_sockopt, optlen),
182 		.read  = "$dst = *(u32 *)($ctx + bpf_sockopt_kern::optlen);",
183 		.write = "*(u32 *)($ctx + bpf_sockopt_kern::optlen) = $src;",
184 		.expected_attach_type = BPF_CGROUP_SETSOCKOPT,
185 	},
186 	{
187 		N(CGROUP_SOCKOPT, struct bpf_sockopt, retval),
188 		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::current_task);"
189 			 "$dst = *(u64 *)($dst + task_struct::bpf_ctx);"
190 			 "$dst = *(u32 *)($dst + bpf_cg_run_ctx::retval);",
191 		.write = "*(u64 *)($ctx + bpf_sockopt_kern::tmp_reg) = r9;"
192 			 "r9 = *(u64 *)($ctx + bpf_sockopt_kern::current_task);"
193 			 "r9 = *(u64 *)(r9 + task_struct::bpf_ctx);"
194 			 "*(u32 *)(r9 + bpf_cg_run_ctx::retval) = $src;"
195 			 "r9 = *(u64 *)($ctx + bpf_sockopt_kern::tmp_reg);",
196 		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
197 	},
198 	{
199 		N(CGROUP_SOCKOPT, struct bpf_sockopt, optval),
200 		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::optval);",
201 		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
202 	},
203 	{
204 		N(CGROUP_SOCKOPT, struct bpf_sockopt, optval_end),
205 		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::optval_end);",
206 		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
207 	},
208 };
209 
210 #undef N
211 
212 static regex_t *ident_regex;
213 static regex_t *field_regex;
214 
215 static char *skip_space(char *str)
216 {
217 	while (*str && isspace(*str))
218 		++str;
219 	return str;
220 }
221 
222 static char *skip_space_and_semi(char *str)
223 {
224 	while (*str && (isspace(*str) || *str == ';'))
225 		++str;
226 	return str;
227 }
228 
229 static char *match_str(char *str, char *prefix)
230 {
231 	while (*str && *prefix && *str == *prefix) {
232 		++str;
233 		++prefix;
234 	}
235 	if (*prefix)
236 		return NULL;
237 	return str;
238 }
239 
240 static char *match_number(char *str, int num)
241 {
242 	char *next;
243 	int snum = strtol(str, &next, 10);
244 
245 	if (next - str == 0 || num != snum)
246 		return NULL;
247 
248 	return next;
249 }
250 
251 static int find_field_offset_aux(struct btf *btf, int btf_id, char *field_name, int off)
252 {
253 	const struct btf_type *type = btf__type_by_id(btf, btf_id);
254 	const struct btf_member *m;
255 	__u16 mnum;
256 	int i;
257 
258 	if (!type) {
259 		PRINT_FAIL("Can't find btf_type for id %d\n", btf_id);
260 		return -1;
261 	}
262 
263 	if (!btf_is_struct(type) && !btf_is_union(type)) {
264 		PRINT_FAIL("BTF id %d is not struct or union\n", btf_id);
265 		return -1;
266 	}
267 
268 	m = btf_members(type);
269 	mnum = btf_vlen(type);
270 
271 	for (i = 0; i < mnum; ++i, ++m) {
272 		const char *mname = btf__name_by_offset(btf, m->name_off);
273 
274 		if (strcmp(mname, "") == 0) {
275 			int msize = find_field_offset_aux(btf, m->type, field_name,
276 							  off + m->offset);
277 			if (msize >= 0)
278 				return msize;
279 		}
280 
281 		if (strcmp(mname, field_name))
282 			continue;
283 
284 		return (off + m->offset) / 8;
285 	}
286 
287 	return -1;
288 }
289 
290 static int find_field_offset(struct btf *btf, char *pattern, regmatch_t *matches)
291 {
292 	int type_sz  = matches[1].rm_eo - matches[1].rm_so;
293 	int field_sz = matches[2].rm_eo - matches[2].rm_so;
294 	char *type   = pattern + matches[1].rm_so;
295 	char *field  = pattern + matches[2].rm_so;
296 	char field_str[128] = {};
297 	char type_str[128] = {};
298 	int btf_id, field_offset;
299 
300 	if (type_sz >= sizeof(type_str)) {
301 		PRINT_FAIL("Malformed pattern: type ident is too long: %d\n", type_sz);
302 		return -1;
303 	}
304 
305 	if (field_sz >= sizeof(field_str)) {
306 		PRINT_FAIL("Malformed pattern: field ident is too long: %d\n", field_sz);
307 		return -1;
308 	}
309 
310 	strncpy(type_str, type, type_sz);
311 	strncpy(field_str, field, field_sz);
312 	btf_id = btf__find_by_name(btf, type_str);
313 	if (btf_id < 0) {
314 		PRINT_FAIL("No BTF info for type %s\n", type_str);
315 		return -1;
316 	}
317 
318 	field_offset = find_field_offset_aux(btf, btf_id, field_str, 0);
319 	if (field_offset < 0) {
320 		PRINT_FAIL("No BTF info for field %s::%s\n", type_str, field_str);
321 		return -1;
322 	}
323 
324 	return field_offset;
325 }
326 
327 static regex_t *compile_regex(char *pat)
328 {
329 	regex_t *re;
330 	int err;
331 
332 	re = malloc(sizeof(regex_t));
333 	if (!re) {
334 		PRINT_FAIL("Can't alloc regex\n");
335 		return NULL;
336 	}
337 
338 	err = regcomp(re, pat, REG_EXTENDED);
339 	if (err) {
340 		char errbuf[512];
341 
342 		regerror(err, re, errbuf, sizeof(errbuf));
343 		PRINT_FAIL("Can't compile regex: %s\n", errbuf);
344 		free(re);
345 		return NULL;
346 	}
347 
348 	return re;
349 }
350 
351 static void free_regex(regex_t *re)
352 {
353 	if (!re)
354 		return;
355 
356 	regfree(re);
357 	free(re);
358 }
359 
360 static u32 max_line_len(char *str)
361 {
362 	u32 max_line = 0;
363 	char *next = str;
364 
365 	while (next) {
366 		next = strchr(str, '\n');
367 		if (next) {
368 			max_line = max_t(u32, max_line, (next - str));
369 			str = next + 1;
370 		} else {
371 			max_line = max_t(u32, max_line, strlen(str));
372 		}
373 	}
374 
375 	return min(max_line, 60u);
376 }
377 
378 /* Print strings `pattern_origin` and `text_origin` side by side,
379  * assume `pattern_pos` and `text_pos` designate location within
380  * corresponding origin string where match diverges.
381  * The output should look like:
382  *
383  *   Can't match disassembly(left) with pattern(right):
384  *   r2 = *(u64 *)(r1 +0)  ;  $dst = *(u64 *)($ctx + bpf_sockopt_kern::sk1)
385  *                     ^                             ^
386  *   r0 = 0                ;
387  *   exit                  ;
388  */
389 static void print_match_error(FILE *out,
390 			      char *pattern_origin, char *text_origin,
391 			      char *pattern_pos, char *text_pos)
392 {
393 	char *pattern = pattern_origin;
394 	char *text = text_origin;
395 	int middle = max_line_len(text) + 2;
396 
397 	fprintf(out, "Can't match disassembly(left) with pattern(right):\n");
398 	while (*pattern || *text) {
399 		int column = 0;
400 		int mark1 = -1;
401 		int mark2 = -1;
402 
403 		/* Print one line from text */
404 		while (*text && *text != '\n') {
405 			if (text == text_pos)
406 				mark1 = column;
407 			fputc(*text, out);
408 			++text;
409 			++column;
410 		}
411 		if (text == text_pos)
412 			mark1 = column;
413 
414 		/* Pad to the middle */
415 		while (column < middle) {
416 			fputc(' ', out);
417 			++column;
418 		}
419 		fputs(";  ", out);
420 		column += 3;
421 
422 		/* Print one line from pattern, pattern lines are terminated by ';' */
423 		while (*pattern && *pattern != ';') {
424 			if (pattern == pattern_pos)
425 				mark2 = column;
426 			fputc(*pattern, out);
427 			++pattern;
428 			++column;
429 		}
430 		if (pattern == pattern_pos)
431 			mark2 = column;
432 
433 		fputc('\n', out);
434 		if (*pattern)
435 			++pattern;
436 		if (*text)
437 			++text;
438 
439 		/* If pattern and text diverge at this line, print an
440 		 * additional line with '^' marks, highlighting
441 		 * positions where match fails.
442 		 */
443 		if (mark1 > 0 || mark2 > 0) {
444 			for (column = 0; column <= max(mark1, mark2); ++column) {
445 				if (column == mark1 || column == mark2)
446 					fputc('^', out);
447 				else
448 					fputc(' ', out);
449 			}
450 			fputc('\n', out);
451 		}
452 	}
453 }
454 
455 /* Test if `text` matches `pattern`. Pattern consists of the following elements:
456  *
457  * - Field offset references:
458  *
459  *     <type>::<field>
460  *
461  *   When such reference is encountered BTF is used to compute numerical
462  *   value for the offset of <field> in <type>. The `text` is expected to
463  *   contain matching numerical value.
464  *
465  * - Field groups:
466  *
467  *     $(<type>::<field> [+ <type>::<field>]*)
468  *
469  *   Allows to specify an offset that is a sum of multiple field offsets.
470  *   The `text` is expected to contain matching numerical value.
471  *
472  * - Variable references, e.g. `$src`, `$dst`, `$ctx`.
473  *   These are substitutions specified in `reg_map` array.
474  *   If a substring of pattern is equal to `reg_map[i][0]` the `text` is
475  *   expected to contain `reg_map[i][1]` in the matching position.
476  *
477  * - Whitespace is ignored, ';' counts as whitespace for `pattern`.
478  *
479  * - Any other characters, `pattern` and `text` should match one-to-one.
480  *
481  * Example of a pattern:
482  *
483  *                    __________ fields group ________________
484  *                   '                                        '
485  *   *(u16 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::tc_classid)) = $src;
486  *            ^^^^                   '______________________'
487  *     variable reference             field offset reference
488  */
489 static bool match_pattern(struct btf *btf, char *pattern, char *text, char *reg_map[][2])
490 {
491 	char *pattern_origin = pattern;
492 	char *text_origin = text;
493 	regmatch_t matches[3];
494 
495 _continue:
496 	while (*pattern) {
497 		if (!*text)
498 			goto err;
499 
500 		/* Skip whitespace */
501 		if (isspace(*pattern) || *pattern == ';') {
502 			if (!isspace(*text) && text != text_origin && isalnum(text[-1]))
503 				goto err;
504 			pattern = skip_space_and_semi(pattern);
505 			text = skip_space(text);
506 			continue;
507 		}
508 
509 		/* Check for variable references */
510 		for (int i = 0; reg_map[i][0]; ++i) {
511 			char *pattern_next, *text_next;
512 
513 			pattern_next = match_str(pattern, reg_map[i][0]);
514 			if (!pattern_next)
515 				continue;
516 
517 			text_next = match_str(text, reg_map[i][1]);
518 			if (!text_next)
519 				goto err;
520 
521 			pattern = pattern_next;
522 			text = text_next;
523 			goto _continue;
524 		}
525 
526 		/* Match field group:
527 		 *   $(sk_buff::cb + qdisc_skb_cb::tc_classid)
528 		 */
529 		if (strncmp(pattern, "$(", 2) == 0) {
530 			char *group_start = pattern, *text_next;
531 			int acc_offset = 0;
532 
533 			pattern += 2;
534 
535 			for (;;) {
536 				int field_offset;
537 
538 				pattern = skip_space(pattern);
539 				if (!*pattern) {
540 					PRINT_FAIL("Unexpected end of pattern\n");
541 					goto err;
542 				}
543 
544 				if (*pattern == ')') {
545 					++pattern;
546 					break;
547 				}
548 
549 				if (*pattern == '+') {
550 					++pattern;
551 					continue;
552 				}
553 
554 				printf("pattern: %s\n", pattern);
555 				if (regexec(field_regex, pattern, 3, matches, 0) != 0) {
556 					PRINT_FAIL("Field reference expected\n");
557 					goto err;
558 				}
559 
560 				field_offset = find_field_offset(btf, pattern, matches);
561 				if (field_offset < 0)
562 					goto err;
563 
564 				pattern += matches[0].rm_eo;
565 				acc_offset += field_offset;
566 			}
567 
568 			text_next = match_number(text, acc_offset);
569 			if (!text_next) {
570 				PRINT_FAIL("No match for group offset %.*s (%d)\n",
571 					   (int)(pattern - group_start),
572 					   group_start,
573 					   acc_offset);
574 				goto err;
575 			}
576 			text = text_next;
577 		}
578 
579 		/* Match field reference:
580 		 *   sk_buff::cb
581 		 */
582 		if (regexec(field_regex, pattern, 3, matches, 0) == 0) {
583 			int field_offset;
584 			char *text_next;
585 
586 			field_offset = find_field_offset(btf, pattern, matches);
587 			if (field_offset < 0)
588 				goto err;
589 
590 			text_next = match_number(text, field_offset);
591 			if (!text_next) {
592 				PRINT_FAIL("No match for field offset %.*s (%d)\n",
593 					   (int)matches[0].rm_eo, pattern, field_offset);
594 				goto err;
595 			}
596 
597 			pattern += matches[0].rm_eo;
598 			text = text_next;
599 			continue;
600 		}
601 
602 		/* If pattern points to identifier not followed by '::'
603 		 * skip the identifier to avoid n^2 application of the
604 		 * field reference rule.
605 		 */
606 		if (regexec(ident_regex, pattern, 1, matches, 0) == 0) {
607 			if (strncmp(pattern, text, matches[0].rm_eo) != 0)
608 				goto err;
609 
610 			pattern += matches[0].rm_eo;
611 			text += matches[0].rm_eo;
612 			continue;
613 		}
614 
615 		/* Match literally */
616 		if (*pattern != *text)
617 			goto err;
618 
619 		++pattern;
620 		++text;
621 	}
622 
623 	return true;
624 
625 err:
626 	test__fail();
627 	print_match_error(stdout, pattern_origin, text_origin, pattern, text);
628 	return false;
629 }
630 
631 static void print_insn(void *private_data, const char *fmt, ...)
632 {
633 	va_list args;
634 
635 	va_start(args, fmt);
636 	vfprintf((FILE *)private_data, fmt, args);
637 	va_end(args);
638 }
639 
640 /* Disassemble instructions to a stream */
641 static void print_xlated(FILE *out, struct bpf_insn *insn, __u32 len)
642 {
643 	const struct bpf_insn_cbs cbs = {
644 		.cb_print	= print_insn,
645 		.cb_call	= NULL,
646 		.cb_imm		= NULL,
647 		.private_data	= out,
648 	};
649 	bool double_insn = false;
650 	int i;
651 
652 	for (i = 0; i < len; i++) {
653 		if (double_insn) {
654 			double_insn = false;
655 			continue;
656 		}
657 
658 		double_insn = insn[i].code == (BPF_LD | BPF_IMM | BPF_DW);
659 		print_bpf_insn(&cbs, insn + i, true);
660 	}
661 }
662 
663 /* We share code with kernel BPF disassembler, it adds '(FF) ' prefix
664  * for each instruction (FF stands for instruction `code` byte).
665  * This function removes the prefix inplace for each line in `str`.
666  */
667 static void remove_insn_prefix(char *str, int size)
668 {
669 	const int prefix_size = 5;
670 
671 	int write_pos = 0, read_pos = prefix_size;
672 	int len = strlen(str);
673 	char c;
674 
675 	size = min(size, len);
676 
677 	while (read_pos < size) {
678 		c = str[read_pos++];
679 		if (c == 0)
680 			break;
681 		str[write_pos++] = c;
682 		if (c == '\n')
683 			read_pos += prefix_size;
684 	}
685 	str[write_pos] = 0;
686 }
687 
688 struct prog_info {
689 	char *prog_kind;
690 	enum bpf_prog_type prog_type;
691 	enum bpf_attach_type expected_attach_type;
692 	struct bpf_insn *prog;
693 	u32 prog_len;
694 };
695 
696 static void match_program(struct btf *btf,
697 			  struct prog_info *pinfo,
698 			  char *pattern,
699 			  char *reg_map[][2],
700 			  bool skip_first_insn)
701 {
702 	struct bpf_insn *buf = NULL;
703 	int err = 0, prog_fd = 0;
704 	FILE *prog_out = NULL;
705 	char *text = NULL;
706 	__u32 cnt = 0;
707 
708 	text = calloc(MAX_PROG_TEXT_SZ, 1);
709 	if (!text) {
710 		PRINT_FAIL("Can't allocate %d bytes\n", MAX_PROG_TEXT_SZ);
711 		goto out;
712 	}
713 
714 	// TODO: log level
715 	LIBBPF_OPTS(bpf_prog_load_opts, opts);
716 	opts.log_buf = text;
717 	opts.log_size = MAX_PROG_TEXT_SZ;
718 	opts.log_level = 1 | 2 | 4;
719 	opts.expected_attach_type = pinfo->expected_attach_type;
720 
721 	prog_fd = bpf_prog_load(pinfo->prog_type, NULL, "GPL",
722 				pinfo->prog, pinfo->prog_len, &opts);
723 	if (prog_fd < 0) {
724 		PRINT_FAIL("Can't load program, errno %d (%s), verifier log:\n%s\n",
725 			   errno, strerror(errno), text);
726 		goto out;
727 	}
728 
729 	memset(text, 0, MAX_PROG_TEXT_SZ);
730 
731 	err = get_xlated_program(prog_fd, &buf, &cnt);
732 	if (err) {
733 		PRINT_FAIL("Can't load back BPF program\n");
734 		goto out;
735 	}
736 
737 	prog_out = fmemopen(text, MAX_PROG_TEXT_SZ - 1, "w");
738 	if (!prog_out) {
739 		PRINT_FAIL("Can't open memory stream\n");
740 		goto out;
741 	}
742 	if (skip_first_insn)
743 		print_xlated(prog_out, buf + 1, cnt - 1);
744 	else
745 		print_xlated(prog_out, buf, cnt);
746 	fclose(prog_out);
747 	remove_insn_prefix(text, MAX_PROG_TEXT_SZ);
748 
749 	ASSERT_TRUE(match_pattern(btf, pattern, text, reg_map),
750 		    pinfo->prog_kind);
751 
752 out:
753 	if (prog_fd)
754 		close(prog_fd);
755 	free(buf);
756 	free(text);
757 }
758 
759 static void run_one_testcase(struct btf *btf, struct test_case *test)
760 {
761 	struct prog_info pinfo = {};
762 	int bpf_sz;
763 
764 	if (!test__start_subtest(test->name))
765 		return;
766 
767 	switch (test->field_sz) {
768 	case 8:
769 		bpf_sz = BPF_DW;
770 		break;
771 	case 4:
772 		bpf_sz = BPF_W;
773 		break;
774 	case 2:
775 		bpf_sz = BPF_H;
776 		break;
777 	case 1:
778 		bpf_sz = BPF_B;
779 		break;
780 	default:
781 		PRINT_FAIL("Unexpected field size: %d, want 8,4,2 or 1\n", test->field_sz);
782 		return;
783 	}
784 
785 	pinfo.prog_type = test->prog_type;
786 	pinfo.expected_attach_type = test->expected_attach_type;
787 
788 	if (test->read) {
789 		struct bpf_insn ldx_prog[] = {
790 			BPF_LDX_MEM(bpf_sz, BPF_REG_2, BPF_REG_1, test->field_offset),
791 			BPF_MOV64_IMM(BPF_REG_0, 0),
792 			BPF_EXIT_INSN(),
793 		};
794 		char *reg_map[][2] = {
795 			{ "$ctx", "r1" },
796 			{ "$dst", "r2" },
797 			{}
798 		};
799 
800 		pinfo.prog_kind = "LDX";
801 		pinfo.prog = ldx_prog;
802 		pinfo.prog_len = ARRAY_SIZE(ldx_prog);
803 		match_program(btf, &pinfo, test->read, reg_map, false);
804 	}
805 
806 	if (test->write || test->write_st || test->write_stx) {
807 		struct bpf_insn stx_prog[] = {
808 			BPF_MOV64_IMM(BPF_REG_2, 0),
809 			BPF_STX_MEM(bpf_sz, BPF_REG_1, BPF_REG_2, test->field_offset),
810 			BPF_MOV64_IMM(BPF_REG_0, 0),
811 			BPF_EXIT_INSN(),
812 		};
813 		char *stx_reg_map[][2] = {
814 			{ "$ctx", "r1" },
815 			{ "$src", "r2" },
816 			{}
817 		};
818 		struct bpf_insn st_prog[] = {
819 			BPF_ST_MEM(bpf_sz, BPF_REG_1, test->field_offset,
820 				   test->st_value.use ? test->st_value.value : 42),
821 			BPF_MOV64_IMM(BPF_REG_0, 0),
822 			BPF_EXIT_INSN(),
823 		};
824 		char *st_reg_map[][2] = {
825 			{ "$ctx", "r1" },
826 			{ "$src", "42" },
827 			{}
828 		};
829 
830 		if (test->write || test->write_stx) {
831 			char *pattern = test->write_stx ? test->write_stx : test->write;
832 
833 			pinfo.prog_kind = "STX";
834 			pinfo.prog = stx_prog;
835 			pinfo.prog_len = ARRAY_SIZE(stx_prog);
836 			match_program(btf, &pinfo, pattern, stx_reg_map, true);
837 		}
838 
839 		if (test->write || test->write_st) {
840 			char *pattern = test->write_st ? test->write_st : test->write;
841 
842 			pinfo.prog_kind = "ST";
843 			pinfo.prog = st_prog;
844 			pinfo.prog_len = ARRAY_SIZE(st_prog);
845 			match_program(btf, &pinfo, pattern, st_reg_map, false);
846 		}
847 	}
848 
849 	test__end_subtest();
850 }
851 
852 void test_ctx_rewrite(void)
853 {
854 	struct btf *btf;
855 	int i;
856 
857 	field_regex = compile_regex("^([[:alpha:]_][[:alnum:]_]+)::([[:alpha:]_][[:alnum:]_]+)");
858 	ident_regex = compile_regex("^[[:alpha:]_][[:alnum:]_]+");
859 	if (!field_regex || !ident_regex)
860 		return;
861 
862 	btf = btf__load_vmlinux_btf();
863 	if (!btf) {
864 		PRINT_FAIL("Can't load vmlinux BTF, errno %d (%s)\n", errno, strerror(errno));
865 		goto out;
866 	}
867 
868 	for (i = 0; i < ARRAY_SIZE(test_cases); ++i)
869 		run_one_testcase(btf, &test_cases[i]);
870 
871 out:
872 	btf__free(btf);
873 	free_regex(field_regex);
874 	free_regex(ident_regex);
875 }
876