xref: /linux/arch/riscv/net/bpf_jit_comp32.c (revision 5d085ad2e68cceec8332b23ea8f630a28b506366)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * BPF JIT compiler for RV32G
4  *
5  * Copyright (c) 2020 Luke Nelson <luke.r.nels@gmail.com>
6  * Copyright (c) 2020 Xi Wang <xi.wang@gmail.com>
7  *
8  * The code is based on the BPF JIT compiler for RV64G by Björn Töpel and
9  * the BPF JIT compiler for 32-bit ARM by Shubham Bansal and Mircea Gherzan.
10  */
11 
12 #include <linux/bpf.h>
13 #include <linux/filter.h>
14 #include "bpf_jit.h"
15 
16 enum {
17 	/* Stack layout - these are offsets from (top of stack - 4). */
18 	BPF_R6_HI,
19 	BPF_R6_LO,
20 	BPF_R7_HI,
21 	BPF_R7_LO,
22 	BPF_R8_HI,
23 	BPF_R8_LO,
24 	BPF_R9_HI,
25 	BPF_R9_LO,
26 	BPF_AX_HI,
27 	BPF_AX_LO,
28 	/* Stack space for BPF_REG_6 through BPF_REG_9 and BPF_REG_AX. */
29 	BPF_JIT_SCRATCH_REGS,
30 };
31 
32 #define STACK_OFFSET(k) (-4 - ((k) * 4))
33 
34 #define TMP_REG_1	(MAX_BPF_JIT_REG + 0)
35 #define TMP_REG_2	(MAX_BPF_JIT_REG + 1)
36 
37 #define RV_REG_TCC		RV_REG_T6
38 #define RV_REG_TCC_SAVED	RV_REG_S7
39 
40 static const s8 bpf2rv32[][2] = {
41 	/* Return value from in-kernel function, and exit value from eBPF. */
42 	[BPF_REG_0] = {RV_REG_S2, RV_REG_S1},
43 	/* Arguments from eBPF program to in-kernel function. */
44 	[BPF_REG_1] = {RV_REG_A1, RV_REG_A0},
45 	[BPF_REG_2] = {RV_REG_A3, RV_REG_A2},
46 	[BPF_REG_3] = {RV_REG_A5, RV_REG_A4},
47 	[BPF_REG_4] = {RV_REG_A7, RV_REG_A6},
48 	[BPF_REG_5] = {RV_REG_S4, RV_REG_S3},
49 	/*
50 	 * Callee-saved registers that in-kernel function will preserve.
51 	 * Stored on the stack.
52 	 */
53 	[BPF_REG_6] = {STACK_OFFSET(BPF_R6_HI), STACK_OFFSET(BPF_R6_LO)},
54 	[BPF_REG_7] = {STACK_OFFSET(BPF_R7_HI), STACK_OFFSET(BPF_R7_LO)},
55 	[BPF_REG_8] = {STACK_OFFSET(BPF_R8_HI), STACK_OFFSET(BPF_R8_LO)},
56 	[BPF_REG_9] = {STACK_OFFSET(BPF_R9_HI), STACK_OFFSET(BPF_R9_LO)},
57 	/* Read-only frame pointer to access BPF stack. */
58 	[BPF_REG_FP] = {RV_REG_S6, RV_REG_S5},
59 	/* Temporary register for blinding constants. Stored on the stack. */
60 	[BPF_REG_AX] = {STACK_OFFSET(BPF_AX_HI), STACK_OFFSET(BPF_AX_LO)},
61 	/*
62 	 * Temporary registers used by the JIT to operate on registers stored
63 	 * on the stack. Save t0 and t1 to be used as temporaries in generated
64 	 * code.
65 	 */
66 	[TMP_REG_1] = {RV_REG_T3, RV_REG_T2},
67 	[TMP_REG_2] = {RV_REG_T5, RV_REG_T4},
68 };
69 
70 static s8 hi(const s8 *r)
71 {
72 	return r[0];
73 }
74 
75 static s8 lo(const s8 *r)
76 {
77 	return r[1];
78 }
79 
80 static void emit_imm(const s8 rd, s32 imm, struct rv_jit_context *ctx)
81 {
82 	u32 upper = (imm + (1 << 11)) >> 12;
83 	u32 lower = imm & 0xfff;
84 
85 	if (upper) {
86 		emit(rv_lui(rd, upper), ctx);
87 		emit(rv_addi(rd, rd, lower), ctx);
88 	} else {
89 		emit(rv_addi(rd, RV_REG_ZERO, lower), ctx);
90 	}
91 }
92 
93 static void emit_imm32(const s8 *rd, s32 imm, struct rv_jit_context *ctx)
94 {
95 	/* Emit immediate into lower bits. */
96 	emit_imm(lo(rd), imm, ctx);
97 
98 	/* Sign-extend into upper bits. */
99 	if (imm >= 0)
100 		emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
101 	else
102 		emit(rv_addi(hi(rd), RV_REG_ZERO, -1), ctx);
103 }
104 
105 static void emit_imm64(const s8 *rd, s32 imm_hi, s32 imm_lo,
106 		       struct rv_jit_context *ctx)
107 {
108 	emit_imm(lo(rd), imm_lo, ctx);
109 	emit_imm(hi(rd), imm_hi, ctx);
110 }
111 
112 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
113 {
114 	int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 4;
115 	const s8 *r0 = bpf2rv32[BPF_REG_0];
116 
117 	store_offset -= 4 * BPF_JIT_SCRATCH_REGS;
118 
119 	/* Set return value if not tail call. */
120 	if (!is_tail_call) {
121 		emit(rv_addi(RV_REG_A0, lo(r0), 0), ctx);
122 		emit(rv_addi(RV_REG_A1, hi(r0), 0), ctx);
123 	}
124 
125 	/* Restore callee-saved registers. */
126 	emit(rv_lw(RV_REG_RA, store_offset - 0, RV_REG_SP), ctx);
127 	emit(rv_lw(RV_REG_FP, store_offset - 4, RV_REG_SP), ctx);
128 	emit(rv_lw(RV_REG_S1, store_offset - 8, RV_REG_SP), ctx);
129 	emit(rv_lw(RV_REG_S2, store_offset - 12, RV_REG_SP), ctx);
130 	emit(rv_lw(RV_REG_S3, store_offset - 16, RV_REG_SP), ctx);
131 	emit(rv_lw(RV_REG_S4, store_offset - 20, RV_REG_SP), ctx);
132 	emit(rv_lw(RV_REG_S5, store_offset - 24, RV_REG_SP), ctx);
133 	emit(rv_lw(RV_REG_S6, store_offset - 28, RV_REG_SP), ctx);
134 	emit(rv_lw(RV_REG_S7, store_offset - 32, RV_REG_SP), ctx);
135 
136 	emit(rv_addi(RV_REG_SP, RV_REG_SP, stack_adjust), ctx);
137 
138 	if (is_tail_call) {
139 		/*
140 		 * goto *(t0 + 4);
141 		 * Skips first instruction of prologue which initializes tail
142 		 * call counter. Assumes t0 contains address of target program,
143 		 * see emit_bpf_tail_call.
144 		 */
145 		emit(rv_jalr(RV_REG_ZERO, RV_REG_T0, 4), ctx);
146 	} else {
147 		emit(rv_jalr(RV_REG_ZERO, RV_REG_RA, 0), ctx);
148 	}
149 }
150 
151 static bool is_stacked(s8 reg)
152 {
153 	return reg < 0;
154 }
155 
156 static const s8 *bpf_get_reg64(const s8 *reg, const s8 *tmp,
157 			       struct rv_jit_context *ctx)
158 {
159 	if (is_stacked(hi(reg))) {
160 		emit(rv_lw(hi(tmp), hi(reg), RV_REG_FP), ctx);
161 		emit(rv_lw(lo(tmp), lo(reg), RV_REG_FP), ctx);
162 		reg = tmp;
163 	}
164 	return reg;
165 }
166 
167 static void bpf_put_reg64(const s8 *reg, const s8 *src,
168 			  struct rv_jit_context *ctx)
169 {
170 	if (is_stacked(hi(reg))) {
171 		emit(rv_sw(RV_REG_FP, hi(reg), hi(src)), ctx);
172 		emit(rv_sw(RV_REG_FP, lo(reg), lo(src)), ctx);
173 	}
174 }
175 
176 static const s8 *bpf_get_reg32(const s8 *reg, const s8 *tmp,
177 			       struct rv_jit_context *ctx)
178 {
179 	if (is_stacked(lo(reg))) {
180 		emit(rv_lw(lo(tmp), lo(reg), RV_REG_FP), ctx);
181 		reg = tmp;
182 	}
183 	return reg;
184 }
185 
186 static void bpf_put_reg32(const s8 *reg, const s8 *src,
187 			  struct rv_jit_context *ctx)
188 {
189 	if (is_stacked(lo(reg))) {
190 		emit(rv_sw(RV_REG_FP, lo(reg), lo(src)), ctx);
191 		if (!ctx->prog->aux->verifier_zext)
192 			emit(rv_sw(RV_REG_FP, hi(reg), RV_REG_ZERO), ctx);
193 	} else if (!ctx->prog->aux->verifier_zext) {
194 		emit(rv_addi(hi(reg), RV_REG_ZERO, 0), ctx);
195 	}
196 }
197 
198 static void emit_jump_and_link(u8 rd, s32 rvoff, bool force_jalr,
199 			       struct rv_jit_context *ctx)
200 {
201 	s32 upper, lower;
202 
203 	if (rvoff && is_21b_int(rvoff) && !force_jalr) {
204 		emit(rv_jal(rd, rvoff >> 1), ctx);
205 		return;
206 	}
207 
208 	upper = (rvoff + (1 << 11)) >> 12;
209 	lower = rvoff & 0xfff;
210 	emit(rv_auipc(RV_REG_T1, upper), ctx);
211 	emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
212 }
213 
214 static void emit_alu_i64(const s8 *dst, s32 imm,
215 			 struct rv_jit_context *ctx, const u8 op)
216 {
217 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
218 	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
219 
220 	switch (op) {
221 	case BPF_MOV:
222 		emit_imm32(rd, imm, ctx);
223 		break;
224 	case BPF_AND:
225 		if (is_12b_int(imm)) {
226 			emit(rv_andi(lo(rd), lo(rd), imm), ctx);
227 		} else {
228 			emit_imm(RV_REG_T0, imm, ctx);
229 			emit(rv_and(lo(rd), lo(rd), RV_REG_T0), ctx);
230 		}
231 		if (imm >= 0)
232 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
233 		break;
234 	case BPF_OR:
235 		if (is_12b_int(imm)) {
236 			emit(rv_ori(lo(rd), lo(rd), imm), ctx);
237 		} else {
238 			emit_imm(RV_REG_T0, imm, ctx);
239 			emit(rv_or(lo(rd), lo(rd), RV_REG_T0), ctx);
240 		}
241 		if (imm < 0)
242 			emit(rv_ori(hi(rd), RV_REG_ZERO, -1), ctx);
243 		break;
244 	case BPF_XOR:
245 		if (is_12b_int(imm)) {
246 			emit(rv_xori(lo(rd), lo(rd), imm), ctx);
247 		} else {
248 			emit_imm(RV_REG_T0, imm, ctx);
249 			emit(rv_xor(lo(rd), lo(rd), RV_REG_T0), ctx);
250 		}
251 		if (imm < 0)
252 			emit(rv_xori(hi(rd), hi(rd), -1), ctx);
253 		break;
254 	case BPF_LSH:
255 		if (imm >= 32) {
256 			emit(rv_slli(hi(rd), lo(rd), imm - 32), ctx);
257 			emit(rv_addi(lo(rd), RV_REG_ZERO, 0), ctx);
258 		} else if (imm == 0) {
259 			/* Do nothing. */
260 		} else {
261 			emit(rv_srli(RV_REG_T0, lo(rd), 32 - imm), ctx);
262 			emit(rv_slli(hi(rd), hi(rd), imm), ctx);
263 			emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
264 			emit(rv_slli(lo(rd), lo(rd), imm), ctx);
265 		}
266 		break;
267 	case BPF_RSH:
268 		if (imm >= 32) {
269 			emit(rv_srli(lo(rd), hi(rd), imm - 32), ctx);
270 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
271 		} else if (imm == 0) {
272 			/* Do nothing. */
273 		} else {
274 			emit(rv_slli(RV_REG_T0, hi(rd), 32 - imm), ctx);
275 			emit(rv_srli(lo(rd), lo(rd), imm), ctx);
276 			emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
277 			emit(rv_srli(hi(rd), hi(rd), imm), ctx);
278 		}
279 		break;
280 	case BPF_ARSH:
281 		if (imm >= 32) {
282 			emit(rv_srai(lo(rd), hi(rd), imm - 32), ctx);
283 			emit(rv_srai(hi(rd), hi(rd), 31), ctx);
284 		} else if (imm == 0) {
285 			/* Do nothing. */
286 		} else {
287 			emit(rv_slli(RV_REG_T0, hi(rd), 32 - imm), ctx);
288 			emit(rv_srli(lo(rd), lo(rd), imm), ctx);
289 			emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
290 			emit(rv_srai(hi(rd), hi(rd), imm), ctx);
291 		}
292 		break;
293 	}
294 
295 	bpf_put_reg64(dst, rd, ctx);
296 }
297 
298 static void emit_alu_i32(const s8 *dst, s32 imm,
299 			 struct rv_jit_context *ctx, const u8 op)
300 {
301 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
302 	const s8 *rd = bpf_get_reg32(dst, tmp1, ctx);
303 
304 	switch (op) {
305 	case BPF_MOV:
306 		emit_imm(lo(rd), imm, ctx);
307 		break;
308 	case BPF_ADD:
309 		if (is_12b_int(imm)) {
310 			emit(rv_addi(lo(rd), lo(rd), imm), ctx);
311 		} else {
312 			emit_imm(RV_REG_T0, imm, ctx);
313 			emit(rv_add(lo(rd), lo(rd), RV_REG_T0), ctx);
314 		}
315 		break;
316 	case BPF_SUB:
317 		if (is_12b_int(-imm)) {
318 			emit(rv_addi(lo(rd), lo(rd), -imm), ctx);
319 		} else {
320 			emit_imm(RV_REG_T0, imm, ctx);
321 			emit(rv_sub(lo(rd), lo(rd), RV_REG_T0), ctx);
322 		}
323 		break;
324 	case BPF_AND:
325 		if (is_12b_int(imm)) {
326 			emit(rv_andi(lo(rd), lo(rd), imm), ctx);
327 		} else {
328 			emit_imm(RV_REG_T0, imm, ctx);
329 			emit(rv_and(lo(rd), lo(rd), RV_REG_T0), ctx);
330 		}
331 		break;
332 	case BPF_OR:
333 		if (is_12b_int(imm)) {
334 			emit(rv_ori(lo(rd), lo(rd), imm), ctx);
335 		} else {
336 			emit_imm(RV_REG_T0, imm, ctx);
337 			emit(rv_or(lo(rd), lo(rd), RV_REG_T0), ctx);
338 		}
339 		break;
340 	case BPF_XOR:
341 		if (is_12b_int(imm)) {
342 			emit(rv_xori(lo(rd), lo(rd), imm), ctx);
343 		} else {
344 			emit_imm(RV_REG_T0, imm, ctx);
345 			emit(rv_xor(lo(rd), lo(rd), RV_REG_T0), ctx);
346 		}
347 		break;
348 	case BPF_LSH:
349 		if (is_12b_int(imm)) {
350 			emit(rv_slli(lo(rd), lo(rd), imm), ctx);
351 		} else {
352 			emit_imm(RV_REG_T0, imm, ctx);
353 			emit(rv_sll(lo(rd), lo(rd), RV_REG_T0), ctx);
354 		}
355 		break;
356 	case BPF_RSH:
357 		if (is_12b_int(imm)) {
358 			emit(rv_srli(lo(rd), lo(rd), imm), ctx);
359 		} else {
360 			emit_imm(RV_REG_T0, imm, ctx);
361 			emit(rv_srl(lo(rd), lo(rd), RV_REG_T0), ctx);
362 		}
363 		break;
364 	case BPF_ARSH:
365 		if (is_12b_int(imm)) {
366 			emit(rv_srai(lo(rd), lo(rd), imm), ctx);
367 		} else {
368 			emit_imm(RV_REG_T0, imm, ctx);
369 			emit(rv_sra(lo(rd), lo(rd), RV_REG_T0), ctx);
370 		}
371 		break;
372 	}
373 
374 	bpf_put_reg32(dst, rd, ctx);
375 }
376 
377 static void emit_alu_r64(const s8 *dst, const s8 *src,
378 			 struct rv_jit_context *ctx, const u8 op)
379 {
380 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
381 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
382 	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
383 	const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
384 
385 	switch (op) {
386 	case BPF_MOV:
387 		emit(rv_addi(lo(rd), lo(rs), 0), ctx);
388 		emit(rv_addi(hi(rd), hi(rs), 0), ctx);
389 		break;
390 	case BPF_ADD:
391 		if (rd == rs) {
392 			emit(rv_srli(RV_REG_T0, lo(rd), 31), ctx);
393 			emit(rv_slli(hi(rd), hi(rd), 1), ctx);
394 			emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
395 			emit(rv_slli(lo(rd), lo(rd), 1), ctx);
396 		} else {
397 			emit(rv_add(lo(rd), lo(rd), lo(rs)), ctx);
398 			emit(rv_sltu(RV_REG_T0, lo(rd), lo(rs)), ctx);
399 			emit(rv_add(hi(rd), hi(rd), hi(rs)), ctx);
400 			emit(rv_add(hi(rd), hi(rd), RV_REG_T0), ctx);
401 		}
402 		break;
403 	case BPF_SUB:
404 		emit(rv_sub(RV_REG_T1, hi(rd), hi(rs)), ctx);
405 		emit(rv_sltu(RV_REG_T0, lo(rd), lo(rs)), ctx);
406 		emit(rv_sub(hi(rd), RV_REG_T1, RV_REG_T0), ctx);
407 		emit(rv_sub(lo(rd), lo(rd), lo(rs)), ctx);
408 		break;
409 	case BPF_AND:
410 		emit(rv_and(lo(rd), lo(rd), lo(rs)), ctx);
411 		emit(rv_and(hi(rd), hi(rd), hi(rs)), ctx);
412 		break;
413 	case BPF_OR:
414 		emit(rv_or(lo(rd), lo(rd), lo(rs)), ctx);
415 		emit(rv_or(hi(rd), hi(rd), hi(rs)), ctx);
416 		break;
417 	case BPF_XOR:
418 		emit(rv_xor(lo(rd), lo(rd), lo(rs)), ctx);
419 		emit(rv_xor(hi(rd), hi(rd), hi(rs)), ctx);
420 		break;
421 	case BPF_MUL:
422 		emit(rv_mul(RV_REG_T0, hi(rs), lo(rd)), ctx);
423 		emit(rv_mul(hi(rd), hi(rd), lo(rs)), ctx);
424 		emit(rv_mulhu(RV_REG_T1, lo(rd), lo(rs)), ctx);
425 		emit(rv_add(hi(rd), hi(rd), RV_REG_T0), ctx);
426 		emit(rv_mul(lo(rd), lo(rd), lo(rs)), ctx);
427 		emit(rv_add(hi(rd), hi(rd), RV_REG_T1), ctx);
428 		break;
429 	case BPF_LSH:
430 		emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
431 		emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
432 		emit(rv_sll(hi(rd), lo(rd), RV_REG_T0), ctx);
433 		emit(rv_addi(lo(rd), RV_REG_ZERO, 0), ctx);
434 		emit(rv_jal(RV_REG_ZERO, 16), ctx);
435 		emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
436 		emit(rv_srli(RV_REG_T0, lo(rd), 1), ctx);
437 		emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
438 		emit(rv_srl(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
439 		emit(rv_sll(hi(rd), hi(rd), lo(rs)), ctx);
440 		emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
441 		emit(rv_sll(lo(rd), lo(rd), lo(rs)), ctx);
442 		break;
443 	case BPF_RSH:
444 		emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
445 		emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
446 		emit(rv_srl(lo(rd), hi(rd), RV_REG_T0), ctx);
447 		emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
448 		emit(rv_jal(RV_REG_ZERO, 16), ctx);
449 		emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
450 		emit(rv_slli(RV_REG_T0, hi(rd), 1), ctx);
451 		emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
452 		emit(rv_sll(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
453 		emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
454 		emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
455 		emit(rv_srl(hi(rd), hi(rd), lo(rs)), ctx);
456 		break;
457 	case BPF_ARSH:
458 		emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
459 		emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
460 		emit(rv_sra(lo(rd), hi(rd), RV_REG_T0), ctx);
461 		emit(rv_srai(hi(rd), hi(rd), 31), ctx);
462 		emit(rv_jal(RV_REG_ZERO, 16), ctx);
463 		emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
464 		emit(rv_slli(RV_REG_T0, hi(rd), 1), ctx);
465 		emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
466 		emit(rv_sll(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
467 		emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
468 		emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
469 		emit(rv_sra(hi(rd), hi(rd), lo(rs)), ctx);
470 		break;
471 	case BPF_NEG:
472 		emit(rv_sub(lo(rd), RV_REG_ZERO, lo(rd)), ctx);
473 		emit(rv_sltu(RV_REG_T0, RV_REG_ZERO, lo(rd)), ctx);
474 		emit(rv_sub(hi(rd), RV_REG_ZERO, hi(rd)), ctx);
475 		emit(rv_sub(hi(rd), hi(rd), RV_REG_T0), ctx);
476 		break;
477 	}
478 
479 	bpf_put_reg64(dst, rd, ctx);
480 }
481 
482 static void emit_alu_r32(const s8 *dst, const s8 *src,
483 			 struct rv_jit_context *ctx, const u8 op)
484 {
485 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
486 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
487 	const s8 *rd = bpf_get_reg32(dst, tmp1, ctx);
488 	const s8 *rs = bpf_get_reg32(src, tmp2, ctx);
489 
490 	switch (op) {
491 	case BPF_MOV:
492 		emit(rv_addi(lo(rd), lo(rs), 0), ctx);
493 		break;
494 	case BPF_ADD:
495 		emit(rv_add(lo(rd), lo(rd), lo(rs)), ctx);
496 		break;
497 	case BPF_SUB:
498 		emit(rv_sub(lo(rd), lo(rd), lo(rs)), ctx);
499 		break;
500 	case BPF_AND:
501 		emit(rv_and(lo(rd), lo(rd), lo(rs)), ctx);
502 		break;
503 	case BPF_OR:
504 		emit(rv_or(lo(rd), lo(rd), lo(rs)), ctx);
505 		break;
506 	case BPF_XOR:
507 		emit(rv_xor(lo(rd), lo(rd), lo(rs)), ctx);
508 		break;
509 	case BPF_MUL:
510 		emit(rv_mul(lo(rd), lo(rd), lo(rs)), ctx);
511 		break;
512 	case BPF_DIV:
513 		emit(rv_divu(lo(rd), lo(rd), lo(rs)), ctx);
514 		break;
515 	case BPF_MOD:
516 		emit(rv_remu(lo(rd), lo(rd), lo(rs)), ctx);
517 		break;
518 	case BPF_LSH:
519 		emit(rv_sll(lo(rd), lo(rd), lo(rs)), ctx);
520 		break;
521 	case BPF_RSH:
522 		emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
523 		break;
524 	case BPF_ARSH:
525 		emit(rv_sra(lo(rd), lo(rd), lo(rs)), ctx);
526 		break;
527 	case BPF_NEG:
528 		emit(rv_sub(lo(rd), RV_REG_ZERO, lo(rd)), ctx);
529 		break;
530 	}
531 
532 	bpf_put_reg32(dst, rd, ctx);
533 }
534 
535 static int emit_branch_r64(const s8 *src1, const s8 *src2, s32 rvoff,
536 			   struct rv_jit_context *ctx, const u8 op)
537 {
538 	int e, s = ctx->ninsns;
539 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
540 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
541 
542 	const s8 *rs1 = bpf_get_reg64(src1, tmp1, ctx);
543 	const s8 *rs2 = bpf_get_reg64(src2, tmp2, ctx);
544 
545 	/*
546 	 * NO_JUMP skips over the rest of the instructions and the
547 	 * emit_jump_and_link, meaning the BPF branch is not taken.
548 	 * JUMP skips directly to the emit_jump_and_link, meaning
549 	 * the BPF branch is taken.
550 	 *
551 	 * The fallthrough case results in the BPF branch being taken.
552 	 */
553 #define NO_JUMP(idx) (6 + (2 * (idx)))
554 #define JUMP(idx) (2 + (2 * (idx)))
555 
556 	switch (op) {
557 	case BPF_JEQ:
558 		emit(rv_bne(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
559 		emit(rv_bne(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
560 		break;
561 	case BPF_JGT:
562 		emit(rv_bgtu(hi(rs1), hi(rs2), JUMP(2)), ctx);
563 		emit(rv_bltu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
564 		emit(rv_bleu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
565 		break;
566 	case BPF_JLT:
567 		emit(rv_bltu(hi(rs1), hi(rs2), JUMP(2)), ctx);
568 		emit(rv_bgtu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
569 		emit(rv_bgeu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
570 		break;
571 	case BPF_JGE:
572 		emit(rv_bgtu(hi(rs1), hi(rs2), JUMP(2)), ctx);
573 		emit(rv_bltu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
574 		emit(rv_bltu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
575 		break;
576 	case BPF_JLE:
577 		emit(rv_bltu(hi(rs1), hi(rs2), JUMP(2)), ctx);
578 		emit(rv_bgtu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
579 		emit(rv_bgtu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
580 		break;
581 	case BPF_JNE:
582 		emit(rv_bne(hi(rs1), hi(rs2), JUMP(1)), ctx);
583 		emit(rv_beq(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
584 		break;
585 	case BPF_JSGT:
586 		emit(rv_bgt(hi(rs1), hi(rs2), JUMP(2)), ctx);
587 		emit(rv_blt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
588 		emit(rv_bleu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
589 		break;
590 	case BPF_JSLT:
591 		emit(rv_blt(hi(rs1), hi(rs2), JUMP(2)), ctx);
592 		emit(rv_bgt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
593 		emit(rv_bgeu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
594 		break;
595 	case BPF_JSGE:
596 		emit(rv_bgt(hi(rs1), hi(rs2), JUMP(2)), ctx);
597 		emit(rv_blt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
598 		emit(rv_bltu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
599 		break;
600 	case BPF_JSLE:
601 		emit(rv_blt(hi(rs1), hi(rs2), JUMP(2)), ctx);
602 		emit(rv_bgt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
603 		emit(rv_bgtu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
604 		break;
605 	case BPF_JSET:
606 		emit(rv_and(RV_REG_T0, hi(rs1), hi(rs2)), ctx);
607 		emit(rv_bne(RV_REG_T0, RV_REG_ZERO, JUMP(2)), ctx);
608 		emit(rv_and(RV_REG_T0, lo(rs1), lo(rs2)), ctx);
609 		emit(rv_beq(RV_REG_T0, RV_REG_ZERO, NO_JUMP(0)), ctx);
610 		break;
611 	}
612 
613 #undef NO_JUMP
614 #undef JUMP
615 
616 	e = ctx->ninsns;
617 	/* Adjust for extra insns. */
618 	rvoff -= (e - s) << 2;
619 	emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
620 	return 0;
621 }
622 
623 static int emit_bcc(u8 op, u8 rd, u8 rs, int rvoff, struct rv_jit_context *ctx)
624 {
625 	int e, s = ctx->ninsns;
626 	bool far = false;
627 	int off;
628 
629 	if (op == BPF_JSET) {
630 		/*
631 		 * BPF_JSET is a special case: it has no inverse so we always
632 		 * treat it as a far branch.
633 		 */
634 		far = true;
635 	} else if (!is_13b_int(rvoff)) {
636 		op = invert_bpf_cond(op);
637 		far = true;
638 	}
639 
640 	/*
641 	 * For a far branch, the condition is negated and we jump over the
642 	 * branch itself, and the two instructions from emit_jump_and_link.
643 	 * For a near branch, just use rvoff.
644 	 */
645 	off = far ? 6 : (rvoff >> 1);
646 
647 	switch (op) {
648 	case BPF_JEQ:
649 		emit(rv_beq(rd, rs, off), ctx);
650 		break;
651 	case BPF_JGT:
652 		emit(rv_bgtu(rd, rs, off), ctx);
653 		break;
654 	case BPF_JLT:
655 		emit(rv_bltu(rd, rs, off), ctx);
656 		break;
657 	case BPF_JGE:
658 		emit(rv_bgeu(rd, rs, off), ctx);
659 		break;
660 	case BPF_JLE:
661 		emit(rv_bleu(rd, rs, off), ctx);
662 		break;
663 	case BPF_JNE:
664 		emit(rv_bne(rd, rs, off), ctx);
665 		break;
666 	case BPF_JSGT:
667 		emit(rv_bgt(rd, rs, off), ctx);
668 		break;
669 	case BPF_JSLT:
670 		emit(rv_blt(rd, rs, off), ctx);
671 		break;
672 	case BPF_JSGE:
673 		emit(rv_bge(rd, rs, off), ctx);
674 		break;
675 	case BPF_JSLE:
676 		emit(rv_ble(rd, rs, off), ctx);
677 		break;
678 	case BPF_JSET:
679 		emit(rv_and(RV_REG_T0, rd, rs), ctx);
680 		emit(rv_beq(RV_REG_T0, RV_REG_ZERO, off), ctx);
681 		break;
682 	}
683 
684 	if (far) {
685 		e = ctx->ninsns;
686 		/* Adjust for extra insns. */
687 		rvoff -= (e - s) << 2;
688 		emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
689 	}
690 	return 0;
691 }
692 
693 static int emit_branch_r32(const s8 *src1, const s8 *src2, s32 rvoff,
694 			   struct rv_jit_context *ctx, const u8 op)
695 {
696 	int e, s = ctx->ninsns;
697 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
698 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
699 
700 	const s8 *rs1 = bpf_get_reg32(src1, tmp1, ctx);
701 	const s8 *rs2 = bpf_get_reg32(src2, tmp2, ctx);
702 
703 	e = ctx->ninsns;
704 	/* Adjust for extra insns. */
705 	rvoff -= (e - s) << 2;
706 
707 	if (emit_bcc(op, lo(rs1), lo(rs2), rvoff, ctx))
708 		return -1;
709 
710 	return 0;
711 }
712 
713 static void emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
714 {
715 	const s8 *r0 = bpf2rv32[BPF_REG_0];
716 	const s8 *r5 = bpf2rv32[BPF_REG_5];
717 	u32 upper = ((u32)addr + (1 << 11)) >> 12;
718 	u32 lower = addr & 0xfff;
719 
720 	/* R1-R4 already in correct registers---need to push R5 to stack. */
721 	emit(rv_addi(RV_REG_SP, RV_REG_SP, -16), ctx);
722 	emit(rv_sw(RV_REG_SP, 0, lo(r5)), ctx);
723 	emit(rv_sw(RV_REG_SP, 4, hi(r5)), ctx);
724 
725 	/* Backup TCC. */
726 	emit(rv_addi(RV_REG_TCC_SAVED, RV_REG_TCC, 0), ctx);
727 
728 	/*
729 	 * Use lui/jalr pair to jump to absolute address. Don't use emit_imm as
730 	 * the number of emitted instructions should not depend on the value of
731 	 * addr.
732 	 */
733 	emit(rv_lui(RV_REG_T1, upper), ctx);
734 	emit(rv_jalr(RV_REG_RA, RV_REG_T1, lower), ctx);
735 
736 	/* Restore TCC. */
737 	emit(rv_addi(RV_REG_TCC, RV_REG_TCC_SAVED, 0), ctx);
738 
739 	/* Set return value and restore stack. */
740 	emit(rv_addi(lo(r0), RV_REG_A0, 0), ctx);
741 	emit(rv_addi(hi(r0), RV_REG_A1, 0), ctx);
742 	emit(rv_addi(RV_REG_SP, RV_REG_SP, 16), ctx);
743 }
744 
745 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
746 {
747 	/*
748 	 * R1 -> &ctx
749 	 * R2 -> &array
750 	 * R3 -> index
751 	 */
752 	int tc_ninsn, off, start_insn = ctx->ninsns;
753 	const s8 *arr_reg = bpf2rv32[BPF_REG_2];
754 	const s8 *idx_reg = bpf2rv32[BPF_REG_3];
755 
756 	tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
757 		ctx->offset[0];
758 
759 	/* max_entries = array->map.max_entries; */
760 	off = offsetof(struct bpf_array, map.max_entries);
761 	if (is_12b_check(off, insn))
762 		return -1;
763 	emit(rv_lw(RV_REG_T1, off, lo(arr_reg)), ctx);
764 
765 	/*
766 	 * if (index >= max_entries)
767 	 *   goto out;
768 	 */
769 	off = (tc_ninsn - (ctx->ninsns - start_insn)) << 2;
770 	emit_bcc(BPF_JGE, lo(idx_reg), RV_REG_T1, off, ctx);
771 
772 	/*
773 	 * temp_tcc = tcc - 1;
774 	 * if (tcc < 0)
775 	 *   goto out;
776 	 */
777 	emit(rv_addi(RV_REG_T1, RV_REG_TCC, -1), ctx);
778 	off = (tc_ninsn - (ctx->ninsns - start_insn)) << 2;
779 	emit_bcc(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
780 
781 	/*
782 	 * prog = array->ptrs[index];
783 	 * if (!prog)
784 	 *   goto out;
785 	 */
786 	emit(rv_slli(RV_REG_T0, lo(idx_reg), 2), ctx);
787 	emit(rv_add(RV_REG_T0, RV_REG_T0, lo(arr_reg)), ctx);
788 	off = offsetof(struct bpf_array, ptrs);
789 	if (is_12b_check(off, insn))
790 		return -1;
791 	emit(rv_lw(RV_REG_T0, off, RV_REG_T0), ctx);
792 	off = (tc_ninsn - (ctx->ninsns - start_insn)) << 2;
793 	emit_bcc(BPF_JEQ, RV_REG_T0, RV_REG_ZERO, off, ctx);
794 
795 	/*
796 	 * tcc = temp_tcc;
797 	 * goto *(prog->bpf_func + 4);
798 	 */
799 	off = offsetof(struct bpf_prog, bpf_func);
800 	if (is_12b_check(off, insn))
801 		return -1;
802 	emit(rv_lw(RV_REG_T0, off, RV_REG_T0), ctx);
803 	emit(rv_addi(RV_REG_TCC, RV_REG_T1, 0), ctx);
804 	/* Epilogue jumps to *(t0 + 4). */
805 	__build_epilogue(true, ctx);
806 	return 0;
807 }
808 
809 static int emit_load_r64(const s8 *dst, const s8 *src, s16 off,
810 			 struct rv_jit_context *ctx, const u8 size)
811 {
812 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
813 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
814 	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
815 	const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
816 
817 	emit_imm(RV_REG_T0, off, ctx);
818 	emit(rv_add(RV_REG_T0, RV_REG_T0, lo(rs)), ctx);
819 
820 	switch (size) {
821 	case BPF_B:
822 		emit(rv_lbu(lo(rd), 0, RV_REG_T0), ctx);
823 		if (!ctx->prog->aux->verifier_zext)
824 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
825 		break;
826 	case BPF_H:
827 		emit(rv_lhu(lo(rd), 0, RV_REG_T0), ctx);
828 		if (!ctx->prog->aux->verifier_zext)
829 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
830 		break;
831 	case BPF_W:
832 		emit(rv_lw(lo(rd), 0, RV_REG_T0), ctx);
833 		if (!ctx->prog->aux->verifier_zext)
834 			emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
835 		break;
836 	case BPF_DW:
837 		emit(rv_lw(lo(rd), 0, RV_REG_T0), ctx);
838 		emit(rv_lw(hi(rd), 4, RV_REG_T0), ctx);
839 		break;
840 	}
841 
842 	bpf_put_reg64(dst, rd, ctx);
843 	return 0;
844 }
845 
846 static int emit_store_r64(const s8 *dst, const s8 *src, s16 off,
847 			  struct rv_jit_context *ctx, const u8 size,
848 			  const u8 mode)
849 {
850 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
851 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
852 	const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
853 	const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
854 
855 	if (mode == BPF_XADD && size != BPF_W)
856 		return -1;
857 
858 	emit_imm(RV_REG_T0, off, ctx);
859 	emit(rv_add(RV_REG_T0, RV_REG_T0, lo(rd)), ctx);
860 
861 	switch (size) {
862 	case BPF_B:
863 		emit(rv_sb(RV_REG_T0, 0, lo(rs)), ctx);
864 		break;
865 	case BPF_H:
866 		emit(rv_sh(RV_REG_T0, 0, lo(rs)), ctx);
867 		break;
868 	case BPF_W:
869 		switch (mode) {
870 		case BPF_MEM:
871 			emit(rv_sw(RV_REG_T0, 0, lo(rs)), ctx);
872 			break;
873 		case BPF_XADD:
874 			emit(rv_amoadd_w(RV_REG_ZERO, lo(rs), RV_REG_T0, 0, 0),
875 			     ctx);
876 			break;
877 		}
878 		break;
879 	case BPF_DW:
880 		emit(rv_sw(RV_REG_T0, 0, lo(rs)), ctx);
881 		emit(rv_sw(RV_REG_T0, 4, hi(rs)), ctx);
882 		break;
883 	}
884 
885 	return 0;
886 }
887 
888 static void emit_rev16(const s8 rd, struct rv_jit_context *ctx)
889 {
890 	emit(rv_slli(rd, rd, 16), ctx);
891 	emit(rv_slli(RV_REG_T1, rd, 8), ctx);
892 	emit(rv_srli(rd, rd, 8), ctx);
893 	emit(rv_add(RV_REG_T1, rd, RV_REG_T1), ctx);
894 	emit(rv_srli(rd, RV_REG_T1, 16), ctx);
895 }
896 
897 static void emit_rev32(const s8 rd, struct rv_jit_context *ctx)
898 {
899 	emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 0), ctx);
900 	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
901 	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
902 	emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
903 	emit(rv_srli(rd, rd, 8), ctx);
904 	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
905 	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
906 	emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
907 	emit(rv_srli(rd, rd, 8), ctx);
908 	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
909 	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
910 	emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
911 	emit(rv_srli(rd, rd, 8), ctx);
912 	emit(rv_andi(RV_REG_T0, rd, 255), ctx);
913 	emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
914 	emit(rv_addi(rd, RV_REG_T1, 0), ctx);
915 }
916 
917 static void emit_zext64(const s8 *dst, struct rv_jit_context *ctx)
918 {
919 	const s8 *rd;
920 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
921 
922 	rd = bpf_get_reg64(dst, tmp1, ctx);
923 	emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
924 	bpf_put_reg64(dst, rd, ctx);
925 }
926 
927 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
928 		      bool extra_pass)
929 {
930 	bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
931 		BPF_CLASS(insn->code) == BPF_JMP;
932 	int s, e, rvoff, i = insn - ctx->prog->insnsi;
933 	u8 code = insn->code;
934 	s16 off = insn->off;
935 	s32 imm = insn->imm;
936 
937 	const s8 *dst = bpf2rv32[insn->dst_reg];
938 	const s8 *src = bpf2rv32[insn->src_reg];
939 	const s8 *tmp1 = bpf2rv32[TMP_REG_1];
940 	const s8 *tmp2 = bpf2rv32[TMP_REG_2];
941 
942 	switch (code) {
943 	case BPF_ALU64 | BPF_MOV | BPF_X:
944 
945 	case BPF_ALU64 | BPF_ADD | BPF_X:
946 	case BPF_ALU64 | BPF_ADD | BPF_K:
947 
948 	case BPF_ALU64 | BPF_SUB | BPF_X:
949 	case BPF_ALU64 | BPF_SUB | BPF_K:
950 
951 	case BPF_ALU64 | BPF_AND | BPF_X:
952 	case BPF_ALU64 | BPF_OR | BPF_X:
953 	case BPF_ALU64 | BPF_XOR | BPF_X:
954 
955 	case BPF_ALU64 | BPF_MUL | BPF_X:
956 	case BPF_ALU64 | BPF_MUL | BPF_K:
957 
958 	case BPF_ALU64 | BPF_LSH | BPF_X:
959 	case BPF_ALU64 | BPF_RSH | BPF_X:
960 	case BPF_ALU64 | BPF_ARSH | BPF_X:
961 		if (BPF_SRC(code) == BPF_K) {
962 			emit_imm32(tmp2, imm, ctx);
963 			src = tmp2;
964 		}
965 		emit_alu_r64(dst, src, ctx, BPF_OP(code));
966 		break;
967 
968 	case BPF_ALU64 | BPF_NEG:
969 		emit_alu_r64(dst, tmp2, ctx, BPF_OP(code));
970 		break;
971 
972 	case BPF_ALU64 | BPF_DIV | BPF_X:
973 	case BPF_ALU64 | BPF_DIV | BPF_K:
974 	case BPF_ALU64 | BPF_MOD | BPF_X:
975 	case BPF_ALU64 | BPF_MOD | BPF_K:
976 		goto notsupported;
977 
978 	case BPF_ALU64 | BPF_MOV | BPF_K:
979 	case BPF_ALU64 | BPF_AND | BPF_K:
980 	case BPF_ALU64 | BPF_OR | BPF_K:
981 	case BPF_ALU64 | BPF_XOR | BPF_K:
982 	case BPF_ALU64 | BPF_LSH | BPF_K:
983 	case BPF_ALU64 | BPF_RSH | BPF_K:
984 	case BPF_ALU64 | BPF_ARSH | BPF_K:
985 		emit_alu_i64(dst, imm, ctx, BPF_OP(code));
986 		break;
987 
988 	case BPF_ALU | BPF_MOV | BPF_X:
989 		if (imm == 1) {
990 			/* Special mov32 for zext. */
991 			emit_zext64(dst, ctx);
992 			break;
993 		}
994 		/* Fallthrough. */
995 
996 	case BPF_ALU | BPF_ADD | BPF_X:
997 	case BPF_ALU | BPF_SUB | BPF_X:
998 	case BPF_ALU | BPF_AND | BPF_X:
999 	case BPF_ALU | BPF_OR | BPF_X:
1000 	case BPF_ALU | BPF_XOR | BPF_X:
1001 
1002 	case BPF_ALU | BPF_MUL | BPF_X:
1003 	case BPF_ALU | BPF_MUL | BPF_K:
1004 
1005 	case BPF_ALU | BPF_DIV | BPF_X:
1006 	case BPF_ALU | BPF_DIV | BPF_K:
1007 
1008 	case BPF_ALU | BPF_MOD | BPF_X:
1009 	case BPF_ALU | BPF_MOD | BPF_K:
1010 
1011 	case BPF_ALU | BPF_LSH | BPF_X:
1012 	case BPF_ALU | BPF_RSH | BPF_X:
1013 	case BPF_ALU | BPF_ARSH | BPF_X:
1014 		if (BPF_SRC(code) == BPF_K) {
1015 			emit_imm32(tmp2, imm, ctx);
1016 			src = tmp2;
1017 		}
1018 		emit_alu_r32(dst, src, ctx, BPF_OP(code));
1019 		break;
1020 
1021 	case BPF_ALU | BPF_MOV | BPF_K:
1022 	case BPF_ALU | BPF_ADD | BPF_K:
1023 	case BPF_ALU | BPF_SUB | BPF_K:
1024 	case BPF_ALU | BPF_AND | BPF_K:
1025 	case BPF_ALU | BPF_OR | BPF_K:
1026 	case BPF_ALU | BPF_XOR | BPF_K:
1027 	case BPF_ALU | BPF_LSH | BPF_K:
1028 	case BPF_ALU | BPF_RSH | BPF_K:
1029 	case BPF_ALU | BPF_ARSH | BPF_K:
1030 		/*
1031 		 * mul,div,mod are handled in the BPF_X case since there are
1032 		 * no RISC-V I-type equivalents.
1033 		 */
1034 		emit_alu_i32(dst, imm, ctx, BPF_OP(code));
1035 		break;
1036 
1037 	case BPF_ALU | BPF_NEG:
1038 		/*
1039 		 * src is ignored---choose tmp2 as a dummy register since it
1040 		 * is not on the stack.
1041 		 */
1042 		emit_alu_r32(dst, tmp2, ctx, BPF_OP(code));
1043 		break;
1044 
1045 	case BPF_ALU | BPF_END | BPF_FROM_LE:
1046 	{
1047 		const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1048 
1049 		switch (imm) {
1050 		case 16:
1051 			emit(rv_slli(lo(rd), lo(rd), 16), ctx);
1052 			emit(rv_srli(lo(rd), lo(rd), 16), ctx);
1053 			/* Fallthrough. */
1054 		case 32:
1055 			if (!ctx->prog->aux->verifier_zext)
1056 				emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1057 			break;
1058 		case 64:
1059 			/* Do nothing. */
1060 			break;
1061 		default:
1062 			pr_err("bpf-jit: BPF_END imm %d invalid\n", imm);
1063 			return -1;
1064 		}
1065 
1066 		bpf_put_reg64(dst, rd, ctx);
1067 		break;
1068 	}
1069 
1070 	case BPF_ALU | BPF_END | BPF_FROM_BE:
1071 	{
1072 		const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1073 
1074 		switch (imm) {
1075 		case 16:
1076 			emit_rev16(lo(rd), ctx);
1077 			if (!ctx->prog->aux->verifier_zext)
1078 				emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1079 			break;
1080 		case 32:
1081 			emit_rev32(lo(rd), ctx);
1082 			if (!ctx->prog->aux->verifier_zext)
1083 				emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1084 			break;
1085 		case 64:
1086 			/* Swap upper and lower halves. */
1087 			emit(rv_addi(RV_REG_T0, lo(rd), 0), ctx);
1088 			emit(rv_addi(lo(rd), hi(rd), 0), ctx);
1089 			emit(rv_addi(hi(rd), RV_REG_T0, 0), ctx);
1090 
1091 			/* Swap each half. */
1092 			emit_rev32(lo(rd), ctx);
1093 			emit_rev32(hi(rd), ctx);
1094 			break;
1095 		default:
1096 			pr_err("bpf-jit: BPF_END imm %d invalid\n", imm);
1097 			return -1;
1098 		}
1099 
1100 		bpf_put_reg64(dst, rd, ctx);
1101 		break;
1102 	}
1103 
1104 	case BPF_JMP | BPF_JA:
1105 		rvoff = rv_offset(i, off, ctx);
1106 		emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
1107 		break;
1108 
1109 	case BPF_JMP | BPF_CALL:
1110 	{
1111 		bool fixed;
1112 		int ret;
1113 		u64 addr;
1114 
1115 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass, &addr,
1116 					    &fixed);
1117 		if (ret < 0)
1118 			return ret;
1119 		emit_call(fixed, addr, ctx);
1120 		break;
1121 	}
1122 
1123 	case BPF_JMP | BPF_TAIL_CALL:
1124 		if (emit_bpf_tail_call(i, ctx))
1125 			return -1;
1126 		break;
1127 
1128 	case BPF_JMP | BPF_JEQ | BPF_X:
1129 	case BPF_JMP | BPF_JEQ | BPF_K:
1130 	case BPF_JMP32 | BPF_JEQ | BPF_X:
1131 	case BPF_JMP32 | BPF_JEQ | BPF_K:
1132 
1133 	case BPF_JMP | BPF_JNE | BPF_X:
1134 	case BPF_JMP | BPF_JNE | BPF_K:
1135 	case BPF_JMP32 | BPF_JNE | BPF_X:
1136 	case BPF_JMP32 | BPF_JNE | BPF_K:
1137 
1138 	case BPF_JMP | BPF_JLE | BPF_X:
1139 	case BPF_JMP | BPF_JLE | BPF_K:
1140 	case BPF_JMP32 | BPF_JLE | BPF_X:
1141 	case BPF_JMP32 | BPF_JLE | BPF_K:
1142 
1143 	case BPF_JMP | BPF_JLT | BPF_X:
1144 	case BPF_JMP | BPF_JLT | BPF_K:
1145 	case BPF_JMP32 | BPF_JLT | BPF_X:
1146 	case BPF_JMP32 | BPF_JLT | BPF_K:
1147 
1148 	case BPF_JMP | BPF_JGE | BPF_X:
1149 	case BPF_JMP | BPF_JGE | BPF_K:
1150 	case BPF_JMP32 | BPF_JGE | BPF_X:
1151 	case BPF_JMP32 | BPF_JGE | BPF_K:
1152 
1153 	case BPF_JMP | BPF_JGT | BPF_X:
1154 	case BPF_JMP | BPF_JGT | BPF_K:
1155 	case BPF_JMP32 | BPF_JGT | BPF_X:
1156 	case BPF_JMP32 | BPF_JGT | BPF_K:
1157 
1158 	case BPF_JMP | BPF_JSLE | BPF_X:
1159 	case BPF_JMP | BPF_JSLE | BPF_K:
1160 	case BPF_JMP32 | BPF_JSLE | BPF_X:
1161 	case BPF_JMP32 | BPF_JSLE | BPF_K:
1162 
1163 	case BPF_JMP | BPF_JSLT | BPF_X:
1164 	case BPF_JMP | BPF_JSLT | BPF_K:
1165 	case BPF_JMP32 | BPF_JSLT | BPF_X:
1166 	case BPF_JMP32 | BPF_JSLT | BPF_K:
1167 
1168 	case BPF_JMP | BPF_JSGE | BPF_X:
1169 	case BPF_JMP | BPF_JSGE | BPF_K:
1170 	case BPF_JMP32 | BPF_JSGE | BPF_X:
1171 	case BPF_JMP32 | BPF_JSGE | BPF_K:
1172 
1173 	case BPF_JMP | BPF_JSGT | BPF_X:
1174 	case BPF_JMP | BPF_JSGT | BPF_K:
1175 	case BPF_JMP32 | BPF_JSGT | BPF_X:
1176 	case BPF_JMP32 | BPF_JSGT | BPF_K:
1177 
1178 	case BPF_JMP | BPF_JSET | BPF_X:
1179 	case BPF_JMP | BPF_JSET | BPF_K:
1180 	case BPF_JMP32 | BPF_JSET | BPF_X:
1181 	case BPF_JMP32 | BPF_JSET | BPF_K:
1182 		rvoff = rv_offset(i, off, ctx);
1183 		if (BPF_SRC(code) == BPF_K) {
1184 			s = ctx->ninsns;
1185 			emit_imm32(tmp2, imm, ctx);
1186 			src = tmp2;
1187 			e = ctx->ninsns;
1188 			rvoff -= (e - s) << 2;
1189 		}
1190 
1191 		if (is64)
1192 			emit_branch_r64(dst, src, rvoff, ctx, BPF_OP(code));
1193 		else
1194 			emit_branch_r32(dst, src, rvoff, ctx, BPF_OP(code));
1195 		break;
1196 
1197 	case BPF_JMP | BPF_EXIT:
1198 		if (i == ctx->prog->len - 1)
1199 			break;
1200 
1201 		rvoff = epilogue_offset(ctx);
1202 		emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
1203 		break;
1204 
1205 	case BPF_LD | BPF_IMM | BPF_DW:
1206 	{
1207 		struct bpf_insn insn1 = insn[1];
1208 		s32 imm_lo = imm;
1209 		s32 imm_hi = insn1.imm;
1210 		const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1211 
1212 		emit_imm64(rd, imm_hi, imm_lo, ctx);
1213 		bpf_put_reg64(dst, rd, ctx);
1214 		return 1;
1215 	}
1216 
1217 	case BPF_LDX | BPF_MEM | BPF_B:
1218 	case BPF_LDX | BPF_MEM | BPF_H:
1219 	case BPF_LDX | BPF_MEM | BPF_W:
1220 	case BPF_LDX | BPF_MEM | BPF_DW:
1221 		if (emit_load_r64(dst, src, off, ctx, BPF_SIZE(code)))
1222 			return -1;
1223 		break;
1224 
1225 	case BPF_ST | BPF_MEM | BPF_B:
1226 	case BPF_ST | BPF_MEM | BPF_H:
1227 	case BPF_ST | BPF_MEM | BPF_W:
1228 	case BPF_ST | BPF_MEM | BPF_DW:
1229 
1230 	case BPF_STX | BPF_MEM | BPF_B:
1231 	case BPF_STX | BPF_MEM | BPF_H:
1232 	case BPF_STX | BPF_MEM | BPF_W:
1233 	case BPF_STX | BPF_MEM | BPF_DW:
1234 	case BPF_STX | BPF_XADD | BPF_W:
1235 		if (BPF_CLASS(code) == BPF_ST) {
1236 			emit_imm32(tmp2, imm, ctx);
1237 			src = tmp2;
1238 		}
1239 
1240 		if (emit_store_r64(dst, src, off, ctx, BPF_SIZE(code),
1241 				   BPF_MODE(code)))
1242 			return -1;
1243 		break;
1244 
1245 	/* No hardware support for 8-byte atomics in RV32. */
1246 	case BPF_STX | BPF_XADD | BPF_DW:
1247 		/* Fallthrough. */
1248 
1249 notsupported:
1250 		pr_info_once("bpf-jit: not supported: opcode %02x ***\n", code);
1251 		return -EFAULT;
1252 
1253 	default:
1254 		pr_err("bpf-jit: unknown opcode %02x\n", code);
1255 		return -EINVAL;
1256 	}
1257 
1258 	return 0;
1259 }
1260 
1261 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1262 {
1263 	/* Make space to save 9 registers: ra, fp, s1--s7. */
1264 	int stack_adjust = 9 * sizeof(u32), store_offset, bpf_stack_adjust;
1265 	const s8 *fp = bpf2rv32[BPF_REG_FP];
1266 	const s8 *r1 = bpf2rv32[BPF_REG_1];
1267 
1268 	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1269 	stack_adjust += bpf_stack_adjust;
1270 
1271 	store_offset = stack_adjust - 4;
1272 
1273 	stack_adjust += 4 * BPF_JIT_SCRATCH_REGS;
1274 
1275 	/*
1276 	 * The first instruction sets the tail-call-counter (TCC) register.
1277 	 * This instruction is skipped by tail calls.
1278 	 */
1279 	emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1280 
1281 	emit(rv_addi(RV_REG_SP, RV_REG_SP, -stack_adjust), ctx);
1282 
1283 	/* Save callee-save registers. */
1284 	emit(rv_sw(RV_REG_SP, store_offset - 0, RV_REG_RA), ctx);
1285 	emit(rv_sw(RV_REG_SP, store_offset - 4, RV_REG_FP), ctx);
1286 	emit(rv_sw(RV_REG_SP, store_offset - 8, RV_REG_S1), ctx);
1287 	emit(rv_sw(RV_REG_SP, store_offset - 12, RV_REG_S2), ctx);
1288 	emit(rv_sw(RV_REG_SP, store_offset - 16, RV_REG_S3), ctx);
1289 	emit(rv_sw(RV_REG_SP, store_offset - 20, RV_REG_S4), ctx);
1290 	emit(rv_sw(RV_REG_SP, store_offset - 24, RV_REG_S5), ctx);
1291 	emit(rv_sw(RV_REG_SP, store_offset - 28, RV_REG_S6), ctx);
1292 	emit(rv_sw(RV_REG_SP, store_offset - 32, RV_REG_S7), ctx);
1293 
1294 	/* Set fp: used as the base address for stacked BPF registers. */
1295 	emit(rv_addi(RV_REG_FP, RV_REG_SP, stack_adjust), ctx);
1296 
1297 	/* Set up BPF stack pointer. */
1298 	emit(rv_addi(lo(fp), RV_REG_SP, bpf_stack_adjust), ctx);
1299 	emit(rv_addi(hi(fp), RV_REG_ZERO, 0), ctx);
1300 
1301 	/* Set up context pointer. */
1302 	emit(rv_addi(lo(r1), RV_REG_A0, 0), ctx);
1303 	emit(rv_addi(hi(r1), RV_REG_ZERO, 0), ctx);
1304 
1305 	ctx->stack_size = stack_adjust;
1306 }
1307 
1308 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1309 {
1310 	__build_epilogue(false, ctx);
1311 }
1312