xref: /linux/tools/net/ynl/ynl-gen-c.py (revision eb01fe7abbe2d0b38824d2a93fdb4cc3eaf2ccc1)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import os
8import re
9import shutil
10import tempfile
11import yaml
12
13from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
14
15
16def c_upper(name):
17    return name.upper().replace('-', '_')
18
19
20def c_lower(name):
21    return name.lower().replace('-', '_')
22
23
24def limit_to_number(name):
25    """
26    Turn a string limit like u32-max or s64-min into its numerical value
27    """
28    if name[0] == 'u' and name.endswith('-min'):
29        return 0
30    width = int(name[1:-4])
31    if name[0] == 's':
32        width -= 1
33    value = (1 << width) - 1
34    if name[0] == 's' and name.endswith('-min'):
35        value = -value - 1
36    return value
37
38
39class BaseNlLib:
40    def get_family_id(self):
41        return 'ys->family_id'
42
43
44class Type(SpecAttr):
45    def __init__(self, family, attr_set, attr, value):
46        super().__init__(family, attr_set, attr, value)
47
48        self.attr = attr
49        self.attr_set = attr_set
50        self.type = attr['type']
51        self.checks = attr.get('checks', {})
52
53        self.request = False
54        self.reply = False
55
56        if 'len' in attr:
57            self.len = attr['len']
58
59        if 'nested-attributes' in attr:
60            self.nested_attrs = attr['nested-attributes']
61            if self.nested_attrs == family.name:
62                self.nested_render_name = c_lower(f"{family.name}")
63            else:
64                self.nested_render_name = c_lower(f"{family.name}_{self.nested_attrs}")
65
66            if self.nested_attrs in self.family.consts:
67                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
68            else:
69                self.nested_struct_type = 'struct ' + self.nested_render_name
70
71        self.c_name = c_lower(self.name)
72        if self.c_name in _C_KW:
73            self.c_name += '_'
74
75        # Added by resolve():
76        self.enum_name = None
77        delattr(self, "enum_name")
78
79    def get_limit(self, limit, default=None):
80        value = self.checks.get(limit, default)
81        if value is None:
82            return value
83        elif value in self.family.consts:
84            return c_upper(f"{self.family['name']}-{value}")
85        if not isinstance(value, int):
86            value = limit_to_number(value)
87        return value
88
89    def resolve(self):
90        if 'name-prefix' in self.attr:
91            enum_name = f"{self.attr['name-prefix']}{self.name}"
92        else:
93            enum_name = f"{self.attr_set.name_prefix}{self.name}"
94        self.enum_name = c_upper(enum_name)
95
96    def is_multi_val(self):
97        return None
98
99    def is_scalar(self):
100        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
101
102    def is_recursive(self):
103        return False
104
105    def is_recursive_for_op(self, ri):
106        return self.is_recursive() and not ri.op
107
108    def presence_type(self):
109        return 'bit'
110
111    def presence_member(self, space, type_filter):
112        if self.presence_type() != type_filter:
113            return
114
115        if self.presence_type() == 'bit':
116            pfx = '__' if space == 'user' else ''
117            return f"{pfx}u32 {self.c_name}:1;"
118
119        if self.presence_type() == 'len':
120            pfx = '__' if space == 'user' else ''
121            return f"{pfx}u32 {self.c_name}_len;"
122
123    def _complex_member_type(self, ri):
124        return None
125
126    def free_needs_iter(self):
127        return False
128
129    def free(self, ri, var, ref):
130        if self.is_multi_val() or self.presence_type() == 'len':
131            ri.cw.p(f'free({var}->{ref}{self.c_name});')
132
133    def arg_member(self, ri):
134        member = self._complex_member_type(ri)
135        if member:
136            arg = [member + ' *' + self.c_name]
137            if self.presence_type() == 'count':
138                arg += ['unsigned int n_' + self.c_name]
139            return arg
140        raise Exception(f"Struct member not implemented for class type {self.type}")
141
142    def struct_member(self, ri):
143        if self.is_multi_val():
144            ri.cw.p(f"unsigned int n_{self.c_name};")
145        member = self._complex_member_type(ri)
146        if member:
147            ptr = '*' if self.is_multi_val() else ''
148            if self.is_recursive_for_op(ri):
149                ptr = '*'
150            ri.cw.p(f"{member} {ptr}{self.c_name};")
151            return
152        members = self.arg_member(ri)
153        for one in members:
154            ri.cw.p(one + ';')
155
156    def _attr_policy(self, policy):
157        return '{ .type = ' + policy + ', }'
158
159    def attr_policy(self, cw):
160        policy = c_upper('nla-' + self.attr['type'])
161
162        spec = self._attr_policy(policy)
163        cw.p(f"\t[{self.enum_name}] = {spec},")
164
165    def _attr_typol(self):
166        raise Exception(f"Type policy not implemented for class type {self.type}")
167
168    def attr_typol(self, cw):
169        typol = self._attr_typol()
170        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
171
172    def _attr_put_line(self, ri, var, line):
173        if self.presence_type() == 'bit':
174            ri.cw.p(f"if ({var}->_present.{self.c_name})")
175        elif self.presence_type() == 'len':
176            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
177        ri.cw.p(f"{line};")
178
179    def _attr_put_simple(self, ri, var, put_type):
180        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
181        self._attr_put_line(ri, var, line)
182
183    def attr_put(self, ri, var):
184        raise Exception(f"Put not implemented for class type {self.type}")
185
186    def _attr_get(self, ri, var):
187        raise Exception(f"Attr get not implemented for class type {self.type}")
188
189    def attr_get(self, ri, var, first):
190        lines, init_lines, local_vars = self._attr_get(ri, var)
191        if type(lines) is str:
192            lines = [lines]
193        if type(init_lines) is str:
194            init_lines = [init_lines]
195
196        kw = 'if' if first else 'else if'
197        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
198        if local_vars:
199            for local in local_vars:
200                ri.cw.p(local)
201            ri.cw.nl()
202
203        if not self.is_multi_val():
204            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
205            ri.cw.p("return YNL_PARSE_CB_ERROR;")
206            if self.presence_type() == 'bit':
207                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
208
209        if init_lines:
210            ri.cw.nl()
211            for line in init_lines:
212                ri.cw.p(line)
213
214        for line in lines:
215            ri.cw.p(line)
216        ri.cw.block_end()
217        return True
218
219    def _setter_lines(self, ri, member, presence):
220        raise Exception(f"Setter not implemented for class type {self.type}")
221
222    def setter(self, ri, space, direction, deref=False, ref=None):
223        ref = (ref if ref else []) + [self.c_name]
224        var = "req"
225        member = f"{var}->{'.'.join(ref)}"
226
227        code = []
228        presence = ''
229        for i in range(0, len(ref)):
230            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
231            if self.presence_type() == 'bit':
232                code.append(presence + ' = 1;')
233        code += self._setter_lines(ri, member, presence)
234
235        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
236        free = bool([x for x in code if 'free(' in x])
237        alloc = bool([x for x in code if 'alloc(' in x])
238        if free and not alloc:
239            func_name = '__' + func_name
240        ri.cw.write_func('static inline void', func_name, body=code,
241                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
242
243
244class TypeUnused(Type):
245    def presence_type(self):
246        return ''
247
248    def arg_member(self, ri):
249        return []
250
251    def _attr_get(self, ri, var):
252        return ['return YNL_PARSE_CB_ERROR;'], None, None
253
254    def _attr_typol(self):
255        return '.type = YNL_PT_REJECT, '
256
257    def attr_policy(self, cw):
258        pass
259
260    def attr_put(self, ri, var):
261        pass
262
263    def attr_get(self, ri, var, first):
264        pass
265
266    def setter(self, ri, space, direction, deref=False, ref=None):
267        pass
268
269
270class TypePad(Type):
271    def presence_type(self):
272        return ''
273
274    def arg_member(self, ri):
275        return []
276
277    def _attr_typol(self):
278        return '.type = YNL_PT_IGNORE, '
279
280    def attr_put(self, ri, var):
281        pass
282
283    def attr_get(self, ri, var, first):
284        pass
285
286    def attr_policy(self, cw):
287        pass
288
289    def setter(self, ri, space, direction, deref=False, ref=None):
290        pass
291
292
293class TypeScalar(Type):
294    def __init__(self, family, attr_set, attr, value):
295        super().__init__(family, attr_set, attr, value)
296
297        self.byte_order_comment = ''
298        if 'byte-order' in attr:
299            self.byte_order_comment = f" /* {attr['byte-order']} */"
300
301        if 'enum' in self.attr:
302            enum = self.family.consts[self.attr['enum']]
303            low, high = enum.value_range()
304            if 'min' not in self.checks:
305                if low != 0 or self.type[0] == 's':
306                    self.checks['min'] = low
307            if 'max' not in self.checks:
308                self.checks['max'] = high
309
310        if 'min' in self.checks and 'max' in self.checks:
311            if self.get_limit('min') > self.get_limit('max'):
312                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
313            self.checks['range'] = True
314
315        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
316        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
317        if low < 0 and self.type[0] == 'u':
318            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
319        if low < -32768 or high > 32767:
320            self.checks['full-range'] = True
321
322        # Added by resolve():
323        self.is_bitfield = None
324        delattr(self, "is_bitfield")
325        self.type_name = None
326        delattr(self, "type_name")
327
328    def resolve(self):
329        self.resolve_up(super())
330
331        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
332            self.is_bitfield = True
333        elif 'enum' in self.attr:
334            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
335        else:
336            self.is_bitfield = False
337
338        if not self.is_bitfield and 'enum' in self.attr:
339            self.type_name = self.family.consts[self.attr['enum']].user_type
340        elif self.is_auto_scalar:
341            self.type_name = '__' + self.type[0] + '64'
342        else:
343            self.type_name = '__' + self.type
344
345    def _attr_policy(self, policy):
346        if 'flags-mask' in self.checks or self.is_bitfield:
347            if self.is_bitfield:
348                enum = self.family.consts[self.attr['enum']]
349                mask = enum.get_mask(as_flags=True)
350            else:
351                flags = self.family.consts[self.checks['flags-mask']]
352                flag_cnt = len(flags['entries'])
353                mask = (1 << flag_cnt) - 1
354            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
355        elif 'full-range' in self.checks:
356            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
357        elif 'range' in self.checks:
358            return f"NLA_POLICY_RANGE({policy}, {self.get_limit('min')}, {self.get_limit('max')})"
359        elif 'min' in self.checks:
360            return f"NLA_POLICY_MIN({policy}, {self.get_limit('min')})"
361        elif 'max' in self.checks:
362            return f"NLA_POLICY_MAX({policy}, {self.get_limit('max')})"
363        return super()._attr_policy(policy)
364
365    def _attr_typol(self):
366        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
367
368    def arg_member(self, ri):
369        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
370
371    def attr_put(self, ri, var):
372        self._attr_put_simple(ri, var, self.type)
373
374    def _attr_get(self, ri, var):
375        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
376
377    def _setter_lines(self, ri, member, presence):
378        return [f"{member} = {self.c_name};"]
379
380
381class TypeFlag(Type):
382    def arg_member(self, ri):
383        return []
384
385    def _attr_typol(self):
386        return '.type = YNL_PT_FLAG, '
387
388    def attr_put(self, ri, var):
389        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
390
391    def _attr_get(self, ri, var):
392        return [], None, None
393
394    def _setter_lines(self, ri, member, presence):
395        return []
396
397
398class TypeString(Type):
399    def arg_member(self, ri):
400        return [f"const char *{self.c_name}"]
401
402    def presence_type(self):
403        return 'len'
404
405    def struct_member(self, ri):
406        ri.cw.p(f"char *{self.c_name};")
407
408    def _attr_typol(self):
409        return f'.type = YNL_PT_NUL_STR, '
410
411    def _attr_policy(self, policy):
412        if 'exact-len' in self.checks:
413            mem = 'NLA_POLICY_EXACT_LEN(' + str(self.checks['exact-len']) + ')'
414        else:
415            mem = '{ .type = ' + policy
416            if 'max-len' in self.checks:
417                mem += ', .len = ' + str(self.get_limit('max-len'))
418            mem += ', }'
419        return mem
420
421    def attr_policy(self, cw):
422        if self.checks.get('unterminated-ok', False):
423            policy = 'NLA_STRING'
424        else:
425            policy = 'NLA_NUL_STRING'
426
427        spec = self._attr_policy(policy)
428        cw.p(f"\t[{self.enum_name}] = {spec},")
429
430    def attr_put(self, ri, var):
431        self._attr_put_simple(ri, var, 'str')
432
433    def _attr_get(self, ri, var):
434        len_mem = var + '->_present.' + self.c_name + '_len'
435        return [f"{len_mem} = len;",
436                f"{var}->{self.c_name} = malloc(len + 1);",
437                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
438                f"{var}->{self.c_name}[len] = 0;"], \
439               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
440               ['unsigned int len;']
441
442    def _setter_lines(self, ri, member, presence):
443        return [f"free({member});",
444                f"{presence}_len = strlen({self.c_name});",
445                f"{member} = malloc({presence}_len + 1);",
446                f'memcpy({member}, {self.c_name}, {presence}_len);',
447                f'{member}[{presence}_len] = 0;']
448
449
450class TypeBinary(Type):
451    def arg_member(self, ri):
452        return [f"const void *{self.c_name}", 'size_t len']
453
454    def presence_type(self):
455        return 'len'
456
457    def struct_member(self, ri):
458        ri.cw.p(f"void *{self.c_name};")
459
460    def _attr_typol(self):
461        return f'.type = YNL_PT_BINARY,'
462
463    def _attr_policy(self, policy):
464        if 'exact-len' in self.checks:
465            mem = 'NLA_POLICY_EXACT_LEN(' + str(self.checks['exact-len']) + ')'
466        else:
467            mem = '{ '
468            if len(self.checks) == 1 and 'min-len' in self.checks:
469                mem += '.len = ' + str(self.get_limit('min-len'))
470            elif len(self.checks) == 0:
471                mem += '.type = NLA_BINARY'
472            else:
473                raise Exception('One or more of binary type checks not implemented, yet')
474            mem += ', }'
475        return mem
476
477    def attr_put(self, ri, var):
478        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
479                            f"{var}->{self.c_name}, {var}->_present.{self.c_name}_len)")
480
481    def _attr_get(self, ri, var):
482        len_mem = var + '->_present.' + self.c_name + '_len'
483        return [f"{len_mem} = len;",
484                f"{var}->{self.c_name} = malloc(len);",
485                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
486               ['len = ynl_attr_data_len(attr);'], \
487               ['unsigned int len;']
488
489    def _setter_lines(self, ri, member, presence):
490        return [f"free({member});",
491                f"{presence}_len = len;",
492                f"{member} = malloc({presence}_len);",
493                f'memcpy({member}, {self.c_name}, {presence}_len);']
494
495
496class TypeBitfield32(Type):
497    def _complex_member_type(self, ri):
498        return "struct nla_bitfield32"
499
500    def _attr_typol(self):
501        return f'.type = YNL_PT_BITFIELD32, '
502
503    def _attr_policy(self, policy):
504        if not 'enum' in self.attr:
505            raise Exception('Enum required for bitfield32 attr')
506        enum = self.family.consts[self.attr['enum']]
507        mask = enum.get_mask(as_flags=True)
508        return f"NLA_POLICY_BITFIELD32({mask})"
509
510    def attr_put(self, ri, var):
511        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
512        self._attr_put_line(ri, var, line)
513
514    def _attr_get(self, ri, var):
515        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
516
517    def _setter_lines(self, ri, member, presence):
518        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
519
520
521class TypeNest(Type):
522    def is_recursive(self):
523        return self.family.pure_nested_structs[self.nested_attrs].recursive
524
525    def _complex_member_type(self, ri):
526        return self.nested_struct_type
527
528    def free(self, ri, var, ref):
529        at = '&'
530        if self.is_recursive_for_op(ri):
531            at = ''
532            ri.cw.p(f'if ({var}->{ref}{self.c_name})')
533        ri.cw.p(f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});')
534
535    def _attr_typol(self):
536        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
537
538    def _attr_policy(self, policy):
539        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
540
541    def attr_put(self, ri, var):
542        at = '' if self.is_recursive_for_op(ri) else '&'
543        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
544                            f"{self.enum_name}, {at}{var}->{self.c_name})")
545
546    def _attr_get(self, ri, var):
547        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
548                     "return YNL_PARSE_CB_ERROR;"]
549        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
550                      f"parg.data = &{var}->{self.c_name};"]
551        return get_lines, init_lines, None
552
553    def setter(self, ri, space, direction, deref=False, ref=None):
554        ref = (ref if ref else []) + [self.c_name]
555
556        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
557            if attr.is_recursive():
558                continue
559            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
560
561
562class TypeMultiAttr(Type):
563    def __init__(self, family, attr_set, attr, value, base_type):
564        super().__init__(family, attr_set, attr, value)
565
566        self.base_type = base_type
567
568    def is_multi_val(self):
569        return True
570
571    def presence_type(self):
572        return 'count'
573
574    def _complex_member_type(self, ri):
575        if 'type' not in self.attr or self.attr['type'] == 'nest':
576            return self.nested_struct_type
577        elif self.attr['type'] in scalars:
578            scalar_pfx = '__' if ri.ku_space == 'user' else ''
579            return scalar_pfx + self.attr['type']
580        else:
581            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
582
583    def free_needs_iter(self):
584        return 'type' not in self.attr or self.attr['type'] == 'nest'
585
586    def free(self, ri, var, ref):
587        if self.attr['type'] in scalars:
588            ri.cw.p(f"free({var}->{ref}{self.c_name});")
589        elif 'type' not in self.attr or self.attr['type'] == 'nest':
590            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
591            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
592            ri.cw.p(f"free({var}->{ref}{self.c_name});")
593        else:
594            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
595
596    def _attr_policy(self, policy):
597        return self.base_type._attr_policy(policy)
598
599    def _attr_typol(self):
600        return self.base_type._attr_typol()
601
602    def _attr_get(self, ri, var):
603        return f'n_{self.c_name}++;', None, None
604
605    def attr_put(self, ri, var):
606        if self.attr['type'] in scalars:
607            put_type = self.type
608            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
609            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
610        elif 'type' not in self.attr or self.attr['type'] == 'nest':
611            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
612            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
613                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
614        else:
615            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
616
617    def _setter_lines(self, ri, member, presence):
618        # For multi-attr we have a count, not presence, hack up the presence
619        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
620        return [f"free({member});",
621                f"{member} = {self.c_name};",
622                f"{presence} = n_{self.c_name};"]
623
624
625class TypeArrayNest(Type):
626    def is_multi_val(self):
627        return True
628
629    def presence_type(self):
630        return 'count'
631
632    def _complex_member_type(self, ri):
633        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
634            return self.nested_struct_type
635        elif self.attr['sub-type'] in scalars:
636            scalar_pfx = '__' if ri.ku_space == 'user' else ''
637            return scalar_pfx + self.attr['sub-type']
638        else:
639            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
640
641    def _attr_typol(self):
642        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
643
644    def _attr_get(self, ri, var):
645        local_vars = ['const struct nlattr *attr2;']
646        get_lines = [f'attr_{self.c_name} = attr;',
647                     'ynl_attr_for_each_nested(attr2, attr)',
648                     f'\t{var}->n_{self.c_name}++;']
649        return get_lines, None, local_vars
650
651
652class TypeNestTypeValue(Type):
653    def _complex_member_type(self, ri):
654        return self.nested_struct_type
655
656    def _attr_typol(self):
657        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
658
659    def _attr_get(self, ri, var):
660        prev = 'attr'
661        tv_args = ''
662        get_lines = []
663        local_vars = []
664        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
665                      f"parg.data = &{var}->{self.c_name};"]
666        if 'type-value' in self.attr:
667            tv_names = [c_lower(x) for x in self.attr["type-value"]]
668            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
669            local_vars += [f'__u32 {", ".join(tv_names)};']
670            for level in self.attr["type-value"]:
671                level = c_lower(level)
672                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
673                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
674                prev = 'attr_' + level
675
676            tv_args = f", {', '.join(tv_names)}"
677
678        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
679        return get_lines, init_lines, local_vars
680
681
682class Struct:
683    def __init__(self, family, space_name, type_list=None, inherited=None):
684        self.family = family
685        self.space_name = space_name
686        self.attr_set = family.attr_sets[space_name]
687        # Use list to catch comparisons with empty sets
688        self._inherited = inherited if inherited is not None else []
689        self.inherited = []
690
691        self.nested = type_list is None
692        if family.name == c_lower(space_name):
693            self.render_name = c_lower(family.name)
694        else:
695            self.render_name = c_lower(family.name + '-' + space_name)
696        self.struct_name = 'struct ' + self.render_name
697        if self.nested and space_name in family.consts:
698            self.struct_name += '_'
699        self.ptr_name = self.struct_name + ' *'
700        # All attr sets this one contains, directly or multiple levels down
701        self.child_nests = set()
702
703        self.request = False
704        self.reply = False
705        self.recursive = False
706
707        self.attr_list = []
708        self.attrs = dict()
709        if type_list is not None:
710            for t in type_list:
711                self.attr_list.append((t, self.attr_set[t]),)
712        else:
713            for t in self.attr_set:
714                self.attr_list.append((t, self.attr_set[t]),)
715
716        max_val = 0
717        self.attr_max_val = None
718        for name, attr in self.attr_list:
719            if attr.value >= max_val:
720                max_val = attr.value
721                self.attr_max_val = attr
722            self.attrs[name] = attr
723
724    def __iter__(self):
725        yield from self.attrs
726
727    def __getitem__(self, key):
728        return self.attrs[key]
729
730    def member_list(self):
731        return self.attr_list
732
733    def set_inherited(self, new_inherited):
734        if self._inherited != new_inherited:
735            raise Exception("Inheriting different members not supported")
736        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
737
738
739class EnumEntry(SpecEnumEntry):
740    def __init__(self, enum_set, yaml, prev, value_start):
741        super().__init__(enum_set, yaml, prev, value_start)
742
743        if prev:
744            self.value_change = (self.value != prev.value + 1)
745        else:
746            self.value_change = (self.value != 0)
747        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
748
749        # Added by resolve:
750        self.c_name = None
751        delattr(self, "c_name")
752
753    def resolve(self):
754        self.resolve_up(super())
755
756        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
757
758
759class EnumSet(SpecEnumSet):
760    def __init__(self, family, yaml):
761        self.render_name = c_lower(family.name + '-' + yaml['name'])
762
763        if 'enum-name' in yaml:
764            if yaml['enum-name']:
765                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
766                self.user_type = self.enum_name
767            else:
768                self.enum_name = None
769        else:
770            self.enum_name = 'enum ' + self.render_name
771
772        if self.enum_name:
773            self.user_type = self.enum_name
774        else:
775            self.user_type = 'int'
776
777        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
778
779        super().__init__(family, yaml)
780
781    def new_entry(self, entry, prev_entry, value_start):
782        return EnumEntry(self, entry, prev_entry, value_start)
783
784    def value_range(self):
785        low = min([x.value for x in self.entries.values()])
786        high = max([x.value for x in self.entries.values()])
787
788        if high - low + 1 != len(self.entries):
789            raise Exception("Can't get value range for a noncontiguous enum")
790
791        return low, high
792
793
794class AttrSet(SpecAttrSet):
795    def __init__(self, family, yaml):
796        super().__init__(family, yaml)
797
798        if self.subset_of is None:
799            if 'name-prefix' in yaml:
800                pfx = yaml['name-prefix']
801            elif self.name == family.name:
802                pfx = family.name + '-a-'
803            else:
804                pfx = f"{family.name}-a-{self.name}-"
805            self.name_prefix = c_upper(pfx)
806            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
807            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
808        else:
809            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
810            self.max_name = family.attr_sets[self.subset_of].max_name
811            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
812
813        # Added by resolve:
814        self.c_name = None
815        delattr(self, "c_name")
816
817    def resolve(self):
818        self.c_name = c_lower(self.name)
819        if self.c_name in _C_KW:
820            self.c_name += '_'
821        if self.c_name == self.family.c_name:
822            self.c_name = ''
823
824    def new_attr(self, elem, value):
825        if elem['type'] in scalars:
826            t = TypeScalar(self.family, self, elem, value)
827        elif elem['type'] == 'unused':
828            t = TypeUnused(self.family, self, elem, value)
829        elif elem['type'] == 'pad':
830            t = TypePad(self.family, self, elem, value)
831        elif elem['type'] == 'flag':
832            t = TypeFlag(self.family, self, elem, value)
833        elif elem['type'] == 'string':
834            t = TypeString(self.family, self, elem, value)
835        elif elem['type'] == 'binary':
836            t = TypeBinary(self.family, self, elem, value)
837        elif elem['type'] == 'bitfield32':
838            t = TypeBitfield32(self.family, self, elem, value)
839        elif elem['type'] == 'nest':
840            t = TypeNest(self.family, self, elem, value)
841        elif elem['type'] == 'array-nest':
842            t = TypeArrayNest(self.family, self, elem, value)
843        elif elem['type'] == 'nest-type-value':
844            t = TypeNestTypeValue(self.family, self, elem, value)
845        else:
846            raise Exception(f"No typed class for type {elem['type']}")
847
848        if 'multi-attr' in elem and elem['multi-attr']:
849            t = TypeMultiAttr(self.family, self, elem, value, t)
850
851        return t
852
853
854class Operation(SpecOperation):
855    def __init__(self, family, yaml, req_value, rsp_value):
856        super().__init__(family, yaml, req_value, rsp_value)
857
858        self.render_name = c_lower(family.name + '_' + self.name)
859
860        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
861                         ('dump' in yaml and 'request' in yaml['dump'])
862
863        self.has_ntf = False
864
865        # Added by resolve:
866        self.enum_name = None
867        delattr(self, "enum_name")
868
869    def resolve(self):
870        self.resolve_up(super())
871
872        if not self.is_async:
873            self.enum_name = self.family.op_prefix + c_upper(self.name)
874        else:
875            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
876
877    def mark_has_ntf(self):
878        self.has_ntf = True
879
880
881class Family(SpecFamily):
882    def __init__(self, file_name, exclude_ops):
883        # Added by resolve:
884        self.c_name = None
885        delattr(self, "c_name")
886        self.op_prefix = None
887        delattr(self, "op_prefix")
888        self.async_op_prefix = None
889        delattr(self, "async_op_prefix")
890        self.mcgrps = None
891        delattr(self, "mcgrps")
892        self.consts = None
893        delattr(self, "consts")
894        self.hooks = None
895        delattr(self, "hooks")
896
897        super().__init__(file_name, exclude_ops=exclude_ops)
898
899        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
900        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
901
902        if 'definitions' not in self.yaml:
903            self.yaml['definitions'] = []
904
905        if 'uapi-header' in self.yaml:
906            self.uapi_header = self.yaml['uapi-header']
907        else:
908            self.uapi_header = f"linux/{self.name}.h"
909        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
910            self.uapi_header_name = self.uapi_header[6:-2]
911        else:
912            self.uapi_header_name = self.name
913
914    def resolve(self):
915        self.resolve_up(super())
916
917        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
918            raise Exception("Codegen only supported for genetlink")
919
920        self.c_name = c_lower(self.name)
921        if 'name-prefix' in self.yaml['operations']:
922            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
923        else:
924            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
925        if 'async-prefix' in self.yaml['operations']:
926            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
927        else:
928            self.async_op_prefix = self.op_prefix
929
930        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
931
932        self.hooks = dict()
933        for when in ['pre', 'post']:
934            self.hooks[when] = dict()
935            for op_mode in ['do', 'dump']:
936                self.hooks[when][op_mode] = dict()
937                self.hooks[when][op_mode]['set'] = set()
938                self.hooks[when][op_mode]['list'] = []
939
940        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
941        self.root_sets = dict()
942        # dict space-name -> set('request', 'reply')
943        self.pure_nested_structs = dict()
944
945        self._mark_notify()
946        self._mock_up_events()
947
948        self._load_root_sets()
949        self._load_nested_sets()
950        self._load_attr_use()
951        self._load_hooks()
952
953        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
954        if self.kernel_policy == 'global':
955            self._load_global_policy()
956
957    def new_enum(self, elem):
958        return EnumSet(self, elem)
959
960    def new_attr_set(self, elem):
961        return AttrSet(self, elem)
962
963    def new_operation(self, elem, req_value, rsp_value):
964        return Operation(self, elem, req_value, rsp_value)
965
966    def _mark_notify(self):
967        for op in self.msgs.values():
968            if 'notify' in op:
969                self.ops[op['notify']].mark_has_ntf()
970
971    # Fake a 'do' equivalent of all events, so that we can render their response parsing
972    def _mock_up_events(self):
973        for op in self.yaml['operations']['list']:
974            if 'event' in op:
975                op['do'] = {
976                    'reply': {
977                        'attributes': op['event']['attributes']
978                    }
979                }
980
981    def _load_root_sets(self):
982        for op_name, op in self.msgs.items():
983            if 'attribute-set' not in op:
984                continue
985
986            req_attrs = set()
987            rsp_attrs = set()
988            for op_mode in ['do', 'dump']:
989                if op_mode in op and 'request' in op[op_mode]:
990                    req_attrs.update(set(op[op_mode]['request']['attributes']))
991                if op_mode in op and 'reply' in op[op_mode]:
992                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
993            if 'event' in op:
994                rsp_attrs.update(set(op['event']['attributes']))
995
996            if op['attribute-set'] not in self.root_sets:
997                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
998            else:
999                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1000                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1001
1002    def _sort_pure_types(self):
1003        # Try to reorder according to dependencies
1004        pns_key_list = list(self.pure_nested_structs.keys())
1005        pns_key_seen = set()
1006        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1007        for _ in range(rounds):
1008            if len(pns_key_list) == 0:
1009                break
1010            name = pns_key_list.pop(0)
1011            finished = True
1012            for _, spec in self.attr_sets[name].items():
1013                if 'nested-attributes' in spec:
1014                    nested = spec['nested-attributes']
1015                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1016                    if self.pure_nested_structs[nested].recursive:
1017                        continue
1018                    if nested not in pns_key_seen:
1019                        # Dicts are sorted, this will make struct last
1020                        struct = self.pure_nested_structs.pop(name)
1021                        self.pure_nested_structs[name] = struct
1022                        finished = False
1023                        break
1024            if finished:
1025                pns_key_seen.add(name)
1026            else:
1027                pns_key_list.append(name)
1028
1029    def _load_nested_sets(self):
1030        attr_set_queue = list(self.root_sets.keys())
1031        attr_set_seen = set(self.root_sets.keys())
1032
1033        while len(attr_set_queue):
1034            a_set = attr_set_queue.pop(0)
1035            for attr, spec in self.attr_sets[a_set].items():
1036                if 'nested-attributes' not in spec:
1037                    continue
1038
1039                nested = spec['nested-attributes']
1040                if nested not in attr_set_seen:
1041                    attr_set_queue.append(nested)
1042                    attr_set_seen.add(nested)
1043
1044                inherit = set()
1045                if nested not in self.root_sets:
1046                    if nested not in self.pure_nested_structs:
1047                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1048                else:
1049                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1050
1051                if 'type-value' in spec:
1052                    if nested in self.root_sets:
1053                        raise Exception("Inheriting members to a space used as root not supported")
1054                    inherit.update(set(spec['type-value']))
1055                elif spec['type'] == 'array-nest':
1056                    inherit.add('idx')
1057                self.pure_nested_structs[nested].set_inherited(inherit)
1058
1059        for root_set, rs_members in self.root_sets.items():
1060            for attr, spec in self.attr_sets[root_set].items():
1061                if 'nested-attributes' in spec:
1062                    nested = spec['nested-attributes']
1063                    if attr in rs_members['request']:
1064                        self.pure_nested_structs[nested].request = True
1065                    if attr in rs_members['reply']:
1066                        self.pure_nested_structs[nested].reply = True
1067
1068        self._sort_pure_types()
1069
1070        # Propagate the request / reply / recursive
1071        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1072            for _, spec in self.attr_sets[attr_set].items():
1073                if 'nested-attributes' in spec:
1074                    child_name = spec['nested-attributes']
1075                    struct.child_nests.add(child_name)
1076                    child = self.pure_nested_structs.get(child_name)
1077                    if child:
1078                        if not child.recursive:
1079                            struct.child_nests.update(child.child_nests)
1080                        child.request |= struct.request
1081                        child.reply |= struct.reply
1082                if attr_set in struct.child_nests:
1083                    struct.recursive = True
1084
1085        self._sort_pure_types()
1086
1087    def _load_attr_use(self):
1088        for _, struct in self.pure_nested_structs.items():
1089            if struct.request:
1090                for _, arg in struct.member_list():
1091                    arg.request = True
1092            if struct.reply:
1093                for _, arg in struct.member_list():
1094                    arg.reply = True
1095
1096        for root_set, rs_members in self.root_sets.items():
1097            for attr, spec in self.attr_sets[root_set].items():
1098                if attr in rs_members['request']:
1099                    spec.request = True
1100                if attr in rs_members['reply']:
1101                    spec.reply = True
1102
1103    def _load_global_policy(self):
1104        global_set = set()
1105        attr_set_name = None
1106        for op_name, op in self.ops.items():
1107            if not op:
1108                continue
1109            if 'attribute-set' not in op:
1110                continue
1111
1112            if attr_set_name is None:
1113                attr_set_name = op['attribute-set']
1114            if attr_set_name != op['attribute-set']:
1115                raise Exception('For a global policy all ops must use the same set')
1116
1117            for op_mode in ['do', 'dump']:
1118                if op_mode in op:
1119                    req = op[op_mode].get('request')
1120                    if req:
1121                        global_set.update(req.get('attributes', []))
1122
1123        self.global_policy = []
1124        self.global_policy_set = attr_set_name
1125        for attr in self.attr_sets[attr_set_name]:
1126            if attr in global_set:
1127                self.global_policy.append(attr)
1128
1129    def _load_hooks(self):
1130        for op in self.ops.values():
1131            for op_mode in ['do', 'dump']:
1132                if op_mode not in op:
1133                    continue
1134                for when in ['pre', 'post']:
1135                    if when not in op[op_mode]:
1136                        continue
1137                    name = op[op_mode][when]
1138                    if name in self.hooks[when][op_mode]['set']:
1139                        continue
1140                    self.hooks[when][op_mode]['set'].add(name)
1141                    self.hooks[when][op_mode]['list'].append(name)
1142
1143
1144class RenderInfo:
1145    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1146        self.family = family
1147        self.nl = cw.nlib
1148        self.ku_space = ku_space
1149        self.op_mode = op_mode
1150        self.op = op
1151
1152        self.fixed_hdr = None
1153        if op and op.fixed_header:
1154            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1155
1156        # 'do' and 'dump' response parsing is identical
1157        self.type_consistent = True
1158        if op_mode != 'do' and 'dump' in op:
1159            if 'do' in op:
1160                if ('reply' in op['do']) != ('reply' in op["dump"]):
1161                    self.type_consistent = False
1162                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1163                    self.type_consistent = False
1164            else:
1165                self.type_consistent = False
1166
1167        self.attr_set = attr_set
1168        if not self.attr_set:
1169            self.attr_set = op['attribute-set']
1170
1171        self.type_name_conflict = False
1172        if op:
1173            self.type_name = c_lower(op.name)
1174        else:
1175            self.type_name = c_lower(attr_set)
1176            if attr_set in family.consts:
1177                self.type_name_conflict = True
1178
1179        self.cw = cw
1180
1181        self.struct = dict()
1182        if op_mode == 'notify':
1183            op_mode = 'do'
1184        for op_dir in ['request', 'reply']:
1185            if op:
1186                type_list = []
1187                if op_dir in op[op_mode]:
1188                    type_list = op[op_mode][op_dir]['attributes']
1189                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1190        if op_mode == 'event':
1191            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1192
1193
1194class CodeWriter:
1195    def __init__(self, nlib, out_file=None, overwrite=True):
1196        self.nlib = nlib
1197        self._overwrite = overwrite
1198
1199        self._nl = False
1200        self._block_end = False
1201        self._silent_block = False
1202        self._ind = 0
1203        self._ifdef_block = None
1204        if out_file is None:
1205            self._out = os.sys.stdout
1206        else:
1207            self._out = tempfile.NamedTemporaryFile('w+')
1208            self._out_file = out_file
1209
1210    def __del__(self):
1211        self.close_out_file()
1212
1213    def close_out_file(self):
1214        if self._out == os.sys.stdout:
1215            return
1216        # Avoid modifying the file if contents didn't change
1217        self._out.flush()
1218        if not self._overwrite and os.path.isfile(self._out_file):
1219            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1220                return
1221        with open(self._out_file, 'w+') as out_file:
1222            self._out.seek(0)
1223            shutil.copyfileobj(self._out, out_file)
1224            self._out.close()
1225        self._out = os.sys.stdout
1226
1227    @classmethod
1228    def _is_cond(cls, line):
1229        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1230
1231    def p(self, line, add_ind=0):
1232        if self._block_end:
1233            self._block_end = False
1234            if line.startswith('else'):
1235                line = '} ' + line
1236            else:
1237                self._out.write('\t' * self._ind + '}\n')
1238
1239        if self._nl:
1240            self._out.write('\n')
1241            self._nl = False
1242
1243        ind = self._ind
1244        if line[-1] == ':':
1245            ind -= 1
1246        if self._silent_block:
1247            ind += 1
1248        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1249        if line[0] == '#':
1250            ind = 0
1251        if add_ind:
1252            ind += add_ind
1253        self._out.write('\t' * ind + line + '\n')
1254
1255    def nl(self):
1256        self._nl = True
1257
1258    def block_start(self, line=''):
1259        if line:
1260            line = line + ' '
1261        self.p(line + '{')
1262        self._ind += 1
1263
1264    def block_end(self, line=''):
1265        if line and line[0] not in {';', ','}:
1266            line = ' ' + line
1267        self._ind -= 1
1268        self._nl = False
1269        if not line:
1270            # Delay printing closing bracket in case "else" comes next
1271            if self._block_end:
1272                self._out.write('\t' * (self._ind + 1) + '}\n')
1273            self._block_end = True
1274        else:
1275            self.p('}' + line)
1276
1277    def write_doc_line(self, doc, indent=True):
1278        words = doc.split()
1279        line = ' *'
1280        for word in words:
1281            if len(line) + len(word) >= 79:
1282                self.p(line)
1283                line = ' *'
1284                if indent:
1285                    line += '  '
1286            line += ' ' + word
1287        self.p(line)
1288
1289    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1290        if not args:
1291            args = ['void']
1292
1293        if doc:
1294            self.p('/*')
1295            self.p(' * ' + doc)
1296            self.p(' */')
1297
1298        oneline = qual_ret
1299        if qual_ret[-1] != '*':
1300            oneline += ' '
1301        oneline += f"{name}({', '.join(args)}){suffix}"
1302
1303        if len(oneline) < 80:
1304            self.p(oneline)
1305            return
1306
1307        v = qual_ret
1308        if len(v) > 3:
1309            self.p(v)
1310            v = ''
1311        elif qual_ret[-1] != '*':
1312            v += ' '
1313        v += name + '('
1314        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1315        delta_ind = len(v) - len(ind)
1316        v += args[0]
1317        i = 1
1318        while i < len(args):
1319            next_len = len(v) + len(args[i])
1320            if v[0] == '\t':
1321                next_len += delta_ind
1322            if next_len > 76:
1323                self.p(v + ',')
1324                v = ind
1325            else:
1326                v += ', '
1327            v += args[i]
1328            i += 1
1329        self.p(v + ')' + suffix)
1330
1331    def write_func_lvar(self, local_vars):
1332        if not local_vars:
1333            return
1334
1335        if type(local_vars) is str:
1336            local_vars = [local_vars]
1337
1338        local_vars.sort(key=len, reverse=True)
1339        for var in local_vars:
1340            self.p(var)
1341        self.nl()
1342
1343    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1344        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1345        self.write_func_lvar(local_vars=local_vars)
1346
1347        self.block_start()
1348        for line in body:
1349            self.p(line)
1350        self.block_end()
1351
1352    def writes_defines(self, defines):
1353        longest = 0
1354        for define in defines:
1355            if len(define[0]) > longest:
1356                longest = len(define[0])
1357        longest = ((longest + 8) // 8) * 8
1358        for define in defines:
1359            line = '#define ' + define[0]
1360            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1361            if type(define[1]) is int:
1362                line += str(define[1])
1363            elif type(define[1]) is str:
1364                line += '"' + define[1] + '"'
1365            self.p(line)
1366
1367    def write_struct_init(self, members):
1368        longest = max([len(x[0]) for x in members])
1369        longest += 1  # because we prepend a .
1370        longest = ((longest + 8) // 8) * 8
1371        for one in members:
1372            line = '.' + one[0]
1373            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1374            line += '= ' + str(one[1]) + ','
1375            self.p(line)
1376
1377    def ifdef_block(self, config):
1378        config_option = None
1379        if config:
1380            config_option = 'CONFIG_' + c_upper(config)
1381        if self._ifdef_block == config_option:
1382            return
1383
1384        if self._ifdef_block:
1385            self.p('#endif /* ' + self._ifdef_block + ' */')
1386        if config_option:
1387            self.p('#ifdef ' + config_option)
1388        self._ifdef_block = config_option
1389
1390
1391scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64', 'uint', 'sint'}
1392
1393direction_to_suffix = {
1394    'reply': '_rsp',
1395    'request': '_req',
1396    '': ''
1397}
1398
1399op_mode_to_wrapper = {
1400    'do': '',
1401    'dump': '_list',
1402    'notify': '_ntf',
1403    'event': '',
1404}
1405
1406_C_KW = {
1407    'auto',
1408    'bool',
1409    'break',
1410    'case',
1411    'char',
1412    'const',
1413    'continue',
1414    'default',
1415    'do',
1416    'double',
1417    'else',
1418    'enum',
1419    'extern',
1420    'float',
1421    'for',
1422    'goto',
1423    'if',
1424    'inline',
1425    'int',
1426    'long',
1427    'register',
1428    'return',
1429    'short',
1430    'signed',
1431    'sizeof',
1432    'static',
1433    'struct',
1434    'switch',
1435    'typedef',
1436    'union',
1437    'unsigned',
1438    'void',
1439    'volatile',
1440    'while'
1441}
1442
1443
1444def rdir(direction):
1445    if direction == 'reply':
1446        return 'request'
1447    if direction == 'request':
1448        return 'reply'
1449    return direction
1450
1451
1452def op_prefix(ri, direction, deref=False):
1453    suffix = f"_{ri.type_name}"
1454
1455    if not ri.op_mode or ri.op_mode == 'do':
1456        suffix += f"{direction_to_suffix[direction]}"
1457    else:
1458        if direction == 'request':
1459            suffix += '_req_dump'
1460        else:
1461            if ri.type_consistent:
1462                if deref:
1463                    suffix += f"{direction_to_suffix[direction]}"
1464                else:
1465                    suffix += op_mode_to_wrapper[ri.op_mode]
1466            else:
1467                suffix += '_rsp'
1468                suffix += '_dump' if deref else '_list'
1469
1470    return f"{ri.family.c_name}{suffix}"
1471
1472
1473def type_name(ri, direction, deref=False):
1474    return f"struct {op_prefix(ri, direction, deref=deref)}"
1475
1476
1477def print_prototype(ri, direction, terminate=True, doc=None):
1478    suffix = ';' if terminate else ''
1479
1480    fname = ri.op.render_name
1481    if ri.op_mode == 'dump':
1482        fname += '_dump'
1483
1484    args = ['struct ynl_sock *ys']
1485    if 'request' in ri.op[ri.op_mode]:
1486        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1487
1488    ret = 'int'
1489    if 'reply' in ri.op[ri.op_mode]:
1490        ret = f"{type_name(ri, rdir(direction))} *"
1491
1492    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1493
1494
1495def print_req_prototype(ri):
1496    print_prototype(ri, "request", doc=ri.op['doc'])
1497
1498
1499def print_dump_prototype(ri):
1500    print_prototype(ri, "request")
1501
1502
1503def put_typol_fwd(cw, struct):
1504    cw.p(f'extern struct ynl_policy_nest {struct.render_name}_nest;')
1505
1506
1507def put_typol(cw, struct):
1508    type_max = struct.attr_set.max_name
1509    cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1510
1511    for _, arg in struct.member_list():
1512        arg.attr_typol(cw)
1513
1514    cw.block_end(line=';')
1515    cw.nl()
1516
1517    cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1518    cw.p(f'.max_attr = {type_max},')
1519    cw.p(f'.table = {struct.render_name}_policy,')
1520    cw.block_end(line=';')
1521    cw.nl()
1522
1523
1524def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1525    args = [f'int {arg_name}']
1526    if enum:
1527        args = [enum.user_type + ' ' + arg_name]
1528    cw.write_func_prot('const char *', f'{render_name}_str', args)
1529    cw.block_start()
1530    if enum and enum.type == 'flags':
1531        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1532    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1533    cw.p('return NULL;')
1534    cw.p(f'return {map_name}[{arg_name}];')
1535    cw.block_end()
1536    cw.nl()
1537
1538
1539def put_op_name_fwd(family, cw):
1540    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1541
1542
1543def put_op_name(family, cw):
1544    map_name = f'{family.c_name}_op_strmap'
1545    cw.block_start(line=f"static const char * const {map_name}[] =")
1546    for op_name, op in family.msgs.items():
1547        if op.rsp_value:
1548            # Make sure we don't add duplicated entries, if multiple commands
1549            # produce the same response in legacy families.
1550            if family.rsp_by_value[op.rsp_value] != op:
1551                cw.p(f'// skip "{op_name}", duplicate reply value')
1552                continue
1553
1554            if op.req_value == op.rsp_value:
1555                cw.p(f'[{op.enum_name}] = "{op_name}",')
1556            else:
1557                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1558    cw.block_end(line=';')
1559    cw.nl()
1560
1561    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1562
1563
1564def put_enum_to_str_fwd(family, cw, enum):
1565    args = [enum.user_type + ' value']
1566    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1567
1568
1569def put_enum_to_str(family, cw, enum):
1570    map_name = f'{enum.render_name}_strmap'
1571    cw.block_start(line=f"static const char * const {map_name}[] =")
1572    for entry in enum.entries.values():
1573        cw.p(f'[{entry.value}] = "{entry.name}",')
1574    cw.block_end(line=';')
1575    cw.nl()
1576
1577    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1578
1579
1580def put_req_nested_prototype(ri, struct, suffix=';'):
1581    func_args = ['struct nlmsghdr *nlh',
1582                 'unsigned int attr_type',
1583                 f'{struct.ptr_name}obj']
1584
1585    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1586                          suffix=suffix)
1587
1588
1589def put_req_nested(ri, struct):
1590    put_req_nested_prototype(ri, struct, suffix='')
1591    ri.cw.block_start()
1592    ri.cw.write_func_lvar('struct nlattr *nest;')
1593
1594    ri.cw.p("nest = ynl_attr_nest_start(nlh, attr_type);")
1595
1596    for _, arg in struct.member_list():
1597        arg.attr_put(ri, "obj")
1598
1599    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1600
1601    ri.cw.nl()
1602    ri.cw.p('return 0;')
1603    ri.cw.block_end()
1604    ri.cw.nl()
1605
1606
1607def _multi_parse(ri, struct, init_lines, local_vars):
1608    if struct.nested:
1609        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1610    else:
1611        if ri.fixed_hdr:
1612            local_vars += ['void *hdr;']
1613        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1614
1615    array_nests = set()
1616    multi_attrs = set()
1617    needs_parg = False
1618    for arg, aspec in struct.member_list():
1619        if aspec['type'] == 'array-nest':
1620            local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1621            array_nests.add(arg)
1622        if 'multi-attr' in aspec:
1623            multi_attrs.add(arg)
1624        needs_parg |= 'nested-attributes' in aspec
1625    if array_nests or multi_attrs:
1626        local_vars.append('int i;')
1627    if needs_parg:
1628        local_vars.append('struct ynl_parse_arg parg;')
1629        init_lines.append('parg.ys = yarg->ys;')
1630
1631    all_multi = array_nests | multi_attrs
1632
1633    for anest in sorted(all_multi):
1634        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1635
1636    ri.cw.block_start()
1637    ri.cw.write_func_lvar(local_vars)
1638
1639    for line in init_lines:
1640        ri.cw.p(line)
1641    ri.cw.nl()
1642
1643    for arg in struct.inherited:
1644        ri.cw.p(f'dst->{arg} = {arg};')
1645
1646    if ri.fixed_hdr:
1647        ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1648        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1649    for anest in sorted(all_multi):
1650        aspec = struct[anest]
1651        ri.cw.p(f"if (dst->{aspec.c_name})")
1652        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1653
1654    ri.cw.nl()
1655    ri.cw.block_start(line=iter_line)
1656    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1657    ri.cw.nl()
1658
1659    first = True
1660    for _, arg in struct.member_list():
1661        good = arg.attr_get(ri, 'dst', first=first)
1662        # First may be 'unused' or 'pad', ignore those
1663        first &= not good
1664
1665    ri.cw.block_end()
1666    ri.cw.nl()
1667
1668    for anest in sorted(array_nests):
1669        aspec = struct[anest]
1670
1671        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1672        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1673        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1674        ri.cw.p('i = 0;')
1675        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1676        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1677        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1678        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1679        ri.cw.p('return YNL_PARSE_CB_ERROR;')
1680        ri.cw.p('i++;')
1681        ri.cw.block_end()
1682        ri.cw.block_end()
1683    ri.cw.nl()
1684
1685    for anest in sorted(multi_attrs):
1686        aspec = struct[anest]
1687        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1688        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1689        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1690        ri.cw.p('i = 0;')
1691        if 'nested-attributes' in aspec:
1692            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1693        ri.cw.block_start(line=iter_line)
1694        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1695        if 'nested-attributes' in aspec:
1696            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1697            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1698            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1699        elif aspec.type in scalars:
1700            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1701        else:
1702            raise Exception('Nest parsing type not supported yet')
1703        ri.cw.p('i++;')
1704        ri.cw.block_end()
1705        ri.cw.block_end()
1706        ri.cw.block_end()
1707    ri.cw.nl()
1708
1709    if struct.nested:
1710        ri.cw.p('return 0;')
1711    else:
1712        ri.cw.p('return YNL_PARSE_CB_OK;')
1713    ri.cw.block_end()
1714    ri.cw.nl()
1715
1716
1717def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1718    func_args = ['struct ynl_parse_arg *yarg',
1719                 'const struct nlattr *nested']
1720    for arg in struct.inherited:
1721        func_args.append('__u32 ' + arg)
1722
1723    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1724                          suffix=suffix)
1725
1726
1727def parse_rsp_nested(ri, struct):
1728    parse_rsp_nested_prototype(ri, struct, suffix='')
1729
1730    local_vars = ['const struct nlattr *attr;',
1731                  f'{struct.ptr_name}dst = yarg->data;']
1732    init_lines = []
1733
1734    _multi_parse(ri, struct, init_lines, local_vars)
1735
1736
1737def parse_rsp_msg(ri, deref=False):
1738    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1739        return
1740
1741    func_args = ['const struct nlmsghdr *nlh',
1742                 'struct ynl_parse_arg *yarg']
1743
1744    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1745                  'const struct nlattr *attr;']
1746    init_lines = ['dst = yarg->data;']
1747
1748    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1749
1750    if ri.struct["reply"].member_list():
1751        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1752    else:
1753        # Empty reply
1754        ri.cw.block_start()
1755        ri.cw.p('return YNL_PARSE_CB_OK;')
1756        ri.cw.block_end()
1757        ri.cw.nl()
1758
1759
1760def print_req(ri):
1761    ret_ok = '0'
1762    ret_err = '-1'
1763    direction = "request"
1764    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1765                  'struct nlmsghdr *nlh;',
1766                  'int err;']
1767
1768    if 'reply' in ri.op[ri.op_mode]:
1769        ret_ok = 'rsp'
1770        ret_err = 'NULL'
1771        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1772
1773    if ri.fixed_hdr:
1774        local_vars += ['size_t hdr_len;',
1775                       'void *hdr;']
1776
1777    print_prototype(ri, direction, terminate=False)
1778    ri.cw.block_start()
1779    ri.cw.write_func_lvar(local_vars)
1780
1781    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1782
1783    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1784    if 'reply' in ri.op[ri.op_mode]:
1785        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1786    ri.cw.nl()
1787
1788    if ri.fixed_hdr:
1789        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1790        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1791        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1792        ri.cw.nl()
1793
1794    for _, attr in ri.struct["request"].member_list():
1795        attr.attr_put(ri, "req")
1796    ri.cw.nl()
1797
1798    if 'reply' in ri.op[ri.op_mode]:
1799        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1800        ri.cw.p('yrs.yarg.data = rsp;')
1801        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1802        if ri.op.value is not None:
1803            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1804        else:
1805            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1806        ri.cw.nl()
1807    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1808    ri.cw.p('if (err < 0)')
1809    if 'reply' in ri.op[ri.op_mode]:
1810        ri.cw.p('goto err_free;')
1811    else:
1812        ri.cw.p('return -1;')
1813    ri.cw.nl()
1814
1815    ri.cw.p(f"return {ret_ok};")
1816    ri.cw.nl()
1817
1818    if 'reply' in ri.op[ri.op_mode]:
1819        ri.cw.p('err_free:')
1820        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1821        ri.cw.p(f"return {ret_err};")
1822
1823    ri.cw.block_end()
1824
1825
1826def print_dump(ri):
1827    direction = "request"
1828    print_prototype(ri, direction, terminate=False)
1829    ri.cw.block_start()
1830    local_vars = ['struct ynl_dump_state yds = {};',
1831                  'struct nlmsghdr *nlh;',
1832                  'int err;']
1833
1834    if ri.fixed_hdr:
1835        local_vars += ['size_t hdr_len;',
1836                       'void *hdr;']
1837
1838    ri.cw.write_func_lvar(local_vars)
1839
1840    ri.cw.p('yds.yarg.ys = ys;')
1841    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1842    ri.cw.p("yds.yarg.data = NULL;")
1843    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1844    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1845    if ri.op.value is not None:
1846        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1847    else:
1848        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1849    ri.cw.nl()
1850    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1851
1852    if ri.fixed_hdr:
1853        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1854        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1855        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1856        ri.cw.nl()
1857
1858    if "request" in ri.op[ri.op_mode]:
1859        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1860        ri.cw.nl()
1861        for _, attr in ri.struct["request"].member_list():
1862            attr.attr_put(ri, "req")
1863    ri.cw.nl()
1864
1865    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1866    ri.cw.p('if (err < 0)')
1867    ri.cw.p('goto free_list;')
1868    ri.cw.nl()
1869
1870    ri.cw.p('return yds.first;')
1871    ri.cw.nl()
1872    ri.cw.p('free_list:')
1873    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1874    ri.cw.p('return NULL;')
1875    ri.cw.block_end()
1876
1877
1878def call_free(ri, direction, var):
1879    return f"{op_prefix(ri, direction)}_free({var});"
1880
1881
1882def free_arg_name(direction):
1883    if direction:
1884        return direction_to_suffix[direction][1:]
1885    return 'obj'
1886
1887
1888def print_alloc_wrapper(ri, direction):
1889    name = op_prefix(ri, direction)
1890    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1891    ri.cw.block_start()
1892    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1893    ri.cw.block_end()
1894
1895
1896def print_free_prototype(ri, direction, suffix=';'):
1897    name = op_prefix(ri, direction)
1898    struct_name = name
1899    if ri.type_name_conflict:
1900        struct_name += '_'
1901    arg = free_arg_name(direction)
1902    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1903
1904
1905def _print_type(ri, direction, struct):
1906    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1907    if not direction and ri.type_name_conflict:
1908        suffix += '_'
1909
1910    if ri.op_mode == 'dump':
1911        suffix += '_dump'
1912
1913    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
1914
1915    if ri.fixed_hdr:
1916        ri.cw.p(ri.fixed_hdr + ' _hdr;')
1917        ri.cw.nl()
1918
1919    meta_started = False
1920    for _, attr in struct.member_list():
1921        for type_filter in ['len', 'bit']:
1922            line = attr.presence_member(ri.ku_space, type_filter)
1923            if line:
1924                if not meta_started:
1925                    ri.cw.block_start(line=f"struct")
1926                    meta_started = True
1927                ri.cw.p(line)
1928    if meta_started:
1929        ri.cw.block_end(line='_present;')
1930        ri.cw.nl()
1931
1932    for arg in struct.inherited:
1933        ri.cw.p(f"__u32 {arg};")
1934
1935    for _, attr in struct.member_list():
1936        attr.struct_member(ri)
1937
1938    ri.cw.block_end(line=';')
1939    ri.cw.nl()
1940
1941
1942def print_type(ri, direction):
1943    _print_type(ri, direction, ri.struct[direction])
1944
1945
1946def print_type_full(ri, struct):
1947    _print_type(ri, "", struct)
1948
1949
1950def print_type_helpers(ri, direction, deref=False):
1951    print_free_prototype(ri, direction)
1952    ri.cw.nl()
1953
1954    if ri.ku_space == 'user' and direction == 'request':
1955        for _, attr in ri.struct[direction].member_list():
1956            attr.setter(ri, ri.attr_set, direction, deref=deref)
1957    ri.cw.nl()
1958
1959
1960def print_req_type_helpers(ri):
1961    if len(ri.struct["request"].attr_list) == 0:
1962        return
1963    print_alloc_wrapper(ri, "request")
1964    print_type_helpers(ri, "request")
1965
1966
1967def print_rsp_type_helpers(ri):
1968    if 'reply' not in ri.op[ri.op_mode]:
1969        return
1970    print_type_helpers(ri, "reply")
1971
1972
1973def print_parse_prototype(ri, direction, terminate=True):
1974    suffix = "_rsp" if direction == "reply" else "_req"
1975    term = ';' if terminate else ''
1976
1977    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1978                          ['const struct nlattr **tb',
1979                           f"struct {ri.op.render_name}{suffix} *req"],
1980                          suffix=term)
1981
1982
1983def print_req_type(ri):
1984    if len(ri.struct["request"].attr_list) == 0:
1985        return
1986    print_type(ri, "request")
1987
1988
1989def print_req_free(ri):
1990    if 'request' not in ri.op[ri.op_mode]:
1991        return
1992    _free_type(ri, 'request', ri.struct['request'])
1993
1994
1995def print_rsp_type(ri):
1996    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1997        direction = 'reply'
1998    elif ri.op_mode == 'event':
1999        direction = 'reply'
2000    else:
2001        return
2002    print_type(ri, direction)
2003
2004
2005def print_wrapped_type(ri):
2006    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2007    if ri.op_mode == 'dump':
2008        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2009    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2010        ri.cw.p('__u16 family;')
2011        ri.cw.p('__u8 cmd;')
2012        ri.cw.p('struct ynl_ntf_base_type *next;')
2013        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2014    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2015    ri.cw.block_end(line=';')
2016    ri.cw.nl()
2017    print_free_prototype(ri, 'reply')
2018    ri.cw.nl()
2019
2020
2021def _free_type_members_iter(ri, struct):
2022    for _, attr in struct.member_list():
2023        if attr.free_needs_iter():
2024            ri.cw.p('unsigned int i;')
2025            ri.cw.nl()
2026            break
2027
2028
2029def _free_type_members(ri, var, struct, ref=''):
2030    for _, attr in struct.member_list():
2031        attr.free(ri, var, ref)
2032
2033
2034def _free_type(ri, direction, struct):
2035    var = free_arg_name(direction)
2036
2037    print_free_prototype(ri, direction, suffix='')
2038    ri.cw.block_start()
2039    _free_type_members_iter(ri, struct)
2040    _free_type_members(ri, var, struct)
2041    if direction:
2042        ri.cw.p(f'free({var});')
2043    ri.cw.block_end()
2044    ri.cw.nl()
2045
2046
2047def free_rsp_nested_prototype(ri):
2048        print_free_prototype(ri, "")
2049
2050
2051def free_rsp_nested(ri, struct):
2052    _free_type(ri, "", struct)
2053
2054
2055def print_rsp_free(ri):
2056    if 'reply' not in ri.op[ri.op_mode]:
2057        return
2058    _free_type(ri, 'reply', ri.struct['reply'])
2059
2060
2061def print_dump_type_free(ri):
2062    sub_type = type_name(ri, 'reply')
2063
2064    print_free_prototype(ri, 'reply', suffix='')
2065    ri.cw.block_start()
2066    ri.cw.p(f"{sub_type} *next = rsp;")
2067    ri.cw.nl()
2068    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2069    _free_type_members_iter(ri, ri.struct['reply'])
2070    ri.cw.p('rsp = next;')
2071    ri.cw.p('next = rsp->next;')
2072    ri.cw.nl()
2073
2074    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2075    ri.cw.p(f'free(rsp);')
2076    ri.cw.block_end()
2077    ri.cw.block_end()
2078    ri.cw.nl()
2079
2080
2081def print_ntf_type_free(ri):
2082    print_free_prototype(ri, 'reply', suffix='')
2083    ri.cw.block_start()
2084    _free_type_members_iter(ri, ri.struct['reply'])
2085    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2086    ri.cw.p(f'free(rsp);')
2087    ri.cw.block_end()
2088    ri.cw.nl()
2089
2090
2091def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2092    if terminate and ri and policy_should_be_static(struct.family):
2093        return
2094
2095    if terminate:
2096        prefix = 'extern '
2097    else:
2098        if ri and policy_should_be_static(struct.family):
2099            prefix = 'static '
2100        else:
2101            prefix = ''
2102
2103    suffix = ';' if terminate else ' = {'
2104
2105    max_attr = struct.attr_max_val
2106    if ri:
2107        name = ri.op.render_name
2108        if ri.op.dual_policy:
2109            name += '_' + ri.op_mode
2110    else:
2111        name = struct.render_name
2112    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2113
2114
2115def print_req_policy(cw, struct, ri=None):
2116    if ri and ri.op:
2117        cw.ifdef_block(ri.op.get('config-cond', None))
2118    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2119    for _, arg in struct.member_list():
2120        arg.attr_policy(cw)
2121    cw.p("};")
2122    cw.ifdef_block(None)
2123    cw.nl()
2124
2125
2126def kernel_can_gen_family_struct(family):
2127    return family.proto == 'genetlink'
2128
2129
2130def policy_should_be_static(family):
2131    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2132
2133
2134def print_kernel_policy_ranges(family, cw):
2135    first = True
2136    for _, attr_set in family.attr_sets.items():
2137        if attr_set.subset_of:
2138            continue
2139
2140        for _, attr in attr_set.items():
2141            if not attr.request:
2142                continue
2143            if 'full-range' not in attr.checks:
2144                continue
2145
2146            if first:
2147                cw.p('/* Integer value ranges */')
2148                first = False
2149
2150            sign = '' if attr.type[0] == 'u' else '_signed'
2151            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2152            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2153            members = []
2154            if 'min' in attr.checks:
2155                members.append(('min', str(attr.get_limit('min')) + suffix))
2156            if 'max' in attr.checks:
2157                members.append(('max', str(attr.get_limit('max')) + suffix))
2158            cw.write_struct_init(members)
2159            cw.block_end(line=';')
2160            cw.nl()
2161
2162
2163def print_kernel_op_table_fwd(family, cw, terminate):
2164    exported = not kernel_can_gen_family_struct(family)
2165
2166    if not terminate or exported:
2167        cw.p(f"/* Ops table for {family.name} */")
2168
2169        pol_to_struct = {'global': 'genl_small_ops',
2170                         'per-op': 'genl_ops',
2171                         'split': 'genl_split_ops'}
2172        struct_type = pol_to_struct[family.kernel_policy]
2173
2174        if not exported:
2175            cnt = ""
2176        elif family.kernel_policy == 'split':
2177            cnt = 0
2178            for op in family.ops.values():
2179                if 'do' in op:
2180                    cnt += 1
2181                if 'dump' in op:
2182                    cnt += 1
2183        else:
2184            cnt = len(family.ops)
2185
2186        qual = 'static const' if not exported else 'const'
2187        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2188        if terminate:
2189            cw.p(f"extern {line};")
2190        else:
2191            cw.block_start(line=line + ' =')
2192
2193    if not terminate:
2194        return
2195
2196    cw.nl()
2197    for name in family.hooks['pre']['do']['list']:
2198        cw.write_func_prot('int', c_lower(name),
2199                           ['const struct genl_split_ops *ops',
2200                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2201    for name in family.hooks['post']['do']['list']:
2202        cw.write_func_prot('void', c_lower(name),
2203                           ['const struct genl_split_ops *ops',
2204                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2205    for name in family.hooks['pre']['dump']['list']:
2206        cw.write_func_prot('int', c_lower(name),
2207                           ['struct netlink_callback *cb'], suffix=';')
2208    for name in family.hooks['post']['dump']['list']:
2209        cw.write_func_prot('int', c_lower(name),
2210                           ['struct netlink_callback *cb'], suffix=';')
2211
2212    cw.nl()
2213
2214    for op_name, op in family.ops.items():
2215        if op.is_async:
2216            continue
2217
2218        if 'do' in op:
2219            name = c_lower(f"{family.name}-nl-{op_name}-doit")
2220            cw.write_func_prot('int', name,
2221                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2222
2223        if 'dump' in op:
2224            name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
2225            cw.write_func_prot('int', name,
2226                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2227    cw.nl()
2228
2229
2230def print_kernel_op_table_hdr(family, cw):
2231    print_kernel_op_table_fwd(family, cw, terminate=True)
2232
2233
2234def print_kernel_op_table(family, cw):
2235    print_kernel_op_table_fwd(family, cw, terminate=False)
2236    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2237        for op_name, op in family.ops.items():
2238            if op.is_async:
2239                continue
2240
2241            cw.ifdef_block(op.get('config-cond', None))
2242            cw.block_start()
2243            members = [('cmd', op.enum_name)]
2244            if 'dont-validate' in op:
2245                members.append(('validate',
2246                                ' | '.join([c_upper('genl-dont-validate-' + x)
2247                                            for x in op['dont-validate']])), )
2248            for op_mode in ['do', 'dump']:
2249                if op_mode in op:
2250                    name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2251                    members.append((op_mode + 'it', name))
2252            if family.kernel_policy == 'per-op':
2253                struct = Struct(family, op['attribute-set'],
2254                                type_list=op['do']['request']['attributes'])
2255
2256                name = c_lower(f"{family.name}-{op_name}-nl-policy")
2257                members.append(('policy', name))
2258                members.append(('maxattr', struct.attr_max_val.enum_name))
2259            if 'flags' in op:
2260                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2261            cw.write_struct_init(members)
2262            cw.block_end(line=',')
2263    elif family.kernel_policy == 'split':
2264        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2265                    'dump': {'pre': 'start', 'post': 'done'}}
2266
2267        for op_name, op in family.ops.items():
2268            for op_mode in ['do', 'dump']:
2269                if op.is_async or op_mode not in op:
2270                    continue
2271
2272                cw.ifdef_block(op.get('config-cond', None))
2273                cw.block_start()
2274                members = [('cmd', op.enum_name)]
2275                if 'dont-validate' in op:
2276                    dont_validate = []
2277                    for x in op['dont-validate']:
2278                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2279                            continue
2280                        if op_mode == "dump" and x == 'strict':
2281                            continue
2282                        dont_validate.append(x)
2283
2284                    if dont_validate:
2285                        members.append(('validate',
2286                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2287                                                    for x in dont_validate])), )
2288                name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2289                if 'pre' in op[op_mode]:
2290                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2291                members.append((op_mode + 'it', name))
2292                if 'post' in op[op_mode]:
2293                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2294                if 'request' in op[op_mode]:
2295                    struct = Struct(family, op['attribute-set'],
2296                                    type_list=op[op_mode]['request']['attributes'])
2297
2298                    if op.dual_policy:
2299                        name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
2300                    else:
2301                        name = c_lower(f"{family.name}-{op_name}-nl-policy")
2302                    members.append(('policy', name))
2303                    members.append(('maxattr', struct.attr_max_val.enum_name))
2304                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2305                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2306                cw.write_struct_init(members)
2307                cw.block_end(line=',')
2308    cw.ifdef_block(None)
2309
2310    cw.block_end(line=';')
2311    cw.nl()
2312
2313
2314def print_kernel_mcgrp_hdr(family, cw):
2315    if not family.mcgrps['list']:
2316        return
2317
2318    cw.block_start('enum')
2319    for grp in family.mcgrps['list']:
2320        grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2321        cw.p(grp_id)
2322    cw.block_end(';')
2323    cw.nl()
2324
2325
2326def print_kernel_mcgrp_src(family, cw):
2327    if not family.mcgrps['list']:
2328        return
2329
2330    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2331    for grp in family.mcgrps['list']:
2332        name = grp['name']
2333        grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2334        cw.p('[' + grp_id + '] = { "' + name + '", },')
2335    cw.block_end(';')
2336    cw.nl()
2337
2338
2339def print_kernel_family_struct_hdr(family, cw):
2340    if not kernel_can_gen_family_struct(family):
2341        return
2342
2343    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2344    cw.nl()
2345    if 'sock-priv' in family.kernel_family:
2346        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2347        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2348        cw.nl()
2349
2350
2351def print_kernel_family_struct_src(family, cw):
2352    if not kernel_can_gen_family_struct(family):
2353        return
2354
2355    cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2356    cw.p('.name\t\t= ' + family.fam_key + ',')
2357    cw.p('.version\t= ' + family.ver_key + ',')
2358    cw.p('.netnsok\t= true,')
2359    cw.p('.parallel_ops\t= true,')
2360    cw.p('.module\t\t= THIS_MODULE,')
2361    if family.kernel_policy == 'per-op':
2362        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2363        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2364    elif family.kernel_policy == 'split':
2365        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2366        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2367    if family.mcgrps['list']:
2368        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2369        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2370    if 'sock-priv' in family.kernel_family:
2371        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2372        # Force cast here, actual helpers take pointer to the real type.
2373        cw.p(f'.sock_priv_init\t= (void *){family.c_name}_nl_sock_priv_init,')
2374        cw.p(f'.sock_priv_destroy = (void *){family.c_name}_nl_sock_priv_destroy,')
2375    cw.block_end(';')
2376
2377
2378def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2379    start_line = 'enum'
2380    if enum_name in obj:
2381        if obj[enum_name]:
2382            start_line = 'enum ' + c_lower(obj[enum_name])
2383    elif ckey and ckey in obj:
2384        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2385    cw.block_start(line=start_line)
2386
2387
2388def render_uapi(family, cw):
2389    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2390    cw.p('#ifndef ' + hdr_prot)
2391    cw.p('#define ' + hdr_prot)
2392    cw.nl()
2393
2394    defines = [(family.fam_key, family["name"]),
2395               (family.ver_key, family.get('version', 1))]
2396    cw.writes_defines(defines)
2397    cw.nl()
2398
2399    defines = []
2400    for const in family['definitions']:
2401        if const['type'] != 'const':
2402            cw.writes_defines(defines)
2403            defines = []
2404            cw.nl()
2405
2406        # Write kdoc for enum and flags (one day maybe also structs)
2407        if const['type'] == 'enum' or const['type'] == 'flags':
2408            enum = family.consts[const['name']]
2409
2410            if enum.has_doc():
2411                cw.p('/**')
2412                doc = ''
2413                if 'doc' in enum:
2414                    doc = ' - ' + enum['doc']
2415                cw.write_doc_line(enum.enum_name + doc)
2416                for entry in enum.entries.values():
2417                    if entry.has_doc():
2418                        doc = '@' + entry.c_name + ': ' + entry['doc']
2419                        cw.write_doc_line(doc)
2420                cw.p(' */')
2421
2422            uapi_enum_start(family, cw, const, 'name')
2423            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2424            for entry in enum.entries.values():
2425                suffix = ','
2426                if entry.value_change:
2427                    suffix = f" = {entry.user_value()}" + suffix
2428                cw.p(entry.c_name + suffix)
2429
2430            if const.get('render-max', False):
2431                cw.nl()
2432                cw.p('/* private: */')
2433                if const['type'] == 'flags':
2434                    max_name = c_upper(name_pfx + 'mask')
2435                    max_val = f' = {enum.get_mask()},'
2436                    cw.p(max_name + max_val)
2437                else:
2438                    max_name = c_upper(name_pfx + 'max')
2439                    cw.p('__' + max_name + ',')
2440                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2441            cw.block_end(line=';')
2442            cw.nl()
2443        elif const['type'] == 'const':
2444            defines.append([c_upper(family.get('c-define-name',
2445                                               f"{family.name}-{const['name']}")),
2446                            const['value']])
2447
2448    if defines:
2449        cw.writes_defines(defines)
2450        cw.nl()
2451
2452    max_by_define = family.get('max-by-define', False)
2453
2454    for _, attr_set in family.attr_sets.items():
2455        if attr_set.subset_of:
2456            continue
2457
2458        max_value = f"({attr_set.cnt_name} - 1)"
2459
2460        val = 0
2461        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2462        for _, attr in attr_set.items():
2463            suffix = ','
2464            if attr.value != val:
2465                suffix = f" = {attr.value},"
2466                val = attr.value
2467            val += 1
2468            cw.p(attr.enum_name + suffix)
2469        cw.nl()
2470        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2471        if not max_by_define:
2472            cw.p(f"{attr_set.max_name} = {max_value}")
2473        cw.block_end(line=';')
2474        if max_by_define:
2475            cw.p(f"#define {attr_set.max_name} {max_value}")
2476        cw.nl()
2477
2478    # Commands
2479    separate_ntf = 'async-prefix' in family['operations']
2480
2481    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2482    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2483    max_value = f"({cnt_name} - 1)"
2484
2485    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2486    val = 0
2487    for op in family.msgs.values():
2488        if separate_ntf and ('notify' in op or 'event' in op):
2489            continue
2490
2491        suffix = ','
2492        if op.value != val:
2493            suffix = f" = {op.value},"
2494            val = op.value
2495        cw.p(op.enum_name + suffix)
2496        val += 1
2497    cw.nl()
2498    cw.p(cnt_name + ('' if max_by_define else ','))
2499    if not max_by_define:
2500        cw.p(f"{max_name} = {max_value}")
2501    cw.block_end(line=';')
2502    if max_by_define:
2503        cw.p(f"#define {max_name} {max_value}")
2504    cw.nl()
2505
2506    if separate_ntf:
2507        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2508        for op in family.msgs.values():
2509            if separate_ntf and not ('notify' in op or 'event' in op):
2510                continue
2511
2512            suffix = ','
2513            if 'value' in op:
2514                suffix = f" = {op['value']},"
2515            cw.p(op.enum_name + suffix)
2516        cw.block_end(line=';')
2517        cw.nl()
2518
2519    # Multicast
2520    defines = []
2521    for grp in family.mcgrps['list']:
2522        name = grp['name']
2523        defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2524                        f'{name}'])
2525    cw.nl()
2526    if defines:
2527        cw.writes_defines(defines)
2528        cw.nl()
2529
2530    cw.p(f'#endif /* {hdr_prot} */')
2531
2532
2533def _render_user_ntf_entry(ri, op):
2534    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2535    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2536    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2537    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2538    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2539    ri.cw.block_end(line=',')
2540
2541
2542def render_user_family(family, cw, prototype):
2543    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2544    if prototype:
2545        cw.p(f'extern {symbol};')
2546        return
2547
2548    if family.ntfs:
2549        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2550        for ntf_op_name, ntf_op in family.ntfs.items():
2551            if 'notify' in ntf_op:
2552                op = family.ops[ntf_op['notify']]
2553                ri = RenderInfo(cw, family, "user", op, "notify")
2554            elif 'event' in ntf_op:
2555                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2556            else:
2557                raise Exception('Invalid notification ' + ntf_op_name)
2558            _render_user_ntf_entry(ri, ntf_op)
2559        for op_name, op in family.ops.items():
2560            if 'event' not in op:
2561                continue
2562            ri = RenderInfo(cw, family, "user", op, "event")
2563            _render_user_ntf_entry(ri, op)
2564        cw.block_end(line=";")
2565        cw.nl()
2566
2567    cw.block_start(f'{symbol} = ')
2568    cw.p(f'.name\t\t= "{family.c_name}",')
2569    if family.fixed_header:
2570        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2571    else:
2572        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2573    if family.ntfs:
2574        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2575        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family['name']}_ntf_info),")
2576    cw.block_end(line=';')
2577
2578
2579def family_contains_bitfield32(family):
2580    for _, attr_set in family.attr_sets.items():
2581        if attr_set.subset_of:
2582            continue
2583        for _, attr in attr_set.items():
2584            if attr.type == "bitfield32":
2585                return True
2586    return False
2587
2588
2589def find_kernel_root(full_path):
2590    sub_path = ''
2591    while True:
2592        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2593        full_path = os.path.dirname(full_path)
2594        maintainers = os.path.join(full_path, "MAINTAINERS")
2595        if os.path.exists(maintainers):
2596            return full_path, sub_path[:-1]
2597
2598
2599def main():
2600    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2601    parser.add_argument('--mode', dest='mode', type=str, required=True)
2602    parser.add_argument('--spec', dest='spec', type=str, required=True)
2603    parser.add_argument('--header', dest='header', action='store_true', default=None)
2604    parser.add_argument('--source', dest='header', action='store_false')
2605    parser.add_argument('--user-header', nargs='+', default=[])
2606    parser.add_argument('--cmp-out', action='store_true', default=None,
2607                        help='Do not overwrite the output file if the new output is identical to the old')
2608    parser.add_argument('--exclude-op', action='append', default=[])
2609    parser.add_argument('-o', dest='out_file', type=str, default=None)
2610    args = parser.parse_args()
2611
2612    if args.header is None:
2613        parser.error("--header or --source is required")
2614
2615    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2616
2617    try:
2618        parsed = Family(args.spec, exclude_ops)
2619        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2620            print('Spec license:', parsed.license)
2621            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2622            os.sys.exit(1)
2623    except yaml.YAMLError as exc:
2624        print(exc)
2625        os.sys.exit(1)
2626        return
2627
2628    supported_models = ['unified']
2629    if args.mode in ['user', 'kernel']:
2630        supported_models += ['directional']
2631    if parsed.msg_id_model not in supported_models:
2632        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2633        os.sys.exit(1)
2634
2635    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2636
2637    _, spec_kernel = find_kernel_root(args.spec)
2638    if args.mode == 'uapi' or args.header:
2639        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2640    else:
2641        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2642    cw.p("/* Do not edit directly, auto-generated from: */")
2643    cw.p(f"/*\t{spec_kernel} */")
2644    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2645    if args.exclude_op or args.user_header:
2646        line = ''
2647        line += ' --user-header '.join([''] + args.user_header)
2648        line += ' --exclude-op '.join([''] + args.exclude_op)
2649        cw.p(f'/* YNL-ARG{line} */')
2650    cw.nl()
2651
2652    if args.mode == 'uapi':
2653        render_uapi(parsed, cw)
2654        return
2655
2656    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2657    if args.header:
2658        cw.p('#ifndef ' + hdr_prot)
2659        cw.p('#define ' + hdr_prot)
2660        cw.nl()
2661
2662    if args.mode == 'kernel':
2663        cw.p('#include <net/netlink.h>')
2664        cw.p('#include <net/genetlink.h>')
2665        cw.nl()
2666        if not args.header:
2667            if args.out_file:
2668                cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2669            cw.nl()
2670        headers = ['uapi/' + parsed.uapi_header]
2671        headers += parsed.kernel_family.get('headers', [])
2672    else:
2673        cw.p('#include <stdlib.h>')
2674        cw.p('#include <string.h>')
2675        if args.header:
2676            cw.p('#include <linux/types.h>')
2677            if family_contains_bitfield32(parsed):
2678                cw.p('#include <linux/netlink.h>')
2679        else:
2680            cw.p(f'#include "{parsed.name}-user.h"')
2681            cw.p('#include "ynl.h"')
2682        headers = [parsed.uapi_header]
2683    for definition in parsed['definitions']:
2684        if 'header' in definition:
2685            headers.append(definition['header'])
2686    for one in headers:
2687        cw.p(f"#include <{one}>")
2688    cw.nl()
2689
2690    if args.mode == "user":
2691        if not args.header:
2692            cw.p("#include <linux/genetlink.h>")
2693            cw.nl()
2694            for one in args.user_header:
2695                cw.p(f'#include "{one}"')
2696        else:
2697            cw.p('struct ynl_sock;')
2698            cw.nl()
2699            render_user_family(parsed, cw, True)
2700        cw.nl()
2701
2702    if args.mode == "kernel":
2703        if args.header:
2704            for _, struct in sorted(parsed.pure_nested_structs.items()):
2705                if struct.request:
2706                    cw.p('/* Common nested types */')
2707                    break
2708            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2709                if struct.request:
2710                    print_req_policy_fwd(cw, struct)
2711            cw.nl()
2712
2713            if parsed.kernel_policy == 'global':
2714                cw.p(f"/* Global operation policy for {parsed.name} */")
2715
2716                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2717                print_req_policy_fwd(cw, struct)
2718                cw.nl()
2719
2720            if parsed.kernel_policy in {'per-op', 'split'}:
2721                for op_name, op in parsed.ops.items():
2722                    if 'do' in op and 'event' not in op:
2723                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2724                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2725                        cw.nl()
2726
2727            print_kernel_op_table_hdr(parsed, cw)
2728            print_kernel_mcgrp_hdr(parsed, cw)
2729            print_kernel_family_struct_hdr(parsed, cw)
2730        else:
2731            print_kernel_policy_ranges(parsed, cw)
2732
2733            for _, struct in sorted(parsed.pure_nested_structs.items()):
2734                if struct.request:
2735                    cw.p('/* Common nested types */')
2736                    break
2737            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2738                if struct.request:
2739                    print_req_policy(cw, struct)
2740            cw.nl()
2741
2742            if parsed.kernel_policy == 'global':
2743                cw.p(f"/* Global operation policy for {parsed.name} */")
2744
2745                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2746                print_req_policy(cw, struct)
2747                cw.nl()
2748
2749            for op_name, op in parsed.ops.items():
2750                if parsed.kernel_policy in {'per-op', 'split'}:
2751                    for op_mode in ['do', 'dump']:
2752                        if op_mode in op and 'request' in op[op_mode]:
2753                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2754                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2755                            print_req_policy(cw, ri.struct['request'], ri=ri)
2756                            cw.nl()
2757
2758            print_kernel_op_table(parsed, cw)
2759            print_kernel_mcgrp_src(parsed, cw)
2760            print_kernel_family_struct_src(parsed, cw)
2761
2762    if args.mode == "user":
2763        if args.header:
2764            cw.p('/* Enums */')
2765            put_op_name_fwd(parsed, cw)
2766
2767            for name, const in parsed.consts.items():
2768                if isinstance(const, EnumSet):
2769                    put_enum_to_str_fwd(parsed, cw, const)
2770            cw.nl()
2771
2772            cw.p('/* Common nested types */')
2773            for attr_set, struct in parsed.pure_nested_structs.items():
2774                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2775                print_type_full(ri, struct)
2776
2777            for op_name, op in parsed.ops.items():
2778                cw.p(f"/* ============== {op.enum_name} ============== */")
2779
2780                if 'do' in op and 'event' not in op:
2781                    cw.p(f"/* {op.enum_name} - do */")
2782                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2783                    print_req_type(ri)
2784                    print_req_type_helpers(ri)
2785                    cw.nl()
2786                    print_rsp_type(ri)
2787                    print_rsp_type_helpers(ri)
2788                    cw.nl()
2789                    print_req_prototype(ri)
2790                    cw.nl()
2791
2792                if 'dump' in op:
2793                    cw.p(f"/* {op.enum_name} - dump */")
2794                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2795                    print_req_type(ri)
2796                    print_req_type_helpers(ri)
2797                    if not ri.type_consistent:
2798                        print_rsp_type(ri)
2799                    print_wrapped_type(ri)
2800                    print_dump_prototype(ri)
2801                    cw.nl()
2802
2803                if op.has_ntf:
2804                    cw.p(f"/* {op.enum_name} - notify */")
2805                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2806                    if not ri.type_consistent:
2807                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2808                    print_wrapped_type(ri)
2809
2810            for op_name, op in parsed.ntfs.items():
2811                if 'event' in op:
2812                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2813                    cw.p(f"/* {op.enum_name} - event */")
2814                    print_rsp_type(ri)
2815                    cw.nl()
2816                    print_wrapped_type(ri)
2817            cw.nl()
2818        else:
2819            cw.p('/* Enums */')
2820            put_op_name(parsed, cw)
2821
2822            for name, const in parsed.consts.items():
2823                if isinstance(const, EnumSet):
2824                    put_enum_to_str(parsed, cw, const)
2825            cw.nl()
2826
2827            has_recursive_nests = False
2828            cw.p('/* Policies */')
2829            for struct in parsed.pure_nested_structs.values():
2830                if struct.recursive:
2831                    put_typol_fwd(cw, struct)
2832                    has_recursive_nests = True
2833            if has_recursive_nests:
2834                cw.nl()
2835            for name in parsed.pure_nested_structs:
2836                struct = Struct(parsed, name)
2837                put_typol(cw, struct)
2838            for name in parsed.root_sets:
2839                struct = Struct(parsed, name)
2840                put_typol(cw, struct)
2841
2842            cw.p('/* Common nested types */')
2843            if has_recursive_nests:
2844                for attr_set, struct in parsed.pure_nested_structs.items():
2845                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2846                    free_rsp_nested_prototype(ri)
2847                    if struct.request:
2848                        put_req_nested_prototype(ri, struct)
2849                    if struct.reply:
2850                        parse_rsp_nested_prototype(ri, struct)
2851                cw.nl()
2852            for attr_set, struct in parsed.pure_nested_structs.items():
2853                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2854
2855                free_rsp_nested(ri, struct)
2856                if struct.request:
2857                    put_req_nested(ri, struct)
2858                if struct.reply:
2859                    parse_rsp_nested(ri, struct)
2860
2861            for op_name, op in parsed.ops.items():
2862                cw.p(f"/* ============== {op.enum_name} ============== */")
2863                if 'do' in op and 'event' not in op:
2864                    cw.p(f"/* {op.enum_name} - do */")
2865                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2866                    print_req_free(ri)
2867                    print_rsp_free(ri)
2868                    parse_rsp_msg(ri)
2869                    print_req(ri)
2870                    cw.nl()
2871
2872                if 'dump' in op:
2873                    cw.p(f"/* {op.enum_name} - dump */")
2874                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2875                    if not ri.type_consistent:
2876                        parse_rsp_msg(ri, deref=True)
2877                    print_req_free(ri)
2878                    print_dump_type_free(ri)
2879                    print_dump(ri)
2880                    cw.nl()
2881
2882                if op.has_ntf:
2883                    cw.p(f"/* {op.enum_name} - notify */")
2884                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2885                    if not ri.type_consistent:
2886                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2887                    print_ntf_type_free(ri)
2888
2889            for op_name, op in parsed.ntfs.items():
2890                if 'event' in op:
2891                    cw.p(f"/* {op.enum_name} - event */")
2892
2893                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2894                    parse_rsp_msg(ri)
2895
2896                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2897                    print_ntf_type_free(ri)
2898            cw.nl()
2899            render_user_family(parsed, cw, False)
2900
2901    if args.header:
2902        cw.p(f'#endif /* {hdr_prot} */')
2903
2904
2905if __name__ == "__main__":
2906    main()
2907