1 // SPDX-License-Identifier: GPL-2.0 2 #include <linux/init.h> 3 #include <linux/static_call.h> 4 #include <linux/bug.h> 5 #include <linux/smp.h> 6 #include <linux/sort.h> 7 #include <linux/slab.h> 8 #include <linux/module.h> 9 #include <linux/cpu.h> 10 #include <linux/processor.h> 11 #include <asm/sections.h> 12 13 extern struct static_call_site __start_static_call_sites[], 14 __stop_static_call_sites[]; 15 extern struct static_call_tramp_key __start_static_call_tramp_key[], 16 __stop_static_call_tramp_key[]; 17 18 static bool static_call_initialized; 19 20 /* mutex to protect key modules/sites */ 21 static DEFINE_MUTEX(static_call_mutex); 22 23 static void static_call_lock(void) 24 { 25 mutex_lock(&static_call_mutex); 26 } 27 28 static void static_call_unlock(void) 29 { 30 mutex_unlock(&static_call_mutex); 31 } 32 33 static inline void *static_call_addr(struct static_call_site *site) 34 { 35 return (void *)((long)site->addr + (long)&site->addr); 36 } 37 38 static inline unsigned long __static_call_key(const struct static_call_site *site) 39 { 40 return (long)site->key + (long)&site->key; 41 } 42 43 static inline struct static_call_key *static_call_key(const struct static_call_site *site) 44 { 45 return (void *)(__static_call_key(site) & ~STATIC_CALL_SITE_FLAGS); 46 } 47 48 /* These assume the key is word-aligned. */ 49 static inline bool static_call_is_init(struct static_call_site *site) 50 { 51 return __static_call_key(site) & STATIC_CALL_SITE_INIT; 52 } 53 54 static inline bool static_call_is_tail(struct static_call_site *site) 55 { 56 return __static_call_key(site) & STATIC_CALL_SITE_TAIL; 57 } 58 59 static inline void static_call_set_init(struct static_call_site *site) 60 { 61 site->key = (__static_call_key(site) | STATIC_CALL_SITE_INIT) - 62 (long)&site->key; 63 } 64 65 static int static_call_site_cmp(const void *_a, const void *_b) 66 { 67 const struct static_call_site *a = _a; 68 const struct static_call_site *b = _b; 69 const struct static_call_key *key_a = static_call_key(a); 70 const struct static_call_key *key_b = static_call_key(b); 71 72 if (key_a < key_b) 73 return -1; 74 75 if (key_a > key_b) 76 return 1; 77 78 return 0; 79 } 80 81 static void static_call_site_swap(void *_a, void *_b, int size) 82 { 83 long delta = (unsigned long)_a - (unsigned long)_b; 84 struct static_call_site *a = _a; 85 struct static_call_site *b = _b; 86 struct static_call_site tmp = *a; 87 88 a->addr = b->addr - delta; 89 a->key = b->key - delta; 90 91 b->addr = tmp.addr + delta; 92 b->key = tmp.key + delta; 93 } 94 95 static inline void static_call_sort_entries(struct static_call_site *start, 96 struct static_call_site *stop) 97 { 98 sort(start, stop - start, sizeof(struct static_call_site), 99 static_call_site_cmp, static_call_site_swap); 100 } 101 102 static inline bool static_call_key_has_mods(struct static_call_key *key) 103 { 104 return !(key->type & 1); 105 } 106 107 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key) 108 { 109 if (!static_call_key_has_mods(key)) 110 return NULL; 111 112 return key->mods; 113 } 114 115 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key) 116 { 117 if (static_call_key_has_mods(key)) 118 return NULL; 119 120 return (struct static_call_site *)(key->type & ~1); 121 } 122 123 void __static_call_update(struct static_call_key *key, void *tramp, void *func) 124 { 125 struct static_call_site *site, *stop; 126 struct static_call_mod *site_mod, first; 127 128 cpus_read_lock(); 129 static_call_lock(); 130 131 if (key->func == func) 132 goto done; 133 134 key->func = func; 135 136 arch_static_call_transform(NULL, tramp, func, false); 137 138 /* 139 * If uninitialized, we'll not update the callsites, but they still 140 * point to the trampoline and we just patched that. 141 */ 142 if (WARN_ON_ONCE(!static_call_initialized)) 143 goto done; 144 145 first = (struct static_call_mod){ 146 .next = static_call_key_next(key), 147 .mod = NULL, 148 .sites = static_call_key_sites(key), 149 }; 150 151 for (site_mod = &first; site_mod; site_mod = site_mod->next) { 152 bool init = system_state < SYSTEM_RUNNING; 153 struct module *mod = site_mod->mod; 154 155 if (!site_mod->sites) { 156 /* 157 * This can happen if the static call key is defined in 158 * a module which doesn't use it. 159 * 160 * It also happens in the has_mods case, where the 161 * 'first' entry has no sites associated with it. 162 */ 163 continue; 164 } 165 166 stop = __stop_static_call_sites; 167 168 if (mod) { 169 #ifdef CONFIG_MODULES 170 stop = mod->static_call_sites + 171 mod->num_static_call_sites; 172 init = mod->state == MODULE_STATE_COMING; 173 #endif 174 } 175 176 for (site = site_mod->sites; 177 site < stop && static_call_key(site) == key; site++) { 178 void *site_addr = static_call_addr(site); 179 180 if (!init && static_call_is_init(site)) 181 continue; 182 183 if (!kernel_text_address((unsigned long)site_addr)) { 184 /* 185 * This skips patching built-in __exit, which 186 * is part of init_section_contains() but is 187 * not part of kernel_text_address(). 188 * 189 * Skipping built-in __exit is fine since it 190 * will never be executed. 191 */ 192 WARN_ONCE(!static_call_is_init(site), 193 "can't patch static call site at %pS", 194 site_addr); 195 continue; 196 } 197 198 arch_static_call_transform(site_addr, NULL, func, 199 static_call_is_tail(site)); 200 } 201 } 202 203 done: 204 static_call_unlock(); 205 cpus_read_unlock(); 206 } 207 EXPORT_SYMBOL_GPL(__static_call_update); 208 209 static int __static_call_init(struct module *mod, 210 struct static_call_site *start, 211 struct static_call_site *stop) 212 { 213 struct static_call_site *site; 214 struct static_call_key *key, *prev_key = NULL; 215 struct static_call_mod *site_mod; 216 217 if (start == stop) 218 return 0; 219 220 static_call_sort_entries(start, stop); 221 222 for (site = start; site < stop; site++) { 223 void *site_addr = static_call_addr(site); 224 225 if ((mod && within_module_init((unsigned long)site_addr, mod)) || 226 (!mod && init_section_contains(site_addr, 1))) 227 static_call_set_init(site); 228 229 key = static_call_key(site); 230 if (key != prev_key) { 231 prev_key = key; 232 233 /* 234 * For vmlinux (!mod) avoid the allocation by storing 235 * the sites pointer in the key itself. Also see 236 * __static_call_update()'s @first. 237 * 238 * This allows architectures (eg. x86) to call 239 * static_call_init() before memory allocation works. 240 */ 241 if (!mod) { 242 key->sites = site; 243 key->type |= 1; 244 goto do_transform; 245 } 246 247 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); 248 if (!site_mod) 249 return -ENOMEM; 250 251 /* 252 * When the key has a direct sites pointer, extract 253 * that into an explicit struct static_call_mod, so we 254 * can have a list of modules. 255 */ 256 if (static_call_key_sites(key)) { 257 site_mod->mod = NULL; 258 site_mod->next = NULL; 259 site_mod->sites = static_call_key_sites(key); 260 261 key->mods = site_mod; 262 263 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); 264 if (!site_mod) 265 return -ENOMEM; 266 } 267 268 site_mod->mod = mod; 269 site_mod->sites = site; 270 site_mod->next = static_call_key_next(key); 271 key->mods = site_mod; 272 } 273 274 do_transform: 275 arch_static_call_transform(site_addr, NULL, key->func, 276 static_call_is_tail(site)); 277 } 278 279 return 0; 280 } 281 282 static int addr_conflict(struct static_call_site *site, void *start, void *end) 283 { 284 unsigned long addr = (unsigned long)static_call_addr(site); 285 286 if (addr <= (unsigned long)end && 287 addr + CALL_INSN_SIZE > (unsigned long)start) 288 return 1; 289 290 return 0; 291 } 292 293 static int __static_call_text_reserved(struct static_call_site *iter_start, 294 struct static_call_site *iter_stop, 295 void *start, void *end, bool init) 296 { 297 struct static_call_site *iter = iter_start; 298 299 while (iter < iter_stop) { 300 if (init || !static_call_is_init(iter)) { 301 if (addr_conflict(iter, start, end)) 302 return 1; 303 } 304 iter++; 305 } 306 307 return 0; 308 } 309 310 #ifdef CONFIG_MODULES 311 312 static int __static_call_mod_text_reserved(void *start, void *end) 313 { 314 struct module *mod; 315 int ret; 316 317 preempt_disable(); 318 mod = __module_text_address((unsigned long)start); 319 WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod); 320 if (!try_module_get(mod)) 321 mod = NULL; 322 preempt_enable(); 323 324 if (!mod) 325 return 0; 326 327 ret = __static_call_text_reserved(mod->static_call_sites, 328 mod->static_call_sites + mod->num_static_call_sites, 329 start, end, mod->state == MODULE_STATE_COMING); 330 331 module_put(mod); 332 333 return ret; 334 } 335 336 static unsigned long tramp_key_lookup(unsigned long addr) 337 { 338 struct static_call_tramp_key *start = __start_static_call_tramp_key; 339 struct static_call_tramp_key *stop = __stop_static_call_tramp_key; 340 struct static_call_tramp_key *tramp_key; 341 342 for (tramp_key = start; tramp_key != stop; tramp_key++) { 343 unsigned long tramp; 344 345 tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp; 346 if (tramp == addr) 347 return (long)tramp_key->key + (long)&tramp_key->key; 348 } 349 350 return 0; 351 } 352 353 static int static_call_add_module(struct module *mod) 354 { 355 struct static_call_site *start = mod->static_call_sites; 356 struct static_call_site *stop = start + mod->num_static_call_sites; 357 struct static_call_site *site; 358 359 for (site = start; site != stop; site++) { 360 unsigned long s_key = __static_call_key(site); 361 unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS; 362 unsigned long key; 363 364 /* 365 * Is the key is exported, 'addr' points to the key, which 366 * means modules are allowed to call static_call_update() on 367 * it. 368 * 369 * Otherwise, the key isn't exported, and 'addr' points to the 370 * trampoline so we need to lookup the key. 371 * 372 * We go through this dance to prevent crazy modules from 373 * abusing sensitive static calls. 374 */ 375 if (!kernel_text_address(addr)) 376 continue; 377 378 key = tramp_key_lookup(addr); 379 if (!key) { 380 pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n", 381 static_call_addr(site)); 382 return -EINVAL; 383 } 384 385 key |= s_key & STATIC_CALL_SITE_FLAGS; 386 site->key = key - (long)&site->key; 387 } 388 389 return __static_call_init(mod, start, stop); 390 } 391 392 static void static_call_del_module(struct module *mod) 393 { 394 struct static_call_site *start = mod->static_call_sites; 395 struct static_call_site *stop = mod->static_call_sites + 396 mod->num_static_call_sites; 397 struct static_call_key *key, *prev_key = NULL; 398 struct static_call_mod *site_mod, **prev; 399 struct static_call_site *site; 400 401 for (site = start; site < stop; site++) { 402 key = static_call_key(site); 403 if (key == prev_key) 404 continue; 405 406 prev_key = key; 407 408 for (prev = &key->mods, site_mod = key->mods; 409 site_mod && site_mod->mod != mod; 410 prev = &site_mod->next, site_mod = site_mod->next) 411 ; 412 413 if (!site_mod) 414 continue; 415 416 *prev = site_mod->next; 417 kfree(site_mod); 418 } 419 } 420 421 static int static_call_module_notify(struct notifier_block *nb, 422 unsigned long val, void *data) 423 { 424 struct module *mod = data; 425 int ret = 0; 426 427 cpus_read_lock(); 428 static_call_lock(); 429 430 switch (val) { 431 case MODULE_STATE_COMING: 432 ret = static_call_add_module(mod); 433 if (ret) { 434 WARN(1, "Failed to allocate memory for static calls"); 435 static_call_del_module(mod); 436 } 437 break; 438 case MODULE_STATE_GOING: 439 static_call_del_module(mod); 440 break; 441 } 442 443 static_call_unlock(); 444 cpus_read_unlock(); 445 446 return notifier_from_errno(ret); 447 } 448 449 static struct notifier_block static_call_module_nb = { 450 .notifier_call = static_call_module_notify, 451 }; 452 453 #else 454 455 static inline int __static_call_mod_text_reserved(void *start, void *end) 456 { 457 return 0; 458 } 459 460 #endif /* CONFIG_MODULES */ 461 462 int static_call_text_reserved(void *start, void *end) 463 { 464 bool init = system_state < SYSTEM_RUNNING; 465 int ret = __static_call_text_reserved(__start_static_call_sites, 466 __stop_static_call_sites, start, end, init); 467 468 if (ret) 469 return ret; 470 471 return __static_call_mod_text_reserved(start, end); 472 } 473 474 int __init static_call_init(void) 475 { 476 int ret; 477 478 if (static_call_initialized) 479 return 0; 480 481 cpus_read_lock(); 482 static_call_lock(); 483 ret = __static_call_init(NULL, __start_static_call_sites, 484 __stop_static_call_sites); 485 static_call_unlock(); 486 cpus_read_unlock(); 487 488 if (ret) { 489 pr_err("Failed to allocate memory for static_call!\n"); 490 BUG(); 491 } 492 493 static_call_initialized = true; 494 495 #ifdef CONFIG_MODULES 496 register_module_notifier(&static_call_module_nb); 497 #endif 498 return 0; 499 } 500 early_initcall(static_call_init); 501 502 long __static_call_return0(void) 503 { 504 return 0; 505 } 506 507 #ifdef CONFIG_STATIC_CALL_SELFTEST 508 509 static int func_a(int x) 510 { 511 return x+1; 512 } 513 514 static int func_b(int x) 515 { 516 return x+2; 517 } 518 519 DEFINE_STATIC_CALL(sc_selftest, func_a); 520 521 static struct static_call_data { 522 int (*func)(int); 523 int val; 524 int expect; 525 } static_call_data [] __initdata = { 526 { NULL, 2, 3 }, 527 { func_b, 2, 4 }, 528 { func_a, 2, 3 } 529 }; 530 531 static int __init test_static_call_init(void) 532 { 533 int i; 534 535 for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) { 536 struct static_call_data *scd = &static_call_data[i]; 537 538 if (scd->func) 539 static_call_update(sc_selftest, scd->func); 540 541 WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect); 542 } 543 544 return 0; 545 } 546 early_initcall(test_static_call_init); 547 548 #endif /* CONFIG_STATIC_CALL_SELFTEST */ 549