xref: /linux/kernel/jump_label.c (revision 3bc70ad12097c19cd6c687bc5b12c31da14b63f7)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * jump label support
4  *
5  * Copyright (C) 2009 Jason Baron <jbaron@redhat.com>
6  * Copyright (C) 2011 Peter Zijlstra
7  *
8  */
9 #include <linux/memory.h>
10 #include <linux/uaccess.h>
11 #include <linux/module.h>
12 #include <linux/list.h>
13 #include <linux/slab.h>
14 #include <linux/sort.h>
15 #include <linux/err.h>
16 #include <linux/static_key.h>
17 #include <linux/jump_label_ratelimit.h>
18 #include <linux/bug.h>
19 #include <linux/cpu.h>
20 #include <asm/sections.h>
21 
22 /* mutex to protect coming/going of the jump_label table */
23 static DEFINE_MUTEX(jump_label_mutex);
24 
jump_label_lock(void)25 void jump_label_lock(void)
26 {
27 	mutex_lock(&jump_label_mutex);
28 }
29 
jump_label_unlock(void)30 void jump_label_unlock(void)
31 {
32 	mutex_unlock(&jump_label_mutex);
33 }
34 
jump_label_cmp(const void * a,const void * b)35 static int jump_label_cmp(const void *a, const void *b)
36 {
37 	const struct jump_entry *jea = a;
38 	const struct jump_entry *jeb = b;
39 
40 	/*
41 	 * Entrires are sorted by key.
42 	 */
43 	if (jump_entry_key(jea) < jump_entry_key(jeb))
44 		return -1;
45 
46 	if (jump_entry_key(jea) > jump_entry_key(jeb))
47 		return 1;
48 
49 	/*
50 	 * In the batching mode, entries should also be sorted by the code
51 	 * inside the already sorted list of entries, enabling a bsearch in
52 	 * the vector.
53 	 */
54 	if (jump_entry_code(jea) < jump_entry_code(jeb))
55 		return -1;
56 
57 	if (jump_entry_code(jea) > jump_entry_code(jeb))
58 		return 1;
59 
60 	return 0;
61 }
62 
jump_label_swap(void * a,void * b,int size)63 static void jump_label_swap(void *a, void *b, int size)
64 {
65 	long delta = (unsigned long)a - (unsigned long)b;
66 	struct jump_entry *jea = a;
67 	struct jump_entry *jeb = b;
68 	struct jump_entry tmp = *jea;
69 
70 	jea->code	= jeb->code - delta;
71 	jea->target	= jeb->target - delta;
72 	jea->key	= jeb->key - delta;
73 
74 	jeb->code	= tmp.code + delta;
75 	jeb->target	= tmp.target + delta;
76 	jeb->key	= tmp.key + delta;
77 }
78 
79 static void
jump_label_sort_entries(struct jump_entry * start,struct jump_entry * stop)80 jump_label_sort_entries(struct jump_entry *start, struct jump_entry *stop)
81 {
82 	unsigned long size;
83 	void *swapfn = NULL;
84 
85 	if (IS_ENABLED(CONFIG_HAVE_ARCH_JUMP_LABEL_RELATIVE))
86 		swapfn = jump_label_swap;
87 
88 	size = (((unsigned long)stop - (unsigned long)start)
89 					/ sizeof(struct jump_entry));
90 	sort(start, size, sizeof(struct jump_entry), jump_label_cmp, swapfn);
91 }
92 
93 static void jump_label_update(struct static_key *key);
94 
95 /*
96  * There are similar definitions for the !CONFIG_JUMP_LABEL case in jump_label.h.
97  * The use of 'atomic_read()' requires atomic.h and its problematic for some
98  * kernel headers such as kernel.h and others. Since static_key_count() is not
99  * used in the branch statements as it is for the !CONFIG_JUMP_LABEL case its ok
100  * to have it be a function here. Similarly, for 'static_key_enable()' and
101  * 'static_key_disable()', which require bug.h. This should allow jump_label.h
102  * to be included from most/all places for CONFIG_JUMP_LABEL.
103  */
static_key_count(struct static_key * key)104 int static_key_count(struct static_key *key)
105 {
106 	/*
107 	 * -1 means the first static_key_slow_inc() is in progress.
108 	 *  static_key_enabled() must return true, so return 1 here.
109 	 */
110 	int n = atomic_read(&key->enabled);
111 
112 	return n >= 0 ? n : 1;
113 }
114 EXPORT_SYMBOL_GPL(static_key_count);
115 
116 /*
117  * static_key_fast_inc_not_disabled - adds a user for a static key
118  * @key: static key that must be already enabled
119  *
120  * The caller must make sure that the static key can't get disabled while
121  * in this function. It doesn't patch jump labels, only adds a user to
122  * an already enabled static key.
123  *
124  * Returns true if the increment was done. Unlike refcount_t the ref counter
125  * is not saturated, but will fail to increment on overflow.
126  */
static_key_fast_inc_not_disabled(struct static_key * key)127 bool static_key_fast_inc_not_disabled(struct static_key *key)
128 {
129 	int v;
130 
131 	STATIC_KEY_CHECK_USE(key);
132 	/*
133 	 * Negative key->enabled has a special meaning: it sends
134 	 * static_key_slow_inc/dec() down the slow path, and it is non-zero
135 	 * so it counts as "enabled" in jump_label_update().
136 	 *
137 	 * The INT_MAX overflow condition is either used by the networking
138 	 * code to reset or detected in the slow path of
139 	 * static_key_slow_inc_cpuslocked().
140 	 */
141 	v = atomic_read(&key->enabled);
142 	do {
143 		if (v <= 0 || v == INT_MAX)
144 			return false;
145 	} while (!likely(atomic_try_cmpxchg(&key->enabled, &v, v + 1)));
146 
147 	return true;
148 }
149 EXPORT_SYMBOL_GPL(static_key_fast_inc_not_disabled);
150 
static_key_slow_inc_cpuslocked(struct static_key * key)151 bool static_key_slow_inc_cpuslocked(struct static_key *key)
152 {
153 	lockdep_assert_cpus_held();
154 
155 	/*
156 	 * Careful if we get concurrent static_key_slow_inc/dec() calls;
157 	 * later calls must wait for the first one to _finish_ the
158 	 * jump_label_update() process.  At the same time, however,
159 	 * the jump_label_update() call below wants to see
160 	 * static_key_enabled(&key) for jumps to be updated properly.
161 	 */
162 	if (static_key_fast_inc_not_disabled(key))
163 		return true;
164 
165 	guard(mutex)(&jump_label_mutex);
166 	/* Try to mark it as 'enabling in progress. */
167 	if (!atomic_cmpxchg(&key->enabled, 0, -1)) {
168 		jump_label_update(key);
169 		/*
170 		 * Ensure that when static_key_fast_inc_not_disabled() or
171 		 * static_key_slow_try_dec() observe the positive value,
172 		 * they must also observe all the text changes.
173 		 */
174 		atomic_set_release(&key->enabled, 1);
175 	} else {
176 		/*
177 		 * While holding the mutex this should never observe
178 		 * anything else than a value >= 1 and succeed
179 		 */
180 		if (WARN_ON_ONCE(!static_key_fast_inc_not_disabled(key)))
181 			return false;
182 	}
183 	return true;
184 }
185 
static_key_slow_inc(struct static_key * key)186 bool static_key_slow_inc(struct static_key *key)
187 {
188 	bool ret;
189 
190 	cpus_read_lock();
191 	ret = static_key_slow_inc_cpuslocked(key);
192 	cpus_read_unlock();
193 	return ret;
194 }
195 EXPORT_SYMBOL_GPL(static_key_slow_inc);
196 
static_key_enable_cpuslocked(struct static_key * key)197 void static_key_enable_cpuslocked(struct static_key *key)
198 {
199 	STATIC_KEY_CHECK_USE(key);
200 	lockdep_assert_cpus_held();
201 
202 	if (atomic_read(&key->enabled) > 0) {
203 		WARN_ON_ONCE(atomic_read(&key->enabled) != 1);
204 		return;
205 	}
206 
207 	jump_label_lock();
208 	if (atomic_read(&key->enabled) == 0) {
209 		atomic_set(&key->enabled, -1);
210 		jump_label_update(key);
211 		/*
212 		 * See static_key_slow_inc().
213 		 */
214 		atomic_set_release(&key->enabled, 1);
215 	}
216 	jump_label_unlock();
217 }
218 EXPORT_SYMBOL_GPL(static_key_enable_cpuslocked);
219 
static_key_enable(struct static_key * key)220 void static_key_enable(struct static_key *key)
221 {
222 	cpus_read_lock();
223 	static_key_enable_cpuslocked(key);
224 	cpus_read_unlock();
225 }
226 EXPORT_SYMBOL_GPL(static_key_enable);
227 
static_key_disable_cpuslocked(struct static_key * key)228 void static_key_disable_cpuslocked(struct static_key *key)
229 {
230 	STATIC_KEY_CHECK_USE(key);
231 	lockdep_assert_cpus_held();
232 
233 	if (atomic_read(&key->enabled) != 1) {
234 		WARN_ON_ONCE(atomic_read(&key->enabled) != 0);
235 		return;
236 	}
237 
238 	jump_label_lock();
239 	if (atomic_cmpxchg(&key->enabled, 1, 0) == 1)
240 		jump_label_update(key);
241 	jump_label_unlock();
242 }
243 EXPORT_SYMBOL_GPL(static_key_disable_cpuslocked);
244 
static_key_disable(struct static_key * key)245 void static_key_disable(struct static_key *key)
246 {
247 	cpus_read_lock();
248 	static_key_disable_cpuslocked(key);
249 	cpus_read_unlock();
250 }
251 EXPORT_SYMBOL_GPL(static_key_disable);
252 
static_key_slow_try_dec(struct static_key * key)253 static bool static_key_slow_try_dec(struct static_key *key)
254 {
255 	int v;
256 
257 	/*
258 	 * Go into the slow path if key::enabled is less than or equal than
259 	 * one. One is valid to shut down the key, anything less than one
260 	 * is an imbalance, which is handled at the call site.
261 	 *
262 	 * That includes the special case of '-1' which is set in
263 	 * static_key_slow_inc_cpuslocked(), but that's harmless as it is
264 	 * fully serialized in the slow path below. By the time this task
265 	 * acquires the jump label lock the value is back to one and the
266 	 * retry under the lock must succeed.
267 	 */
268 	v = atomic_read(&key->enabled);
269 	do {
270 		/*
271 		 * Warn about the '-1' case though; since that means a
272 		 * decrement is concurrent with a first (0->1) increment. IOW
273 		 * people are trying to disable something that wasn't yet fully
274 		 * enabled. This suggests an ordering problem on the user side.
275 		 */
276 		WARN_ON_ONCE(v < 0);
277 		if (v <= 1)
278 			return false;
279 	} while (!likely(atomic_try_cmpxchg(&key->enabled, &v, v - 1)));
280 
281 	return true;
282 }
283 
__static_key_slow_dec_cpuslocked(struct static_key * key)284 static void __static_key_slow_dec_cpuslocked(struct static_key *key)
285 {
286 	lockdep_assert_cpus_held();
287 
288 	if (static_key_slow_try_dec(key))
289 		return;
290 
291 	guard(mutex)(&jump_label_mutex);
292 	if (atomic_cmpxchg(&key->enabled, 1, 0) == 1)
293 		jump_label_update(key);
294 	else
295 		WARN_ON_ONCE(!static_key_slow_try_dec(key));
296 }
297 
__static_key_slow_dec(struct static_key * key)298 static void __static_key_slow_dec(struct static_key *key)
299 {
300 	cpus_read_lock();
301 	__static_key_slow_dec_cpuslocked(key);
302 	cpus_read_unlock();
303 }
304 
jump_label_update_timeout(struct work_struct * work)305 void jump_label_update_timeout(struct work_struct *work)
306 {
307 	struct static_key_deferred *key =
308 		container_of(work, struct static_key_deferred, work.work);
309 	__static_key_slow_dec(&key->key);
310 }
311 EXPORT_SYMBOL_GPL(jump_label_update_timeout);
312 
static_key_slow_dec(struct static_key * key)313 void static_key_slow_dec(struct static_key *key)
314 {
315 	STATIC_KEY_CHECK_USE(key);
316 	__static_key_slow_dec(key);
317 }
318 EXPORT_SYMBOL_GPL(static_key_slow_dec);
319 
static_key_slow_dec_cpuslocked(struct static_key * key)320 void static_key_slow_dec_cpuslocked(struct static_key *key)
321 {
322 	STATIC_KEY_CHECK_USE(key);
323 	__static_key_slow_dec_cpuslocked(key);
324 }
325 
__static_key_slow_dec_deferred(struct static_key * key,struct delayed_work * work,unsigned long timeout)326 void __static_key_slow_dec_deferred(struct static_key *key,
327 				    struct delayed_work *work,
328 				    unsigned long timeout)
329 {
330 	STATIC_KEY_CHECK_USE(key);
331 
332 	if (static_key_slow_try_dec(key))
333 		return;
334 
335 	schedule_delayed_work(work, timeout);
336 }
337 EXPORT_SYMBOL_GPL(__static_key_slow_dec_deferred);
338 
__static_key_deferred_flush(void * key,struct delayed_work * work)339 void __static_key_deferred_flush(void *key, struct delayed_work *work)
340 {
341 	STATIC_KEY_CHECK_USE(key);
342 	flush_delayed_work(work);
343 }
344 EXPORT_SYMBOL_GPL(__static_key_deferred_flush);
345 
jump_label_rate_limit(struct static_key_deferred * key,unsigned long rl)346 void jump_label_rate_limit(struct static_key_deferred *key,
347 		unsigned long rl)
348 {
349 	STATIC_KEY_CHECK_USE(key);
350 	key->timeout = rl;
351 	INIT_DELAYED_WORK(&key->work, jump_label_update_timeout);
352 }
353 EXPORT_SYMBOL_GPL(jump_label_rate_limit);
354 
addr_conflict(struct jump_entry * entry,void * start,void * end)355 static int addr_conflict(struct jump_entry *entry, void *start, void *end)
356 {
357 	if (jump_entry_code(entry) <= (unsigned long)end &&
358 	    jump_entry_code(entry) + jump_entry_size(entry) > (unsigned long)start)
359 		return 1;
360 
361 	return 0;
362 }
363 
__jump_label_text_reserved(struct jump_entry * iter_start,struct jump_entry * iter_stop,void * start,void * end,bool init)364 static int __jump_label_text_reserved(struct jump_entry *iter_start,
365 		struct jump_entry *iter_stop, void *start, void *end, bool init)
366 {
367 	struct jump_entry *iter;
368 
369 	iter = iter_start;
370 	while (iter < iter_stop) {
371 		if (init || !jump_entry_is_init(iter)) {
372 			if (addr_conflict(iter, start, end))
373 				return 1;
374 		}
375 		iter++;
376 	}
377 
378 	return 0;
379 }
380 
381 #ifndef arch_jump_label_transform_static
arch_jump_label_transform_static(struct jump_entry * entry,enum jump_label_type type)382 static void arch_jump_label_transform_static(struct jump_entry *entry,
383 					     enum jump_label_type type)
384 {
385 	/* nothing to do on most architectures */
386 }
387 #endif
388 
static_key_entries(struct static_key * key)389 static inline struct jump_entry *static_key_entries(struct static_key *key)
390 {
391 	WARN_ON_ONCE(key->type & JUMP_TYPE_LINKED);
392 	return (struct jump_entry *)(key->type & ~JUMP_TYPE_MASK);
393 }
394 
static_key_type(struct static_key * key)395 static inline bool static_key_type(struct static_key *key)
396 {
397 	return key->type & JUMP_TYPE_TRUE;
398 }
399 
static_key_linked(struct static_key * key)400 static inline bool static_key_linked(struct static_key *key)
401 {
402 	return key->type & JUMP_TYPE_LINKED;
403 }
404 
static_key_clear_linked(struct static_key * key)405 static inline void static_key_clear_linked(struct static_key *key)
406 {
407 	key->type &= ~JUMP_TYPE_LINKED;
408 }
409 
static_key_set_linked(struct static_key * key)410 static inline void static_key_set_linked(struct static_key *key)
411 {
412 	key->type |= JUMP_TYPE_LINKED;
413 }
414 
415 /***
416  * A 'struct static_key' uses a union such that it either points directly
417  * to a table of 'struct jump_entry' or to a linked list of modules which in
418  * turn point to 'struct jump_entry' tables.
419  *
420  * The two lower bits of the pointer are used to keep track of which pointer
421  * type is in use and to store the initial branch direction, we use an access
422  * function which preserves these bits.
423  */
static_key_set_entries(struct static_key * key,struct jump_entry * entries)424 static void static_key_set_entries(struct static_key *key,
425 				   struct jump_entry *entries)
426 {
427 	unsigned long type;
428 
429 	WARN_ON_ONCE((unsigned long)entries & JUMP_TYPE_MASK);
430 	type = key->type & JUMP_TYPE_MASK;
431 	key->entries = entries;
432 	key->type |= type;
433 }
434 
jump_label_type(struct jump_entry * entry)435 static enum jump_label_type jump_label_type(struct jump_entry *entry)
436 {
437 	struct static_key *key = jump_entry_key(entry);
438 	bool enabled = static_key_enabled(key);
439 	bool branch = jump_entry_is_branch(entry);
440 
441 	/* See the comment in linux/jump_label.h */
442 	return enabled ^ branch;
443 }
444 
jump_label_can_update(struct jump_entry * entry,bool init)445 static bool jump_label_can_update(struct jump_entry *entry, bool init)
446 {
447 	/*
448 	 * Cannot update code that was in an init text area.
449 	 */
450 	if (!init && jump_entry_is_init(entry))
451 		return false;
452 
453 	if (!kernel_text_address(jump_entry_code(entry))) {
454 		/*
455 		 * This skips patching built-in __exit, which
456 		 * is part of init_section_contains() but is
457 		 * not part of kernel_text_address().
458 		 *
459 		 * Skipping built-in __exit is fine since it
460 		 * will never be executed.
461 		 */
462 		WARN_ONCE(!jump_entry_is_init(entry),
463 			  "can't patch jump_label at %pS",
464 			  (void *)jump_entry_code(entry));
465 		return false;
466 	}
467 
468 	return true;
469 }
470 
471 #ifndef HAVE_JUMP_LABEL_BATCH
__jump_label_update(struct static_key * key,struct jump_entry * entry,struct jump_entry * stop,bool init)472 static void __jump_label_update(struct static_key *key,
473 				struct jump_entry *entry,
474 				struct jump_entry *stop,
475 				bool init)
476 {
477 	for (; (entry < stop) && (jump_entry_key(entry) == key); entry++) {
478 		if (jump_label_can_update(entry, init))
479 			arch_jump_label_transform(entry, jump_label_type(entry));
480 	}
481 }
482 #else
__jump_label_update(struct static_key * key,struct jump_entry * entry,struct jump_entry * stop,bool init)483 static void __jump_label_update(struct static_key *key,
484 				struct jump_entry *entry,
485 				struct jump_entry *stop,
486 				bool init)
487 {
488 	for (; (entry < stop) && (jump_entry_key(entry) == key); entry++) {
489 
490 		if (!jump_label_can_update(entry, init))
491 			continue;
492 
493 		if (!arch_jump_label_transform_queue(entry, jump_label_type(entry))) {
494 			/*
495 			 * Queue is full: Apply the current queue and try again.
496 			 */
497 			arch_jump_label_transform_apply();
498 			BUG_ON(!arch_jump_label_transform_queue(entry, jump_label_type(entry)));
499 		}
500 	}
501 	arch_jump_label_transform_apply();
502 }
503 #endif
504 
jump_label_init(void)505 void __init jump_label_init(void)
506 {
507 	struct jump_entry *iter_start = __start___jump_table;
508 	struct jump_entry *iter_stop = __stop___jump_table;
509 	struct static_key *key = NULL;
510 	struct jump_entry *iter;
511 
512 	/*
513 	 * Since we are initializing the static_key.enabled field with
514 	 * with the 'raw' int values (to avoid pulling in atomic.h) in
515 	 * jump_label.h, let's make sure that is safe. There are only two
516 	 * cases to check since we initialize to 0 or 1.
517 	 */
518 	BUILD_BUG_ON((int)ATOMIC_INIT(0) != 0);
519 	BUILD_BUG_ON((int)ATOMIC_INIT(1) != 1);
520 
521 	if (static_key_initialized)
522 		return;
523 
524 	cpus_read_lock();
525 	jump_label_lock();
526 	jump_label_sort_entries(iter_start, iter_stop);
527 
528 	for (iter = iter_start; iter < iter_stop; iter++) {
529 		struct static_key *iterk;
530 		bool in_init;
531 
532 		/* rewrite NOPs */
533 		if (jump_label_type(iter) == JUMP_LABEL_NOP)
534 			arch_jump_label_transform_static(iter, JUMP_LABEL_NOP);
535 
536 		in_init = init_section_contains((void *)jump_entry_code(iter), 1);
537 		jump_entry_set_init(iter, in_init);
538 
539 		iterk = jump_entry_key(iter);
540 		if (iterk == key)
541 			continue;
542 
543 		key = iterk;
544 		static_key_set_entries(key, iter);
545 	}
546 	static_key_initialized = true;
547 	jump_label_unlock();
548 	cpus_read_unlock();
549 }
550 
static_key_sealed(struct static_key * key)551 static inline bool static_key_sealed(struct static_key *key)
552 {
553 	return (key->type & JUMP_TYPE_LINKED) && !(key->type & ~JUMP_TYPE_MASK);
554 }
555 
static_key_seal(struct static_key * key)556 static inline void static_key_seal(struct static_key *key)
557 {
558 	unsigned long type = key->type & JUMP_TYPE_TRUE;
559 	key->type = JUMP_TYPE_LINKED | type;
560 }
561 
jump_label_init_ro(void)562 void jump_label_init_ro(void)
563 {
564 	struct jump_entry *iter_start = __start___jump_table;
565 	struct jump_entry *iter_stop = __stop___jump_table;
566 	struct jump_entry *iter;
567 
568 	if (WARN_ON_ONCE(!static_key_initialized))
569 		return;
570 
571 	cpus_read_lock();
572 	jump_label_lock();
573 
574 	for (iter = iter_start; iter < iter_stop; iter++) {
575 		struct static_key *iterk = jump_entry_key(iter);
576 
577 		if (!is_kernel_ro_after_init((unsigned long)iterk))
578 			continue;
579 
580 		if (static_key_sealed(iterk))
581 			continue;
582 
583 		static_key_seal(iterk);
584 	}
585 
586 	jump_label_unlock();
587 	cpus_read_unlock();
588 }
589 
590 #ifdef CONFIG_MODULES
591 
jump_label_init_type(struct jump_entry * entry)592 enum jump_label_type jump_label_init_type(struct jump_entry *entry)
593 {
594 	struct static_key *key = jump_entry_key(entry);
595 	bool type = static_key_type(key);
596 	bool branch = jump_entry_is_branch(entry);
597 
598 	/* See the comment in linux/jump_label.h */
599 	return type ^ branch;
600 }
601 
602 struct static_key_mod {
603 	struct static_key_mod *next;
604 	struct jump_entry *entries;
605 	struct module *mod;
606 };
607 
static_key_mod(struct static_key * key)608 static inline struct static_key_mod *static_key_mod(struct static_key *key)
609 {
610 	WARN_ON_ONCE(!static_key_linked(key));
611 	return (struct static_key_mod *)(key->type & ~JUMP_TYPE_MASK);
612 }
613 
614 /***
615  * key->type and key->next are the same via union.
616  * This sets key->next and preserves the type bits.
617  *
618  * See additional comments above static_key_set_entries().
619  */
static_key_set_mod(struct static_key * key,struct static_key_mod * mod)620 static void static_key_set_mod(struct static_key *key,
621 			       struct static_key_mod *mod)
622 {
623 	unsigned long type;
624 
625 	WARN_ON_ONCE((unsigned long)mod & JUMP_TYPE_MASK);
626 	type = key->type & JUMP_TYPE_MASK;
627 	key->next = mod;
628 	key->type |= type;
629 }
630 
__jump_label_mod_text_reserved(void * start,void * end)631 static int __jump_label_mod_text_reserved(void *start, void *end)
632 {
633 	struct module *mod;
634 	int ret;
635 
636 	preempt_disable();
637 	mod = __module_text_address((unsigned long)start);
638 	WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod);
639 	if (!try_module_get(mod))
640 		mod = NULL;
641 	preempt_enable();
642 
643 	if (!mod)
644 		return 0;
645 
646 	ret = __jump_label_text_reserved(mod->jump_entries,
647 				mod->jump_entries + mod->num_jump_entries,
648 				start, end, mod->state == MODULE_STATE_COMING);
649 
650 	module_put(mod);
651 
652 	return ret;
653 }
654 
__jump_label_mod_update(struct static_key * key)655 static void __jump_label_mod_update(struct static_key *key)
656 {
657 	struct static_key_mod *mod;
658 
659 	for (mod = static_key_mod(key); mod; mod = mod->next) {
660 		struct jump_entry *stop;
661 		struct module *m;
662 
663 		/*
664 		 * NULL if the static_key is defined in a module
665 		 * that does not use it
666 		 */
667 		if (!mod->entries)
668 			continue;
669 
670 		m = mod->mod;
671 		if (!m)
672 			stop = __stop___jump_table;
673 		else
674 			stop = m->jump_entries + m->num_jump_entries;
675 		__jump_label_update(key, mod->entries, stop,
676 				    m && m->state == MODULE_STATE_COMING);
677 	}
678 }
679 
jump_label_add_module(struct module * mod)680 static int jump_label_add_module(struct module *mod)
681 {
682 	struct jump_entry *iter_start = mod->jump_entries;
683 	struct jump_entry *iter_stop = iter_start + mod->num_jump_entries;
684 	struct jump_entry *iter;
685 	struct static_key *key = NULL;
686 	struct static_key_mod *jlm, *jlm2;
687 
688 	/* if the module doesn't have jump label entries, just return */
689 	if (iter_start == iter_stop)
690 		return 0;
691 
692 	jump_label_sort_entries(iter_start, iter_stop);
693 
694 	for (iter = iter_start; iter < iter_stop; iter++) {
695 		struct static_key *iterk;
696 		bool in_init;
697 
698 		in_init = within_module_init(jump_entry_code(iter), mod);
699 		jump_entry_set_init(iter, in_init);
700 
701 		iterk = jump_entry_key(iter);
702 		if (iterk == key)
703 			continue;
704 
705 		key = iterk;
706 		if (within_module((unsigned long)key, mod)) {
707 			static_key_set_entries(key, iter);
708 			continue;
709 		}
710 
711 		/*
712 		 * If the key was sealed at init, then there's no need to keep a
713 		 * reference to its module entries - just patch them now and be
714 		 * done with it.
715 		 */
716 		if (static_key_sealed(key))
717 			goto do_poke;
718 
719 		jlm = kzalloc(sizeof(struct static_key_mod), GFP_KERNEL);
720 		if (!jlm)
721 			return -ENOMEM;
722 		if (!static_key_linked(key)) {
723 			jlm2 = kzalloc(sizeof(struct static_key_mod),
724 				       GFP_KERNEL);
725 			if (!jlm2) {
726 				kfree(jlm);
727 				return -ENOMEM;
728 			}
729 			preempt_disable();
730 			jlm2->mod = __module_address((unsigned long)key);
731 			preempt_enable();
732 			jlm2->entries = static_key_entries(key);
733 			jlm2->next = NULL;
734 			static_key_set_mod(key, jlm2);
735 			static_key_set_linked(key);
736 		}
737 		jlm->mod = mod;
738 		jlm->entries = iter;
739 		jlm->next = static_key_mod(key);
740 		static_key_set_mod(key, jlm);
741 		static_key_set_linked(key);
742 
743 		/* Only update if we've changed from our initial state */
744 do_poke:
745 		if (jump_label_type(iter) != jump_label_init_type(iter))
746 			__jump_label_update(key, iter, iter_stop, true);
747 	}
748 
749 	return 0;
750 }
751 
jump_label_del_module(struct module * mod)752 static void jump_label_del_module(struct module *mod)
753 {
754 	struct jump_entry *iter_start = mod->jump_entries;
755 	struct jump_entry *iter_stop = iter_start + mod->num_jump_entries;
756 	struct jump_entry *iter;
757 	struct static_key *key = NULL;
758 	struct static_key_mod *jlm, **prev;
759 
760 	for (iter = iter_start; iter < iter_stop; iter++) {
761 		if (jump_entry_key(iter) == key)
762 			continue;
763 
764 		key = jump_entry_key(iter);
765 
766 		if (within_module((unsigned long)key, mod))
767 			continue;
768 
769 		/* No @jlm allocated because key was sealed at init. */
770 		if (static_key_sealed(key))
771 			continue;
772 
773 		/* No memory during module load */
774 		if (WARN_ON(!static_key_linked(key)))
775 			continue;
776 
777 		prev = &key->next;
778 		jlm = static_key_mod(key);
779 
780 		while (jlm && jlm->mod != mod) {
781 			prev = &jlm->next;
782 			jlm = jlm->next;
783 		}
784 
785 		/* No memory during module load */
786 		if (WARN_ON(!jlm))
787 			continue;
788 
789 		if (prev == &key->next)
790 			static_key_set_mod(key, jlm->next);
791 		else
792 			*prev = jlm->next;
793 
794 		kfree(jlm);
795 
796 		jlm = static_key_mod(key);
797 		/* if only one etry is left, fold it back into the static_key */
798 		if (jlm->next == NULL) {
799 			static_key_set_entries(key, jlm->entries);
800 			static_key_clear_linked(key);
801 			kfree(jlm);
802 		}
803 	}
804 }
805 
806 static int
jump_label_module_notify(struct notifier_block * self,unsigned long val,void * data)807 jump_label_module_notify(struct notifier_block *self, unsigned long val,
808 			 void *data)
809 {
810 	struct module *mod = data;
811 	int ret = 0;
812 
813 	cpus_read_lock();
814 	jump_label_lock();
815 
816 	switch (val) {
817 	case MODULE_STATE_COMING:
818 		ret = jump_label_add_module(mod);
819 		if (ret) {
820 			WARN(1, "Failed to allocate memory: jump_label may not work properly.\n");
821 			jump_label_del_module(mod);
822 		}
823 		break;
824 	case MODULE_STATE_GOING:
825 		jump_label_del_module(mod);
826 		break;
827 	}
828 
829 	jump_label_unlock();
830 	cpus_read_unlock();
831 
832 	return notifier_from_errno(ret);
833 }
834 
835 static struct notifier_block jump_label_module_nb = {
836 	.notifier_call = jump_label_module_notify,
837 	.priority = 1, /* higher than tracepoints */
838 };
839 
jump_label_init_module(void)840 static __init int jump_label_init_module(void)
841 {
842 	return register_module_notifier(&jump_label_module_nb);
843 }
844 early_initcall(jump_label_init_module);
845 
846 #endif /* CONFIG_MODULES */
847 
848 /***
849  * jump_label_text_reserved - check if addr range is reserved
850  * @start: start text addr
851  * @end: end text addr
852  *
853  * checks if the text addr located between @start and @end
854  * overlaps with any of the jump label patch addresses. Code
855  * that wants to modify kernel text should first verify that
856  * it does not overlap with any of the jump label addresses.
857  * Caller must hold jump_label_mutex.
858  *
859  * returns 1 if there is an overlap, 0 otherwise
860  */
jump_label_text_reserved(void * start,void * end)861 int jump_label_text_reserved(void *start, void *end)
862 {
863 	bool init = system_state < SYSTEM_RUNNING;
864 	int ret = __jump_label_text_reserved(__start___jump_table,
865 			__stop___jump_table, start, end, init);
866 
867 	if (ret)
868 		return ret;
869 
870 #ifdef CONFIG_MODULES
871 	ret = __jump_label_mod_text_reserved(start, end);
872 #endif
873 	return ret;
874 }
875 
jump_label_update(struct static_key * key)876 static void jump_label_update(struct static_key *key)
877 {
878 	struct jump_entry *stop = __stop___jump_table;
879 	bool init = system_state < SYSTEM_RUNNING;
880 	struct jump_entry *entry;
881 #ifdef CONFIG_MODULES
882 	struct module *mod;
883 
884 	if (static_key_linked(key)) {
885 		__jump_label_mod_update(key);
886 		return;
887 	}
888 
889 	preempt_disable();
890 	mod = __module_address((unsigned long)key);
891 	if (mod) {
892 		stop = mod->jump_entries + mod->num_jump_entries;
893 		init = mod->state == MODULE_STATE_COMING;
894 	}
895 	preempt_enable();
896 #endif
897 	entry = static_key_entries(key);
898 	/* if there are no users, entry can be NULL */
899 	if (entry)
900 		__jump_label_update(key, entry, stop, init);
901 }
902 
903 #ifdef CONFIG_STATIC_KEYS_SELFTEST
904 static DEFINE_STATIC_KEY_TRUE(sk_true);
905 static DEFINE_STATIC_KEY_FALSE(sk_false);
906 
jump_label_test(void)907 static __init int jump_label_test(void)
908 {
909 	int i;
910 
911 	for (i = 0; i < 2; i++) {
912 		WARN_ON(static_key_enabled(&sk_true.key) != true);
913 		WARN_ON(static_key_enabled(&sk_false.key) != false);
914 
915 		WARN_ON(!static_branch_likely(&sk_true));
916 		WARN_ON(!static_branch_unlikely(&sk_true));
917 		WARN_ON(static_branch_likely(&sk_false));
918 		WARN_ON(static_branch_unlikely(&sk_false));
919 
920 		static_branch_disable(&sk_true);
921 		static_branch_enable(&sk_false);
922 
923 		WARN_ON(static_key_enabled(&sk_true.key) == true);
924 		WARN_ON(static_key_enabled(&sk_false.key) == false);
925 
926 		WARN_ON(static_branch_likely(&sk_true));
927 		WARN_ON(static_branch_unlikely(&sk_true));
928 		WARN_ON(!static_branch_likely(&sk_false));
929 		WARN_ON(!static_branch_unlikely(&sk_false));
930 
931 		static_branch_enable(&sk_true);
932 		static_branch_disable(&sk_false);
933 	}
934 
935 	return 0;
936 }
937 early_initcall(jump_label_test);
938 #endif /* STATIC_KEYS_SELFTEST */
939