xref: /linux/arch/x86/kernel/static_call.c (revision 0a94608f0f7de9b1135ffea3546afe68eafef57f)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/static_call.h>
3 #include <linux/memory.h>
4 #include <linux/bug.h>
5 #include <asm/text-patching.h>
6 
7 enum insn_type {
8 	CALL = 0, /* site call */
9 	NOP = 1,  /* site cond-call */
10 	JMP = 2,  /* tramp / site tail-call */
11 	RET = 3,  /* tramp / site cond-tail-call */
12 };
13 
14 /*
15  * cs cs cs xorl %eax, %eax - a single 5 byte instruction that clears %[er]ax
16  */
17 static const u8 xor5rax[] = { 0x2e, 0x2e, 0x2e, 0x31, 0xc0 };
18 
19 static const u8 retinsn[] = { RET_INSN_OPCODE, 0xcc, 0xcc, 0xcc, 0xcc };
20 
21 static void __ref __static_call_transform(void *insn, enum insn_type type, void *func)
22 {
23 	const void *emulate = NULL;
24 	int size = CALL_INSN_SIZE;
25 	const void *code;
26 
27 	switch (type) {
28 	case CALL:
29 		code = text_gen_insn(CALL_INSN_OPCODE, insn, func);
30 		if (func == &__static_call_return0) {
31 			emulate = code;
32 			code = &xor5rax;
33 		}
34 
35 		break;
36 
37 	case NOP:
38 		code = x86_nops[5];
39 		break;
40 
41 	case JMP:
42 		code = text_gen_insn(JMP32_INSN_OPCODE, insn, func);
43 		break;
44 
45 	case RET:
46 		code = &retinsn;
47 		break;
48 	}
49 
50 	if (memcmp(insn, code, size) == 0)
51 		return;
52 
53 	if (unlikely(system_state == SYSTEM_BOOTING))
54 		return text_poke_early(insn, code, size);
55 
56 	text_poke_bp(insn, code, size, emulate);
57 }
58 
59 static void __static_call_validate(void *insn, bool tail, bool tramp)
60 {
61 	u8 opcode = *(u8 *)insn;
62 
63 	if (tramp && memcmp(insn+5, "SCT", 3)) {
64 		pr_err("trampoline signature fail");
65 		BUG();
66 	}
67 
68 	if (tail) {
69 		if (opcode == JMP32_INSN_OPCODE ||
70 		    opcode == RET_INSN_OPCODE)
71 			return;
72 	} else {
73 		if (opcode == CALL_INSN_OPCODE ||
74 		    !memcmp(insn, x86_nops[5], 5) ||
75 		    !memcmp(insn, xor5rax, 5))
76 			return;
77 	}
78 
79 	/*
80 	 * If we ever trigger this, our text is corrupt, we'll probably not live long.
81 	 */
82 	pr_err("unexpected static_call insn opcode 0x%x at %pS\n", opcode, insn);
83 	BUG();
84 }
85 
86 static inline enum insn_type __sc_insn(bool null, bool tail)
87 {
88 	/*
89 	 * Encode the following table without branches:
90 	 *
91 	 *	tail	null	insn
92 	 *	-----+-------+------
93 	 *	  0  |   0   |  CALL
94 	 *	  0  |   1   |  NOP
95 	 *	  1  |   0   |  JMP
96 	 *	  1  |   1   |  RET
97 	 */
98 	return 2*tail + null;
99 }
100 
101 void arch_static_call_transform(void *site, void *tramp, void *func, bool tail)
102 {
103 	mutex_lock(&text_mutex);
104 
105 	if (tramp) {
106 		__static_call_validate(tramp, true, true);
107 		__static_call_transform(tramp, __sc_insn(!func, true), func);
108 	}
109 
110 	if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site) {
111 		__static_call_validate(site, tail, false);
112 		__static_call_transform(site, __sc_insn(!func, tail), func);
113 	}
114 
115 	mutex_unlock(&text_mutex);
116 }
117 EXPORT_SYMBOL_GPL(arch_static_call_transform);
118