xref: /linux/tools/net/ynl/ynl-gen-c.py (revision 79ac11393328fb1717d17c12e3c0eef0e9fa0647)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import os
8import re
9import shutil
10import tempfile
11import yaml
12
13from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
14
15
16def c_upper(name):
17    return name.upper().replace('-', '_')
18
19
20def c_lower(name):
21    return name.lower().replace('-', '_')
22
23
24def limit_to_number(name):
25    """
26    Turn a string limit like u32-max or s64-min into its numerical value
27    """
28    if name[0] == 'u' and name.endswith('-min'):
29        return 0
30    width = int(name[1:-4])
31    if name[0] == 's':
32        width -= 1
33    value = (1 << width) - 1
34    if name[0] == 's' and name.endswith('-min'):
35        value = -value - 1
36    return value
37
38
39class BaseNlLib:
40    def get_family_id(self):
41        return 'ys->family_id'
42
43    def parse_cb_run(self, cb, data, is_dump=False, indent=1):
44        ind = '\n\t\t' + '\t' * indent + ' '
45        if is_dump:
46            return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
47        else:
48            return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
49                   "ynl_cb_array, NLMSG_MIN_TYPE)"
50
51
52class Type(SpecAttr):
53    def __init__(self, family, attr_set, attr, value):
54        super().__init__(family, attr_set, attr, value)
55
56        self.attr = attr
57        self.attr_set = attr_set
58        self.type = attr['type']
59        self.checks = attr.get('checks', {})
60
61        self.request = False
62        self.reply = False
63
64        if 'len' in attr:
65            self.len = attr['len']
66
67        if 'nested-attributes' in attr:
68            self.nested_attrs = attr['nested-attributes']
69            if self.nested_attrs == family.name:
70                self.nested_render_name = c_lower(f"{family.name}")
71            else:
72                self.nested_render_name = c_lower(f"{family.name}_{self.nested_attrs}")
73
74            if self.nested_attrs in self.family.consts:
75                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
76            else:
77                self.nested_struct_type = 'struct ' + self.nested_render_name
78
79        self.c_name = c_lower(self.name)
80        if self.c_name in _C_KW:
81            self.c_name += '_'
82
83        # Added by resolve():
84        self.enum_name = None
85        delattr(self, "enum_name")
86
87    def get_limit(self, limit, default=None):
88        value = self.checks.get(limit, default)
89        if value is None:
90            return value
91        if not isinstance(value, int):
92            value = limit_to_number(value)
93        return value
94
95    def resolve(self):
96        if 'name-prefix' in self.attr:
97            enum_name = f"{self.attr['name-prefix']}{self.name}"
98        else:
99            enum_name = f"{self.attr_set.name_prefix}{self.name}"
100        self.enum_name = c_upper(enum_name)
101
102    def is_multi_val(self):
103        return None
104
105    def is_scalar(self):
106        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
107
108    def presence_type(self):
109        return 'bit'
110
111    def presence_member(self, space, type_filter):
112        if self.presence_type() != type_filter:
113            return
114
115        if self.presence_type() == 'bit':
116            pfx = '__' if space == 'user' else ''
117            return f"{pfx}u32 {self.c_name}:1;"
118
119        if self.presence_type() == 'len':
120            pfx = '__' if space == 'user' else ''
121            return f"{pfx}u32 {self.c_name}_len;"
122
123    def _complex_member_type(self, ri):
124        return None
125
126    def free_needs_iter(self):
127        return False
128
129    def free(self, ri, var, ref):
130        if self.is_multi_val() or self.presence_type() == 'len':
131            ri.cw.p(f'free({var}->{ref}{self.c_name});')
132
133    def arg_member(self, ri):
134        member = self._complex_member_type(ri)
135        if member:
136            arg = [member + ' *' + self.c_name]
137            if self.presence_type() == 'count':
138                arg += ['unsigned int n_' + self.c_name]
139            return arg
140        raise Exception(f"Struct member not implemented for class type {self.type}")
141
142    def struct_member(self, ri):
143        if self.is_multi_val():
144            ri.cw.p(f"unsigned int n_{self.c_name};")
145        member = self._complex_member_type(ri)
146        if member:
147            ptr = '*' if self.is_multi_val() else ''
148            ri.cw.p(f"{member} {ptr}{self.c_name};")
149            return
150        members = self.arg_member(ri)
151        for one in members:
152            ri.cw.p(one + ';')
153
154    def _attr_policy(self, policy):
155        return '{ .type = ' + policy + ', }'
156
157    def attr_policy(self, cw):
158        policy = c_upper('nla-' + self.attr['type'])
159
160        spec = self._attr_policy(policy)
161        cw.p(f"\t[{self.enum_name}] = {spec},")
162
163    def _mnl_type(self):
164        # mnl does not have helpers for signed integer types
165        # turn signed type into unsigned
166        # this only makes sense for scalar types
167        t = self.type
168        if t[0] == 's':
169            t = 'u' + t[1:]
170        return t
171
172    def _attr_typol(self):
173        raise Exception(f"Type policy not implemented for class type {self.type}")
174
175    def attr_typol(self, cw):
176        typol = self._attr_typol()
177        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
178
179    def _attr_put_line(self, ri, var, line):
180        if self.presence_type() == 'bit':
181            ri.cw.p(f"if ({var}->_present.{self.c_name})")
182        elif self.presence_type() == 'len':
183            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
184        ri.cw.p(f"{line};")
185
186    def _attr_put_simple(self, ri, var, put_type):
187        line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
188        self._attr_put_line(ri, var, line)
189
190    def attr_put(self, ri, var):
191        raise Exception(f"Put not implemented for class type {self.type}")
192
193    def _attr_get(self, ri, var):
194        raise Exception(f"Attr get not implemented for class type {self.type}")
195
196    def attr_get(self, ri, var, first):
197        lines, init_lines, local_vars = self._attr_get(ri, var)
198        if type(lines) is str:
199            lines = [lines]
200        if type(init_lines) is str:
201            init_lines = [init_lines]
202
203        kw = 'if' if first else 'else if'
204        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
205        if local_vars:
206            for local in local_vars:
207                ri.cw.p(local)
208            ri.cw.nl()
209
210        if not self.is_multi_val():
211            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
212            ri.cw.p("return MNL_CB_ERROR;")
213            if self.presence_type() == 'bit':
214                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
215
216        if init_lines:
217            ri.cw.nl()
218            for line in init_lines:
219                ri.cw.p(line)
220
221        for line in lines:
222            ri.cw.p(line)
223        ri.cw.block_end()
224        return True
225
226    def _setter_lines(self, ri, member, presence):
227        raise Exception(f"Setter not implemented for class type {self.type}")
228
229    def setter(self, ri, space, direction, deref=False, ref=None):
230        ref = (ref if ref else []) + [self.c_name]
231        var = "req"
232        member = f"{var}->{'.'.join(ref)}"
233
234        code = []
235        presence = ''
236        for i in range(0, len(ref)):
237            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
238            if self.presence_type() == 'bit':
239                code.append(presence + ' = 1;')
240        code += self._setter_lines(ri, member, presence)
241
242        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
243        free = bool([x for x in code if 'free(' in x])
244        alloc = bool([x for x in code if 'alloc(' in x])
245        if free and not alloc:
246            func_name = '__' + func_name
247        ri.cw.write_func('static inline void', func_name, body=code,
248                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
249
250
251class TypeUnused(Type):
252    def presence_type(self):
253        return ''
254
255    def arg_member(self, ri):
256        return []
257
258    def _attr_get(self, ri, var):
259        return ['return MNL_CB_ERROR;'], None, None
260
261    def _attr_typol(self):
262        return '.type = YNL_PT_REJECT, '
263
264    def attr_policy(self, cw):
265        pass
266
267
268class TypePad(Type):
269    def presence_type(self):
270        return ''
271
272    def arg_member(self, ri):
273        return []
274
275    def _attr_typol(self):
276        return '.type = YNL_PT_IGNORE, '
277
278    def attr_put(self, ri, var):
279        pass
280
281    def attr_get(self, ri, var, first):
282        pass
283
284    def attr_policy(self, cw):
285        pass
286
287    def setter(self, ri, space, direction, deref=False, ref=None):
288        pass
289
290
291class TypeScalar(Type):
292    def __init__(self, family, attr_set, attr, value):
293        super().__init__(family, attr_set, attr, value)
294
295        self.byte_order_comment = ''
296        if 'byte-order' in attr:
297            self.byte_order_comment = f" /* {attr['byte-order']} */"
298
299        if 'enum' in self.attr:
300            enum = self.family.consts[self.attr['enum']]
301            low, high = enum.value_range()
302            if 'min' not in self.checks:
303                if low != 0 or self.type[0] == 's':
304                    self.checks['min'] = low
305            if 'max' not in self.checks:
306                self.checks['max'] = high
307
308        if 'min' in self.checks and 'max' in self.checks:
309            if self.get_limit('min') > self.get_limit('max'):
310                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
311            self.checks['range'] = True
312
313        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
314        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
315        if low < 0 and self.type[0] == 'u':
316            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
317        if low < -32768 or high > 32767:
318            self.checks['full-range'] = True
319
320        # Added by resolve():
321        self.is_bitfield = None
322        delattr(self, "is_bitfield")
323        self.type_name = None
324        delattr(self, "type_name")
325
326    def resolve(self):
327        self.resolve_up(super())
328
329        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
330            self.is_bitfield = True
331        elif 'enum' in self.attr:
332            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
333        else:
334            self.is_bitfield = False
335
336        maybe_enum = not self.is_bitfield and 'enum' in self.attr
337        if maybe_enum and self.family.consts[self.attr['enum']].enum_name:
338            self.type_name = c_lower(f"enum {self.family.name}_{self.attr['enum']}")
339        elif self.is_auto_scalar:
340            self.type_name = '__' + self.type[0] + '64'
341        else:
342            self.type_name = '__' + self.type
343
344    def mnl_type(self):
345        return self._mnl_type()
346
347    def _attr_policy(self, policy):
348        if 'flags-mask' in self.checks or self.is_bitfield:
349            if self.is_bitfield:
350                enum = self.family.consts[self.attr['enum']]
351                mask = enum.get_mask(as_flags=True)
352            else:
353                flags = self.family.consts[self.checks['flags-mask']]
354                flag_cnt = len(flags['entries'])
355                mask = (1 << flag_cnt) - 1
356            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
357        elif 'full-range' in self.checks:
358            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
359        elif 'range' in self.checks:
360            return f"NLA_POLICY_RANGE({policy}, {self.get_limit('min')}, {self.get_limit('max')})"
361        elif 'min' in self.checks:
362            return f"NLA_POLICY_MIN({policy}, {self.get_limit('min')})"
363        elif 'max' in self.checks:
364            return f"NLA_POLICY_MAX({policy}, {self.get_limit('max')})"
365        return super()._attr_policy(policy)
366
367    def _attr_typol(self):
368        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
369
370    def arg_member(self, ri):
371        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
372
373    def attr_put(self, ri, var):
374        self._attr_put_simple(ri, var, self.mnl_type())
375
376    def _attr_get(self, ri, var):
377        return f"{var}->{self.c_name} = mnl_attr_get_{self.mnl_type()}(attr);", None, None
378
379    def _setter_lines(self, ri, member, presence):
380        return [f"{member} = {self.c_name};"]
381
382
383class TypeFlag(Type):
384    def arg_member(self, ri):
385        return []
386
387    def _attr_typol(self):
388        return '.type = YNL_PT_FLAG, '
389
390    def attr_put(self, ri, var):
391        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
392
393    def _attr_get(self, ri, var):
394        return [], None, None
395
396    def _setter_lines(self, ri, member, presence):
397        return []
398
399
400class TypeString(Type):
401    def arg_member(self, ri):
402        return [f"const char *{self.c_name}"]
403
404    def presence_type(self):
405        return 'len'
406
407    def struct_member(self, ri):
408        ri.cw.p(f"char *{self.c_name};")
409
410    def _attr_typol(self):
411        return f'.type = YNL_PT_NUL_STR, '
412
413    def _attr_policy(self, policy):
414        if 'exact-len' in self.checks:
415            mem = 'NLA_POLICY_EXACT_LEN(' + str(self.checks['exact-len']) + ')'
416        else:
417            mem = '{ .type = ' + policy
418            if 'max-len' in self.checks:
419                mem += ', .len = ' + str(self.get_limit('max-len'))
420            mem += ', }'
421        return mem
422
423    def attr_policy(self, cw):
424        if self.checks.get('unterminated-ok', False):
425            policy = 'NLA_STRING'
426        else:
427            policy = 'NLA_NUL_STRING'
428
429        spec = self._attr_policy(policy)
430        cw.p(f"\t[{self.enum_name}] = {spec},")
431
432    def attr_put(self, ri, var):
433        self._attr_put_simple(ri, var, 'strz')
434
435    def _attr_get(self, ri, var):
436        len_mem = var + '->_present.' + self.c_name + '_len'
437        return [f"{len_mem} = len;",
438                f"{var}->{self.c_name} = malloc(len + 1);",
439                f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
440                f"{var}->{self.c_name}[len] = 0;"], \
441               ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
442               ['unsigned int len;']
443
444    def _setter_lines(self, ri, member, presence):
445        return [f"free({member});",
446                f"{presence}_len = strlen({self.c_name});",
447                f"{member} = malloc({presence}_len + 1);",
448                f'memcpy({member}, {self.c_name}, {presence}_len);',
449                f'{member}[{presence}_len] = 0;']
450
451
452class TypeBinary(Type):
453    def arg_member(self, ri):
454        return [f"const void *{self.c_name}", 'size_t len']
455
456    def presence_type(self):
457        return 'len'
458
459    def struct_member(self, ri):
460        ri.cw.p(f"void *{self.c_name};")
461
462    def _attr_typol(self):
463        return f'.type = YNL_PT_BINARY,'
464
465    def _attr_policy(self, policy):
466        if 'exact-len' in self.checks:
467            mem = 'NLA_POLICY_EXACT_LEN(' + str(self.checks['exact-len']) + ')'
468        else:
469            mem = '{ '
470            if len(self.checks) == 1 and 'min-len' in self.checks:
471                mem += '.len = ' + str(self.get_limit('min-len'))
472            elif len(self.checks) == 0:
473                mem += '.type = NLA_BINARY'
474            else:
475                raise Exception('One or more of binary type checks not implemented, yet')
476            mem += ', }'
477        return mem
478
479    def attr_put(self, ri, var):
480        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
481                            f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
482
483    def _attr_get(self, ri, var):
484        len_mem = var + '->_present.' + self.c_name + '_len'
485        return [f"{len_mem} = len;",
486                f"{var}->{self.c_name} = malloc(len);",
487                f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
488               ['len = mnl_attr_get_payload_len(attr);'], \
489               ['unsigned int len;']
490
491    def _setter_lines(self, ri, member, presence):
492        return [f"free({member});",
493                f"{presence}_len = len;",
494                f"{member} = malloc({presence}_len);",
495                f'memcpy({member}, {self.c_name}, {presence}_len);']
496
497
498class TypeBitfield32(Type):
499    def _complex_member_type(self, ri):
500        return "struct nla_bitfield32"
501
502    def _attr_typol(self):
503        return f'.type = YNL_PT_BITFIELD32, '
504
505    def _attr_policy(self, policy):
506        if not 'enum' in self.attr:
507            raise Exception('Enum required for bitfield32 attr')
508        enum = self.family.consts[self.attr['enum']]
509        mask = enum.get_mask(as_flags=True)
510        return f"NLA_POLICY_BITFIELD32({mask})"
511
512    def attr_put(self, ri, var):
513        line = f"mnl_attr_put(nlh, {self.enum_name}, sizeof(struct nla_bitfield32), &{var}->{self.c_name})"
514        self._attr_put_line(ri, var, line)
515
516    def _attr_get(self, ri, var):
517        return f"memcpy(&{var}->{self.c_name}, mnl_attr_get_payload(attr), sizeof(struct nla_bitfield32));", None, None
518
519    def _setter_lines(self, ri, member, presence):
520        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
521
522
523class TypeNest(Type):
524    def _complex_member_type(self, ri):
525        return self.nested_struct_type
526
527    def free(self, ri, var, ref):
528        ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
529
530    def _attr_typol(self):
531        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
532
533    def _attr_policy(self, policy):
534        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
535
536    def attr_put(self, ri, var):
537        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
538                            f"{self.enum_name}, &{var}->{self.c_name})")
539
540    def _attr_get(self, ri, var):
541        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
542                     "return MNL_CB_ERROR;"]
543        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
544                      f"parg.data = &{var}->{self.c_name};"]
545        return get_lines, init_lines, None
546
547    def setter(self, ri, space, direction, deref=False, ref=None):
548        ref = (ref if ref else []) + [self.c_name]
549
550        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
551            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
552
553
554class TypeMultiAttr(Type):
555    def __init__(self, family, attr_set, attr, value, base_type):
556        super().__init__(family, attr_set, attr, value)
557
558        self.base_type = base_type
559
560    def is_multi_val(self):
561        return True
562
563    def presence_type(self):
564        return 'count'
565
566    def mnl_type(self):
567        return self._mnl_type()
568
569    def _complex_member_type(self, ri):
570        if 'type' not in self.attr or self.attr['type'] == 'nest':
571            return self.nested_struct_type
572        elif self.attr['type'] in scalars:
573            scalar_pfx = '__' if ri.ku_space == 'user' else ''
574            return scalar_pfx + self.attr['type']
575        else:
576            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
577
578    def free_needs_iter(self):
579        return 'type' not in self.attr or self.attr['type'] == 'nest'
580
581    def free(self, ri, var, ref):
582        if self.attr['type'] in scalars:
583            ri.cw.p(f"free({var}->{ref}{self.c_name});")
584        elif 'type' not in self.attr or self.attr['type'] == 'nest':
585            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
586            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
587            ri.cw.p(f"free({var}->{ref}{self.c_name});")
588        else:
589            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
590
591    def _attr_policy(self, policy):
592        return self.base_type._attr_policy(policy)
593
594    def _attr_typol(self):
595        return self.base_type._attr_typol()
596
597    def _attr_get(self, ri, var):
598        return f'n_{self.c_name}++;', None, None
599
600    def attr_put(self, ri, var):
601        if self.attr['type'] in scalars:
602            put_type = self.mnl_type()
603            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
604            ri.cw.p(f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
605        elif 'type' not in self.attr or self.attr['type'] == 'nest':
606            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
607            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
608                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
609        else:
610            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
611
612    def _setter_lines(self, ri, member, presence):
613        # For multi-attr we have a count, not presence, hack up the presence
614        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
615        return [f"free({member});",
616                f"{member} = {self.c_name};",
617                f"{presence} = n_{self.c_name};"]
618
619
620class TypeArrayNest(Type):
621    def is_multi_val(self):
622        return True
623
624    def presence_type(self):
625        return 'count'
626
627    def _complex_member_type(self, ri):
628        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
629            return self.nested_struct_type
630        elif self.attr['sub-type'] in scalars:
631            scalar_pfx = '__' if ri.ku_space == 'user' else ''
632            return scalar_pfx + self.attr['sub-type']
633        else:
634            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
635
636    def _attr_typol(self):
637        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
638
639    def _attr_get(self, ri, var):
640        local_vars = ['const struct nlattr *attr2;']
641        get_lines = [f'attr_{self.c_name} = attr;',
642                     'mnl_attr_for_each_nested(attr2, attr)',
643                     f'\t{var}->n_{self.c_name}++;']
644        return get_lines, None, local_vars
645
646
647class TypeNestTypeValue(Type):
648    def _complex_member_type(self, ri):
649        return self.nested_struct_type
650
651    def _attr_typol(self):
652        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
653
654    def _attr_get(self, ri, var):
655        prev = 'attr'
656        tv_args = ''
657        get_lines = []
658        local_vars = []
659        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
660                      f"parg.data = &{var}->{self.c_name};"]
661        if 'type-value' in self.attr:
662            tv_names = [c_lower(x) for x in self.attr["type-value"]]
663            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
664            local_vars += [f'__u32 {", ".join(tv_names)};']
665            for level in self.attr["type-value"]:
666                level = c_lower(level)
667                get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
668                get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
669                prev = 'attr_' + level
670
671            tv_args = f", {', '.join(tv_names)}"
672
673        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
674        return get_lines, init_lines, local_vars
675
676
677class Struct:
678    def __init__(self, family, space_name, type_list=None, inherited=None):
679        self.family = family
680        self.space_name = space_name
681        self.attr_set = family.attr_sets[space_name]
682        # Use list to catch comparisons with empty sets
683        self._inherited = inherited if inherited is not None else []
684        self.inherited = []
685
686        self.nested = type_list is None
687        if family.name == c_lower(space_name):
688            self.render_name = c_lower(family.name)
689        else:
690            self.render_name = c_lower(family.name + '-' + space_name)
691        self.struct_name = 'struct ' + self.render_name
692        if self.nested and space_name in family.consts:
693            self.struct_name += '_'
694        self.ptr_name = self.struct_name + ' *'
695
696        self.request = False
697        self.reply = False
698
699        self.attr_list = []
700        self.attrs = dict()
701        if type_list is not None:
702            for t in type_list:
703                self.attr_list.append((t, self.attr_set[t]),)
704        else:
705            for t in self.attr_set:
706                self.attr_list.append((t, self.attr_set[t]),)
707
708        max_val = 0
709        self.attr_max_val = None
710        for name, attr in self.attr_list:
711            if attr.value >= max_val:
712                max_val = attr.value
713                self.attr_max_val = attr
714            self.attrs[name] = attr
715
716    def __iter__(self):
717        yield from self.attrs
718
719    def __getitem__(self, key):
720        return self.attrs[key]
721
722    def member_list(self):
723        return self.attr_list
724
725    def set_inherited(self, new_inherited):
726        if self._inherited != new_inherited:
727            raise Exception("Inheriting different members not supported")
728        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
729
730
731class EnumEntry(SpecEnumEntry):
732    def __init__(self, enum_set, yaml, prev, value_start):
733        super().__init__(enum_set, yaml, prev, value_start)
734
735        if prev:
736            self.value_change = (self.value != prev.value + 1)
737        else:
738            self.value_change = (self.value != 0)
739        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
740
741        # Added by resolve:
742        self.c_name = None
743        delattr(self, "c_name")
744
745    def resolve(self):
746        self.resolve_up(super())
747
748        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
749
750
751class EnumSet(SpecEnumSet):
752    def __init__(self, family, yaml):
753        self.render_name = c_lower(family.name + '-' + yaml['name'])
754
755        if 'enum-name' in yaml:
756            if yaml['enum-name']:
757                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
758                self.user_type = self.enum_name
759            else:
760                self.enum_name = None
761        else:
762            self.enum_name = 'enum ' + self.render_name
763
764        if self.enum_name:
765            self.user_type = self.enum_name
766        else:
767            self.user_type = 'int'
768
769        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
770
771        super().__init__(family, yaml)
772
773    def new_entry(self, entry, prev_entry, value_start):
774        return EnumEntry(self, entry, prev_entry, value_start)
775
776    def value_range(self):
777        low = min([x.value for x in self.entries.values()])
778        high = max([x.value for x in self.entries.values()])
779
780        if high - low + 1 != len(self.entries):
781            raise Exception("Can't get value range for a noncontiguous enum")
782
783        return low, high
784
785
786class AttrSet(SpecAttrSet):
787    def __init__(self, family, yaml):
788        super().__init__(family, yaml)
789
790        if self.subset_of is None:
791            if 'name-prefix' in yaml:
792                pfx = yaml['name-prefix']
793            elif self.name == family.name:
794                pfx = family.name + '-a-'
795            else:
796                pfx = f"{family.name}-a-{self.name}-"
797            self.name_prefix = c_upper(pfx)
798            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
799            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
800        else:
801            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
802            self.max_name = family.attr_sets[self.subset_of].max_name
803            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
804
805        # Added by resolve:
806        self.c_name = None
807        delattr(self, "c_name")
808
809    def resolve(self):
810        self.c_name = c_lower(self.name)
811        if self.c_name in _C_KW:
812            self.c_name += '_'
813        if self.c_name == self.family.c_name:
814            self.c_name = ''
815
816    def new_attr(self, elem, value):
817        if elem['type'] in scalars:
818            t = TypeScalar(self.family, self, elem, value)
819        elif elem['type'] == 'unused':
820            t = TypeUnused(self.family, self, elem, value)
821        elif elem['type'] == 'pad':
822            t = TypePad(self.family, self, elem, value)
823        elif elem['type'] == 'flag':
824            t = TypeFlag(self.family, self, elem, value)
825        elif elem['type'] == 'string':
826            t = TypeString(self.family, self, elem, value)
827        elif elem['type'] == 'binary':
828            t = TypeBinary(self.family, self, elem, value)
829        elif elem['type'] == 'bitfield32':
830            t = TypeBitfield32(self.family, self, elem, value)
831        elif elem['type'] == 'nest':
832            t = TypeNest(self.family, self, elem, value)
833        elif elem['type'] == 'array-nest':
834            t = TypeArrayNest(self.family, self, elem, value)
835        elif elem['type'] == 'nest-type-value':
836            t = TypeNestTypeValue(self.family, self, elem, value)
837        else:
838            raise Exception(f"No typed class for type {elem['type']}")
839
840        if 'multi-attr' in elem and elem['multi-attr']:
841            t = TypeMultiAttr(self.family, self, elem, value, t)
842
843        return t
844
845
846class Operation(SpecOperation):
847    def __init__(self, family, yaml, req_value, rsp_value):
848        super().__init__(family, yaml, req_value, rsp_value)
849
850        self.render_name = c_lower(family.name + '_' + self.name)
851
852        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
853                         ('dump' in yaml and 'request' in yaml['dump'])
854
855        self.has_ntf = False
856
857        # Added by resolve:
858        self.enum_name = None
859        delattr(self, "enum_name")
860
861    def resolve(self):
862        self.resolve_up(super())
863
864        if not self.is_async:
865            self.enum_name = self.family.op_prefix + c_upper(self.name)
866        else:
867            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
868
869    def mark_has_ntf(self):
870        self.has_ntf = True
871
872
873class Family(SpecFamily):
874    def __init__(self, file_name, exclude_ops):
875        # Added by resolve:
876        self.c_name = None
877        delattr(self, "c_name")
878        self.op_prefix = None
879        delattr(self, "op_prefix")
880        self.async_op_prefix = None
881        delattr(self, "async_op_prefix")
882        self.mcgrps = None
883        delattr(self, "mcgrps")
884        self.consts = None
885        delattr(self, "consts")
886        self.hooks = None
887        delattr(self, "hooks")
888
889        super().__init__(file_name, exclude_ops=exclude_ops)
890
891        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
892        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
893
894        if 'definitions' not in self.yaml:
895            self.yaml['definitions'] = []
896
897        if 'uapi-header' in self.yaml:
898            self.uapi_header = self.yaml['uapi-header']
899        else:
900            self.uapi_header = f"linux/{self.name}.h"
901        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
902            self.uapi_header_name = self.uapi_header[6:-2]
903        else:
904            self.uapi_header_name = self.name
905
906    def resolve(self):
907        self.resolve_up(super())
908
909        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
910            raise Exception("Codegen only supported for genetlink")
911
912        self.c_name = c_lower(self.name)
913        if 'name-prefix' in self.yaml['operations']:
914            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
915        else:
916            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
917        if 'async-prefix' in self.yaml['operations']:
918            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
919        else:
920            self.async_op_prefix = self.op_prefix
921
922        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
923
924        self.hooks = dict()
925        for when in ['pre', 'post']:
926            self.hooks[when] = dict()
927            for op_mode in ['do', 'dump']:
928                self.hooks[when][op_mode] = dict()
929                self.hooks[when][op_mode]['set'] = set()
930                self.hooks[when][op_mode]['list'] = []
931
932        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
933        self.root_sets = dict()
934        # dict space-name -> set('request', 'reply')
935        self.pure_nested_structs = dict()
936
937        self._mark_notify()
938        self._mock_up_events()
939
940        self._load_root_sets()
941        self._load_nested_sets()
942        self._load_attr_use()
943        self._load_hooks()
944
945        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
946        if self.kernel_policy == 'global':
947            self._load_global_policy()
948
949    def new_enum(self, elem):
950        return EnumSet(self, elem)
951
952    def new_attr_set(self, elem):
953        return AttrSet(self, elem)
954
955    def new_operation(self, elem, req_value, rsp_value):
956        return Operation(self, elem, req_value, rsp_value)
957
958    def _mark_notify(self):
959        for op in self.msgs.values():
960            if 'notify' in op:
961                self.ops[op['notify']].mark_has_ntf()
962
963    # Fake a 'do' equivalent of all events, so that we can render their response parsing
964    def _mock_up_events(self):
965        for op in self.yaml['operations']['list']:
966            if 'event' in op:
967                op['do'] = {
968                    'reply': {
969                        'attributes': op['event']['attributes']
970                    }
971                }
972
973    def _load_root_sets(self):
974        for op_name, op in self.msgs.items():
975            if 'attribute-set' not in op:
976                continue
977
978            req_attrs = set()
979            rsp_attrs = set()
980            for op_mode in ['do', 'dump']:
981                if op_mode in op and 'request' in op[op_mode]:
982                    req_attrs.update(set(op[op_mode]['request']['attributes']))
983                if op_mode in op and 'reply' in op[op_mode]:
984                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
985            if 'event' in op:
986                rsp_attrs.update(set(op['event']['attributes']))
987
988            if op['attribute-set'] not in self.root_sets:
989                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
990            else:
991                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
992                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
993
994    def _load_nested_sets(self):
995        attr_set_queue = list(self.root_sets.keys())
996        attr_set_seen = set(self.root_sets.keys())
997
998        while len(attr_set_queue):
999            a_set = attr_set_queue.pop(0)
1000            for attr, spec in self.attr_sets[a_set].items():
1001                if 'nested-attributes' not in spec:
1002                    continue
1003
1004                nested = spec['nested-attributes']
1005                if nested not in attr_set_seen:
1006                    attr_set_queue.append(nested)
1007                    attr_set_seen.add(nested)
1008
1009                inherit = set()
1010                if nested not in self.root_sets:
1011                    if nested not in self.pure_nested_structs:
1012                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1013                else:
1014                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1015
1016                if 'type-value' in spec:
1017                    if nested in self.root_sets:
1018                        raise Exception("Inheriting members to a space used as root not supported")
1019                    inherit.update(set(spec['type-value']))
1020                elif spec['type'] == 'array-nest':
1021                    inherit.add('idx')
1022                self.pure_nested_structs[nested].set_inherited(inherit)
1023
1024        for root_set, rs_members in self.root_sets.items():
1025            for attr, spec in self.attr_sets[root_set].items():
1026                if 'nested-attributes' in spec:
1027                    nested = spec['nested-attributes']
1028                    if attr in rs_members['request']:
1029                        self.pure_nested_structs[nested].request = True
1030                    if attr in rs_members['reply']:
1031                        self.pure_nested_structs[nested].reply = True
1032
1033        # Try to reorder according to dependencies
1034        pns_key_list = list(self.pure_nested_structs.keys())
1035        pns_key_seen = set()
1036        rounds = len(pns_key_list)**2  # it's basically bubble sort
1037        for _ in range(rounds):
1038            if len(pns_key_list) == 0:
1039                break
1040            name = pns_key_list.pop(0)
1041            finished = True
1042            for _, spec in self.attr_sets[name].items():
1043                if 'nested-attributes' in spec:
1044                    if spec['nested-attributes'] not in pns_key_seen:
1045                        # Dicts are sorted, this will make struct last
1046                        struct = self.pure_nested_structs.pop(name)
1047                        self.pure_nested_structs[name] = struct
1048                        finished = False
1049                        break
1050            if finished:
1051                pns_key_seen.add(name)
1052            else:
1053                pns_key_list.append(name)
1054        # Propagate the request / reply
1055        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1056            for _, spec in self.attr_sets[attr_set].items():
1057                if 'nested-attributes' in spec:
1058                    child = self.pure_nested_structs.get(spec['nested-attributes'])
1059                    if child:
1060                        child.request |= struct.request
1061                        child.reply |= struct.reply
1062
1063    def _load_attr_use(self):
1064        for _, struct in self.pure_nested_structs.items():
1065            if struct.request:
1066                for _, arg in struct.member_list():
1067                    arg.request = True
1068            if struct.reply:
1069                for _, arg in struct.member_list():
1070                    arg.reply = True
1071
1072        for root_set, rs_members in self.root_sets.items():
1073            for attr, spec in self.attr_sets[root_set].items():
1074                if attr in rs_members['request']:
1075                    spec.request = True
1076                if attr in rs_members['reply']:
1077                    spec.reply = True
1078
1079    def _load_global_policy(self):
1080        global_set = set()
1081        attr_set_name = None
1082        for op_name, op in self.ops.items():
1083            if not op:
1084                continue
1085            if 'attribute-set' not in op:
1086                continue
1087
1088            if attr_set_name is None:
1089                attr_set_name = op['attribute-set']
1090            if attr_set_name != op['attribute-set']:
1091                raise Exception('For a global policy all ops must use the same set')
1092
1093            for op_mode in ['do', 'dump']:
1094                if op_mode in op:
1095                    req = op[op_mode].get('request')
1096                    if req:
1097                        global_set.update(req.get('attributes', []))
1098
1099        self.global_policy = []
1100        self.global_policy_set = attr_set_name
1101        for attr in self.attr_sets[attr_set_name]:
1102            if attr in global_set:
1103                self.global_policy.append(attr)
1104
1105    def _load_hooks(self):
1106        for op in self.ops.values():
1107            for op_mode in ['do', 'dump']:
1108                if op_mode not in op:
1109                    continue
1110                for when in ['pre', 'post']:
1111                    if when not in op[op_mode]:
1112                        continue
1113                    name = op[op_mode][when]
1114                    if name in self.hooks[when][op_mode]['set']:
1115                        continue
1116                    self.hooks[when][op_mode]['set'].add(name)
1117                    self.hooks[when][op_mode]['list'].append(name)
1118
1119
1120class RenderInfo:
1121    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1122        self.family = family
1123        self.nl = cw.nlib
1124        self.ku_space = ku_space
1125        self.op_mode = op_mode
1126        self.op = op
1127
1128        # 'do' and 'dump' response parsing is identical
1129        self.type_consistent = True
1130        if op_mode != 'do' and 'dump' in op:
1131            if 'do' in op:
1132                if ('reply' in op['do']) != ('reply' in op["dump"]):
1133                    self.type_consistent = False
1134                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1135                    self.type_consistent = False
1136            else:
1137                self.type_consistent = False
1138
1139        self.attr_set = attr_set
1140        if not self.attr_set:
1141            self.attr_set = op['attribute-set']
1142
1143        self.type_name_conflict = False
1144        if op:
1145            self.type_name = c_lower(op.name)
1146        else:
1147            self.type_name = c_lower(attr_set)
1148            if attr_set in family.consts:
1149                self.type_name_conflict = True
1150
1151        self.cw = cw
1152
1153        self.struct = dict()
1154        if op_mode == 'notify':
1155            op_mode = 'do'
1156        for op_dir in ['request', 'reply']:
1157            if op:
1158                type_list = []
1159                if op_dir in op[op_mode]:
1160                    type_list = op[op_mode][op_dir]['attributes']
1161                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1162        if op_mode == 'event':
1163            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1164
1165
1166class CodeWriter:
1167    def __init__(self, nlib, out_file=None, overwrite=True):
1168        self.nlib = nlib
1169        self._overwrite = overwrite
1170
1171        self._nl = False
1172        self._block_end = False
1173        self._silent_block = False
1174        self._ind = 0
1175        self._ifdef_block = None
1176        if out_file is None:
1177            self._out = os.sys.stdout
1178        else:
1179            self._out = tempfile.NamedTemporaryFile('w+')
1180            self._out_file = out_file
1181
1182    def __del__(self):
1183        self.close_out_file()
1184
1185    def close_out_file(self):
1186        if self._out == os.sys.stdout:
1187            return
1188        # Avoid modifying the file if contents didn't change
1189        self._out.flush()
1190        if not self._overwrite and os.path.isfile(self._out_file):
1191            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1192                return
1193        with open(self._out_file, 'w+') as out_file:
1194            self._out.seek(0)
1195            shutil.copyfileobj(self._out, out_file)
1196            self._out.close()
1197        self._out = os.sys.stdout
1198
1199    @classmethod
1200    def _is_cond(cls, line):
1201        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1202
1203    def p(self, line, add_ind=0):
1204        if self._block_end:
1205            self._block_end = False
1206            if line.startswith('else'):
1207                line = '} ' + line
1208            else:
1209                self._out.write('\t' * self._ind + '}\n')
1210
1211        if self._nl:
1212            self._out.write('\n')
1213            self._nl = False
1214
1215        ind = self._ind
1216        if line[-1] == ':':
1217            ind -= 1
1218        if self._silent_block:
1219            ind += 1
1220        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1221        if line[0] == '#':
1222            ind = 0
1223        if add_ind:
1224            ind += add_ind
1225        self._out.write('\t' * ind + line + '\n')
1226
1227    def nl(self):
1228        self._nl = True
1229
1230    def block_start(self, line=''):
1231        if line:
1232            line = line + ' '
1233        self.p(line + '{')
1234        self._ind += 1
1235
1236    def block_end(self, line=''):
1237        if line and line[0] not in {';', ','}:
1238            line = ' ' + line
1239        self._ind -= 1
1240        self._nl = False
1241        if not line:
1242            # Delay printing closing bracket in case "else" comes next
1243            if self._block_end:
1244                self._out.write('\t' * (self._ind + 1) + '}\n')
1245            self._block_end = True
1246        else:
1247            self.p('}' + line)
1248
1249    def write_doc_line(self, doc, indent=True):
1250        words = doc.split()
1251        line = ' *'
1252        for word in words:
1253            if len(line) + len(word) >= 79:
1254                self.p(line)
1255                line = ' *'
1256                if indent:
1257                    line += '  '
1258            line += ' ' + word
1259        self.p(line)
1260
1261    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1262        if not args:
1263            args = ['void']
1264
1265        if doc:
1266            self.p('/*')
1267            self.p(' * ' + doc)
1268            self.p(' */')
1269
1270        oneline = qual_ret
1271        if qual_ret[-1] != '*':
1272            oneline += ' '
1273        oneline += f"{name}({', '.join(args)}){suffix}"
1274
1275        if len(oneline) < 80:
1276            self.p(oneline)
1277            return
1278
1279        v = qual_ret
1280        if len(v) > 3:
1281            self.p(v)
1282            v = ''
1283        elif qual_ret[-1] != '*':
1284            v += ' '
1285        v += name + '('
1286        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1287        delta_ind = len(v) - len(ind)
1288        v += args[0]
1289        i = 1
1290        while i < len(args):
1291            next_len = len(v) + len(args[i])
1292            if v[0] == '\t':
1293                next_len += delta_ind
1294            if next_len > 76:
1295                self.p(v + ',')
1296                v = ind
1297            else:
1298                v += ', '
1299            v += args[i]
1300            i += 1
1301        self.p(v + ')' + suffix)
1302
1303    def write_func_lvar(self, local_vars):
1304        if not local_vars:
1305            return
1306
1307        if type(local_vars) is str:
1308            local_vars = [local_vars]
1309
1310        local_vars.sort(key=len, reverse=True)
1311        for var in local_vars:
1312            self.p(var)
1313        self.nl()
1314
1315    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1316        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1317        self.write_func_lvar(local_vars=local_vars)
1318
1319        self.block_start()
1320        for line in body:
1321            self.p(line)
1322        self.block_end()
1323
1324    def writes_defines(self, defines):
1325        longest = 0
1326        for define in defines:
1327            if len(define[0]) > longest:
1328                longest = len(define[0])
1329        longest = ((longest + 8) // 8) * 8
1330        for define in defines:
1331            line = '#define ' + define[0]
1332            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1333            if type(define[1]) is int:
1334                line += str(define[1])
1335            elif type(define[1]) is str:
1336                line += '"' + define[1] + '"'
1337            self.p(line)
1338
1339    def write_struct_init(self, members):
1340        longest = max([len(x[0]) for x in members])
1341        longest += 1  # because we prepend a .
1342        longest = ((longest + 8) // 8) * 8
1343        for one in members:
1344            line = '.' + one[0]
1345            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1346            line += '= ' + str(one[1]) + ','
1347            self.p(line)
1348
1349    def ifdef_block(self, config):
1350        config_option = None
1351        if config:
1352            config_option = 'CONFIG_' + c_upper(config)
1353        if self._ifdef_block == config_option:
1354            return
1355
1356        if self._ifdef_block:
1357            self.p('#endif /* ' + self._ifdef_block + ' */')
1358        if config_option:
1359            self.p('#ifdef ' + config_option)
1360        self._ifdef_block = config_option
1361
1362
1363scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64', 'uint', 'sint'}
1364
1365direction_to_suffix = {
1366    'reply': '_rsp',
1367    'request': '_req',
1368    '': ''
1369}
1370
1371op_mode_to_wrapper = {
1372    'do': '',
1373    'dump': '_list',
1374    'notify': '_ntf',
1375    'event': '',
1376}
1377
1378_C_KW = {
1379    'auto',
1380    'bool',
1381    'break',
1382    'case',
1383    'char',
1384    'const',
1385    'continue',
1386    'default',
1387    'do',
1388    'double',
1389    'else',
1390    'enum',
1391    'extern',
1392    'float',
1393    'for',
1394    'goto',
1395    'if',
1396    'inline',
1397    'int',
1398    'long',
1399    'register',
1400    'return',
1401    'short',
1402    'signed',
1403    'sizeof',
1404    'static',
1405    'struct',
1406    'switch',
1407    'typedef',
1408    'union',
1409    'unsigned',
1410    'void',
1411    'volatile',
1412    'while'
1413}
1414
1415
1416def rdir(direction):
1417    if direction == 'reply':
1418        return 'request'
1419    if direction == 'request':
1420        return 'reply'
1421    return direction
1422
1423
1424def op_prefix(ri, direction, deref=False):
1425    suffix = f"_{ri.type_name}"
1426
1427    if not ri.op_mode or ri.op_mode == 'do':
1428        suffix += f"{direction_to_suffix[direction]}"
1429    else:
1430        if direction == 'request':
1431            suffix += '_req_dump'
1432        else:
1433            if ri.type_consistent:
1434                if deref:
1435                    suffix += f"{direction_to_suffix[direction]}"
1436                else:
1437                    suffix += op_mode_to_wrapper[ri.op_mode]
1438            else:
1439                suffix += '_rsp'
1440                suffix += '_dump' if deref else '_list'
1441
1442    return f"{ri.family.c_name}{suffix}"
1443
1444
1445def type_name(ri, direction, deref=False):
1446    return f"struct {op_prefix(ri, direction, deref=deref)}"
1447
1448
1449def print_prototype(ri, direction, terminate=True, doc=None):
1450    suffix = ';' if terminate else ''
1451
1452    fname = ri.op.render_name
1453    if ri.op_mode == 'dump':
1454        fname += '_dump'
1455
1456    args = ['struct ynl_sock *ys']
1457    if 'request' in ri.op[ri.op_mode]:
1458        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1459
1460    ret = 'int'
1461    if 'reply' in ri.op[ri.op_mode]:
1462        ret = f"{type_name(ri, rdir(direction))} *"
1463
1464    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1465
1466
1467def print_req_prototype(ri):
1468    print_prototype(ri, "request", doc=ri.op['doc'])
1469
1470
1471def print_dump_prototype(ri):
1472    print_prototype(ri, "request")
1473
1474
1475def put_typol(cw, struct):
1476    type_max = struct.attr_set.max_name
1477    cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1478
1479    for _, arg in struct.member_list():
1480        arg.attr_typol(cw)
1481
1482    cw.block_end(line=';')
1483    cw.nl()
1484
1485    cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1486    cw.p(f'.max_attr = {type_max},')
1487    cw.p(f'.table = {struct.render_name}_policy,')
1488    cw.block_end(line=';')
1489    cw.nl()
1490
1491
1492def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1493    args = [f'int {arg_name}']
1494    if enum:
1495        args = [enum.user_type + ' ' + arg_name]
1496    cw.write_func_prot('const char *', f'{render_name}_str', args)
1497    cw.block_start()
1498    if enum and enum.type == 'flags':
1499        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1500    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)MNL_ARRAY_SIZE({map_name}))')
1501    cw.p('return NULL;')
1502    cw.p(f'return {map_name}[{arg_name}];')
1503    cw.block_end()
1504    cw.nl()
1505
1506
1507def put_op_name_fwd(family, cw):
1508    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1509
1510
1511def put_op_name(family, cw):
1512    map_name = f'{family.c_name}_op_strmap'
1513    cw.block_start(line=f"static const char * const {map_name}[] =")
1514    for op_name, op in family.msgs.items():
1515        if op.rsp_value:
1516            # Make sure we don't add duplicated entries, if multiple commands
1517            # produce the same response in legacy families.
1518            if family.rsp_by_value[op.rsp_value] != op:
1519                cw.p(f'// skip "{op_name}", duplicate reply value')
1520                continue
1521
1522            if op.req_value == op.rsp_value:
1523                cw.p(f'[{op.enum_name}] = "{op_name}",')
1524            else:
1525                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1526    cw.block_end(line=';')
1527    cw.nl()
1528
1529    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1530
1531
1532def put_enum_to_str_fwd(family, cw, enum):
1533    args = [enum.user_type + ' value']
1534    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1535
1536
1537def put_enum_to_str(family, cw, enum):
1538    map_name = f'{enum.render_name}_strmap'
1539    cw.block_start(line=f"static const char * const {map_name}[] =")
1540    for entry in enum.entries.values():
1541        cw.p(f'[{entry.value}] = "{entry.name}",')
1542    cw.block_end(line=';')
1543    cw.nl()
1544
1545    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1546
1547
1548def put_req_nested(ri, struct):
1549    func_args = ['struct nlmsghdr *nlh',
1550                 'unsigned int attr_type',
1551                 f'{struct.ptr_name}obj']
1552
1553    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1554    ri.cw.block_start()
1555    ri.cw.write_func_lvar('struct nlattr *nest;')
1556
1557    ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1558
1559    for _, arg in struct.member_list():
1560        arg.attr_put(ri, "obj")
1561
1562    ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1563
1564    ri.cw.nl()
1565    ri.cw.p('return 0;')
1566    ri.cw.block_end()
1567    ri.cw.nl()
1568
1569
1570def _multi_parse(ri, struct, init_lines, local_vars):
1571    if struct.nested:
1572        iter_line = "mnl_attr_for_each_nested(attr, nested)"
1573    else:
1574        iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1575
1576    array_nests = set()
1577    multi_attrs = set()
1578    needs_parg = False
1579    for arg, aspec in struct.member_list():
1580        if aspec['type'] == 'array-nest':
1581            local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1582            array_nests.add(arg)
1583        if 'multi-attr' in aspec:
1584            multi_attrs.add(arg)
1585        needs_parg |= 'nested-attributes' in aspec
1586    if array_nests or multi_attrs:
1587        local_vars.append('int i;')
1588    if needs_parg:
1589        local_vars.append('struct ynl_parse_arg parg;')
1590        init_lines.append('parg.ys = yarg->ys;')
1591
1592    all_multi = array_nests | multi_attrs
1593
1594    for anest in sorted(all_multi):
1595        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1596
1597    ri.cw.block_start()
1598    ri.cw.write_func_lvar(local_vars)
1599
1600    for line in init_lines:
1601        ri.cw.p(line)
1602    ri.cw.nl()
1603
1604    for arg in struct.inherited:
1605        ri.cw.p(f'dst->{arg} = {arg};')
1606
1607    for anest in sorted(all_multi):
1608        aspec = struct[anest]
1609        ri.cw.p(f"if (dst->{aspec.c_name})")
1610        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1611
1612    ri.cw.nl()
1613    ri.cw.block_start(line=iter_line)
1614    ri.cw.p('unsigned int type = mnl_attr_get_type(attr);')
1615    ri.cw.nl()
1616
1617    first = True
1618    for _, arg in struct.member_list():
1619        good = arg.attr_get(ri, 'dst', first=first)
1620        # First may be 'unused' or 'pad', ignore those
1621        first &= not good
1622
1623    ri.cw.block_end()
1624    ri.cw.nl()
1625
1626    for anest in sorted(array_nests):
1627        aspec = struct[anest]
1628
1629        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1630        ri.cw.p(f"dst->{aspec.c_name} = calloc({aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1631        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1632        ri.cw.p('i = 0;')
1633        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1634        ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1635        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1636        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1637        ri.cw.p('return MNL_CB_ERROR;')
1638        ri.cw.p('i++;')
1639        ri.cw.block_end()
1640        ri.cw.block_end()
1641    ri.cw.nl()
1642
1643    for anest in sorted(multi_attrs):
1644        aspec = struct[anest]
1645        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1646        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1647        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1648        ri.cw.p('i = 0;')
1649        if 'nested-attributes' in aspec:
1650            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1651        ri.cw.block_start(line=iter_line)
1652        ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1653        if 'nested-attributes' in aspec:
1654            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1655            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1656            ri.cw.p('return MNL_CB_ERROR;')
1657        elif aspec.type in scalars:
1658            ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{aspec.mnl_type()}(attr);")
1659        else:
1660            raise Exception('Nest parsing type not supported yet')
1661        ri.cw.p('i++;')
1662        ri.cw.block_end()
1663        ri.cw.block_end()
1664        ri.cw.block_end()
1665    ri.cw.nl()
1666
1667    if struct.nested:
1668        ri.cw.p('return 0;')
1669    else:
1670        ri.cw.p('return MNL_CB_OK;')
1671    ri.cw.block_end()
1672    ri.cw.nl()
1673
1674
1675def parse_rsp_nested(ri, struct):
1676    func_args = ['struct ynl_parse_arg *yarg',
1677                 'const struct nlattr *nested']
1678    for arg in struct.inherited:
1679        func_args.append('__u32 ' + arg)
1680
1681    local_vars = ['const struct nlattr *attr;',
1682                  f'{struct.ptr_name}dst = yarg->data;']
1683    init_lines = []
1684
1685    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1686
1687    _multi_parse(ri, struct, init_lines, local_vars)
1688
1689
1690def parse_rsp_msg(ri, deref=False):
1691    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1692        return
1693
1694    func_args = ['const struct nlmsghdr *nlh',
1695                 'void *data']
1696
1697    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1698                  'struct ynl_parse_arg *yarg = data;',
1699                  'const struct nlattr *attr;']
1700    init_lines = ['dst = yarg->data;']
1701
1702    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1703
1704    if ri.struct["reply"].member_list():
1705        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1706    else:
1707        # Empty reply
1708        ri.cw.block_start()
1709        ri.cw.p('return MNL_CB_OK;')
1710        ri.cw.block_end()
1711        ri.cw.nl()
1712
1713
1714def print_req(ri):
1715    ret_ok = '0'
1716    ret_err = '-1'
1717    direction = "request"
1718    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1719                  'struct nlmsghdr *nlh;',
1720                  'int err;']
1721
1722    if 'reply' in ri.op[ri.op_mode]:
1723        ret_ok = 'rsp'
1724        ret_err = 'NULL'
1725        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1726
1727    print_prototype(ri, direction, terminate=False)
1728    ri.cw.block_start()
1729    ri.cw.write_func_lvar(local_vars)
1730
1731    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1732
1733    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1734    if 'reply' in ri.op[ri.op_mode]:
1735        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1736    ri.cw.nl()
1737    for _, attr in ri.struct["request"].member_list():
1738        attr.attr_put(ri, "req")
1739    ri.cw.nl()
1740
1741    if 'reply' in ri.op[ri.op_mode]:
1742        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1743        ri.cw.p('yrs.yarg.data = rsp;')
1744        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1745        if ri.op.value is not None:
1746            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1747        else:
1748            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1749        ri.cw.nl()
1750    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1751    ri.cw.p('if (err < 0)')
1752    if 'reply' in ri.op[ri.op_mode]:
1753        ri.cw.p('goto err_free;')
1754    else:
1755        ri.cw.p('return -1;')
1756    ri.cw.nl()
1757
1758    ri.cw.p(f"return {ret_ok};")
1759    ri.cw.nl()
1760
1761    if 'reply' in ri.op[ri.op_mode]:
1762        ri.cw.p('err_free:')
1763        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1764        ri.cw.p(f"return {ret_err};")
1765
1766    ri.cw.block_end()
1767
1768
1769def print_dump(ri):
1770    direction = "request"
1771    print_prototype(ri, direction, terminate=False)
1772    ri.cw.block_start()
1773    local_vars = ['struct ynl_dump_state yds = {};',
1774                  'struct nlmsghdr *nlh;',
1775                  'int err;']
1776
1777    for var in local_vars:
1778        ri.cw.p(f'{var}')
1779    ri.cw.nl()
1780
1781    ri.cw.p('yds.ys = ys;')
1782    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1783    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1784    if ri.op.value is not None:
1785        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1786    else:
1787        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1788    ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1789    ri.cw.nl()
1790    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1791
1792    if "request" in ri.op[ri.op_mode]:
1793        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1794        ri.cw.nl()
1795        for _, attr in ri.struct["request"].member_list():
1796            attr.attr_put(ri, "req")
1797    ri.cw.nl()
1798
1799    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1800    ri.cw.p('if (err < 0)')
1801    ri.cw.p('goto free_list;')
1802    ri.cw.nl()
1803
1804    ri.cw.p('return yds.first;')
1805    ri.cw.nl()
1806    ri.cw.p('free_list:')
1807    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1808    ri.cw.p('return NULL;')
1809    ri.cw.block_end()
1810
1811
1812def call_free(ri, direction, var):
1813    return f"{op_prefix(ri, direction)}_free({var});"
1814
1815
1816def free_arg_name(direction):
1817    if direction:
1818        return direction_to_suffix[direction][1:]
1819    return 'obj'
1820
1821
1822def print_alloc_wrapper(ri, direction):
1823    name = op_prefix(ri, direction)
1824    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1825    ri.cw.block_start()
1826    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1827    ri.cw.block_end()
1828
1829
1830def print_free_prototype(ri, direction, suffix=';'):
1831    name = op_prefix(ri, direction)
1832    struct_name = name
1833    if ri.type_name_conflict:
1834        struct_name += '_'
1835    arg = free_arg_name(direction)
1836    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1837
1838
1839def _print_type(ri, direction, struct):
1840    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1841    if not direction and ri.type_name_conflict:
1842        suffix += '_'
1843
1844    if ri.op_mode == 'dump':
1845        suffix += '_dump'
1846
1847    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
1848
1849    meta_started = False
1850    for _, attr in struct.member_list():
1851        for type_filter in ['len', 'bit']:
1852            line = attr.presence_member(ri.ku_space, type_filter)
1853            if line:
1854                if not meta_started:
1855                    ri.cw.block_start(line=f"struct")
1856                    meta_started = True
1857                ri.cw.p(line)
1858    if meta_started:
1859        ri.cw.block_end(line='_present;')
1860        ri.cw.nl()
1861
1862    for arg in struct.inherited:
1863        ri.cw.p(f"__u32 {arg};")
1864
1865    for _, attr in struct.member_list():
1866        attr.struct_member(ri)
1867
1868    ri.cw.block_end(line=';')
1869    ri.cw.nl()
1870
1871
1872def print_type(ri, direction):
1873    _print_type(ri, direction, ri.struct[direction])
1874
1875
1876def print_type_full(ri, struct):
1877    _print_type(ri, "", struct)
1878
1879
1880def print_type_helpers(ri, direction, deref=False):
1881    print_free_prototype(ri, direction)
1882    ri.cw.nl()
1883
1884    if ri.ku_space == 'user' and direction == 'request':
1885        for _, attr in ri.struct[direction].member_list():
1886            attr.setter(ri, ri.attr_set, direction, deref=deref)
1887    ri.cw.nl()
1888
1889
1890def print_req_type_helpers(ri):
1891    if len(ri.struct["request"].attr_list) == 0:
1892        return
1893    print_alloc_wrapper(ri, "request")
1894    print_type_helpers(ri, "request")
1895
1896
1897def print_rsp_type_helpers(ri):
1898    if 'reply' not in ri.op[ri.op_mode]:
1899        return
1900    print_type_helpers(ri, "reply")
1901
1902
1903def print_parse_prototype(ri, direction, terminate=True):
1904    suffix = "_rsp" if direction == "reply" else "_req"
1905    term = ';' if terminate else ''
1906
1907    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1908                          ['const struct nlattr **tb',
1909                           f"struct {ri.op.render_name}{suffix} *req"],
1910                          suffix=term)
1911
1912
1913def print_req_type(ri):
1914    if len(ri.struct["request"].attr_list) == 0:
1915        return
1916    print_type(ri, "request")
1917
1918
1919def print_req_free(ri):
1920    if 'request' not in ri.op[ri.op_mode]:
1921        return
1922    _free_type(ri, 'request', ri.struct['request'])
1923
1924
1925def print_rsp_type(ri):
1926    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1927        direction = 'reply'
1928    elif ri.op_mode == 'event':
1929        direction = 'reply'
1930    else:
1931        return
1932    print_type(ri, direction)
1933
1934
1935def print_wrapped_type(ri):
1936    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1937    if ri.op_mode == 'dump':
1938        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1939    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1940        ri.cw.p('__u16 family;')
1941        ri.cw.p('__u8 cmd;')
1942        ri.cw.p('struct ynl_ntf_base_type *next;')
1943        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1944    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
1945    ri.cw.block_end(line=';')
1946    ri.cw.nl()
1947    print_free_prototype(ri, 'reply')
1948    ri.cw.nl()
1949
1950
1951def _free_type_members_iter(ri, struct):
1952    for _, attr in struct.member_list():
1953        if attr.free_needs_iter():
1954            ri.cw.p('unsigned int i;')
1955            ri.cw.nl()
1956            break
1957
1958
1959def _free_type_members(ri, var, struct, ref=''):
1960    for _, attr in struct.member_list():
1961        attr.free(ri, var, ref)
1962
1963
1964def _free_type(ri, direction, struct):
1965    var = free_arg_name(direction)
1966
1967    print_free_prototype(ri, direction, suffix='')
1968    ri.cw.block_start()
1969    _free_type_members_iter(ri, struct)
1970    _free_type_members(ri, var, struct)
1971    if direction:
1972        ri.cw.p(f'free({var});')
1973    ri.cw.block_end()
1974    ri.cw.nl()
1975
1976
1977def free_rsp_nested(ri, struct):
1978    _free_type(ri, "", struct)
1979
1980
1981def print_rsp_free(ri):
1982    if 'reply' not in ri.op[ri.op_mode]:
1983        return
1984    _free_type(ri, 'reply', ri.struct['reply'])
1985
1986
1987def print_dump_type_free(ri):
1988    sub_type = type_name(ri, 'reply')
1989
1990    print_free_prototype(ri, 'reply', suffix='')
1991    ri.cw.block_start()
1992    ri.cw.p(f"{sub_type} *next = rsp;")
1993    ri.cw.nl()
1994    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
1995    _free_type_members_iter(ri, ri.struct['reply'])
1996    ri.cw.p('rsp = next;')
1997    ri.cw.p('next = rsp->next;')
1998    ri.cw.nl()
1999
2000    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2001    ri.cw.p(f'free(rsp);')
2002    ri.cw.block_end()
2003    ri.cw.block_end()
2004    ri.cw.nl()
2005
2006
2007def print_ntf_type_free(ri):
2008    print_free_prototype(ri, 'reply', suffix='')
2009    ri.cw.block_start()
2010    _free_type_members_iter(ri, ri.struct['reply'])
2011    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2012    ri.cw.p(f'free(rsp);')
2013    ri.cw.block_end()
2014    ri.cw.nl()
2015
2016
2017def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2018    if terminate and ri and policy_should_be_static(struct.family):
2019        return
2020
2021    if terminate:
2022        prefix = 'extern '
2023    else:
2024        if ri and policy_should_be_static(struct.family):
2025            prefix = 'static '
2026        else:
2027            prefix = ''
2028
2029    suffix = ';' if terminate else ' = {'
2030
2031    max_attr = struct.attr_max_val
2032    if ri:
2033        name = ri.op.render_name
2034        if ri.op.dual_policy:
2035            name += '_' + ri.op_mode
2036    else:
2037        name = struct.render_name
2038    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2039
2040
2041def print_req_policy(cw, struct, ri=None):
2042    if ri and ri.op:
2043        cw.ifdef_block(ri.op.get('config-cond', None))
2044    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2045    for _, arg in struct.member_list():
2046        arg.attr_policy(cw)
2047    cw.p("};")
2048    cw.ifdef_block(None)
2049    cw.nl()
2050
2051
2052def kernel_can_gen_family_struct(family):
2053    return family.proto == 'genetlink'
2054
2055
2056def policy_should_be_static(family):
2057    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2058
2059
2060def print_kernel_policy_ranges(family, cw):
2061    first = True
2062    for _, attr_set in family.attr_sets.items():
2063        if attr_set.subset_of:
2064            continue
2065
2066        for _, attr in attr_set.items():
2067            if not attr.request:
2068                continue
2069            if 'full-range' not in attr.checks:
2070                continue
2071
2072            if first:
2073                cw.p('/* Integer value ranges */')
2074                first = False
2075
2076            sign = '' if attr.type[0] == 'u' else '_signed'
2077            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2078            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2079            members = []
2080            if 'min' in attr.checks:
2081                members.append(('min', str(attr.get_limit('min')) + suffix))
2082            if 'max' in attr.checks:
2083                members.append(('max', str(attr.get_limit('max')) + suffix))
2084            cw.write_struct_init(members)
2085            cw.block_end(line=';')
2086            cw.nl()
2087
2088
2089def print_kernel_op_table_fwd(family, cw, terminate):
2090    exported = not kernel_can_gen_family_struct(family)
2091
2092    if not terminate or exported:
2093        cw.p(f"/* Ops table for {family.name} */")
2094
2095        pol_to_struct = {'global': 'genl_small_ops',
2096                         'per-op': 'genl_ops',
2097                         'split': 'genl_split_ops'}
2098        struct_type = pol_to_struct[family.kernel_policy]
2099
2100        if not exported:
2101            cnt = ""
2102        elif family.kernel_policy == 'split':
2103            cnt = 0
2104            for op in family.ops.values():
2105                if 'do' in op:
2106                    cnt += 1
2107                if 'dump' in op:
2108                    cnt += 1
2109        else:
2110            cnt = len(family.ops)
2111
2112        qual = 'static const' if not exported else 'const'
2113        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2114        if terminate:
2115            cw.p(f"extern {line};")
2116        else:
2117            cw.block_start(line=line + ' =')
2118
2119    if not terminate:
2120        return
2121
2122    cw.nl()
2123    for name in family.hooks['pre']['do']['list']:
2124        cw.write_func_prot('int', c_lower(name),
2125                           ['const struct genl_split_ops *ops',
2126                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2127    for name in family.hooks['post']['do']['list']:
2128        cw.write_func_prot('void', c_lower(name),
2129                           ['const struct genl_split_ops *ops',
2130                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2131    for name in family.hooks['pre']['dump']['list']:
2132        cw.write_func_prot('int', c_lower(name),
2133                           ['struct netlink_callback *cb'], suffix=';')
2134    for name in family.hooks['post']['dump']['list']:
2135        cw.write_func_prot('int', c_lower(name),
2136                           ['struct netlink_callback *cb'], suffix=';')
2137
2138    cw.nl()
2139
2140    for op_name, op in family.ops.items():
2141        if op.is_async:
2142            continue
2143
2144        if 'do' in op:
2145            name = c_lower(f"{family.name}-nl-{op_name}-doit")
2146            cw.write_func_prot('int', name,
2147                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2148
2149        if 'dump' in op:
2150            name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
2151            cw.write_func_prot('int', name,
2152                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2153    cw.nl()
2154
2155
2156def print_kernel_op_table_hdr(family, cw):
2157    print_kernel_op_table_fwd(family, cw, terminate=True)
2158
2159
2160def print_kernel_op_table(family, cw):
2161    print_kernel_op_table_fwd(family, cw, terminate=False)
2162    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2163        for op_name, op in family.ops.items():
2164            if op.is_async:
2165                continue
2166
2167            cw.ifdef_block(op.get('config-cond', None))
2168            cw.block_start()
2169            members = [('cmd', op.enum_name)]
2170            if 'dont-validate' in op:
2171                members.append(('validate',
2172                                ' | '.join([c_upper('genl-dont-validate-' + x)
2173                                            for x in op['dont-validate']])), )
2174            for op_mode in ['do', 'dump']:
2175                if op_mode in op:
2176                    name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2177                    members.append((op_mode + 'it', name))
2178            if family.kernel_policy == 'per-op':
2179                struct = Struct(family, op['attribute-set'],
2180                                type_list=op['do']['request']['attributes'])
2181
2182                name = c_lower(f"{family.name}-{op_name}-nl-policy")
2183                members.append(('policy', name))
2184                members.append(('maxattr', struct.attr_max_val.enum_name))
2185            if 'flags' in op:
2186                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2187            cw.write_struct_init(members)
2188            cw.block_end(line=',')
2189    elif family.kernel_policy == 'split':
2190        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2191                    'dump': {'pre': 'start', 'post': 'done'}}
2192
2193        for op_name, op in family.ops.items():
2194            for op_mode in ['do', 'dump']:
2195                if op.is_async or op_mode not in op:
2196                    continue
2197
2198                cw.ifdef_block(op.get('config-cond', None))
2199                cw.block_start()
2200                members = [('cmd', op.enum_name)]
2201                if 'dont-validate' in op:
2202                    dont_validate = []
2203                    for x in op['dont-validate']:
2204                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2205                            continue
2206                        if op_mode == "dump" and x == 'strict':
2207                            continue
2208                        dont_validate.append(x)
2209
2210                    if dont_validate:
2211                        members.append(('validate',
2212                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2213                                                    for x in dont_validate])), )
2214                name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2215                if 'pre' in op[op_mode]:
2216                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2217                members.append((op_mode + 'it', name))
2218                if 'post' in op[op_mode]:
2219                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2220                if 'request' in op[op_mode]:
2221                    struct = Struct(family, op['attribute-set'],
2222                                    type_list=op[op_mode]['request']['attributes'])
2223
2224                    if op.dual_policy:
2225                        name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
2226                    else:
2227                        name = c_lower(f"{family.name}-{op_name}-nl-policy")
2228                    members.append(('policy', name))
2229                    members.append(('maxattr', struct.attr_max_val.enum_name))
2230                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2231                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2232                cw.write_struct_init(members)
2233                cw.block_end(line=',')
2234    cw.ifdef_block(None)
2235
2236    cw.block_end(line=';')
2237    cw.nl()
2238
2239
2240def print_kernel_mcgrp_hdr(family, cw):
2241    if not family.mcgrps['list']:
2242        return
2243
2244    cw.block_start('enum')
2245    for grp in family.mcgrps['list']:
2246        grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2247        cw.p(grp_id)
2248    cw.block_end(';')
2249    cw.nl()
2250
2251
2252def print_kernel_mcgrp_src(family, cw):
2253    if not family.mcgrps['list']:
2254        return
2255
2256    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2257    for grp in family.mcgrps['list']:
2258        name = grp['name']
2259        grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2260        cw.p('[' + grp_id + '] = { "' + name + '", },')
2261    cw.block_end(';')
2262    cw.nl()
2263
2264
2265def print_kernel_family_struct_hdr(family, cw):
2266    if not kernel_can_gen_family_struct(family):
2267        return
2268
2269    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2270    cw.nl()
2271
2272
2273def print_kernel_family_struct_src(family, cw):
2274    if not kernel_can_gen_family_struct(family):
2275        return
2276
2277    cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2278    cw.p('.name\t\t= ' + family.fam_key + ',')
2279    cw.p('.version\t= ' + family.ver_key + ',')
2280    cw.p('.netnsok\t= true,')
2281    cw.p('.parallel_ops\t= true,')
2282    cw.p('.module\t\t= THIS_MODULE,')
2283    if family.kernel_policy == 'per-op':
2284        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2285        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2286    elif family.kernel_policy == 'split':
2287        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2288        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2289    if family.mcgrps['list']:
2290        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2291        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2292    cw.block_end(';')
2293
2294
2295def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2296    start_line = 'enum'
2297    if enum_name in obj:
2298        if obj[enum_name]:
2299            start_line = 'enum ' + c_lower(obj[enum_name])
2300    elif ckey and ckey in obj:
2301        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2302    cw.block_start(line=start_line)
2303
2304
2305def render_uapi(family, cw):
2306    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2307    cw.p('#ifndef ' + hdr_prot)
2308    cw.p('#define ' + hdr_prot)
2309    cw.nl()
2310
2311    defines = [(family.fam_key, family["name"]),
2312               (family.ver_key, family.get('version', 1))]
2313    cw.writes_defines(defines)
2314    cw.nl()
2315
2316    defines = []
2317    for const in family['definitions']:
2318        if const['type'] != 'const':
2319            cw.writes_defines(defines)
2320            defines = []
2321            cw.nl()
2322
2323        # Write kdoc for enum and flags (one day maybe also structs)
2324        if const['type'] == 'enum' or const['type'] == 'flags':
2325            enum = family.consts[const['name']]
2326
2327            if enum.has_doc():
2328                cw.p('/**')
2329                doc = ''
2330                if 'doc' in enum:
2331                    doc = ' - ' + enum['doc']
2332                cw.write_doc_line(enum.enum_name + doc)
2333                for entry in enum.entries.values():
2334                    if entry.has_doc():
2335                        doc = '@' + entry.c_name + ': ' + entry['doc']
2336                        cw.write_doc_line(doc)
2337                cw.p(' */')
2338
2339            uapi_enum_start(family, cw, const, 'name')
2340            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2341            for entry in enum.entries.values():
2342                suffix = ','
2343                if entry.value_change:
2344                    suffix = f" = {entry.user_value()}" + suffix
2345                cw.p(entry.c_name + suffix)
2346
2347            if const.get('render-max', False):
2348                cw.nl()
2349                cw.p('/* private: */')
2350                if const['type'] == 'flags':
2351                    max_name = c_upper(name_pfx + 'mask')
2352                    max_val = f' = {enum.get_mask()},'
2353                    cw.p(max_name + max_val)
2354                else:
2355                    max_name = c_upper(name_pfx + 'max')
2356                    cw.p('__' + max_name + ',')
2357                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2358            cw.block_end(line=';')
2359            cw.nl()
2360        elif const['type'] == 'const':
2361            defines.append([c_upper(family.get('c-define-name',
2362                                               f"{family.name}-{const['name']}")),
2363                            const['value']])
2364
2365    if defines:
2366        cw.writes_defines(defines)
2367        cw.nl()
2368
2369    max_by_define = family.get('max-by-define', False)
2370
2371    for _, attr_set in family.attr_sets.items():
2372        if attr_set.subset_of:
2373            continue
2374
2375        max_value = f"({attr_set.cnt_name} - 1)"
2376
2377        val = 0
2378        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2379        for _, attr in attr_set.items():
2380            suffix = ','
2381            if attr.value != val:
2382                suffix = f" = {attr.value},"
2383                val = attr.value
2384            val += 1
2385            cw.p(attr.enum_name + suffix)
2386        cw.nl()
2387        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2388        if not max_by_define:
2389            cw.p(f"{attr_set.max_name} = {max_value}")
2390        cw.block_end(line=';')
2391        if max_by_define:
2392            cw.p(f"#define {attr_set.max_name} {max_value}")
2393        cw.nl()
2394
2395    # Commands
2396    separate_ntf = 'async-prefix' in family['operations']
2397
2398    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2399    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2400    max_value = f"({cnt_name} - 1)"
2401
2402    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2403    val = 0
2404    for op in family.msgs.values():
2405        if separate_ntf and ('notify' in op or 'event' in op):
2406            continue
2407
2408        suffix = ','
2409        if op.value != val:
2410            suffix = f" = {op.value},"
2411            val = op.value
2412        cw.p(op.enum_name + suffix)
2413        val += 1
2414    cw.nl()
2415    cw.p(cnt_name + ('' if max_by_define else ','))
2416    if not max_by_define:
2417        cw.p(f"{max_name} = {max_value}")
2418    cw.block_end(line=';')
2419    if max_by_define:
2420        cw.p(f"#define {max_name} {max_value}")
2421    cw.nl()
2422
2423    if separate_ntf:
2424        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2425        for op in family.msgs.values():
2426            if separate_ntf and not ('notify' in op or 'event' in op):
2427                continue
2428
2429            suffix = ','
2430            if 'value' in op:
2431                suffix = f" = {op['value']},"
2432            cw.p(op.enum_name + suffix)
2433        cw.block_end(line=';')
2434        cw.nl()
2435
2436    # Multicast
2437    defines = []
2438    for grp in family.mcgrps['list']:
2439        name = grp['name']
2440        defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2441                        f'{name}'])
2442    cw.nl()
2443    if defines:
2444        cw.writes_defines(defines)
2445        cw.nl()
2446
2447    cw.p(f'#endif /* {hdr_prot} */')
2448
2449
2450def _render_user_ntf_entry(ri, op):
2451    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2452    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2453    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2454    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2455    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2456    ri.cw.block_end(line=',')
2457
2458
2459def render_user_family(family, cw, prototype):
2460    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2461    if prototype:
2462        cw.p(f'extern {symbol};')
2463        return
2464
2465    if family.ntfs:
2466        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2467        for ntf_op_name, ntf_op in family.ntfs.items():
2468            if 'notify' in ntf_op:
2469                op = family.ops[ntf_op['notify']]
2470                ri = RenderInfo(cw, family, "user", op, "notify")
2471            elif 'event' in ntf_op:
2472                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2473            else:
2474                raise Exception('Invalid notification ' + ntf_op_name)
2475            _render_user_ntf_entry(ri, ntf_op)
2476        for op_name, op in family.ops.items():
2477            if 'event' not in op:
2478                continue
2479            ri = RenderInfo(cw, family, "user", op, "event")
2480            _render_user_ntf_entry(ri, op)
2481        cw.block_end(line=";")
2482        cw.nl()
2483
2484    cw.block_start(f'{symbol} = ')
2485    cw.p(f'.name\t\t= "{family.c_name}",')
2486    if family.ntfs:
2487        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2488        cw.p(f".ntf_info_size\t= MNL_ARRAY_SIZE({family['name']}_ntf_info),")
2489    cw.block_end(line=';')
2490
2491
2492def family_contains_bitfield32(family):
2493    for _, attr_set in family.attr_sets.items():
2494        if attr_set.subset_of:
2495            continue
2496        for _, attr in attr_set.items():
2497            if attr.type == "bitfield32":
2498                return True
2499    return False
2500
2501
2502def find_kernel_root(full_path):
2503    sub_path = ''
2504    while True:
2505        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2506        full_path = os.path.dirname(full_path)
2507        maintainers = os.path.join(full_path, "MAINTAINERS")
2508        if os.path.exists(maintainers):
2509            return full_path, sub_path[:-1]
2510
2511
2512def main():
2513    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2514    parser.add_argument('--mode', dest='mode', type=str, required=True)
2515    parser.add_argument('--spec', dest='spec', type=str, required=True)
2516    parser.add_argument('--header', dest='header', action='store_true', default=None)
2517    parser.add_argument('--source', dest='header', action='store_false')
2518    parser.add_argument('--user-header', nargs='+', default=[])
2519    parser.add_argument('--cmp-out', action='store_true', default=None,
2520                        help='Do not overwrite the output file if the new output is identical to the old')
2521    parser.add_argument('--exclude-op', action='append', default=[])
2522    parser.add_argument('-o', dest='out_file', type=str, default=None)
2523    args = parser.parse_args()
2524
2525    if args.header is None:
2526        parser.error("--header or --source is required")
2527
2528    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2529
2530    try:
2531        parsed = Family(args.spec, exclude_ops)
2532        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2533            print('Spec license:', parsed.license)
2534            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2535            os.sys.exit(1)
2536    except yaml.YAMLError as exc:
2537        print(exc)
2538        os.sys.exit(1)
2539        return
2540
2541    supported_models = ['unified']
2542    if args.mode in ['user', 'kernel']:
2543        supported_models += ['directional']
2544    if parsed.msg_id_model not in supported_models:
2545        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2546        os.sys.exit(1)
2547
2548    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2549
2550    _, spec_kernel = find_kernel_root(args.spec)
2551    if args.mode == 'uapi' or args.header:
2552        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2553    else:
2554        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2555    cw.p("/* Do not edit directly, auto-generated from: */")
2556    cw.p(f"/*\t{spec_kernel} */")
2557    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2558    if args.exclude_op or args.user_header:
2559        line = ''
2560        line += ' --user-header '.join([''] + args.user_header)
2561        line += ' --exclude-op '.join([''] + args.exclude_op)
2562        cw.p(f'/* YNL-ARG{line} */')
2563    cw.nl()
2564
2565    if args.mode == 'uapi':
2566        render_uapi(parsed, cw)
2567        return
2568
2569    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2570    if args.header:
2571        cw.p('#ifndef ' + hdr_prot)
2572        cw.p('#define ' + hdr_prot)
2573        cw.nl()
2574
2575    if args.mode == 'kernel':
2576        cw.p('#include <net/netlink.h>')
2577        cw.p('#include <net/genetlink.h>')
2578        cw.nl()
2579        if not args.header:
2580            if args.out_file:
2581                cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2582            cw.nl()
2583        headers = ['uapi/' + parsed.uapi_header]
2584    else:
2585        cw.p('#include <stdlib.h>')
2586        cw.p('#include <string.h>')
2587        if args.header:
2588            cw.p('#include <linux/types.h>')
2589            if family_contains_bitfield32(parsed):
2590                cw.p('#include <linux/netlink.h>')
2591        else:
2592            cw.p(f'#include "{parsed.name}-user.h"')
2593            cw.p('#include "ynl.h"')
2594        headers = [parsed.uapi_header]
2595    for definition in parsed['definitions']:
2596        if 'header' in definition:
2597            headers.append(definition['header'])
2598    for one in headers:
2599        cw.p(f"#include <{one}>")
2600    cw.nl()
2601
2602    if args.mode == "user":
2603        if not args.header:
2604            cw.p("#include <libmnl/libmnl.h>")
2605            cw.p("#include <linux/genetlink.h>")
2606            cw.nl()
2607            for one in args.user_header:
2608                cw.p(f'#include "{one}"')
2609        else:
2610            cw.p('struct ynl_sock;')
2611            cw.nl()
2612            render_user_family(parsed, cw, True)
2613        cw.nl()
2614
2615    if args.mode == "kernel":
2616        if args.header:
2617            for _, struct in sorted(parsed.pure_nested_structs.items()):
2618                if struct.request:
2619                    cw.p('/* Common nested types */')
2620                    break
2621            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2622                if struct.request:
2623                    print_req_policy_fwd(cw, struct)
2624            cw.nl()
2625
2626            if parsed.kernel_policy == 'global':
2627                cw.p(f"/* Global operation policy for {parsed.name} */")
2628
2629                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2630                print_req_policy_fwd(cw, struct)
2631                cw.nl()
2632
2633            if parsed.kernel_policy in {'per-op', 'split'}:
2634                for op_name, op in parsed.ops.items():
2635                    if 'do' in op and 'event' not in op:
2636                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2637                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2638                        cw.nl()
2639
2640            print_kernel_op_table_hdr(parsed, cw)
2641            print_kernel_mcgrp_hdr(parsed, cw)
2642            print_kernel_family_struct_hdr(parsed, cw)
2643        else:
2644            print_kernel_policy_ranges(parsed, cw)
2645
2646            for _, struct in sorted(parsed.pure_nested_structs.items()):
2647                if struct.request:
2648                    cw.p('/* Common nested types */')
2649                    break
2650            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2651                if struct.request:
2652                    print_req_policy(cw, struct)
2653            cw.nl()
2654
2655            if parsed.kernel_policy == 'global':
2656                cw.p(f"/* Global operation policy for {parsed.name} */")
2657
2658                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2659                print_req_policy(cw, struct)
2660                cw.nl()
2661
2662            for op_name, op in parsed.ops.items():
2663                if parsed.kernel_policy in {'per-op', 'split'}:
2664                    for op_mode in ['do', 'dump']:
2665                        if op_mode in op and 'request' in op[op_mode]:
2666                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2667                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2668                            print_req_policy(cw, ri.struct['request'], ri=ri)
2669                            cw.nl()
2670
2671            print_kernel_op_table(parsed, cw)
2672            print_kernel_mcgrp_src(parsed, cw)
2673            print_kernel_family_struct_src(parsed, cw)
2674
2675    if args.mode == "user":
2676        if args.header:
2677            cw.p('/* Enums */')
2678            put_op_name_fwd(parsed, cw)
2679
2680            for name, const in parsed.consts.items():
2681                if isinstance(const, EnumSet):
2682                    put_enum_to_str_fwd(parsed, cw, const)
2683            cw.nl()
2684
2685            cw.p('/* Common nested types */')
2686            for attr_set, struct in parsed.pure_nested_structs.items():
2687                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2688                print_type_full(ri, struct)
2689
2690            for op_name, op in parsed.ops.items():
2691                cw.p(f"/* ============== {op.enum_name} ============== */")
2692
2693                if 'do' in op and 'event' not in op:
2694                    cw.p(f"/* {op.enum_name} - do */")
2695                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2696                    print_req_type(ri)
2697                    print_req_type_helpers(ri)
2698                    cw.nl()
2699                    print_rsp_type(ri)
2700                    print_rsp_type_helpers(ri)
2701                    cw.nl()
2702                    print_req_prototype(ri)
2703                    cw.nl()
2704
2705                if 'dump' in op:
2706                    cw.p(f"/* {op.enum_name} - dump */")
2707                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2708                    print_req_type(ri)
2709                    print_req_type_helpers(ri)
2710                    if not ri.type_consistent:
2711                        print_rsp_type(ri)
2712                    print_wrapped_type(ri)
2713                    print_dump_prototype(ri)
2714                    cw.nl()
2715
2716                if op.has_ntf:
2717                    cw.p(f"/* {op.enum_name} - notify */")
2718                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2719                    if not ri.type_consistent:
2720                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2721                    print_wrapped_type(ri)
2722
2723            for op_name, op in parsed.ntfs.items():
2724                if 'event' in op:
2725                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2726                    cw.p(f"/* {op.enum_name} - event */")
2727                    print_rsp_type(ri)
2728                    cw.nl()
2729                    print_wrapped_type(ri)
2730            cw.nl()
2731        else:
2732            cw.p('/* Enums */')
2733            put_op_name(parsed, cw)
2734
2735            for name, const in parsed.consts.items():
2736                if isinstance(const, EnumSet):
2737                    put_enum_to_str(parsed, cw, const)
2738            cw.nl()
2739
2740            cw.p('/* Policies */')
2741            for name in parsed.pure_nested_structs:
2742                struct = Struct(parsed, name)
2743                put_typol(cw, struct)
2744            for name in parsed.root_sets:
2745                struct = Struct(parsed, name)
2746                put_typol(cw, struct)
2747
2748            cw.p('/* Common nested types */')
2749            for attr_set, struct in parsed.pure_nested_structs.items():
2750                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2751
2752                free_rsp_nested(ri, struct)
2753                if struct.request:
2754                    put_req_nested(ri, struct)
2755                if struct.reply:
2756                    parse_rsp_nested(ri, struct)
2757
2758            for op_name, op in parsed.ops.items():
2759                cw.p(f"/* ============== {op.enum_name} ============== */")
2760                if 'do' in op and 'event' not in op:
2761                    cw.p(f"/* {op.enum_name} - do */")
2762                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2763                    print_req_free(ri)
2764                    print_rsp_free(ri)
2765                    parse_rsp_msg(ri)
2766                    print_req(ri)
2767                    cw.nl()
2768
2769                if 'dump' in op:
2770                    cw.p(f"/* {op.enum_name} - dump */")
2771                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2772                    if not ri.type_consistent:
2773                        parse_rsp_msg(ri, deref=True)
2774                    print_dump_type_free(ri)
2775                    print_dump(ri)
2776                    cw.nl()
2777
2778                if op.has_ntf:
2779                    cw.p(f"/* {op.enum_name} - notify */")
2780                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2781                    if not ri.type_consistent:
2782                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2783                    print_ntf_type_free(ri)
2784
2785            for op_name, op in parsed.ntfs.items():
2786                if 'event' in op:
2787                    cw.p(f"/* {op.enum_name} - event */")
2788
2789                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2790                    parse_rsp_msg(ri)
2791
2792                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2793                    print_ntf_type_free(ri)
2794            cw.nl()
2795            render_user_family(parsed, cw, False)
2796
2797    if args.header:
2798        cw.p(f'#endif /* {hdr_prot} */')
2799
2800
2801if __name__ == "__main__":
2802    main()
2803