xref: /linux/tools/testing/selftests/bpf/prog_tests/ctx_rewrite.c (revision 246ad6e5ee259669692bdb7fb353e8c5d5bba628)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <limits.h>
4 #include <stdio.h>
5 #include <string.h>
6 #include <ctype.h>
7 #include <regex.h>
8 #include <test_progs.h>
9 
10 #include "bpf/btf.h"
11 #include "bpf_util.h"
12 #include "linux/filter.h"
13 #include "linux/kernel.h"
14 #include "disasm_helpers.h"
15 
16 #define MAX_PROG_TEXT_SZ (32 * 1024)
17 
18 /* The code in this file serves the sole purpose of executing test cases
19  * specified in the test_cases array. Each test case specifies a program
20  * type, context field offset, and disassembly patterns that correspond
21  * to read and write instructions generated by
22  * verifier.c:convert_ctx_access() for accessing that field.
23  *
24  * For each test case, up to three programs are created:
25  * - One that uses BPF_LDX_MEM to read the context field.
26  * - One that uses BPF_STX_MEM to write to the context field.
27  * - One that uses BPF_ST_MEM to write to the context field.
28  *
29  * The disassembly of each program is then compared with the pattern
30  * specified in the test case.
31  */
32 struct test_case {
33 	char *name;
34 	enum bpf_prog_type prog_type;
35 	enum bpf_attach_type expected_attach_type;
36 	int field_offset;
37 	int field_sz;
38 	/* Program generated for BPF_ST_MEM uses value 42 by default,
39 	 * this field allows to specify custom value.
40 	 */
41 	struct {
42 		bool use;
43 		int value;
44 	} st_value;
45 	/* Pattern for BPF_LDX_MEM(field_sz, dst, ctx, field_offset) */
46 	char *read;
47 	/* Pattern for BPF_STX_MEM(field_sz, ctx, src, field_offset) and
48 	 *             BPF_ST_MEM (field_sz, ctx, src, field_offset)
49 	 */
50 	char *write;
51 	/* Pattern for BPF_ST_MEM(field_sz, ctx, src, field_offset),
52 	 * takes priority over `write`.
53 	 */
54 	char *write_st;
55 	/* Pattern for BPF_STX_MEM (field_sz, ctx, src, field_offset),
56 	 * takes priority over `write`.
57 	 */
58 	char *write_stx;
59 };
60 
61 #define N(_prog_type, type, field, name_extra...)	\
62 	.name = #_prog_type "." #field name_extra,	\
63 	.prog_type = BPF_PROG_TYPE_##_prog_type,	\
64 	.field_offset = offsetof(type, field),		\
65 	.field_sz = sizeof(typeof(((type *)NULL)->field))
66 
67 static struct test_case test_cases[] = {
68 /* Sign extension on s390 changes the pattern */
69 #if defined(__x86_64__) || defined(__aarch64__)
70 	{
71 		N(SCHED_CLS, struct __sk_buff, tstamp),
72 		.read  = "r12 = *(u8 *)($ctx + sk_buff::__mono_tc_offset);"
73 			 "if w12 & 0x4 goto pc+1;"
74 			 "goto pc+4;"
75 			 "if w12 & 0x3 goto pc+1;"
76 			 "goto pc+2;"
77 			 "$dst = 0;"
78 			 "goto pc+1;"
79 			 "$dst = *(u64 *)($ctx + sk_buff::tstamp);",
80 		.write = "r12 = *(u8 *)($ctx + sk_buff::__mono_tc_offset);"
81 			 "if w12 & 0x4 goto pc+1;"
82 			 "goto pc+2;"
83 			 "w12 &= -4;"
84 			 "*(u8 *)($ctx + sk_buff::__mono_tc_offset) = r12;"
85 			 "*(u64 *)($ctx + sk_buff::tstamp) = $src;",
86 	},
87 #endif
88 	{
89 		N(SCHED_CLS, struct __sk_buff, priority),
90 		.read  = "$dst = *(u32 *)($ctx + sk_buff::priority);",
91 		.write = "*(u32 *)($ctx + sk_buff::priority) = $src;",
92 	},
93 	{
94 		N(SCHED_CLS, struct __sk_buff, mark),
95 		.read  = "$dst = *(u32 *)($ctx + sk_buff::mark);",
96 		.write = "*(u32 *)($ctx + sk_buff::mark) = $src;",
97 	},
98 	{
99 		N(SCHED_CLS, struct __sk_buff, cb[0]),
100 		.read  = "$dst = *(u32 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::data));",
101 		.write = "*(u32 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::data)) = $src;",
102 	},
103 	{
104 		N(SCHED_CLS, struct __sk_buff, tc_classid),
105 		.read  = "$dst = *(u16 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::tc_classid));",
106 		.write = "*(u16 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::tc_classid)) = $src;",
107 	},
108 	{
109 		N(SCHED_CLS, struct __sk_buff, tc_index),
110 		.read  = "$dst = *(u16 *)($ctx + sk_buff::tc_index);",
111 		.write = "*(u16 *)($ctx + sk_buff::tc_index) = $src;",
112 	},
113 	{
114 		N(SCHED_CLS, struct __sk_buff, queue_mapping),
115 		.read      = "$dst = *(u16 *)($ctx + sk_buff::queue_mapping);",
116 		.write_stx = "if $src >= 0xffff goto pc+1;"
117 			     "*(u16 *)($ctx + sk_buff::queue_mapping) = $src;",
118 		.write_st  = "*(u16 *)($ctx + sk_buff::queue_mapping) = $src;",
119 	},
120 	{
121 		/* This is a corner case in filter.c:bpf_convert_ctx_access() */
122 		N(SCHED_CLS, struct __sk_buff, queue_mapping, ".ushrt_max"),
123 		.st_value = { true, USHRT_MAX },
124 		.write_st = "goto pc+0;",
125 	},
126 	{
127 		N(CGROUP_SOCK, struct bpf_sock, bound_dev_if),
128 		.read  = "$dst = *(u32 *)($ctx + sock_common::skc_bound_dev_if);",
129 		.write = "*(u32 *)($ctx + sock_common::skc_bound_dev_if) = $src;",
130 	},
131 	{
132 		N(CGROUP_SOCK, struct bpf_sock, mark),
133 		.read  = "$dst = *(u32 *)($ctx + sock::sk_mark);",
134 		.write = "*(u32 *)($ctx + sock::sk_mark) = $src;",
135 	},
136 	{
137 		N(CGROUP_SOCK, struct bpf_sock, priority),
138 		.read  = "$dst = *(u32 *)($ctx + sock::sk_priority);",
139 		.write = "*(u32 *)($ctx + sock::sk_priority) = $src;",
140 	},
141 	{
142 		N(SOCK_OPS, struct bpf_sock_ops, replylong[0]),
143 		.read  = "$dst = *(u32 *)($ctx + bpf_sock_ops_kern::replylong);",
144 		.write = "*(u32 *)($ctx + bpf_sock_ops_kern::replylong) = $src;",
145 	},
146 	{
147 		N(CGROUP_SYSCTL, struct bpf_sysctl, file_pos),
148 #if __BYTE_ORDER == __LITTLE_ENDIAN
149 		.read  = "$dst = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
150 			 "$dst = *(u32 *)($dst +0);",
151 		.write = "*(u64 *)($ctx + bpf_sysctl_kern::tmp_reg) = r9;"
152 			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
153 			 "*(u32 *)(r9 +0) = $src;"
154 			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::tmp_reg);",
155 #else
156 		.read  = "$dst = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
157 			 "$dst = *(u32 *)($dst +4);",
158 		.write = "*(u64 *)($ctx + bpf_sysctl_kern::tmp_reg) = r9;"
159 			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
160 			 "*(u32 *)(r9 +4) = $src;"
161 			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::tmp_reg);",
162 #endif
163 	},
164 	{
165 		N(CGROUP_SOCKOPT, struct bpf_sockopt, sk),
166 		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::sk);",
167 		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
168 	},
169 	{
170 		N(CGROUP_SOCKOPT, struct bpf_sockopt, level),
171 		.read  = "$dst = *(u32 *)($ctx + bpf_sockopt_kern::level);",
172 		.write = "*(u32 *)($ctx + bpf_sockopt_kern::level) = $src;",
173 		.expected_attach_type = BPF_CGROUP_SETSOCKOPT,
174 	},
175 	{
176 		N(CGROUP_SOCKOPT, struct bpf_sockopt, optname),
177 		.read  = "$dst = *(u32 *)($ctx + bpf_sockopt_kern::optname);",
178 		.write = "*(u32 *)($ctx + bpf_sockopt_kern::optname) = $src;",
179 		.expected_attach_type = BPF_CGROUP_SETSOCKOPT,
180 	},
181 	{
182 		N(CGROUP_SOCKOPT, struct bpf_sockopt, optlen),
183 		.read  = "$dst = *(u32 *)($ctx + bpf_sockopt_kern::optlen);",
184 		.write = "*(u32 *)($ctx + bpf_sockopt_kern::optlen) = $src;",
185 		.expected_attach_type = BPF_CGROUP_SETSOCKOPT,
186 	},
187 	{
188 		N(CGROUP_SOCKOPT, struct bpf_sockopt, retval),
189 		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::current_task);"
190 			 "$dst = *(u64 *)($dst + task_struct::bpf_ctx);"
191 			 "$dst = *(u32 *)($dst + bpf_cg_run_ctx::retval);",
192 		.write = "*(u64 *)($ctx + bpf_sockopt_kern::tmp_reg) = r9;"
193 			 "r9 = *(u64 *)($ctx + bpf_sockopt_kern::current_task);"
194 			 "r9 = *(u64 *)(r9 + task_struct::bpf_ctx);"
195 			 "*(u32 *)(r9 + bpf_cg_run_ctx::retval) = $src;"
196 			 "r9 = *(u64 *)($ctx + bpf_sockopt_kern::tmp_reg);",
197 		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
198 	},
199 	{
200 		N(CGROUP_SOCKOPT, struct bpf_sockopt, optval),
201 		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::optval);",
202 		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
203 	},
204 	{
205 		N(CGROUP_SOCKOPT, struct bpf_sockopt, optval_end),
206 		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::optval_end);",
207 		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
208 	},
209 };
210 
211 #undef N
212 
213 static regex_t *ident_regex;
214 static regex_t *field_regex;
215 
216 static char *skip_space(char *str)
217 {
218 	while (*str && isspace(*str))
219 		++str;
220 	return str;
221 }
222 
223 static char *skip_space_and_semi(char *str)
224 {
225 	while (*str && (isspace(*str) || *str == ';'))
226 		++str;
227 	return str;
228 }
229 
230 static char *match_str(char *str, char *prefix)
231 {
232 	while (*str && *prefix && *str == *prefix) {
233 		++str;
234 		++prefix;
235 	}
236 	if (*prefix)
237 		return NULL;
238 	return str;
239 }
240 
241 static char *match_number(char *str, int num)
242 {
243 	char *next;
244 	int snum = strtol(str, &next, 10);
245 
246 	if (next - str == 0 || num != snum)
247 		return NULL;
248 
249 	return next;
250 }
251 
252 static int find_field_offset_aux(struct btf *btf, int btf_id, char *field_name, int off)
253 {
254 	const struct btf_type *type = btf__type_by_id(btf, btf_id);
255 	const struct btf_member *m;
256 	__u32 mnum, i;
257 
258 	if (!type) {
259 		PRINT_FAIL("Can't find btf_type for id %d\n", btf_id);
260 		return -1;
261 	}
262 
263 	if (!btf_is_struct(type) && !btf_is_union(type)) {
264 		PRINT_FAIL("BTF id %d is not struct or union\n", btf_id);
265 		return -1;
266 	}
267 
268 	m = btf_members(type);
269 	mnum = btf_vlen(type);
270 
271 	for (i = 0; i < mnum; ++i, ++m) {
272 		const char *mname = btf__name_by_offset(btf, m->name_off);
273 
274 		if (strcmp(mname, "") == 0) {
275 			int msize = find_field_offset_aux(btf, m->type, field_name,
276 							  off + m->offset);
277 			if (msize >= 0)
278 				return msize;
279 		}
280 
281 		if (strcmp(mname, field_name))
282 			continue;
283 
284 		return (off + m->offset) / 8;
285 	}
286 
287 	return -1;
288 }
289 
290 static int find_field_offset(struct btf *btf, char *pattern, regmatch_t *matches)
291 {
292 	int type_sz  = matches[1].rm_eo - matches[1].rm_so;
293 	int field_sz = matches[2].rm_eo - matches[2].rm_so;
294 	char *type   = pattern + matches[1].rm_so;
295 	char *field  = pattern + matches[2].rm_so;
296 	char field_str[128] = {};
297 	char type_str[128] = {};
298 	int btf_id, field_offset;
299 
300 	if (type_sz >= sizeof(type_str)) {
301 		PRINT_FAIL("Malformed pattern: type ident is too long: %d\n", type_sz);
302 		return -1;
303 	}
304 
305 	if (field_sz >= sizeof(field_str)) {
306 		PRINT_FAIL("Malformed pattern: field ident is too long: %d\n", field_sz);
307 		return -1;
308 	}
309 
310 	memcpy(type_str, type, type_sz);
311 	type_str[type_sz] = '\0';
312 	memcpy(field_str, field, field_sz);
313 	field_str[field_sz] = '\0';
314 	btf_id = btf__find_by_name(btf, type_str);
315 	if (btf_id < 0) {
316 		PRINT_FAIL("No BTF info for type %s\n", type_str);
317 		return -1;
318 	}
319 
320 	field_offset = find_field_offset_aux(btf, btf_id, field_str, 0);
321 	if (field_offset < 0) {
322 		PRINT_FAIL("No BTF info for field %s::%s\n", type_str, field_str);
323 		return -1;
324 	}
325 
326 	return field_offset;
327 }
328 
329 static regex_t *compile_regex(char *pat)
330 {
331 	regex_t *re;
332 	int err;
333 
334 	re = malloc(sizeof(regex_t));
335 	if (!re) {
336 		PRINT_FAIL("Can't alloc regex\n");
337 		return NULL;
338 	}
339 
340 	err = regcomp(re, pat, REG_EXTENDED);
341 	if (err) {
342 		char errbuf[512];
343 
344 		regerror(err, re, errbuf, sizeof(errbuf));
345 		PRINT_FAIL("Can't compile regex: %s\n", errbuf);
346 		free(re);
347 		return NULL;
348 	}
349 
350 	return re;
351 }
352 
353 static void free_regex(regex_t *re)
354 {
355 	if (!re)
356 		return;
357 
358 	regfree(re);
359 	free(re);
360 }
361 
362 static u32 max_line_len(char *str)
363 {
364 	u32 max_line = 0;
365 	char *next = str;
366 
367 	while (next) {
368 		next = strchr(str, '\n');
369 		if (next) {
370 			max_line = max_t(u32, max_line, (next - str));
371 			str = next + 1;
372 		} else {
373 			max_line = max_t(u32, max_line, strlen(str));
374 		}
375 	}
376 
377 	return min(max_line, 60u);
378 }
379 
380 /* Print strings `pattern_origin` and `text_origin` side by side,
381  * assume `pattern_pos` and `text_pos` designate location within
382  * corresponding origin string where match diverges.
383  * The output should look like:
384  *
385  *   Can't match disassembly(left) with pattern(right):
386  *   r2 = *(u64 *)(r1 +0)  ;  $dst = *(u64 *)($ctx + bpf_sockopt_kern::sk1)
387  *                     ^                             ^
388  *   r0 = 0                ;
389  *   exit                  ;
390  */
391 static void print_match_error(FILE *out,
392 			      char *pattern_origin, char *text_origin,
393 			      char *pattern_pos, char *text_pos)
394 {
395 	char *pattern = pattern_origin;
396 	char *text = text_origin;
397 	int middle = max_line_len(text) + 2;
398 
399 	fprintf(out, "Can't match disassembly(left) with pattern(right):\n");
400 	while (*pattern || *text) {
401 		int column = 0;
402 		int mark1 = -1;
403 		int mark2 = -1;
404 
405 		/* Print one line from text */
406 		while (*text && *text != '\n') {
407 			if (text == text_pos)
408 				mark1 = column;
409 			fputc(*text, out);
410 			++text;
411 			++column;
412 		}
413 		if (text == text_pos)
414 			mark1 = column;
415 
416 		/* Pad to the middle */
417 		while (column < middle) {
418 			fputc(' ', out);
419 			++column;
420 		}
421 		fputs(";  ", out);
422 		column += 3;
423 
424 		/* Print one line from pattern, pattern lines are terminated by ';' */
425 		while (*pattern && *pattern != ';') {
426 			if (pattern == pattern_pos)
427 				mark2 = column;
428 			fputc(*pattern, out);
429 			++pattern;
430 			++column;
431 		}
432 		if (pattern == pattern_pos)
433 			mark2 = column;
434 
435 		fputc('\n', out);
436 		if (*pattern)
437 			++pattern;
438 		if (*text)
439 			++text;
440 
441 		/* If pattern and text diverge at this line, print an
442 		 * additional line with '^' marks, highlighting
443 		 * positions where match fails.
444 		 */
445 		if (mark1 > 0 || mark2 > 0) {
446 			for (column = 0; column <= max(mark1, mark2); ++column) {
447 				if (column == mark1 || column == mark2)
448 					fputc('^', out);
449 				else
450 					fputc(' ', out);
451 			}
452 			fputc('\n', out);
453 		}
454 	}
455 }
456 
457 /* Test if `text` matches `pattern`. Pattern consists of the following elements:
458  *
459  * - Field offset references:
460  *
461  *     <type>::<field>
462  *
463  *   When such reference is encountered BTF is used to compute numerical
464  *   value for the offset of <field> in <type>. The `text` is expected to
465  *   contain matching numerical value.
466  *
467  * - Field groups:
468  *
469  *     $(<type>::<field> [+ <type>::<field>]*)
470  *
471  *   Allows to specify an offset that is a sum of multiple field offsets.
472  *   The `text` is expected to contain matching numerical value.
473  *
474  * - Variable references, e.g. `$src`, `$dst`, `$ctx`.
475  *   These are substitutions specified in `reg_map` array.
476  *   If a substring of pattern is equal to `reg_map[i][0]` the `text` is
477  *   expected to contain `reg_map[i][1]` in the matching position.
478  *
479  * - Whitespace is ignored, ';' counts as whitespace for `pattern`.
480  *
481  * - Any other characters, `pattern` and `text` should match one-to-one.
482  *
483  * Example of a pattern:
484  *
485  *                    __________ fields group ________________
486  *                   '                                        '
487  *   *(u16 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::tc_classid)) = $src;
488  *            ^^^^                   '______________________'
489  *     variable reference             field offset reference
490  */
491 static bool match_pattern(struct btf *btf, char *pattern, char *text, char *reg_map[][2])
492 {
493 	char *pattern_origin = pattern;
494 	char *text_origin = text;
495 	regmatch_t matches[3];
496 
497 _continue:
498 	while (*pattern) {
499 		if (!*text)
500 			goto err;
501 
502 		/* Skip whitespace */
503 		if (isspace(*pattern) || *pattern == ';') {
504 			if (!isspace(*text) && text != text_origin && isalnum(text[-1]))
505 				goto err;
506 			pattern = skip_space_and_semi(pattern);
507 			text = skip_space(text);
508 			continue;
509 		}
510 
511 		/* Check for variable references */
512 		for (int i = 0; reg_map[i][0]; ++i) {
513 			char *pattern_next, *text_next;
514 
515 			pattern_next = match_str(pattern, reg_map[i][0]);
516 			if (!pattern_next)
517 				continue;
518 
519 			text_next = match_str(text, reg_map[i][1]);
520 			if (!text_next)
521 				goto err;
522 
523 			pattern = pattern_next;
524 			text = text_next;
525 			goto _continue;
526 		}
527 
528 		/* Match field group:
529 		 *   $(sk_buff::cb + qdisc_skb_cb::tc_classid)
530 		 */
531 		if (strncmp(pattern, "$(", 2) == 0) {
532 			char *group_start = pattern, *text_next;
533 			int acc_offset = 0;
534 
535 			pattern += 2;
536 
537 			for (;;) {
538 				int field_offset;
539 
540 				pattern = skip_space(pattern);
541 				if (!*pattern) {
542 					PRINT_FAIL("Unexpected end of pattern\n");
543 					goto err;
544 				}
545 
546 				if (*pattern == ')') {
547 					++pattern;
548 					break;
549 				}
550 
551 				if (*pattern == '+') {
552 					++pattern;
553 					continue;
554 				}
555 
556 				printf("pattern: %s\n", pattern);
557 				if (regexec(field_regex, pattern, 3, matches, 0) != 0) {
558 					PRINT_FAIL("Field reference expected\n");
559 					goto err;
560 				}
561 
562 				field_offset = find_field_offset(btf, pattern, matches);
563 				if (field_offset < 0)
564 					goto err;
565 
566 				pattern += matches[0].rm_eo;
567 				acc_offset += field_offset;
568 			}
569 
570 			text_next = match_number(text, acc_offset);
571 			if (!text_next) {
572 				PRINT_FAIL("No match for group offset %.*s (%d)\n",
573 					   (int)(pattern - group_start),
574 					   group_start,
575 					   acc_offset);
576 				goto err;
577 			}
578 			text = text_next;
579 		}
580 
581 		/* Match field reference:
582 		 *   sk_buff::cb
583 		 */
584 		if (regexec(field_regex, pattern, 3, matches, 0) == 0) {
585 			int field_offset;
586 			char *text_next;
587 
588 			field_offset = find_field_offset(btf, pattern, matches);
589 			if (field_offset < 0)
590 				goto err;
591 
592 			text_next = match_number(text, field_offset);
593 			if (!text_next) {
594 				PRINT_FAIL("No match for field offset %.*s (%d)\n",
595 					   (int)matches[0].rm_eo, pattern, field_offset);
596 				goto err;
597 			}
598 
599 			pattern += matches[0].rm_eo;
600 			text = text_next;
601 			continue;
602 		}
603 
604 		/* If pattern points to identifier not followed by '::'
605 		 * skip the identifier to avoid n^2 application of the
606 		 * field reference rule.
607 		 */
608 		if (regexec(ident_regex, pattern, 1, matches, 0) == 0) {
609 			if (strncmp(pattern, text, matches[0].rm_eo) != 0)
610 				goto err;
611 
612 			pattern += matches[0].rm_eo;
613 			text += matches[0].rm_eo;
614 			continue;
615 		}
616 
617 		/* Match literally */
618 		if (*pattern != *text)
619 			goto err;
620 
621 		++pattern;
622 		++text;
623 	}
624 
625 	return true;
626 
627 err:
628 	test__fail();
629 	print_match_error(stdout, pattern_origin, text_origin, pattern, text);
630 	return false;
631 }
632 
633 struct prog_info {
634 	char *prog_kind;
635 	enum bpf_prog_type prog_type;
636 	enum bpf_attach_type expected_attach_type;
637 	struct bpf_insn *prog;
638 	u32 prog_len;
639 };
640 
641 static void match_program(struct btf *btf,
642 			  struct prog_info *pinfo,
643 			  char *pattern,
644 			  char *reg_map[][2],
645 			  bool skip_first_insn)
646 {
647 	struct bpf_insn *buf = NULL, *insn, *insn_end;
648 	int err = 0, prog_fd = 0;
649 	FILE *prog_out = NULL;
650 	char insn_buf[64];
651 	char *text = NULL;
652 	__u32 cnt = 0;
653 
654 	text = calloc(MAX_PROG_TEXT_SZ, 1);
655 	if (!text) {
656 		PRINT_FAIL("Can't allocate %d bytes\n", MAX_PROG_TEXT_SZ);
657 		goto out;
658 	}
659 
660 	// TODO: log level
661 	LIBBPF_OPTS(bpf_prog_load_opts, opts);
662 	opts.log_buf = text;
663 	opts.log_size = MAX_PROG_TEXT_SZ;
664 	opts.log_level = 1 | 2 | 4;
665 	opts.expected_attach_type = pinfo->expected_attach_type;
666 
667 	prog_fd = bpf_prog_load(pinfo->prog_type, NULL, "GPL",
668 				pinfo->prog, pinfo->prog_len, &opts);
669 	if (prog_fd < 0) {
670 		PRINT_FAIL("Can't load program, errno %d (%s), verifier log:\n%s\n",
671 			   errno, strerror(errno), text);
672 		goto out;
673 	}
674 
675 	memset(text, 0, MAX_PROG_TEXT_SZ);
676 
677 	err = get_xlated_program(prog_fd, &buf, &cnt);
678 	if (err) {
679 		PRINT_FAIL("Can't load back BPF program\n");
680 		goto out;
681 	}
682 
683 	prog_out = fmemopen(text, MAX_PROG_TEXT_SZ - 1, "w");
684 	if (!prog_out) {
685 		PRINT_FAIL("Can't open memory stream\n");
686 		goto out;
687 	}
688 	insn_end = buf + cnt;
689 	insn = buf + (skip_first_insn ? 1 : 0);
690 	while (insn < insn_end) {
691 		insn = disasm_insn(insn, insn_buf, sizeof(insn_buf));
692 		fprintf(prog_out, "%s\n", insn_buf);
693 	}
694 	fclose(prog_out);
695 
696 	ASSERT_TRUE(match_pattern(btf, pattern, text, reg_map),
697 		    pinfo->prog_kind);
698 
699 out:
700 	if (prog_fd)
701 		close(prog_fd);
702 	free(buf);
703 	free(text);
704 }
705 
706 static void run_one_testcase(struct btf *btf, struct test_case *test)
707 {
708 	struct prog_info pinfo = {};
709 	int bpf_sz;
710 
711 	if (!test__start_subtest(test->name))
712 		return;
713 
714 	switch (test->field_sz) {
715 	case 8:
716 		bpf_sz = BPF_DW;
717 		break;
718 	case 4:
719 		bpf_sz = BPF_W;
720 		break;
721 	case 2:
722 		bpf_sz = BPF_H;
723 		break;
724 	case 1:
725 		bpf_sz = BPF_B;
726 		break;
727 	default:
728 		PRINT_FAIL("Unexpected field size: %d, want 8,4,2 or 1\n", test->field_sz);
729 		return;
730 	}
731 
732 	pinfo.prog_type = test->prog_type;
733 	pinfo.expected_attach_type = test->expected_attach_type;
734 
735 	if (test->read) {
736 		struct bpf_insn ldx_prog[] = {
737 			BPF_LDX_MEM(bpf_sz, BPF_REG_2, BPF_REG_1, test->field_offset),
738 			BPF_MOV64_IMM(BPF_REG_0, 0),
739 			BPF_EXIT_INSN(),
740 		};
741 		char *reg_map[][2] = {
742 			{ "$ctx", "r1" },
743 			{ "$dst", "r2" },
744 			{}
745 		};
746 
747 		pinfo.prog_kind = "LDX";
748 		pinfo.prog = ldx_prog;
749 		pinfo.prog_len = ARRAY_SIZE(ldx_prog);
750 		match_program(btf, &pinfo, test->read, reg_map, false);
751 	}
752 
753 	if (test->write || test->write_st || test->write_stx) {
754 		struct bpf_insn stx_prog[] = {
755 			BPF_MOV64_IMM(BPF_REG_2, 0),
756 			BPF_STX_MEM(bpf_sz, BPF_REG_1, BPF_REG_2, test->field_offset),
757 			BPF_MOV64_IMM(BPF_REG_0, 0),
758 			BPF_EXIT_INSN(),
759 		};
760 		char *stx_reg_map[][2] = {
761 			{ "$ctx", "r1" },
762 			{ "$src", "r2" },
763 			{}
764 		};
765 		struct bpf_insn st_prog[] = {
766 			BPF_ST_MEM(bpf_sz, BPF_REG_1, test->field_offset,
767 				   test->st_value.use ? test->st_value.value : 42),
768 			BPF_MOV64_IMM(BPF_REG_0, 0),
769 			BPF_EXIT_INSN(),
770 		};
771 		char *st_reg_map[][2] = {
772 			{ "$ctx", "r1" },
773 			{ "$src", "42" },
774 			{}
775 		};
776 
777 		if (test->write || test->write_stx) {
778 			char *pattern = test->write_stx ? test->write_stx : test->write;
779 
780 			pinfo.prog_kind = "STX";
781 			pinfo.prog = stx_prog;
782 			pinfo.prog_len = ARRAY_SIZE(stx_prog);
783 			match_program(btf, &pinfo, pattern, stx_reg_map, true);
784 		}
785 
786 		if (test->write || test->write_st) {
787 			char *pattern = test->write_st ? test->write_st : test->write;
788 
789 			pinfo.prog_kind = "ST";
790 			pinfo.prog = st_prog;
791 			pinfo.prog_len = ARRAY_SIZE(st_prog);
792 			match_program(btf, &pinfo, pattern, st_reg_map, false);
793 		}
794 	}
795 
796 	test__end_subtest();
797 }
798 
799 void test_ctx_rewrite(void)
800 {
801 	struct btf *btf;
802 	int i;
803 
804 	field_regex = compile_regex("^([[:alpha:]_][[:alnum:]_]+)::([[:alpha:]_][[:alnum:]_]+)");
805 	ident_regex = compile_regex("^[[:alpha:]_][[:alnum:]_]+");
806 	if (!field_regex || !ident_regex)
807 		return;
808 
809 	btf = btf__load_vmlinux_btf();
810 	if (!btf) {
811 		PRINT_FAIL("Can't load vmlinux BTF, errno %d (%s)\n", errno, strerror(errno));
812 		goto out;
813 	}
814 
815 	for (i = 0; i < ARRAY_SIZE(test_cases); ++i)
816 		run_one_testcase(btf, &test_cases[i]);
817 
818 out:
819 	btf__free(btf);
820 	free_regex(field_regex);
821 	free_regex(ident_regex);
822 }
823