xref: /linux/arch/arm64/net/bpf_jit_comp.c (revision 7a9b709e7cc5ce1ffb84ce07bf6d157e1de758df)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * BPF JIT compiler for ARM64
4  *
5  * Copyright (C) 2014-2016 Zi Shen Lim <zlim.lnx@gmail.com>
6  */
7 
8 #define pr_fmt(fmt) "bpf_jit: " fmt
9 
10 #include <linux/bitfield.h>
11 #include <linux/bpf.h>
12 #include <linux/filter.h>
13 #include <linux/memory.h>
14 #include <linux/printk.h>
15 #include <linux/slab.h>
16 
17 #include <asm/asm-extable.h>
18 #include <asm/byteorder.h>
19 #include <asm/cacheflush.h>
20 #include <asm/debug-monitors.h>
21 #include <asm/insn.h>
22 #include <asm/text-patching.h>
23 #include <asm/set_memory.h>
24 
25 #include "bpf_jit.h"
26 
27 #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
28 #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
29 #define TCCNT_PTR (MAX_BPF_JIT_REG + 2)
30 #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
31 #define ARENA_VM_START (MAX_BPF_JIT_REG + 5)
32 
33 #define check_imm(bits, imm) do {				\
34 	if ((((imm) > 0) && ((imm) >> (bits))) ||		\
35 	    (((imm) < 0) && (~(imm) >> (bits)))) {		\
36 		pr_info("[%2d] imm=%d(0x%x) out of range\n",	\
37 			i, imm, imm);				\
38 		return -EINVAL;					\
39 	}							\
40 } while (0)
41 #define check_imm19(imm) check_imm(19, imm)
42 #define check_imm26(imm) check_imm(26, imm)
43 
44 /* Map BPF registers to A64 registers */
45 static const int bpf2a64[] = {
46 	/* return value from in-kernel function, and exit value from eBPF */
47 	[BPF_REG_0] = A64_R(7),
48 	/* arguments from eBPF program to in-kernel function */
49 	[BPF_REG_1] = A64_R(0),
50 	[BPF_REG_2] = A64_R(1),
51 	[BPF_REG_3] = A64_R(2),
52 	[BPF_REG_4] = A64_R(3),
53 	[BPF_REG_5] = A64_R(4),
54 	/* callee saved registers that in-kernel function will preserve */
55 	[BPF_REG_6] = A64_R(19),
56 	[BPF_REG_7] = A64_R(20),
57 	[BPF_REG_8] = A64_R(21),
58 	[BPF_REG_9] = A64_R(22),
59 	/* read-only frame pointer to access stack */
60 	[BPF_REG_FP] = A64_R(25),
61 	/* temporary registers for BPF JIT */
62 	[TMP_REG_1] = A64_R(10),
63 	[TMP_REG_2] = A64_R(11),
64 	[TMP_REG_3] = A64_R(12),
65 	/* tail_call_cnt_ptr */
66 	[TCCNT_PTR] = A64_R(26),
67 	/* temporary register for blinding constants */
68 	[BPF_REG_AX] = A64_R(9),
69 	/* callee saved register for kern_vm_start address */
70 	[ARENA_VM_START] = A64_R(28),
71 };
72 
73 struct jit_ctx {
74 	const struct bpf_prog *prog;
75 	int idx;
76 	int epilogue_offset;
77 	int *offset;
78 	int exentry_idx;
79 	int nr_used_callee_reg;
80 	u8 used_callee_reg[8]; /* r6~r9, fp, arena_vm_start */
81 	__le32 *image;
82 	__le32 *ro_image;
83 	u32 stack_size;
84 	u64 user_vm_start;
85 	u64 arena_vm_start;
86 	bool fp_used;
87 	bool write;
88 };
89 
90 struct bpf_plt {
91 	u32 insn_ldr; /* load target */
92 	u32 insn_br;  /* branch to target */
93 	u64 target;   /* target value */
94 };
95 
96 #define PLT_TARGET_SIZE   sizeof_field(struct bpf_plt, target)
97 #define PLT_TARGET_OFFSET offsetof(struct bpf_plt, target)
98 
99 static inline void emit(const u32 insn, struct jit_ctx *ctx)
100 {
101 	if (ctx->image != NULL && ctx->write)
102 		ctx->image[ctx->idx] = cpu_to_le32(insn);
103 
104 	ctx->idx++;
105 }
106 
107 static inline void emit_a64_mov_i(const int is64, const int reg,
108 				  const s32 val, struct jit_ctx *ctx)
109 {
110 	u16 hi = val >> 16;
111 	u16 lo = val & 0xffff;
112 
113 	if (hi & 0x8000) {
114 		if (hi == 0xffff) {
115 			emit(A64_MOVN(is64, reg, (u16)~lo, 0), ctx);
116 		} else {
117 			emit(A64_MOVN(is64, reg, (u16)~hi, 16), ctx);
118 			if (lo != 0xffff)
119 				emit(A64_MOVK(is64, reg, lo, 0), ctx);
120 		}
121 	} else {
122 		emit(A64_MOVZ(is64, reg, lo, 0), ctx);
123 		if (hi)
124 			emit(A64_MOVK(is64, reg, hi, 16), ctx);
125 	}
126 }
127 
128 static int i64_i16_blocks(const u64 val, bool inverse)
129 {
130 	return (((val >>  0) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
131 	       (((val >> 16) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
132 	       (((val >> 32) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
133 	       (((val >> 48) & 0xffff) != (inverse ? 0xffff : 0x0000));
134 }
135 
136 static inline void emit_a64_mov_i64(const int reg, const u64 val,
137 				    struct jit_ctx *ctx)
138 {
139 	u64 nrm_tmp = val, rev_tmp = ~val;
140 	bool inverse;
141 	int shift;
142 
143 	if (!(nrm_tmp >> 32))
144 		return emit_a64_mov_i(0, reg, (u32)val, ctx);
145 
146 	inverse = i64_i16_blocks(nrm_tmp, true) < i64_i16_blocks(nrm_tmp, false);
147 	shift = max(round_down((inverse ? (fls64(rev_tmp) - 1) :
148 					  (fls64(nrm_tmp) - 1)), 16), 0);
149 	if (inverse)
150 		emit(A64_MOVN(1, reg, (rev_tmp >> shift) & 0xffff, shift), ctx);
151 	else
152 		emit(A64_MOVZ(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
153 	shift -= 16;
154 	while (shift >= 0) {
155 		if (((nrm_tmp >> shift) & 0xffff) != (inverse ? 0xffff : 0x0000))
156 			emit(A64_MOVK(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
157 		shift -= 16;
158 	}
159 }
160 
161 static inline void emit_bti(u32 insn, struct jit_ctx *ctx)
162 {
163 	if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
164 		emit(insn, ctx);
165 }
166 
167 /*
168  * Kernel addresses in the vmalloc space use at most 48 bits, and the
169  * remaining bits are guaranteed to be 0x1. So we can compose the address
170  * with a fixed length movn/movk/movk sequence.
171  */
172 static inline void emit_addr_mov_i64(const int reg, const u64 val,
173 				     struct jit_ctx *ctx)
174 {
175 	u64 tmp = val;
176 	int shift = 0;
177 
178 	emit(A64_MOVN(1, reg, ~tmp & 0xffff, shift), ctx);
179 	while (shift < 32) {
180 		tmp >>= 16;
181 		shift += 16;
182 		emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
183 	}
184 }
185 
186 static bool should_emit_indirect_call(long target, const struct jit_ctx *ctx)
187 {
188 	long offset;
189 
190 	/* when ctx->ro_image is not allocated or the target is unknown,
191 	 * emit indirect call
192 	 */
193 	if (!ctx->ro_image || !target)
194 		return true;
195 
196 	offset = target - (long)&ctx->ro_image[ctx->idx];
197 	return offset < -SZ_128M || offset >= SZ_128M;
198 }
199 
200 static void emit_direct_call(u64 target, struct jit_ctx *ctx)
201 {
202 	u32 insn;
203 	unsigned long pc;
204 
205 	pc = (unsigned long)&ctx->ro_image[ctx->idx];
206 	insn = aarch64_insn_gen_branch_imm(pc, target, AARCH64_INSN_BRANCH_LINK);
207 	emit(insn, ctx);
208 }
209 
210 static void emit_indirect_call(u64 target, struct jit_ctx *ctx)
211 {
212 	u8 tmp;
213 
214 	tmp = bpf2a64[TMP_REG_1];
215 	emit_addr_mov_i64(tmp, target, ctx);
216 	emit(A64_BLR(tmp), ctx);
217 }
218 
219 static void emit_call(u64 target, struct jit_ctx *ctx)
220 {
221 	if (should_emit_indirect_call((long)target, ctx))
222 		emit_indirect_call(target, ctx);
223 	else
224 		emit_direct_call(target, ctx);
225 }
226 
227 static inline int bpf2a64_offset(int bpf_insn, int off,
228 				 const struct jit_ctx *ctx)
229 {
230 	/* BPF JMP offset is relative to the next instruction */
231 	bpf_insn++;
232 	/*
233 	 * Whereas arm64 branch instructions encode the offset
234 	 * from the branch itself, so we must subtract 1 from the
235 	 * instruction offset.
236 	 */
237 	return ctx->offset[bpf_insn + off] - (ctx->offset[bpf_insn] - 1);
238 }
239 
240 static void jit_fill_hole(void *area, unsigned int size)
241 {
242 	__le32 *ptr;
243 	/* We are guaranteed to have aligned memory. */
244 	for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
245 		*ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
246 }
247 
248 int bpf_arch_text_invalidate(void *dst, size_t len)
249 {
250 	if (!aarch64_insn_set(dst, AARCH64_BREAK_FAULT, len))
251 		return -EINVAL;
252 
253 	return 0;
254 }
255 
256 static inline int epilogue_offset(const struct jit_ctx *ctx)
257 {
258 	int to = ctx->epilogue_offset;
259 	int from = ctx->idx;
260 
261 	return to - from;
262 }
263 
264 static bool is_addsub_imm(u32 imm)
265 {
266 	/* Either imm12 or shifted imm12. */
267 	return !(imm & ~0xfff) || !(imm & ~0xfff000);
268 }
269 
270 static inline void emit_a64_add_i(const bool is64, const int dst, const int src,
271 				  const int tmp, const s32 imm, struct jit_ctx *ctx)
272 {
273 	if (is_addsub_imm(imm)) {
274 		emit(A64_ADD_I(is64, dst, src, imm), ctx);
275 	} else if (is_addsub_imm(-(u32)imm)) {
276 		emit(A64_SUB_I(is64, dst, src, -imm), ctx);
277 	} else {
278 		emit_a64_mov_i(is64, tmp, imm, ctx);
279 		emit(A64_ADD(is64, dst, src, tmp), ctx);
280 	}
281 }
282 
283 /*
284  * There are 3 types of AArch64 LDR/STR (immediate) instruction:
285  * Post-index, Pre-index, Unsigned offset.
286  *
287  * For BPF ldr/str, the "unsigned offset" type is sufficient.
288  *
289  * "Unsigned offset" type LDR(immediate) format:
290  *
291  *    3                   2                   1                   0
292  *  1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
293  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
294  * |x x|1 1 1 0 0 1 0 1|         imm12         |    Rn   |    Rt   |
295  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
296  * scale
297  *
298  * "Unsigned offset" type STR(immediate) format:
299  *    3                   2                   1                   0
300  *  1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
301  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
302  * |x x|1 1 1 0 0 1 0 0|         imm12         |    Rn   |    Rt   |
303  * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
304  * scale
305  *
306  * The offset is calculated from imm12 and scale in the following way:
307  *
308  * offset = (u64)imm12 << scale
309  */
310 static bool is_lsi_offset(int offset, int scale)
311 {
312 	if (offset < 0)
313 		return false;
314 
315 	if (offset > (0xFFF << scale))
316 		return false;
317 
318 	if (offset & ((1 << scale) - 1))
319 		return false;
320 
321 	return true;
322 }
323 
324 /* generated main prog prologue:
325  *      bti c // if CONFIG_ARM64_BTI_KERNEL
326  *      mov x9, lr
327  *      nop  // POKE_OFFSET
328  *      paciasp // if CONFIG_ARM64_PTR_AUTH_KERNEL
329  *      stp x29, lr, [sp, #-16]!
330  *      mov x29, sp
331  *      stp xzr, x26, [sp, #-16]!
332  *      mov x26, sp
333  *      // PROLOGUE_OFFSET
334  *	// save callee-saved registers
335  */
336 static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
337 {
338 	const bool is_main_prog = !bpf_is_subprog(ctx->prog);
339 	const u8 ptr = bpf2a64[TCCNT_PTR];
340 
341 	if (is_main_prog) {
342 		/* Initialize tail_call_cnt. */
343 		emit(A64_PUSH(A64_ZR, ptr, A64_SP), ctx);
344 		emit(A64_MOV(1, ptr, A64_SP), ctx);
345 	} else
346 		emit(A64_PUSH(ptr, ptr, A64_SP), ctx);
347 }
348 
349 static void find_used_callee_regs(struct jit_ctx *ctx)
350 {
351 	int i;
352 	const struct bpf_prog *prog = ctx->prog;
353 	const struct bpf_insn *insn = &prog->insnsi[0];
354 	int reg_used = 0;
355 
356 	for (i = 0; i < prog->len; i++, insn++) {
357 		if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6)
358 			reg_used |= 1;
359 
360 		if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7)
361 			reg_used |= 2;
362 
363 		if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8)
364 			reg_used |= 4;
365 
366 		if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9)
367 			reg_used |= 8;
368 
369 		if (insn->dst_reg == BPF_REG_FP || insn->src_reg == BPF_REG_FP) {
370 			ctx->fp_used = true;
371 			reg_used |= 16;
372 		}
373 	}
374 
375 	i = 0;
376 	if (reg_used & 1)
377 		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_6];
378 
379 	if (reg_used & 2)
380 		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_7];
381 
382 	if (reg_used & 4)
383 		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_8];
384 
385 	if (reg_used & 8)
386 		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_9];
387 
388 	if (reg_used & 16)
389 		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_FP];
390 
391 	if (ctx->arena_vm_start)
392 		ctx->used_callee_reg[i++] = bpf2a64[ARENA_VM_START];
393 
394 	ctx->nr_used_callee_reg = i;
395 }
396 
397 /* Save callee-saved registers */
398 static void push_callee_regs(struct jit_ctx *ctx)
399 {
400 	int reg1, reg2, i;
401 
402 	/*
403 	 * Program acting as exception boundary should save all ARM64
404 	 * Callee-saved registers as the exception callback needs to recover
405 	 * all ARM64 Callee-saved registers in its epilogue.
406 	 */
407 	if (ctx->prog->aux->exception_boundary) {
408 		emit(A64_PUSH(A64_R(19), A64_R(20), A64_SP), ctx);
409 		emit(A64_PUSH(A64_R(21), A64_R(22), A64_SP), ctx);
410 		emit(A64_PUSH(A64_R(23), A64_R(24), A64_SP), ctx);
411 		emit(A64_PUSH(A64_R(25), A64_R(26), A64_SP), ctx);
412 		emit(A64_PUSH(A64_R(27), A64_R(28), A64_SP), ctx);
413 	} else {
414 		find_used_callee_regs(ctx);
415 		for (i = 0; i + 1 < ctx->nr_used_callee_reg; i += 2) {
416 			reg1 = ctx->used_callee_reg[i];
417 			reg2 = ctx->used_callee_reg[i + 1];
418 			emit(A64_PUSH(reg1, reg2, A64_SP), ctx);
419 		}
420 		if (i < ctx->nr_used_callee_reg) {
421 			reg1 = ctx->used_callee_reg[i];
422 			/* keep SP 16-byte aligned */
423 			emit(A64_PUSH(reg1, A64_ZR, A64_SP), ctx);
424 		}
425 	}
426 }
427 
428 /* Restore callee-saved registers */
429 static void pop_callee_regs(struct jit_ctx *ctx)
430 {
431 	struct bpf_prog_aux *aux = ctx->prog->aux;
432 	int reg1, reg2, i;
433 
434 	/*
435 	 * Program acting as exception boundary pushes R23 and R24 in addition
436 	 * to BPF callee-saved registers. Exception callback uses the boundary
437 	 * program's stack frame, so recover these extra registers in the above
438 	 * two cases.
439 	 */
440 	if (aux->exception_boundary || aux->exception_cb) {
441 		emit(A64_POP(A64_R(27), A64_R(28), A64_SP), ctx);
442 		emit(A64_POP(A64_R(25), A64_R(26), A64_SP), ctx);
443 		emit(A64_POP(A64_R(23), A64_R(24), A64_SP), ctx);
444 		emit(A64_POP(A64_R(21), A64_R(22), A64_SP), ctx);
445 		emit(A64_POP(A64_R(19), A64_R(20), A64_SP), ctx);
446 	} else {
447 		i = ctx->nr_used_callee_reg - 1;
448 		if (ctx->nr_used_callee_reg % 2 != 0) {
449 			reg1 = ctx->used_callee_reg[i];
450 			emit(A64_POP(reg1, A64_ZR, A64_SP), ctx);
451 			i--;
452 		}
453 		while (i > 0) {
454 			reg1 = ctx->used_callee_reg[i - 1];
455 			reg2 = ctx->used_callee_reg[i];
456 			emit(A64_POP(reg1, reg2, A64_SP), ctx);
457 			i -= 2;
458 		}
459 	}
460 }
461 
462 #define BTI_INSNS (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL) ? 1 : 0)
463 #define PAC_INSNS (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL) ? 1 : 0)
464 
465 /* Offset of nop instruction in bpf prog entry to be poked */
466 #define POKE_OFFSET (BTI_INSNS + 1)
467 
468 /* Tail call offset to jump into */
469 #define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 4)
470 
471 static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
472 {
473 	const struct bpf_prog *prog = ctx->prog;
474 	const bool is_main_prog = !bpf_is_subprog(prog);
475 	const u8 fp = bpf2a64[BPF_REG_FP];
476 	const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
477 	const int idx0 = ctx->idx;
478 	int cur_offset;
479 
480 	/*
481 	 * BPF prog stack layout
482 	 *
483 	 *                         high
484 	 * original A64_SP =>   0:+-----+ BPF prologue
485 	 *                        |FP/LR|
486 	 * current A64_FP =>  -16:+-----+
487 	 *                        | ... | callee saved registers
488 	 * BPF fp register => -64:+-----+ <= (BPF_FP)
489 	 *                        |     |
490 	 *                        | ... | BPF prog stack
491 	 *                        |     |
492 	 *                        +-----+ <= (BPF_FP - prog->aux->stack_depth)
493 	 *                        |RSVD | padding
494 	 * current A64_SP =>      +-----+ <= (BPF_FP - ctx->stack_size)
495 	 *                        |     |
496 	 *                        | ... | Function call stack
497 	 *                        |     |
498 	 *                        +-----+
499 	 *                          low
500 	 *
501 	 */
502 
503 	/* bpf function may be invoked by 3 instruction types:
504 	 * 1. bl, attached via freplace to bpf prog via short jump
505 	 * 2. br, attached via freplace to bpf prog via long jump
506 	 * 3. blr, working as a function pointer, used by emit_call.
507 	 * So BTI_JC should used here to support both br and blr.
508 	 */
509 	emit_bti(A64_BTI_JC, ctx);
510 
511 	emit(A64_MOV(1, A64_R(9), A64_LR), ctx);
512 	emit(A64_NOP, ctx);
513 
514 	if (!prog->aux->exception_cb) {
515 		/* Sign lr */
516 		if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
517 			emit(A64_PACIASP, ctx);
518 
519 		/* Save FP and LR registers to stay align with ARM64 AAPCS */
520 		emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
521 		emit(A64_MOV(1, A64_FP, A64_SP), ctx);
522 
523 		prepare_bpf_tail_call_cnt(ctx);
524 
525 		if (!ebpf_from_cbpf && is_main_prog) {
526 			cur_offset = ctx->idx - idx0;
527 			if (cur_offset != PROLOGUE_OFFSET) {
528 				pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
529 						cur_offset, PROLOGUE_OFFSET);
530 				return -1;
531 			}
532 			/* BTI landing pad for the tail call, done with a BR */
533 			emit_bti(A64_BTI_J, ctx);
534 		}
535 		push_callee_regs(ctx);
536 	} else {
537 		/*
538 		 * Exception callback receives FP of Main Program as third
539 		 * parameter
540 		 */
541 		emit(A64_MOV(1, A64_FP, A64_R(2)), ctx);
542 		/*
543 		 * Main Program already pushed the frame record and the
544 		 * callee-saved registers. The exception callback will not push
545 		 * anything and re-use the main program's stack.
546 		 *
547 		 * 12 registers are on the stack
548 		 */
549 		emit(A64_SUB_I(1, A64_SP, A64_FP, 96), ctx);
550 	}
551 
552 	if (ctx->fp_used)
553 		/* Set up BPF prog stack base register */
554 		emit(A64_MOV(1, fp, A64_SP), ctx);
555 
556 	/* Stack must be multiples of 16B */
557 	ctx->stack_size = round_up(prog->aux->stack_depth, 16);
558 
559 	/* Set up function call stack */
560 	if (ctx->stack_size)
561 		emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
562 
563 	if (ctx->arena_vm_start)
564 		emit_a64_mov_i64(arena_vm_base, ctx->arena_vm_start, ctx);
565 
566 	return 0;
567 }
568 
569 static int emit_bpf_tail_call(struct jit_ctx *ctx)
570 {
571 	/* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
572 	const u8 r2 = bpf2a64[BPF_REG_2];
573 	const u8 r3 = bpf2a64[BPF_REG_3];
574 
575 	const u8 tmp = bpf2a64[TMP_REG_1];
576 	const u8 prg = bpf2a64[TMP_REG_2];
577 	const u8 tcc = bpf2a64[TMP_REG_3];
578 	const u8 ptr = bpf2a64[TCCNT_PTR];
579 	size_t off;
580 	__le32 *branch1 = NULL;
581 	__le32 *branch2 = NULL;
582 	__le32 *branch3 = NULL;
583 
584 	/* if (index >= array->map.max_entries)
585 	 *     goto out;
586 	 */
587 	off = offsetof(struct bpf_array, map.max_entries);
588 	emit_a64_mov_i64(tmp, off, ctx);
589 	emit(A64_LDR32(tmp, r2, tmp), ctx);
590 	emit(A64_MOV(0, r3, r3), ctx);
591 	emit(A64_CMP(0, r3, tmp), ctx);
592 	branch1 = ctx->image + ctx->idx;
593 	emit(A64_NOP, ctx);
594 
595 	/*
596 	 * if ((*tail_call_cnt_ptr) >= MAX_TAIL_CALL_CNT)
597 	 *     goto out;
598 	 */
599 	emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
600 	emit(A64_LDR64I(tcc, ptr, 0), ctx);
601 	emit(A64_CMP(1, tcc, tmp), ctx);
602 	branch2 = ctx->image + ctx->idx;
603 	emit(A64_NOP, ctx);
604 
605 	/* (*tail_call_cnt_ptr)++; */
606 	emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
607 
608 	/* prog = array->ptrs[index];
609 	 * if (prog == NULL)
610 	 *     goto out;
611 	 */
612 	off = offsetof(struct bpf_array, ptrs);
613 	emit_a64_mov_i64(tmp, off, ctx);
614 	emit(A64_ADD(1, tmp, r2, tmp), ctx);
615 	emit(A64_LSL(1, prg, r3, 3), ctx);
616 	emit(A64_LDR64(prg, tmp, prg), ctx);
617 	branch3 = ctx->image + ctx->idx;
618 	emit(A64_NOP, ctx);
619 
620 	/* Update tail_call_cnt if the slot is populated. */
621 	emit(A64_STR64I(tcc, ptr, 0), ctx);
622 
623 	/* restore SP */
624 	if (ctx->stack_size)
625 		emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
626 
627 	pop_callee_regs(ctx);
628 
629 	/* goto *(prog->bpf_func + prologue_offset); */
630 	off = offsetof(struct bpf_prog, bpf_func);
631 	emit_a64_mov_i64(tmp, off, ctx);
632 	emit(A64_LDR64(tmp, prg, tmp), ctx);
633 	emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
634 	emit(A64_BR(tmp), ctx);
635 
636 	if (ctx->image) {
637 		off = &ctx->image[ctx->idx] - branch1;
638 		*branch1 = cpu_to_le32(A64_B_(A64_COND_CS, off));
639 
640 		off = &ctx->image[ctx->idx] - branch2;
641 		*branch2 = cpu_to_le32(A64_B_(A64_COND_CS, off));
642 
643 		off = &ctx->image[ctx->idx] - branch3;
644 		*branch3 = cpu_to_le32(A64_CBZ(1, prg, off));
645 	}
646 
647 	return 0;
648 }
649 
650 static int emit_atomic_ld_st(const struct bpf_insn *insn, struct jit_ctx *ctx)
651 {
652 	const s32 imm = insn->imm;
653 	const s16 off = insn->off;
654 	const u8 code = insn->code;
655 	const bool arena = BPF_MODE(code) == BPF_PROBE_ATOMIC;
656 	const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
657 	const u8 dst = bpf2a64[insn->dst_reg];
658 	const u8 src = bpf2a64[insn->src_reg];
659 	const u8 tmp = bpf2a64[TMP_REG_1];
660 	u8 reg;
661 
662 	switch (imm) {
663 	case BPF_LOAD_ACQ:
664 		reg = src;
665 		break;
666 	case BPF_STORE_REL:
667 		reg = dst;
668 		break;
669 	default:
670 		pr_err_once("unknown atomic load/store op code %02x\n", imm);
671 		return -EINVAL;
672 	}
673 
674 	if (off) {
675 		emit_a64_add_i(1, tmp, reg, tmp, off, ctx);
676 		reg = tmp;
677 	}
678 	if (arena) {
679 		emit(A64_ADD(1, tmp, reg, arena_vm_base), ctx);
680 		reg = tmp;
681 	}
682 
683 	switch (imm) {
684 	case BPF_LOAD_ACQ:
685 		switch (BPF_SIZE(code)) {
686 		case BPF_B:
687 			emit(A64_LDARB(dst, reg), ctx);
688 			break;
689 		case BPF_H:
690 			emit(A64_LDARH(dst, reg), ctx);
691 			break;
692 		case BPF_W:
693 			emit(A64_LDAR32(dst, reg), ctx);
694 			break;
695 		case BPF_DW:
696 			emit(A64_LDAR64(dst, reg), ctx);
697 			break;
698 		}
699 		break;
700 	case BPF_STORE_REL:
701 		switch (BPF_SIZE(code)) {
702 		case BPF_B:
703 			emit(A64_STLRB(src, reg), ctx);
704 			break;
705 		case BPF_H:
706 			emit(A64_STLRH(src, reg), ctx);
707 			break;
708 		case BPF_W:
709 			emit(A64_STLR32(src, reg), ctx);
710 			break;
711 		case BPF_DW:
712 			emit(A64_STLR64(src, reg), ctx);
713 			break;
714 		}
715 		break;
716 	default:
717 		pr_err_once("unexpected atomic load/store op code %02x\n",
718 			    imm);
719 		return -EINVAL;
720 	}
721 
722 	return 0;
723 }
724 
725 #ifdef CONFIG_ARM64_LSE_ATOMICS
726 static int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
727 {
728 	const u8 code = insn->code;
729 	const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
730 	const u8 dst = bpf2a64[insn->dst_reg];
731 	const u8 src = bpf2a64[insn->src_reg];
732 	const u8 tmp = bpf2a64[TMP_REG_1];
733 	const u8 tmp2 = bpf2a64[TMP_REG_2];
734 	const bool isdw = BPF_SIZE(code) == BPF_DW;
735 	const bool arena = BPF_MODE(code) == BPF_PROBE_ATOMIC;
736 	const s16 off = insn->off;
737 	u8 reg = dst;
738 
739 	if (off) {
740 		emit_a64_add_i(1, tmp, reg, tmp, off, ctx);
741 		reg = tmp;
742 	}
743 	if (arena) {
744 		emit(A64_ADD(1, tmp, reg, arena_vm_base), ctx);
745 		reg = tmp;
746 	}
747 
748 	switch (insn->imm) {
749 	/* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
750 	case BPF_ADD:
751 		emit(A64_STADD(isdw, reg, src), ctx);
752 		break;
753 	case BPF_AND:
754 		emit(A64_MVN(isdw, tmp2, src), ctx);
755 		emit(A64_STCLR(isdw, reg, tmp2), ctx);
756 		break;
757 	case BPF_OR:
758 		emit(A64_STSET(isdw, reg, src), ctx);
759 		break;
760 	case BPF_XOR:
761 		emit(A64_STEOR(isdw, reg, src), ctx);
762 		break;
763 	/* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
764 	case BPF_ADD | BPF_FETCH:
765 		emit(A64_LDADDAL(isdw, src, reg, src), ctx);
766 		break;
767 	case BPF_AND | BPF_FETCH:
768 		emit(A64_MVN(isdw, tmp2, src), ctx);
769 		emit(A64_LDCLRAL(isdw, src, reg, tmp2), ctx);
770 		break;
771 	case BPF_OR | BPF_FETCH:
772 		emit(A64_LDSETAL(isdw, src, reg, src), ctx);
773 		break;
774 	case BPF_XOR | BPF_FETCH:
775 		emit(A64_LDEORAL(isdw, src, reg, src), ctx);
776 		break;
777 	/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
778 	case BPF_XCHG:
779 		emit(A64_SWPAL(isdw, src, reg, src), ctx);
780 		break;
781 	/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
782 	case BPF_CMPXCHG:
783 		emit(A64_CASAL(isdw, src, reg, bpf2a64[BPF_REG_0]), ctx);
784 		break;
785 	default:
786 		pr_err_once("unknown atomic op code %02x\n", insn->imm);
787 		return -EINVAL;
788 	}
789 
790 	return 0;
791 }
792 #else
793 static inline int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
794 {
795 	return -EINVAL;
796 }
797 #endif
798 
799 static int emit_ll_sc_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
800 {
801 	const u8 code = insn->code;
802 	const u8 dst = bpf2a64[insn->dst_reg];
803 	const u8 src = bpf2a64[insn->src_reg];
804 	const u8 tmp = bpf2a64[TMP_REG_1];
805 	const u8 tmp2 = bpf2a64[TMP_REG_2];
806 	const u8 tmp3 = bpf2a64[TMP_REG_3];
807 	const int i = insn - ctx->prog->insnsi;
808 	const s32 imm = insn->imm;
809 	const s16 off = insn->off;
810 	const bool isdw = BPF_SIZE(code) == BPF_DW;
811 	u8 reg = dst;
812 	s32 jmp_offset;
813 
814 	if (BPF_MODE(code) == BPF_PROBE_ATOMIC) {
815 		/* ll_sc based atomics don't support unsafe pointers yet. */
816 		pr_err_once("unknown atomic opcode %02x\n", code);
817 		return -EINVAL;
818 	}
819 
820 	if (off) {
821 		emit_a64_add_i(1, tmp, reg, tmp, off, ctx);
822 		reg = tmp;
823 	}
824 
825 	if (imm == BPF_ADD || imm == BPF_AND ||
826 	    imm == BPF_OR || imm == BPF_XOR) {
827 		/* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
828 		emit(A64_LDXR(isdw, tmp2, reg), ctx);
829 		if (imm == BPF_ADD)
830 			emit(A64_ADD(isdw, tmp2, tmp2, src), ctx);
831 		else if (imm == BPF_AND)
832 			emit(A64_AND(isdw, tmp2, tmp2, src), ctx);
833 		else if (imm == BPF_OR)
834 			emit(A64_ORR(isdw, tmp2, tmp2, src), ctx);
835 		else
836 			emit(A64_EOR(isdw, tmp2, tmp2, src), ctx);
837 		emit(A64_STXR(isdw, tmp2, reg, tmp3), ctx);
838 		jmp_offset = -3;
839 		check_imm19(jmp_offset);
840 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
841 	} else if (imm == (BPF_ADD | BPF_FETCH) ||
842 		   imm == (BPF_AND | BPF_FETCH) ||
843 		   imm == (BPF_OR | BPF_FETCH) ||
844 		   imm == (BPF_XOR | BPF_FETCH)) {
845 		/* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
846 		const u8 ax = bpf2a64[BPF_REG_AX];
847 
848 		emit(A64_MOV(isdw, ax, src), ctx);
849 		emit(A64_LDXR(isdw, src, reg), ctx);
850 		if (imm == (BPF_ADD | BPF_FETCH))
851 			emit(A64_ADD(isdw, tmp2, src, ax), ctx);
852 		else if (imm == (BPF_AND | BPF_FETCH))
853 			emit(A64_AND(isdw, tmp2, src, ax), ctx);
854 		else if (imm == (BPF_OR | BPF_FETCH))
855 			emit(A64_ORR(isdw, tmp2, src, ax), ctx);
856 		else
857 			emit(A64_EOR(isdw, tmp2, src, ax), ctx);
858 		emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
859 		jmp_offset = -3;
860 		check_imm19(jmp_offset);
861 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
862 		emit(A64_DMB_ISH, ctx);
863 	} else if (imm == BPF_XCHG) {
864 		/* src_reg = atomic_xchg(dst_reg + off, src_reg); */
865 		emit(A64_MOV(isdw, tmp2, src), ctx);
866 		emit(A64_LDXR(isdw, src, reg), ctx);
867 		emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
868 		jmp_offset = -2;
869 		check_imm19(jmp_offset);
870 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
871 		emit(A64_DMB_ISH, ctx);
872 	} else if (imm == BPF_CMPXCHG) {
873 		/* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
874 		const u8 r0 = bpf2a64[BPF_REG_0];
875 
876 		emit(A64_MOV(isdw, tmp2, r0), ctx);
877 		emit(A64_LDXR(isdw, r0, reg), ctx);
878 		emit(A64_EOR(isdw, tmp3, r0, tmp2), ctx);
879 		jmp_offset = 4;
880 		check_imm19(jmp_offset);
881 		emit(A64_CBNZ(isdw, tmp3, jmp_offset), ctx);
882 		emit(A64_STLXR(isdw, src, reg, tmp3), ctx);
883 		jmp_offset = -4;
884 		check_imm19(jmp_offset);
885 		emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
886 		emit(A64_DMB_ISH, ctx);
887 	} else {
888 		pr_err_once("unknown atomic op code %02x\n", imm);
889 		return -EINVAL;
890 	}
891 
892 	return 0;
893 }
894 
895 void dummy_tramp(void);
896 
897 asm (
898 "	.pushsection .text, \"ax\", @progbits\n"
899 "	.global dummy_tramp\n"
900 "	.type dummy_tramp, %function\n"
901 "dummy_tramp:"
902 #if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL)
903 "	bti j\n" /* dummy_tramp is called via "br x10" */
904 #endif
905 "	mov x10, x30\n"
906 "	mov x30, x9\n"
907 "	ret x10\n"
908 "	.size dummy_tramp, .-dummy_tramp\n"
909 "	.popsection\n"
910 );
911 
912 /* build a plt initialized like this:
913  *
914  * plt:
915  *      ldr tmp, target
916  *      br tmp
917  * target:
918  *      .quad dummy_tramp
919  *
920  * when a long jump trampoline is attached, target is filled with the
921  * trampoline address, and when the trampoline is removed, target is
922  * restored to dummy_tramp address.
923  */
924 static void build_plt(struct jit_ctx *ctx)
925 {
926 	const u8 tmp = bpf2a64[TMP_REG_1];
927 	struct bpf_plt *plt = NULL;
928 
929 	/* make sure target is 64-bit aligned */
930 	if ((ctx->idx + PLT_TARGET_OFFSET / AARCH64_INSN_SIZE) % 2)
931 		emit(A64_NOP, ctx);
932 
933 	plt = (struct bpf_plt *)(ctx->image + ctx->idx);
934 	/* plt is called via bl, no BTI needed here */
935 	emit(A64_LDR64LIT(tmp, 2 * AARCH64_INSN_SIZE), ctx);
936 	emit(A64_BR(tmp), ctx);
937 
938 	if (ctx->image)
939 		plt->target = (u64)&dummy_tramp;
940 }
941 
942 static void build_epilogue(struct jit_ctx *ctx)
943 {
944 	const u8 r0 = bpf2a64[BPF_REG_0];
945 	const u8 ptr = bpf2a64[TCCNT_PTR];
946 
947 	/* We're done with BPF stack */
948 	if (ctx->stack_size)
949 		emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
950 
951 	pop_callee_regs(ctx);
952 
953 	emit(A64_POP(A64_ZR, ptr, A64_SP), ctx);
954 
955 	/* Restore FP/LR registers */
956 	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
957 
958 	/* Set return value */
959 	emit(A64_MOV(1, A64_R(0), r0), ctx);
960 
961 	/* Authenticate lr */
962 	if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
963 		emit(A64_AUTIASP, ctx);
964 
965 	emit(A64_RET(A64_LR), ctx);
966 }
967 
968 #define BPF_FIXUP_OFFSET_MASK	GENMASK(26, 0)
969 #define BPF_FIXUP_REG_MASK	GENMASK(31, 27)
970 #define DONT_CLEAR 5 /* Unused ARM64 register from BPF's POV */
971 
972 bool ex_handler_bpf(const struct exception_table_entry *ex,
973 		    struct pt_regs *regs)
974 {
975 	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
976 	int dst_reg = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
977 
978 	if (dst_reg != DONT_CLEAR)
979 		regs->regs[dst_reg] = 0;
980 	regs->pc = (unsigned long)&ex->fixup - offset;
981 	return true;
982 }
983 
984 /* For accesses to BTF pointers, add an entry to the exception table */
985 static int add_exception_handler(const struct bpf_insn *insn,
986 				 struct jit_ctx *ctx,
987 				 int dst_reg)
988 {
989 	off_t ins_offset;
990 	off_t fixup_offset;
991 	unsigned long pc;
992 	struct exception_table_entry *ex;
993 
994 	if (!ctx->image)
995 		/* First pass */
996 		return 0;
997 
998 	if (BPF_MODE(insn->code) != BPF_PROBE_MEM &&
999 		BPF_MODE(insn->code) != BPF_PROBE_MEMSX &&
1000 			BPF_MODE(insn->code) != BPF_PROBE_MEM32 &&
1001 				BPF_MODE(insn->code) != BPF_PROBE_ATOMIC)
1002 		return 0;
1003 
1004 	if (!ctx->prog->aux->extable ||
1005 	    WARN_ON_ONCE(ctx->exentry_idx >= ctx->prog->aux->num_exentries))
1006 		return -EINVAL;
1007 
1008 	ex = &ctx->prog->aux->extable[ctx->exentry_idx];
1009 	pc = (unsigned long)&ctx->ro_image[ctx->idx - 1];
1010 
1011 	/*
1012 	 * This is the relative offset of the instruction that may fault from
1013 	 * the exception table itself. This will be written to the exception
1014 	 * table and if this instruction faults, the destination register will
1015 	 * be set to '0' and the execution will jump to the next instruction.
1016 	 */
1017 	ins_offset = pc - (long)&ex->insn;
1018 	if (WARN_ON_ONCE(ins_offset >= 0 || ins_offset < INT_MIN))
1019 		return -ERANGE;
1020 
1021 	/*
1022 	 * Since the extable follows the program, the fixup offset is always
1023 	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
1024 	 * to keep things simple, and put the destination register in the upper
1025 	 * bits. We don't need to worry about buildtime or runtime sort
1026 	 * modifying the upper bits because the table is already sorted, and
1027 	 * isn't part of the main exception table.
1028 	 *
1029 	 * The fixup_offset is set to the next instruction from the instruction
1030 	 * that may fault. The execution will jump to this after handling the
1031 	 * fault.
1032 	 */
1033 	fixup_offset = (long)&ex->fixup - (pc + AARCH64_INSN_SIZE);
1034 	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, fixup_offset))
1035 		return -ERANGE;
1036 
1037 	/*
1038 	 * The offsets above have been calculated using the RO buffer but we
1039 	 * need to use the R/W buffer for writes.
1040 	 * switch ex to rw buffer for writing.
1041 	 */
1042 	ex = (void *)ctx->image + ((void *)ex - (void *)ctx->ro_image);
1043 
1044 	ex->insn = ins_offset;
1045 
1046 	if (BPF_CLASS(insn->code) != BPF_LDX)
1047 		dst_reg = DONT_CLEAR;
1048 
1049 	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, fixup_offset) |
1050 		    FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
1051 
1052 	ex->type = EX_TYPE_BPF;
1053 
1054 	ctx->exentry_idx++;
1055 	return 0;
1056 }
1057 
1058 /* JITs an eBPF instruction.
1059  * Returns:
1060  * 0  - successfully JITed an 8-byte eBPF instruction.
1061  * >0 - successfully JITed a 16-byte eBPF instruction.
1062  * <0 - failed to JIT.
1063  */
1064 static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
1065 		      bool extra_pass)
1066 {
1067 	const u8 code = insn->code;
1068 	u8 dst = bpf2a64[insn->dst_reg];
1069 	u8 src = bpf2a64[insn->src_reg];
1070 	const u8 tmp = bpf2a64[TMP_REG_1];
1071 	const u8 tmp2 = bpf2a64[TMP_REG_2];
1072 	const u8 fp = bpf2a64[BPF_REG_FP];
1073 	const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
1074 	const s16 off = insn->off;
1075 	const s32 imm = insn->imm;
1076 	const int i = insn - ctx->prog->insnsi;
1077 	const bool is64 = BPF_CLASS(code) == BPF_ALU64 ||
1078 			  BPF_CLASS(code) == BPF_JMP;
1079 	u8 jmp_cond;
1080 	s32 jmp_offset;
1081 	u32 a64_insn;
1082 	u8 src_adj;
1083 	u8 dst_adj;
1084 	int off_adj;
1085 	int ret;
1086 	bool sign_extend;
1087 
1088 	switch (code) {
1089 	/* dst = src */
1090 	case BPF_ALU | BPF_MOV | BPF_X:
1091 	case BPF_ALU64 | BPF_MOV | BPF_X:
1092 		if (insn_is_cast_user(insn)) {
1093 			emit(A64_MOV(0, tmp, src), ctx); // 32-bit mov clears the upper 32 bits
1094 			emit_a64_mov_i(0, dst, ctx->user_vm_start >> 32, ctx);
1095 			emit(A64_LSL(1, dst, dst, 32), ctx);
1096 			emit(A64_CBZ(1, tmp, 2), ctx);
1097 			emit(A64_ORR(1, tmp, dst, tmp), ctx);
1098 			emit(A64_MOV(1, dst, tmp), ctx);
1099 			break;
1100 		} else if (insn_is_mov_percpu_addr(insn)) {
1101 			if (dst != src)
1102 				emit(A64_MOV(1, dst, src), ctx);
1103 			if (cpus_have_cap(ARM64_HAS_VIRT_HOST_EXTN))
1104 				emit(A64_MRS_TPIDR_EL2(tmp), ctx);
1105 			else
1106 				emit(A64_MRS_TPIDR_EL1(tmp), ctx);
1107 			emit(A64_ADD(1, dst, dst, tmp), ctx);
1108 			break;
1109 		}
1110 		switch (insn->off) {
1111 		case 0:
1112 			emit(A64_MOV(is64, dst, src), ctx);
1113 			break;
1114 		case 8:
1115 			emit(A64_SXTB(is64, dst, src), ctx);
1116 			break;
1117 		case 16:
1118 			emit(A64_SXTH(is64, dst, src), ctx);
1119 			break;
1120 		case 32:
1121 			emit(A64_SXTW(is64, dst, src), ctx);
1122 			break;
1123 		}
1124 		break;
1125 	/* dst = dst OP src */
1126 	case BPF_ALU | BPF_ADD | BPF_X:
1127 	case BPF_ALU64 | BPF_ADD | BPF_X:
1128 		emit(A64_ADD(is64, dst, dst, src), ctx);
1129 		break;
1130 	case BPF_ALU | BPF_SUB | BPF_X:
1131 	case BPF_ALU64 | BPF_SUB | BPF_X:
1132 		emit(A64_SUB(is64, dst, dst, src), ctx);
1133 		break;
1134 	case BPF_ALU | BPF_AND | BPF_X:
1135 	case BPF_ALU64 | BPF_AND | BPF_X:
1136 		emit(A64_AND(is64, dst, dst, src), ctx);
1137 		break;
1138 	case BPF_ALU | BPF_OR | BPF_X:
1139 	case BPF_ALU64 | BPF_OR | BPF_X:
1140 		emit(A64_ORR(is64, dst, dst, src), ctx);
1141 		break;
1142 	case BPF_ALU | BPF_XOR | BPF_X:
1143 	case BPF_ALU64 | BPF_XOR | BPF_X:
1144 		emit(A64_EOR(is64, dst, dst, src), ctx);
1145 		break;
1146 	case BPF_ALU | BPF_MUL | BPF_X:
1147 	case BPF_ALU64 | BPF_MUL | BPF_X:
1148 		emit(A64_MUL(is64, dst, dst, src), ctx);
1149 		break;
1150 	case BPF_ALU | BPF_DIV | BPF_X:
1151 	case BPF_ALU64 | BPF_DIV | BPF_X:
1152 		if (!off)
1153 			emit(A64_UDIV(is64, dst, dst, src), ctx);
1154 		else
1155 			emit(A64_SDIV(is64, dst, dst, src), ctx);
1156 		break;
1157 	case BPF_ALU | BPF_MOD | BPF_X:
1158 	case BPF_ALU64 | BPF_MOD | BPF_X:
1159 		if (!off)
1160 			emit(A64_UDIV(is64, tmp, dst, src), ctx);
1161 		else
1162 			emit(A64_SDIV(is64, tmp, dst, src), ctx);
1163 		emit(A64_MSUB(is64, dst, dst, tmp, src), ctx);
1164 		break;
1165 	case BPF_ALU | BPF_LSH | BPF_X:
1166 	case BPF_ALU64 | BPF_LSH | BPF_X:
1167 		emit(A64_LSLV(is64, dst, dst, src), ctx);
1168 		break;
1169 	case BPF_ALU | BPF_RSH | BPF_X:
1170 	case BPF_ALU64 | BPF_RSH | BPF_X:
1171 		emit(A64_LSRV(is64, dst, dst, src), ctx);
1172 		break;
1173 	case BPF_ALU | BPF_ARSH | BPF_X:
1174 	case BPF_ALU64 | BPF_ARSH | BPF_X:
1175 		emit(A64_ASRV(is64, dst, dst, src), ctx);
1176 		break;
1177 	/* dst = -dst */
1178 	case BPF_ALU | BPF_NEG:
1179 	case BPF_ALU64 | BPF_NEG:
1180 		emit(A64_NEG(is64, dst, dst), ctx);
1181 		break;
1182 	/* dst = BSWAP##imm(dst) */
1183 	case BPF_ALU | BPF_END | BPF_FROM_LE:
1184 	case BPF_ALU | BPF_END | BPF_FROM_BE:
1185 	case BPF_ALU64 | BPF_END | BPF_FROM_LE:
1186 #ifdef CONFIG_CPU_BIG_ENDIAN
1187 		if (BPF_CLASS(code) == BPF_ALU && BPF_SRC(code) == BPF_FROM_BE)
1188 			goto emit_bswap_uxt;
1189 #else /* !CONFIG_CPU_BIG_ENDIAN */
1190 		if (BPF_CLASS(code) == BPF_ALU && BPF_SRC(code) == BPF_FROM_LE)
1191 			goto emit_bswap_uxt;
1192 #endif
1193 		switch (imm) {
1194 		case 16:
1195 			emit(A64_REV16(is64, dst, dst), ctx);
1196 			/* zero-extend 16 bits into 64 bits */
1197 			emit(A64_UXTH(is64, dst, dst), ctx);
1198 			break;
1199 		case 32:
1200 			emit(A64_REV32(0, dst, dst), ctx);
1201 			/* upper 32 bits already cleared */
1202 			break;
1203 		case 64:
1204 			emit(A64_REV64(dst, dst), ctx);
1205 			break;
1206 		}
1207 		break;
1208 emit_bswap_uxt:
1209 		switch (imm) {
1210 		case 16:
1211 			/* zero-extend 16 bits into 64 bits */
1212 			emit(A64_UXTH(is64, dst, dst), ctx);
1213 			break;
1214 		case 32:
1215 			/* zero-extend 32 bits into 64 bits */
1216 			emit(A64_UXTW(is64, dst, dst), ctx);
1217 			break;
1218 		case 64:
1219 			/* nop */
1220 			break;
1221 		}
1222 		break;
1223 	/* dst = imm */
1224 	case BPF_ALU | BPF_MOV | BPF_K:
1225 	case BPF_ALU64 | BPF_MOV | BPF_K:
1226 		emit_a64_mov_i(is64, dst, imm, ctx);
1227 		break;
1228 	/* dst = dst OP imm */
1229 	case BPF_ALU | BPF_ADD | BPF_K:
1230 	case BPF_ALU64 | BPF_ADD | BPF_K:
1231 		emit_a64_add_i(is64, dst, dst, tmp, imm, ctx);
1232 		break;
1233 	case BPF_ALU | BPF_SUB | BPF_K:
1234 	case BPF_ALU64 | BPF_SUB | BPF_K:
1235 		if (is_addsub_imm(imm)) {
1236 			emit(A64_SUB_I(is64, dst, dst, imm), ctx);
1237 		} else if (is_addsub_imm(-(u32)imm)) {
1238 			emit(A64_ADD_I(is64, dst, dst, -imm), ctx);
1239 		} else {
1240 			emit_a64_mov_i(is64, tmp, imm, ctx);
1241 			emit(A64_SUB(is64, dst, dst, tmp), ctx);
1242 		}
1243 		break;
1244 	case BPF_ALU | BPF_AND | BPF_K:
1245 	case BPF_ALU64 | BPF_AND | BPF_K:
1246 		a64_insn = A64_AND_I(is64, dst, dst, imm);
1247 		if (a64_insn != AARCH64_BREAK_FAULT) {
1248 			emit(a64_insn, ctx);
1249 		} else {
1250 			emit_a64_mov_i(is64, tmp, imm, ctx);
1251 			emit(A64_AND(is64, dst, dst, tmp), ctx);
1252 		}
1253 		break;
1254 	case BPF_ALU | BPF_OR | BPF_K:
1255 	case BPF_ALU64 | BPF_OR | BPF_K:
1256 		a64_insn = A64_ORR_I(is64, dst, dst, imm);
1257 		if (a64_insn != AARCH64_BREAK_FAULT) {
1258 			emit(a64_insn, ctx);
1259 		} else {
1260 			emit_a64_mov_i(is64, tmp, imm, ctx);
1261 			emit(A64_ORR(is64, dst, dst, tmp), ctx);
1262 		}
1263 		break;
1264 	case BPF_ALU | BPF_XOR | BPF_K:
1265 	case BPF_ALU64 | BPF_XOR | BPF_K:
1266 		a64_insn = A64_EOR_I(is64, dst, dst, imm);
1267 		if (a64_insn != AARCH64_BREAK_FAULT) {
1268 			emit(a64_insn, ctx);
1269 		} else {
1270 			emit_a64_mov_i(is64, tmp, imm, ctx);
1271 			emit(A64_EOR(is64, dst, dst, tmp), ctx);
1272 		}
1273 		break;
1274 	case BPF_ALU | BPF_MUL | BPF_K:
1275 	case BPF_ALU64 | BPF_MUL | BPF_K:
1276 		emit_a64_mov_i(is64, tmp, imm, ctx);
1277 		emit(A64_MUL(is64, dst, dst, tmp), ctx);
1278 		break;
1279 	case BPF_ALU | BPF_DIV | BPF_K:
1280 	case BPF_ALU64 | BPF_DIV | BPF_K:
1281 		emit_a64_mov_i(is64, tmp, imm, ctx);
1282 		if (!off)
1283 			emit(A64_UDIV(is64, dst, dst, tmp), ctx);
1284 		else
1285 			emit(A64_SDIV(is64, dst, dst, tmp), ctx);
1286 		break;
1287 	case BPF_ALU | BPF_MOD | BPF_K:
1288 	case BPF_ALU64 | BPF_MOD | BPF_K:
1289 		emit_a64_mov_i(is64, tmp2, imm, ctx);
1290 		if (!off)
1291 			emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
1292 		else
1293 			emit(A64_SDIV(is64, tmp, dst, tmp2), ctx);
1294 		emit(A64_MSUB(is64, dst, dst, tmp, tmp2), ctx);
1295 		break;
1296 	case BPF_ALU | BPF_LSH | BPF_K:
1297 	case BPF_ALU64 | BPF_LSH | BPF_K:
1298 		emit(A64_LSL(is64, dst, dst, imm), ctx);
1299 		break;
1300 	case BPF_ALU | BPF_RSH | BPF_K:
1301 	case BPF_ALU64 | BPF_RSH | BPF_K:
1302 		emit(A64_LSR(is64, dst, dst, imm), ctx);
1303 		break;
1304 	case BPF_ALU | BPF_ARSH | BPF_K:
1305 	case BPF_ALU64 | BPF_ARSH | BPF_K:
1306 		emit(A64_ASR(is64, dst, dst, imm), ctx);
1307 		break;
1308 
1309 	/* JUMP off */
1310 	case BPF_JMP | BPF_JA:
1311 	case BPF_JMP32 | BPF_JA:
1312 		if (BPF_CLASS(code) == BPF_JMP)
1313 			jmp_offset = bpf2a64_offset(i, off, ctx);
1314 		else
1315 			jmp_offset = bpf2a64_offset(i, imm, ctx);
1316 		check_imm26(jmp_offset);
1317 		emit(A64_B(jmp_offset), ctx);
1318 		break;
1319 	/* IF (dst COND src) JUMP off */
1320 	case BPF_JMP | BPF_JEQ | BPF_X:
1321 	case BPF_JMP | BPF_JGT | BPF_X:
1322 	case BPF_JMP | BPF_JLT | BPF_X:
1323 	case BPF_JMP | BPF_JGE | BPF_X:
1324 	case BPF_JMP | BPF_JLE | BPF_X:
1325 	case BPF_JMP | BPF_JNE | BPF_X:
1326 	case BPF_JMP | BPF_JSGT | BPF_X:
1327 	case BPF_JMP | BPF_JSLT | BPF_X:
1328 	case BPF_JMP | BPF_JSGE | BPF_X:
1329 	case BPF_JMP | BPF_JSLE | BPF_X:
1330 	case BPF_JMP32 | BPF_JEQ | BPF_X:
1331 	case BPF_JMP32 | BPF_JGT | BPF_X:
1332 	case BPF_JMP32 | BPF_JLT | BPF_X:
1333 	case BPF_JMP32 | BPF_JGE | BPF_X:
1334 	case BPF_JMP32 | BPF_JLE | BPF_X:
1335 	case BPF_JMP32 | BPF_JNE | BPF_X:
1336 	case BPF_JMP32 | BPF_JSGT | BPF_X:
1337 	case BPF_JMP32 | BPF_JSLT | BPF_X:
1338 	case BPF_JMP32 | BPF_JSGE | BPF_X:
1339 	case BPF_JMP32 | BPF_JSLE | BPF_X:
1340 		emit(A64_CMP(is64, dst, src), ctx);
1341 emit_cond_jmp:
1342 		jmp_offset = bpf2a64_offset(i, off, ctx);
1343 		check_imm19(jmp_offset);
1344 		switch (BPF_OP(code)) {
1345 		case BPF_JEQ:
1346 			jmp_cond = A64_COND_EQ;
1347 			break;
1348 		case BPF_JGT:
1349 			jmp_cond = A64_COND_HI;
1350 			break;
1351 		case BPF_JLT:
1352 			jmp_cond = A64_COND_CC;
1353 			break;
1354 		case BPF_JGE:
1355 			jmp_cond = A64_COND_CS;
1356 			break;
1357 		case BPF_JLE:
1358 			jmp_cond = A64_COND_LS;
1359 			break;
1360 		case BPF_JSET:
1361 		case BPF_JNE:
1362 			jmp_cond = A64_COND_NE;
1363 			break;
1364 		case BPF_JSGT:
1365 			jmp_cond = A64_COND_GT;
1366 			break;
1367 		case BPF_JSLT:
1368 			jmp_cond = A64_COND_LT;
1369 			break;
1370 		case BPF_JSGE:
1371 			jmp_cond = A64_COND_GE;
1372 			break;
1373 		case BPF_JSLE:
1374 			jmp_cond = A64_COND_LE;
1375 			break;
1376 		default:
1377 			return -EFAULT;
1378 		}
1379 		emit(A64_B_(jmp_cond, jmp_offset), ctx);
1380 		break;
1381 	case BPF_JMP | BPF_JSET | BPF_X:
1382 	case BPF_JMP32 | BPF_JSET | BPF_X:
1383 		emit(A64_TST(is64, dst, src), ctx);
1384 		goto emit_cond_jmp;
1385 	/* IF (dst COND imm) JUMP off */
1386 	case BPF_JMP | BPF_JEQ | BPF_K:
1387 	case BPF_JMP | BPF_JGT | BPF_K:
1388 	case BPF_JMP | BPF_JLT | BPF_K:
1389 	case BPF_JMP | BPF_JGE | BPF_K:
1390 	case BPF_JMP | BPF_JLE | BPF_K:
1391 	case BPF_JMP | BPF_JNE | BPF_K:
1392 	case BPF_JMP | BPF_JSGT | BPF_K:
1393 	case BPF_JMP | BPF_JSLT | BPF_K:
1394 	case BPF_JMP | BPF_JSGE | BPF_K:
1395 	case BPF_JMP | BPF_JSLE | BPF_K:
1396 	case BPF_JMP32 | BPF_JEQ | BPF_K:
1397 	case BPF_JMP32 | BPF_JGT | BPF_K:
1398 	case BPF_JMP32 | BPF_JLT | BPF_K:
1399 	case BPF_JMP32 | BPF_JGE | BPF_K:
1400 	case BPF_JMP32 | BPF_JLE | BPF_K:
1401 	case BPF_JMP32 | BPF_JNE | BPF_K:
1402 	case BPF_JMP32 | BPF_JSGT | BPF_K:
1403 	case BPF_JMP32 | BPF_JSLT | BPF_K:
1404 	case BPF_JMP32 | BPF_JSGE | BPF_K:
1405 	case BPF_JMP32 | BPF_JSLE | BPF_K:
1406 		if (is_addsub_imm(imm)) {
1407 			emit(A64_CMP_I(is64, dst, imm), ctx);
1408 		} else if (is_addsub_imm(-(u32)imm)) {
1409 			emit(A64_CMN_I(is64, dst, -imm), ctx);
1410 		} else {
1411 			emit_a64_mov_i(is64, tmp, imm, ctx);
1412 			emit(A64_CMP(is64, dst, tmp), ctx);
1413 		}
1414 		goto emit_cond_jmp;
1415 	case BPF_JMP | BPF_JSET | BPF_K:
1416 	case BPF_JMP32 | BPF_JSET | BPF_K:
1417 		a64_insn = A64_TST_I(is64, dst, imm);
1418 		if (a64_insn != AARCH64_BREAK_FAULT) {
1419 			emit(a64_insn, ctx);
1420 		} else {
1421 			emit_a64_mov_i(is64, tmp, imm, ctx);
1422 			emit(A64_TST(is64, dst, tmp), ctx);
1423 		}
1424 		goto emit_cond_jmp;
1425 	/* function call */
1426 	case BPF_JMP | BPF_CALL:
1427 	{
1428 		const u8 r0 = bpf2a64[BPF_REG_0];
1429 		bool func_addr_fixed;
1430 		u64 func_addr;
1431 		u32 cpu_offset;
1432 
1433 		/* Implement helper call to bpf_get_smp_processor_id() inline */
1434 		if (insn->src_reg == 0 && insn->imm == BPF_FUNC_get_smp_processor_id) {
1435 			cpu_offset = offsetof(struct thread_info, cpu);
1436 
1437 			emit(A64_MRS_SP_EL0(tmp), ctx);
1438 			if (is_lsi_offset(cpu_offset, 2)) {
1439 				emit(A64_LDR32I(r0, tmp, cpu_offset), ctx);
1440 			} else {
1441 				emit_a64_mov_i(1, tmp2, cpu_offset, ctx);
1442 				emit(A64_LDR32(r0, tmp, tmp2), ctx);
1443 			}
1444 			break;
1445 		}
1446 
1447 		/* Implement helper call to bpf_get_current_task/_btf() inline */
1448 		if (insn->src_reg == 0 && (insn->imm == BPF_FUNC_get_current_task ||
1449 					   insn->imm == BPF_FUNC_get_current_task_btf)) {
1450 			emit(A64_MRS_SP_EL0(r0), ctx);
1451 			break;
1452 		}
1453 
1454 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
1455 					    &func_addr, &func_addr_fixed);
1456 		if (ret < 0)
1457 			return ret;
1458 		emit_call(func_addr, ctx);
1459 		emit(A64_MOV(1, r0, A64_R(0)), ctx);
1460 		break;
1461 	}
1462 	/* tail call */
1463 	case BPF_JMP | BPF_TAIL_CALL:
1464 		if (emit_bpf_tail_call(ctx))
1465 			return -EFAULT;
1466 		break;
1467 	/* function return */
1468 	case BPF_JMP | BPF_EXIT:
1469 		/* Optimization: when last instruction is EXIT,
1470 		   simply fallthrough to epilogue. */
1471 		if (i == ctx->prog->len - 1)
1472 			break;
1473 		jmp_offset = epilogue_offset(ctx);
1474 		check_imm26(jmp_offset);
1475 		emit(A64_B(jmp_offset), ctx);
1476 		break;
1477 
1478 	/* dst = imm64 */
1479 	case BPF_LD | BPF_IMM | BPF_DW:
1480 	{
1481 		const struct bpf_insn insn1 = insn[1];
1482 		u64 imm64;
1483 
1484 		imm64 = (u64)insn1.imm << 32 | (u32)imm;
1485 		if (bpf_pseudo_func(insn))
1486 			emit_addr_mov_i64(dst, imm64, ctx);
1487 		else
1488 			emit_a64_mov_i64(dst, imm64, ctx);
1489 
1490 		return 1;
1491 	}
1492 
1493 	/* LDX: dst = (u64)*(unsigned size *)(src + off) */
1494 	case BPF_LDX | BPF_MEM | BPF_W:
1495 	case BPF_LDX | BPF_MEM | BPF_H:
1496 	case BPF_LDX | BPF_MEM | BPF_B:
1497 	case BPF_LDX | BPF_MEM | BPF_DW:
1498 	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1499 	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1500 	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1501 	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1502 	/* LDXS: dst_reg = (s64)*(signed size *)(src_reg + off) */
1503 	case BPF_LDX | BPF_MEMSX | BPF_B:
1504 	case BPF_LDX | BPF_MEMSX | BPF_H:
1505 	case BPF_LDX | BPF_MEMSX | BPF_W:
1506 	case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
1507 	case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
1508 	case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
1509 	case BPF_LDX | BPF_PROBE_MEM32 | BPF_B:
1510 	case BPF_LDX | BPF_PROBE_MEM32 | BPF_H:
1511 	case BPF_LDX | BPF_PROBE_MEM32 | BPF_W:
1512 	case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW:
1513 		if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) {
1514 			emit(A64_ADD(1, tmp2, src, arena_vm_base), ctx);
1515 			src = tmp2;
1516 		}
1517 		if (src == fp) {
1518 			src_adj = A64_SP;
1519 			off_adj = off + ctx->stack_size;
1520 		} else {
1521 			src_adj = src;
1522 			off_adj = off;
1523 		}
1524 		sign_extend = (BPF_MODE(insn->code) == BPF_MEMSX ||
1525 				BPF_MODE(insn->code) == BPF_PROBE_MEMSX);
1526 		switch (BPF_SIZE(code)) {
1527 		case BPF_W:
1528 			if (is_lsi_offset(off_adj, 2)) {
1529 				if (sign_extend)
1530 					emit(A64_LDRSWI(dst, src_adj, off_adj), ctx);
1531 				else
1532 					emit(A64_LDR32I(dst, src_adj, off_adj), ctx);
1533 			} else {
1534 				emit_a64_mov_i(1, tmp, off, ctx);
1535 				if (sign_extend)
1536 					emit(A64_LDRSW(dst, src, tmp), ctx);
1537 				else
1538 					emit(A64_LDR32(dst, src, tmp), ctx);
1539 			}
1540 			break;
1541 		case BPF_H:
1542 			if (is_lsi_offset(off_adj, 1)) {
1543 				if (sign_extend)
1544 					emit(A64_LDRSHI(dst, src_adj, off_adj), ctx);
1545 				else
1546 					emit(A64_LDRHI(dst, src_adj, off_adj), ctx);
1547 			} else {
1548 				emit_a64_mov_i(1, tmp, off, ctx);
1549 				if (sign_extend)
1550 					emit(A64_LDRSH(dst, src, tmp), ctx);
1551 				else
1552 					emit(A64_LDRH(dst, src, tmp), ctx);
1553 			}
1554 			break;
1555 		case BPF_B:
1556 			if (is_lsi_offset(off_adj, 0)) {
1557 				if (sign_extend)
1558 					emit(A64_LDRSBI(dst, src_adj, off_adj), ctx);
1559 				else
1560 					emit(A64_LDRBI(dst, src_adj, off_adj), ctx);
1561 			} else {
1562 				emit_a64_mov_i(1, tmp, off, ctx);
1563 				if (sign_extend)
1564 					emit(A64_LDRSB(dst, src, tmp), ctx);
1565 				else
1566 					emit(A64_LDRB(dst, src, tmp), ctx);
1567 			}
1568 			break;
1569 		case BPF_DW:
1570 			if (is_lsi_offset(off_adj, 3)) {
1571 				emit(A64_LDR64I(dst, src_adj, off_adj), ctx);
1572 			} else {
1573 				emit_a64_mov_i(1, tmp, off, ctx);
1574 				emit(A64_LDR64(dst, src, tmp), ctx);
1575 			}
1576 			break;
1577 		}
1578 
1579 		ret = add_exception_handler(insn, ctx, dst);
1580 		if (ret)
1581 			return ret;
1582 		break;
1583 
1584 	/* speculation barrier */
1585 	case BPF_ST | BPF_NOSPEC:
1586 		/*
1587 		 * Nothing required here.
1588 		 *
1589 		 * In case of arm64, we rely on the firmware mitigation of
1590 		 * Speculative Store Bypass as controlled via the ssbd kernel
1591 		 * parameter. Whenever the mitigation is enabled, it works
1592 		 * for all of the kernel code with no need to provide any
1593 		 * additional instructions.
1594 		 */
1595 		break;
1596 
1597 	/* ST: *(size *)(dst + off) = imm */
1598 	case BPF_ST | BPF_MEM | BPF_W:
1599 	case BPF_ST | BPF_MEM | BPF_H:
1600 	case BPF_ST | BPF_MEM | BPF_B:
1601 	case BPF_ST | BPF_MEM | BPF_DW:
1602 	case BPF_ST | BPF_PROBE_MEM32 | BPF_B:
1603 	case BPF_ST | BPF_PROBE_MEM32 | BPF_H:
1604 	case BPF_ST | BPF_PROBE_MEM32 | BPF_W:
1605 	case BPF_ST | BPF_PROBE_MEM32 | BPF_DW:
1606 		if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) {
1607 			emit(A64_ADD(1, tmp2, dst, arena_vm_base), ctx);
1608 			dst = tmp2;
1609 		}
1610 		if (dst == fp) {
1611 			dst_adj = A64_SP;
1612 			off_adj = off + ctx->stack_size;
1613 		} else {
1614 			dst_adj = dst;
1615 			off_adj = off;
1616 		}
1617 		/* Load imm to a register then store it */
1618 		emit_a64_mov_i(1, tmp, imm, ctx);
1619 		switch (BPF_SIZE(code)) {
1620 		case BPF_W:
1621 			if (is_lsi_offset(off_adj, 2)) {
1622 				emit(A64_STR32I(tmp, dst_adj, off_adj), ctx);
1623 			} else {
1624 				emit_a64_mov_i(1, tmp2, off, ctx);
1625 				emit(A64_STR32(tmp, dst, tmp2), ctx);
1626 			}
1627 			break;
1628 		case BPF_H:
1629 			if (is_lsi_offset(off_adj, 1)) {
1630 				emit(A64_STRHI(tmp, dst_adj, off_adj), ctx);
1631 			} else {
1632 				emit_a64_mov_i(1, tmp2, off, ctx);
1633 				emit(A64_STRH(tmp, dst, tmp2), ctx);
1634 			}
1635 			break;
1636 		case BPF_B:
1637 			if (is_lsi_offset(off_adj, 0)) {
1638 				emit(A64_STRBI(tmp, dst_adj, off_adj), ctx);
1639 			} else {
1640 				emit_a64_mov_i(1, tmp2, off, ctx);
1641 				emit(A64_STRB(tmp, dst, tmp2), ctx);
1642 			}
1643 			break;
1644 		case BPF_DW:
1645 			if (is_lsi_offset(off_adj, 3)) {
1646 				emit(A64_STR64I(tmp, dst_adj, off_adj), ctx);
1647 			} else {
1648 				emit_a64_mov_i(1, tmp2, off, ctx);
1649 				emit(A64_STR64(tmp, dst, tmp2), ctx);
1650 			}
1651 			break;
1652 		}
1653 
1654 		ret = add_exception_handler(insn, ctx, dst);
1655 		if (ret)
1656 			return ret;
1657 		break;
1658 
1659 	/* STX: *(size *)(dst + off) = src */
1660 	case BPF_STX | BPF_MEM | BPF_W:
1661 	case BPF_STX | BPF_MEM | BPF_H:
1662 	case BPF_STX | BPF_MEM | BPF_B:
1663 	case BPF_STX | BPF_MEM | BPF_DW:
1664 	case BPF_STX | BPF_PROBE_MEM32 | BPF_B:
1665 	case BPF_STX | BPF_PROBE_MEM32 | BPF_H:
1666 	case BPF_STX | BPF_PROBE_MEM32 | BPF_W:
1667 	case BPF_STX | BPF_PROBE_MEM32 | BPF_DW:
1668 		if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) {
1669 			emit(A64_ADD(1, tmp2, dst, arena_vm_base), ctx);
1670 			dst = tmp2;
1671 		}
1672 		if (dst == fp) {
1673 			dst_adj = A64_SP;
1674 			off_adj = off + ctx->stack_size;
1675 		} else {
1676 			dst_adj = dst;
1677 			off_adj = off;
1678 		}
1679 		switch (BPF_SIZE(code)) {
1680 		case BPF_W:
1681 			if (is_lsi_offset(off_adj, 2)) {
1682 				emit(A64_STR32I(src, dst_adj, off_adj), ctx);
1683 			} else {
1684 				emit_a64_mov_i(1, tmp, off, ctx);
1685 				emit(A64_STR32(src, dst, tmp), ctx);
1686 			}
1687 			break;
1688 		case BPF_H:
1689 			if (is_lsi_offset(off_adj, 1)) {
1690 				emit(A64_STRHI(src, dst_adj, off_adj), ctx);
1691 			} else {
1692 				emit_a64_mov_i(1, tmp, off, ctx);
1693 				emit(A64_STRH(src, dst, tmp), ctx);
1694 			}
1695 			break;
1696 		case BPF_B:
1697 			if (is_lsi_offset(off_adj, 0)) {
1698 				emit(A64_STRBI(src, dst_adj, off_adj), ctx);
1699 			} else {
1700 				emit_a64_mov_i(1, tmp, off, ctx);
1701 				emit(A64_STRB(src, dst, tmp), ctx);
1702 			}
1703 			break;
1704 		case BPF_DW:
1705 			if (is_lsi_offset(off_adj, 3)) {
1706 				emit(A64_STR64I(src, dst_adj, off_adj), ctx);
1707 			} else {
1708 				emit_a64_mov_i(1, tmp, off, ctx);
1709 				emit(A64_STR64(src, dst, tmp), ctx);
1710 			}
1711 			break;
1712 		}
1713 
1714 		ret = add_exception_handler(insn, ctx, dst);
1715 		if (ret)
1716 			return ret;
1717 		break;
1718 
1719 	case BPF_STX | BPF_ATOMIC | BPF_B:
1720 	case BPF_STX | BPF_ATOMIC | BPF_H:
1721 	case BPF_STX | BPF_ATOMIC | BPF_W:
1722 	case BPF_STX | BPF_ATOMIC | BPF_DW:
1723 	case BPF_STX | BPF_PROBE_ATOMIC | BPF_B:
1724 	case BPF_STX | BPF_PROBE_ATOMIC | BPF_H:
1725 	case BPF_STX | BPF_PROBE_ATOMIC | BPF_W:
1726 	case BPF_STX | BPF_PROBE_ATOMIC | BPF_DW:
1727 		if (bpf_atomic_is_load_store(insn))
1728 			ret = emit_atomic_ld_st(insn, ctx);
1729 		else if (cpus_have_cap(ARM64_HAS_LSE_ATOMICS))
1730 			ret = emit_lse_atomic(insn, ctx);
1731 		else
1732 			ret = emit_ll_sc_atomic(insn, ctx);
1733 		if (ret)
1734 			return ret;
1735 
1736 		ret = add_exception_handler(insn, ctx, dst);
1737 		if (ret)
1738 			return ret;
1739 		break;
1740 
1741 	default:
1742 		pr_err_once("unknown opcode %02x\n", code);
1743 		return -EINVAL;
1744 	}
1745 
1746 	return 0;
1747 }
1748 
1749 static int build_body(struct jit_ctx *ctx, bool extra_pass)
1750 {
1751 	const struct bpf_prog *prog = ctx->prog;
1752 	int i;
1753 
1754 	/*
1755 	 * - offset[0] offset of the end of prologue,
1756 	 *   start of the 1st instruction.
1757 	 * - offset[1] - offset of the end of 1st instruction,
1758 	 *   start of the 2nd instruction
1759 	 * [....]
1760 	 * - offset[3] - offset of the end of 3rd instruction,
1761 	 *   start of 4th instruction
1762 	 */
1763 	for (i = 0; i < prog->len; i++) {
1764 		const struct bpf_insn *insn = &prog->insnsi[i];
1765 		int ret;
1766 
1767 		ctx->offset[i] = ctx->idx;
1768 		ret = build_insn(insn, ctx, extra_pass);
1769 		if (ret > 0) {
1770 			i++;
1771 			ctx->offset[i] = ctx->idx;
1772 			continue;
1773 		}
1774 		if (ret)
1775 			return ret;
1776 	}
1777 	/*
1778 	 * offset is allocated with prog->len + 1 so fill in
1779 	 * the last element with the offset after the last
1780 	 * instruction (end of program)
1781 	 */
1782 	ctx->offset[i] = ctx->idx;
1783 
1784 	return 0;
1785 }
1786 
1787 static int validate_code(struct jit_ctx *ctx)
1788 {
1789 	int i;
1790 
1791 	for (i = 0; i < ctx->idx; i++) {
1792 		u32 a64_insn = le32_to_cpu(ctx->image[i]);
1793 
1794 		if (a64_insn == AARCH64_BREAK_FAULT)
1795 			return -1;
1796 	}
1797 	return 0;
1798 }
1799 
1800 static int validate_ctx(struct jit_ctx *ctx)
1801 {
1802 	if (validate_code(ctx))
1803 		return -1;
1804 
1805 	if (WARN_ON_ONCE(ctx->exentry_idx != ctx->prog->aux->num_exentries))
1806 		return -1;
1807 
1808 	return 0;
1809 }
1810 
1811 static inline void bpf_flush_icache(void *start, void *end)
1812 {
1813 	flush_icache_range((unsigned long)start, (unsigned long)end);
1814 }
1815 
1816 struct arm64_jit_data {
1817 	struct bpf_binary_header *header;
1818 	u8 *ro_image;
1819 	struct bpf_binary_header *ro_header;
1820 	struct jit_ctx ctx;
1821 };
1822 
1823 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1824 {
1825 	int image_size, prog_size, extable_size, extable_align, extable_offset;
1826 	struct bpf_prog *tmp, *orig_prog = prog;
1827 	struct bpf_binary_header *header;
1828 	struct bpf_binary_header *ro_header;
1829 	struct arm64_jit_data *jit_data;
1830 	bool was_classic = bpf_prog_was_classic(prog);
1831 	bool tmp_blinded = false;
1832 	bool extra_pass = false;
1833 	struct jit_ctx ctx;
1834 	u8 *image_ptr;
1835 	u8 *ro_image_ptr;
1836 	int body_idx;
1837 	int exentry_idx;
1838 
1839 	if (!prog->jit_requested)
1840 		return orig_prog;
1841 
1842 	tmp = bpf_jit_blind_constants(prog);
1843 	/* If blinding was requested and we failed during blinding,
1844 	 * we must fall back to the interpreter.
1845 	 */
1846 	if (IS_ERR(tmp))
1847 		return orig_prog;
1848 	if (tmp != prog) {
1849 		tmp_blinded = true;
1850 		prog = tmp;
1851 	}
1852 
1853 	jit_data = prog->aux->jit_data;
1854 	if (!jit_data) {
1855 		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1856 		if (!jit_data) {
1857 			prog = orig_prog;
1858 			goto out;
1859 		}
1860 		prog->aux->jit_data = jit_data;
1861 	}
1862 	if (jit_data->ctx.offset) {
1863 		ctx = jit_data->ctx;
1864 		ro_image_ptr = jit_data->ro_image;
1865 		ro_header = jit_data->ro_header;
1866 		header = jit_data->header;
1867 		image_ptr = (void *)header + ((void *)ro_image_ptr
1868 						 - (void *)ro_header);
1869 		extra_pass = true;
1870 		prog_size = sizeof(u32) * ctx.idx;
1871 		goto skip_init_ctx;
1872 	}
1873 	memset(&ctx, 0, sizeof(ctx));
1874 	ctx.prog = prog;
1875 
1876 	ctx.offset = kvcalloc(prog->len + 1, sizeof(int), GFP_KERNEL);
1877 	if (ctx.offset == NULL) {
1878 		prog = orig_prog;
1879 		goto out_off;
1880 	}
1881 
1882 	ctx.user_vm_start = bpf_arena_get_user_vm_start(prog->aux->arena);
1883 	ctx.arena_vm_start = bpf_arena_get_kern_vm_start(prog->aux->arena);
1884 
1885 	/* Pass 1: Estimate the maximum image size.
1886 	 *
1887 	 * BPF line info needs ctx->offset[i] to be the offset of
1888 	 * instruction[i] in jited image, so build prologue first.
1889 	 */
1890 	if (build_prologue(&ctx, was_classic)) {
1891 		prog = orig_prog;
1892 		goto out_off;
1893 	}
1894 
1895 	if (build_body(&ctx, extra_pass)) {
1896 		prog = orig_prog;
1897 		goto out_off;
1898 	}
1899 
1900 	ctx.epilogue_offset = ctx.idx;
1901 	build_epilogue(&ctx);
1902 	build_plt(&ctx);
1903 
1904 	extable_align = __alignof__(struct exception_table_entry);
1905 	extable_size = prog->aux->num_exentries *
1906 		sizeof(struct exception_table_entry);
1907 
1908 	/* Now we know the maximum image size. */
1909 	prog_size = sizeof(u32) * ctx.idx;
1910 	/* also allocate space for plt target */
1911 	extable_offset = round_up(prog_size + PLT_TARGET_SIZE, extable_align);
1912 	image_size = extable_offset + extable_size;
1913 	ro_header = bpf_jit_binary_pack_alloc(image_size, &ro_image_ptr,
1914 					      sizeof(u32), &header, &image_ptr,
1915 					      jit_fill_hole);
1916 	if (!ro_header) {
1917 		prog = orig_prog;
1918 		goto out_off;
1919 	}
1920 
1921 	/* Pass 2: Determine jited position and result for each instruction */
1922 
1923 	/*
1924 	 * Use the image(RW) for writing the JITed instructions. But also save
1925 	 * the ro_image(RX) for calculating the offsets in the image. The RW
1926 	 * image will be later copied to the RX image from where the program
1927 	 * will run. The bpf_jit_binary_pack_finalize() will do this copy in the
1928 	 * final step.
1929 	 */
1930 	ctx.image = (__le32 *)image_ptr;
1931 	ctx.ro_image = (__le32 *)ro_image_ptr;
1932 	if (extable_size)
1933 		prog->aux->extable = (void *)ro_image_ptr + extable_offset;
1934 skip_init_ctx:
1935 	ctx.idx = 0;
1936 	ctx.exentry_idx = 0;
1937 	ctx.write = true;
1938 
1939 	build_prologue(&ctx, was_classic);
1940 
1941 	/* Record exentry_idx and body_idx before first build_body */
1942 	exentry_idx = ctx.exentry_idx;
1943 	body_idx = ctx.idx;
1944 	/* Dont write body instructions to memory for now */
1945 	ctx.write = false;
1946 
1947 	if (build_body(&ctx, extra_pass)) {
1948 		prog = orig_prog;
1949 		goto out_free_hdr;
1950 	}
1951 
1952 	ctx.epilogue_offset = ctx.idx;
1953 	ctx.exentry_idx = exentry_idx;
1954 	ctx.idx = body_idx;
1955 	ctx.write = true;
1956 
1957 	/* Pass 3: Adjust jump offset and write final image */
1958 	if (build_body(&ctx, extra_pass) ||
1959 		WARN_ON_ONCE(ctx.idx != ctx.epilogue_offset)) {
1960 		prog = orig_prog;
1961 		goto out_free_hdr;
1962 	}
1963 
1964 	build_epilogue(&ctx);
1965 	build_plt(&ctx);
1966 
1967 	/* Extra pass to validate JITed code. */
1968 	if (validate_ctx(&ctx)) {
1969 		prog = orig_prog;
1970 		goto out_free_hdr;
1971 	}
1972 
1973 	/* update the real prog size */
1974 	prog_size = sizeof(u32) * ctx.idx;
1975 
1976 	/* And we're done. */
1977 	if (bpf_jit_enable > 1)
1978 		bpf_jit_dump(prog->len, prog_size, 2, ctx.image);
1979 
1980 	if (!prog->is_func || extra_pass) {
1981 		/* The jited image may shrink since the jited result for
1982 		 * BPF_CALL to subprog may be changed from indirect call
1983 		 * to direct call.
1984 		 */
1985 		if (extra_pass && ctx.idx > jit_data->ctx.idx) {
1986 			pr_err_once("multi-func JIT bug %d > %d\n",
1987 				    ctx.idx, jit_data->ctx.idx);
1988 			prog->bpf_func = NULL;
1989 			prog->jited = 0;
1990 			prog->jited_len = 0;
1991 			goto out_free_hdr;
1992 		}
1993 		if (WARN_ON(bpf_jit_binary_pack_finalize(ro_header, header))) {
1994 			/* ro_header has been freed */
1995 			ro_header = NULL;
1996 			prog = orig_prog;
1997 			goto out_off;
1998 		}
1999 		/*
2000 		 * The instructions have now been copied to the ROX region from
2001 		 * where they will execute. Now the data cache has to be cleaned to
2002 		 * the PoU and the I-cache has to be invalidated for the VAs.
2003 		 */
2004 		bpf_flush_icache(ro_header, ctx.ro_image + ctx.idx);
2005 	} else {
2006 		jit_data->ctx = ctx;
2007 		jit_data->ro_image = ro_image_ptr;
2008 		jit_data->header = header;
2009 		jit_data->ro_header = ro_header;
2010 	}
2011 
2012 	prog->bpf_func = (void *)ctx.ro_image;
2013 	prog->jited = 1;
2014 	prog->jited_len = prog_size;
2015 
2016 	if (!prog->is_func || extra_pass) {
2017 		int i;
2018 
2019 		/* offset[prog->len] is the size of program */
2020 		for (i = 0; i <= prog->len; i++)
2021 			ctx.offset[i] *= AARCH64_INSN_SIZE;
2022 		bpf_prog_fill_jited_linfo(prog, ctx.offset + 1);
2023 out_off:
2024 		kvfree(ctx.offset);
2025 		kfree(jit_data);
2026 		prog->aux->jit_data = NULL;
2027 	}
2028 out:
2029 	if (tmp_blinded)
2030 		bpf_jit_prog_release_other(prog, prog == orig_prog ?
2031 					   tmp : orig_prog);
2032 	return prog;
2033 
2034 out_free_hdr:
2035 	if (header) {
2036 		bpf_arch_text_copy(&ro_header->size, &header->size,
2037 				   sizeof(header->size));
2038 		bpf_jit_binary_pack_free(ro_header, header);
2039 	}
2040 	goto out_off;
2041 }
2042 
2043 bool bpf_jit_supports_kfunc_call(void)
2044 {
2045 	return true;
2046 }
2047 
2048 void *bpf_arch_text_copy(void *dst, void *src, size_t len)
2049 {
2050 	if (!aarch64_insn_copy(dst, src, len))
2051 		return ERR_PTR(-EINVAL);
2052 	return dst;
2053 }
2054 
2055 u64 bpf_jit_alloc_exec_limit(void)
2056 {
2057 	return VMALLOC_END - VMALLOC_START;
2058 }
2059 
2060 /* Indicate the JIT backend supports mixing bpf2bpf and tailcalls. */
2061 bool bpf_jit_supports_subprog_tailcalls(void)
2062 {
2063 	return true;
2064 }
2065 
2066 static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
2067 			    int args_off, int retval_off, int run_ctx_off,
2068 			    bool save_ret)
2069 {
2070 	__le32 *branch;
2071 	u64 enter_prog;
2072 	u64 exit_prog;
2073 	struct bpf_prog *p = l->link.prog;
2074 	int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
2075 
2076 	enter_prog = (u64)bpf_trampoline_enter(p);
2077 	exit_prog = (u64)bpf_trampoline_exit(p);
2078 
2079 	if (l->cookie == 0) {
2080 		/* if cookie is zero, one instruction is enough to store it */
2081 		emit(A64_STR64I(A64_ZR, A64_SP, run_ctx_off + cookie_off), ctx);
2082 	} else {
2083 		emit_a64_mov_i64(A64_R(10), l->cookie, ctx);
2084 		emit(A64_STR64I(A64_R(10), A64_SP, run_ctx_off + cookie_off),
2085 		     ctx);
2086 	}
2087 
2088 	/* save p to callee saved register x19 to avoid loading p with mov_i64
2089 	 * each time.
2090 	 */
2091 	emit_addr_mov_i64(A64_R(19), (const u64)p, ctx);
2092 
2093 	/* arg1: prog */
2094 	emit(A64_MOV(1, A64_R(0), A64_R(19)), ctx);
2095 	/* arg2: &run_ctx */
2096 	emit(A64_ADD_I(1, A64_R(1), A64_SP, run_ctx_off), ctx);
2097 
2098 	emit_call(enter_prog, ctx);
2099 
2100 	/* save return value to callee saved register x20 */
2101 	emit(A64_MOV(1, A64_R(20), A64_R(0)), ctx);
2102 
2103 	/* if (__bpf_prog_enter(prog) == 0)
2104 	 *         goto skip_exec_of_prog;
2105 	 */
2106 	branch = ctx->image + ctx->idx;
2107 	emit(A64_NOP, ctx);
2108 
2109 	emit(A64_ADD_I(1, A64_R(0), A64_SP, args_off), ctx);
2110 	if (!p->jited)
2111 		emit_addr_mov_i64(A64_R(1), (const u64)p->insnsi, ctx);
2112 
2113 	emit_call((const u64)p->bpf_func, ctx);
2114 
2115 	if (save_ret)
2116 		emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);
2117 
2118 	if (ctx->image) {
2119 		int offset = &ctx->image[ctx->idx] - branch;
2120 		*branch = cpu_to_le32(A64_CBZ(1, A64_R(0), offset));
2121 	}
2122 
2123 	/* arg1: prog */
2124 	emit(A64_MOV(1, A64_R(0), A64_R(19)), ctx);
2125 	/* arg2: start time */
2126 	emit(A64_MOV(1, A64_R(1), A64_R(20)), ctx);
2127 	/* arg3: &run_ctx */
2128 	emit(A64_ADD_I(1, A64_R(2), A64_SP, run_ctx_off), ctx);
2129 
2130 	emit_call(exit_prog, ctx);
2131 }
2132 
2133 static void invoke_bpf_mod_ret(struct jit_ctx *ctx, struct bpf_tramp_links *tl,
2134 			       int args_off, int retval_off, int run_ctx_off,
2135 			       __le32 **branches)
2136 {
2137 	int i;
2138 
2139 	/* The first fmod_ret program will receive a garbage return value.
2140 	 * Set this to 0 to avoid confusing the program.
2141 	 */
2142 	emit(A64_STR64I(A64_ZR, A64_SP, retval_off), ctx);
2143 	for (i = 0; i < tl->nr_links; i++) {
2144 		invoke_bpf_prog(ctx, tl->links[i], args_off, retval_off,
2145 				run_ctx_off, true);
2146 		/* if (*(u64 *)(sp + retval_off) !=  0)
2147 		 *	goto do_fexit;
2148 		 */
2149 		emit(A64_LDR64I(A64_R(10), A64_SP, retval_off), ctx);
2150 		/* Save the location of branch, and generate a nop.
2151 		 * This nop will be replaced with a cbnz later.
2152 		 */
2153 		branches[i] = ctx->image + ctx->idx;
2154 		emit(A64_NOP, ctx);
2155 	}
2156 }
2157 
2158 static void save_args(struct jit_ctx *ctx, int args_off, int nregs)
2159 {
2160 	int i;
2161 
2162 	for (i = 0; i < nregs; i++) {
2163 		emit(A64_STR64I(i, A64_SP, args_off), ctx);
2164 		args_off += 8;
2165 	}
2166 }
2167 
2168 static void restore_args(struct jit_ctx *ctx, int args_off, int nregs)
2169 {
2170 	int i;
2171 
2172 	for (i = 0; i < nregs; i++) {
2173 		emit(A64_LDR64I(i, A64_SP, args_off), ctx);
2174 		args_off += 8;
2175 	}
2176 }
2177 
2178 static bool is_struct_ops_tramp(const struct bpf_tramp_links *fentry_links)
2179 {
2180 	return fentry_links->nr_links == 1 &&
2181 		fentry_links->links[0]->link.type == BPF_LINK_TYPE_STRUCT_OPS;
2182 }
2183 
2184 /* Based on the x86's implementation of arch_prepare_bpf_trampoline().
2185  *
2186  * bpf prog and function entry before bpf trampoline hooked:
2187  *   mov x9, lr
2188  *   nop
2189  *
2190  * bpf prog and function entry after bpf trampoline hooked:
2191  *   mov x9, lr
2192  *   bl  <bpf_trampoline or plt>
2193  *
2194  */
2195 static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
2196 			      struct bpf_tramp_links *tlinks, void *func_addr,
2197 			      int nregs, u32 flags)
2198 {
2199 	int i;
2200 	int stack_size;
2201 	int retaddr_off;
2202 	int regs_off;
2203 	int retval_off;
2204 	int args_off;
2205 	int nregs_off;
2206 	int ip_off;
2207 	int run_ctx_off;
2208 	struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
2209 	struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
2210 	struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
2211 	bool save_ret;
2212 	__le32 **branches = NULL;
2213 	bool is_struct_ops = is_struct_ops_tramp(fentry);
2214 
2215 	/* trampoline stack layout:
2216 	 *                  [ parent ip         ]
2217 	 *                  [ FP                ]
2218 	 * SP + retaddr_off [ self ip           ]
2219 	 *                  [ FP                ]
2220 	 *
2221 	 *                  [ padding           ] align SP to multiples of 16
2222 	 *
2223 	 *                  [ x20               ] callee saved reg x20
2224 	 * SP + regs_off    [ x19               ] callee saved reg x19
2225 	 *
2226 	 * SP + retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
2227 	 *                                        BPF_TRAMP_F_RET_FENTRY_RET
2228 	 *
2229 	 *                  [ arg reg N         ]
2230 	 *                  [ ...               ]
2231 	 * SP + args_off    [ arg reg 1         ]
2232 	 *
2233 	 * SP + nregs_off   [ arg regs count    ]
2234 	 *
2235 	 * SP + ip_off      [ traced function   ] BPF_TRAMP_F_IP_ARG flag
2236 	 *
2237 	 * SP + run_ctx_off [ bpf_tramp_run_ctx ]
2238 	 */
2239 
2240 	stack_size = 0;
2241 	run_ctx_off = stack_size;
2242 	/* room for bpf_tramp_run_ctx */
2243 	stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
2244 
2245 	ip_off = stack_size;
2246 	/* room for IP address argument */
2247 	if (flags & BPF_TRAMP_F_IP_ARG)
2248 		stack_size += 8;
2249 
2250 	nregs_off = stack_size;
2251 	/* room for args count */
2252 	stack_size += 8;
2253 
2254 	args_off = stack_size;
2255 	/* room for args */
2256 	stack_size += nregs * 8;
2257 
2258 	/* room for return value */
2259 	retval_off = stack_size;
2260 	save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
2261 	if (save_ret)
2262 		stack_size += 8;
2263 
2264 	/* room for callee saved registers, currently x19 and x20 are used */
2265 	regs_off = stack_size;
2266 	stack_size += 16;
2267 
2268 	/* round up to multiples of 16 to avoid SPAlignmentFault */
2269 	stack_size = round_up(stack_size, 16);
2270 
2271 	/* return address locates above FP */
2272 	retaddr_off = stack_size + 8;
2273 
2274 	/* bpf trampoline may be invoked by 3 instruction types:
2275 	 * 1. bl, attached to bpf prog or kernel function via short jump
2276 	 * 2. br, attached to bpf prog or kernel function via long jump
2277 	 * 3. blr, working as a function pointer, used by struct_ops.
2278 	 * So BTI_JC should used here to support both br and blr.
2279 	 */
2280 	emit_bti(A64_BTI_JC, ctx);
2281 
2282 	/* x9 is not set for struct_ops */
2283 	if (!is_struct_ops) {
2284 		/* frame for parent function */
2285 		emit(A64_PUSH(A64_FP, A64_R(9), A64_SP), ctx);
2286 		emit(A64_MOV(1, A64_FP, A64_SP), ctx);
2287 	}
2288 
2289 	/* frame for patched function for tracing, or caller for struct_ops */
2290 	emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
2291 	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
2292 
2293 	/* allocate stack space */
2294 	emit(A64_SUB_I(1, A64_SP, A64_SP, stack_size), ctx);
2295 
2296 	if (flags & BPF_TRAMP_F_IP_ARG) {
2297 		/* save ip address of the traced function */
2298 		emit_addr_mov_i64(A64_R(10), (const u64)func_addr, ctx);
2299 		emit(A64_STR64I(A64_R(10), A64_SP, ip_off), ctx);
2300 	}
2301 
2302 	/* save arg regs count*/
2303 	emit(A64_MOVZ(1, A64_R(10), nregs, 0), ctx);
2304 	emit(A64_STR64I(A64_R(10), A64_SP, nregs_off), ctx);
2305 
2306 	/* save arg regs */
2307 	save_args(ctx, args_off, nregs);
2308 
2309 	/* save callee saved registers */
2310 	emit(A64_STR64I(A64_R(19), A64_SP, regs_off), ctx);
2311 	emit(A64_STR64I(A64_R(20), A64_SP, regs_off + 8), ctx);
2312 
2313 	if (flags & BPF_TRAMP_F_CALL_ORIG) {
2314 		/* for the first pass, assume the worst case */
2315 		if (!ctx->image)
2316 			ctx->idx += 4;
2317 		else
2318 			emit_a64_mov_i64(A64_R(0), (const u64)im, ctx);
2319 		emit_call((const u64)__bpf_tramp_enter, ctx);
2320 	}
2321 
2322 	for (i = 0; i < fentry->nr_links; i++)
2323 		invoke_bpf_prog(ctx, fentry->links[i], args_off,
2324 				retval_off, run_ctx_off,
2325 				flags & BPF_TRAMP_F_RET_FENTRY_RET);
2326 
2327 	if (fmod_ret->nr_links) {
2328 		branches = kcalloc(fmod_ret->nr_links, sizeof(__le32 *),
2329 				   GFP_KERNEL);
2330 		if (!branches)
2331 			return -ENOMEM;
2332 
2333 		invoke_bpf_mod_ret(ctx, fmod_ret, args_off, retval_off,
2334 				   run_ctx_off, branches);
2335 	}
2336 
2337 	if (flags & BPF_TRAMP_F_CALL_ORIG) {
2338 		restore_args(ctx, args_off, nregs);
2339 		/* call original func */
2340 		emit(A64_LDR64I(A64_R(10), A64_SP, retaddr_off), ctx);
2341 		emit(A64_ADR(A64_LR, AARCH64_INSN_SIZE * 2), ctx);
2342 		emit(A64_RET(A64_R(10)), ctx);
2343 		/* store return value */
2344 		emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);
2345 		/* reserve a nop for bpf_tramp_image_put */
2346 		im->ip_after_call = ctx->ro_image + ctx->idx;
2347 		emit(A64_NOP, ctx);
2348 	}
2349 
2350 	/* update the branches saved in invoke_bpf_mod_ret with cbnz */
2351 	for (i = 0; i < fmod_ret->nr_links && ctx->image != NULL; i++) {
2352 		int offset = &ctx->image[ctx->idx] - branches[i];
2353 		*branches[i] = cpu_to_le32(A64_CBNZ(1, A64_R(10), offset));
2354 	}
2355 
2356 	for (i = 0; i < fexit->nr_links; i++)
2357 		invoke_bpf_prog(ctx, fexit->links[i], args_off, retval_off,
2358 				run_ctx_off, false);
2359 
2360 	if (flags & BPF_TRAMP_F_CALL_ORIG) {
2361 		im->ip_epilogue = ctx->ro_image + ctx->idx;
2362 		/* for the first pass, assume the worst case */
2363 		if (!ctx->image)
2364 			ctx->idx += 4;
2365 		else
2366 			emit_a64_mov_i64(A64_R(0), (const u64)im, ctx);
2367 		emit_call((const u64)__bpf_tramp_exit, ctx);
2368 	}
2369 
2370 	if (flags & BPF_TRAMP_F_RESTORE_REGS)
2371 		restore_args(ctx, args_off, nregs);
2372 
2373 	/* restore callee saved register x19 and x20 */
2374 	emit(A64_LDR64I(A64_R(19), A64_SP, regs_off), ctx);
2375 	emit(A64_LDR64I(A64_R(20), A64_SP, regs_off + 8), ctx);
2376 
2377 	if (save_ret)
2378 		emit(A64_LDR64I(A64_R(0), A64_SP, retval_off), ctx);
2379 
2380 	/* reset SP  */
2381 	emit(A64_MOV(1, A64_SP, A64_FP), ctx);
2382 
2383 	if (is_struct_ops) {
2384 		emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
2385 		emit(A64_RET(A64_LR), ctx);
2386 	} else {
2387 		/* pop frames */
2388 		emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
2389 		emit(A64_POP(A64_FP, A64_R(9), A64_SP), ctx);
2390 
2391 		if (flags & BPF_TRAMP_F_SKIP_FRAME) {
2392 			/* skip patched function, return to parent */
2393 			emit(A64_MOV(1, A64_LR, A64_R(9)), ctx);
2394 			emit(A64_RET(A64_R(9)), ctx);
2395 		} else {
2396 			/* return to patched function */
2397 			emit(A64_MOV(1, A64_R(10), A64_LR), ctx);
2398 			emit(A64_MOV(1, A64_LR, A64_R(9)), ctx);
2399 			emit(A64_RET(A64_R(10)), ctx);
2400 		}
2401 	}
2402 
2403 	kfree(branches);
2404 
2405 	return ctx->idx;
2406 }
2407 
2408 static int btf_func_model_nregs(const struct btf_func_model *m)
2409 {
2410 	int nregs = m->nr_args;
2411 	int i;
2412 
2413 	/* extra registers needed for struct argument */
2414 	for (i = 0; i < MAX_BPF_FUNC_ARGS; i++) {
2415 		/* The arg_size is at most 16 bytes, enforced by the verifier. */
2416 		if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
2417 			nregs += (m->arg_size[i] + 7) / 8 - 1;
2418 	}
2419 
2420 	return nregs;
2421 }
2422 
2423 int arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
2424 			     struct bpf_tramp_links *tlinks, void *func_addr)
2425 {
2426 	struct jit_ctx ctx = {
2427 		.image = NULL,
2428 		.idx = 0,
2429 	};
2430 	struct bpf_tramp_image im;
2431 	int nregs, ret;
2432 
2433 	nregs = btf_func_model_nregs(m);
2434 	/* the first 8 registers are used for arguments */
2435 	if (nregs > 8)
2436 		return -ENOTSUPP;
2437 
2438 	ret = prepare_trampoline(&ctx, &im, tlinks, func_addr, nregs, flags);
2439 	if (ret < 0)
2440 		return ret;
2441 
2442 	return ret < 0 ? ret : ret * AARCH64_INSN_SIZE;
2443 }
2444 
2445 void *arch_alloc_bpf_trampoline(unsigned int size)
2446 {
2447 	return bpf_prog_pack_alloc(size, jit_fill_hole);
2448 }
2449 
2450 void arch_free_bpf_trampoline(void *image, unsigned int size)
2451 {
2452 	bpf_prog_pack_free(image, size);
2453 }
2454 
2455 int arch_protect_bpf_trampoline(void *image, unsigned int size)
2456 {
2457 	return 0;
2458 }
2459 
2460 int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *ro_image,
2461 				void *ro_image_end, const struct btf_func_model *m,
2462 				u32 flags, struct bpf_tramp_links *tlinks,
2463 				void *func_addr)
2464 {
2465 	int ret, nregs;
2466 	void *image, *tmp;
2467 	u32 size = ro_image_end - ro_image;
2468 
2469 	/* image doesn't need to be in module memory range, so we can
2470 	 * use kvmalloc.
2471 	 */
2472 	image = kvmalloc(size, GFP_KERNEL);
2473 	if (!image)
2474 		return -ENOMEM;
2475 
2476 	struct jit_ctx ctx = {
2477 		.image = image,
2478 		.ro_image = ro_image,
2479 		.idx = 0,
2480 		.write = true,
2481 	};
2482 
2483 	nregs = btf_func_model_nregs(m);
2484 	/* the first 8 registers are used for arguments */
2485 	if (nregs > 8)
2486 		return -ENOTSUPP;
2487 
2488 	jit_fill_hole(image, (unsigned int)(ro_image_end - ro_image));
2489 	ret = prepare_trampoline(&ctx, im, tlinks, func_addr, nregs, flags);
2490 
2491 	if (ret > 0 && validate_code(&ctx) < 0) {
2492 		ret = -EINVAL;
2493 		goto out;
2494 	}
2495 
2496 	if (ret > 0)
2497 		ret *= AARCH64_INSN_SIZE;
2498 
2499 	tmp = bpf_arch_text_copy(ro_image, image, size);
2500 	if (IS_ERR(tmp)) {
2501 		ret = PTR_ERR(tmp);
2502 		goto out;
2503 	}
2504 
2505 	bpf_flush_icache(ro_image, ro_image + size);
2506 out:
2507 	kvfree(image);
2508 	return ret;
2509 }
2510 
2511 static bool is_long_jump(void *ip, void *target)
2512 {
2513 	long offset;
2514 
2515 	/* NULL target means this is a NOP */
2516 	if (!target)
2517 		return false;
2518 
2519 	offset = (long)target - (long)ip;
2520 	return offset < -SZ_128M || offset >= SZ_128M;
2521 }
2522 
2523 static int gen_branch_or_nop(enum aarch64_insn_branch_type type, void *ip,
2524 			     void *addr, void *plt, u32 *insn)
2525 {
2526 	void *target;
2527 
2528 	if (!addr) {
2529 		*insn = aarch64_insn_gen_nop();
2530 		return 0;
2531 	}
2532 
2533 	if (is_long_jump(ip, addr))
2534 		target = plt;
2535 	else
2536 		target = addr;
2537 
2538 	*insn = aarch64_insn_gen_branch_imm((unsigned long)ip,
2539 					    (unsigned long)target,
2540 					    type);
2541 
2542 	return *insn != AARCH64_BREAK_FAULT ? 0 : -EFAULT;
2543 }
2544 
2545 /* Replace the branch instruction from @ip to @old_addr in a bpf prog or a bpf
2546  * trampoline with the branch instruction from @ip to @new_addr. If @old_addr
2547  * or @new_addr is NULL, the old or new instruction is NOP.
2548  *
2549  * When @ip is the bpf prog entry, a bpf trampoline is being attached or
2550  * detached. Since bpf trampoline and bpf prog are allocated separately with
2551  * vmalloc, the address distance may exceed 128MB, the maximum branch range.
2552  * So long jump should be handled.
2553  *
2554  * When a bpf prog is constructed, a plt pointing to empty trampoline
2555  * dummy_tramp is placed at the end:
2556  *
2557  *      bpf_prog:
2558  *              mov x9, lr
2559  *              nop // patchsite
2560  *              ...
2561  *              ret
2562  *
2563  *      plt:
2564  *              ldr x10, target
2565  *              br x10
2566  *      target:
2567  *              .quad dummy_tramp // plt target
2568  *
2569  * This is also the state when no trampoline is attached.
2570  *
2571  * When a short-jump bpf trampoline is attached, the patchsite is patched
2572  * to a bl instruction to the trampoline directly:
2573  *
2574  *      bpf_prog:
2575  *              mov x9, lr
2576  *              bl <short-jump bpf trampoline address> // patchsite
2577  *              ...
2578  *              ret
2579  *
2580  *      plt:
2581  *              ldr x10, target
2582  *              br x10
2583  *      target:
2584  *              .quad dummy_tramp // plt target
2585  *
2586  * When a long-jump bpf trampoline is attached, the plt target is filled with
2587  * the trampoline address and the patchsite is patched to a bl instruction to
2588  * the plt:
2589  *
2590  *      bpf_prog:
2591  *              mov x9, lr
2592  *              bl plt // patchsite
2593  *              ...
2594  *              ret
2595  *
2596  *      plt:
2597  *              ldr x10, target
2598  *              br x10
2599  *      target:
2600  *              .quad <long-jump bpf trampoline address> // plt target
2601  *
2602  * The dummy_tramp is used to prevent another CPU from jumping to unknown
2603  * locations during the patching process, making the patching process easier.
2604  */
2605 int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
2606 		       void *old_addr, void *new_addr)
2607 {
2608 	int ret;
2609 	u32 old_insn;
2610 	u32 new_insn;
2611 	u32 replaced;
2612 	struct bpf_plt *plt = NULL;
2613 	unsigned long size = 0UL;
2614 	unsigned long offset = ~0UL;
2615 	enum aarch64_insn_branch_type branch_type;
2616 	char namebuf[KSYM_NAME_LEN];
2617 	void *image = NULL;
2618 	u64 plt_target = 0ULL;
2619 	bool poking_bpf_entry;
2620 
2621 	if (!__bpf_address_lookup((unsigned long)ip, &size, &offset, namebuf))
2622 		/* Only poking bpf text is supported. Since kernel function
2623 		 * entry is set up by ftrace, we reply on ftrace to poke kernel
2624 		 * functions.
2625 		 */
2626 		return -ENOTSUPP;
2627 
2628 	image = ip - offset;
2629 	/* zero offset means we're poking bpf prog entry */
2630 	poking_bpf_entry = (offset == 0UL);
2631 
2632 	/* bpf prog entry, find plt and the real patchsite */
2633 	if (poking_bpf_entry) {
2634 		/* plt locates at the end of bpf prog */
2635 		plt = image + size - PLT_TARGET_OFFSET;
2636 
2637 		/* skip to the nop instruction in bpf prog entry:
2638 		 * bti c // if BTI enabled
2639 		 * mov x9, x30
2640 		 * nop
2641 		 */
2642 		ip = image + POKE_OFFSET * AARCH64_INSN_SIZE;
2643 	}
2644 
2645 	/* long jump is only possible at bpf prog entry */
2646 	if (WARN_ON((is_long_jump(ip, new_addr) || is_long_jump(ip, old_addr)) &&
2647 		    !poking_bpf_entry))
2648 		return -EINVAL;
2649 
2650 	if (poke_type == BPF_MOD_CALL)
2651 		branch_type = AARCH64_INSN_BRANCH_LINK;
2652 	else
2653 		branch_type = AARCH64_INSN_BRANCH_NOLINK;
2654 
2655 	if (gen_branch_or_nop(branch_type, ip, old_addr, plt, &old_insn) < 0)
2656 		return -EFAULT;
2657 
2658 	if (gen_branch_or_nop(branch_type, ip, new_addr, plt, &new_insn) < 0)
2659 		return -EFAULT;
2660 
2661 	if (is_long_jump(ip, new_addr))
2662 		plt_target = (u64)new_addr;
2663 	else if (is_long_jump(ip, old_addr))
2664 		/* if the old target is a long jump and the new target is not,
2665 		 * restore the plt target to dummy_tramp, so there is always a
2666 		 * legal and harmless address stored in plt target, and we'll
2667 		 * never jump from plt to an unknown place.
2668 		 */
2669 		plt_target = (u64)&dummy_tramp;
2670 
2671 	if (plt_target) {
2672 		/* non-zero plt_target indicates we're patching a bpf prog,
2673 		 * which is read only.
2674 		 */
2675 		if (set_memory_rw(PAGE_MASK & ((uintptr_t)&plt->target), 1))
2676 			return -EFAULT;
2677 		WRITE_ONCE(plt->target, plt_target);
2678 		set_memory_ro(PAGE_MASK & ((uintptr_t)&plt->target), 1);
2679 		/* since plt target points to either the new trampoline
2680 		 * or dummy_tramp, even if another CPU reads the old plt
2681 		 * target value before fetching the bl instruction to plt,
2682 		 * it will be brought back by dummy_tramp, so no barrier is
2683 		 * required here.
2684 		 */
2685 	}
2686 
2687 	/* if the old target and the new target are both long jumps, no
2688 	 * patching is required
2689 	 */
2690 	if (old_insn == new_insn)
2691 		return 0;
2692 
2693 	mutex_lock(&text_mutex);
2694 	if (aarch64_insn_read(ip, &replaced)) {
2695 		ret = -EFAULT;
2696 		goto out;
2697 	}
2698 
2699 	if (replaced != old_insn) {
2700 		ret = -EFAULT;
2701 		goto out;
2702 	}
2703 
2704 	/* We call aarch64_insn_patch_text_nosync() to replace instruction
2705 	 * atomically, so no other CPUs will fetch a half-new and half-old
2706 	 * instruction. But there is chance that another CPU executes the
2707 	 * old instruction after the patching operation finishes (e.g.,
2708 	 * pipeline not flushed, or icache not synchronized yet).
2709 	 *
2710 	 * 1. when a new trampoline is attached, it is not a problem for
2711 	 *    different CPUs to jump to different trampolines temporarily.
2712 	 *
2713 	 * 2. when an old trampoline is freed, we should wait for all other
2714 	 *    CPUs to exit the trampoline and make sure the trampoline is no
2715 	 *    longer reachable, since bpf_tramp_image_put() function already
2716 	 *    uses percpu_ref and task-based rcu to do the sync, no need to call
2717 	 *    the sync version here, see bpf_tramp_image_put() for details.
2718 	 */
2719 	ret = aarch64_insn_patch_text_nosync(ip, new_insn);
2720 out:
2721 	mutex_unlock(&text_mutex);
2722 
2723 	return ret;
2724 }
2725 
2726 bool bpf_jit_supports_ptr_xchg(void)
2727 {
2728 	return true;
2729 }
2730 
2731 bool bpf_jit_supports_exceptions(void)
2732 {
2733 	/* We unwind through both kernel frames starting from within bpf_throw
2734 	 * call and BPF frames. Therefore we require FP unwinder to be enabled
2735 	 * to walk kernel frames and reach BPF frames in the stack trace.
2736 	 * ARM64 kernel is aways compiled with CONFIG_FRAME_POINTER=y
2737 	 */
2738 	return true;
2739 }
2740 
2741 bool bpf_jit_supports_arena(void)
2742 {
2743 	return true;
2744 }
2745 
2746 bool bpf_jit_supports_insn(struct bpf_insn *insn, bool in_arena)
2747 {
2748 	if (!in_arena)
2749 		return true;
2750 	switch (insn->code) {
2751 	case BPF_STX | BPF_ATOMIC | BPF_W:
2752 	case BPF_STX | BPF_ATOMIC | BPF_DW:
2753 		if (!bpf_atomic_is_load_store(insn) &&
2754 		    !cpus_have_cap(ARM64_HAS_LSE_ATOMICS))
2755 			return false;
2756 	}
2757 	return true;
2758 }
2759 
2760 bool bpf_jit_supports_percpu_insn(void)
2761 {
2762 	return true;
2763 }
2764 
2765 bool bpf_jit_inlines_helper_call(s32 imm)
2766 {
2767 	switch (imm) {
2768 	case BPF_FUNC_get_smp_processor_id:
2769 	case BPF_FUNC_get_current_task:
2770 	case BPF_FUNC_get_current_task_btf:
2771 		return true;
2772 	default:
2773 		return false;
2774 	}
2775 }
2776 
2777 void bpf_jit_free(struct bpf_prog *prog)
2778 {
2779 	if (prog->jited) {
2780 		struct arm64_jit_data *jit_data = prog->aux->jit_data;
2781 		struct bpf_binary_header *hdr;
2782 
2783 		/*
2784 		 * If we fail the final pass of JIT (from jit_subprogs),
2785 		 * the program may not be finalized yet. Call finalize here
2786 		 * before freeing it.
2787 		 */
2788 		if (jit_data) {
2789 			bpf_arch_text_copy(&jit_data->ro_header->size, &jit_data->header->size,
2790 					   sizeof(jit_data->header->size));
2791 			kfree(jit_data);
2792 		}
2793 		hdr = bpf_jit_binary_pack_hdr(prog);
2794 		bpf_jit_binary_pack_free(hdr, NULL);
2795 		WARN_ON_ONCE(!bpf_prog_kallsyms_verify_off(prog));
2796 	}
2797 
2798 	bpf_prog_unlock_free(prog);
2799 }
2800