xref: /linux/arch/riscv/net/bpf_jit_comp64.c (revision 1f77ed9422cbc41e1a5d17654b7e527a4a23b665)
1 // SPDX-License-Identifier: GPL-2.0
2 /* BPF JIT compiler for RV64G
3  *
4  * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
5  *
6  */
7 
8 #include <linux/bitfield.h>
9 #include <linux/bpf.h>
10 #include <linux/filter.h>
11 #include "bpf_jit.h"
12 
13 #define RV_REG_TCC RV_REG_A6
14 #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
15 
16 static const int regmap[] = {
17 	[BPF_REG_0] =	RV_REG_A5,
18 	[BPF_REG_1] =	RV_REG_A0,
19 	[BPF_REG_2] =	RV_REG_A1,
20 	[BPF_REG_3] =	RV_REG_A2,
21 	[BPF_REG_4] =	RV_REG_A3,
22 	[BPF_REG_5] =	RV_REG_A4,
23 	[BPF_REG_6] =	RV_REG_S1,
24 	[BPF_REG_7] =	RV_REG_S2,
25 	[BPF_REG_8] =	RV_REG_S3,
26 	[BPF_REG_9] =	RV_REG_S4,
27 	[BPF_REG_FP] =	RV_REG_S5,
28 	[BPF_REG_AX] =	RV_REG_T0,
29 };
30 
31 static const int pt_regmap[] = {
32 	[RV_REG_A0] = offsetof(struct pt_regs, a0),
33 	[RV_REG_A1] = offsetof(struct pt_regs, a1),
34 	[RV_REG_A2] = offsetof(struct pt_regs, a2),
35 	[RV_REG_A3] = offsetof(struct pt_regs, a3),
36 	[RV_REG_A4] = offsetof(struct pt_regs, a4),
37 	[RV_REG_A5] = offsetof(struct pt_regs, a5),
38 	[RV_REG_S1] = offsetof(struct pt_regs, s1),
39 	[RV_REG_S2] = offsetof(struct pt_regs, s2),
40 	[RV_REG_S3] = offsetof(struct pt_regs, s3),
41 	[RV_REG_S4] = offsetof(struct pt_regs, s4),
42 	[RV_REG_S5] = offsetof(struct pt_regs, s5),
43 	[RV_REG_T0] = offsetof(struct pt_regs, t0),
44 };
45 
46 enum {
47 	RV_CTX_F_SEEN_TAIL_CALL =	0,
48 	RV_CTX_F_SEEN_CALL =		RV_REG_RA,
49 	RV_CTX_F_SEEN_S1 =		RV_REG_S1,
50 	RV_CTX_F_SEEN_S2 =		RV_REG_S2,
51 	RV_CTX_F_SEEN_S3 =		RV_REG_S3,
52 	RV_CTX_F_SEEN_S4 =		RV_REG_S4,
53 	RV_CTX_F_SEEN_S5 =		RV_REG_S5,
54 	RV_CTX_F_SEEN_S6 =		RV_REG_S6,
55 };
56 
57 static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
58 {
59 	u8 reg = regmap[bpf_reg];
60 
61 	switch (reg) {
62 	case RV_CTX_F_SEEN_S1:
63 	case RV_CTX_F_SEEN_S2:
64 	case RV_CTX_F_SEEN_S3:
65 	case RV_CTX_F_SEEN_S4:
66 	case RV_CTX_F_SEEN_S5:
67 	case RV_CTX_F_SEEN_S6:
68 		__set_bit(reg, &ctx->flags);
69 	}
70 	return reg;
71 };
72 
73 static bool seen_reg(int reg, struct rv_jit_context *ctx)
74 {
75 	switch (reg) {
76 	case RV_CTX_F_SEEN_CALL:
77 	case RV_CTX_F_SEEN_S1:
78 	case RV_CTX_F_SEEN_S2:
79 	case RV_CTX_F_SEEN_S3:
80 	case RV_CTX_F_SEEN_S4:
81 	case RV_CTX_F_SEEN_S5:
82 	case RV_CTX_F_SEEN_S6:
83 		return test_bit(reg, &ctx->flags);
84 	}
85 	return false;
86 }
87 
88 static void mark_fp(struct rv_jit_context *ctx)
89 {
90 	__set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
91 }
92 
93 static void mark_call(struct rv_jit_context *ctx)
94 {
95 	__set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
96 }
97 
98 static bool seen_call(struct rv_jit_context *ctx)
99 {
100 	return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
101 }
102 
103 static void mark_tail_call(struct rv_jit_context *ctx)
104 {
105 	__set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
106 }
107 
108 static bool seen_tail_call(struct rv_jit_context *ctx)
109 {
110 	return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
111 }
112 
113 static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
114 {
115 	mark_tail_call(ctx);
116 
117 	if (seen_call(ctx)) {
118 		__set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
119 		return RV_REG_S6;
120 	}
121 	return RV_REG_A6;
122 }
123 
124 static bool is_32b_int(s64 val)
125 {
126 	return -(1L << 31) <= val && val < (1L << 31);
127 }
128 
129 static bool in_auipc_jalr_range(s64 val)
130 {
131 	/*
132 	 * auipc+jalr can reach any signed PC-relative offset in the range
133 	 * [-2^31 - 2^11, 2^31 - 2^11).
134 	 */
135 	return (-(1L << 31) - (1L << 11)) <= val &&
136 		val < ((1L << 31) - (1L << 11));
137 }
138 
139 static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
140 {
141 	/* Note that the immediate from the add is sign-extended,
142 	 * which means that we need to compensate this by adding 2^12,
143 	 * when the 12th bit is set. A simpler way of doing this, and
144 	 * getting rid of the check, is to just add 2**11 before the
145 	 * shift. The "Loading a 32-Bit constant" example from the
146 	 * "Computer Organization and Design, RISC-V edition" book by
147 	 * Patterson/Hennessy highlights this fact.
148 	 *
149 	 * This also means that we need to process LSB to MSB.
150 	 */
151 	s64 upper = (val + (1 << 11)) >> 12;
152 	/* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
153 	 * and addi are signed and RVC checks will perform signed comparisons.
154 	 */
155 	s64 lower = ((val & 0xfff) << 52) >> 52;
156 	int shift;
157 
158 	if (is_32b_int(val)) {
159 		if (upper)
160 			emit_lui(rd, upper, ctx);
161 
162 		if (!upper) {
163 			emit_li(rd, lower, ctx);
164 			return;
165 		}
166 
167 		emit_addiw(rd, rd, lower, ctx);
168 		return;
169 	}
170 
171 	shift = __ffs(upper);
172 	upper >>= shift;
173 	shift += 12;
174 
175 	emit_imm(rd, upper, ctx);
176 
177 	emit_slli(rd, rd, shift, ctx);
178 	if (lower)
179 		emit_addi(rd, rd, lower, ctx);
180 }
181 
182 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
183 {
184 	int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
185 
186 	if (seen_reg(RV_REG_RA, ctx)) {
187 		emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
188 		store_offset -= 8;
189 	}
190 	emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
191 	store_offset -= 8;
192 	if (seen_reg(RV_REG_S1, ctx)) {
193 		emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
194 		store_offset -= 8;
195 	}
196 	if (seen_reg(RV_REG_S2, ctx)) {
197 		emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
198 		store_offset -= 8;
199 	}
200 	if (seen_reg(RV_REG_S3, ctx)) {
201 		emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
202 		store_offset -= 8;
203 	}
204 	if (seen_reg(RV_REG_S4, ctx)) {
205 		emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
206 		store_offset -= 8;
207 	}
208 	if (seen_reg(RV_REG_S5, ctx)) {
209 		emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
210 		store_offset -= 8;
211 	}
212 	if (seen_reg(RV_REG_S6, ctx)) {
213 		emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
214 		store_offset -= 8;
215 	}
216 
217 	emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
218 	/* Set return value. */
219 	if (!is_tail_call)
220 		emit_mv(RV_REG_A0, RV_REG_A5, ctx);
221 	emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
222 		  is_tail_call ? 4 : 0, /* skip TCC init */
223 		  ctx);
224 }
225 
226 static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
227 		     struct rv_jit_context *ctx)
228 {
229 	switch (cond) {
230 	case BPF_JEQ:
231 		emit(rv_beq(rd, rs, rvoff >> 1), ctx);
232 		return;
233 	case BPF_JGT:
234 		emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
235 		return;
236 	case BPF_JLT:
237 		emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
238 		return;
239 	case BPF_JGE:
240 		emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
241 		return;
242 	case BPF_JLE:
243 		emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
244 		return;
245 	case BPF_JNE:
246 		emit(rv_bne(rd, rs, rvoff >> 1), ctx);
247 		return;
248 	case BPF_JSGT:
249 		emit(rv_blt(rs, rd, rvoff >> 1), ctx);
250 		return;
251 	case BPF_JSLT:
252 		emit(rv_blt(rd, rs, rvoff >> 1), ctx);
253 		return;
254 	case BPF_JSGE:
255 		emit(rv_bge(rd, rs, rvoff >> 1), ctx);
256 		return;
257 	case BPF_JSLE:
258 		emit(rv_bge(rs, rd, rvoff >> 1), ctx);
259 	}
260 }
261 
262 static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
263 			struct rv_jit_context *ctx)
264 {
265 	s64 upper, lower;
266 
267 	if (is_13b_int(rvoff)) {
268 		emit_bcc(cond, rd, rs, rvoff, ctx);
269 		return;
270 	}
271 
272 	/* Adjust for jal */
273 	rvoff -= 4;
274 
275 	/* Transform, e.g.:
276 	 *   bne rd,rs,foo
277 	 * to
278 	 *   beq rd,rs,<.L1>
279 	 *   (auipc foo)
280 	 *   jal(r) foo
281 	 * .L1
282 	 */
283 	cond = invert_bpf_cond(cond);
284 	if (is_21b_int(rvoff)) {
285 		emit_bcc(cond, rd, rs, 8, ctx);
286 		emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
287 		return;
288 	}
289 
290 	/* 32b No need for an additional rvoff adjustment, since we
291 	 * get that from the auipc at PC', where PC = PC' + 4.
292 	 */
293 	upper = (rvoff + (1 << 11)) >> 12;
294 	lower = rvoff & 0xfff;
295 
296 	emit_bcc(cond, rd, rs, 12, ctx);
297 	emit(rv_auipc(RV_REG_T1, upper), ctx);
298 	emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
299 }
300 
301 static void emit_zext_32(u8 reg, struct rv_jit_context *ctx)
302 {
303 	emit_slli(reg, reg, 32, ctx);
304 	emit_srli(reg, reg, 32, ctx);
305 }
306 
307 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
308 {
309 	int tc_ninsn, off, start_insn = ctx->ninsns;
310 	u8 tcc = rv_tail_call_reg(ctx);
311 
312 	/* a0: &ctx
313 	 * a1: &array
314 	 * a2: index
315 	 *
316 	 * if (index >= array->map.max_entries)
317 	 *	goto out;
318 	 */
319 	tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
320 		   ctx->offset[0];
321 	emit_zext_32(RV_REG_A2, ctx);
322 
323 	off = offsetof(struct bpf_array, map.max_entries);
324 	if (is_12b_check(off, insn))
325 		return -1;
326 	emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
327 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
328 	emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
329 
330 	/* if (TCC-- < 0)
331 	 *     goto out;
332 	 */
333 	emit_addi(RV_REG_T1, tcc, -1, ctx);
334 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
335 	emit_branch(BPF_JSLT, tcc, RV_REG_ZERO, off, ctx);
336 
337 	/* prog = array->ptrs[index];
338 	 * if (!prog)
339 	 *     goto out;
340 	 */
341 	emit_slli(RV_REG_T2, RV_REG_A2, 3, ctx);
342 	emit_add(RV_REG_T2, RV_REG_T2, RV_REG_A1, ctx);
343 	off = offsetof(struct bpf_array, ptrs);
344 	if (is_12b_check(off, insn))
345 		return -1;
346 	emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
347 	off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
348 	emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
349 
350 	/* goto *(prog->bpf_func + 4); */
351 	off = offsetof(struct bpf_prog, bpf_func);
352 	if (is_12b_check(off, insn))
353 		return -1;
354 	emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
355 	emit_mv(RV_REG_TCC, RV_REG_T1, ctx);
356 	__build_epilogue(true, ctx);
357 	return 0;
358 }
359 
360 static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
361 		      struct rv_jit_context *ctx)
362 {
363 	u8 code = insn->code;
364 
365 	switch (code) {
366 	case BPF_JMP | BPF_JA:
367 	case BPF_JMP | BPF_CALL:
368 	case BPF_JMP | BPF_EXIT:
369 	case BPF_JMP | BPF_TAIL_CALL:
370 		break;
371 	default:
372 		*rd = bpf_to_rv_reg(insn->dst_reg, ctx);
373 	}
374 
375 	if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
376 	    code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
377 	    code & BPF_LDX || code & BPF_STX)
378 		*rs = bpf_to_rv_reg(insn->src_reg, ctx);
379 }
380 
381 static void emit_zext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
382 {
383 	emit_mv(RV_REG_T2, *rd, ctx);
384 	emit_zext_32(RV_REG_T2, ctx);
385 	emit_mv(RV_REG_T1, *rs, ctx);
386 	emit_zext_32(RV_REG_T1, ctx);
387 	*rd = RV_REG_T2;
388 	*rs = RV_REG_T1;
389 }
390 
391 static void emit_sext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
392 {
393 	emit_addiw(RV_REG_T2, *rd, 0, ctx);
394 	emit_addiw(RV_REG_T1, *rs, 0, ctx);
395 	*rd = RV_REG_T2;
396 	*rs = RV_REG_T1;
397 }
398 
399 static void emit_zext_32_rd_t1(u8 *rd, struct rv_jit_context *ctx)
400 {
401 	emit_mv(RV_REG_T2, *rd, ctx);
402 	emit_zext_32(RV_REG_T2, ctx);
403 	emit_zext_32(RV_REG_T1, ctx);
404 	*rd = RV_REG_T2;
405 }
406 
407 static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
408 {
409 	emit_addiw(RV_REG_T2, *rd, 0, ctx);
410 	*rd = RV_REG_T2;
411 }
412 
413 static int emit_jump_and_link(u8 rd, s64 rvoff, bool force_jalr,
414 			      struct rv_jit_context *ctx)
415 {
416 	s64 upper, lower;
417 
418 	if (rvoff && is_21b_int(rvoff) && !force_jalr) {
419 		emit(rv_jal(rd, rvoff >> 1), ctx);
420 		return 0;
421 	} else if (in_auipc_jalr_range(rvoff)) {
422 		upper = (rvoff + (1 << 11)) >> 12;
423 		lower = rvoff & 0xfff;
424 		emit(rv_auipc(RV_REG_T1, upper), ctx);
425 		emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
426 		return 0;
427 	}
428 
429 	pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
430 	return -ERANGE;
431 }
432 
433 static bool is_signed_bpf_cond(u8 cond)
434 {
435 	return cond == BPF_JSGT || cond == BPF_JSLT ||
436 		cond == BPF_JSGE || cond == BPF_JSLE;
437 }
438 
439 static int emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
440 {
441 	s64 off = 0;
442 	u64 ip;
443 	u8 rd;
444 	int ret;
445 
446 	if (addr && ctx->insns) {
447 		ip = (u64)(long)(ctx->insns + ctx->ninsns);
448 		off = addr - ip;
449 	}
450 
451 	ret = emit_jump_and_link(RV_REG_RA, off, !fixed, ctx);
452 	if (ret)
453 		return ret;
454 	rd = bpf_to_rv_reg(BPF_REG_0, ctx);
455 	emit_mv(rd, RV_REG_A0, ctx);
456 	return 0;
457 }
458 
459 #define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
460 #define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
461 
462 bool ex_handler_bpf(const struct exception_table_entry *ex,
463 		    struct pt_regs *regs)
464 {
465 	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
466 	int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
467 
468 	*(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
469 	regs->epc = (unsigned long)&ex->fixup - offset;
470 
471 	return true;
472 }
473 
474 /* For accesses to BTF pointers, add an entry to the exception table */
475 static int add_exception_handler(const struct bpf_insn *insn,
476 				 struct rv_jit_context *ctx,
477 				 int dst_reg, int insn_len)
478 {
479 	struct exception_table_entry *ex;
480 	unsigned long pc;
481 	off_t offset;
482 
483 	if (!ctx->insns || !ctx->prog->aux->extable || BPF_MODE(insn->code) != BPF_PROBE_MEM)
484 		return 0;
485 
486 	if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
487 		return -EINVAL;
488 
489 	if (WARN_ON_ONCE(insn_len > ctx->ninsns))
490 		return -EINVAL;
491 
492 	if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
493 		return -EINVAL;
494 
495 	ex = &ctx->prog->aux->extable[ctx->nexentries];
496 	pc = (unsigned long)&ctx->insns[ctx->ninsns - insn_len];
497 
498 	offset = pc - (long)&ex->insn;
499 	if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
500 		return -ERANGE;
501 	ex->insn = pc;
502 
503 	/*
504 	 * Since the extable follows the program, the fixup offset is always
505 	 * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
506 	 * to keep things simple, and put the destination register in the upper
507 	 * bits. We don't need to worry about buildtime or runtime sort
508 	 * modifying the upper bits because the table is already sorted, and
509 	 * isn't part of the main exception table.
510 	 */
511 	offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
512 	if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
513 		return -ERANGE;
514 
515 	ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
516 		FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
517 	ex->type = EX_TYPE_BPF;
518 
519 	ctx->nexentries++;
520 	return 0;
521 }
522 
523 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
524 		      bool extra_pass)
525 {
526 	bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
527 		    BPF_CLASS(insn->code) == BPF_JMP;
528 	int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
529 	struct bpf_prog_aux *aux = ctx->prog->aux;
530 	u8 rd = -1, rs = -1, code = insn->code;
531 	s16 off = insn->off;
532 	s32 imm = insn->imm;
533 
534 	init_regs(&rd, &rs, insn, ctx);
535 
536 	switch (code) {
537 	/* dst = src */
538 	case BPF_ALU | BPF_MOV | BPF_X:
539 	case BPF_ALU64 | BPF_MOV | BPF_X:
540 		if (imm == 1) {
541 			/* Special mov32 for zext */
542 			emit_zext_32(rd, ctx);
543 			break;
544 		}
545 		emit_mv(rd, rs, ctx);
546 		if (!is64 && !aux->verifier_zext)
547 			emit_zext_32(rd, ctx);
548 		break;
549 
550 	/* dst = dst OP src */
551 	case BPF_ALU | BPF_ADD | BPF_X:
552 	case BPF_ALU64 | BPF_ADD | BPF_X:
553 		emit_add(rd, rd, rs, ctx);
554 		if (!is64 && !aux->verifier_zext)
555 			emit_zext_32(rd, ctx);
556 		break;
557 	case BPF_ALU | BPF_SUB | BPF_X:
558 	case BPF_ALU64 | BPF_SUB | BPF_X:
559 		if (is64)
560 			emit_sub(rd, rd, rs, ctx);
561 		else
562 			emit_subw(rd, rd, rs, ctx);
563 
564 		if (!is64 && !aux->verifier_zext)
565 			emit_zext_32(rd, ctx);
566 		break;
567 	case BPF_ALU | BPF_AND | BPF_X:
568 	case BPF_ALU64 | BPF_AND | BPF_X:
569 		emit_and(rd, rd, rs, ctx);
570 		if (!is64 && !aux->verifier_zext)
571 			emit_zext_32(rd, ctx);
572 		break;
573 	case BPF_ALU | BPF_OR | BPF_X:
574 	case BPF_ALU64 | BPF_OR | BPF_X:
575 		emit_or(rd, rd, rs, ctx);
576 		if (!is64 && !aux->verifier_zext)
577 			emit_zext_32(rd, ctx);
578 		break;
579 	case BPF_ALU | BPF_XOR | BPF_X:
580 	case BPF_ALU64 | BPF_XOR | BPF_X:
581 		emit_xor(rd, rd, rs, ctx);
582 		if (!is64 && !aux->verifier_zext)
583 			emit_zext_32(rd, ctx);
584 		break;
585 	case BPF_ALU | BPF_MUL | BPF_X:
586 	case BPF_ALU64 | BPF_MUL | BPF_X:
587 		emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
588 		if (!is64 && !aux->verifier_zext)
589 			emit_zext_32(rd, ctx);
590 		break;
591 	case BPF_ALU | BPF_DIV | BPF_X:
592 	case BPF_ALU64 | BPF_DIV | BPF_X:
593 		emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
594 		if (!is64 && !aux->verifier_zext)
595 			emit_zext_32(rd, ctx);
596 		break;
597 	case BPF_ALU | BPF_MOD | BPF_X:
598 	case BPF_ALU64 | BPF_MOD | BPF_X:
599 		emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
600 		if (!is64 && !aux->verifier_zext)
601 			emit_zext_32(rd, ctx);
602 		break;
603 	case BPF_ALU | BPF_LSH | BPF_X:
604 	case BPF_ALU64 | BPF_LSH | BPF_X:
605 		emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
606 		if (!is64 && !aux->verifier_zext)
607 			emit_zext_32(rd, ctx);
608 		break;
609 	case BPF_ALU | BPF_RSH | BPF_X:
610 	case BPF_ALU64 | BPF_RSH | BPF_X:
611 		emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
612 		if (!is64 && !aux->verifier_zext)
613 			emit_zext_32(rd, ctx);
614 		break;
615 	case BPF_ALU | BPF_ARSH | BPF_X:
616 	case BPF_ALU64 | BPF_ARSH | BPF_X:
617 		emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
618 		if (!is64 && !aux->verifier_zext)
619 			emit_zext_32(rd, ctx);
620 		break;
621 
622 	/* dst = -dst */
623 	case BPF_ALU | BPF_NEG:
624 	case BPF_ALU64 | BPF_NEG:
625 		emit_sub(rd, RV_REG_ZERO, rd, ctx);
626 		if (!is64 && !aux->verifier_zext)
627 			emit_zext_32(rd, ctx);
628 		break;
629 
630 	/* dst = BSWAP##imm(dst) */
631 	case BPF_ALU | BPF_END | BPF_FROM_LE:
632 		switch (imm) {
633 		case 16:
634 			emit_slli(rd, rd, 48, ctx);
635 			emit_srli(rd, rd, 48, ctx);
636 			break;
637 		case 32:
638 			if (!aux->verifier_zext)
639 				emit_zext_32(rd, ctx);
640 			break;
641 		case 64:
642 			/* Do nothing */
643 			break;
644 		}
645 		break;
646 
647 	case BPF_ALU | BPF_END | BPF_FROM_BE:
648 		emit_li(RV_REG_T2, 0, ctx);
649 
650 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
651 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
652 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
653 		emit_srli(rd, rd, 8, ctx);
654 		if (imm == 16)
655 			goto out_be;
656 
657 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
658 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
659 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
660 		emit_srli(rd, rd, 8, ctx);
661 
662 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
663 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
664 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
665 		emit_srli(rd, rd, 8, ctx);
666 		if (imm == 32)
667 			goto out_be;
668 
669 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
670 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
671 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
672 		emit_srli(rd, rd, 8, ctx);
673 
674 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
675 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
676 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
677 		emit_srli(rd, rd, 8, ctx);
678 
679 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
680 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
681 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
682 		emit_srli(rd, rd, 8, ctx);
683 
684 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
685 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
686 		emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
687 		emit_srli(rd, rd, 8, ctx);
688 out_be:
689 		emit_andi(RV_REG_T1, rd, 0xff, ctx);
690 		emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
691 
692 		emit_mv(rd, RV_REG_T2, ctx);
693 		break;
694 
695 	/* dst = imm */
696 	case BPF_ALU | BPF_MOV | BPF_K:
697 	case BPF_ALU64 | BPF_MOV | BPF_K:
698 		emit_imm(rd, imm, ctx);
699 		if (!is64 && !aux->verifier_zext)
700 			emit_zext_32(rd, ctx);
701 		break;
702 
703 	/* dst = dst OP imm */
704 	case BPF_ALU | BPF_ADD | BPF_K:
705 	case BPF_ALU64 | BPF_ADD | BPF_K:
706 		if (is_12b_int(imm)) {
707 			emit_addi(rd, rd, imm, ctx);
708 		} else {
709 			emit_imm(RV_REG_T1, imm, ctx);
710 			emit_add(rd, rd, RV_REG_T1, ctx);
711 		}
712 		if (!is64 && !aux->verifier_zext)
713 			emit_zext_32(rd, ctx);
714 		break;
715 	case BPF_ALU | BPF_SUB | BPF_K:
716 	case BPF_ALU64 | BPF_SUB | BPF_K:
717 		if (is_12b_int(-imm)) {
718 			emit_addi(rd, rd, -imm, ctx);
719 		} else {
720 			emit_imm(RV_REG_T1, imm, ctx);
721 			emit_sub(rd, rd, RV_REG_T1, ctx);
722 		}
723 		if (!is64 && !aux->verifier_zext)
724 			emit_zext_32(rd, ctx);
725 		break;
726 	case BPF_ALU | BPF_AND | BPF_K:
727 	case BPF_ALU64 | BPF_AND | BPF_K:
728 		if (is_12b_int(imm)) {
729 			emit_andi(rd, rd, imm, ctx);
730 		} else {
731 			emit_imm(RV_REG_T1, imm, ctx);
732 			emit_and(rd, rd, RV_REG_T1, ctx);
733 		}
734 		if (!is64 && !aux->verifier_zext)
735 			emit_zext_32(rd, ctx);
736 		break;
737 	case BPF_ALU | BPF_OR | BPF_K:
738 	case BPF_ALU64 | BPF_OR | BPF_K:
739 		if (is_12b_int(imm)) {
740 			emit(rv_ori(rd, rd, imm), ctx);
741 		} else {
742 			emit_imm(RV_REG_T1, imm, ctx);
743 			emit_or(rd, rd, RV_REG_T1, ctx);
744 		}
745 		if (!is64 && !aux->verifier_zext)
746 			emit_zext_32(rd, ctx);
747 		break;
748 	case BPF_ALU | BPF_XOR | BPF_K:
749 	case BPF_ALU64 | BPF_XOR | BPF_K:
750 		if (is_12b_int(imm)) {
751 			emit(rv_xori(rd, rd, imm), ctx);
752 		} else {
753 			emit_imm(RV_REG_T1, imm, ctx);
754 			emit_xor(rd, rd, RV_REG_T1, ctx);
755 		}
756 		if (!is64 && !aux->verifier_zext)
757 			emit_zext_32(rd, ctx);
758 		break;
759 	case BPF_ALU | BPF_MUL | BPF_K:
760 	case BPF_ALU64 | BPF_MUL | BPF_K:
761 		emit_imm(RV_REG_T1, imm, ctx);
762 		emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
763 		     rv_mulw(rd, rd, RV_REG_T1), ctx);
764 		if (!is64 && !aux->verifier_zext)
765 			emit_zext_32(rd, ctx);
766 		break;
767 	case BPF_ALU | BPF_DIV | BPF_K:
768 	case BPF_ALU64 | BPF_DIV | BPF_K:
769 		emit_imm(RV_REG_T1, imm, ctx);
770 		emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
771 		     rv_divuw(rd, rd, RV_REG_T1), ctx);
772 		if (!is64 && !aux->verifier_zext)
773 			emit_zext_32(rd, ctx);
774 		break;
775 	case BPF_ALU | BPF_MOD | BPF_K:
776 	case BPF_ALU64 | BPF_MOD | BPF_K:
777 		emit_imm(RV_REG_T1, imm, ctx);
778 		emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
779 		     rv_remuw(rd, rd, RV_REG_T1), ctx);
780 		if (!is64 && !aux->verifier_zext)
781 			emit_zext_32(rd, ctx);
782 		break;
783 	case BPF_ALU | BPF_LSH | BPF_K:
784 	case BPF_ALU64 | BPF_LSH | BPF_K:
785 		emit_slli(rd, rd, imm, ctx);
786 
787 		if (!is64 && !aux->verifier_zext)
788 			emit_zext_32(rd, ctx);
789 		break;
790 	case BPF_ALU | BPF_RSH | BPF_K:
791 	case BPF_ALU64 | BPF_RSH | BPF_K:
792 		if (is64)
793 			emit_srli(rd, rd, imm, ctx);
794 		else
795 			emit(rv_srliw(rd, rd, imm), ctx);
796 
797 		if (!is64 && !aux->verifier_zext)
798 			emit_zext_32(rd, ctx);
799 		break;
800 	case BPF_ALU | BPF_ARSH | BPF_K:
801 	case BPF_ALU64 | BPF_ARSH | BPF_K:
802 		if (is64)
803 			emit_srai(rd, rd, imm, ctx);
804 		else
805 			emit(rv_sraiw(rd, rd, imm), ctx);
806 
807 		if (!is64 && !aux->verifier_zext)
808 			emit_zext_32(rd, ctx);
809 		break;
810 
811 	/* JUMP off */
812 	case BPF_JMP | BPF_JA:
813 		rvoff = rv_offset(i, off, ctx);
814 		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
815 		if (ret)
816 			return ret;
817 		break;
818 
819 	/* IF (dst COND src) JUMP off */
820 	case BPF_JMP | BPF_JEQ | BPF_X:
821 	case BPF_JMP32 | BPF_JEQ | BPF_X:
822 	case BPF_JMP | BPF_JGT | BPF_X:
823 	case BPF_JMP32 | BPF_JGT | BPF_X:
824 	case BPF_JMP | BPF_JLT | BPF_X:
825 	case BPF_JMP32 | BPF_JLT | BPF_X:
826 	case BPF_JMP | BPF_JGE | BPF_X:
827 	case BPF_JMP32 | BPF_JGE | BPF_X:
828 	case BPF_JMP | BPF_JLE | BPF_X:
829 	case BPF_JMP32 | BPF_JLE | BPF_X:
830 	case BPF_JMP | BPF_JNE | BPF_X:
831 	case BPF_JMP32 | BPF_JNE | BPF_X:
832 	case BPF_JMP | BPF_JSGT | BPF_X:
833 	case BPF_JMP32 | BPF_JSGT | BPF_X:
834 	case BPF_JMP | BPF_JSLT | BPF_X:
835 	case BPF_JMP32 | BPF_JSLT | BPF_X:
836 	case BPF_JMP | BPF_JSGE | BPF_X:
837 	case BPF_JMP32 | BPF_JSGE | BPF_X:
838 	case BPF_JMP | BPF_JSLE | BPF_X:
839 	case BPF_JMP32 | BPF_JSLE | BPF_X:
840 	case BPF_JMP | BPF_JSET | BPF_X:
841 	case BPF_JMP32 | BPF_JSET | BPF_X:
842 		rvoff = rv_offset(i, off, ctx);
843 		if (!is64) {
844 			s = ctx->ninsns;
845 			if (is_signed_bpf_cond(BPF_OP(code)))
846 				emit_sext_32_rd_rs(&rd, &rs, ctx);
847 			else
848 				emit_zext_32_rd_rs(&rd, &rs, ctx);
849 			e = ctx->ninsns;
850 
851 			/* Adjust for extra insns */
852 			rvoff -= ninsns_rvoff(e - s);
853 		}
854 
855 		if (BPF_OP(code) == BPF_JSET) {
856 			/* Adjust for and */
857 			rvoff -= 4;
858 			emit_and(RV_REG_T1, rd, rs, ctx);
859 			emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff,
860 				    ctx);
861 		} else {
862 			emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
863 		}
864 		break;
865 
866 	/* IF (dst COND imm) JUMP off */
867 	case BPF_JMP | BPF_JEQ | BPF_K:
868 	case BPF_JMP32 | BPF_JEQ | BPF_K:
869 	case BPF_JMP | BPF_JGT | BPF_K:
870 	case BPF_JMP32 | BPF_JGT | BPF_K:
871 	case BPF_JMP | BPF_JLT | BPF_K:
872 	case BPF_JMP32 | BPF_JLT | BPF_K:
873 	case BPF_JMP | BPF_JGE | BPF_K:
874 	case BPF_JMP32 | BPF_JGE | BPF_K:
875 	case BPF_JMP | BPF_JLE | BPF_K:
876 	case BPF_JMP32 | BPF_JLE | BPF_K:
877 	case BPF_JMP | BPF_JNE | BPF_K:
878 	case BPF_JMP32 | BPF_JNE | BPF_K:
879 	case BPF_JMP | BPF_JSGT | BPF_K:
880 	case BPF_JMP32 | BPF_JSGT | BPF_K:
881 	case BPF_JMP | BPF_JSLT | BPF_K:
882 	case BPF_JMP32 | BPF_JSLT | BPF_K:
883 	case BPF_JMP | BPF_JSGE | BPF_K:
884 	case BPF_JMP32 | BPF_JSGE | BPF_K:
885 	case BPF_JMP | BPF_JSLE | BPF_K:
886 	case BPF_JMP32 | BPF_JSLE | BPF_K:
887 		rvoff = rv_offset(i, off, ctx);
888 		s = ctx->ninsns;
889 		if (imm) {
890 			emit_imm(RV_REG_T1, imm, ctx);
891 			rs = RV_REG_T1;
892 		} else {
893 			/* If imm is 0, simply use zero register. */
894 			rs = RV_REG_ZERO;
895 		}
896 		if (!is64) {
897 			if (is_signed_bpf_cond(BPF_OP(code)))
898 				emit_sext_32_rd(&rd, ctx);
899 			else
900 				emit_zext_32_rd_t1(&rd, ctx);
901 		}
902 		e = ctx->ninsns;
903 
904 		/* Adjust for extra insns */
905 		rvoff -= ninsns_rvoff(e - s);
906 		emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
907 		break;
908 
909 	case BPF_JMP | BPF_JSET | BPF_K:
910 	case BPF_JMP32 | BPF_JSET | BPF_K:
911 		rvoff = rv_offset(i, off, ctx);
912 		s = ctx->ninsns;
913 		if (is_12b_int(imm)) {
914 			emit_andi(RV_REG_T1, rd, imm, ctx);
915 		} else {
916 			emit_imm(RV_REG_T1, imm, ctx);
917 			emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
918 		}
919 		/* For jset32, we should clear the upper 32 bits of t1, but
920 		 * sign-extension is sufficient here and saves one instruction,
921 		 * as t1 is used only in comparison against zero.
922 		 */
923 		if (!is64 && imm < 0)
924 			emit_addiw(RV_REG_T1, RV_REG_T1, 0, ctx);
925 		e = ctx->ninsns;
926 		rvoff -= ninsns_rvoff(e - s);
927 		emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
928 		break;
929 
930 	/* function call */
931 	case BPF_JMP | BPF_CALL:
932 	{
933 		bool fixed;
934 		u64 addr;
935 
936 		mark_call(ctx);
937 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass, &addr,
938 					    &fixed);
939 		if (ret < 0)
940 			return ret;
941 		ret = emit_call(fixed, addr, ctx);
942 		if (ret)
943 			return ret;
944 		break;
945 	}
946 	/* tail call */
947 	case BPF_JMP | BPF_TAIL_CALL:
948 		if (emit_bpf_tail_call(i, ctx))
949 			return -1;
950 		break;
951 
952 	/* function return */
953 	case BPF_JMP | BPF_EXIT:
954 		if (i == ctx->prog->len - 1)
955 			break;
956 
957 		rvoff = epilogue_offset(ctx);
958 		ret = emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
959 		if (ret)
960 			return ret;
961 		break;
962 
963 	/* dst = imm64 */
964 	case BPF_LD | BPF_IMM | BPF_DW:
965 	{
966 		struct bpf_insn insn1 = insn[1];
967 		u64 imm64;
968 
969 		imm64 = (u64)insn1.imm << 32 | (u32)imm;
970 		emit_imm(rd, imm64, ctx);
971 		return 1;
972 	}
973 
974 	/* LDX: dst = *(size *)(src + off) */
975 	case BPF_LDX | BPF_MEM | BPF_B:
976 	case BPF_LDX | BPF_MEM | BPF_H:
977 	case BPF_LDX | BPF_MEM | BPF_W:
978 	case BPF_LDX | BPF_MEM | BPF_DW:
979 	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
980 	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
981 	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
982 	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
983 	{
984 		int insn_len, insns_start;
985 
986 		switch (BPF_SIZE(code)) {
987 		case BPF_B:
988 			if (is_12b_int(off)) {
989 				insns_start = ctx->ninsns;
990 				emit(rv_lbu(rd, off, rs), ctx);
991 				insn_len = ctx->ninsns - insns_start;
992 				break;
993 			}
994 
995 			emit_imm(RV_REG_T1, off, ctx);
996 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
997 			insns_start = ctx->ninsns;
998 			emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
999 			insn_len = ctx->ninsns - insns_start;
1000 			if (insn_is_zext(&insn[1]))
1001 				return 1;
1002 			break;
1003 		case BPF_H:
1004 			if (is_12b_int(off)) {
1005 				insns_start = ctx->ninsns;
1006 				emit(rv_lhu(rd, off, rs), ctx);
1007 				insn_len = ctx->ninsns - insns_start;
1008 				break;
1009 			}
1010 
1011 			emit_imm(RV_REG_T1, off, ctx);
1012 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1013 			insns_start = ctx->ninsns;
1014 			emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
1015 			insn_len = ctx->ninsns - insns_start;
1016 			if (insn_is_zext(&insn[1]))
1017 				return 1;
1018 			break;
1019 		case BPF_W:
1020 			if (is_12b_int(off)) {
1021 				insns_start = ctx->ninsns;
1022 				emit(rv_lwu(rd, off, rs), ctx);
1023 				insn_len = ctx->ninsns - insns_start;
1024 				break;
1025 			}
1026 
1027 			emit_imm(RV_REG_T1, off, ctx);
1028 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1029 			insns_start = ctx->ninsns;
1030 			emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
1031 			insn_len = ctx->ninsns - insns_start;
1032 			if (insn_is_zext(&insn[1]))
1033 				return 1;
1034 			break;
1035 		case BPF_DW:
1036 			if (is_12b_int(off)) {
1037 				insns_start = ctx->ninsns;
1038 				emit_ld(rd, off, rs, ctx);
1039 				insn_len = ctx->ninsns - insns_start;
1040 				break;
1041 			}
1042 
1043 			emit_imm(RV_REG_T1, off, ctx);
1044 			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1045 			insns_start = ctx->ninsns;
1046 			emit_ld(rd, 0, RV_REG_T1, ctx);
1047 			insn_len = ctx->ninsns - insns_start;
1048 			break;
1049 		}
1050 
1051 		ret = add_exception_handler(insn, ctx, rd, insn_len);
1052 		if (ret)
1053 			return ret;
1054 		break;
1055 	}
1056 	/* speculation barrier */
1057 	case BPF_ST | BPF_NOSPEC:
1058 		break;
1059 
1060 	/* ST: *(size *)(dst + off) = imm */
1061 	case BPF_ST | BPF_MEM | BPF_B:
1062 		emit_imm(RV_REG_T1, imm, ctx);
1063 		if (is_12b_int(off)) {
1064 			emit(rv_sb(rd, off, RV_REG_T1), ctx);
1065 			break;
1066 		}
1067 
1068 		emit_imm(RV_REG_T2, off, ctx);
1069 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1070 		emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1071 		break;
1072 
1073 	case BPF_ST | BPF_MEM | BPF_H:
1074 		emit_imm(RV_REG_T1, imm, ctx);
1075 		if (is_12b_int(off)) {
1076 			emit(rv_sh(rd, off, RV_REG_T1), ctx);
1077 			break;
1078 		}
1079 
1080 		emit_imm(RV_REG_T2, off, ctx);
1081 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1082 		emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1083 		break;
1084 	case BPF_ST | BPF_MEM | BPF_W:
1085 		emit_imm(RV_REG_T1, imm, ctx);
1086 		if (is_12b_int(off)) {
1087 			emit_sw(rd, off, RV_REG_T1, ctx);
1088 			break;
1089 		}
1090 
1091 		emit_imm(RV_REG_T2, off, ctx);
1092 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1093 		emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1094 		break;
1095 	case BPF_ST | BPF_MEM | BPF_DW:
1096 		emit_imm(RV_REG_T1, imm, ctx);
1097 		if (is_12b_int(off)) {
1098 			emit_sd(rd, off, RV_REG_T1, ctx);
1099 			break;
1100 		}
1101 
1102 		emit_imm(RV_REG_T2, off, ctx);
1103 		emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1104 		emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1105 		break;
1106 
1107 	/* STX: *(size *)(dst + off) = src */
1108 	case BPF_STX | BPF_MEM | BPF_B:
1109 		if (is_12b_int(off)) {
1110 			emit(rv_sb(rd, off, rs), ctx);
1111 			break;
1112 		}
1113 
1114 		emit_imm(RV_REG_T1, off, ctx);
1115 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1116 		emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1117 		break;
1118 	case BPF_STX | BPF_MEM | BPF_H:
1119 		if (is_12b_int(off)) {
1120 			emit(rv_sh(rd, off, rs), ctx);
1121 			break;
1122 		}
1123 
1124 		emit_imm(RV_REG_T1, off, ctx);
1125 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1126 		emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1127 		break;
1128 	case BPF_STX | BPF_MEM | BPF_W:
1129 		if (is_12b_int(off)) {
1130 			emit_sw(rd, off, rs, ctx);
1131 			break;
1132 		}
1133 
1134 		emit_imm(RV_REG_T1, off, ctx);
1135 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1136 		emit_sw(RV_REG_T1, 0, rs, ctx);
1137 		break;
1138 	case BPF_STX | BPF_MEM | BPF_DW:
1139 		if (is_12b_int(off)) {
1140 			emit_sd(rd, off, rs, ctx);
1141 			break;
1142 		}
1143 
1144 		emit_imm(RV_REG_T1, off, ctx);
1145 		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1146 		emit_sd(RV_REG_T1, 0, rs, ctx);
1147 		break;
1148 	case BPF_STX | BPF_ATOMIC | BPF_W:
1149 	case BPF_STX | BPF_ATOMIC | BPF_DW:
1150 		if (insn->imm != BPF_ADD) {
1151 			pr_err("bpf-jit: not supported: atomic operation %02x ***\n",
1152 			       insn->imm);
1153 			return -EINVAL;
1154 		}
1155 
1156 		/* atomic_add: lock *(u32 *)(dst + off) += src
1157 		 * atomic_add: lock *(u64 *)(dst + off) += src
1158 		 */
1159 
1160 		if (off) {
1161 			if (is_12b_int(off)) {
1162 				emit_addi(RV_REG_T1, rd, off, ctx);
1163 			} else {
1164 				emit_imm(RV_REG_T1, off, ctx);
1165 				emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1166 			}
1167 
1168 			rd = RV_REG_T1;
1169 		}
1170 
1171 		emit(BPF_SIZE(code) == BPF_W ?
1172 		     rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0) :
1173 		     rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0), ctx);
1174 		break;
1175 	default:
1176 		pr_err("bpf-jit: unknown opcode %02x\n", code);
1177 		return -EINVAL;
1178 	}
1179 
1180 	return 0;
1181 }
1182 
1183 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1184 {
1185 	int stack_adjust = 0, store_offset, bpf_stack_adjust;
1186 
1187 	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1188 	if (bpf_stack_adjust)
1189 		mark_fp(ctx);
1190 
1191 	if (seen_reg(RV_REG_RA, ctx))
1192 		stack_adjust += 8;
1193 	stack_adjust += 8; /* RV_REG_FP */
1194 	if (seen_reg(RV_REG_S1, ctx))
1195 		stack_adjust += 8;
1196 	if (seen_reg(RV_REG_S2, ctx))
1197 		stack_adjust += 8;
1198 	if (seen_reg(RV_REG_S3, ctx))
1199 		stack_adjust += 8;
1200 	if (seen_reg(RV_REG_S4, ctx))
1201 		stack_adjust += 8;
1202 	if (seen_reg(RV_REG_S5, ctx))
1203 		stack_adjust += 8;
1204 	if (seen_reg(RV_REG_S6, ctx))
1205 		stack_adjust += 8;
1206 
1207 	stack_adjust = round_up(stack_adjust, 16);
1208 	stack_adjust += bpf_stack_adjust;
1209 
1210 	store_offset = stack_adjust - 8;
1211 
1212 	/* First instruction is always setting the tail-call-counter
1213 	 * (TCC) register. This instruction is skipped for tail calls.
1214 	 * Force using a 4-byte (non-compressed) instruction.
1215 	 */
1216 	emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1217 
1218 	emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
1219 
1220 	if (seen_reg(RV_REG_RA, ctx)) {
1221 		emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
1222 		store_offset -= 8;
1223 	}
1224 	emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
1225 	store_offset -= 8;
1226 	if (seen_reg(RV_REG_S1, ctx)) {
1227 		emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
1228 		store_offset -= 8;
1229 	}
1230 	if (seen_reg(RV_REG_S2, ctx)) {
1231 		emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
1232 		store_offset -= 8;
1233 	}
1234 	if (seen_reg(RV_REG_S3, ctx)) {
1235 		emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
1236 		store_offset -= 8;
1237 	}
1238 	if (seen_reg(RV_REG_S4, ctx)) {
1239 		emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
1240 		store_offset -= 8;
1241 	}
1242 	if (seen_reg(RV_REG_S5, ctx)) {
1243 		emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
1244 		store_offset -= 8;
1245 	}
1246 	if (seen_reg(RV_REG_S6, ctx)) {
1247 		emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
1248 		store_offset -= 8;
1249 	}
1250 
1251 	emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
1252 
1253 	if (bpf_stack_adjust)
1254 		emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
1255 
1256 	/* Program contains calls and tail calls, so RV_REG_TCC need
1257 	 * to be saved across calls.
1258 	 */
1259 	if (seen_tail_call(ctx) && seen_call(ctx))
1260 		emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
1261 
1262 	ctx->stack_size = stack_adjust;
1263 }
1264 
1265 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1266 {
1267 	__build_epilogue(false, ctx);
1268 }
1269