1 // SPDX-License-Identifier: GPL-2.0 2 #include <linux/init.h> 3 #include <linux/static_call.h> 4 #include <linux/bug.h> 5 #include <linux/smp.h> 6 #include <linux/sort.h> 7 #include <linux/slab.h> 8 #include <linux/module.h> 9 #include <linux/cpu.h> 10 #include <linux/processor.h> 11 #include <asm/sections.h> 12 13 extern struct static_call_site __start_static_call_sites[], 14 __stop_static_call_sites[]; 15 extern struct static_call_tramp_key __start_static_call_tramp_key[], 16 __stop_static_call_tramp_key[]; 17 18 static bool static_call_initialized; 19 20 /* mutex to protect key modules/sites */ 21 static DEFINE_MUTEX(static_call_mutex); 22 23 static void static_call_lock(void) 24 { 25 mutex_lock(&static_call_mutex); 26 } 27 28 static void static_call_unlock(void) 29 { 30 mutex_unlock(&static_call_mutex); 31 } 32 33 static inline void *static_call_addr(struct static_call_site *site) 34 { 35 return (void *)((long)site->addr + (long)&site->addr); 36 } 37 38 static inline unsigned long __static_call_key(const struct static_call_site *site) 39 { 40 return (long)site->key + (long)&site->key; 41 } 42 43 static inline struct static_call_key *static_call_key(const struct static_call_site *site) 44 { 45 return (void *)(__static_call_key(site) & ~STATIC_CALL_SITE_FLAGS); 46 } 47 48 /* These assume the key is word-aligned. */ 49 static inline bool static_call_is_init(struct static_call_site *site) 50 { 51 return __static_call_key(site) & STATIC_CALL_SITE_INIT; 52 } 53 54 static inline bool static_call_is_tail(struct static_call_site *site) 55 { 56 return __static_call_key(site) & STATIC_CALL_SITE_TAIL; 57 } 58 59 static inline void static_call_set_init(struct static_call_site *site) 60 { 61 site->key = (__static_call_key(site) | STATIC_CALL_SITE_INIT) - 62 (long)&site->key; 63 } 64 65 static int static_call_site_cmp(const void *_a, const void *_b) 66 { 67 const struct static_call_site *a = _a; 68 const struct static_call_site *b = _b; 69 const struct static_call_key *key_a = static_call_key(a); 70 const struct static_call_key *key_b = static_call_key(b); 71 72 if (key_a < key_b) 73 return -1; 74 75 if (key_a > key_b) 76 return 1; 77 78 return 0; 79 } 80 81 static void static_call_site_swap(void *_a, void *_b, int size) 82 { 83 long delta = (unsigned long)_a - (unsigned long)_b; 84 struct static_call_site *a = _a; 85 struct static_call_site *b = _b; 86 struct static_call_site tmp = *a; 87 88 a->addr = b->addr - delta; 89 a->key = b->key - delta; 90 91 b->addr = tmp.addr + delta; 92 b->key = tmp.key + delta; 93 } 94 95 static inline void static_call_sort_entries(struct static_call_site *start, 96 struct static_call_site *stop) 97 { 98 sort(start, stop - start, sizeof(struct static_call_site), 99 static_call_site_cmp, static_call_site_swap); 100 } 101 102 static inline bool static_call_key_has_mods(struct static_call_key *key) 103 { 104 return !(key->type & 1); 105 } 106 107 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key) 108 { 109 if (!static_call_key_has_mods(key)) 110 return NULL; 111 112 return key->mods; 113 } 114 115 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key) 116 { 117 if (static_call_key_has_mods(key)) 118 return NULL; 119 120 return (struct static_call_site *)(key->type & ~1); 121 } 122 123 void __static_call_update(struct static_call_key *key, void *tramp, void *func) 124 { 125 struct static_call_site *site, *stop; 126 struct static_call_mod *site_mod, first; 127 128 cpus_read_lock(); 129 static_call_lock(); 130 131 if (key->func == func) 132 goto done; 133 134 key->func = func; 135 136 arch_static_call_transform(NULL, tramp, func, false); 137 138 /* 139 * If uninitialized, we'll not update the callsites, but they still 140 * point to the trampoline and we just patched that. 141 */ 142 if (WARN_ON_ONCE(!static_call_initialized)) 143 goto done; 144 145 first = (struct static_call_mod){ 146 .next = static_call_key_next(key), 147 .mod = NULL, 148 .sites = static_call_key_sites(key), 149 }; 150 151 for (site_mod = &first; site_mod; site_mod = site_mod->next) { 152 bool init = system_state < SYSTEM_RUNNING; 153 struct module *mod = site_mod->mod; 154 155 if (!site_mod->sites) { 156 /* 157 * This can happen if the static call key is defined in 158 * a module which doesn't use it. 159 * 160 * It also happens in the has_mods case, where the 161 * 'first' entry has no sites associated with it. 162 */ 163 continue; 164 } 165 166 stop = __stop_static_call_sites; 167 168 if (mod) { 169 #ifdef CONFIG_MODULES 170 stop = mod->static_call_sites + 171 mod->num_static_call_sites; 172 init = mod->state == MODULE_STATE_COMING; 173 #endif 174 } 175 176 for (site = site_mod->sites; 177 site < stop && static_call_key(site) == key; site++) { 178 void *site_addr = static_call_addr(site); 179 180 if (!init && static_call_is_init(site)) 181 continue; 182 183 if (!kernel_text_address((unsigned long)site_addr)) { 184 /* 185 * This skips patching built-in __exit, which 186 * is part of init_section_contains() but is 187 * not part of kernel_text_address(). 188 * 189 * Skipping built-in __exit is fine since it 190 * will never be executed. 191 */ 192 WARN_ONCE(!static_call_is_init(site), 193 "can't patch static call site at %pS", 194 site_addr); 195 continue; 196 } 197 198 arch_static_call_transform(site_addr, NULL, func, 199 static_call_is_tail(site)); 200 } 201 } 202 203 done: 204 static_call_unlock(); 205 cpus_read_unlock(); 206 } 207 EXPORT_SYMBOL_GPL(__static_call_update); 208 209 static int __static_call_init(struct module *mod, 210 struct static_call_site *start, 211 struct static_call_site *stop) 212 { 213 struct static_call_site *site; 214 struct static_call_key *key, *prev_key = NULL; 215 struct static_call_mod *site_mod; 216 217 if (start == stop) 218 return 0; 219 220 static_call_sort_entries(start, stop); 221 222 for (site = start; site < stop; site++) { 223 void *site_addr = static_call_addr(site); 224 225 if ((mod && within_module_init((unsigned long)site_addr, mod)) || 226 (!mod && init_section_contains(site_addr, 1))) 227 static_call_set_init(site); 228 229 key = static_call_key(site); 230 if (key != prev_key) { 231 prev_key = key; 232 233 /* 234 * For vmlinux (!mod) avoid the allocation by storing 235 * the sites pointer in the key itself. Also see 236 * __static_call_update()'s @first. 237 * 238 * This allows architectures (eg. x86) to call 239 * static_call_init() before memory allocation works. 240 */ 241 if (!mod) { 242 key->sites = site; 243 key->type |= 1; 244 goto do_transform; 245 } 246 247 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); 248 if (!site_mod) 249 return -ENOMEM; 250 251 /* 252 * When the key has a direct sites pointer, extract 253 * that into an explicit struct static_call_mod, so we 254 * can have a list of modules. 255 */ 256 if (static_call_key_sites(key)) { 257 site_mod->mod = NULL; 258 site_mod->next = NULL; 259 site_mod->sites = static_call_key_sites(key); 260 261 key->mods = site_mod; 262 263 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); 264 if (!site_mod) 265 return -ENOMEM; 266 } 267 268 site_mod->mod = mod; 269 site_mod->sites = site; 270 site_mod->next = static_call_key_next(key); 271 key->mods = site_mod; 272 } 273 274 do_transform: 275 arch_static_call_transform(site_addr, NULL, key->func, 276 static_call_is_tail(site)); 277 } 278 279 return 0; 280 } 281 282 static int addr_conflict(struct static_call_site *site, void *start, void *end) 283 { 284 unsigned long addr = (unsigned long)static_call_addr(site); 285 286 if (addr <= (unsigned long)end && 287 addr + CALL_INSN_SIZE > (unsigned long)start) 288 return 1; 289 290 return 0; 291 } 292 293 static int __static_call_text_reserved(struct static_call_site *iter_start, 294 struct static_call_site *iter_stop, 295 void *start, void *end) 296 { 297 struct static_call_site *iter = iter_start; 298 299 while (iter < iter_stop) { 300 if (addr_conflict(iter, start, end)) 301 return 1; 302 iter++; 303 } 304 305 return 0; 306 } 307 308 #ifdef CONFIG_MODULES 309 310 static int __static_call_mod_text_reserved(void *start, void *end) 311 { 312 struct module *mod; 313 int ret; 314 315 preempt_disable(); 316 mod = __module_text_address((unsigned long)start); 317 WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod); 318 if (!try_module_get(mod)) 319 mod = NULL; 320 preempt_enable(); 321 322 if (!mod) 323 return 0; 324 325 ret = __static_call_text_reserved(mod->static_call_sites, 326 mod->static_call_sites + mod->num_static_call_sites, 327 start, end); 328 329 module_put(mod); 330 331 return ret; 332 } 333 334 static unsigned long tramp_key_lookup(unsigned long addr) 335 { 336 struct static_call_tramp_key *start = __start_static_call_tramp_key; 337 struct static_call_tramp_key *stop = __stop_static_call_tramp_key; 338 struct static_call_tramp_key *tramp_key; 339 340 for (tramp_key = start; tramp_key != stop; tramp_key++) { 341 unsigned long tramp; 342 343 tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp; 344 if (tramp == addr) 345 return (long)tramp_key->key + (long)&tramp_key->key; 346 } 347 348 return 0; 349 } 350 351 static int static_call_add_module(struct module *mod) 352 { 353 struct static_call_site *start = mod->static_call_sites; 354 struct static_call_site *stop = start + mod->num_static_call_sites; 355 struct static_call_site *site; 356 357 for (site = start; site != stop; site++) { 358 unsigned long s_key = __static_call_key(site); 359 unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS; 360 unsigned long key; 361 362 /* 363 * Is the key is exported, 'addr' points to the key, which 364 * means modules are allowed to call static_call_update() on 365 * it. 366 * 367 * Otherwise, the key isn't exported, and 'addr' points to the 368 * trampoline so we need to lookup the key. 369 * 370 * We go through this dance to prevent crazy modules from 371 * abusing sensitive static calls. 372 */ 373 if (!kernel_text_address(addr)) 374 continue; 375 376 key = tramp_key_lookup(addr); 377 if (!key) { 378 pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n", 379 static_call_addr(site)); 380 return -EINVAL; 381 } 382 383 key |= s_key & STATIC_CALL_SITE_FLAGS; 384 site->key = key - (long)&site->key; 385 } 386 387 return __static_call_init(mod, start, stop); 388 } 389 390 static void static_call_del_module(struct module *mod) 391 { 392 struct static_call_site *start = mod->static_call_sites; 393 struct static_call_site *stop = mod->static_call_sites + 394 mod->num_static_call_sites; 395 struct static_call_key *key, *prev_key = NULL; 396 struct static_call_mod *site_mod, **prev; 397 struct static_call_site *site; 398 399 for (site = start; site < stop; site++) { 400 key = static_call_key(site); 401 if (key == prev_key) 402 continue; 403 404 prev_key = key; 405 406 for (prev = &key->mods, site_mod = key->mods; 407 site_mod && site_mod->mod != mod; 408 prev = &site_mod->next, site_mod = site_mod->next) 409 ; 410 411 if (!site_mod) 412 continue; 413 414 *prev = site_mod->next; 415 kfree(site_mod); 416 } 417 } 418 419 static int static_call_module_notify(struct notifier_block *nb, 420 unsigned long val, void *data) 421 { 422 struct module *mod = data; 423 int ret = 0; 424 425 cpus_read_lock(); 426 static_call_lock(); 427 428 switch (val) { 429 case MODULE_STATE_COMING: 430 ret = static_call_add_module(mod); 431 if (ret) { 432 WARN(1, "Failed to allocate memory for static calls"); 433 static_call_del_module(mod); 434 } 435 break; 436 case MODULE_STATE_GOING: 437 static_call_del_module(mod); 438 break; 439 } 440 441 static_call_unlock(); 442 cpus_read_unlock(); 443 444 return notifier_from_errno(ret); 445 } 446 447 static struct notifier_block static_call_module_nb = { 448 .notifier_call = static_call_module_notify, 449 }; 450 451 #else 452 453 static inline int __static_call_mod_text_reserved(void *start, void *end) 454 { 455 return 0; 456 } 457 458 #endif /* CONFIG_MODULES */ 459 460 int static_call_text_reserved(void *start, void *end) 461 { 462 int ret = __static_call_text_reserved(__start_static_call_sites, 463 __stop_static_call_sites, start, end); 464 465 if (ret) 466 return ret; 467 468 return __static_call_mod_text_reserved(start, end); 469 } 470 471 int __init static_call_init(void) 472 { 473 int ret; 474 475 if (static_call_initialized) 476 return 0; 477 478 cpus_read_lock(); 479 static_call_lock(); 480 ret = __static_call_init(NULL, __start_static_call_sites, 481 __stop_static_call_sites); 482 static_call_unlock(); 483 cpus_read_unlock(); 484 485 if (ret) { 486 pr_err("Failed to allocate memory for static_call!\n"); 487 BUG(); 488 } 489 490 static_call_initialized = true; 491 492 #ifdef CONFIG_MODULES 493 register_module_notifier(&static_call_module_nb); 494 #endif 495 return 0; 496 } 497 early_initcall(static_call_init); 498 499 long __static_call_return0(void) 500 { 501 return 0; 502 } 503 504 #ifdef CONFIG_STATIC_CALL_SELFTEST 505 506 static int func_a(int x) 507 { 508 return x+1; 509 } 510 511 static int func_b(int x) 512 { 513 return x+2; 514 } 515 516 DEFINE_STATIC_CALL(sc_selftest, func_a); 517 518 static struct static_call_data { 519 int (*func)(int); 520 int val; 521 int expect; 522 } static_call_data [] __initdata = { 523 { NULL, 2, 3 }, 524 { func_b, 2, 4 }, 525 { func_a, 2, 3 } 526 }; 527 528 static int __init test_static_call_init(void) 529 { 530 int i; 531 532 for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) { 533 struct static_call_data *scd = &static_call_data[i]; 534 535 if (scd->func) 536 static_call_update(sc_selftest, scd->func); 537 538 WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect); 539 } 540 541 return 0; 542 } 543 early_initcall(test_static_call_init); 544 545 #endif /* CONFIG_STATIC_CALL_SELFTEST */ 546