]> git.proxmox.com Git - mirror_ubuntu-hirsute-kernel.git/blame - kernel/static_call.c
static_call: Fix the module key fixup
[mirror_ubuntu-hirsute-kernel.git] / kernel / static_call.c
CommitLineData
9183c3f9
JP
1// SPDX-License-Identifier: GPL-2.0
2#include <linux/init.h>
3#include <linux/static_call.h>
4#include <linux/bug.h>
5#include <linux/smp.h>
6#include <linux/sort.h>
7#include <linux/slab.h>
8#include <linux/module.h>
9#include <linux/cpu.h>
10#include <linux/processor.h>
11#include <asm/sections.h>
12
13extern struct static_call_site __start_static_call_sites[],
14 __stop_static_call_sites[];
6e2f698e
JP
15extern struct static_call_tramp_key __start_static_call_tramp_key[],
16 __stop_static_call_tramp_key[];
9183c3f9
JP
17
18static bool static_call_initialized;
19
9183c3f9
JP
20/* mutex to protect key modules/sites */
21static DEFINE_MUTEX(static_call_mutex);
22
23static void static_call_lock(void)
24{
25 mutex_lock(&static_call_mutex);
26}
27
28static void static_call_unlock(void)
29{
30 mutex_unlock(&static_call_mutex);
31}
32
33static inline void *static_call_addr(struct static_call_site *site)
34{
35 return (void *)((long)site->addr + (long)&site->addr);
36}
37
38
39static inline struct static_call_key *static_call_key(const struct static_call_site *site)
40{
41 return (struct static_call_key *)
5b06fd3b 42 (((long)site->key + (long)&site->key) & ~STATIC_CALL_SITE_FLAGS);
9183c3f9
JP
43}
44
45/* These assume the key is word-aligned. */
46static inline bool static_call_is_init(struct static_call_site *site)
47{
5b06fd3b
PZ
48 return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_INIT;
49}
50
51static inline bool static_call_is_tail(struct static_call_site *site)
52{
53 return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_TAIL;
9183c3f9
JP
54}
55
56static inline void static_call_set_init(struct static_call_site *site)
57{
5b06fd3b 58 site->key = ((long)static_call_key(site) | STATIC_CALL_SITE_INIT) -
9183c3f9
JP
59 (long)&site->key;
60}
61
62static int static_call_site_cmp(const void *_a, const void *_b)
63{
64 const struct static_call_site *a = _a;
65 const struct static_call_site *b = _b;
66 const struct static_call_key *key_a = static_call_key(a);
67 const struct static_call_key *key_b = static_call_key(b);
68
69 if (key_a < key_b)
70 return -1;
71
72 if (key_a > key_b)
73 return 1;
74
75 return 0;
76}
77
78static void static_call_site_swap(void *_a, void *_b, int size)
79{
80 long delta = (unsigned long)_a - (unsigned long)_b;
81 struct static_call_site *a = _a;
82 struct static_call_site *b = _b;
83 struct static_call_site tmp = *a;
84
85 a->addr = b->addr - delta;
86 a->key = b->key - delta;
87
88 b->addr = tmp.addr + delta;
89 b->key = tmp.key + delta;
90}
91
92static inline void static_call_sort_entries(struct static_call_site *start,
93 struct static_call_site *stop)
94{
95 sort(start, stop - start, sizeof(struct static_call_site),
96 static_call_site_cmp, static_call_site_swap);
97}
98
a945c834
PZ
99static inline bool static_call_key_has_mods(struct static_call_key *key)
100{
101 return !(key->type & 1);
102}
103
104static inline struct static_call_mod *static_call_key_next(struct static_call_key *key)
105{
106 if (!static_call_key_has_mods(key))
107 return NULL;
108
109 return key->mods;
110}
111
112static inline struct static_call_site *static_call_key_sites(struct static_call_key *key)
113{
114 if (static_call_key_has_mods(key))
115 return NULL;
116
117 return (struct static_call_site *)(key->type & ~1);
118}
119
9183c3f9
JP
120void __static_call_update(struct static_call_key *key, void *tramp, void *func)
121{
122 struct static_call_site *site, *stop;
a945c834 123 struct static_call_mod *site_mod, first;
9183c3f9
JP
124
125 cpus_read_lock();
126 static_call_lock();
127
128 if (key->func == func)
129 goto done;
130
131 key->func = func;
132
5b06fd3b 133 arch_static_call_transform(NULL, tramp, func, false);
9183c3f9
JP
134
135 /*
136 * If uninitialized, we'll not update the callsites, but they still
137 * point to the trampoline and we just patched that.
138 */
139 if (WARN_ON_ONCE(!static_call_initialized))
140 goto done;
141
a945c834
PZ
142 first = (struct static_call_mod){
143 .next = static_call_key_next(key),
144 .mod = NULL,
145 .sites = static_call_key_sites(key),
146 };
147
148 for (site_mod = &first; site_mod; site_mod = site_mod->next) {
9183c3f9
JP
149 struct module *mod = site_mod->mod;
150
151 if (!site_mod->sites) {
152 /*
153 * This can happen if the static call key is defined in
154 * a module which doesn't use it.
a945c834
PZ
155 *
156 * It also happens in the has_mods case, where the
157 * 'first' entry has no sites associated with it.
9183c3f9
JP
158 */
159 continue;
160 }
161
162 stop = __stop_static_call_sites;
163
164#ifdef CONFIG_MODULES
165 if (mod) {
166 stop = mod->static_call_sites +
167 mod->num_static_call_sites;
168 }
169#endif
170
171 for (site = site_mod->sites;
172 site < stop && static_call_key(site) == key; site++) {
173 void *site_addr = static_call_addr(site);
174
175 if (static_call_is_init(site)) {
176 /*
177 * Don't write to call sites which were in
178 * initmem and have since been freed.
179 */
180 if (!mod && system_state >= SYSTEM_RUNNING)
181 continue;
182 if (mod && !within_module_init((unsigned long)site_addr, mod))
183 continue;
184 }
185
186 if (!kernel_text_address((unsigned long)site_addr)) {
d4ea1929
PZ
187 /*
188 * This skips patching built-in __exit, which
189 * is part of init_section_contains() but is
190 * not part of kernel_text_address().
191 *
192 * Skipping built-in __exit is fine since it
193 * will never be executed.
194 */
195 WARN_ONCE(!static_call_is_init(site),
196 "can't patch static call site at %pS",
9183c3f9
JP
197 site_addr);
198 continue;
199 }
200
5b06fd3b
PZ
201 arch_static_call_transform(site_addr, NULL, func,
202 static_call_is_tail(site));
9183c3f9
JP
203 }
204 }
205
206done:
207 static_call_unlock();
208 cpus_read_unlock();
209}
210EXPORT_SYMBOL_GPL(__static_call_update);
211
212static int __static_call_init(struct module *mod,
213 struct static_call_site *start,
214 struct static_call_site *stop)
215{
216 struct static_call_site *site;
217 struct static_call_key *key, *prev_key = NULL;
218 struct static_call_mod *site_mod;
219
220 if (start == stop)
221 return 0;
222
223 static_call_sort_entries(start, stop);
224
225 for (site = start; site < stop; site++) {
226 void *site_addr = static_call_addr(site);
227
228 if ((mod && within_module_init((unsigned long)site_addr, mod)) ||
229 (!mod && init_section_contains(site_addr, 1)))
230 static_call_set_init(site);
231
232 key = static_call_key(site);
233 if (key != prev_key) {
234 prev_key = key;
235
a945c834
PZ
236 /*
237 * For vmlinux (!mod) avoid the allocation by storing
238 * the sites pointer in the key itself. Also see
239 * __static_call_update()'s @first.
240 *
241 * This allows architectures (eg. x86) to call
242 * static_call_init() before memory allocation works.
243 */
244 if (!mod) {
245 key->sites = site;
246 key->type |= 1;
247 goto do_transform;
248 }
249
9183c3f9
JP
250 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
251 if (!site_mod)
252 return -ENOMEM;
253
a945c834
PZ
254 /*
255 * When the key has a direct sites pointer, extract
256 * that into an explicit struct static_call_mod, so we
257 * can have a list of modules.
258 */
259 if (static_call_key_sites(key)) {
260 site_mod->mod = NULL;
261 site_mod->next = NULL;
262 site_mod->sites = static_call_key_sites(key);
263
264 key->mods = site_mod;
265
266 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
267 if (!site_mod)
268 return -ENOMEM;
269 }
270
9183c3f9
JP
271 site_mod->mod = mod;
272 site_mod->sites = site;
a945c834 273 site_mod->next = static_call_key_next(key);
9183c3f9
JP
274 key->mods = site_mod;
275 }
276
a945c834 277do_transform:
5b06fd3b
PZ
278 arch_static_call_transform(site_addr, NULL, key->func,
279 static_call_is_tail(site));
9183c3f9
JP
280 }
281
282 return 0;
283}
284
6333e8f7
PZ
285static int addr_conflict(struct static_call_site *site, void *start, void *end)
286{
287 unsigned long addr = (unsigned long)static_call_addr(site);
288
289 if (addr <= (unsigned long)end &&
290 addr + CALL_INSN_SIZE > (unsigned long)start)
291 return 1;
292
293 return 0;
294}
295
296static int __static_call_text_reserved(struct static_call_site *iter_start,
297 struct static_call_site *iter_stop,
298 void *start, void *end)
299{
300 struct static_call_site *iter = iter_start;
301
302 while (iter < iter_stop) {
303 if (addr_conflict(iter, start, end))
304 return 1;
305 iter++;
306 }
307
308 return 0;
309}
310
9183c3f9
JP
311#ifdef CONFIG_MODULES
312
6333e8f7
PZ
313static int __static_call_mod_text_reserved(void *start, void *end)
314{
315 struct module *mod;
316 int ret;
317
318 preempt_disable();
319 mod = __module_text_address((unsigned long)start);
320 WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod);
321 if (!try_module_get(mod))
322 mod = NULL;
323 preempt_enable();
324
325 if (!mod)
326 return 0;
327
328 ret = __static_call_text_reserved(mod->static_call_sites,
329 mod->static_call_sites + mod->num_static_call_sites,
330 start, end);
331
332 module_put(mod);
333
334 return ret;
335}
336
6e2f698e
JP
337static unsigned long tramp_key_lookup(unsigned long addr)
338{
339 struct static_call_tramp_key *start = __start_static_call_tramp_key;
340 struct static_call_tramp_key *stop = __stop_static_call_tramp_key;
341 struct static_call_tramp_key *tramp_key;
342
343 for (tramp_key = start; tramp_key != stop; tramp_key++) {
344 unsigned long tramp;
345
346 tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp;
347 if (tramp == addr)
348 return (long)tramp_key->key + (long)&tramp_key->key;
349 }
350
351 return 0;
352}
353
9183c3f9
JP
354static int static_call_add_module(struct module *mod)
355{
6e2f698e
JP
356 struct static_call_site *start = mod->static_call_sites;
357 struct static_call_site *stop = start + mod->num_static_call_sites;
358 struct static_call_site *site;
359
360 for (site = start; site != stop; site++) {
373c7c37
PZ
361 unsigned long s_key = (long)site->key + (long)&site->key;
362 unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS;
6e2f698e
JP
363 unsigned long key;
364
365 /*
366 * Is the key is exported, 'addr' points to the key, which
367 * means modules are allowed to call static_call_update() on
368 * it.
369 *
370 * Otherwise, the key isn't exported, and 'addr' points to the
371 * trampoline so we need to lookup the key.
372 *
373 * We go through this dance to prevent crazy modules from
374 * abusing sensitive static calls.
375 */
376 if (!kernel_text_address(addr))
377 continue;
378
379 key = tramp_key_lookup(addr);
380 if (!key) {
381 pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n",
382 static_call_addr(site));
383 return -EINVAL;
384 }
385
373c7c37
PZ
386 key |= s_key & STATIC_CALL_SITE_FLAGS;
387 site->key = key - (long)&site->key;
6e2f698e
JP
388 }
389
390 return __static_call_init(mod, start, stop);
9183c3f9
JP
391}
392
393static void static_call_del_module(struct module *mod)
394{
395 struct static_call_site *start = mod->static_call_sites;
396 struct static_call_site *stop = mod->static_call_sites +
397 mod->num_static_call_sites;
398 struct static_call_key *key, *prev_key = NULL;
399 struct static_call_mod *site_mod, **prev;
400 struct static_call_site *site;
401
402 for (site = start; site < stop; site++) {
403 key = static_call_key(site);
404 if (key == prev_key)
405 continue;
406
407 prev_key = key;
408
409 for (prev = &key->mods, site_mod = key->mods;
410 site_mod && site_mod->mod != mod;
411 prev = &site_mod->next, site_mod = site_mod->next)
412 ;
413
414 if (!site_mod)
415 continue;
416
417 *prev = site_mod->next;
418 kfree(site_mod);
419 }
420}
421
422static int static_call_module_notify(struct notifier_block *nb,
423 unsigned long val, void *data)
424{
425 struct module *mod = data;
426 int ret = 0;
427
428 cpus_read_lock();
429 static_call_lock();
430
431 switch (val) {
432 case MODULE_STATE_COMING:
433 ret = static_call_add_module(mod);
434 if (ret) {
435 WARN(1, "Failed to allocate memory for static calls");
436 static_call_del_module(mod);
437 }
438 break;
439 case MODULE_STATE_GOING:
440 static_call_del_module(mod);
441 break;
442 }
443
444 static_call_unlock();
445 cpus_read_unlock();
446
447 return notifier_from_errno(ret);
448}
449
450static struct notifier_block static_call_module_nb = {
451 .notifier_call = static_call_module_notify,
452};
453
6333e8f7
PZ
454#else
455
456static inline int __static_call_mod_text_reserved(void *start, void *end)
457{
458 return 0;
459}
460
9183c3f9
JP
461#endif /* CONFIG_MODULES */
462
6333e8f7
PZ
463int static_call_text_reserved(void *start, void *end)
464{
465 int ret = __static_call_text_reserved(__start_static_call_sites,
466 __stop_static_call_sites, start, end);
467
468 if (ret)
469 return ret;
470
471 return __static_call_mod_text_reserved(start, end);
472}
473
69e0ad37 474int __init static_call_init(void)
9183c3f9
JP
475{
476 int ret;
477
478 if (static_call_initialized)
69e0ad37 479 return 0;
9183c3f9
JP
480
481 cpus_read_lock();
482 static_call_lock();
483 ret = __static_call_init(NULL, __start_static_call_sites,
484 __stop_static_call_sites);
485 static_call_unlock();
486 cpus_read_unlock();
487
488 if (ret) {
489 pr_err("Failed to allocate memory for static_call!\n");
490 BUG();
491 }
492
493 static_call_initialized = true;
494
495#ifdef CONFIG_MODULES
496 register_module_notifier(&static_call_module_nb);
497#endif
69e0ad37 498 return 0;
9183c3f9
JP
499}
500early_initcall(static_call_init);
f03c4129
PZ
501
502#ifdef CONFIG_STATIC_CALL_SELFTEST
503
504static int func_a(int x)
505{
506 return x+1;
507}
508
509static int func_b(int x)
510{
511 return x+2;
512}
513
514DEFINE_STATIC_CALL(sc_selftest, func_a);
515
516static struct static_call_data {
517 int (*func)(int);
518 int val;
519 int expect;
520} static_call_data [] __initdata = {
521 { NULL, 2, 3 },
522 { func_b, 2, 4 },
523 { func_a, 2, 3 }
524};
525
526static int __init test_static_call_init(void)
527{
528 int i;
529
530 for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) {
531 struct static_call_data *scd = &static_call_data[i];
532
533 if (scd->func)
534 static_call_update(sc_selftest, scd->func);
535
536 WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect);
537 }
538
539 return 0;
540}
541early_initcall(test_static_call_init);
542
543#endif /* CONFIG_STATIC_CALL_SELFTEST */