xref: /linux/tools/testing/selftests/bpf/progs/verifier_tailcall_jit.c (revision 566ab427f827b0256d3e8ce0235d088e6a9c28bd)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/bpf.h>
3 #include <bpf/bpf_helpers.h>
4 #include "bpf_misc.h"
5 
6 int main(void);
7 
8 struct {
9 	__uint(type, BPF_MAP_TYPE_PROG_ARRAY);
10 	__uint(max_entries, 1);
11 	__uint(key_size, sizeof(__u32));
12 	__array(values, void (void));
13 } jmp_table SEC(".maps") = {
14 	.values = {
15 		[0] = (void *) &main,
16 	},
17 };
18 
19 __noinline __auxiliary
20 static __naked int sub(void)
21 {
22 	asm volatile (
23 	"r2 = %[jmp_table] ll;"
24 	"r3 = 0;"
25 	"call 12;"
26 	"exit;"
27 	:
28 	: __imm_addr(jmp_table)
29 	: __clobber_all);
30 }
31 
32 __success
33 __arch_x86_64
34 /* program entry for main(), regular function prologue */
35 __jited("	endbr64")
36 __jited("	nopl	(%rax,%rax)")
37 __jited("	xorq	%rax, %rax")
38 __jited("	pushq	%rbp")
39 __jited("	movq	%rsp, %rbp")
40 /* tail call prologue for program:
41  * - establish memory location for tail call counter at &rbp[-8];
42  * - spill tail_call_cnt_ptr at &rbp[-16];
43  * - expect tail call counter to be passed in rax;
44  * - for entry program rax is a raw counter, value < 33;
45  * - for tail called program rax is tail_call_cnt_ptr (value > 33).
46  */
47 __jited("	endbr64")
48 __jited("	cmpq	$0x21, %rax")
49 __jited("	ja	L0")
50 __jited("	pushq	%rax")
51 __jited("	movq	%rsp, %rax")
52 __jited("	jmp	L1")
53 __jited("L0:	pushq	%rax")			/* rbp[-8]  = rax         */
54 __jited("L1:	pushq	%rax")			/* rbp[-16] = rax         */
55 /* on subprogram call restore rax to be tail_call_cnt_ptr from rbp[-16]
56  * (cause original rax might be clobbered by this point)
57  */
58 __jited("	movq	-0x10(%rbp), %rax")
59 __jited("	callq	0x{{.*}}")		/* call to sub()          */
60 __jited("	xorl	%eax, %eax")
61 __jited("	leave")
62 __jited("	{{(retq|jmp	0x)}}")		/* return or jump to rethunk */
63 __jited("...")
64 /* subprogram entry for sub(), regular function prologue */
65 __jited("	endbr64")
66 __jited("	nopl	(%rax,%rax)")
67 __jited("	nopl	(%rax)")
68 __jited("	pushq	%rbp")
69 __jited("	movq	%rsp, %rbp")
70 /* tail call prologue for subprogram address of tail call counter
71  * stored at rbp[-16].
72  */
73 __jited("	endbr64")
74 __jited("	pushq	%rax")			/* rbp[-8]  = rax          */
75 __jited("	pushq	%rax")			/* rbp[-16] = rax          */
76 __jited("	movabsq	${{.*}}, %rsi")		/* r2 = &jmp_table         */
77 __jited("	xorl	%edx, %edx")		/* r3 = 0                  */
78 /* bpf_tail_call implementation:
79  * - load tail_call_cnt_ptr from rbp[-16];
80  * - if *tail_call_cnt_ptr < 33, increment it and jump to target;
81  * - otherwise do nothing.
82  */
83 __jited("	movq	-0x10(%rbp), %rax")
84 __jited("	cmpq	$0x21, (%rax)")
85 __jited("	jae	L0")
86 __jited("	nopl	(%rax,%rax)")
87 __jited("	addq	$0x1, (%rax)")		/* *tail_call_cnt_ptr += 1 */
88 __jited("	popq	%rax")
89 __jited("	popq	%rax")
90 __jited("	jmp	{{.*}}")		/* jump to tail call tgt   */
91 __jited("L0:	leave")
92 __jited("	{{(retq|jmp	0x)}}")		/* return or jump to rethunk */
93 SEC("tc")
94 __naked int main(void)
95 {
96 	asm volatile (
97 	"call %[sub];"
98 	"r0 = 0;"
99 	"exit;"
100 	:
101 	: __imm(sub)
102 	: __clobber_all);
103 }
104 
105 char __license[] SEC("license") = "GPL";
106