xref: /linux/rust/macros/module.rs (revision d421fa4f73f59c512b30338a3453b62ed8fd9122)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 use std::ffi::CString;
4 
5 use proc_macro2::{
6     Literal,
7     TokenStream, //
8 };
9 use quote::{
10     format_ident,
11     quote, //
12 };
13 use syn::{
14     braced,
15     bracketed,
16     ext::IdentExt,
17     parse::{
18         Parse,
19         ParseStream, //
20     },
21     parse_quote,
22     punctuated::Punctuated,
23     Error,
24     Expr,
25     Ident,
26     LitStr,
27     Path,
28     Result,
29     Token,
30     Type, //
31 };
32 
33 use crate::helpers::*;
34 
35 struct ModInfoBuilder<'a> {
36     module: &'a str,
37     counter: usize,
38     ts: TokenStream,
39     param_ts: TokenStream,
40 }
41 
42 impl<'a> ModInfoBuilder<'a> {
43     fn new(module: &'a str) -> Self {
44         ModInfoBuilder {
45             module,
46             counter: 0,
47             ts: TokenStream::new(),
48             param_ts: TokenStream::new(),
49         }
50     }
51 
52     fn emit_base(&mut self, field: &str, content: &str, builtin: bool, param: bool) {
53         let string = if builtin {
54             // Built-in modules prefix their modinfo strings by `module.`.
55             format!(
56                 "{module}.{field}={content}\0",
57                 module = self.module,
58                 field = field,
59                 content = content
60             )
61         } else {
62             // Loadable modules' modinfo strings go as-is.
63             format!("{field}={content}\0")
64         };
65         let length = string.len();
66         let string = Literal::byte_string(string.as_bytes());
67         let cfg = if builtin {
68             quote!(#[cfg(not(MODULE))])
69         } else {
70             quote!(#[cfg(MODULE)])
71         };
72 
73         let counter = format_ident!(
74             "__{module}_{counter}",
75             module = self.module.to_uppercase(),
76             counter = self.counter
77         );
78         let item = quote! {
79             #cfg
80             #[cfg_attr(not(target_os = "macos"), link_section = ".modinfo")]
81             #[used(compiler)]
82             pub static #counter: [u8; #length] = *#string;
83         };
84 
85         if param {
86             self.param_ts.extend(item);
87         } else {
88             self.ts.extend(item);
89         }
90 
91         self.counter += 1;
92     }
93 
94     fn emit_only_builtin(&mut self, field: &str, content: &str, param: bool) {
95         self.emit_base(field, content, true, param)
96     }
97 
98     fn emit_only_loadable(&mut self, field: &str, content: &str, param: bool) {
99         self.emit_base(field, content, false, param)
100     }
101 
102     fn emit(&mut self, field: &str, content: &str) {
103         self.emit_internal(field, content, false);
104     }
105 
106     fn emit_internal(&mut self, field: &str, content: &str, param: bool) {
107         self.emit_only_builtin(field, content, param);
108         self.emit_only_loadable(field, content, param);
109     }
110 
111     fn emit_param(&mut self, field: &str, param: &str, content: &str) {
112         let content = format!("{param}:{content}", param = param, content = content);
113         self.emit_internal(field, &content, true);
114     }
115 
116     fn emit_params(&mut self, info: &ModuleInfo) {
117         let Some(params) = &info.params else {
118             return;
119         };
120 
121         for param in params {
122             let param_name_str = param.name.to_string();
123             let param_type_str = param.ptype.to_string();
124 
125             let ops = param_ops_path(&param_type_str);
126 
127             // Note: The spelling of these fields is dictated by the user space
128             // tool `modinfo`.
129             self.emit_param("parmtype", &param_name_str, &param_type_str);
130             self.emit_param("parm", &param_name_str, &param.description.value());
131 
132             let static_name = format_ident!("__{}_{}_struct", self.module, param.name);
133             let param_name_cstr =
134                 CString::new(param_name_str).expect("name contains NUL-terminator");
135             let param_name_cstr_with_module =
136                 CString::new(format!("{}.{}", self.module, param.name))
137                     .expect("name contains NUL-terminator");
138 
139             let param_name = &param.name;
140             let param_type = &param.ptype;
141             let param_default = &param.default;
142 
143             self.param_ts.extend(quote! {
144                 #[allow(non_upper_case_globals)]
145                 pub(crate) static #param_name:
146                     ::kernel::module_param::ModuleParamAccess<#param_type> =
147                         ::kernel::module_param::ModuleParamAccess::new(#param_default);
148 
149                 const _: () = {
150                     #[allow(non_upper_case_globals)]
151                     #[link_section = "__param"]
152                     #[used(compiler)]
153                     static #static_name:
154                         ::kernel::module_param::KernelParam =
155                         ::kernel::module_param::KernelParam::new(
156                             ::kernel::bindings::kernel_param {
157                                 name: kernel::str::as_char_ptr_in_const_context(
158                                     if ::core::cfg!(MODULE) {
159                                         #param_name_cstr
160                                     } else {
161                                         #param_name_cstr_with_module
162                                     }
163                                 ),
164                                 // SAFETY: `__this_module` is constructed by the kernel at load
165                                 // time and will not be freed until the module is unloaded.
166                                 #[cfg(MODULE)]
167                                 mod_: unsafe {
168                                     core::ptr::from_ref(&::kernel::bindings::__this_module)
169                                         .cast_mut()
170                                 },
171                                 #[cfg(not(MODULE))]
172                                 mod_: ::core::ptr::null_mut(),
173                                 ops: core::ptr::from_ref(&#ops),
174                                 perm: 0, // Will not appear in sysfs
175                                 level: -1,
176                                 flags: 0,
177                                 __bindgen_anon_1: ::kernel::bindings::kernel_param__bindgen_ty_1 {
178                                     arg: #param_name.as_void_ptr()
179                                 },
180                             }
181                         );
182                 };
183             });
184         }
185     }
186 }
187 
188 fn param_ops_path(param_type: &str) -> Path {
189     match param_type {
190         "i8" => parse_quote!(::kernel::module_param::PARAM_OPS_I8),
191         "u8" => parse_quote!(::kernel::module_param::PARAM_OPS_U8),
192         "i16" => parse_quote!(::kernel::module_param::PARAM_OPS_I16),
193         "u16" => parse_quote!(::kernel::module_param::PARAM_OPS_U16),
194         "i32" => parse_quote!(::kernel::module_param::PARAM_OPS_I32),
195         "u32" => parse_quote!(::kernel::module_param::PARAM_OPS_U32),
196         "i64" => parse_quote!(::kernel::module_param::PARAM_OPS_I64),
197         "u64" => parse_quote!(::kernel::module_param::PARAM_OPS_U64),
198         "isize" => parse_quote!(::kernel::module_param::PARAM_OPS_ISIZE),
199         "usize" => parse_quote!(::kernel::module_param::PARAM_OPS_USIZE),
200         t => panic!("Unsupported parameter type {}", t),
201     }
202 }
203 
204 /// Parse fields that are required to use a specific order.
205 ///
206 /// As fields must follow a specific order, we *could* just parse fields one by one by peeking.
207 /// However the error message generated when implementing that way is not very friendly.
208 ///
209 /// So instead we parse fields in an arbitrary order, but only enforce the ordering after parsing,
210 /// and if the wrong order is used, the proper order is communicated to the user with error message.
211 ///
212 /// Usage looks like this:
213 /// ```ignore
214 /// parse_ordered_fields! {
215 ///     from input;
216 ///
217 ///     // This will extract "foo: <field>" into a variable named "foo".
218 ///     // The variable will have type `Option<_>`.
219 ///     foo => <expression that parses the field>,
220 ///
221 ///     // If you need the variable name to be different than the key name.
222 ///     // This extracts "baz: <field>" into a variable named "bar".
223 ///     // You might want this if "baz" is a keyword.
224 ///     baz as bar => <expression that parse the field>,
225 ///
226 ///     // You can mark a key as required, and the variable will no longer be `Option`.
227 ///     // foobar will be of type `Expr` instead of `Option<Expr>`.
228 ///     foobar [required] => input.parse::<Expr>()?,
229 /// }
230 /// ```
231 macro_rules! parse_ordered_fields {
232     (@gen
233         [$input:expr]
234         [$([$name:ident; $key:ident; $parser:expr])*]
235         [$([$req_name:ident; $req_key:ident])*]
236     ) => {
237         $(let mut $name = None;)*
238 
239         const EXPECTED_KEYS: &[&str] = &[$(stringify!($key),)*];
240         const REQUIRED_KEYS: &[&str] = &[$(stringify!($req_key),)*];
241 
242         let span = $input.span();
243         let mut seen_keys = Vec::new();
244 
245         while !$input.is_empty() {
246             let key = $input.call(Ident::parse_any)?;
247 
248             if seen_keys.contains(&key) {
249                 Err(Error::new_spanned(
250                     &key,
251                     format!(r#"duplicated key "{key}". Keys can only be specified once."#),
252                 ))?
253             }
254 
255             $input.parse::<Token![:]>()?;
256 
257             match &*key.to_string() {
258                 $(
259                     stringify!($key) => $name = Some($parser),
260                 )*
261                 _ => {
262                     Err(Error::new_spanned(
263                         &key,
264                         format!(r#"unknown key "{key}". Valid keys are: {EXPECTED_KEYS:?}."#),
265                     ))?
266                 }
267             }
268 
269             $input.parse::<Token![,]>()?;
270             seen_keys.push(key);
271         }
272 
273         for key in REQUIRED_KEYS {
274             if !seen_keys.iter().any(|e| e == key) {
275                 Err(Error::new(span, format!(r#"missing required key "{key}""#)))?
276             }
277         }
278 
279         let mut ordered_keys: Vec<&str> = Vec::new();
280         for key in EXPECTED_KEYS {
281             if seen_keys.iter().any(|e| e == key) {
282                 ordered_keys.push(key);
283             }
284         }
285 
286         if seen_keys != ordered_keys {
287             Err(Error::new(
288                 span,
289                 format!(r#"keys are not ordered as expected. Order them like: {ordered_keys:?}."#),
290             ))?
291         }
292 
293         $(let $req_name = $req_name.expect("required field");)*
294     };
295 
296     // Handle required fields.
297     (@gen
298         [$input:expr] [$($tok:tt)*] [$($req:tt)*]
299         $key:ident as $name:ident [required] => $parser:expr,
300         $($rest:tt)*
301     ) => {
302         parse_ordered_fields!(
303             @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)* [$name; $key]] $($rest)*
304         )
305     };
306     (@gen
307         [$input:expr] [$($tok:tt)*] [$($req:tt)*]
308         $name:ident [required] => $parser:expr,
309         $($rest:tt)*
310     ) => {
311         parse_ordered_fields!(
312             @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)* [$name; $name]] $($rest)*
313         )
314     };
315 
316     // Handle optional fields.
317     (@gen
318         [$input:expr] [$($tok:tt)*] [$($req:tt)*]
319         $key:ident as $name:ident => $parser:expr,
320         $($rest:tt)*
321     ) => {
322         parse_ordered_fields!(
323             @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)*] $($rest)*
324         )
325     };
326     (@gen
327         [$input:expr] [$($tok:tt)*] [$($req:tt)*]
328         $name:ident => $parser:expr,
329         $($rest:tt)*
330     ) => {
331         parse_ordered_fields!(
332             @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)*] $($rest)*
333         )
334     };
335 
336     (from $input:expr; $($tok:tt)*) => {
337         parse_ordered_fields!(@gen [$input] [] [] $($tok)*)
338     }
339 }
340 
341 struct Parameter {
342     name: Ident,
343     ptype: Ident,
344     default: Expr,
345     description: LitStr,
346 }
347 
348 impl Parse for Parameter {
349     fn parse(input: ParseStream<'_>) -> Result<Self> {
350         let name = input.parse()?;
351         input.parse::<Token![:]>()?;
352         let ptype = input.parse()?;
353 
354         let fields;
355         braced!(fields in input);
356 
357         parse_ordered_fields! {
358             from fields;
359             default [required] => fields.parse()?,
360             description [required] => fields.parse()?,
361         }
362 
363         Ok(Self {
364             name,
365             ptype,
366             default,
367             description,
368         })
369     }
370 }
371 
372 pub(crate) struct ModuleInfo {
373     type_: Type,
374     license: AsciiLitStr,
375     name: AsciiLitStr,
376     authors: Option<Punctuated<AsciiLitStr, Token![,]>>,
377     description: Option<LitStr>,
378     alias: Option<Punctuated<AsciiLitStr, Token![,]>>,
379     firmware: Option<Punctuated<AsciiLitStr, Token![,]>>,
380     imports_ns: Option<Punctuated<AsciiLitStr, Token![,]>>,
381     params: Option<Punctuated<Parameter, Token![,]>>,
382 }
383 
384 impl Parse for ModuleInfo {
385     fn parse(input: ParseStream<'_>) -> Result<Self> {
386         parse_ordered_fields!(
387             from input;
388             type as type_ [required] => input.parse()?,
389             name [required] => input.parse()?,
390             authors => {
391                 let list;
392                 bracketed!(list in input);
393                 Punctuated::parse_terminated(&list)?
394             },
395             description => input.parse()?,
396             license [required] => input.parse()?,
397             alias => {
398                 let list;
399                 bracketed!(list in input);
400                 Punctuated::parse_terminated(&list)?
401             },
402             firmware => {
403                 let list;
404                 bracketed!(list in input);
405                 Punctuated::parse_terminated(&list)?
406             },
407             imports_ns => {
408                 let list;
409                 bracketed!(list in input);
410                 Punctuated::parse_terminated(&list)?
411             },
412             params => {
413                 let list;
414                 braced!(list in input);
415                 Punctuated::parse_terminated(&list)?
416             },
417         );
418 
419         Ok(ModuleInfo {
420             type_,
421             license,
422             name,
423             authors,
424             description,
425             alias,
426             firmware,
427             imports_ns,
428             params,
429         })
430     }
431 }
432 
433 pub(crate) fn module(info: ModuleInfo) -> Result<TokenStream> {
434     let ModuleInfo {
435         type_,
436         license,
437         name,
438         authors,
439         description,
440         alias,
441         firmware,
442         imports_ns,
443         params: _,
444     } = &info;
445 
446     // Rust does not allow hyphens in identifiers, use underscore instead.
447     let ident = name.value().replace('-', "_");
448     let mut modinfo = ModInfoBuilder::new(ident.as_ref());
449     if let Some(authors) = authors {
450         for author in authors {
451             modinfo.emit("author", &author.value());
452         }
453     }
454     if let Some(description) = description {
455         modinfo.emit("description", &description.value());
456     }
457     modinfo.emit("license", &license.value());
458     if let Some(aliases) = alias {
459         for alias in aliases {
460             modinfo.emit("alias", &alias.value());
461         }
462     }
463     if let Some(firmware) = firmware {
464         for fw in firmware {
465             modinfo.emit("firmware", &fw.value());
466         }
467     }
468     if let Some(imports) = imports_ns {
469         for ns in imports {
470             modinfo.emit("import_ns", &ns.value());
471         }
472     }
473 
474     // Built-in modules also export the `file` modinfo string.
475     let file =
476         std::env::var("RUST_MODFILE").expect("Unable to fetch RUST_MODFILE environmental variable");
477     modinfo.emit_only_builtin("file", &file, false);
478 
479     modinfo.emit_params(&info);
480 
481     let modinfo_ts = modinfo.ts;
482     let params_ts = modinfo.param_ts;
483 
484     let ident_init = format_ident!("__{ident}_init");
485     let ident_exit = format_ident!("__{ident}_exit");
486     let ident_initcall = format_ident!("__{ident}_initcall");
487     let initcall_section = ".initcall6.init";
488 
489     let global_asm = format!(
490         r#".section "{initcall_section}", "a"
491         __{ident}_initcall:
492             .long   __{ident}_init - .
493             .previous
494         "#
495     );
496 
497     let name_cstr = CString::new(name.value()).expect("name contains NUL-terminator");
498 
499     Ok(quote! {
500         /// The module name.
501         ///
502         /// Used by the printing macros, e.g. [`info!`].
503         const __LOG_PREFIX: &[u8] = #name_cstr.to_bytes_with_nul();
504 
505         // SAFETY: `__this_module` is constructed by the kernel at load time and will not be
506         // freed until the module is unloaded.
507         #[cfg(MODULE)]
508         static THIS_MODULE: ::kernel::ThisModule = unsafe {
509             extern "C" {
510                 static __this_module: ::kernel::types::Opaque<::kernel::bindings::module>;
511             };
512 
513             ::kernel::ThisModule::from_ptr(__this_module.get())
514         };
515 
516         #[cfg(not(MODULE))]
517         static THIS_MODULE: ::kernel::ThisModule = unsafe {
518             ::kernel::ThisModule::from_ptr(::core::ptr::null_mut())
519         };
520 
521         /// The `LocalModule` type is the type of the module created by `module!`,
522         /// `module_pci_driver!`, `module_platform_driver!`, etc.
523         type LocalModule = #type_;
524 
525         impl ::kernel::ModuleMetadata for #type_ {
526             const NAME: &'static ::kernel::str::CStr = #name_cstr;
527         }
528 
529         // Double nested modules, since then nobody can access the public items inside.
530         #[doc(hidden)]
531         mod __module_init {
532             mod __module_init {
533                 use pin_init::PinInit;
534 
535                 /// The "Rust loadable module" mark.
536                 //
537                 // This may be best done another way later on, e.g. as a new modinfo
538                 // key or a new section. For the moment, keep it simple.
539                 #[cfg(MODULE)]
540                 #[used(compiler)]
541                 static __IS_RUST_MODULE: () = ();
542 
543                 static mut __MOD: ::core::mem::MaybeUninit<super::super::LocalModule> =
544                     ::core::mem::MaybeUninit::uninit();
545 
546                 // Loadable modules need to export the `{init,cleanup}_module` identifiers.
547                 /// # Safety
548                 ///
549                 /// This function must not be called after module initialization, because it may be
550                 /// freed after that completes.
551                 #[cfg(MODULE)]
552                 #[no_mangle]
553                 #[link_section = ".init.text"]
554                 pub unsafe extern "C" fn init_module() -> ::kernel::ffi::c_int {
555                     // SAFETY: This function is inaccessible to the outside due to the double
556                     // module wrapping it. It is called exactly once by the C side via its
557                     // unique name.
558                     unsafe { __init() }
559                 }
560 
561                 #[cfg(MODULE)]
562                 #[used(compiler)]
563                 #[link_section = ".init.data"]
564                 static __UNIQUE_ID___addressable_init_module: unsafe extern "C" fn() -> i32 =
565                     init_module;
566 
567                 #[cfg(MODULE)]
568                 #[no_mangle]
569                 #[link_section = ".exit.text"]
570                 pub extern "C" fn cleanup_module() {
571                     // SAFETY:
572                     // - This function is inaccessible to the outside due to the double
573                     //   module wrapping it. It is called exactly once by the C side via its
574                     //   unique name,
575                     // - furthermore it is only called after `init_module` has returned `0`
576                     //   (which delegates to `__init`).
577                     unsafe { __exit() }
578                 }
579 
580                 #[cfg(MODULE)]
581                 #[used(compiler)]
582                 #[link_section = ".exit.data"]
583                 static __UNIQUE_ID___addressable_cleanup_module: extern "C" fn() = cleanup_module;
584 
585                 // Built-in modules are initialized through an initcall pointer
586                 // and the identifiers need to be unique.
587                 #[cfg(not(MODULE))]
588                 #[cfg(not(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS))]
589                 #[link_section = #initcall_section]
590                 #[used(compiler)]
591                 pub static #ident_initcall: extern "C" fn() ->
592                     ::kernel::ffi::c_int = #ident_init;
593 
594                 #[cfg(not(MODULE))]
595                 #[cfg(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS)]
596                 ::core::arch::global_asm!(#global_asm);
597 
598                 #[cfg(not(MODULE))]
599                 #[no_mangle]
600                 pub extern "C" fn #ident_init() -> ::kernel::ffi::c_int {
601                     // SAFETY: This function is inaccessible to the outside due to the double
602                     // module wrapping it. It is called exactly once by the C side via its
603                     // placement above in the initcall section.
604                     unsafe { __init() }
605                 }
606 
607                 #[cfg(not(MODULE))]
608                 #[no_mangle]
609                 pub extern "C" fn #ident_exit() {
610                     // SAFETY:
611                     // - This function is inaccessible to the outside due to the double
612                     //   module wrapping it. It is called exactly once by the C side via its
613                     //   unique name,
614                     // - furthermore it is only called after `#ident_init` has
615                     //   returned `0` (which delegates to `__init`).
616                     unsafe { __exit() }
617                 }
618 
619                 /// # Safety
620                 ///
621                 /// This function must only be called once.
622                 unsafe fn __init() -> ::kernel::ffi::c_int {
623                     let initer = <super::super::LocalModule as ::kernel::InPlaceModule>::init(
624                         &super::super::THIS_MODULE
625                     );
626                     // SAFETY: No data race, since `__MOD` can only be accessed by this module
627                     // and there only `__init` and `__exit` access it. These functions are only
628                     // called once and `__exit` cannot be called before or during `__init`.
629                     match unsafe { initer.__pinned_init(__MOD.as_mut_ptr()) } {
630                         Ok(m) => 0,
631                         Err(e) => e.to_errno(),
632                     }
633                 }
634 
635                 /// # Safety
636                 ///
637                 /// This function must
638                 /// - only be called once,
639                 /// - be called after `__init` has been called and returned `0`.
640                 unsafe fn __exit() {
641                     // SAFETY: No data race, since `__MOD` can only be accessed by this module
642                     // and there only `__init` and `__exit` access it. These functions are only
643                     // called once and `__init` was already called.
644                     unsafe {
645                         // Invokes `drop()` on `__MOD`, which should be used for cleanup.
646                         __MOD.assume_init_drop();
647                     }
648                 }
649 
650                 #modinfo_ts
651             }
652         }
653 
654         mod module_parameters {
655             #params_ts
656         }
657     })
658 }
659