1 // SPDX-License-Identifier: GPL-2.0 2 #include <linux/init.h> 3 #include <linux/static_call.h> 4 #include <linux/bug.h> 5 #include <linux/smp.h> 6 #include <linux/sort.h> 7 #include <linux/slab.h> 8 #include <linux/module.h> 9 #include <linux/cpu.h> 10 #include <linux/processor.h> 11 #include <asm/sections.h> 12 13 extern struct static_call_site __start_static_call_sites[], 14 __stop_static_call_sites[]; 15 extern struct static_call_tramp_key __start_static_call_tramp_key[], 16 __stop_static_call_tramp_key[]; 17 18 int static_call_initialized; 19 20 /* 21 * Must be called before early_initcall() to be effective. 22 */ 23 void static_call_force_reinit(void) 24 { 25 if (WARN_ON_ONCE(!static_call_initialized)) 26 return; 27 28 static_call_initialized++; 29 } 30 31 /* mutex to protect key modules/sites */ 32 static DEFINE_MUTEX(static_call_mutex); 33 34 static void static_call_lock(void) 35 { 36 mutex_lock(&static_call_mutex); 37 } 38 39 static void static_call_unlock(void) 40 { 41 mutex_unlock(&static_call_mutex); 42 } 43 44 static inline void *static_call_addr(struct static_call_site *site) 45 { 46 return (void *)((long)site->addr + (long)&site->addr); 47 } 48 49 static inline unsigned long __static_call_key(const struct static_call_site *site) 50 { 51 return (long)site->key + (long)&site->key; 52 } 53 54 static inline struct static_call_key *static_call_key(const struct static_call_site *site) 55 { 56 return (void *)(__static_call_key(site) & ~STATIC_CALL_SITE_FLAGS); 57 } 58 59 /* These assume the key is word-aligned. */ 60 static inline bool static_call_is_init(struct static_call_site *site) 61 { 62 return __static_call_key(site) & STATIC_CALL_SITE_INIT; 63 } 64 65 static inline bool static_call_is_tail(struct static_call_site *site) 66 { 67 return __static_call_key(site) & STATIC_CALL_SITE_TAIL; 68 } 69 70 static inline void static_call_set_init(struct static_call_site *site) 71 { 72 site->key = (__static_call_key(site) | STATIC_CALL_SITE_INIT) - 73 (long)&site->key; 74 } 75 76 static int static_call_site_cmp(const void *_a, const void *_b) 77 { 78 const struct static_call_site *a = _a; 79 const struct static_call_site *b = _b; 80 const struct static_call_key *key_a = static_call_key(a); 81 const struct static_call_key *key_b = static_call_key(b); 82 83 if (key_a < key_b) 84 return -1; 85 86 if (key_a > key_b) 87 return 1; 88 89 return 0; 90 } 91 92 static void static_call_site_swap(void *_a, void *_b, int size) 93 { 94 long delta = (unsigned long)_a - (unsigned long)_b; 95 struct static_call_site *a = _a; 96 struct static_call_site *b = _b; 97 struct static_call_site tmp = *a; 98 99 a->addr = b->addr - delta; 100 a->key = b->key - delta; 101 102 b->addr = tmp.addr + delta; 103 b->key = tmp.key + delta; 104 } 105 106 static inline void static_call_sort_entries(struct static_call_site *start, 107 struct static_call_site *stop) 108 { 109 sort(start, stop - start, sizeof(struct static_call_site), 110 static_call_site_cmp, static_call_site_swap); 111 } 112 113 static inline bool static_call_key_has_mods(struct static_call_key *key) 114 { 115 return !(key->type & 1); 116 } 117 118 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key) 119 { 120 if (!static_call_key_has_mods(key)) 121 return NULL; 122 123 return key->mods; 124 } 125 126 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key) 127 { 128 if (static_call_key_has_mods(key)) 129 return NULL; 130 131 return (struct static_call_site *)(key->type & ~1); 132 } 133 134 void __static_call_update(struct static_call_key *key, void *tramp, void *func) 135 { 136 struct static_call_site *site, *stop; 137 struct static_call_mod *site_mod, first; 138 139 cpus_read_lock(); 140 static_call_lock(); 141 142 if (key->func == func) 143 goto done; 144 145 key->func = func; 146 147 arch_static_call_transform(NULL, tramp, func, false); 148 149 /* 150 * If uninitialized, we'll not update the callsites, but they still 151 * point to the trampoline and we just patched that. 152 */ 153 if (WARN_ON_ONCE(!static_call_initialized)) 154 goto done; 155 156 first = (struct static_call_mod){ 157 .next = static_call_key_next(key), 158 .mod = NULL, 159 .sites = static_call_key_sites(key), 160 }; 161 162 for (site_mod = &first; site_mod; site_mod = site_mod->next) { 163 bool init = system_state < SYSTEM_RUNNING; 164 struct module *mod = site_mod->mod; 165 166 if (!site_mod->sites) { 167 /* 168 * This can happen if the static call key is defined in 169 * a module which doesn't use it. 170 * 171 * It also happens in the has_mods case, where the 172 * 'first' entry has no sites associated with it. 173 */ 174 continue; 175 } 176 177 stop = __stop_static_call_sites; 178 179 if (mod) { 180 #ifdef CONFIG_MODULES 181 stop = mod->static_call_sites + 182 mod->num_static_call_sites; 183 init = mod->state == MODULE_STATE_COMING; 184 #endif 185 } 186 187 for (site = site_mod->sites; 188 site < stop && static_call_key(site) == key; site++) { 189 void *site_addr = static_call_addr(site); 190 191 if (!init && static_call_is_init(site)) 192 continue; 193 194 if (!kernel_text_address((unsigned long)site_addr)) { 195 /* 196 * This skips patching built-in __exit, which 197 * is part of init_section_contains() but is 198 * not part of kernel_text_address(). 199 * 200 * Skipping built-in __exit is fine since it 201 * will never be executed. 202 */ 203 WARN_ONCE(!static_call_is_init(site), 204 "can't patch static call site at %pS", 205 site_addr); 206 continue; 207 } 208 209 arch_static_call_transform(site_addr, NULL, func, 210 static_call_is_tail(site)); 211 } 212 } 213 214 done: 215 static_call_unlock(); 216 cpus_read_unlock(); 217 } 218 EXPORT_SYMBOL_GPL(__static_call_update); 219 220 static int __static_call_init(struct module *mod, 221 struct static_call_site *start, 222 struct static_call_site *stop) 223 { 224 struct static_call_site *site; 225 struct static_call_key *key, *prev_key = NULL; 226 struct static_call_mod *site_mod; 227 228 if (start == stop) 229 return 0; 230 231 static_call_sort_entries(start, stop); 232 233 for (site = start; site < stop; site++) { 234 void *site_addr = static_call_addr(site); 235 236 if ((mod && within_module_init((unsigned long)site_addr, mod)) || 237 (!mod && init_section_contains(site_addr, 1))) 238 static_call_set_init(site); 239 240 key = static_call_key(site); 241 if (key != prev_key) { 242 prev_key = key; 243 244 /* 245 * For vmlinux (!mod) avoid the allocation by storing 246 * the sites pointer in the key itself. Also see 247 * __static_call_update()'s @first. 248 * 249 * This allows architectures (eg. x86) to call 250 * static_call_init() before memory allocation works. 251 */ 252 if (!mod) { 253 key->sites = site; 254 key->type |= 1; 255 goto do_transform; 256 } 257 258 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); 259 if (!site_mod) 260 return -ENOMEM; 261 262 /* 263 * When the key has a direct sites pointer, extract 264 * that into an explicit struct static_call_mod, so we 265 * can have a list of modules. 266 */ 267 if (static_call_key_sites(key)) { 268 site_mod->mod = NULL; 269 site_mod->next = NULL; 270 site_mod->sites = static_call_key_sites(key); 271 272 key->mods = site_mod; 273 274 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); 275 if (!site_mod) 276 return -ENOMEM; 277 } 278 279 site_mod->mod = mod; 280 site_mod->sites = site; 281 site_mod->next = static_call_key_next(key); 282 key->mods = site_mod; 283 } 284 285 do_transform: 286 arch_static_call_transform(site_addr, NULL, key->func, 287 static_call_is_tail(site)); 288 } 289 290 return 0; 291 } 292 293 static int addr_conflict(struct static_call_site *site, void *start, void *end) 294 { 295 unsigned long addr = (unsigned long)static_call_addr(site); 296 297 if (addr <= (unsigned long)end && 298 addr + CALL_INSN_SIZE > (unsigned long)start) 299 return 1; 300 301 return 0; 302 } 303 304 static int __static_call_text_reserved(struct static_call_site *iter_start, 305 struct static_call_site *iter_stop, 306 void *start, void *end, bool init) 307 { 308 struct static_call_site *iter = iter_start; 309 310 while (iter < iter_stop) { 311 if (init || !static_call_is_init(iter)) { 312 if (addr_conflict(iter, start, end)) 313 return 1; 314 } 315 iter++; 316 } 317 318 return 0; 319 } 320 321 #ifdef CONFIG_MODULES 322 323 static int __static_call_mod_text_reserved(void *start, void *end) 324 { 325 struct module *mod; 326 int ret; 327 328 preempt_disable(); 329 mod = __module_text_address((unsigned long)start); 330 WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod); 331 if (!try_module_get(mod)) 332 mod = NULL; 333 preempt_enable(); 334 335 if (!mod) 336 return 0; 337 338 ret = __static_call_text_reserved(mod->static_call_sites, 339 mod->static_call_sites + mod->num_static_call_sites, 340 start, end, mod->state == MODULE_STATE_COMING); 341 342 module_put(mod); 343 344 return ret; 345 } 346 347 static unsigned long tramp_key_lookup(unsigned long addr) 348 { 349 struct static_call_tramp_key *start = __start_static_call_tramp_key; 350 struct static_call_tramp_key *stop = __stop_static_call_tramp_key; 351 struct static_call_tramp_key *tramp_key; 352 353 for (tramp_key = start; tramp_key != stop; tramp_key++) { 354 unsigned long tramp; 355 356 tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp; 357 if (tramp == addr) 358 return (long)tramp_key->key + (long)&tramp_key->key; 359 } 360 361 return 0; 362 } 363 364 static int static_call_add_module(struct module *mod) 365 { 366 struct static_call_site *start = mod->static_call_sites; 367 struct static_call_site *stop = start + mod->num_static_call_sites; 368 struct static_call_site *site; 369 370 for (site = start; site != stop; site++) { 371 unsigned long s_key = __static_call_key(site); 372 unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS; 373 unsigned long key; 374 375 /* 376 * Is the key is exported, 'addr' points to the key, which 377 * means modules are allowed to call static_call_update() on 378 * it. 379 * 380 * Otherwise, the key isn't exported, and 'addr' points to the 381 * trampoline so we need to lookup the key. 382 * 383 * We go through this dance to prevent crazy modules from 384 * abusing sensitive static calls. 385 */ 386 if (!kernel_text_address(addr)) 387 continue; 388 389 key = tramp_key_lookup(addr); 390 if (!key) { 391 pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n", 392 static_call_addr(site)); 393 return -EINVAL; 394 } 395 396 key |= s_key & STATIC_CALL_SITE_FLAGS; 397 site->key = key - (long)&site->key; 398 } 399 400 return __static_call_init(mod, start, stop); 401 } 402 403 static void static_call_del_module(struct module *mod) 404 { 405 struct static_call_site *start = mod->static_call_sites; 406 struct static_call_site *stop = mod->static_call_sites + 407 mod->num_static_call_sites; 408 struct static_call_key *key, *prev_key = NULL; 409 struct static_call_mod *site_mod, **prev; 410 struct static_call_site *site; 411 412 for (site = start; site < stop; site++) { 413 key = static_call_key(site); 414 415 /* 416 * If the key was not updated due to a memory allocation 417 * failure in __static_call_init() then treating key::sites 418 * as key::mods in the code below would cause random memory 419 * access and #GP. In that case all subsequent sites have 420 * not been touched either, so stop iterating. 421 */ 422 if (!static_call_key_has_mods(key)) 423 break; 424 425 if (key == prev_key) 426 continue; 427 428 prev_key = key; 429 430 for (prev = &key->mods, site_mod = key->mods; 431 site_mod && site_mod->mod != mod; 432 prev = &site_mod->next, site_mod = site_mod->next) 433 ; 434 435 if (!site_mod) 436 continue; 437 438 *prev = site_mod->next; 439 kfree(site_mod); 440 } 441 } 442 443 static int static_call_module_notify(struct notifier_block *nb, 444 unsigned long val, void *data) 445 { 446 struct module *mod = data; 447 int ret = 0; 448 449 cpus_read_lock(); 450 static_call_lock(); 451 452 switch (val) { 453 case MODULE_STATE_COMING: 454 ret = static_call_add_module(mod); 455 if (ret) { 456 pr_warn("Failed to allocate memory for static calls\n"); 457 static_call_del_module(mod); 458 } 459 break; 460 case MODULE_STATE_GOING: 461 static_call_del_module(mod); 462 break; 463 } 464 465 static_call_unlock(); 466 cpus_read_unlock(); 467 468 return notifier_from_errno(ret); 469 } 470 471 static struct notifier_block static_call_module_nb = { 472 .notifier_call = static_call_module_notify, 473 }; 474 475 #else 476 477 static inline int __static_call_mod_text_reserved(void *start, void *end) 478 { 479 return 0; 480 } 481 482 #endif /* CONFIG_MODULES */ 483 484 int static_call_text_reserved(void *start, void *end) 485 { 486 bool init = system_state < SYSTEM_RUNNING; 487 int ret = __static_call_text_reserved(__start_static_call_sites, 488 __stop_static_call_sites, start, end, init); 489 490 if (ret) 491 return ret; 492 493 return __static_call_mod_text_reserved(start, end); 494 } 495 496 int __init static_call_init(void) 497 { 498 int ret; 499 500 /* See static_call_force_reinit(). */ 501 if (static_call_initialized == 1) 502 return 0; 503 504 cpus_read_lock(); 505 static_call_lock(); 506 ret = __static_call_init(NULL, __start_static_call_sites, 507 __stop_static_call_sites); 508 static_call_unlock(); 509 cpus_read_unlock(); 510 511 if (ret) { 512 pr_err("Failed to allocate memory for static_call!\n"); 513 BUG(); 514 } 515 516 #ifdef CONFIG_MODULES 517 if (!static_call_initialized) 518 register_module_notifier(&static_call_module_nb); 519 #endif 520 521 static_call_initialized = 1; 522 return 0; 523 } 524 early_initcall(static_call_init); 525 526 #ifdef CONFIG_STATIC_CALL_SELFTEST 527 528 static int func_a(int x) 529 { 530 return x+1; 531 } 532 533 static int func_b(int x) 534 { 535 return x+2; 536 } 537 538 DEFINE_STATIC_CALL(sc_selftest, func_a); 539 540 static struct static_call_data { 541 int (*func)(int); 542 int val; 543 int expect; 544 } static_call_data [] __initdata = { 545 { NULL, 2, 3 }, 546 { func_b, 2, 4 }, 547 { func_a, 2, 3 } 548 }; 549 550 static int __init test_static_call_init(void) 551 { 552 int i; 553 554 for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) { 555 struct static_call_data *scd = &static_call_data[i]; 556 557 if (scd->func) 558 static_call_update(sc_selftest, scd->func); 559 560 WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect); 561 } 562 563 return 0; 564 } 565 early_initcall(test_static_call_init); 566 567 #endif /* CONFIG_STATIC_CALL_SELFTEST */ 568