1 // SPDX-License-Identifier: GPL-2.0 2 #include <linux/bpf.h> 3 #include <bpf/bpf_helpers.h> 4 #include "bpf_misc.h" 5 6 int main(void); 7 8 struct { 9 __uint(type, BPF_MAP_TYPE_PROG_ARRAY); 10 __uint(max_entries, 1); 11 __uint(key_size, sizeof(__u32)); 12 __array(values, void (void)); 13 } jmp_table SEC(".maps") = { 14 .values = { 15 [0] = (void *) &main, 16 }, 17 }; 18 19 __noinline __auxiliary 20 static __naked int sub(void) 21 { 22 asm volatile ( 23 "r2 = %[jmp_table] ll;" 24 "r3 = 0;" 25 "call 12;" 26 "exit;" 27 : 28 : __imm_addr(jmp_table) 29 : __clobber_all); 30 } 31 32 __success 33 __arch_x86_64 34 /* program entry for main(), regular function prologue */ 35 __jited(" endbr64") 36 __jited(" nopl (%rax,%rax)") 37 __jited(" xorq %rax, %rax") 38 __jited(" pushq %rbp") 39 __jited(" movq %rsp, %rbp") 40 /* tail call prologue for program: 41 * - establish memory location for tail call counter at &rbp[-8]; 42 * - spill tail_call_cnt_ptr at &rbp[-16]; 43 * - expect tail call counter to be passed in rax; 44 * - for entry program rax is a raw counter, value < 33; 45 * - for tail called program rax is tail_call_cnt_ptr (value > 33). 46 */ 47 __jited(" endbr64") 48 __jited(" cmpq $0x21, %rax") 49 __jited(" ja L0") 50 __jited(" pushq %rax") 51 __jited(" movq %rsp, %rax") 52 __jited(" jmp L1") 53 __jited("L0: pushq %rax") /* rbp[-8] = rax */ 54 __jited("L1: pushq %rax") /* rbp[-16] = rax */ 55 /* on subprogram call restore rax to be tail_call_cnt_ptr from rbp[-16] 56 * (cause original rax might be clobbered by this point) 57 */ 58 __jited(" movq -0x10(%rbp), %rax") 59 __jited(" callq 0x{{.*}}") /* call to sub() */ 60 __jited(" xorl %eax, %eax") 61 __jited(" leave") 62 __jited(" {{(retq|jmp 0x)}}") /* return or jump to rethunk */ 63 __jited("...") 64 /* subprogram entry for sub(), regular function prologue */ 65 __jited(" endbr64") 66 __jited(" nopl (%rax,%rax)") 67 __jited(" nopl (%rax)") 68 __jited(" pushq %rbp") 69 __jited(" movq %rsp, %rbp") 70 /* tail call prologue for subprogram address of tail call counter 71 * stored at rbp[-16]. 72 */ 73 __jited(" endbr64") 74 __jited(" pushq %rax") /* rbp[-8] = rax */ 75 __jited(" pushq %rax") /* rbp[-16] = rax */ 76 __jited(" movabsq ${{.*}}, %rsi") /* r2 = &jmp_table */ 77 __jited(" xorl %edx, %edx") /* r3 = 0 */ 78 /* bpf_tail_call implementation: 79 * - load tail_call_cnt_ptr from rbp[-16]; 80 * - if *tail_call_cnt_ptr < 33, increment it and jump to target; 81 * - otherwise do nothing. 82 */ 83 __jited(" movq -0x10(%rbp), %rax") 84 __jited(" cmpq $0x21, (%rax)") 85 __jited(" jae L0") 86 __jited(" nopl (%rax,%rax)") 87 __jited(" addq $0x1, (%rax)") /* *tail_call_cnt_ptr += 1 */ 88 __jited(" popq %rax") 89 __jited(" popq %rax") 90 __jited(" jmp {{.*}}") /* jump to tail call tgt */ 91 __jited("L0: leave") 92 __jited(" {{(retq|jmp 0x)}}") /* return or jump to rethunk */ 93 SEC("tc") 94 __naked int main(void) 95 { 96 asm volatile ( 97 "call %[sub];" 98 "r0 = 0;" 99 "exit;" 100 : 101 : __imm(sub) 102 : __clobber_all); 103 } 104 105 char __license[] SEC("license") = "GPL"; 106