xref: /linux/arch/arm64/net/bpf_jit_comp.c (revision 566ab427f827b0256d3e8ce0235d088e6a9c28bd)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * BPF JIT compiler for ARM64
4  *
5  * Copyright (C) 2014-2016 Zi Shen Lim <zlim.lnx@gmail.com>
6  */
7 
8 #define pr_fmt(fmt) "bpf_jit: " fmt
9 
10 #include <linux/bitfield.h>
11 #include <linux/bpf.h>
12 #include <linux/filter.h>
13 #include <linux/memory.h>
14 #include <linux/printk.h>
15 #include <linux/slab.h>
16 
17 #include <asm/asm-extable.h>
18 #include <asm/byteorder.h>
19 #include <asm/cacheflush.h>
20 #include <asm/debug-monitors.h>
21 #include <asm/insn.h>
22 #include <asm/patching.h>
23 #include <asm/set_memory.h>
24 
25 #include "bpf_jit.h"
26 
27 #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
28 #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
29 #define TCCNT_PTR (MAX_BPF_JIT_REG + 2)
30 #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
31 #define ARENA_VM_START (MAX_BPF_JIT_REG + 5)
32 
33 #define check_imm(bits, imm) do {				\
34 	if ((((imm) > 0) && ((imm) >> (bits))) ||		\
35 	    (((imm) < 0) && (~(imm) >> (bits)))) {		\
36 		pr_info("[%2d] imm=%d(0x%x) out of range\n",	\
37 			i, imm, imm);				\
38 		return -EINVAL;					\
39 	}							\
40 } while (0)
41 #define check_imm19(imm) check_imm(19, imm)
42 #define check_imm26(imm) check_imm(26, imm)
43 
44 /* Map BPF registers to A64 registers */
45 static const int bpf2a64[] = {
46 	/* return value from in-kernel function, and exit value from eBPF */
47 	[BPF_REG_0] = A64_R(7),
48 	/* arguments from eBPF program to in-kernel function */
49 	[BPF_REG_1] = A64_R(0),
50 	[BPF_REG_2] = A64_R(1),
51 	[BPF_REG_3] = A64_R(2),
52 	[BPF_REG_4] = A64_R(3),
53 	[BPF_REG_5] = A64_R(4),
54 	/* callee saved registers that in-kernel function will preserve */
55 	[BPF_REG_6] = A64_R(19),
56 	[BPF_REG_7] = A64_R(20),
57 	[BPF_REG_8] = A64_R(21),
58 	[BPF_REG_9] = A64_R(22),
59 	/* read-only frame pointer to access stack */
60 	[BPF_REG_FP] = A64_R(25),
61 	/* temporary registers for BPF JIT */
62 	[TMP_REG_1] = A64_R(10),
63 	[TMP_REG_2] = A64_R(11),
64 	[TMP_REG_3] = A64_R(12),
65 	/* tail_call_cnt_ptr */
66 	[TCCNT_PTR] = A64_R(26),
67 	/* temporary register for blinding constants */
68 	[BPF_REG_AX] = A64_R(9),
69 	/* callee saved register for kern_vm_start address */
70 	[ARENA_VM_START] = A64_R(28),
71 };
72 
73 struct jit_ctx {
74 	const struct bpf_prog *prog;
75 	int idx;
76 	int epilogue_offset;
77 	int *offset;
78 	int exentry_idx;
79 	int nr_used_callee_reg;
80 	u8 used_callee_reg[8]; /* r6~r9, fp, arena_vm_start */
81 	__le32 *image;
82 	__le32 *ro_image;
83 	u32 stack_size;
84 	u64 user_vm_start;
85 	u64 arena_vm_start;
86 	bool fp_used;
87 	bool write;
88 };
89 
90 struct bpf_plt {
91 	u32 insn_ldr; /* load target */
92 	u32 insn_br;  /* branch to target */
93 	u64 target;   /* target value */
94 };
95 
96 #define PLT_TARGET_SIZE   sizeof_field(struct bpf_plt, target)
97 #define PLT_TARGET_OFFSET offsetof(struct bpf_plt, target)
98 
99 static inline void emit(const u32 insn, struct jit_ctx *ctx)
100 {
101 	if (ctx->image != NULL && ctx->write)
102 		ctx->image[ctx->idx] = cpu_to_le32(insn);
103 
104 	ctx->idx++;
105 }
106 
107 static inline void emit_a64_mov_i(const int is64, const int reg,
108 				  const s32 val, struct jit_ctx *ctx)
109 {
110 	u16 hi = val >> 16;
111 	u16 lo = val & 0xffff;
112 
113 	if (hi & 0x8000) {
114 		if (hi == 0xffff) {
115 			emit(A64_MOVN(is64, reg, (u16)~lo, 0), ctx);
116 		} else {
117 			emit(A64_MOVN(is64, reg, (u16)~hi, 16), ctx);
118 			if (lo != 0xffff)
119 				emit(A64_MOVK(is64, reg, lo, 0), ctx);
120 		}
121 	} else {
122 		emit(A64_MOVZ(is64, reg, lo, 0), ctx);
123 		if (hi)
124 			emit(A64_MOVK(is64, reg, hi, 16), ctx);
125 	}
126 }
127 
128 static int i64_i16_blocks(const u64 val, bool inverse)
129 {
130 	return (((val >>  0) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
131 	       (((val >> 16) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
132 	       (((val >> 32) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
133 	       (((val >> 48) & 0xffff) != (inverse ? 0xffff : 0x0000));
134 }
135 
136 static inline void emit_a64_mov_i64(const int reg, const u64 val,
137 				    struct jit_ctx *ctx)
138 {
139 	u64 nrm_tmp = val, rev_tmp = ~val;
140 	bool inverse;
141 	int shift;
142 
143 	if (!(nrm_tmp >> 32))
144 		return emit_a64_mov_i(0, reg, (u32)val, ctx);
145 
146 	inverse = i64_i16_blocks(nrm_tmp, true) < i64_i16_blocks(nrm_tmp, false);
147 	shift = max(round_down((inverse ? (fls64(rev_tmp) - 1) :
148 					  (fls64(nrm_tmp) - 1)), 16), 0);
149 	if (inverse)
150 		emit(A64_MOVN(1, reg, (rev_tmp >> shift) & 0xffff, shift), ctx);
151 	else
152 		emit(A64_MOVZ(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
153 	shift -= 16;
154 	while (shift >= 0) {
155 		if (((nrm_tmp >> shift) & 0xffff) != (inverse ? 0xffff : 0x0000))
156 			emit(A64_MOVK(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
157 		shift -= 16;
158 	}
159 }
160 
161 static inline void emit_bti(u32 insn, struct jit_ctx *ctx)
162 {
163 	if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
164 		emit(insn, ctx);
165 }
166 
167 /*
168  * Kernel addresses in the vmalloc space use at most 48 bits, and the
169  * remaining bits are guaranteed to be 0x1. So we can compose the address
170  * with a fixed length movn/movk/movk sequence.
171  */
172 static inline void emit_addr_mov_i64(const int reg, const u64 val,
173 				     struct jit_ctx *ctx)
174 {
175 	u64 tmp = val;
176 	int shift = 0;
177 
178 	emit(A64_MOVN(1, reg, ~tmp & 0xffff, shift), ctx);
179 	while (shift < 32) {
180 		tmp >>= 16;
181 		shift += 16;
182 		emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
183 	}
184 }
185 
186 static bool should_emit_indirect_call(long target, const struct jit_ctx *ctx)
187 {
188 	long offset;
189 
190 	/* when ctx->ro_image is not allocated or the target is unknown,
191 	 * emit indirect call
192 	 */
193 	if (!ctx->ro_image || !target)
194 		return true;
195 
196 	offset = target - (long)&ctx->ro_image[ctx->idx];
197 	return offset < -SZ_128M || offset >= SZ_128M;
198 }
199 
200 static void emit_direct_call(u64 target, struct jit_ctx *ctx)
201 {
202 	u32 insn;
203 	unsigned long pc;
204 
205 	pc = (unsigned long)&ctx->ro_image[ctx->idx];
206 	insn = aarch64_insn_gen_branch_imm(pc, target, AARCH64_INSN_BRANCH_LINK);
207 	emit(insn, ctx);
208 }
209 
210 static void emit_indirect_call(u64 target, struct jit_ctx *ctx)
211 {
212 	u8 tmp;
213 
214 	tmp = bpf2a64[TMP_REG_1];
215 	emit_addr_mov_i64(tmp, target, ctx);
216 	emit(A64_BLR(tmp), ctx);
217 }
218 
219 static void emit_call(u64 target, struct jit_ctx *ctx)
220 {
221 	if (should_emit_indirect_call((long)target, ctx))
222 		emit_indirect_call(target, ctx);
223 	else
224 		emit_direct_call(target, ctx);
225 }
226 
227 static inline int bpf2a64_offset(int bpf_insn, int off,
228 				 const struct jit_ctx *ctx)
229 {
230 	/* BPF JMP offset is relative to the next instruction */
231 	bpf_insn++;
232 	/*
233 	 * Whereas arm64 branch instructions encode the offset
234 	 * from the branch itself, so we must subtract 1 from the
235 	 * instruction offset.
236 	 */
237 	return ctx->offset[bpf_insn + off] - (ctx->offset[bpf_insn] - 1);
238 }
239 
240 static void jit_fill_hole(void *area, unsigned int size)
241 {
242 	__le32 *ptr;
243 	/* We are guaranteed to have aligned memory. */
244 	for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
245 		*ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
246 }
247 
248 int bpf_arch_text_invalidate(void *dst, size_t len)
249 {
250 	if (!aarch64_insn_set(dst, AARCH64_BREAK_FAULT, len))
251 		return -EINVAL;
252 
253 	return 0;
254 }
255 
256 static inline int epilogue_offset(const struct jit_ctx *ctx)
257 {
258 	int to = ctx->epilogue_offset;
259 	int from = ctx->idx;
260 
261 	return to - from;
262 }
263 
264 static bool is_addsub_imm(u32 imm)
265 {
266 	/* Either imm12 or shifted imm12. */
267 	return !(imm & ~0xfff) || !(imm & ~0xfff000);
268 }
269 
270 /*
271  * There are 3 types of AArch64 LDR/STR (immediate) instruction:
272  * Post-index, Pre-index, Unsigned offset.
273  *
274  * For BPF ldr/str, the "unsigned offset" type is sufficient.
275  *
276  * "Unsigned offset" type LDR(immediate) format:
277  *
278  *    3                   2                   1                   0
279  *  1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
280  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
281  * |x x|1 1 1 0 0 1 0 1|         imm12         |    Rn   |    Rt   |
282  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
283  * scale
284  *
285  * "Unsigned offset" type STR(immediate) format:
286  *    3                   2                   1                   0
287  *  1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
288  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
289  * |x x|1 1 1 0 0 1 0 0|         imm12         |    Rn   |    Rt   |
290  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
291  * scale
292  *
293  * The offset is calculated from imm12 and scale in the following way:
294  *
295  * offset = (u64)imm12 << scale
296  */
297 static bool is_lsi_offset(int offset, int scale)
298 {
299 	if (offset < 0)
300 		return false;
301 
302 	if (offset > (0xFFF << scale))
303 		return false;
304 
305 	if (offset & ((1 << scale) - 1))
306 		return false;
307 
308 	return true;
309 }
310 
311 /* generated main prog prologue:
312  *      bti c // if CONFIG_ARM64_BTI_KERNEL
313  *      mov x9, lr
314  *      nop  // POKE_OFFSET
315  *      paciasp // if CONFIG_ARM64_PTR_AUTH_KERNEL
316  *      stp x29, lr, [sp, #-16]!
317  *      mov x29, sp
318  *      stp xzr, x26, [sp, #-16]!
319  *      mov x26, sp
320  *      // PROLOGUE_OFFSET
321  *	// save callee-saved registers
322  */
323 static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
324 {
325 	const bool is_main_prog = !bpf_is_subprog(ctx->prog);
326 	const u8 ptr = bpf2a64[TCCNT_PTR];
327 
328 	if (is_main_prog) {
329 		/* Initialize tail_call_cnt. */
330 		emit(A64_PUSH(A64_ZR, ptr, A64_SP), ctx);
331 		emit(A64_MOV(1, ptr, A64_SP), ctx);
332 	} else
333 		emit(A64_PUSH(ptr, ptr, A64_SP), ctx);
334 }
335 
336 static void find_used_callee_regs(struct jit_ctx *ctx)
337 {
338 	int i;
339 	const struct bpf_prog *prog = ctx->prog;
340 	const struct bpf_insn *insn = &prog->insnsi[0];
341 	int reg_used = 0;
342 
343 	for (i = 0; i < prog->len; i++, insn++) {
344 		if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6)
345 			reg_used |= 1;
346 
347 		if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7)
348 			reg_used |= 2;
349 
350 		if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8)
351 			reg_used |= 4;
352 
353 		if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9)
354 			reg_used |= 8;
355 
356 		if (insn->dst_reg == BPF_REG_FP || insn->src_reg == BPF_REG_FP) {
357 			ctx->fp_used = true;
358 			reg_used |= 16;
359 		}
360 	}
361 
362 	i = 0;
363 	if (reg_used & 1)
364 		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_6];
365 
366 	if (reg_used & 2)
367 		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_7];
368 
369 	if (reg_used & 4)
370 		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_8];
371 
372 	if (reg_used & 8)
373 		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_9];
374 
375 	if (reg_used & 16)
376 		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_FP];
377 
378 	if (ctx->arena_vm_start)
379 		ctx->used_callee_reg[i++] = bpf2a64[ARENA_VM_START];
380 
381 	ctx->nr_used_callee_reg = i;
382 }
383 
384 /* Save callee-saved registers */
385 static void push_callee_regs(struct jit_ctx *ctx)
386 {
387 	int reg1, reg2, i;
388 
389 	/*
390 	 * Program acting as exception boundary should save all ARM64
391 	 * Callee-saved registers as the exception callback needs to recover
392 	 * all ARM64 Callee-saved registers in its epilogue.
393 	 */
394 	if (ctx->prog->aux->exception_boundary) {
395 		emit(A64_PUSH(A64_R(19), A64_R(20), A64_SP), ctx);
396 		emit(A64_PUSH(A64_R(21), A64_R(22), A64_SP), ctx);
397 		emit(A64_PUSH(A64_R(23), A64_R(24), A64_SP), ctx);
398 		emit(A64_PUSH(A64_R(25), A64_R(26), A64_SP), ctx);
399 		emit(A64_PUSH(A64_R(27), A64_R(28), A64_SP), ctx);
400 	} else {
401 		find_used_callee_regs(ctx);
402 		for (i = 0; i + 1 < ctx->nr_used_callee_reg; i += 2) {
403 			reg1 = ctx->used_callee_reg[i];
404 			reg2 = ctx->used_callee_reg[i + 1];
405 			emit(A64_PUSH(reg1, reg2, A64_SP), ctx);
406 		}
407 		if (i < ctx->nr_used_callee_reg) {
408 			reg1 = ctx->used_callee_reg[i];
409 			/* keep SP 16-byte aligned */
410 			emit(A64_PUSH(reg1, A64_ZR, A64_SP), ctx);
411 		}
412 	}
413 }
414 
415 /* Restore callee-saved registers */
416 static void pop_callee_regs(struct jit_ctx *ctx)
417 {
418 	struct bpf_prog_aux *aux = ctx->prog->aux;
419 	int reg1, reg2, i;
420 
421 	/*
422 	 * Program acting as exception boundary pushes R23 and R24 in addition
423 	 * to BPF callee-saved registers. Exception callback uses the boundary
424 	 * program's stack frame, so recover these extra registers in the above
425 	 * two cases.
426 	 */
427 	if (aux->exception_boundary || aux->exception_cb) {
428 		emit(A64_POP(A64_R(27), A64_R(28), A64_SP), ctx);
429 		emit(A64_POP(A64_R(25), A64_R(26), A64_SP), ctx);
430 		emit(A64_POP(A64_R(23), A64_R(24), A64_SP), ctx);
431 		emit(A64_POP(A64_R(21), A64_R(22), A64_SP), ctx);
432 		emit(A64_POP(A64_R(19), A64_R(20), A64_SP), ctx);
433 	} else {
434 		i = ctx->nr_used_callee_reg - 1;
435 		if (ctx->nr_used_callee_reg % 2 != 0) {
436 			reg1 = ctx->used_callee_reg[i];
437 			emit(A64_POP(reg1, A64_ZR, A64_SP), ctx);
438 			i--;
439 		}
440 		while (i > 0) {
441 			reg1 = ctx->used_callee_reg[i - 1];
442 			reg2 = ctx->used_callee_reg[i];
443 			emit(A64_POP(reg1, reg2, A64_SP), ctx);
444 			i -= 2;
445 		}
446 	}
447 }
448 
449 #define BTI_INSNS (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL) ? 1 : 0)
450 #define PAC_INSNS (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL) ? 1 : 0)
451 
452 /* Offset of nop instruction in bpf prog entry to be poked */
453 #define POKE_OFFSET (BTI_INSNS + 1)
454 
455 /* Tail call offset to jump into */
456 #define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 4)
457 
458 static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
459 {
460 	const struct bpf_prog *prog = ctx->prog;
461 	const bool is_main_prog = !bpf_is_subprog(prog);
462 	const u8 fp = bpf2a64[BPF_REG_FP];
463 	const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
464 	const int idx0 = ctx->idx;
465 	int cur_offset;
466 
467 	/*
468 	 * BPF prog stack layout
469 	 *
470 	 *                         high
471 	 * original A64_SP =>   0:+-----+ BPF prologue
472 	 *                        |FP/LR|
473 	 * current A64_FP =>  -16:+-----+
474 	 *                        | ... | callee saved registers
475 	 * BPF fp register => -64:+-----+ <= (BPF_FP)
476 	 *                        |     |
477 	 *                        | ... | BPF prog stack
478 	 *                        |     |
479 	 *                        +-----+ <= (BPF_FP - prog->aux->stack_depth)
480 	 *                        |RSVD | padding
481 	 * current A64_SP =>      +-----+ <= (BPF_FP - ctx->stack_size)
482 	 *                        |     |
483 	 *                        | ... | Function call stack
484 	 *                        |     |
485 	 *                        +-----+
486 	 *                          low
487 	 *
488 	 */
489 
490 	/* bpf function may be invoked by 3 instruction types:
491 	 * 1. bl, attached via freplace to bpf prog via short jump
492 	 * 2. br, attached via freplace to bpf prog via long jump
493 	 * 3. blr, working as a function pointer, used by emit_call.
494 	 * So BTI_JC should used here to support both br and blr.
495 	 */
496 	emit_bti(A64_BTI_JC, ctx);
497 
498 	emit(A64_MOV(1, A64_R(9), A64_LR), ctx);
499 	emit(A64_NOP, ctx);
500 
501 	if (!prog->aux->exception_cb) {
502 		/* Sign lr */
503 		if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
504 			emit(A64_PACIASP, ctx);
505 
506 		/* Save FP and LR registers to stay align with ARM64 AAPCS */
507 		emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
508 		emit(A64_MOV(1, A64_FP, A64_SP), ctx);
509 
510 		prepare_bpf_tail_call_cnt(ctx);
511 
512 		if (!ebpf_from_cbpf && is_main_prog) {
513 			cur_offset = ctx->idx - idx0;
514 			if (cur_offset != PROLOGUE_OFFSET) {
515 				pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
516 						cur_offset, PROLOGUE_OFFSET);
517 				return -1;
518 			}
519 			/* BTI landing pad for the tail call, done with a BR */
520 			emit_bti(A64_BTI_J, ctx);
521 		}
522 		push_callee_regs(ctx);
523 	} else {
524 		/*
525 		 * Exception callback receives FP of Main Program as third
526 		 * parameter
527 		 */
528 		emit(A64_MOV(1, A64_FP, A64_R(2)), ctx);
529 		/*
530 		 * Main Program already pushed the frame record and the
531 		 * callee-saved registers. The exception callback will not push
532 		 * anything and re-use the main program's stack.
533 		 *
534 		 * 12 registers are on the stack
535 		 */
536 		emit(A64_SUB_I(1, A64_SP, A64_FP, 96), ctx);
537 	}
538 
539 	if (ctx->fp_used)
540 		/* Set up BPF prog stack base register */
541 		emit(A64_MOV(1, fp, A64_SP), ctx);
542 
543 	/* Stack must be multiples of 16B */
544 	ctx->stack_size = round_up(prog->aux->stack_depth, 16);
545 
546 	/* Set up function call stack */
547 	if (ctx->stack_size)
548 		emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
549 
550 	if (ctx->arena_vm_start)
551 		emit_a64_mov_i64(arena_vm_base, ctx->arena_vm_start, ctx);
552 
553 	return 0;
554 }
555 
556 static int emit_bpf_tail_call(struct jit_ctx *ctx)
557 {
558 	/* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
559 	const u8 r2 = bpf2a64[BPF_REG_2];
560 	const u8 r3 = bpf2a64[BPF_REG_3];
561 
562 	const u8 tmp = bpf2a64[TMP_REG_1];
563 	const u8 prg = bpf2a64[TMP_REG_2];
564 	const u8 tcc = bpf2a64[TMP_REG_3];
565 	const u8 ptr = bpf2a64[TCCNT_PTR];
566 	size_t off;
567 	__le32 *branch1 = NULL;
568 	__le32 *branch2 = NULL;
569 	__le32 *branch3 = NULL;
570 
571 	/* if (index >= array->map.max_entries)
572 	 *     goto out;
573 	 */
574 	off = offsetof(struct bpf_array, map.max_entries);
575 	emit_a64_mov_i64(tmp, off, ctx);
576 	emit(A64_LDR32(tmp, r2, tmp), ctx);
577 	emit(A64_MOV(0, r3, r3), ctx);
578 	emit(A64_CMP(0, r3, tmp), ctx);
579 	branch1 = ctx->image + ctx->idx;
580 	emit(A64_NOP, ctx);
581 
582 	/*
583 	 * if ((*tail_call_cnt_ptr) >= MAX_TAIL_CALL_CNT)
584 	 *     goto out;
585 	 */
586 	emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
587 	emit(A64_LDR64I(tcc, ptr, 0), ctx);
588 	emit(A64_CMP(1, tcc, tmp), ctx);
589 	branch2 = ctx->image + ctx->idx;
590 	emit(A64_NOP, ctx);
591 
592 	/* (*tail_call_cnt_ptr)++; */
593 	emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
594 
595 	/* prog = array->ptrs[index];
596 	 * if (prog == NULL)
597 	 *     goto out;
598 	 */
599 	off = offsetof(struct bpf_array, ptrs);
600 	emit_a64_mov_i64(tmp, off, ctx);
601 	emit(A64_ADD(1, tmp, r2, tmp), ctx);
602 	emit(A64_LSL(1, prg, r3, 3), ctx);
603 	emit(A64_LDR64(prg, tmp, prg), ctx);
604 	branch3 = ctx->image + ctx->idx;
605 	emit(A64_NOP, ctx);
606 
607 	/* Update tail_call_cnt if the slot is populated. */
608 	emit(A64_STR64I(tcc, ptr, 0), ctx);
609 
610 	/* restore SP */
611 	if (ctx->stack_size)
612 		emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
613 
614 	pop_callee_regs(ctx);
615 
616 	/* goto *(prog->bpf_func + prologue_offset); */
617 	off = offsetof(struct bpf_prog, bpf_func);
618 	emit_a64_mov_i64(tmp, off, ctx);
619 	emit(A64_LDR64(tmp, prg, tmp), ctx);
620 	emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
621 	emit(A64_BR(tmp), ctx);
622 
623 	if (ctx->image) {
624 		off = &ctx->image[ctx->idx] - branch1;
625 		*branch1 = cpu_to_le32(A64_B_(A64_COND_CS, off));
626 
627 		off = &ctx->image[ctx->idx] - branch2;
628 		*branch2 = cpu_to_le32(A64_B_(A64_COND_CS, off));
629 
630 		off = &ctx->image[ctx->idx] - branch3;
631 		*branch3 = cpu_to_le32(A64_CBZ(1, prg, off));
632 	}
633 
634 	return 0;
635 }
636 
637 #ifdef CONFIG_ARM64_LSE_ATOMICS
638 static int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
639 {
640 	const u8 code = insn->code;
641 	const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
642 	const u8 dst = bpf2a64[insn->dst_reg];
643 	const u8 src = bpf2a64[insn->src_reg];
644 	const u8 tmp = bpf2a64[TMP_REG_1];
645 	const u8 tmp2 = bpf2a64[TMP_REG_2];
646 	const bool isdw = BPF_SIZE(code) == BPF_DW;
647 	const bool arena = BPF_MODE(code) == BPF_PROBE_ATOMIC;
648 	const s16 off = insn->off;
649 	u8 reg = dst;
650 
651 	if (off || arena) {
652 		if (off) {
653 			emit_a64_mov_i(1, tmp, off, ctx);
654 			emit(A64_ADD(1, tmp, tmp, dst), ctx);
655 			reg = tmp;
656 		}
657 		if (arena) {
658 			emit(A64_ADD(1, tmp, reg, arena_vm_base), ctx);
659 			reg = tmp;
660 		}
661 	}
662 
663 	switch (insn->imm) {
664 	/* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
665 	case BPF_ADD:
666 		emit(A64_STADD(isdw, reg, src), ctx);
667 		break;
668 	case BPF_AND:
669 		emit(A64_MVN(isdw, tmp2, src), ctx);
670 		emit(A64_STCLR(isdw, reg, tmp2), ctx);
671 		break;
672 	case BPF_OR:
673 		emit(A64_STSET(isdw, reg, src), ctx);
674 		break;
675 	case BPF_XOR:
676 		emit(A64_STEOR(isdw, reg, src), ctx);
677 		break;
678 	/* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
679 	case BPF_ADD | BPF_FETCH:
680 		emit(A64_LDADDAL(isdw, src, reg, src), ctx);
681 		break;
682 	case BPF_AND | BPF_FETCH:
683 		emit(A64_MVN(isdw, tmp2, src), ctx);
684 		emit(A64_LDCLRAL(isdw, src, reg, tmp2), ctx);
685 		break;
686 	case BPF_OR | BPF_FETCH:
687 		emit(A64_LDSETAL(isdw, src, reg, src), ctx);
688 		break;
689 	case BPF_XOR | BPF_FETCH:
690 		emit(A64_LDEORAL(isdw, src, reg, src), ctx);
691 		break;
692 	/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
693 	case BPF_XCHG:
694 		emit(A64_SWPAL(isdw, src, reg, src), ctx);
695 		break;
696 	/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
697 	case BPF_CMPXCHG:
698 		emit(A64_CASAL(isdw, src, reg, bpf2a64[BPF_REG_0]), ctx);
699 		break;
700 	default:
701 		pr_err_once("unknown atomic op code %02x\n", insn->imm);
702 		return -EINVAL;
703 	}
704 
705 	return 0;
706 }
707 #else
708 static inline int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
709 {
710 	return -EINVAL;
711 }
712 #endif
713 
714 static int emit_ll_sc_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
715 {
716 	const u8 code = insn->code;
717 	const u8 dst = bpf2a64[insn->dst_reg];
718 	const u8 src = bpf2a64[insn->src_reg];
719 	const u8 tmp = bpf2a64[TMP_REG_1];
720 	const u8 tmp2 = bpf2a64[TMP_REG_2];
721 	const u8 tmp3 = bpf2a64[TMP_REG_3];
722 	const int i = insn - ctx->prog->insnsi;
723 	const s32 imm = insn->imm;
724 	const s16 off = insn->off;
725 	const bool isdw = BPF_SIZE(code) == BPF_DW;
726 	u8 reg;
727 	s32 jmp_offset;
728 
729 	if (BPF_MODE(code) == BPF_PROBE_ATOMIC) {
730 		/* ll_sc based atomics don't support unsafe pointers yet. */
731 		pr_err_once("unknown atomic opcode %02x\n", code);
732 		return -EINVAL;
733 	}
734 
735 	if (!off) {
736 		reg = dst;
737 	} else {
738 		emit_a64_mov_i(1, tmp, off, ctx);
739 		emit(A64_ADD(1, tmp, tmp, dst), ctx);
740 		reg = tmp;
741 	}
742 
743 	if (imm == BPF_ADD || imm == BPF_AND ||
744 	    imm == BPF_OR || imm == BPF_XOR) {
745 		/* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
746 		emit(A64_LDXR(isdw, tmp2, reg), ctx);
747 		if (imm == BPF_ADD)
748 			emit(A64_ADD(isdw, tmp2, tmp2, src), ctx);
749 		else if (imm == BPF_AND)
750 			emit(A64_AND(isdw, tmp2, tmp2, src), ctx);
751 		else if (imm == BPF_OR)
752 			emit(A64_ORR(isdw, tmp2, tmp2, src), ctx);
753 		else
754 			emit(A64_EOR(isdw, tmp2, tmp2, src), ctx);
755 		emit(A64_STXR(isdw, tmp2, reg, tmp3), ctx);
756 		jmp_offset = -3;
757 		check_imm19(jmp_offset);
758 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
759 	} else if (imm == (BPF_ADD | BPF_FETCH) ||
760 		   imm == (BPF_AND | BPF_FETCH) ||
761 		   imm == (BPF_OR | BPF_FETCH) ||
762 		   imm == (BPF_XOR | BPF_FETCH)) {
763 		/* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
764 		const u8 ax = bpf2a64[BPF_REG_AX];
765 
766 		emit(A64_MOV(isdw, ax, src), ctx);
767 		emit(A64_LDXR(isdw, src, reg), ctx);
768 		if (imm == (BPF_ADD | BPF_FETCH))
769 			emit(A64_ADD(isdw, tmp2, src, ax), ctx);
770 		else if (imm == (BPF_AND | BPF_FETCH))
771 			emit(A64_AND(isdw, tmp2, src, ax), ctx);
772 		else if (imm == (BPF_OR | BPF_FETCH))
773 			emit(A64_ORR(isdw, tmp2, src, ax), ctx);
774 		else
775 			emit(A64_EOR(isdw, tmp2, src, ax), ctx);
776 		emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
777 		jmp_offset = -3;
778 		check_imm19(jmp_offset);
779 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
780 		emit(A64_DMB_ISH, ctx);
781 	} else if (imm == BPF_XCHG) {
782 		/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
783 		emit(A64_MOV(isdw, tmp2, src), ctx);
784 		emit(A64_LDXR(isdw, src, reg), ctx);
785 		emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
786 		jmp_offset = -2;
787 		check_imm19(jmp_offset);
788 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
789 		emit(A64_DMB_ISH, ctx);
790 	} else if (imm == BPF_CMPXCHG) {
791 		/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
792 		const u8 r0 = bpf2a64[BPF_REG_0];
793 
794 		emit(A64_MOV(isdw, tmp2, r0), ctx);
795 		emit(A64_LDXR(isdw, r0, reg), ctx);
796 		emit(A64_EOR(isdw, tmp3, r0, tmp2), ctx);
797 		jmp_offset = 4;
798 		check_imm19(jmp_offset);
799 		emit(A64_CBNZ(isdw, tmp3, jmp_offset), ctx);
800 		emit(A64_STLXR(isdw, src, reg, tmp3), ctx);
801 		jmp_offset = -4;
802 		check_imm19(jmp_offset);
803 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
804 		emit(A64_DMB_ISH, ctx);
805 	} else {
806 		pr_err_once("unknown atomic op code %02x\n", imm);
807 		return -EINVAL;
808 	}
809 
810 	return 0;
811 }
812 
813 void dummy_tramp(void);
814 
815 asm (
816 "	.pushsection .text, \"ax\", @progbits\n"
817 "	.global dummy_tramp\n"
818 "	.type dummy_tramp, %function\n"
819 "dummy_tramp:"
820 #if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL)
821 "	bti j\n" /* dummy_tramp is called via "br x10" */
822 #endif
823 "	mov x10, x30\n"
824 "	mov x30, x9\n"
825 "	ret x10\n"
826 "	.size dummy_tramp, .-dummy_tramp\n"
827 "	.popsection\n"
828 );
829 
830 /* build a plt initialized like this:
831  *
832  * plt:
833  *      ldr tmp, target
834  *      br tmp
835  * target:
836  *      .quad dummy_tramp
837  *
838  * when a long jump trampoline is attached, target is filled with the
839  * trampoline address, and when the trampoline is removed, target is
840  * restored to dummy_tramp address.
841  */
842 static void build_plt(struct jit_ctx *ctx)
843 {
844 	const u8 tmp = bpf2a64[TMP_REG_1];
845 	struct bpf_plt *plt = NULL;
846 
847 	/* make sure target is 64-bit aligned */
848 	if ((ctx->idx + PLT_TARGET_OFFSET / AARCH64_INSN_SIZE) % 2)
849 		emit(A64_NOP, ctx);
850 
851 	plt = (struct bpf_plt *)(ctx->image + ctx->idx);
852 	/* plt is called via bl, no BTI needed here */
853 	emit(A64_LDR64LIT(tmp, 2 * AARCH64_INSN_SIZE), ctx);
854 	emit(A64_BR(tmp), ctx);
855 
856 	if (ctx->image)
857 		plt->target = (u64)&dummy_tramp;
858 }
859 
860 static void build_epilogue(struct jit_ctx *ctx)
861 {
862 	const u8 r0 = bpf2a64[BPF_REG_0];
863 	const u8 ptr = bpf2a64[TCCNT_PTR];
864 
865 	/* We're done with BPF stack */
866 	if (ctx->stack_size)
867 		emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
868 
869 	pop_callee_regs(ctx);
870 
871 	emit(A64_POP(A64_ZR, ptr, A64_SP), ctx);
872 
873 	/* Restore FP/LR registers */
874 	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
875 
876 	/* Set return value */
877 	emit(A64_MOV(1, A64_R(0), r0), ctx);
878 
879 	/* Authenticate lr */
880 	if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
881 		emit(A64_AUTIASP, ctx);
882 
883 	emit(A64_RET(A64_LR), ctx);
884 }
885 
886 #define BPF_FIXUP_OFFSET_MASK	GENMASK(26, 0)
887 #define BPF_FIXUP_REG_MASK	GENMASK(31, 27)
888 #define DONT_CLEAR 5 /* Unused ARM64 register from BPF's POV */
889 
890 bool ex_handler_bpf(const struct exception_table_entry *ex,
891 		    struct pt_regs *regs)
892 {
893 	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
894 	int dst_reg = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
895 
896 	if (dst_reg != DONT_CLEAR)
897 		regs->regs[dst_reg] = 0;
898 	regs->pc = (unsigned long)&ex->fixup - offset;
899 	return true;
900 }
901 
902 /* For accesses to BTF pointers, add an entry to the exception table */
903 static int add_exception_handler(const struct bpf_insn *insn,
904 				 struct jit_ctx *ctx,
905 				 int dst_reg)
906 {
907 	off_t ins_offset;
908 	off_t fixup_offset;
909 	unsigned long pc;
910 	struct exception_table_entry *ex;
911 
912 	if (!ctx->image)
913 		/* First pass */
914 		return 0;
915 
916 	if (BPF_MODE(insn->code) != BPF_PROBE_MEM &&
917 		BPF_MODE(insn->code) != BPF_PROBE_MEMSX &&
918 			BPF_MODE(insn->code) != BPF_PROBE_MEM32 &&
919 				BPF_MODE(insn->code) != BPF_PROBE_ATOMIC)
920 		return 0;
921 
922 	if (!ctx->prog->aux->extable ||
923 	    WARN_ON_ONCE(ctx->exentry_idx >= ctx->prog->aux->num_exentries))
924 		return -EINVAL;
925 
926 	ex = &ctx->prog->aux->extable[ctx->exentry_idx];
927 	pc = (unsigned long)&ctx->ro_image[ctx->idx - 1];
928 
929 	/*
930 	 * This is the relative offset of the instruction that may fault from
931 	 * the exception table itself. This will be written to the exception
932 	 * table and if this instruction faults, the destination register will
933 	 * be set to '0' and the execution will jump to the next instruction.
934 	 */
935 	ins_offset = pc - (long)&ex->insn;
936 	if (WARN_ON_ONCE(ins_offset >= 0 || ins_offset < INT_MIN))
937 		return -ERANGE;
938 
939 	/*
940 	 * Since the extable follows the program, the fixup offset is always
941 	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
942 	 * to keep things simple, and put the destination register in the upper
943 	 * bits. We don't need to worry about buildtime or runtime sort
944 	 * modifying the upper bits because the table is already sorted, and
945 	 * isn't part of the main exception table.
946 	 *
947 	 * The fixup_offset is set to the next instruction from the instruction
948 	 * that may fault. The execution will jump to this after handling the
949 	 * fault.
950 	 */
951 	fixup_offset = (long)&ex->fixup - (pc + AARCH64_INSN_SIZE);
952 	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, fixup_offset))
953 		return -ERANGE;
954 
955 	/*
956 	 * The offsets above have been calculated using the RO buffer but we
957 	 * need to use the R/W buffer for writes.
958 	 * switch ex to rw buffer for writing.
959 	 */
960 	ex = (void *)ctx->image + ((void *)ex - (void *)ctx->ro_image);
961 
962 	ex->insn = ins_offset;
963 
964 	if (BPF_CLASS(insn->code) != BPF_LDX)
965 		dst_reg = DONT_CLEAR;
966 
967 	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, fixup_offset) |
968 		    FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
969 
970 	ex->type = EX_TYPE_BPF;
971 
972 	ctx->exentry_idx++;
973 	return 0;
974 }
975 
976 /* JITs an eBPF instruction.
977  * Returns:
978  * 0  - successfully JITed an 8-byte eBPF instruction.
979  * >0 - successfully JITed a 16-byte eBPF instruction.
980  * <0 - failed to JIT.
981  */
982 static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
983 		      bool extra_pass)
984 {
985 	const u8 code = insn->code;
986 	u8 dst = bpf2a64[insn->dst_reg];
987 	u8 src = bpf2a64[insn->src_reg];
988 	const u8 tmp = bpf2a64[TMP_REG_1];
989 	const u8 tmp2 = bpf2a64[TMP_REG_2];
990 	const u8 fp = bpf2a64[BPF_REG_FP];
991 	const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
992 	const s16 off = insn->off;
993 	const s32 imm = insn->imm;
994 	const int i = insn - ctx->prog->insnsi;
995 	const bool is64 = BPF_CLASS(code) == BPF_ALU64 ||
996 			  BPF_CLASS(code) == BPF_JMP;
997 	u8 jmp_cond;
998 	s32 jmp_offset;
999 	u32 a64_insn;
1000 	u8 src_adj;
1001 	u8 dst_adj;
1002 	int off_adj;
1003 	int ret;
1004 	bool sign_extend;
1005 
1006 	switch (code) {
1007 	/* dst = src */
1008 	case BPF_ALU | BPF_MOV | BPF_X:
1009 	case BPF_ALU64 | BPF_MOV | BPF_X:
1010 		if (insn_is_cast_user(insn)) {
1011 			emit(A64_MOV(0, tmp, src), ctx); // 32-bit mov clears the upper 32 bits
1012 			emit_a64_mov_i(0, dst, ctx->user_vm_start >> 32, ctx);
1013 			emit(A64_LSL(1, dst, dst, 32), ctx);
1014 			emit(A64_CBZ(1, tmp, 2), ctx);
1015 			emit(A64_ORR(1, tmp, dst, tmp), ctx);
1016 			emit(A64_MOV(1, dst, tmp), ctx);
1017 			break;
1018 		} else if (insn_is_mov_percpu_addr(insn)) {
1019 			if (dst != src)
1020 				emit(A64_MOV(1, dst, src), ctx);
1021 			if (cpus_have_cap(ARM64_HAS_VIRT_HOST_EXTN))
1022 				emit(A64_MRS_TPIDR_EL2(tmp), ctx);
1023 			else
1024 				emit(A64_MRS_TPIDR_EL1(tmp), ctx);
1025 			emit(A64_ADD(1, dst, dst, tmp), ctx);
1026 			break;
1027 		}
1028 		switch (insn->off) {
1029 		case 0:
1030 			emit(A64_MOV(is64, dst, src), ctx);
1031 			break;
1032 		case 8:
1033 			emit(A64_SXTB(is64, dst, src), ctx);
1034 			break;
1035 		case 16:
1036 			emit(A64_SXTH(is64, dst, src), ctx);
1037 			break;
1038 		case 32:
1039 			emit(A64_SXTW(is64, dst, src), ctx);
1040 			break;
1041 		}
1042 		break;
1043 	/* dst = dst OP src */
1044 	case BPF_ALU | BPF_ADD | BPF_X:
1045 	case BPF_ALU64 | BPF_ADD | BPF_X:
1046 		emit(A64_ADD(is64, dst, dst, src), ctx);
1047 		break;
1048 	case BPF_ALU | BPF_SUB | BPF_X:
1049 	case BPF_ALU64 | BPF_SUB | BPF_X:
1050 		emit(A64_SUB(is64, dst, dst, src), ctx);
1051 		break;
1052 	case BPF_ALU | BPF_AND | BPF_X:
1053 	case BPF_ALU64 | BPF_AND | BPF_X:
1054 		emit(A64_AND(is64, dst, dst, src), ctx);
1055 		break;
1056 	case BPF_ALU | BPF_OR | BPF_X:
1057 	case BPF_ALU64 | BPF_OR | BPF_X:
1058 		emit(A64_ORR(is64, dst, dst, src), ctx);
1059 		break;
1060 	case BPF_ALU | BPF_XOR | BPF_X:
1061 	case BPF_ALU64 | BPF_XOR | BPF_X:
1062 		emit(A64_EOR(is64, dst, dst, src), ctx);
1063 		break;
1064 	case BPF_ALU | BPF_MUL | BPF_X:
1065 	case BPF_ALU64 | BPF_MUL | BPF_X:
1066 		emit(A64_MUL(is64, dst, dst, src), ctx);
1067 		break;
1068 	case BPF_ALU | BPF_DIV | BPF_X:
1069 	case BPF_ALU64 | BPF_DIV | BPF_X:
1070 		if (!off)
1071 			emit(A64_UDIV(is64, dst, dst, src), ctx);
1072 		else
1073 			emit(A64_SDIV(is64, dst, dst, src), ctx);
1074 		break;
1075 	case BPF_ALU | BPF_MOD | BPF_X:
1076 	case BPF_ALU64 | BPF_MOD | BPF_X:
1077 		if (!off)
1078 			emit(A64_UDIV(is64, tmp, dst, src), ctx);
1079 		else
1080 			emit(A64_SDIV(is64, tmp, dst, src), ctx);
1081 		emit(A64_MSUB(is64, dst, dst, tmp, src), ctx);
1082 		break;
1083 	case BPF_ALU | BPF_LSH | BPF_X:
1084 	case BPF_ALU64 | BPF_LSH | BPF_X:
1085 		emit(A64_LSLV(is64, dst, dst, src), ctx);
1086 		break;
1087 	case BPF_ALU | BPF_RSH | BPF_X:
1088 	case BPF_ALU64 | BPF_RSH | BPF_X:
1089 		emit(A64_LSRV(is64, dst, dst, src), ctx);
1090 		break;
1091 	case BPF_ALU | BPF_ARSH | BPF_X:
1092 	case BPF_ALU64 | BPF_ARSH | BPF_X:
1093 		emit(A64_ASRV(is64, dst, dst, src), ctx);
1094 		break;
1095 	/* dst = -dst */
1096 	case BPF_ALU | BPF_NEG:
1097 	case BPF_ALU64 | BPF_NEG:
1098 		emit(A64_NEG(is64, dst, dst), ctx);
1099 		break;
1100 	/* dst = BSWAP##imm(dst) */
1101 	case BPF_ALU | BPF_END | BPF_FROM_LE:
1102 	case BPF_ALU | BPF_END | BPF_FROM_BE:
1103 	case BPF_ALU64 | BPF_END | BPF_FROM_LE:
1104 #ifdef CONFIG_CPU_BIG_ENDIAN
1105 		if (BPF_CLASS(code) == BPF_ALU && BPF_SRC(code) == BPF_FROM_BE)
1106 			goto emit_bswap_uxt;
1107 #else /* !CONFIG_CPU_BIG_ENDIAN */
1108 		if (BPF_CLASS(code) == BPF_ALU && BPF_SRC(code) == BPF_FROM_LE)
1109 			goto emit_bswap_uxt;
1110 #endif
1111 		switch (imm) {
1112 		case 16:
1113 			emit(A64_REV16(is64, dst, dst), ctx);
1114 			/* zero-extend 16 bits into 64 bits */
1115 			emit(A64_UXTH(is64, dst, dst), ctx);
1116 			break;
1117 		case 32:
1118 			emit(A64_REV32(0, dst, dst), ctx);
1119 			/* upper 32 bits already cleared */
1120 			break;
1121 		case 64:
1122 			emit(A64_REV64(dst, dst), ctx);
1123 			break;
1124 		}
1125 		break;
1126 emit_bswap_uxt:
1127 		switch (imm) {
1128 		case 16:
1129 			/* zero-extend 16 bits into 64 bits */
1130 			emit(A64_UXTH(is64, dst, dst), ctx);
1131 			break;
1132 		case 32:
1133 			/* zero-extend 32 bits into 64 bits */
1134 			emit(A64_UXTW(is64, dst, dst), ctx);
1135 			break;
1136 		case 64:
1137 			/* nop */
1138 			break;
1139 		}
1140 		break;
1141 	/* dst = imm */
1142 	case BPF_ALU | BPF_MOV | BPF_K:
1143 	case BPF_ALU64 | BPF_MOV | BPF_K:
1144 		emit_a64_mov_i(is64, dst, imm, ctx);
1145 		break;
1146 	/* dst = dst OP imm */
1147 	case BPF_ALU | BPF_ADD | BPF_K:
1148 	case BPF_ALU64 | BPF_ADD | BPF_K:
1149 		if (is_addsub_imm(imm)) {
1150 			emit(A64_ADD_I(is64, dst, dst, imm), ctx);
1151 		} else if (is_addsub_imm(-imm)) {
1152 			emit(A64_SUB_I(is64, dst, dst, -imm), ctx);
1153 		} else {
1154 			emit_a64_mov_i(is64, tmp, imm, ctx);
1155 			emit(A64_ADD(is64, dst, dst, tmp), ctx);
1156 		}
1157 		break;
1158 	case BPF_ALU | BPF_SUB | BPF_K:
1159 	case BPF_ALU64 | BPF_SUB | BPF_K:
1160 		if (is_addsub_imm(imm)) {
1161 			emit(A64_SUB_I(is64, dst, dst, imm), ctx);
1162 		} else if (is_addsub_imm(-imm)) {
1163 			emit(A64_ADD_I(is64, dst, dst, -imm), ctx);
1164 		} else {
1165 			emit_a64_mov_i(is64, tmp, imm, ctx);
1166 			emit(A64_SUB(is64, dst, dst, tmp), ctx);
1167 		}
1168 		break;
1169 	case BPF_ALU | BPF_AND | BPF_K:
1170 	case BPF_ALU64 | BPF_AND | BPF_K:
1171 		a64_insn = A64_AND_I(is64, dst, dst, imm);
1172 		if (a64_insn != AARCH64_BREAK_FAULT) {
1173 			emit(a64_insn, ctx);
1174 		} else {
1175 			emit_a64_mov_i(is64, tmp, imm, ctx);
1176 			emit(A64_AND(is64, dst, dst, tmp), ctx);
1177 		}
1178 		break;
1179 	case BPF_ALU | BPF_OR | BPF_K:
1180 	case BPF_ALU64 | BPF_OR | BPF_K:
1181 		a64_insn = A64_ORR_I(is64, dst, dst, imm);
1182 		if (a64_insn != AARCH64_BREAK_FAULT) {
1183 			emit(a64_insn, ctx);
1184 		} else {
1185 			emit_a64_mov_i(is64, tmp, imm, ctx);
1186 			emit(A64_ORR(is64, dst, dst, tmp), ctx);
1187 		}
1188 		break;
1189 	case BPF_ALU | BPF_XOR | BPF_K:
1190 	case BPF_ALU64 | BPF_XOR | BPF_K:
1191 		a64_insn = A64_EOR_I(is64, dst, dst, imm);
1192 		if (a64_insn != AARCH64_BREAK_FAULT) {
1193 			emit(a64_insn, ctx);
1194 		} else {
1195 			emit_a64_mov_i(is64, tmp, imm, ctx);
1196 			emit(A64_EOR(is64, dst, dst, tmp), ctx);
1197 		}
1198 		break;
1199 	case BPF_ALU | BPF_MUL | BPF_K:
1200 	case BPF_ALU64 | BPF_MUL | BPF_K:
1201 		emit_a64_mov_i(is64, tmp, imm, ctx);
1202 		emit(A64_MUL(is64, dst, dst, tmp), ctx);
1203 		break;
1204 	case BPF_ALU | BPF_DIV | BPF_K:
1205 	case BPF_ALU64 | BPF_DIV | BPF_K:
1206 		emit_a64_mov_i(is64, tmp, imm, ctx);
1207 		if (!off)
1208 			emit(A64_UDIV(is64, dst, dst, tmp), ctx);
1209 		else
1210 			emit(A64_SDIV(is64, dst, dst, tmp), ctx);
1211 		break;
1212 	case BPF_ALU | BPF_MOD | BPF_K:
1213 	case BPF_ALU64 | BPF_MOD | BPF_K:
1214 		emit_a64_mov_i(is64, tmp2, imm, ctx);
1215 		if (!off)
1216 			emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
1217 		else
1218 			emit(A64_SDIV(is64, tmp, dst, tmp2), ctx);
1219 		emit(A64_MSUB(is64, dst, dst, tmp, tmp2), ctx);
1220 		break;
1221 	case BPF_ALU | BPF_LSH | BPF_K:
1222 	case BPF_ALU64 | BPF_LSH | BPF_K:
1223 		emit(A64_LSL(is64, dst, dst, imm), ctx);
1224 		break;
1225 	case BPF_ALU | BPF_RSH | BPF_K:
1226 	case BPF_ALU64 | BPF_RSH | BPF_K:
1227 		emit(A64_LSR(is64, dst, dst, imm), ctx);
1228 		break;
1229 	case BPF_ALU | BPF_ARSH | BPF_K:
1230 	case BPF_ALU64 | BPF_ARSH | BPF_K:
1231 		emit(A64_ASR(is64, dst, dst, imm), ctx);
1232 		break;
1233 
1234 	/* JUMP off */
1235 	case BPF_JMP | BPF_JA:
1236 	case BPF_JMP32 | BPF_JA:
1237 		if (BPF_CLASS(code) == BPF_JMP)
1238 			jmp_offset = bpf2a64_offset(i, off, ctx);
1239 		else
1240 			jmp_offset = bpf2a64_offset(i, imm, ctx);
1241 		check_imm26(jmp_offset);
1242 		emit(A64_B(jmp_offset), ctx);
1243 		break;
1244 	/* IF (dst COND src) JUMP off */
1245 	case BPF_JMP | BPF_JEQ | BPF_X:
1246 	case BPF_JMP | BPF_JGT | BPF_X:
1247 	case BPF_JMP | BPF_JLT | BPF_X:
1248 	case BPF_JMP | BPF_JGE | BPF_X:
1249 	case BPF_JMP | BPF_JLE | BPF_X:
1250 	case BPF_JMP | BPF_JNE | BPF_X:
1251 	case BPF_JMP | BPF_JSGT | BPF_X:
1252 	case BPF_JMP | BPF_JSLT | BPF_X:
1253 	case BPF_JMP | BPF_JSGE | BPF_X:
1254 	case BPF_JMP | BPF_JSLE | BPF_X:
1255 	case BPF_JMP32 | BPF_JEQ | BPF_X:
1256 	case BPF_JMP32 | BPF_JGT | BPF_X:
1257 	case BPF_JMP32 | BPF_JLT | BPF_X:
1258 	case BPF_JMP32 | BPF_JGE | BPF_X:
1259 	case BPF_JMP32 | BPF_JLE | BPF_X:
1260 	case BPF_JMP32 | BPF_JNE | BPF_X:
1261 	case BPF_JMP32 | BPF_JSGT | BPF_X:
1262 	case BPF_JMP32 | BPF_JSLT | BPF_X:
1263 	case BPF_JMP32 | BPF_JSGE | BPF_X:
1264 	case BPF_JMP32 | BPF_JSLE | BPF_X:
1265 		emit(A64_CMP(is64, dst, src), ctx);
1266 emit_cond_jmp:
1267 		jmp_offset = bpf2a64_offset(i, off, ctx);
1268 		check_imm19(jmp_offset);
1269 		switch (BPF_OP(code)) {
1270 		case BPF_JEQ:
1271 			jmp_cond = A64_COND_EQ;
1272 			break;
1273 		case BPF_JGT:
1274 			jmp_cond = A64_COND_HI;
1275 			break;
1276 		case BPF_JLT:
1277 			jmp_cond = A64_COND_CC;
1278 			break;
1279 		case BPF_JGE:
1280 			jmp_cond = A64_COND_CS;
1281 			break;
1282 		case BPF_JLE:
1283 			jmp_cond = A64_COND_LS;
1284 			break;
1285 		case BPF_JSET:
1286 		case BPF_JNE:
1287 			jmp_cond = A64_COND_NE;
1288 			break;
1289 		case BPF_JSGT:
1290 			jmp_cond = A64_COND_GT;
1291 			break;
1292 		case BPF_JSLT:
1293 			jmp_cond = A64_COND_LT;
1294 			break;
1295 		case BPF_JSGE:
1296 			jmp_cond = A64_COND_GE;
1297 			break;
1298 		case BPF_JSLE:
1299 			jmp_cond = A64_COND_LE;
1300 			break;
1301 		default:
1302 			return -EFAULT;
1303 		}
1304 		emit(A64_B_(jmp_cond, jmp_offset), ctx);
1305 		break;
1306 	case BPF_JMP | BPF_JSET | BPF_X:
1307 	case BPF_JMP32 | BPF_JSET | BPF_X:
1308 		emit(A64_TST(is64, dst, src), ctx);
1309 		goto emit_cond_jmp;
1310 	/* IF (dst COND imm) JUMP off */
1311 	case BPF_JMP | BPF_JEQ | BPF_K:
1312 	case BPF_JMP | BPF_JGT | BPF_K:
1313 	case BPF_JMP | BPF_JLT | BPF_K:
1314 	case BPF_JMP | BPF_JGE | BPF_K:
1315 	case BPF_JMP | BPF_JLE | BPF_K:
1316 	case BPF_JMP | BPF_JNE | BPF_K:
1317 	case BPF_JMP | BPF_JSGT | BPF_K:
1318 	case BPF_JMP | BPF_JSLT | BPF_K:
1319 	case BPF_JMP | BPF_JSGE | BPF_K:
1320 	case BPF_JMP | BPF_JSLE | BPF_K:
1321 	case BPF_JMP32 | BPF_JEQ | BPF_K:
1322 	case BPF_JMP32 | BPF_JGT | BPF_K:
1323 	case BPF_JMP32 | BPF_JLT | BPF_K:
1324 	case BPF_JMP32 | BPF_JGE | BPF_K:
1325 	case BPF_JMP32 | BPF_JLE | BPF_K:
1326 	case BPF_JMP32 | BPF_JNE | BPF_K:
1327 	case BPF_JMP32 | BPF_JSGT | BPF_K:
1328 	case BPF_JMP32 | BPF_JSLT | BPF_K:
1329 	case BPF_JMP32 | BPF_JSGE | BPF_K:
1330 	case BPF_JMP32 | BPF_JSLE | BPF_K:
1331 		if (is_addsub_imm(imm)) {
1332 			emit(A64_CMP_I(is64, dst, imm), ctx);
1333 		} else if (is_addsub_imm(-imm)) {
1334 			emit(A64_CMN_I(is64, dst, -imm), ctx);
1335 		} else {
1336 			emit_a64_mov_i(is64, tmp, imm, ctx);
1337 			emit(A64_CMP(is64, dst, tmp), ctx);
1338 		}
1339 		goto emit_cond_jmp;
1340 	case BPF_JMP | BPF_JSET | BPF_K:
1341 	case BPF_JMP32 | BPF_JSET | BPF_K:
1342 		a64_insn = A64_TST_I(is64, dst, imm);
1343 		if (a64_insn != AARCH64_BREAK_FAULT) {
1344 			emit(a64_insn, ctx);
1345 		} else {
1346 			emit_a64_mov_i(is64, tmp, imm, ctx);
1347 			emit(A64_TST(is64, dst, tmp), ctx);
1348 		}
1349 		goto emit_cond_jmp;
1350 	/* function call */
1351 	case BPF_JMP | BPF_CALL:
1352 	{
1353 		const u8 r0 = bpf2a64[BPF_REG_0];
1354 		bool func_addr_fixed;
1355 		u64 func_addr;
1356 		u32 cpu_offset;
1357 
1358 		/* Implement helper call to bpf_get_smp_processor_id() inline */
1359 		if (insn->src_reg == 0 && insn->imm == BPF_FUNC_get_smp_processor_id) {
1360 			cpu_offset = offsetof(struct thread_info, cpu);
1361 
1362 			emit(A64_MRS_SP_EL0(tmp), ctx);
1363 			if (is_lsi_offset(cpu_offset, 2)) {
1364 				emit(A64_LDR32I(r0, tmp, cpu_offset), ctx);
1365 			} else {
1366 				emit_a64_mov_i(1, tmp2, cpu_offset, ctx);
1367 				emit(A64_LDR32(r0, tmp, tmp2), ctx);
1368 			}
1369 			break;
1370 		}
1371 
1372 		/* Implement helper call to bpf_get_current_task/_btf() inline */
1373 		if (insn->src_reg == 0 && (insn->imm == BPF_FUNC_get_current_task ||
1374 					   insn->imm == BPF_FUNC_get_current_task_btf)) {
1375 			emit(A64_MRS_SP_EL0(r0), ctx);
1376 			break;
1377 		}
1378 
1379 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
1380 					    &func_addr, &func_addr_fixed);
1381 		if (ret < 0)
1382 			return ret;
1383 		emit_call(func_addr, ctx);
1384 		emit(A64_MOV(1, r0, A64_R(0)), ctx);
1385 		break;
1386 	}
1387 	/* tail call */
1388 	case BPF_JMP | BPF_TAIL_CALL:
1389 		if (emit_bpf_tail_call(ctx))
1390 			return -EFAULT;
1391 		break;
1392 	/* function return */
1393 	case BPF_JMP | BPF_EXIT:
1394 		/* Optimization: when last instruction is EXIT,
1395 		   simply fallthrough to epilogue. */
1396 		if (i == ctx->prog->len - 1)
1397 			break;
1398 		jmp_offset = epilogue_offset(ctx);
1399 		check_imm26(jmp_offset);
1400 		emit(A64_B(jmp_offset), ctx);
1401 		break;
1402 
1403 	/* dst = imm64 */
1404 	case BPF_LD | BPF_IMM | BPF_DW:
1405 	{
1406 		const struct bpf_insn insn1 = insn[1];
1407 		u64 imm64;
1408 
1409 		imm64 = (u64)insn1.imm << 32 | (u32)imm;
1410 		if (bpf_pseudo_func(insn))
1411 			emit_addr_mov_i64(dst, imm64, ctx);
1412 		else
1413 			emit_a64_mov_i64(dst, imm64, ctx);
1414 
1415 		return 1;
1416 	}
1417 
1418 	/* LDX: dst = (u64)*(unsigned size *)(src + off) */
1419 	case BPF_LDX | BPF_MEM | BPF_W:
1420 	case BPF_LDX | BPF_MEM | BPF_H:
1421 	case BPF_LDX | BPF_MEM | BPF_B:
1422 	case BPF_LDX | BPF_MEM | BPF_DW:
1423 	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1424 	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1425 	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1426 	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1427 	/* LDXS: dst_reg = (s64)*(signed size *)(src_reg + off) */
1428 	case BPF_LDX | BPF_MEMSX | BPF_B:
1429 	case BPF_LDX | BPF_MEMSX | BPF_H:
1430 	case BPF_LDX | BPF_MEMSX | BPF_W:
1431 	case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
1432 	case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
1433 	case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
1434 	case BPF_LDX | BPF_PROBE_MEM32 | BPF_B:
1435 	case BPF_LDX | BPF_PROBE_MEM32 | BPF_H:
1436 	case BPF_LDX | BPF_PROBE_MEM32 | BPF_W:
1437 	case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW:
1438 		if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) {
1439 			emit(A64_ADD(1, tmp2, src, arena_vm_base), ctx);
1440 			src = tmp2;
1441 		}
1442 		if (src == fp) {
1443 			src_adj = A64_SP;
1444 			off_adj = off + ctx->stack_size;
1445 		} else {
1446 			src_adj = src;
1447 			off_adj = off;
1448 		}
1449 		sign_extend = (BPF_MODE(insn->code) == BPF_MEMSX ||
1450 				BPF_MODE(insn->code) == BPF_PROBE_MEMSX);
1451 		switch (BPF_SIZE(code)) {
1452 		case BPF_W:
1453 			if (is_lsi_offset(off_adj, 2)) {
1454 				if (sign_extend)
1455 					emit(A64_LDRSWI(dst, src_adj, off_adj), ctx);
1456 				else
1457 					emit(A64_LDR32I(dst, src_adj, off_adj), ctx);
1458 			} else {
1459 				emit_a64_mov_i(1, tmp, off, ctx);
1460 				if (sign_extend)
1461 					emit(A64_LDRSW(dst, src, tmp), ctx);
1462 				else
1463 					emit(A64_LDR32(dst, src, tmp), ctx);
1464 			}
1465 			break;
1466 		case BPF_H:
1467 			if (is_lsi_offset(off_adj, 1)) {
1468 				if (sign_extend)
1469 					emit(A64_LDRSHI(dst, src_adj, off_adj), ctx);
1470 				else
1471 					emit(A64_LDRHI(dst, src_adj, off_adj), ctx);
1472 			} else {
1473 				emit_a64_mov_i(1, tmp, off, ctx);
1474 				if (sign_extend)
1475 					emit(A64_LDRSH(dst, src, tmp), ctx);
1476 				else
1477 					emit(A64_LDRH(dst, src, tmp), ctx);
1478 			}
1479 			break;
1480 		case BPF_B:
1481 			if (is_lsi_offset(off_adj, 0)) {
1482 				if (sign_extend)
1483 					emit(A64_LDRSBI(dst, src_adj, off_adj), ctx);
1484 				else
1485 					emit(A64_LDRBI(dst, src_adj, off_adj), ctx);
1486 			} else {
1487 				emit_a64_mov_i(1, tmp, off, ctx);
1488 				if (sign_extend)
1489 					emit(A64_LDRSB(dst, src, tmp), ctx);
1490 				else
1491 					emit(A64_LDRB(dst, src, tmp), ctx);
1492 			}
1493 			break;
1494 		case BPF_DW:
1495 			if (is_lsi_offset(off_adj, 3)) {
1496 				emit(A64_LDR64I(dst, src_adj, off_adj), ctx);
1497 			} else {
1498 				emit_a64_mov_i(1, tmp, off, ctx);
1499 				emit(A64_LDR64(dst, src, tmp), ctx);
1500 			}
1501 			break;
1502 		}
1503 
1504 		ret = add_exception_handler(insn, ctx, dst);
1505 		if (ret)
1506 			return ret;
1507 		break;
1508 
1509 	/* speculation barrier */
1510 	case BPF_ST | BPF_NOSPEC:
1511 		/*
1512 		 * Nothing required here.
1513 		 *
1514 		 * In case of arm64, we rely on the firmware mitigation of
1515 		 * Speculative Store Bypass as controlled via the ssbd kernel
1516 		 * parameter. Whenever the mitigation is enabled, it works
1517 		 * for all of the kernel code with no need to provide any
1518 		 * additional instructions.
1519 		 */
1520 		break;
1521 
1522 	/* ST: *(size *)(dst + off) = imm */
1523 	case BPF_ST | BPF_MEM | BPF_W:
1524 	case BPF_ST | BPF_MEM | BPF_H:
1525 	case BPF_ST | BPF_MEM | BPF_B:
1526 	case BPF_ST | BPF_MEM | BPF_DW:
1527 	case BPF_ST | BPF_PROBE_MEM32 | BPF_B:
1528 	case BPF_ST | BPF_PROBE_MEM32 | BPF_H:
1529 	case BPF_ST | BPF_PROBE_MEM32 | BPF_W:
1530 	case BPF_ST | BPF_PROBE_MEM32 | BPF_DW:
1531 		if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) {
1532 			emit(A64_ADD(1, tmp2, dst, arena_vm_base), ctx);
1533 			dst = tmp2;
1534 		}
1535 		if (dst == fp) {
1536 			dst_adj = A64_SP;
1537 			off_adj = off + ctx->stack_size;
1538 		} else {
1539 			dst_adj = dst;
1540 			off_adj = off;
1541 		}
1542 		/* Load imm to a register then store it */
1543 		emit_a64_mov_i(1, tmp, imm, ctx);
1544 		switch (BPF_SIZE(code)) {
1545 		case BPF_W:
1546 			if (is_lsi_offset(off_adj, 2)) {
1547 				emit(A64_STR32I(tmp, dst_adj, off_adj), ctx);
1548 			} else {
1549 				emit_a64_mov_i(1, tmp2, off, ctx);
1550 				emit(A64_STR32(tmp, dst, tmp2), ctx);
1551 			}
1552 			break;
1553 		case BPF_H:
1554 			if (is_lsi_offset(off_adj, 1)) {
1555 				emit(A64_STRHI(tmp, dst_adj, off_adj), ctx);
1556 			} else {
1557 				emit_a64_mov_i(1, tmp2, off, ctx);
1558 				emit(A64_STRH(tmp, dst, tmp2), ctx);
1559 			}
1560 			break;
1561 		case BPF_B:
1562 			if (is_lsi_offset(off_adj, 0)) {
1563 				emit(A64_STRBI(tmp, dst_adj, off_adj), ctx);
1564 			} else {
1565 				emit_a64_mov_i(1, tmp2, off, ctx);
1566 				emit(A64_STRB(tmp, dst, tmp2), ctx);
1567 			}
1568 			break;
1569 		case BPF_DW:
1570 			if (is_lsi_offset(off_adj, 3)) {
1571 				emit(A64_STR64I(tmp, dst_adj, off_adj), ctx);
1572 			} else {
1573 				emit_a64_mov_i(1, tmp2, off, ctx);
1574 				emit(A64_STR64(tmp, dst, tmp2), ctx);
1575 			}
1576 			break;
1577 		}
1578 
1579 		ret = add_exception_handler(insn, ctx, dst);
1580 		if (ret)
1581 			return ret;
1582 		break;
1583 
1584 	/* STX: *(size *)(dst + off) = src */
1585 	case BPF_STX | BPF_MEM | BPF_W:
1586 	case BPF_STX | BPF_MEM | BPF_H:
1587 	case BPF_STX | BPF_MEM | BPF_B:
1588 	case BPF_STX | BPF_MEM | BPF_DW:
1589 	case BPF_STX | BPF_PROBE_MEM32 | BPF_B:
1590 	case BPF_STX | BPF_PROBE_MEM32 | BPF_H:
1591 	case BPF_STX | BPF_PROBE_MEM32 | BPF_W:
1592 	case BPF_STX | BPF_PROBE_MEM32 | BPF_DW:
1593 		if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) {
1594 			emit(A64_ADD(1, tmp2, dst, arena_vm_base), ctx);
1595 			dst = tmp2;
1596 		}
1597 		if (dst == fp) {
1598 			dst_adj = A64_SP;
1599 			off_adj = off + ctx->stack_size;
1600 		} else {
1601 			dst_adj = dst;
1602 			off_adj = off;
1603 		}
1604 		switch (BPF_SIZE(code)) {
1605 		case BPF_W:
1606 			if (is_lsi_offset(off_adj, 2)) {
1607 				emit(A64_STR32I(src, dst_adj, off_adj), ctx);
1608 			} else {
1609 				emit_a64_mov_i(1, tmp, off, ctx);
1610 				emit(A64_STR32(src, dst, tmp), ctx);
1611 			}
1612 			break;
1613 		case BPF_H:
1614 			if (is_lsi_offset(off_adj, 1)) {
1615 				emit(A64_STRHI(src, dst_adj, off_adj), ctx);
1616 			} else {
1617 				emit_a64_mov_i(1, tmp, off, ctx);
1618 				emit(A64_STRH(src, dst, tmp), ctx);
1619 			}
1620 			break;
1621 		case BPF_B:
1622 			if (is_lsi_offset(off_adj, 0)) {
1623 				emit(A64_STRBI(src, dst_adj, off_adj), ctx);
1624 			} else {
1625 				emit_a64_mov_i(1, tmp, off, ctx);
1626 				emit(A64_STRB(src, dst, tmp), ctx);
1627 			}
1628 			break;
1629 		case BPF_DW:
1630 			if (is_lsi_offset(off_adj, 3)) {
1631 				emit(A64_STR64I(src, dst_adj, off_adj), ctx);
1632 			} else {
1633 				emit_a64_mov_i(1, tmp, off, ctx);
1634 				emit(A64_STR64(src, dst, tmp), ctx);
1635 			}
1636 			break;
1637 		}
1638 
1639 		ret = add_exception_handler(insn, ctx, dst);
1640 		if (ret)
1641 			return ret;
1642 		break;
1643 
1644 	case BPF_STX | BPF_ATOMIC | BPF_W:
1645 	case BPF_STX | BPF_ATOMIC | BPF_DW:
1646 	case BPF_STX | BPF_PROBE_ATOMIC | BPF_W:
1647 	case BPF_STX | BPF_PROBE_ATOMIC | BPF_DW:
1648 		if (cpus_have_cap(ARM64_HAS_LSE_ATOMICS))
1649 			ret = emit_lse_atomic(insn, ctx);
1650 		else
1651 			ret = emit_ll_sc_atomic(insn, ctx);
1652 		if (ret)
1653 			return ret;
1654 
1655 		ret = add_exception_handler(insn, ctx, dst);
1656 		if (ret)
1657 			return ret;
1658 		break;
1659 
1660 	default:
1661 		pr_err_once("unknown opcode %02x\n", code);
1662 		return -EINVAL;
1663 	}
1664 
1665 	return 0;
1666 }
1667 
1668 static int build_body(struct jit_ctx *ctx, bool extra_pass)
1669 {
1670 	const struct bpf_prog *prog = ctx->prog;
1671 	int i;
1672 
1673 	/*
1674 	 * - offset[0] offset of the end of prologue,
1675 	 *   start of the 1st instruction.
1676 	 * - offset[1] - offset of the end of 1st instruction,
1677 	 *   start of the 2nd instruction
1678 	 * [....]
1679 	 * - offset[3] - offset of the end of 3rd instruction,
1680 	 *   start of 4th instruction
1681 	 */
1682 	for (i = 0; i < prog->len; i++) {
1683 		const struct bpf_insn *insn = &prog->insnsi[i];
1684 		int ret;
1685 
1686 		ctx->offset[i] = ctx->idx;
1687 		ret = build_insn(insn, ctx, extra_pass);
1688 		if (ret > 0) {
1689 			i++;
1690 			ctx->offset[i] = ctx->idx;
1691 			continue;
1692 		}
1693 		if (ret)
1694 			return ret;
1695 	}
1696 	/*
1697 	 * offset is allocated with prog->len + 1 so fill in
1698 	 * the last element with the offset after the last
1699 	 * instruction (end of program)
1700 	 */
1701 	ctx->offset[i] = ctx->idx;
1702 
1703 	return 0;
1704 }
1705 
1706 static int validate_code(struct jit_ctx *ctx)
1707 {
1708 	int i;
1709 
1710 	for (i = 0; i < ctx->idx; i++) {
1711 		u32 a64_insn = le32_to_cpu(ctx->image[i]);
1712 
1713 		if (a64_insn == AARCH64_BREAK_FAULT)
1714 			return -1;
1715 	}
1716 	return 0;
1717 }
1718 
1719 static int validate_ctx(struct jit_ctx *ctx)
1720 {
1721 	if (validate_code(ctx))
1722 		return -1;
1723 
1724 	if (WARN_ON_ONCE(ctx->exentry_idx != ctx->prog->aux->num_exentries))
1725 		return -1;
1726 
1727 	return 0;
1728 }
1729 
1730 static inline void bpf_flush_icache(void *start, void *end)
1731 {
1732 	flush_icache_range((unsigned long)start, (unsigned long)end);
1733 }
1734 
1735 struct arm64_jit_data {
1736 	struct bpf_binary_header *header;
1737 	u8 *ro_image;
1738 	struct bpf_binary_header *ro_header;
1739 	struct jit_ctx ctx;
1740 };
1741 
1742 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1743 {
1744 	int image_size, prog_size, extable_size, extable_align, extable_offset;
1745 	struct bpf_prog *tmp, *orig_prog = prog;
1746 	struct bpf_binary_header *header;
1747 	struct bpf_binary_header *ro_header;
1748 	struct arm64_jit_data *jit_data;
1749 	bool was_classic = bpf_prog_was_classic(prog);
1750 	bool tmp_blinded = false;
1751 	bool extra_pass = false;
1752 	struct jit_ctx ctx;
1753 	u8 *image_ptr;
1754 	u8 *ro_image_ptr;
1755 	int body_idx;
1756 	int exentry_idx;
1757 
1758 	if (!prog->jit_requested)
1759 		return orig_prog;
1760 
1761 	tmp = bpf_jit_blind_constants(prog);
1762 	/* If blinding was requested and we failed during blinding,
1763 	 * we must fall back to the interpreter.
1764 	 */
1765 	if (IS_ERR(tmp))
1766 		return orig_prog;
1767 	if (tmp != prog) {
1768 		tmp_blinded = true;
1769 		prog = tmp;
1770 	}
1771 
1772 	jit_data = prog->aux->jit_data;
1773 	if (!jit_data) {
1774 		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1775 		if (!jit_data) {
1776 			prog = orig_prog;
1777 			goto out;
1778 		}
1779 		prog->aux->jit_data = jit_data;
1780 	}
1781 	if (jit_data->ctx.offset) {
1782 		ctx = jit_data->ctx;
1783 		ro_image_ptr = jit_data->ro_image;
1784 		ro_header = jit_data->ro_header;
1785 		header = jit_data->header;
1786 		image_ptr = (void *)header + ((void *)ro_image_ptr
1787 						 - (void *)ro_header);
1788 		extra_pass = true;
1789 		prog_size = sizeof(u32) * ctx.idx;
1790 		goto skip_init_ctx;
1791 	}
1792 	memset(&ctx, 0, sizeof(ctx));
1793 	ctx.prog = prog;
1794 
1795 	ctx.offset = kvcalloc(prog->len + 1, sizeof(int), GFP_KERNEL);
1796 	if (ctx.offset == NULL) {
1797 		prog = orig_prog;
1798 		goto out_off;
1799 	}
1800 
1801 	ctx.user_vm_start = bpf_arena_get_user_vm_start(prog->aux->arena);
1802 	ctx.arena_vm_start = bpf_arena_get_kern_vm_start(prog->aux->arena);
1803 
1804 	/* Pass 1: Estimate the maximum image size.
1805 	 *
1806 	 * BPF line info needs ctx->offset[i] to be the offset of
1807 	 * instruction[i] in jited image, so build prologue first.
1808 	 */
1809 	if (build_prologue(&ctx, was_classic)) {
1810 		prog = orig_prog;
1811 		goto out_off;
1812 	}
1813 
1814 	if (build_body(&ctx, extra_pass)) {
1815 		prog = orig_prog;
1816 		goto out_off;
1817 	}
1818 
1819 	ctx.epilogue_offset = ctx.idx;
1820 	build_epilogue(&ctx);
1821 	build_plt(&ctx);
1822 
1823 	extable_align = __alignof__(struct exception_table_entry);
1824 	extable_size = prog->aux->num_exentries *
1825 		sizeof(struct exception_table_entry);
1826 
1827 	/* Now we know the maximum image size. */
1828 	prog_size = sizeof(u32) * ctx.idx;
1829 	/* also allocate space for plt target */
1830 	extable_offset = round_up(prog_size + PLT_TARGET_SIZE, extable_align);
1831 	image_size = extable_offset + extable_size;
1832 	ro_header = bpf_jit_binary_pack_alloc(image_size, &ro_image_ptr,
1833 					      sizeof(u32), &header, &image_ptr,
1834 					      jit_fill_hole);
1835 	if (!ro_header) {
1836 		prog = orig_prog;
1837 		goto out_off;
1838 	}
1839 
1840 	/* Pass 2: Determine jited position and result for each instruction */
1841 
1842 	/*
1843 	 * Use the image(RW) for writing the JITed instructions. But also save
1844 	 * the ro_image(RX) for calculating the offsets in the image. The RW
1845 	 * image will be later copied to the RX image from where the program
1846 	 * will run. The bpf_jit_binary_pack_finalize() will do this copy in the
1847 	 * final step.
1848 	 */
1849 	ctx.image = (__le32 *)image_ptr;
1850 	ctx.ro_image = (__le32 *)ro_image_ptr;
1851 	if (extable_size)
1852 		prog->aux->extable = (void *)ro_image_ptr + extable_offset;
1853 skip_init_ctx:
1854 	ctx.idx = 0;
1855 	ctx.exentry_idx = 0;
1856 	ctx.write = true;
1857 
1858 	build_prologue(&ctx, was_classic);
1859 
1860 	/* Record exentry_idx and body_idx before first build_body */
1861 	exentry_idx = ctx.exentry_idx;
1862 	body_idx = ctx.idx;
1863 	/* Dont write body instructions to memory for now */
1864 	ctx.write = false;
1865 
1866 	if (build_body(&ctx, extra_pass)) {
1867 		prog = orig_prog;
1868 		goto out_free_hdr;
1869 	}
1870 
1871 	ctx.epilogue_offset = ctx.idx;
1872 	ctx.exentry_idx = exentry_idx;
1873 	ctx.idx = body_idx;
1874 	ctx.write = true;
1875 
1876 	/* Pass 3: Adjust jump offset and write final image */
1877 	if (build_body(&ctx, extra_pass) ||
1878 		WARN_ON_ONCE(ctx.idx != ctx.epilogue_offset)) {
1879 		prog = orig_prog;
1880 		goto out_free_hdr;
1881 	}
1882 
1883 	build_epilogue(&ctx);
1884 	build_plt(&ctx);
1885 
1886 	/* Extra pass to validate JITed code. */
1887 	if (validate_ctx(&ctx)) {
1888 		prog = orig_prog;
1889 		goto out_free_hdr;
1890 	}
1891 
1892 	/* update the real prog size */
1893 	prog_size = sizeof(u32) * ctx.idx;
1894 
1895 	/* And we're done. */
1896 	if (bpf_jit_enable > 1)
1897 		bpf_jit_dump(prog->len, prog_size, 2, ctx.image);
1898 
1899 	if (!prog->is_func || extra_pass) {
1900 		/* The jited image may shrink since the jited result for
1901 		 * BPF_CALL to subprog may be changed from indirect call
1902 		 * to direct call.
1903 		 */
1904 		if (extra_pass && ctx.idx > jit_data->ctx.idx) {
1905 			pr_err_once("multi-func JIT bug %d > %d\n",
1906 				    ctx.idx, jit_data->ctx.idx);
1907 			prog->bpf_func = NULL;
1908 			prog->jited = 0;
1909 			prog->jited_len = 0;
1910 			goto out_free_hdr;
1911 		}
1912 		if (WARN_ON(bpf_jit_binary_pack_finalize(ro_header, header))) {
1913 			/* ro_header has been freed */
1914 			ro_header = NULL;
1915 			prog = orig_prog;
1916 			goto out_off;
1917 		}
1918 		/*
1919 		 * The instructions have now been copied to the ROX region from
1920 		 * where they will execute. Now the data cache has to be cleaned to
1921 		 * the PoU and the I-cache has to be invalidated for the VAs.
1922 		 */
1923 		bpf_flush_icache(ro_header, ctx.ro_image + ctx.idx);
1924 	} else {
1925 		jit_data->ctx = ctx;
1926 		jit_data->ro_image = ro_image_ptr;
1927 		jit_data->header = header;
1928 		jit_data->ro_header = ro_header;
1929 	}
1930 
1931 	prog->bpf_func = (void *)ctx.ro_image;
1932 	prog->jited = 1;
1933 	prog->jited_len = prog_size;
1934 
1935 	if (!prog->is_func || extra_pass) {
1936 		int i;
1937 
1938 		/* offset[prog->len] is the size of program */
1939 		for (i = 0; i <= prog->len; i++)
1940 			ctx.offset[i] *= AARCH64_INSN_SIZE;
1941 		bpf_prog_fill_jited_linfo(prog, ctx.offset + 1);
1942 out_off:
1943 		kvfree(ctx.offset);
1944 		kfree(jit_data);
1945 		prog->aux->jit_data = NULL;
1946 	}
1947 out:
1948 	if (tmp_blinded)
1949 		bpf_jit_prog_release_other(prog, prog == orig_prog ?
1950 					   tmp : orig_prog);
1951 	return prog;
1952 
1953 out_free_hdr:
1954 	if (header) {
1955 		bpf_arch_text_copy(&ro_header->size, &header->size,
1956 				   sizeof(header->size));
1957 		bpf_jit_binary_pack_free(ro_header, header);
1958 	}
1959 	goto out_off;
1960 }
1961 
1962 bool bpf_jit_supports_kfunc_call(void)
1963 {
1964 	return true;
1965 }
1966 
1967 void *bpf_arch_text_copy(void *dst, void *src, size_t len)
1968 {
1969 	if (!aarch64_insn_copy(dst, src, len))
1970 		return ERR_PTR(-EINVAL);
1971 	return dst;
1972 }
1973 
1974 u64 bpf_jit_alloc_exec_limit(void)
1975 {
1976 	return VMALLOC_END - VMALLOC_START;
1977 }
1978 
1979 /* Indicate the JIT backend supports mixing bpf2bpf and tailcalls. */
1980 bool bpf_jit_supports_subprog_tailcalls(void)
1981 {
1982 	return true;
1983 }
1984 
1985 static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
1986 			    int args_off, int retval_off, int run_ctx_off,
1987 			    bool save_ret)
1988 {
1989 	__le32 *branch;
1990 	u64 enter_prog;
1991 	u64 exit_prog;
1992 	struct bpf_prog *p = l->link.prog;
1993 	int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
1994 
1995 	enter_prog = (u64)bpf_trampoline_enter(p);
1996 	exit_prog = (u64)bpf_trampoline_exit(p);
1997 
1998 	if (l->cookie == 0) {
1999 		/* if cookie is zero, one instruction is enough to store it */
2000 		emit(A64_STR64I(A64_ZR, A64_SP, run_ctx_off + cookie_off), ctx);
2001 	} else {
2002 		emit_a64_mov_i64(A64_R(10), l->cookie, ctx);
2003 		emit(A64_STR64I(A64_R(10), A64_SP, run_ctx_off + cookie_off),
2004 		     ctx);
2005 	}
2006 
2007 	/* save p to callee saved register x19 to avoid loading p with mov_i64
2008 	 * each time.
2009 	 */
2010 	emit_addr_mov_i64(A64_R(19), (const u64)p, ctx);
2011 
2012 	/* arg1: prog */
2013 	emit(A64_MOV(1, A64_R(0), A64_R(19)), ctx);
2014 	/* arg2: &run_ctx */
2015 	emit(A64_ADD_I(1, A64_R(1), A64_SP, run_ctx_off), ctx);
2016 
2017 	emit_call(enter_prog, ctx);
2018 
2019 	/* save return value to callee saved register x20 */
2020 	emit(A64_MOV(1, A64_R(20), A64_R(0)), ctx);
2021 
2022 	/* if (__bpf_prog_enter(prog) == 0)
2023 	 *         goto skip_exec_of_prog;
2024 	 */
2025 	branch = ctx->image + ctx->idx;
2026 	emit(A64_NOP, ctx);
2027 
2028 	emit(A64_ADD_I(1, A64_R(0), A64_SP, args_off), ctx);
2029 	if (!p->jited)
2030 		emit_addr_mov_i64(A64_R(1), (const u64)p->insnsi, ctx);
2031 
2032 	emit_call((const u64)p->bpf_func, ctx);
2033 
2034 	if (save_ret)
2035 		emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);
2036 
2037 	if (ctx->image) {
2038 		int offset = &ctx->image[ctx->idx] - branch;
2039 		*branch = cpu_to_le32(A64_CBZ(1, A64_R(0), offset));
2040 	}
2041 
2042 	/* arg1: prog */
2043 	emit(A64_MOV(1, A64_R(0), A64_R(19)), ctx);
2044 	/* arg2: start time */
2045 	emit(A64_MOV(1, A64_R(1), A64_R(20)), ctx);
2046 	/* arg3: &run_ctx */
2047 	emit(A64_ADD_I(1, A64_R(2), A64_SP, run_ctx_off), ctx);
2048 
2049 	emit_call(exit_prog, ctx);
2050 }
2051 
2052 static void invoke_bpf_mod_ret(struct jit_ctx *ctx, struct bpf_tramp_links *tl,
2053 			       int args_off, int retval_off, int run_ctx_off,
2054 			       __le32 **branches)
2055 {
2056 	int i;
2057 
2058 	/* The first fmod_ret program will receive a garbage return value.
2059 	 * Set this to 0 to avoid confusing the program.
2060 	 */
2061 	emit(A64_STR64I(A64_ZR, A64_SP, retval_off), ctx);
2062 	for (i = 0; i < tl->nr_links; i++) {
2063 		invoke_bpf_prog(ctx, tl->links[i], args_off, retval_off,
2064 				run_ctx_off, true);
2065 		/* if (*(u64 *)(sp + retval_off) !=  0)
2066 		 *	goto do_fexit;
2067 		 */
2068 		emit(A64_LDR64I(A64_R(10), A64_SP, retval_off), ctx);
2069 		/* Save the location of branch, and generate a nop.
2070 		 * This nop will be replaced with a cbnz later.
2071 		 */
2072 		branches[i] = ctx->image + ctx->idx;
2073 		emit(A64_NOP, ctx);
2074 	}
2075 }
2076 
2077 static void save_args(struct jit_ctx *ctx, int args_off, int nregs)
2078 {
2079 	int i;
2080 
2081 	for (i = 0; i < nregs; i++) {
2082 		emit(A64_STR64I(i, A64_SP, args_off), ctx);
2083 		args_off += 8;
2084 	}
2085 }
2086 
2087 static void restore_args(struct jit_ctx *ctx, int args_off, int nregs)
2088 {
2089 	int i;
2090 
2091 	for (i = 0; i < nregs; i++) {
2092 		emit(A64_LDR64I(i, A64_SP, args_off), ctx);
2093 		args_off += 8;
2094 	}
2095 }
2096 
2097 /* Based on the x86's implementation of arch_prepare_bpf_trampoline().
2098  *
2099  * bpf prog and function entry before bpf trampoline hooked:
2100  *   mov x9, lr
2101  *   nop
2102  *
2103  * bpf prog and function entry after bpf trampoline hooked:
2104  *   mov x9, lr
2105  *   bl  <bpf_trampoline or plt>
2106  *
2107  */
2108 static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
2109 			      struct bpf_tramp_links *tlinks, void *func_addr,
2110 			      int nregs, u32 flags)
2111 {
2112 	int i;
2113 	int stack_size;
2114 	int retaddr_off;
2115 	int regs_off;
2116 	int retval_off;
2117 	int args_off;
2118 	int nregs_off;
2119 	int ip_off;
2120 	int run_ctx_off;
2121 	struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
2122 	struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
2123 	struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
2124 	bool save_ret;
2125 	__le32 **branches = NULL;
2126 
2127 	/* trampoline stack layout:
2128 	 *                  [ parent ip         ]
2129 	 *                  [ FP                ]
2130 	 * SP + retaddr_off [ self ip           ]
2131 	 *                  [ FP                ]
2132 	 *
2133 	 *                  [ padding           ] align SP to multiples of 16
2134 	 *
2135 	 *                  [ x20               ] callee saved reg x20
2136 	 * SP + regs_off    [ x19               ] callee saved reg x19
2137 	 *
2138 	 * SP + retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
2139 	 *                                        BPF_TRAMP_F_RET_FENTRY_RET
2140 	 *
2141 	 *                  [ arg reg N         ]
2142 	 *                  [ ...               ]
2143 	 * SP + args_off    [ arg reg 1         ]
2144 	 *
2145 	 * SP + nregs_off   [ arg regs count    ]
2146 	 *
2147 	 * SP + ip_off      [ traced function   ] BPF_TRAMP_F_IP_ARG flag
2148 	 *
2149 	 * SP + run_ctx_off [ bpf_tramp_run_ctx ]
2150 	 */
2151 
2152 	stack_size = 0;
2153 	run_ctx_off = stack_size;
2154 	/* room for bpf_tramp_run_ctx */
2155 	stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
2156 
2157 	ip_off = stack_size;
2158 	/* room for IP address argument */
2159 	if (flags & BPF_TRAMP_F_IP_ARG)
2160 		stack_size += 8;
2161 
2162 	nregs_off = stack_size;
2163 	/* room for args count */
2164 	stack_size += 8;
2165 
2166 	args_off = stack_size;
2167 	/* room for args */
2168 	stack_size += nregs * 8;
2169 
2170 	/* room for return value */
2171 	retval_off = stack_size;
2172 	save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
2173 	if (save_ret)
2174 		stack_size += 8;
2175 
2176 	/* room for callee saved registers, currently x19 and x20 are used */
2177 	regs_off = stack_size;
2178 	stack_size += 16;
2179 
2180 	/* round up to multiples of 16 to avoid SPAlignmentFault */
2181 	stack_size = round_up(stack_size, 16);
2182 
2183 	/* return address locates above FP */
2184 	retaddr_off = stack_size + 8;
2185 
2186 	/* bpf trampoline may be invoked by 3 instruction types:
2187 	 * 1. bl, attached to bpf prog or kernel function via short jump
2188 	 * 2. br, attached to bpf prog or kernel function via long jump
2189 	 * 3. blr, working as a function pointer, used by struct_ops.
2190 	 * So BTI_JC should used here to support both br and blr.
2191 	 */
2192 	emit_bti(A64_BTI_JC, ctx);
2193 
2194 	/* frame for parent function */
2195 	emit(A64_PUSH(A64_FP, A64_R(9), A64_SP), ctx);
2196 	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
2197 
2198 	/* frame for patched function */
2199 	emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
2200 	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
2201 
2202 	/* allocate stack space */
2203 	emit(A64_SUB_I(1, A64_SP, A64_SP, stack_size), ctx);
2204 
2205 	if (flags & BPF_TRAMP_F_IP_ARG) {
2206 		/* save ip address of the traced function */
2207 		emit_addr_mov_i64(A64_R(10), (const u64)func_addr, ctx);
2208 		emit(A64_STR64I(A64_R(10), A64_SP, ip_off), ctx);
2209 	}
2210 
2211 	/* save arg regs count*/
2212 	emit(A64_MOVZ(1, A64_R(10), nregs, 0), ctx);
2213 	emit(A64_STR64I(A64_R(10), A64_SP, nregs_off), ctx);
2214 
2215 	/* save arg regs */
2216 	save_args(ctx, args_off, nregs);
2217 
2218 	/* save callee saved registers */
2219 	emit(A64_STR64I(A64_R(19), A64_SP, regs_off), ctx);
2220 	emit(A64_STR64I(A64_R(20), A64_SP, regs_off + 8), ctx);
2221 
2222 	if (flags & BPF_TRAMP_F_CALL_ORIG) {
2223 		emit_a64_mov_i64(A64_R(0), (const u64)im, ctx);
2224 		emit_call((const u64)__bpf_tramp_enter, ctx);
2225 	}
2226 
2227 	for (i = 0; i < fentry->nr_links; i++)
2228 		invoke_bpf_prog(ctx, fentry->links[i], args_off,
2229 				retval_off, run_ctx_off,
2230 				flags & BPF_TRAMP_F_RET_FENTRY_RET);
2231 
2232 	if (fmod_ret->nr_links) {
2233 		branches = kcalloc(fmod_ret->nr_links, sizeof(__le32 *),
2234 				   GFP_KERNEL);
2235 		if (!branches)
2236 			return -ENOMEM;
2237 
2238 		invoke_bpf_mod_ret(ctx, fmod_ret, args_off, retval_off,
2239 				   run_ctx_off, branches);
2240 	}
2241 
2242 	if (flags & BPF_TRAMP_F_CALL_ORIG) {
2243 		restore_args(ctx, args_off, nregs);
2244 		/* call original func */
2245 		emit(A64_LDR64I(A64_R(10), A64_SP, retaddr_off), ctx);
2246 		emit(A64_ADR(A64_LR, AARCH64_INSN_SIZE * 2), ctx);
2247 		emit(A64_RET(A64_R(10)), ctx);
2248 		/* store return value */
2249 		emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);
2250 		/* reserve a nop for bpf_tramp_image_put */
2251 		im->ip_after_call = ctx->ro_image + ctx->idx;
2252 		emit(A64_NOP, ctx);
2253 	}
2254 
2255 	/* update the branches saved in invoke_bpf_mod_ret with cbnz */
2256 	for (i = 0; i < fmod_ret->nr_links && ctx->image != NULL; i++) {
2257 		int offset = &ctx->image[ctx->idx] - branches[i];
2258 		*branches[i] = cpu_to_le32(A64_CBNZ(1, A64_R(10), offset));
2259 	}
2260 
2261 	for (i = 0; i < fexit->nr_links; i++)
2262 		invoke_bpf_prog(ctx, fexit->links[i], args_off, retval_off,
2263 				run_ctx_off, false);
2264 
2265 	if (flags & BPF_TRAMP_F_CALL_ORIG) {
2266 		im->ip_epilogue = ctx->ro_image + ctx->idx;
2267 		emit_a64_mov_i64(A64_R(0), (const u64)im, ctx);
2268 		emit_call((const u64)__bpf_tramp_exit, ctx);
2269 	}
2270 
2271 	if (flags & BPF_TRAMP_F_RESTORE_REGS)
2272 		restore_args(ctx, args_off, nregs);
2273 
2274 	/* restore callee saved register x19 and x20 */
2275 	emit(A64_LDR64I(A64_R(19), A64_SP, regs_off), ctx);
2276 	emit(A64_LDR64I(A64_R(20), A64_SP, regs_off + 8), ctx);
2277 
2278 	if (save_ret)
2279 		emit(A64_LDR64I(A64_R(0), A64_SP, retval_off), ctx);
2280 
2281 	/* reset SP  */
2282 	emit(A64_MOV(1, A64_SP, A64_FP), ctx);
2283 
2284 	/* pop frames  */
2285 	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
2286 	emit(A64_POP(A64_FP, A64_R(9), A64_SP), ctx);
2287 
2288 	if (flags & BPF_TRAMP_F_SKIP_FRAME) {
2289 		/* skip patched function, return to parent */
2290 		emit(A64_MOV(1, A64_LR, A64_R(9)), ctx);
2291 		emit(A64_RET(A64_R(9)), ctx);
2292 	} else {
2293 		/* return to patched function */
2294 		emit(A64_MOV(1, A64_R(10), A64_LR), ctx);
2295 		emit(A64_MOV(1, A64_LR, A64_R(9)), ctx);
2296 		emit(A64_RET(A64_R(10)), ctx);
2297 	}
2298 
2299 	kfree(branches);
2300 
2301 	return ctx->idx;
2302 }
2303 
2304 static int btf_func_model_nregs(const struct btf_func_model *m)
2305 {
2306 	int nregs = m->nr_args;
2307 	int i;
2308 
2309 	/* extra registers needed for struct argument */
2310 	for (i = 0; i < MAX_BPF_FUNC_ARGS; i++) {
2311 		/* The arg_size is at most 16 bytes, enforced by the verifier. */
2312 		if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
2313 			nregs += (m->arg_size[i] + 7) / 8 - 1;
2314 	}
2315 
2316 	return nregs;
2317 }
2318 
2319 int arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
2320 			     struct bpf_tramp_links *tlinks, void *func_addr)
2321 {
2322 	struct jit_ctx ctx = {
2323 		.image = NULL,
2324 		.idx = 0,
2325 	};
2326 	struct bpf_tramp_image im;
2327 	int nregs, ret;
2328 
2329 	nregs = btf_func_model_nregs(m);
2330 	/* the first 8 registers are used for arguments */
2331 	if (nregs > 8)
2332 		return -ENOTSUPP;
2333 
2334 	ret = prepare_trampoline(&ctx, &im, tlinks, func_addr, nregs, flags);
2335 	if (ret < 0)
2336 		return ret;
2337 
2338 	return ret < 0 ? ret : ret * AARCH64_INSN_SIZE;
2339 }
2340 
2341 void *arch_alloc_bpf_trampoline(unsigned int size)
2342 {
2343 	return bpf_prog_pack_alloc(size, jit_fill_hole);
2344 }
2345 
2346 void arch_free_bpf_trampoline(void *image, unsigned int size)
2347 {
2348 	bpf_prog_pack_free(image, size);
2349 }
2350 
2351 int arch_protect_bpf_trampoline(void *image, unsigned int size)
2352 {
2353 	return 0;
2354 }
2355 
2356 int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *ro_image,
2357 				void *ro_image_end, const struct btf_func_model *m,
2358 				u32 flags, struct bpf_tramp_links *tlinks,
2359 				void *func_addr)
2360 {
2361 	int ret, nregs;
2362 	void *image, *tmp;
2363 	u32 size = ro_image_end - ro_image;
2364 
2365 	/* image doesn't need to be in module memory range, so we can
2366 	 * use kvmalloc.
2367 	 */
2368 	image = kvmalloc(size, GFP_KERNEL);
2369 	if (!image)
2370 		return -ENOMEM;
2371 
2372 	struct jit_ctx ctx = {
2373 		.image = image,
2374 		.ro_image = ro_image,
2375 		.idx = 0,
2376 		.write = true,
2377 	};
2378 
2379 	nregs = btf_func_model_nregs(m);
2380 	/* the first 8 registers are used for arguments */
2381 	if (nregs > 8)
2382 		return -ENOTSUPP;
2383 
2384 	jit_fill_hole(image, (unsigned int)(ro_image_end - ro_image));
2385 	ret = prepare_trampoline(&ctx, im, tlinks, func_addr, nregs, flags);
2386 
2387 	if (ret > 0 && validate_code(&ctx) < 0) {
2388 		ret = -EINVAL;
2389 		goto out;
2390 	}
2391 
2392 	if (ret > 0)
2393 		ret *= AARCH64_INSN_SIZE;
2394 
2395 	tmp = bpf_arch_text_copy(ro_image, image, size);
2396 	if (IS_ERR(tmp)) {
2397 		ret = PTR_ERR(tmp);
2398 		goto out;
2399 	}
2400 
2401 	bpf_flush_icache(ro_image, ro_image + size);
2402 out:
2403 	kvfree(image);
2404 	return ret;
2405 }
2406 
2407 static bool is_long_jump(void *ip, void *target)
2408 {
2409 	long offset;
2410 
2411 	/* NULL target means this is a NOP */
2412 	if (!target)
2413 		return false;
2414 
2415 	offset = (long)target - (long)ip;
2416 	return offset < -SZ_128M || offset >= SZ_128M;
2417 }
2418 
2419 static int gen_branch_or_nop(enum aarch64_insn_branch_type type, void *ip,
2420 			     void *addr, void *plt, u32 *insn)
2421 {
2422 	void *target;
2423 
2424 	if (!addr) {
2425 		*insn = aarch64_insn_gen_nop();
2426 		return 0;
2427 	}
2428 
2429 	if (is_long_jump(ip, addr))
2430 		target = plt;
2431 	else
2432 		target = addr;
2433 
2434 	*insn = aarch64_insn_gen_branch_imm((unsigned long)ip,
2435 					    (unsigned long)target,
2436 					    type);
2437 
2438 	return *insn != AARCH64_BREAK_FAULT ? 0 : -EFAULT;
2439 }
2440 
2441 /* Replace the branch instruction from @ip to @old_addr in a bpf prog or a bpf
2442  * trampoline with the branch instruction from @ip to @new_addr. If @old_addr
2443  * or @new_addr is NULL, the old or new instruction is NOP.
2444  *
2445  * When @ip is the bpf prog entry, a bpf trampoline is being attached or
2446  * detached. Since bpf trampoline and bpf prog are allocated separately with
2447  * vmalloc, the address distance may exceed 128MB, the maximum branch range.
2448  * So long jump should be handled.
2449  *
2450  * When a bpf prog is constructed, a plt pointing to empty trampoline
2451  * dummy_tramp is placed at the end:
2452  *
2453  *      bpf_prog:
2454  *              mov x9, lr
2455  *              nop // patchsite
2456  *              ...
2457  *              ret
2458  *
2459  *      plt:
2460  *              ldr x10, target
2461  *              br x10
2462  *      target:
2463  *              .quad dummy_tramp // plt target
2464  *
2465  * This is also the state when no trampoline is attached.
2466  *
2467  * When a short-jump bpf trampoline is attached, the patchsite is patched
2468  * to a bl instruction to the trampoline directly:
2469  *
2470  *      bpf_prog:
2471  *              mov x9, lr
2472  *              bl <short-jump bpf trampoline address> // patchsite
2473  *              ...
2474  *              ret
2475  *
2476  *      plt:
2477  *              ldr x10, target
2478  *              br x10
2479  *      target:
2480  *              .quad dummy_tramp // plt target
2481  *
2482  * When a long-jump bpf trampoline is attached, the plt target is filled with
2483  * the trampoline address and the patchsite is patched to a bl instruction to
2484  * the plt:
2485  *
2486  *      bpf_prog:
2487  *              mov x9, lr
2488  *              bl plt // patchsite
2489  *              ...
2490  *              ret
2491  *
2492  *      plt:
2493  *              ldr x10, target
2494  *              br x10
2495  *      target:
2496  *              .quad <long-jump bpf trampoline address> // plt target
2497  *
2498  * The dummy_tramp is used to prevent another CPU from jumping to unknown
2499  * locations during the patching process, making the patching process easier.
2500  */
2501 int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
2502 		       void *old_addr, void *new_addr)
2503 {
2504 	int ret;
2505 	u32 old_insn;
2506 	u32 new_insn;
2507 	u32 replaced;
2508 	struct bpf_plt *plt = NULL;
2509 	unsigned long size = 0UL;
2510 	unsigned long offset = ~0UL;
2511 	enum aarch64_insn_branch_type branch_type;
2512 	char namebuf[KSYM_NAME_LEN];
2513 	void *image = NULL;
2514 	u64 plt_target = 0ULL;
2515 	bool poking_bpf_entry;
2516 
2517 	if (!__bpf_address_lookup((unsigned long)ip, &size, &offset, namebuf))
2518 		/* Only poking bpf text is supported. Since kernel function
2519 		 * entry is set up by ftrace, we reply on ftrace to poke kernel
2520 		 * functions.
2521 		 */
2522 		return -ENOTSUPP;
2523 
2524 	image = ip - offset;
2525 	/* zero offset means we're poking bpf prog entry */
2526 	poking_bpf_entry = (offset == 0UL);
2527 
2528 	/* bpf prog entry, find plt and the real patchsite */
2529 	if (poking_bpf_entry) {
2530 		/* plt locates at the end of bpf prog */
2531 		plt = image + size - PLT_TARGET_OFFSET;
2532 
2533 		/* skip to the nop instruction in bpf prog entry:
2534 		 * bti c // if BTI enabled
2535 		 * mov x9, x30
2536 		 * nop
2537 		 */
2538 		ip = image + POKE_OFFSET * AARCH64_INSN_SIZE;
2539 	}
2540 
2541 	/* long jump is only possible at bpf prog entry */
2542 	if (WARN_ON((is_long_jump(ip, new_addr) || is_long_jump(ip, old_addr)) &&
2543 		    !poking_bpf_entry))
2544 		return -EINVAL;
2545 
2546 	if (poke_type == BPF_MOD_CALL)
2547 		branch_type = AARCH64_INSN_BRANCH_LINK;
2548 	else
2549 		branch_type = AARCH64_INSN_BRANCH_NOLINK;
2550 
2551 	if (gen_branch_or_nop(branch_type, ip, old_addr, plt, &old_insn) < 0)
2552 		return -EFAULT;
2553 
2554 	if (gen_branch_or_nop(branch_type, ip, new_addr, plt, &new_insn) < 0)
2555 		return -EFAULT;
2556 
2557 	if (is_long_jump(ip, new_addr))
2558 		plt_target = (u64)new_addr;
2559 	else if (is_long_jump(ip, old_addr))
2560 		/* if the old target is a long jump and the new target is not,
2561 		 * restore the plt target to dummy_tramp, so there is always a
2562 		 * legal and harmless address stored in plt target, and we'll
2563 		 * never jump from plt to an unknown place.
2564 		 */
2565 		plt_target = (u64)&dummy_tramp;
2566 
2567 	if (plt_target) {
2568 		/* non-zero plt_target indicates we're patching a bpf prog,
2569 		 * which is read only.
2570 		 */
2571 		if (set_memory_rw(PAGE_MASK & ((uintptr_t)&plt->target), 1))
2572 			return -EFAULT;
2573 		WRITE_ONCE(plt->target, plt_target);
2574 		set_memory_ro(PAGE_MASK & ((uintptr_t)&plt->target), 1);
2575 		/* since plt target points to either the new trampoline
2576 		 * or dummy_tramp, even if another CPU reads the old plt
2577 		 * target value before fetching the bl instruction to plt,
2578 		 * it will be brought back by dummy_tramp, so no barrier is
2579 		 * required here.
2580 		 */
2581 	}
2582 
2583 	/* if the old target and the new target are both long jumps, no
2584 	 * patching is required
2585 	 */
2586 	if (old_insn == new_insn)
2587 		return 0;
2588 
2589 	mutex_lock(&text_mutex);
2590 	if (aarch64_insn_read(ip, &replaced)) {
2591 		ret = -EFAULT;
2592 		goto out;
2593 	}
2594 
2595 	if (replaced != old_insn) {
2596 		ret = -EFAULT;
2597 		goto out;
2598 	}
2599 
2600 	/* We call aarch64_insn_patch_text_nosync() to replace instruction
2601 	 * atomically, so no other CPUs will fetch a half-new and half-old
2602 	 * instruction. But there is chance that another CPU executes the
2603 	 * old instruction after the patching operation finishes (e.g.,
2604 	 * pipeline not flushed, or icache not synchronized yet).
2605 	 *
2606 	 * 1. when a new trampoline is attached, it is not a problem for
2607 	 *    different CPUs to jump to different trampolines temporarily.
2608 	 *
2609 	 * 2. when an old trampoline is freed, we should wait for all other
2610 	 *    CPUs to exit the trampoline and make sure the trampoline is no
2611 	 *    longer reachable, since bpf_tramp_image_put() function already
2612 	 *    uses percpu_ref and task-based rcu to do the sync, no need to call
2613 	 *    the sync version here, see bpf_tramp_image_put() for details.
2614 	 */
2615 	ret = aarch64_insn_patch_text_nosync(ip, new_insn);
2616 out:
2617 	mutex_unlock(&text_mutex);
2618 
2619 	return ret;
2620 }
2621 
2622 bool bpf_jit_supports_ptr_xchg(void)
2623 {
2624 	return true;
2625 }
2626 
2627 bool bpf_jit_supports_exceptions(void)
2628 {
2629 	/* We unwind through both kernel frames starting from within bpf_throw
2630 	 * call and BPF frames. Therefore we require FP unwinder to be enabled
2631 	 * to walk kernel frames and reach BPF frames in the stack trace.
2632 	 * ARM64 kernel is aways compiled with CONFIG_FRAME_POINTER=y
2633 	 */
2634 	return true;
2635 }
2636 
2637 bool bpf_jit_supports_arena(void)
2638 {
2639 	return true;
2640 }
2641 
2642 bool bpf_jit_supports_insn(struct bpf_insn *insn, bool in_arena)
2643 {
2644 	if (!in_arena)
2645 		return true;
2646 	switch (insn->code) {
2647 	case BPF_STX | BPF_ATOMIC | BPF_W:
2648 	case BPF_STX | BPF_ATOMIC | BPF_DW:
2649 		if (!cpus_have_cap(ARM64_HAS_LSE_ATOMICS))
2650 			return false;
2651 	}
2652 	return true;
2653 }
2654 
2655 bool bpf_jit_supports_percpu_insn(void)
2656 {
2657 	return true;
2658 }
2659 
2660 bool bpf_jit_inlines_helper_call(s32 imm)
2661 {
2662 	switch (imm) {
2663 	case BPF_FUNC_get_smp_processor_id:
2664 	case BPF_FUNC_get_current_task:
2665 	case BPF_FUNC_get_current_task_btf:
2666 		return true;
2667 	default:
2668 		return false;
2669 	}
2670 }
2671 
2672 void bpf_jit_free(struct bpf_prog *prog)
2673 {
2674 	if (prog->jited) {
2675 		struct arm64_jit_data *jit_data = prog->aux->jit_data;
2676 		struct bpf_binary_header *hdr;
2677 
2678 		/*
2679 		 * If we fail the final pass of JIT (from jit_subprogs),
2680 		 * the program may not be finalized yet. Call finalize here
2681 		 * before freeing it.
2682 		 */
2683 		if (jit_data) {
2684 			bpf_arch_text_copy(&jit_data->ro_header->size, &jit_data->header->size,
2685 					   sizeof(jit_data->header->size));
2686 			kfree(jit_data);
2687 		}
2688 		hdr = bpf_jit_binary_pack_hdr(prog);
2689 		bpf_jit_binary_pack_free(hdr, NULL);
2690 		WARN_ON_ONCE(!bpf_prog_kallsyms_verify_off(prog));
2691 	}
2692 
2693 	bpf_prog_unlock_free(prog);
2694 }
2695