xref: /linux/kernel/static_call.c (revision 8b83369ddcb3fb9cab5c1088987ce477565bb630)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/init.h>
3 #include <linux/static_call.h>
4 #include <linux/bug.h>
5 #include <linux/smp.h>
6 #include <linux/sort.h>
7 #include <linux/slab.h>
8 #include <linux/module.h>
9 #include <linux/cpu.h>
10 #include <linux/processor.h>
11 #include <asm/sections.h>
12 
13 extern struct static_call_site __start_static_call_sites[],
14 			       __stop_static_call_sites[];
15 extern struct static_call_tramp_key __start_static_call_tramp_key[],
16 				    __stop_static_call_tramp_key[];
17 
18 static bool static_call_initialized;
19 
20 /* mutex to protect key modules/sites */
21 static DEFINE_MUTEX(static_call_mutex);
22 
23 static void static_call_lock(void)
24 {
25 	mutex_lock(&static_call_mutex);
26 }
27 
28 static void static_call_unlock(void)
29 {
30 	mutex_unlock(&static_call_mutex);
31 }
32 
33 static inline void *static_call_addr(struct static_call_site *site)
34 {
35 	return (void *)((long)site->addr + (long)&site->addr);
36 }
37 
38 
39 static inline struct static_call_key *static_call_key(const struct static_call_site *site)
40 {
41 	return (struct static_call_key *)
42 		(((long)site->key + (long)&site->key) & ~STATIC_CALL_SITE_FLAGS);
43 }
44 
45 /* These assume the key is word-aligned. */
46 static inline bool static_call_is_init(struct static_call_site *site)
47 {
48 	return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_INIT;
49 }
50 
51 static inline bool static_call_is_tail(struct static_call_site *site)
52 {
53 	return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_TAIL;
54 }
55 
56 static inline void static_call_set_init(struct static_call_site *site)
57 {
58 	site->key = ((long)static_call_key(site) | STATIC_CALL_SITE_INIT) -
59 		    (long)&site->key;
60 }
61 
62 static int static_call_site_cmp(const void *_a, const void *_b)
63 {
64 	const struct static_call_site *a = _a;
65 	const struct static_call_site *b = _b;
66 	const struct static_call_key *key_a = static_call_key(a);
67 	const struct static_call_key *key_b = static_call_key(b);
68 
69 	if (key_a < key_b)
70 		return -1;
71 
72 	if (key_a > key_b)
73 		return 1;
74 
75 	return 0;
76 }
77 
78 static void static_call_site_swap(void *_a, void *_b, int size)
79 {
80 	long delta = (unsigned long)_a - (unsigned long)_b;
81 	struct static_call_site *a = _a;
82 	struct static_call_site *b = _b;
83 	struct static_call_site tmp = *a;
84 
85 	a->addr = b->addr  - delta;
86 	a->key  = b->key   - delta;
87 
88 	b->addr = tmp.addr + delta;
89 	b->key  = tmp.key  + delta;
90 }
91 
92 static inline void static_call_sort_entries(struct static_call_site *start,
93 					    struct static_call_site *stop)
94 {
95 	sort(start, stop - start, sizeof(struct static_call_site),
96 	     static_call_site_cmp, static_call_site_swap);
97 }
98 
99 static inline bool static_call_key_has_mods(struct static_call_key *key)
100 {
101 	return !(key->type & 1);
102 }
103 
104 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key)
105 {
106 	if (!static_call_key_has_mods(key))
107 		return NULL;
108 
109 	return key->mods;
110 }
111 
112 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key)
113 {
114 	if (static_call_key_has_mods(key))
115 		return NULL;
116 
117 	return (struct static_call_site *)(key->type & ~1);
118 }
119 
120 void __static_call_update(struct static_call_key *key, void *tramp, void *func)
121 {
122 	struct static_call_site *site, *stop;
123 	struct static_call_mod *site_mod, first;
124 
125 	cpus_read_lock();
126 	static_call_lock();
127 
128 	if (key->func == func)
129 		goto done;
130 
131 	key->func = func;
132 
133 	arch_static_call_transform(NULL, tramp, func, false);
134 
135 	/*
136 	 * If uninitialized, we'll not update the callsites, but they still
137 	 * point to the trampoline and we just patched that.
138 	 */
139 	if (WARN_ON_ONCE(!static_call_initialized))
140 		goto done;
141 
142 	first = (struct static_call_mod){
143 		.next = static_call_key_next(key),
144 		.mod = NULL,
145 		.sites = static_call_key_sites(key),
146 	};
147 
148 	for (site_mod = &first; site_mod; site_mod = site_mod->next) {
149 		struct module *mod = site_mod->mod;
150 
151 		if (!site_mod->sites) {
152 			/*
153 			 * This can happen if the static call key is defined in
154 			 * a module which doesn't use it.
155 			 *
156 			 * It also happens in the has_mods case, where the
157 			 * 'first' entry has no sites associated with it.
158 			 */
159 			continue;
160 		}
161 
162 		stop = __stop_static_call_sites;
163 
164 #ifdef CONFIG_MODULES
165 		if (mod) {
166 			stop = mod->static_call_sites +
167 			       mod->num_static_call_sites;
168 		}
169 #endif
170 
171 		for (site = site_mod->sites;
172 		     site < stop && static_call_key(site) == key; site++) {
173 			void *site_addr = static_call_addr(site);
174 
175 			if (static_call_is_init(site)) {
176 				/*
177 				 * Don't write to call sites which were in
178 				 * initmem and have since been freed.
179 				 */
180 				if (!mod && system_state >= SYSTEM_RUNNING)
181 					continue;
182 				if (mod && !within_module_init((unsigned long)site_addr, mod))
183 					continue;
184 			}
185 
186 			if (!kernel_text_address((unsigned long)site_addr)) {
187 				WARN_ONCE(1, "can't patch static call site at %pS",
188 					  site_addr);
189 				continue;
190 			}
191 
192 			arch_static_call_transform(site_addr, NULL, func,
193 				static_call_is_tail(site));
194 		}
195 	}
196 
197 done:
198 	static_call_unlock();
199 	cpus_read_unlock();
200 }
201 EXPORT_SYMBOL_GPL(__static_call_update);
202 
203 static int __static_call_init(struct module *mod,
204 			      struct static_call_site *start,
205 			      struct static_call_site *stop)
206 {
207 	struct static_call_site *site;
208 	struct static_call_key *key, *prev_key = NULL;
209 	struct static_call_mod *site_mod;
210 
211 	if (start == stop)
212 		return 0;
213 
214 	static_call_sort_entries(start, stop);
215 
216 	for (site = start; site < stop; site++) {
217 		void *site_addr = static_call_addr(site);
218 
219 		if ((mod && within_module_init((unsigned long)site_addr, mod)) ||
220 		    (!mod && init_section_contains(site_addr, 1)))
221 			static_call_set_init(site);
222 
223 		key = static_call_key(site);
224 		if (key != prev_key) {
225 			prev_key = key;
226 
227 			/*
228 			 * For vmlinux (!mod) avoid the allocation by storing
229 			 * the sites pointer in the key itself. Also see
230 			 * __static_call_update()'s @first.
231 			 *
232 			 * This allows architectures (eg. x86) to call
233 			 * static_call_init() before memory allocation works.
234 			 */
235 			if (!mod) {
236 				key->sites = site;
237 				key->type |= 1;
238 				goto do_transform;
239 			}
240 
241 			site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
242 			if (!site_mod)
243 				return -ENOMEM;
244 
245 			/*
246 			 * When the key has a direct sites pointer, extract
247 			 * that into an explicit struct static_call_mod, so we
248 			 * can have a list of modules.
249 			 */
250 			if (static_call_key_sites(key)) {
251 				site_mod->mod = NULL;
252 				site_mod->next = NULL;
253 				site_mod->sites = static_call_key_sites(key);
254 
255 				key->mods = site_mod;
256 
257 				site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
258 				if (!site_mod)
259 					return -ENOMEM;
260 			}
261 
262 			site_mod->mod = mod;
263 			site_mod->sites = site;
264 			site_mod->next = static_call_key_next(key);
265 			key->mods = site_mod;
266 		}
267 
268 do_transform:
269 		arch_static_call_transform(site_addr, NULL, key->func,
270 				static_call_is_tail(site));
271 	}
272 
273 	return 0;
274 }
275 
276 static int addr_conflict(struct static_call_site *site, void *start, void *end)
277 {
278 	unsigned long addr = (unsigned long)static_call_addr(site);
279 
280 	if (addr <= (unsigned long)end &&
281 	    addr + CALL_INSN_SIZE > (unsigned long)start)
282 		return 1;
283 
284 	return 0;
285 }
286 
287 static int __static_call_text_reserved(struct static_call_site *iter_start,
288 				       struct static_call_site *iter_stop,
289 				       void *start, void *end)
290 {
291 	struct static_call_site *iter = iter_start;
292 
293 	while (iter < iter_stop) {
294 		if (addr_conflict(iter, start, end))
295 			return 1;
296 		iter++;
297 	}
298 
299 	return 0;
300 }
301 
302 #ifdef CONFIG_MODULES
303 
304 static int __static_call_mod_text_reserved(void *start, void *end)
305 {
306 	struct module *mod;
307 	int ret;
308 
309 	preempt_disable();
310 	mod = __module_text_address((unsigned long)start);
311 	WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod);
312 	if (!try_module_get(mod))
313 		mod = NULL;
314 	preempt_enable();
315 
316 	if (!mod)
317 		return 0;
318 
319 	ret = __static_call_text_reserved(mod->static_call_sites,
320 			mod->static_call_sites + mod->num_static_call_sites,
321 			start, end);
322 
323 	module_put(mod);
324 
325 	return ret;
326 }
327 
328 static unsigned long tramp_key_lookup(unsigned long addr)
329 {
330 	struct static_call_tramp_key *start = __start_static_call_tramp_key;
331 	struct static_call_tramp_key *stop = __stop_static_call_tramp_key;
332 	struct static_call_tramp_key *tramp_key;
333 
334 	for (tramp_key = start; tramp_key != stop; tramp_key++) {
335 		unsigned long tramp;
336 
337 		tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp;
338 		if (tramp == addr)
339 			return (long)tramp_key->key + (long)&tramp_key->key;
340 	}
341 
342 	return 0;
343 }
344 
345 static int static_call_add_module(struct module *mod)
346 {
347 	struct static_call_site *start = mod->static_call_sites;
348 	struct static_call_site *stop = start + mod->num_static_call_sites;
349 	struct static_call_site *site;
350 
351 	for (site = start; site != stop; site++) {
352 		unsigned long addr = (unsigned long)static_call_key(site);
353 		unsigned long key;
354 
355 		/*
356 		 * Is the key is exported, 'addr' points to the key, which
357 		 * means modules are allowed to call static_call_update() on
358 		 * it.
359 		 *
360 		 * Otherwise, the key isn't exported, and 'addr' points to the
361 		 * trampoline so we need to lookup the key.
362 		 *
363 		 * We go through this dance to prevent crazy modules from
364 		 * abusing sensitive static calls.
365 		 */
366 		if (!kernel_text_address(addr))
367 			continue;
368 
369 		key = tramp_key_lookup(addr);
370 		if (!key) {
371 			pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n",
372 				static_call_addr(site));
373 			return -EINVAL;
374 		}
375 
376 		site->key = (key - (long)&site->key) |
377 			    (site->key & STATIC_CALL_SITE_FLAGS);
378 	}
379 
380 	return __static_call_init(mod, start, stop);
381 }
382 
383 static void static_call_del_module(struct module *mod)
384 {
385 	struct static_call_site *start = mod->static_call_sites;
386 	struct static_call_site *stop = mod->static_call_sites +
387 					mod->num_static_call_sites;
388 	struct static_call_key *key, *prev_key = NULL;
389 	struct static_call_mod *site_mod, **prev;
390 	struct static_call_site *site;
391 
392 	for (site = start; site < stop; site++) {
393 		key = static_call_key(site);
394 		if (key == prev_key)
395 			continue;
396 
397 		prev_key = key;
398 
399 		for (prev = &key->mods, site_mod = key->mods;
400 		     site_mod && site_mod->mod != mod;
401 		     prev = &site_mod->next, site_mod = site_mod->next)
402 			;
403 
404 		if (!site_mod)
405 			continue;
406 
407 		*prev = site_mod->next;
408 		kfree(site_mod);
409 	}
410 }
411 
412 static int static_call_module_notify(struct notifier_block *nb,
413 				     unsigned long val, void *data)
414 {
415 	struct module *mod = data;
416 	int ret = 0;
417 
418 	cpus_read_lock();
419 	static_call_lock();
420 
421 	switch (val) {
422 	case MODULE_STATE_COMING:
423 		ret = static_call_add_module(mod);
424 		if (ret) {
425 			WARN(1, "Failed to allocate memory for static calls");
426 			static_call_del_module(mod);
427 		}
428 		break;
429 	case MODULE_STATE_GOING:
430 		static_call_del_module(mod);
431 		break;
432 	}
433 
434 	static_call_unlock();
435 	cpus_read_unlock();
436 
437 	return notifier_from_errno(ret);
438 }
439 
440 static struct notifier_block static_call_module_nb = {
441 	.notifier_call = static_call_module_notify,
442 };
443 
444 #else
445 
446 static inline int __static_call_mod_text_reserved(void *start, void *end)
447 {
448 	return 0;
449 }
450 
451 #endif /* CONFIG_MODULES */
452 
453 int static_call_text_reserved(void *start, void *end)
454 {
455 	int ret = __static_call_text_reserved(__start_static_call_sites,
456 			__stop_static_call_sites, start, end);
457 
458 	if (ret)
459 		return ret;
460 
461 	return __static_call_mod_text_reserved(start, end);
462 }
463 
464 int __init static_call_init(void)
465 {
466 	int ret;
467 
468 	if (static_call_initialized)
469 		return 0;
470 
471 	cpus_read_lock();
472 	static_call_lock();
473 	ret = __static_call_init(NULL, __start_static_call_sites,
474 				 __stop_static_call_sites);
475 	static_call_unlock();
476 	cpus_read_unlock();
477 
478 	if (ret) {
479 		pr_err("Failed to allocate memory for static_call!\n");
480 		BUG();
481 	}
482 
483 	static_call_initialized = true;
484 
485 #ifdef CONFIG_MODULES
486 	register_module_notifier(&static_call_module_nb);
487 #endif
488 	return 0;
489 }
490 early_initcall(static_call_init);
491 
492 long __static_call_return0(void)
493 {
494 	return 0;
495 }
496 
497 #ifdef CONFIG_STATIC_CALL_SELFTEST
498 
499 static int func_a(int x)
500 {
501 	return x+1;
502 }
503 
504 static int func_b(int x)
505 {
506 	return x+2;
507 }
508 
509 DEFINE_STATIC_CALL(sc_selftest, func_a);
510 
511 static struct static_call_data {
512       int (*func)(int);
513       int val;
514       int expect;
515 } static_call_data [] __initdata = {
516       { NULL,   2, 3 },
517       { func_b, 2, 4 },
518       { func_a, 2, 3 }
519 };
520 
521 static int __init test_static_call_init(void)
522 {
523       int i;
524 
525       for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) {
526 	      struct static_call_data *scd = &static_call_data[i];
527 
528               if (scd->func)
529                       static_call_update(sc_selftest, scd->func);
530 
531               WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect);
532       }
533 
534       return 0;
535 }
536 early_initcall(test_static_call_init);
537 
538 #endif /* CONFIG_STATIC_CALL_SELFTEST */
539