xref: /linux/rust/macros/module.rs (revision 3d5731a6be6a43d8c90d766e1404502c44545241)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 use std::ffi::CString;
4 
5 use proc_macro2::{
6     Literal,
7     TokenStream, //
8 };
9 use quote::{
10     format_ident,
11     quote, //
12 };
13 use syn::{
14     braced,
15     bracketed,
16     ext::IdentExt,
17     parse::{
18         Parse,
19         ParseStream, //
20     },
21     parse_quote,
22     punctuated::Punctuated,
23     Error,
24     Expr,
25     Ident,
26     LitStr,
27     Path,
28     Result,
29     Token,
30     Type, //
31 };
32 
33 use crate::helpers::*;
34 
35 struct ModInfoBuilder<'a> {
36     module: &'a str,
37     counter: usize,
38     ts: TokenStream,
39     param_ts: TokenStream,
40 }
41 
42 impl<'a> ModInfoBuilder<'a> {
43     fn new(module: &'a str) -> Self {
44         ModInfoBuilder {
45             module,
46             counter: 0,
47             ts: TokenStream::new(),
48             param_ts: TokenStream::new(),
49         }
50     }
51 
52     fn emit_base(&mut self, field: &str, content: &str, builtin: bool, param: bool) {
53         let string = if builtin {
54             // Built-in modules prefix their modinfo strings by `module.`.
55             format!(
56                 "{module}.{field}={content}\0",
57                 module = self.module,
58                 field = field,
59                 content = content
60             )
61         } else {
62             // Loadable modules' modinfo strings go as-is.
63             format!("{field}={content}\0")
64         };
65         let length = string.len();
66         let string = Literal::byte_string(string.as_bytes());
67         let cfg = if builtin {
68             quote!(#[cfg(not(MODULE))])
69         } else {
70             quote!(#[cfg(MODULE)])
71         };
72 
73         let counter = format_ident!(
74             "__{module}_{counter}",
75             module = self.module.to_uppercase(),
76             counter = self.counter
77         );
78         let item = quote! {
79             #cfg
80             #[doc(hidden)]
81             #[cfg_attr(not(target_os = "macos"), link_section = ".modinfo")]
82             #[used(compiler)]
83             pub static #counter: [u8; #length] = *#string;
84         };
85 
86         if param {
87             self.param_ts.extend(item);
88         } else {
89             self.ts.extend(item);
90         }
91 
92         self.counter += 1;
93     }
94 
95     fn emit_only_builtin(&mut self, field: &str, content: &str, param: bool) {
96         self.emit_base(field, content, true, param)
97     }
98 
99     fn emit_only_loadable(&mut self, field: &str, content: &str, param: bool) {
100         self.emit_base(field, content, false, param)
101     }
102 
103     fn emit(&mut self, field: &str, content: &str) {
104         self.emit_internal(field, content, false);
105     }
106 
107     fn emit_internal(&mut self, field: &str, content: &str, param: bool) {
108         self.emit_only_builtin(field, content, param);
109         self.emit_only_loadable(field, content, param);
110     }
111 
112     fn emit_param(&mut self, field: &str, param: &str, content: &str) {
113         let content = format!("{param}:{content}", param = param, content = content);
114         self.emit_internal(field, &content, true);
115     }
116 
117     fn emit_params(&mut self, info: &ModuleInfo) {
118         let Some(params) = &info.params else {
119             return;
120         };
121 
122         for param in params {
123             let param_name_str = param.name.to_string();
124             let param_type_str = param.ptype.to_string();
125 
126             let ops = param_ops_path(&param_type_str);
127 
128             // Note: The spelling of these fields is dictated by the user space
129             // tool `modinfo`.
130             self.emit_param("parmtype", &param_name_str, &param_type_str);
131             self.emit_param("parm", &param_name_str, &param.description.value());
132 
133             let static_name = format_ident!("__{}_{}_struct", self.module, param.name);
134             let param_name_cstr =
135                 CString::new(param_name_str).expect("name contains NUL-terminator");
136             let param_name_cstr_with_module =
137                 CString::new(format!("{}.{}", self.module, param.name))
138                     .expect("name contains NUL-terminator");
139 
140             let param_name = &param.name;
141             let param_type = &param.ptype;
142             let param_default = &param.default;
143 
144             self.param_ts.extend(quote! {
145                 #[allow(non_upper_case_globals)]
146                 pub(crate) static #param_name:
147                     ::kernel::module_param::ModuleParamAccess<#param_type> =
148                         ::kernel::module_param::ModuleParamAccess::new(#param_default);
149 
150                 const _: () = {
151                     #[allow(non_upper_case_globals)]
152                     #[link_section = "__param"]
153                     #[used(compiler)]
154                     static #static_name:
155                         ::kernel::module_param::KernelParam =
156                         ::kernel::module_param::KernelParam::new(
157                             ::kernel::bindings::kernel_param {
158                                 name: kernel::str::as_char_ptr_in_const_context(
159                                     if ::core::cfg!(MODULE) {
160                                         #param_name_cstr
161                                     } else {
162                                         #param_name_cstr_with_module
163                                     }
164                                 ),
165                                 // SAFETY: `__this_module` is constructed by the kernel at load
166                                 // time and will not be freed until the module is unloaded.
167                                 #[cfg(MODULE)]
168                                 mod_: unsafe {
169                                     core::ptr::from_ref(&::kernel::bindings::__this_module)
170                                         .cast_mut()
171                                 },
172                                 #[cfg(not(MODULE))]
173                                 mod_: ::core::ptr::null_mut(),
174                                 ops: core::ptr::from_ref(&#ops),
175                                 perm: 0, // Will not appear in sysfs
176                                 level: -1,
177                                 flags: 0,
178                                 __bindgen_anon_1: ::kernel::bindings::kernel_param__bindgen_ty_1 {
179                                     arg: #param_name.as_void_ptr()
180                                 },
181                             }
182                         );
183                 };
184             });
185         }
186     }
187 }
188 
189 fn param_ops_path(param_type: &str) -> Path {
190     match param_type {
191         "i8" => parse_quote!(::kernel::module_param::PARAM_OPS_I8),
192         "u8" => parse_quote!(::kernel::module_param::PARAM_OPS_U8),
193         "i16" => parse_quote!(::kernel::module_param::PARAM_OPS_I16),
194         "u16" => parse_quote!(::kernel::module_param::PARAM_OPS_U16),
195         "i32" => parse_quote!(::kernel::module_param::PARAM_OPS_I32),
196         "u32" => parse_quote!(::kernel::module_param::PARAM_OPS_U32),
197         "i64" => parse_quote!(::kernel::module_param::PARAM_OPS_I64),
198         "u64" => parse_quote!(::kernel::module_param::PARAM_OPS_U64),
199         "isize" => parse_quote!(::kernel::module_param::PARAM_OPS_ISIZE),
200         "usize" => parse_quote!(::kernel::module_param::PARAM_OPS_USIZE),
201         t => panic!("Unsupported parameter type {}", t),
202     }
203 }
204 
205 /// Parse fields that are required to use a specific order.
206 ///
207 /// As fields must follow a specific order, we *could* just parse fields one by one by peeking.
208 /// However the error message generated when implementing that way is not very friendly.
209 ///
210 /// So instead we parse fields in an arbitrary order, but only enforce the ordering after parsing,
211 /// and if the wrong order is used, the proper order is communicated to the user with error message.
212 ///
213 /// Usage looks like this:
214 /// ```ignore
215 /// parse_ordered_fields! {
216 ///     from input;
217 ///
218 ///     // This will extract "foo: <field>" into a variable named "foo".
219 ///     // The variable will have type `Option<_>`.
220 ///     foo => <expression that parses the field>,
221 ///
222 ///     // If you need the variable name to be different than the key name.
223 ///     // This extracts "baz: <field>" into a variable named "bar".
224 ///     // You might want this if "baz" is a keyword.
225 ///     baz as bar => <expression that parse the field>,
226 ///
227 ///     // You can mark a key as required, and the variable will no longer be `Option`.
228 ///     // foobar will be of type `Expr` instead of `Option<Expr>`.
229 ///     foobar [required] => input.parse::<Expr>()?,
230 /// }
231 /// ```
232 macro_rules! parse_ordered_fields {
233     (@gen
234         [$input:expr]
235         [$([$name:ident; $key:ident; $parser:expr])*]
236         [$([$req_name:ident; $req_key:ident])*]
237     ) => {
238         $(let mut $name = None;)*
239 
240         const EXPECTED_KEYS: &[&str] = &[$(stringify!($key),)*];
241         const REQUIRED_KEYS: &[&str] = &[$(stringify!($req_key),)*];
242 
243         let span = $input.span();
244         let mut seen_keys = Vec::new();
245 
246         while !$input.is_empty() {
247             let key = $input.call(Ident::parse_any)?;
248 
249             if seen_keys.contains(&key) {
250                 Err(Error::new_spanned(
251                     &key,
252                     format!(r#"duplicated key "{key}". Keys can only be specified once."#),
253                 ))?
254             }
255 
256             $input.parse::<Token![:]>()?;
257 
258             match &*key.to_string() {
259                 $(
260                     stringify!($key) => $name = Some($parser),
261                 )*
262                 _ => {
263                     Err(Error::new_spanned(
264                         &key,
265                         format!(r#"unknown key "{key}". Valid keys are: {EXPECTED_KEYS:?}."#),
266                     ))?
267                 }
268             }
269 
270             $input.parse::<Token![,]>()?;
271             seen_keys.push(key);
272         }
273 
274         for key in REQUIRED_KEYS {
275             if !seen_keys.iter().any(|e| e == key) {
276                 Err(Error::new(span, format!(r#"missing required key "{key}""#)))?
277             }
278         }
279 
280         let mut ordered_keys: Vec<&str> = Vec::new();
281         for key in EXPECTED_KEYS {
282             if seen_keys.iter().any(|e| e == key) {
283                 ordered_keys.push(key);
284             }
285         }
286 
287         if seen_keys != ordered_keys {
288             Err(Error::new(
289                 span,
290                 format!(r#"keys are not ordered as expected. Order them like: {ordered_keys:?}."#),
291             ))?
292         }
293 
294         $(let $req_name = $req_name.expect("required field");)*
295     };
296 
297     // Handle required fields.
298     (@gen
299         [$input:expr] [$($tok:tt)*] [$($req:tt)*]
300         $key:ident as $name:ident [required] => $parser:expr,
301         $($rest:tt)*
302     ) => {
303         parse_ordered_fields!(
304             @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)* [$name; $key]] $($rest)*
305         )
306     };
307     (@gen
308         [$input:expr] [$($tok:tt)*] [$($req:tt)*]
309         $name:ident [required] => $parser:expr,
310         $($rest:tt)*
311     ) => {
312         parse_ordered_fields!(
313             @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)* [$name; $name]] $($rest)*
314         )
315     };
316 
317     // Handle optional fields.
318     (@gen
319         [$input:expr] [$($tok:tt)*] [$($req:tt)*]
320         $key:ident as $name:ident => $parser:expr,
321         $($rest:tt)*
322     ) => {
323         parse_ordered_fields!(
324             @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)*] $($rest)*
325         )
326     };
327     (@gen
328         [$input:expr] [$($tok:tt)*] [$($req:tt)*]
329         $name:ident => $parser:expr,
330         $($rest:tt)*
331     ) => {
332         parse_ordered_fields!(
333             @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)*] $($rest)*
334         )
335     };
336 
337     (from $input:expr; $($tok:tt)*) => {
338         parse_ordered_fields!(@gen [$input] [] [] $($tok)*)
339     }
340 }
341 
342 struct Parameter {
343     name: Ident,
344     ptype: Ident,
345     default: Expr,
346     description: LitStr,
347 }
348 
349 impl Parse for Parameter {
350     fn parse(input: ParseStream<'_>) -> Result<Self> {
351         let name = input.parse()?;
352         input.parse::<Token![:]>()?;
353         let ptype = input.parse()?;
354 
355         let fields;
356         braced!(fields in input);
357 
358         parse_ordered_fields! {
359             from fields;
360             default [required] => fields.parse()?,
361             description [required] => fields.parse()?,
362         }
363 
364         Ok(Self {
365             name,
366             ptype,
367             default,
368             description,
369         })
370     }
371 }
372 
373 pub(crate) struct ModuleInfo {
374     type_: Type,
375     license: AsciiLitStr,
376     name: AsciiLitStr,
377     authors: Option<Punctuated<AsciiLitStr, Token![,]>>,
378     description: Option<LitStr>,
379     alias: Option<Punctuated<AsciiLitStr, Token![,]>>,
380     firmware: Option<Punctuated<AsciiLitStr, Token![,]>>,
381     imports_ns: Option<Punctuated<AsciiLitStr, Token![,]>>,
382     params: Option<Punctuated<Parameter, Token![,]>>,
383 }
384 
385 impl Parse for ModuleInfo {
386     fn parse(input: ParseStream<'_>) -> Result<Self> {
387         parse_ordered_fields!(
388             from input;
389             type as type_ [required] => input.parse()?,
390             name [required] => input.parse()?,
391             authors => {
392                 let list;
393                 bracketed!(list in input);
394                 Punctuated::parse_terminated(&list)?
395             },
396             description => input.parse()?,
397             license [required] => input.parse()?,
398             alias => {
399                 let list;
400                 bracketed!(list in input);
401                 Punctuated::parse_terminated(&list)?
402             },
403             firmware => {
404                 let list;
405                 bracketed!(list in input);
406                 Punctuated::parse_terminated(&list)?
407             },
408             imports_ns => {
409                 let list;
410                 bracketed!(list in input);
411                 Punctuated::parse_terminated(&list)?
412             },
413             params => {
414                 let list;
415                 braced!(list in input);
416                 Punctuated::parse_terminated(&list)?
417             },
418         );
419 
420         Ok(ModuleInfo {
421             type_,
422             license,
423             name,
424             authors,
425             description,
426             alias,
427             firmware,
428             imports_ns,
429             params,
430         })
431     }
432 }
433 
434 pub(crate) fn module(info: ModuleInfo) -> Result<TokenStream> {
435     let ModuleInfo {
436         type_,
437         license,
438         name,
439         authors,
440         description,
441         alias,
442         firmware,
443         imports_ns,
444         params: _,
445     } = &info;
446 
447     // Rust does not allow hyphens in identifiers, use underscore instead.
448     let ident = name.value().replace('-', "_");
449     let mut modinfo = ModInfoBuilder::new(ident.as_ref());
450     if let Some(authors) = authors {
451         for author in authors {
452             modinfo.emit("author", &author.value());
453         }
454     }
455     if let Some(description) = description {
456         modinfo.emit("description", &description.value());
457     }
458     modinfo.emit("license", &license.value());
459     if let Some(aliases) = alias {
460         for alias in aliases {
461             modinfo.emit("alias", &alias.value());
462         }
463     }
464     if let Some(firmware) = firmware {
465         for fw in firmware {
466             modinfo.emit("firmware", &fw.value());
467         }
468     }
469     if let Some(imports) = imports_ns {
470         for ns in imports {
471             modinfo.emit("import_ns", &ns.value());
472         }
473     }
474 
475     // Built-in modules also export the `file` modinfo string.
476     let file =
477         std::env::var("RUST_MODFILE").expect("Unable to fetch RUST_MODFILE environmental variable");
478     modinfo.emit_only_builtin("file", &file, false);
479 
480     modinfo.emit_params(&info);
481 
482     let modinfo_ts = modinfo.ts;
483     let params_ts = modinfo.param_ts;
484 
485     let ident_init = format_ident!("__{ident}_init");
486     let ident_exit = format_ident!("__{ident}_exit");
487     let ident_initcall = format_ident!("__{ident}_initcall");
488     let initcall_section = ".initcall6.init";
489 
490     let global_asm = format!(
491         r#".section "{initcall_section}", "a"
492         __{ident}_initcall:
493             .long   __{ident}_init - .
494             .previous
495         "#
496     );
497 
498     let name_cstr = CString::new(name.value()).expect("name contains NUL-terminator");
499 
500     Ok(quote! {
501         /// The module name.
502         ///
503         /// Used by the printing macros, e.g. [`info!`].
504         const __LOG_PREFIX: &[u8] = #name_cstr.to_bytes_with_nul();
505 
506         // SAFETY: `__this_module` is constructed by the kernel at load time and will not be
507         // freed until the module is unloaded.
508         #[cfg(MODULE)]
509         static THIS_MODULE: ::kernel::ThisModule = unsafe {
510             extern "C" {
511                 static __this_module: ::kernel::types::Opaque<::kernel::bindings::module>;
512             };
513 
514             ::kernel::ThisModule::from_ptr(__this_module.get())
515         };
516 
517         #[cfg(not(MODULE))]
518         static THIS_MODULE: ::kernel::ThisModule = unsafe {
519             ::kernel::ThisModule::from_ptr(::core::ptr::null_mut())
520         };
521 
522         /// The `LocalModule` type is the type of the module created by `module!`,
523         /// `module_pci_driver!`, `module_platform_driver!`, etc.
524         type LocalModule = #type_;
525 
526         impl ::kernel::ModuleMetadata for #type_ {
527             const NAME: &'static ::kernel::str::CStr = #name_cstr;
528         }
529 
530         // Double nested modules, since then nobody can access the public items inside.
531         mod __module_init {
532             mod __module_init {
533                 use pin_init::PinInit;
534 
535                 /// The "Rust loadable module" mark.
536                 //
537                 // This may be best done another way later on, e.g. as a new modinfo
538                 // key or a new section. For the moment, keep it simple.
539                 #[cfg(MODULE)]
540                 #[doc(hidden)]
541                 #[used(compiler)]
542                 static __IS_RUST_MODULE: () = ();
543 
544                 static mut __MOD: ::core::mem::MaybeUninit<super::super::LocalModule> =
545                     ::core::mem::MaybeUninit::uninit();
546 
547                 // Loadable modules need to export the `{init,cleanup}_module` identifiers.
548                 /// # Safety
549                 ///
550                 /// This function must not be called after module initialization, because it may be
551                 /// freed after that completes.
552                 #[cfg(MODULE)]
553                 #[doc(hidden)]
554                 #[no_mangle]
555                 #[link_section = ".init.text"]
556                 pub unsafe extern "C" fn init_module() -> ::kernel::ffi::c_int {
557                     // SAFETY: This function is inaccessible to the outside due to the double
558                     // module wrapping it. It is called exactly once by the C side via its
559                     // unique name.
560                     unsafe { __init() }
561                 }
562 
563                 #[cfg(MODULE)]
564                 #[doc(hidden)]
565                 #[used(compiler)]
566                 #[link_section = ".init.data"]
567                 static __UNIQUE_ID___addressable_init_module: unsafe extern "C" fn() -> i32 =
568                     init_module;
569 
570                 #[cfg(MODULE)]
571                 #[doc(hidden)]
572                 #[no_mangle]
573                 #[link_section = ".exit.text"]
574                 pub extern "C" fn cleanup_module() {
575                     // SAFETY:
576                     // - This function is inaccessible to the outside due to the double
577                     //   module wrapping it. It is called exactly once by the C side via its
578                     //   unique name,
579                     // - furthermore it is only called after `init_module` has returned `0`
580                     //   (which delegates to `__init`).
581                     unsafe { __exit() }
582                 }
583 
584                 #[cfg(MODULE)]
585                 #[doc(hidden)]
586                 #[used(compiler)]
587                 #[link_section = ".exit.data"]
588                 static __UNIQUE_ID___addressable_cleanup_module: extern "C" fn() = cleanup_module;
589 
590                 // Built-in modules are initialized through an initcall pointer
591                 // and the identifiers need to be unique.
592                 #[cfg(not(MODULE))]
593                 #[cfg(not(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS))]
594                 #[doc(hidden)]
595                 #[link_section = #initcall_section]
596                 #[used(compiler)]
597                 pub static #ident_initcall: extern "C" fn() ->
598                     ::kernel::ffi::c_int = #ident_init;
599 
600                 #[cfg(not(MODULE))]
601                 #[cfg(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS)]
602                 ::core::arch::global_asm!(#global_asm);
603 
604                 #[cfg(not(MODULE))]
605                 #[doc(hidden)]
606                 #[no_mangle]
607                 pub extern "C" fn #ident_init() -> ::kernel::ffi::c_int {
608                     // SAFETY: This function is inaccessible to the outside due to the double
609                     // module wrapping it. It is called exactly once by the C side via its
610                     // placement above in the initcall section.
611                     unsafe { __init() }
612                 }
613 
614                 #[cfg(not(MODULE))]
615                 #[doc(hidden)]
616                 #[no_mangle]
617                 pub extern "C" fn #ident_exit() {
618                     // SAFETY:
619                     // - This function is inaccessible to the outside due to the double
620                     //   module wrapping it. It is called exactly once by the C side via its
621                     //   unique name,
622                     // - furthermore it is only called after `#ident_init` has
623                     //   returned `0` (which delegates to `__init`).
624                     unsafe { __exit() }
625                 }
626 
627                 /// # Safety
628                 ///
629                 /// This function must only be called once.
630                 unsafe fn __init() -> ::kernel::ffi::c_int {
631                     let initer = <super::super::LocalModule as ::kernel::InPlaceModule>::init(
632                         &super::super::THIS_MODULE
633                     );
634                     // SAFETY: No data race, since `__MOD` can only be accessed by this module
635                     // and there only `__init` and `__exit` access it. These functions are only
636                     // called once and `__exit` cannot be called before or during `__init`.
637                     match unsafe { initer.__pinned_init(__MOD.as_mut_ptr()) } {
638                         Ok(m) => 0,
639                         Err(e) => e.to_errno(),
640                     }
641                 }
642 
643                 /// # Safety
644                 ///
645                 /// This function must
646                 /// - only be called once,
647                 /// - be called after `__init` has been called and returned `0`.
648                 unsafe fn __exit() {
649                     // SAFETY: No data race, since `__MOD` can only be accessed by this module
650                     // and there only `__init` and `__exit` access it. These functions are only
651                     // called once and `__init` was already called.
652                     unsafe {
653                         // Invokes `drop()` on `__MOD`, which should be used for cleanup.
654                         __MOD.assume_init_drop();
655                     }
656                 }
657 
658                 #modinfo_ts
659             }
660         }
661 
662         mod module_parameters {
663             #params_ts
664         }
665     })
666 }
667