xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 99b76908a7a3df629a83555333ce0b97824e4734)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        if 'len' in attr:
60            self.len = attr['len']
61
62        if 'nested-attributes' in attr:
63            nested = attr['nested-attributes']
64        else:
65            nested = None
66
67        if nested:
68            self.nested_attrs = nested
69            if self.nested_attrs == family.name:
70                self.nested_render_name = c_lower(f"{family.ident_name}")
71            else:
72                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
73
74            if self.nested_attrs in self.family.consts:
75                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
76            else:
77                self.nested_struct_type = 'struct ' + self.nested_render_name
78
79        self.c_name = c_lower(self.name)
80        if self.c_name in _C_KW:
81            self.c_name += '_'
82        if self.c_name[0].isdigit():
83            self.c_name = '_' + self.c_name
84
85        # Added by resolve():
86        self.enum_name = None
87        delattr(self, "enum_name")
88
89    def _get_real_attr(self):
90        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
91        return self.family.attr_sets[self.attr_set.subset_of][self.name]
92
93    def set_request(self):
94        self.request = True
95        if self.attr_set.subset_of:
96            self._get_real_attr().set_request()
97
98    def set_reply(self):
99        self.reply = True
100        if self.attr_set.subset_of:
101            self._get_real_attr().set_reply()
102
103    def get_limit(self, limit, default=None):
104        value = self.checks.get(limit, default)
105        if value is None:
106            return value
107        if isinstance(value, int):
108            return value
109        if value in self.family.consts:
110            return self.family.consts[value]["value"]
111        return limit_to_number(value)
112
113    def get_limit_str(self, limit, default=None, suffix=''):
114        value = self.checks.get(limit, default)
115        if value is None:
116            return ''
117        if isinstance(value, int):
118            return str(value) + suffix
119        if value in self.family.consts:
120            const = self.family.consts[value]
121            if const.get('header'):
122                return c_upper(value)
123            return c_upper(f"{self.family['name']}-{value}")
124        return c_upper(value)
125
126    def resolve(self):
127        if 'name-prefix' in self.attr:
128            enum_name = f"{self.attr['name-prefix']}{self.name}"
129        else:
130            enum_name = f"{self.attr_set.name_prefix}{self.name}"
131        self.enum_name = c_upper(enum_name)
132
133        if self.attr_set.subset_of:
134            if self.checks != self._get_real_attr().checks:
135                raise Exception("Overriding checks not supported by codegen, yet")
136
137    def is_multi_val(self):
138        return None
139
140    def is_scalar(self):
141        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
142
143    def is_recursive(self):
144        return False
145
146    def is_recursive_for_op(self, ri):
147        return self.is_recursive() and not ri.op
148
149    def presence_type(self):
150        return 'present'
151
152    def presence_member(self, space, type_filter):
153        if self.presence_type() != type_filter:
154            return
155
156        if self.presence_type() == 'present':
157            pfx = '__' if space == 'user' else ''
158            return f"{pfx}u32 {self.c_name}:1;"
159
160        if self.presence_type() in {'len', 'count'}:
161            pfx = '__' if space == 'user' else ''
162            return f"{pfx}u32 {self.c_name};"
163
164    def _complex_member_type(self, ri):
165        return None
166
167    def free_needs_iter(self):
168        return False
169
170    def _free_lines(self, ri, var, ref):
171        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
172            return [f'free({var}->{ref}{self.c_name});']
173        return []
174
175    def free(self, ri, var, ref):
176        lines = self._free_lines(ri, var, ref)
177        for line in lines:
178            ri.cw.p(line)
179
180    def arg_member(self, ri):
181        member = self._complex_member_type(ri)
182        if member:
183            spc = ' ' if member[-1] != '*' else ''
184            arg = [member + spc + '*' + self.c_name]
185            if self.presence_type() == 'count':
186                arg += ['unsigned int n_' + self.c_name]
187            return arg
188        raise Exception(f"Struct member not implemented for class type {self.type}")
189
190    def struct_member(self, ri):
191        member = self._complex_member_type(ri)
192        if member:
193            ptr = '*' if self.is_multi_val() else ''
194            if self.is_recursive_for_op(ri):
195                ptr = '*'
196            spc = ' ' if member[-1] != '*' else ''
197            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
198            return
199        members = self.arg_member(ri)
200        for one in members:
201            ri.cw.p(one + ';')
202
203    def _attr_policy(self, policy):
204        return '{ .type = ' + policy + ', }'
205
206    def attr_policy(self, cw):
207        policy = f'NLA_{c_upper(self.type)}'
208        if self.attr.get('byte-order') == 'big-endian':
209            if self.type in {'u16', 'u32'}:
210                policy = f'NLA_BE{self.type[1:]}'
211
212        spec = self._attr_policy(policy)
213        cw.p(f"\t[{self.enum_name}] = {spec},")
214
215    def _attr_typol(self):
216        raise Exception(f"Type policy not implemented for class type {self.type}")
217
218    def attr_typol(self, cw):
219        typol = self._attr_typol()
220        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
221
222    def _attr_put_line(self, ri, var, line):
223        presence = self.presence_type()
224        if presence in {'present', 'len'}:
225            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
226        ri.cw.p(f"{line};")
227
228    def _attr_put_simple(self, ri, var, put_type):
229        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
230        self._attr_put_line(ri, var, line)
231
232    def attr_put(self, ri, var):
233        raise Exception(f"Put not implemented for class type {self.type}")
234
235    def _attr_get(self, ri, var):
236        raise Exception(f"Attr get not implemented for class type {self.type}")
237
238    def attr_get(self, ri, var, first):
239        lines, init_lines, local_vars = self._attr_get(ri, var)
240        if type(lines) is str:
241            lines = [lines]
242        if type(init_lines) is str:
243            init_lines = [init_lines]
244
245        kw = 'if' if first else 'else if'
246        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
247        if local_vars:
248            for local in local_vars:
249                ri.cw.p(local)
250            ri.cw.nl()
251
252        if not self.is_multi_val():
253            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
254            ri.cw.p("return YNL_PARSE_CB_ERROR;")
255            if self.presence_type() == 'present':
256                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
257
258        if init_lines:
259            ri.cw.nl()
260            for line in init_lines:
261                ri.cw.p(line)
262
263        for line in lines:
264            ri.cw.p(line)
265        ri.cw.block_end()
266        return True
267
268    def _setter_lines(self, ri, member, presence):
269        raise Exception(f"Setter not implemented for class type {self.type}")
270
271    def setter(self, ri, space, direction, deref=False, ref=None):
272        ref = (ref if ref else []) + [self.c_name]
273        var = "req"
274        member = f"{var}->{'.'.join(ref)}"
275
276        local_vars = []
277        if self.free_needs_iter():
278            local_vars += ['unsigned int i;']
279
280        code = []
281        presence = ''
282        for i in range(0, len(ref)):
283            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
284            # Every layer below last is a nest, so we know it uses bit presence
285            # last layer is "self" and may be a complex type
286            if i == len(ref) - 1 and self.presence_type() != 'present':
287                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
288                continue
289            code.append(presence + ' = 1;')
290        ref_path = '.'.join(ref[:-1])
291        if ref_path:
292            ref_path += '.'
293        code += self._free_lines(ri, var, ref_path)
294        code += self._setter_lines(ri, member, presence)
295
296        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
297        free = bool([x for x in code if 'free(' in x])
298        alloc = bool([x for x in code if 'alloc(' in x])
299        if free and not alloc:
300            func_name = '__' + func_name
301        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
302                         body=code,
303                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
304
305
306class TypeUnused(Type):
307    def presence_type(self):
308        return ''
309
310    def arg_member(self, ri):
311        return []
312
313    def _attr_get(self, ri, var):
314        return ['return YNL_PARSE_CB_ERROR;'], None, None
315
316    def _attr_typol(self):
317        return '.type = YNL_PT_REJECT, '
318
319    def attr_policy(self, cw):
320        pass
321
322    def attr_put(self, ri, var):
323        pass
324
325    def attr_get(self, ri, var, first):
326        pass
327
328    def setter(self, ri, space, direction, deref=False, ref=None):
329        pass
330
331
332class TypePad(Type):
333    def presence_type(self):
334        return ''
335
336    def arg_member(self, ri):
337        return []
338
339    def _attr_typol(self):
340        return '.type = YNL_PT_IGNORE, '
341
342    def attr_put(self, ri, var):
343        pass
344
345    def attr_get(self, ri, var, first):
346        pass
347
348    def attr_policy(self, cw):
349        pass
350
351    def setter(self, ri, space, direction, deref=False, ref=None):
352        pass
353
354
355class TypeScalar(Type):
356    def __init__(self, family, attr_set, attr, value):
357        super().__init__(family, attr_set, attr, value)
358
359        self.byte_order_comment = ''
360        if 'byte-order' in attr:
361            self.byte_order_comment = f" /* {attr['byte-order']} */"
362
363        # Classic families have some funny enums, don't bother
364        # computing checks, since we only need them for kernel policies
365        if not family.is_classic():
366            self._init_checks()
367
368        # Added by resolve():
369        self.is_bitfield = None
370        delattr(self, "is_bitfield")
371        self.type_name = None
372        delattr(self, "type_name")
373
374    def resolve(self):
375        self.resolve_up(super())
376
377        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
378            self.is_bitfield = True
379        elif 'enum' in self.attr:
380            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
381        else:
382            self.is_bitfield = False
383
384        if not self.is_bitfield and 'enum' in self.attr:
385            self.type_name = self.family.consts[self.attr['enum']].user_type
386        elif self.is_auto_scalar:
387            self.type_name = '__' + self.type[0] + '64'
388        else:
389            self.type_name = '__' + self.type
390
391    def _init_checks(self):
392        if 'enum' in self.attr:
393            enum = self.family.consts[self.attr['enum']]
394            low, high = enum.value_range()
395            if low == None and high == None:
396                self.checks['sparse'] = True
397            else:
398                if 'min' not in self.checks:
399                    if low != 0 or self.type[0] == 's':
400                        self.checks['min'] = low
401                if 'max' not in self.checks:
402                    self.checks['max'] = high
403
404        if 'min' in self.checks and 'max' in self.checks:
405            if self.get_limit('min') > self.get_limit('max'):
406                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
407            self.checks['range'] = True
408
409        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
410        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
411        if low < 0 and self.type[0] == 'u':
412            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
413        if low < -32768 or high > 32767:
414            self.checks['full-range'] = True
415
416    def _attr_policy(self, policy):
417        if 'flags-mask' in self.checks or self.is_bitfield:
418            if self.is_bitfield:
419                enum = self.family.consts[self.attr['enum']]
420                mask = enum.get_mask(as_flags=True)
421            else:
422                flags = self.family.consts[self.checks['flags-mask']]
423                flag_cnt = len(flags['entries'])
424                mask = (1 << flag_cnt) - 1
425            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
426        elif 'full-range' in self.checks:
427            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
428        elif 'range' in self.checks:
429            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
430        elif 'min' in self.checks:
431            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
432        elif 'max' in self.checks:
433            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
434        elif 'sparse' in self.checks:
435            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
436        return super()._attr_policy(policy)
437
438    def _attr_typol(self):
439        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
440
441    def arg_member(self, ri):
442        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
443
444    def attr_put(self, ri, var):
445        self._attr_put_simple(ri, var, self.type)
446
447    def _attr_get(self, ri, var):
448        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
449
450    def _setter_lines(self, ri, member, presence):
451        return [f"{member} = {self.c_name};"]
452
453
454class TypeFlag(Type):
455    def arg_member(self, ri):
456        return []
457
458    def _attr_typol(self):
459        return '.type = YNL_PT_FLAG, '
460
461    def attr_put(self, ri, var):
462        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
463
464    def _attr_get(self, ri, var):
465        return [], None, None
466
467    def _setter_lines(self, ri, member, presence):
468        return []
469
470
471class TypeString(Type):
472    def arg_member(self, ri):
473        return [f"const char *{self.c_name}"]
474
475    def presence_type(self):
476        return 'len'
477
478    def struct_member(self, ri):
479        ri.cw.p(f"char *{self.c_name};")
480
481    def _attr_typol(self):
482        return f'.type = YNL_PT_NUL_STR, '
483
484    def _attr_policy(self, policy):
485        if 'exact-len' in self.checks:
486            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
487        else:
488            mem = '{ .type = ' + policy
489            if 'max-len' in self.checks:
490                mem += ', .len = ' + self.get_limit_str('max-len')
491            mem += ', }'
492        return mem
493
494    def attr_policy(self, cw):
495        if self.checks.get('unterminated-ok', False):
496            policy = 'NLA_STRING'
497        else:
498            policy = 'NLA_NUL_STRING'
499
500        spec = self._attr_policy(policy)
501        cw.p(f"\t[{self.enum_name}] = {spec},")
502
503    def attr_put(self, ri, var):
504        self._attr_put_simple(ri, var, 'str')
505
506    def _attr_get(self, ri, var):
507        len_mem = var + '->_len.' + self.c_name
508        return [f"{len_mem} = len;",
509                f"{var}->{self.c_name} = malloc(len + 1);",
510                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
511                f"{var}->{self.c_name}[len] = 0;"], \
512               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
513               ['unsigned int len;']
514
515    def _setter_lines(self, ri, member, presence):
516        return [f"{presence} = strlen({self.c_name});",
517                f"{member} = malloc({presence} + 1);",
518                f'memcpy({member}, {self.c_name}, {presence});',
519                f'{member}[{presence}] = 0;']
520
521
522class TypeBinary(Type):
523    def arg_member(self, ri):
524        return [f"const void *{self.c_name}", 'size_t len']
525
526    def presence_type(self):
527        return 'len'
528
529    def struct_member(self, ri):
530        ri.cw.p(f"void *{self.c_name};")
531
532    def _attr_typol(self):
533        return f'.type = YNL_PT_BINARY,'
534
535    def _attr_policy(self, policy):
536        if len(self.checks) == 0:
537            pass
538        elif len(self.checks) == 1:
539            check_name = list(self.checks)[0]
540            if check_name not in {'exact-len', 'min-len', 'max-len'}:
541                raise Exception('Unsupported check for binary type: ' + check_name)
542        else:
543            raise Exception('More than one check for binary type not implemented, yet')
544
545        if len(self.checks) == 0:
546            mem = '{ .type = NLA_BINARY, }'
547        elif 'exact-len' in self.checks:
548            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
549        elif 'min-len' in self.checks:
550            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
551        elif 'max-len' in self.checks:
552            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
553
554        return mem
555
556    def attr_put(self, ri, var):
557        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
558                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
559
560    def _attr_get(self, ri, var):
561        len_mem = var + '->_len.' + self.c_name
562        return [f"{len_mem} = len;",
563                f"{var}->{self.c_name} = malloc(len);",
564                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
565               ['len = ynl_attr_data_len(attr);'], \
566               ['unsigned int len;']
567
568    def _setter_lines(self, ri, member, presence):
569        return [f"{presence} = len;",
570                f"{member} = malloc({presence});",
571                f'memcpy({member}, {self.c_name}, {presence});']
572
573
574class TypeBinaryStruct(TypeBinary):
575    def struct_member(self, ri):
576        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
577
578    def _attr_get(self, ri, var):
579        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
580        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
581        return [f"{len_mem} = len;",
582                f"if (len < {struct_sz})",
583                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
584                "else",
585                f"{var}->{self.c_name} = malloc(len);",
586                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
587               ['len = ynl_attr_data_len(attr);'], \
588               ['unsigned int len;']
589
590
591class TypeBinaryScalarArray(TypeBinary):
592    def arg_member(self, ri):
593        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
594
595    def presence_type(self):
596        return 'count'
597
598    def struct_member(self, ri):
599        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
600
601    def attr_put(self, ri, var):
602        presence = self.presence_type()
603        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
604        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
605        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
606                f"{var}->{self.c_name}, i);")
607        ri.cw.block_end()
608
609    def _attr_get(self, ri, var):
610        len_mem = var + '->_count.' + self.c_name
611        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
612                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
613                f"{var}->{self.c_name} = malloc(len);",
614                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
615               ['len = ynl_attr_data_len(attr);'], \
616               ['unsigned int len;']
617
618    def _setter_lines(self, ri, member, presence):
619        return [f"{presence} = count;",
620                f"count *= sizeof(__{self.get('sub-type')});",
621                f"{member} = malloc(count);",
622                f'memcpy({member}, {self.c_name}, count);']
623
624
625class TypeBitfield32(Type):
626    def _complex_member_type(self, ri):
627        return "struct nla_bitfield32"
628
629    def _attr_typol(self):
630        return f'.type = YNL_PT_BITFIELD32, '
631
632    def _attr_policy(self, policy):
633        if not 'enum' in self.attr:
634            raise Exception('Enum required for bitfield32 attr')
635        enum = self.family.consts[self.attr['enum']]
636        mask = enum.get_mask(as_flags=True)
637        return f"NLA_POLICY_BITFIELD32({mask})"
638
639    def attr_put(self, ri, var):
640        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
641        self._attr_put_line(ri, var, line)
642
643    def _attr_get(self, ri, var):
644        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
645
646    def _setter_lines(self, ri, member, presence):
647        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
648
649
650class TypeNest(Type):
651    def is_recursive(self):
652        return self.family.pure_nested_structs[self.nested_attrs].recursive
653
654    def _complex_member_type(self, ri):
655        return self.nested_struct_type
656
657    def _free_lines(self, ri, var, ref):
658        lines = []
659        at = '&'
660        if self.is_recursive_for_op(ri):
661            at = ''
662            lines += [f'if ({var}->{ref}{self.c_name})']
663        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
664        return lines
665
666    def _attr_typol(self):
667        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
668
669    def _attr_policy(self, policy):
670        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
671
672    def attr_put(self, ri, var):
673        at = '' if self.is_recursive_for_op(ri) else '&'
674        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
675                            f"{self.enum_name}, {at}{var}->{self.c_name})")
676
677    def _attr_get(self, ri, var):
678        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
679                     "return YNL_PARSE_CB_ERROR;"]
680        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
681                      f"parg.data = &{var}->{self.c_name};"]
682        return get_lines, init_lines, None
683
684    def setter(self, ri, space, direction, deref=False, ref=None):
685        ref = (ref if ref else []) + [self.c_name]
686
687        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
688            if attr.is_recursive():
689                continue
690            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
691
692
693class TypeMultiAttr(Type):
694    def __init__(self, family, attr_set, attr, value, base_type):
695        super().__init__(family, attr_set, attr, value)
696
697        self.base_type = base_type
698
699    def is_multi_val(self):
700        return True
701
702    def presence_type(self):
703        return 'count'
704
705    def _complex_member_type(self, ri):
706        if 'type' not in self.attr or self.attr['type'] == 'nest':
707            return self.nested_struct_type
708        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
709            return None  # use arg_member()
710        elif self.attr['type'] == 'string':
711            return 'struct ynl_string *'
712        elif self.attr['type'] in scalars:
713            scalar_pfx = '__' if ri.ku_space == 'user' else ''
714            return scalar_pfx + self.attr['type']
715        else:
716            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
717
718    def arg_member(self, ri):
719        if self.type == 'binary' and 'struct' in self.attr:
720            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
721                    f'unsigned int n_{self.c_name}']
722        return super().arg_member(ri)
723
724    def free_needs_iter(self):
725        return self.attr['type'] in {'nest', 'string'}
726
727    def _free_lines(self, ri, var, ref):
728        lines = []
729        if self.attr['type'] in scalars:
730            lines += [f"free({var}->{ref}{self.c_name});"]
731        elif self.attr['type'] == 'binary':
732            lines += [f"free({var}->{ref}{self.c_name});"]
733        elif self.attr['type'] == 'string':
734            lines += [
735                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
736                f"free({var}->{ref}{self.c_name}[i]);",
737                f"free({var}->{ref}{self.c_name});",
738            ]
739        elif 'type' not in self.attr or self.attr['type'] == 'nest':
740            lines += [
741                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
742                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
743                f"free({var}->{ref}{self.c_name});",
744            ]
745        else:
746            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
747        return lines
748
749    def _attr_policy(self, policy):
750        return self.base_type._attr_policy(policy)
751
752    def _attr_typol(self):
753        return self.base_type._attr_typol()
754
755    def _attr_get(self, ri, var):
756        return f'n_{self.c_name}++;', None, None
757
758    def attr_put(self, ri, var):
759        if self.attr['type'] in scalars:
760            put_type = self.type
761            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
762            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
763        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
764            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
765            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
766        elif self.attr['type'] == 'string':
767            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
768            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
769        elif 'type' not in self.attr or self.attr['type'] == 'nest':
770            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
771            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
772                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
773        else:
774            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
775
776    def _setter_lines(self, ri, member, presence):
777        return [f"{member} = {self.c_name};",
778                f"{presence} = n_{self.c_name};"]
779
780
781class TypeArrayNest(Type):
782    def is_multi_val(self):
783        return True
784
785    def presence_type(self):
786        return 'count'
787
788    def _complex_member_type(self, ri):
789        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
790            return self.nested_struct_type
791        elif self.attr['sub-type'] in scalars:
792            scalar_pfx = '__' if ri.ku_space == 'user' else ''
793            return scalar_pfx + self.attr['sub-type']
794        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
795            return None  # use arg_member()
796        else:
797            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
798
799    def arg_member(self, ri):
800        if self.sub_type == 'binary' and 'exact-len' in self.checks:
801            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
802                    f'unsigned int n_{self.c_name}']
803        return super().arg_member(ri)
804
805    def _attr_typol(self):
806        if self.attr['sub-type'] in scalars:
807            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
808        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
809            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
810        else:
811            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
812
813    def _attr_get(self, ri, var):
814        local_vars = ['const struct nlattr *attr2;']
815        get_lines = [f'attr_{self.c_name} = attr;',
816                     'ynl_attr_for_each_nested(attr2, attr) {',
817                     '\tif (ynl_attr_validate(yarg, attr2))',
818                     '\t\treturn YNL_PARSE_CB_ERROR;',
819                     f'\t{var}->_count.{self.c_name}++;',
820                     '}']
821        return get_lines, None, local_vars
822
823    def attr_put(self, ri, var):
824        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
825        if self.sub_type in scalars:
826            put_type = self.sub_type
827            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
828            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
829            ri.cw.block_end()
830        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
831            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
832            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
833        elif self.sub_type == 'nest':
834            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
835            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
836        else:
837            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
838        ri.cw.p('ynl_attr_nest_end(nlh, array);')
839
840    def _setter_lines(self, ri, member, presence):
841        return [f"{member} = {self.c_name};",
842                f"{presence} = n_{self.c_name};"]
843
844
845class TypeNestTypeValue(Type):
846    def _complex_member_type(self, ri):
847        return self.nested_struct_type
848
849    def _attr_typol(self):
850        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
851
852    def _attr_get(self, ri, var):
853        prev = 'attr'
854        tv_args = ''
855        get_lines = []
856        local_vars = []
857        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
858                      f"parg.data = &{var}->{self.c_name};"]
859        if 'type-value' in self.attr:
860            tv_names = [c_lower(x) for x in self.attr["type-value"]]
861            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
862            local_vars += [f'__u32 {", ".join(tv_names)};']
863            for level in self.attr["type-value"]:
864                level = c_lower(level)
865                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
866                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
867                prev = 'attr_' + level
868
869            tv_args = f", {', '.join(tv_names)}"
870
871        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
872        return get_lines, init_lines, local_vars
873
874
875class Struct:
876    def __init__(self, family, space_name, type_list=None, inherited=None):
877        self.family = family
878        self.space_name = space_name
879        self.attr_set = family.attr_sets[space_name]
880        # Use list to catch comparisons with empty sets
881        self._inherited = inherited if inherited is not None else []
882        self.inherited = []
883
884        self.nested = type_list is None
885        if family.name == c_lower(space_name):
886            self.render_name = c_lower(family.ident_name)
887        else:
888            self.render_name = c_lower(family.ident_name + '-' + space_name)
889        self.struct_name = 'struct ' + self.render_name
890        if self.nested and space_name in family.consts:
891            self.struct_name += '_'
892        self.ptr_name = self.struct_name + ' *'
893        # All attr sets this one contains, directly or multiple levels down
894        self.child_nests = set()
895
896        self.request = False
897        self.reply = False
898        self.recursive = False
899        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
900
901        self.attr_list = []
902        self.attrs = dict()
903        if type_list is not None:
904            for t in type_list:
905                self.attr_list.append((t, self.attr_set[t]),)
906        else:
907            for t in self.attr_set:
908                self.attr_list.append((t, self.attr_set[t]),)
909
910        max_val = 0
911        self.attr_max_val = None
912        for name, attr in self.attr_list:
913            if attr.value >= max_val:
914                max_val = attr.value
915                self.attr_max_val = attr
916            self.attrs[name] = attr
917
918    def __iter__(self):
919        yield from self.attrs
920
921    def __getitem__(self, key):
922        return self.attrs[key]
923
924    def member_list(self):
925        return self.attr_list
926
927    def set_inherited(self, new_inherited):
928        if self._inherited != new_inherited:
929            raise Exception("Inheriting different members not supported")
930        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
931
932    def free_needs_iter(self):
933        for _, attr in self.attr_list:
934            if attr.free_needs_iter():
935                return True
936        return False
937
938
939class EnumEntry(SpecEnumEntry):
940    def __init__(self, enum_set, yaml, prev, value_start):
941        super().__init__(enum_set, yaml, prev, value_start)
942
943        if prev:
944            self.value_change = (self.value != prev.value + 1)
945        else:
946            self.value_change = (self.value != 0)
947        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
948
949        # Added by resolve:
950        self.c_name = None
951        delattr(self, "c_name")
952
953    def resolve(self):
954        self.resolve_up(super())
955
956        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
957
958
959class EnumSet(SpecEnumSet):
960    def __init__(self, family, yaml):
961        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
962
963        if 'enum-name' in yaml:
964            if yaml['enum-name']:
965                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
966                self.user_type = self.enum_name
967            else:
968                self.enum_name = None
969        else:
970            self.enum_name = 'enum ' + self.render_name
971
972        if self.enum_name:
973            self.user_type = self.enum_name
974        else:
975            self.user_type = 'int'
976
977        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
978        self.header = yaml.get('header', None)
979        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
980
981        super().__init__(family, yaml)
982
983    def new_entry(self, entry, prev_entry, value_start):
984        return EnumEntry(self, entry, prev_entry, value_start)
985
986    def value_range(self):
987        low = min([x.value for x in self.entries.values()])
988        high = max([x.value for x in self.entries.values()])
989
990        if high - low + 1 != len(self.entries):
991            return None, None
992
993        return low, high
994
995
996class AttrSet(SpecAttrSet):
997    def __init__(self, family, yaml):
998        super().__init__(family, yaml)
999
1000        if self.subset_of is None:
1001            if 'name-prefix' in yaml:
1002                pfx = yaml['name-prefix']
1003            elif self.name == family.name:
1004                pfx = family.ident_name + '-a-'
1005            else:
1006                pfx = f"{family.ident_name}-a-{self.name}-"
1007            self.name_prefix = c_upper(pfx)
1008            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1009            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1010        else:
1011            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1012            self.max_name = family.attr_sets[self.subset_of].max_name
1013            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1014
1015        # Added by resolve:
1016        self.c_name = None
1017        delattr(self, "c_name")
1018
1019    def resolve(self):
1020        self.c_name = c_lower(self.name)
1021        if self.c_name in _C_KW:
1022            self.c_name += '_'
1023        if self.c_name == self.family.c_name:
1024            self.c_name = ''
1025
1026    def new_attr(self, elem, value):
1027        if elem['type'] in scalars:
1028            t = TypeScalar(self.family, self, elem, value)
1029        elif elem['type'] == 'unused':
1030            t = TypeUnused(self.family, self, elem, value)
1031        elif elem['type'] == 'pad':
1032            t = TypePad(self.family, self, elem, value)
1033        elif elem['type'] == 'flag':
1034            t = TypeFlag(self.family, self, elem, value)
1035        elif elem['type'] == 'string':
1036            t = TypeString(self.family, self, elem, value)
1037        elif elem['type'] == 'binary':
1038            if 'struct' in elem:
1039                t = TypeBinaryStruct(self.family, self, elem, value)
1040            elif elem.get('sub-type') in scalars:
1041                t = TypeBinaryScalarArray(self.family, self, elem, value)
1042            else:
1043                t = TypeBinary(self.family, self, elem, value)
1044        elif elem['type'] == 'bitfield32':
1045            t = TypeBitfield32(self.family, self, elem, value)
1046        elif elem['type'] == 'nest':
1047            t = TypeNest(self.family, self, elem, value)
1048        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1049            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1050                t = TypeArrayNest(self.family, self, elem, value)
1051            else:
1052                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1053        elif elem['type'] == 'nest-type-value':
1054            t = TypeNestTypeValue(self.family, self, elem, value)
1055        else:
1056            raise Exception(f"No typed class for type {elem['type']}")
1057
1058        if 'multi-attr' in elem and elem['multi-attr']:
1059            t = TypeMultiAttr(self.family, self, elem, value, t)
1060
1061        return t
1062
1063
1064class Operation(SpecOperation):
1065    def __init__(self, family, yaml, req_value, rsp_value):
1066        # Fill in missing operation properties (for fixed hdr-only msgs)
1067        for mode in ['do', 'dump', 'event']:
1068            for direction in ['request', 'reply']:
1069                try:
1070                    yaml[mode][direction].setdefault('attributes', [])
1071                except KeyError:
1072                    pass
1073
1074        super().__init__(family, yaml, req_value, rsp_value)
1075
1076        self.render_name = c_lower(family.ident_name + '_' + self.name)
1077
1078        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1079                         ('dump' in yaml and 'request' in yaml['dump'])
1080
1081        self.has_ntf = False
1082
1083        # Added by resolve:
1084        self.enum_name = None
1085        delattr(self, "enum_name")
1086
1087    def resolve(self):
1088        self.resolve_up(super())
1089
1090        if not self.is_async:
1091            self.enum_name = self.family.op_prefix + c_upper(self.name)
1092        else:
1093            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1094
1095    def mark_has_ntf(self):
1096        self.has_ntf = True
1097
1098
1099class Family(SpecFamily):
1100    def __init__(self, file_name, exclude_ops):
1101        # Added by resolve:
1102        self.c_name = None
1103        delattr(self, "c_name")
1104        self.op_prefix = None
1105        delattr(self, "op_prefix")
1106        self.async_op_prefix = None
1107        delattr(self, "async_op_prefix")
1108        self.mcgrps = None
1109        delattr(self, "mcgrps")
1110        self.consts = None
1111        delattr(self, "consts")
1112        self.hooks = None
1113        delattr(self, "hooks")
1114
1115        super().__init__(file_name, exclude_ops=exclude_ops)
1116
1117        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1118        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1119
1120        if 'definitions' not in self.yaml:
1121            self.yaml['definitions'] = []
1122
1123        if 'uapi-header' in self.yaml:
1124            self.uapi_header = self.yaml['uapi-header']
1125        else:
1126            self.uapi_header = f"linux/{self.ident_name}.h"
1127        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1128            self.uapi_header_name = self.uapi_header[6:-2]
1129        else:
1130            self.uapi_header_name = self.ident_name
1131
1132    def resolve(self):
1133        self.resolve_up(super())
1134
1135        self.c_name = c_lower(self.ident_name)
1136        if 'name-prefix' in self.yaml['operations']:
1137            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1138        else:
1139            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1140        if 'async-prefix' in self.yaml['operations']:
1141            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1142        else:
1143            self.async_op_prefix = self.op_prefix
1144
1145        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1146
1147        self.hooks = dict()
1148        for when in ['pre', 'post']:
1149            self.hooks[when] = dict()
1150            for op_mode in ['do', 'dump']:
1151                self.hooks[when][op_mode] = dict()
1152                self.hooks[when][op_mode]['set'] = set()
1153                self.hooks[when][op_mode]['list'] = []
1154
1155        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1156        self.root_sets = dict()
1157        # dict space-name -> Struct
1158        self.pure_nested_structs = dict()
1159
1160        self._mark_notify()
1161        self._mock_up_events()
1162
1163        self._load_root_sets()
1164        self._load_nested_sets()
1165        self._load_attr_use()
1166        self._load_hooks()
1167
1168        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1169        if self.kernel_policy == 'global':
1170            self._load_global_policy()
1171
1172    def new_enum(self, elem):
1173        return EnumSet(self, elem)
1174
1175    def new_attr_set(self, elem):
1176        return AttrSet(self, elem)
1177
1178    def new_operation(self, elem, req_value, rsp_value):
1179        return Operation(self, elem, req_value, rsp_value)
1180
1181    def is_classic(self):
1182        return self.proto == 'netlink-raw'
1183
1184    def _mark_notify(self):
1185        for op in self.msgs.values():
1186            if 'notify' in op:
1187                self.ops[op['notify']].mark_has_ntf()
1188
1189    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1190    def _mock_up_events(self):
1191        for op in self.yaml['operations']['list']:
1192            if 'event' in op:
1193                op['do'] = {
1194                    'reply': {
1195                        'attributes': op['event']['attributes']
1196                    }
1197                }
1198
1199    def _load_root_sets(self):
1200        for op_name, op in self.msgs.items():
1201            if 'attribute-set' not in op:
1202                continue
1203
1204            req_attrs = set()
1205            rsp_attrs = set()
1206            for op_mode in ['do', 'dump']:
1207                if op_mode in op and 'request' in op[op_mode]:
1208                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1209                if op_mode in op and 'reply' in op[op_mode]:
1210                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1211            if 'event' in op:
1212                rsp_attrs.update(set(op['event']['attributes']))
1213
1214            if op['attribute-set'] not in self.root_sets:
1215                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1216            else:
1217                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1218                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1219
1220    def _sort_pure_types(self):
1221        # Try to reorder according to dependencies
1222        pns_key_list = list(self.pure_nested_structs.keys())
1223        pns_key_seen = set()
1224        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1225        for _ in range(rounds):
1226            if len(pns_key_list) == 0:
1227                break
1228            name = pns_key_list.pop(0)
1229            finished = True
1230            for _, spec in self.attr_sets[name].items():
1231                if 'nested-attributes' in spec:
1232                    nested = spec['nested-attributes']
1233                else:
1234                    continue
1235
1236                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1237                if self.pure_nested_structs[nested].recursive:
1238                    continue
1239                if nested not in pns_key_seen:
1240                    # Dicts are sorted, this will make struct last
1241                    struct = self.pure_nested_structs.pop(name)
1242                    self.pure_nested_structs[name] = struct
1243                    finished = False
1244                    break
1245            if finished:
1246                pns_key_seen.add(name)
1247            else:
1248                pns_key_list.append(name)
1249
1250    def _load_nested_set_nest(self, spec):
1251        inherit = set()
1252        nested = spec['nested-attributes']
1253        if nested not in self.root_sets:
1254            if nested not in self.pure_nested_structs:
1255                self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1256        else:
1257            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1258
1259        if 'type-value' in spec:
1260            if nested in self.root_sets:
1261                raise Exception("Inheriting members to a space used as root not supported")
1262            inherit.update(set(spec['type-value']))
1263        elif spec['type'] == 'indexed-array':
1264            inherit.add('idx')
1265        self.pure_nested_structs[nested].set_inherited(inherit)
1266
1267        return nested
1268
1269    def _load_nested_sets(self):
1270        attr_set_queue = list(self.root_sets.keys())
1271        attr_set_seen = set(self.root_sets.keys())
1272
1273        while len(attr_set_queue):
1274            a_set = attr_set_queue.pop(0)
1275            for attr, spec in self.attr_sets[a_set].items():
1276                if 'nested-attributes' in spec:
1277                    nested = self._load_nested_set_nest(spec)
1278                else:
1279                    continue
1280
1281                if nested not in attr_set_seen:
1282                    attr_set_queue.append(nested)
1283                    attr_set_seen.add(nested)
1284
1285        for root_set, rs_members in self.root_sets.items():
1286            for attr, spec in self.attr_sets[root_set].items():
1287                if 'nested-attributes' in spec:
1288                    nested = spec['nested-attributes']
1289                else:
1290                    nested = None
1291
1292                if nested:
1293                    if attr in rs_members['request']:
1294                        self.pure_nested_structs[nested].request = True
1295                    if attr in rs_members['reply']:
1296                        self.pure_nested_structs[nested].reply = True
1297
1298                    if spec.is_multi_val():
1299                        child = self.pure_nested_structs.get(nested)
1300                        child.in_multi_val = True
1301
1302        self._sort_pure_types()
1303
1304        # Propagate the request / reply / recursive
1305        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1306            for _, spec in self.attr_sets[attr_set].items():
1307                if attr_set in struct.child_nests:
1308                    struct.recursive = True
1309
1310                if 'nested-attributes' in spec:
1311                    child_name = spec['nested-attributes']
1312                else:
1313                    continue
1314
1315                struct.child_nests.add(child_name)
1316                child = self.pure_nested_structs.get(child_name)
1317                if child:
1318                    if not child.recursive:
1319                        struct.child_nests.update(child.child_nests)
1320                    child.request |= struct.request
1321                    child.reply |= struct.reply
1322                    if spec.is_multi_val():
1323                        child.in_multi_val = True
1324
1325        self._sort_pure_types()
1326
1327    def _load_attr_use(self):
1328        for _, struct in self.pure_nested_structs.items():
1329            if struct.request:
1330                for _, arg in struct.member_list():
1331                    arg.set_request()
1332            if struct.reply:
1333                for _, arg in struct.member_list():
1334                    arg.set_reply()
1335
1336        for root_set, rs_members in self.root_sets.items():
1337            for attr, spec in self.attr_sets[root_set].items():
1338                if attr in rs_members['request']:
1339                    spec.set_request()
1340                if attr in rs_members['reply']:
1341                    spec.set_reply()
1342
1343    def _load_global_policy(self):
1344        global_set = set()
1345        attr_set_name = None
1346        for op_name, op in self.ops.items():
1347            if not op:
1348                continue
1349            if 'attribute-set' not in op:
1350                continue
1351
1352            if attr_set_name is None:
1353                attr_set_name = op['attribute-set']
1354            if attr_set_name != op['attribute-set']:
1355                raise Exception('For a global policy all ops must use the same set')
1356
1357            for op_mode in ['do', 'dump']:
1358                if op_mode in op:
1359                    req = op[op_mode].get('request')
1360                    if req:
1361                        global_set.update(req.get('attributes', []))
1362
1363        self.global_policy = []
1364        self.global_policy_set = attr_set_name
1365        for attr in self.attr_sets[attr_set_name]:
1366            if attr in global_set:
1367                self.global_policy.append(attr)
1368
1369    def _load_hooks(self):
1370        for op in self.ops.values():
1371            for op_mode in ['do', 'dump']:
1372                if op_mode not in op:
1373                    continue
1374                for when in ['pre', 'post']:
1375                    if when not in op[op_mode]:
1376                        continue
1377                    name = op[op_mode][when]
1378                    if name in self.hooks[when][op_mode]['set']:
1379                        continue
1380                    self.hooks[when][op_mode]['set'].add(name)
1381                    self.hooks[when][op_mode]['list'].append(name)
1382
1383
1384class RenderInfo:
1385    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1386        self.family = family
1387        self.nl = cw.nlib
1388        self.ku_space = ku_space
1389        self.op_mode = op_mode
1390        self.op = op
1391
1392        self.fixed_hdr = None
1393        self.fixed_hdr_len = 'ys->family->hdr_len'
1394        if op and op.fixed_header:
1395            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1396            if op.fixed_header != family.fixed_header:
1397                if family.is_classic():
1398                    self.fixed_hdr_len = f"sizeof({self.fixed_hdr})"
1399                else:
1400                    raise Exception(f"Per-op fixed header not supported, yet")
1401
1402
1403        # 'do' and 'dump' response parsing is identical
1404        self.type_consistent = True
1405        self.type_oneside = False
1406        if op_mode != 'do' and 'dump' in op:
1407            if 'do' in op:
1408                if ('reply' in op['do']) != ('reply' in op["dump"]):
1409                    self.type_consistent = False
1410                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1411                    self.type_consistent = False
1412            else:
1413                self.type_consistent = True
1414                self.type_oneside = True
1415
1416        self.attr_set = attr_set
1417        if not self.attr_set:
1418            self.attr_set = op['attribute-set']
1419
1420        self.type_name_conflict = False
1421        if op:
1422            self.type_name = c_lower(op.name)
1423        else:
1424            self.type_name = c_lower(attr_set)
1425            if attr_set in family.consts:
1426                self.type_name_conflict = True
1427
1428        self.cw = cw
1429
1430        self.struct = dict()
1431        if op_mode == 'notify':
1432            op_mode = 'do' if 'do' in op else 'dump'
1433        for op_dir in ['request', 'reply']:
1434            if op:
1435                type_list = []
1436                if op_dir in op[op_mode]:
1437                    type_list = op[op_mode][op_dir]['attributes']
1438                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1439        if op_mode == 'event':
1440            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1441
1442    def type_empty(self, key):
1443        return len(self.struct[key].attr_list) == 0 and self.fixed_hdr is None
1444
1445    def needs_nlflags(self, direction):
1446        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1447
1448
1449class CodeWriter:
1450    def __init__(self, nlib, out_file=None, overwrite=True):
1451        self.nlib = nlib
1452        self._overwrite = overwrite
1453
1454        self._nl = False
1455        self._block_end = False
1456        self._silent_block = False
1457        self._ind = 0
1458        self._ifdef_block = None
1459        if out_file is None:
1460            self._out = os.sys.stdout
1461        else:
1462            self._out = tempfile.NamedTemporaryFile('w+')
1463            self._out_file = out_file
1464
1465    def __del__(self):
1466        self.close_out_file()
1467
1468    def close_out_file(self):
1469        if self._out == os.sys.stdout:
1470            return
1471        # Avoid modifying the file if contents didn't change
1472        self._out.flush()
1473        if not self._overwrite and os.path.isfile(self._out_file):
1474            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1475                return
1476        with open(self._out_file, 'w+') as out_file:
1477            self._out.seek(0)
1478            shutil.copyfileobj(self._out, out_file)
1479            self._out.close()
1480        self._out = os.sys.stdout
1481
1482    @classmethod
1483    def _is_cond(cls, line):
1484        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1485
1486    def p(self, line, add_ind=0):
1487        if self._block_end:
1488            self._block_end = False
1489            if line.startswith('else'):
1490                line = '} ' + line
1491            else:
1492                self._out.write('\t' * self._ind + '}\n')
1493
1494        if self._nl:
1495            self._out.write('\n')
1496            self._nl = False
1497
1498        ind = self._ind
1499        if line[-1] == ':':
1500            ind -= 1
1501        if self._silent_block:
1502            ind += 1
1503        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1504        self._silent_block |= line.strip() == 'else'
1505        if line[0] == '#':
1506            ind = 0
1507        if add_ind:
1508            ind += add_ind
1509        self._out.write('\t' * ind + line + '\n')
1510
1511    def nl(self):
1512        self._nl = True
1513
1514    def block_start(self, line=''):
1515        if line:
1516            line = line + ' '
1517        self.p(line + '{')
1518        self._ind += 1
1519
1520    def block_end(self, line=''):
1521        if line and line[0] not in {';', ','}:
1522            line = ' ' + line
1523        self._ind -= 1
1524        self._nl = False
1525        if not line:
1526            # Delay printing closing bracket in case "else" comes next
1527            if self._block_end:
1528                self._out.write('\t' * (self._ind + 1) + '}\n')
1529            self._block_end = True
1530        else:
1531            self.p('}' + line)
1532
1533    def write_doc_line(self, doc, indent=True):
1534        words = doc.split()
1535        line = ' *'
1536        for word in words:
1537            if len(line) + len(word) >= 79:
1538                self.p(line)
1539                line = ' *'
1540                if indent:
1541                    line += '  '
1542            line += ' ' + word
1543        self.p(line)
1544
1545    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1546        if not args:
1547            args = ['void']
1548
1549        if doc:
1550            self.p('/*')
1551            self.p(' * ' + doc)
1552            self.p(' */')
1553
1554        oneline = qual_ret
1555        if qual_ret[-1] != '*':
1556            oneline += ' '
1557        oneline += f"{name}({', '.join(args)}){suffix}"
1558
1559        if len(oneline) < 80:
1560            self.p(oneline)
1561            return
1562
1563        v = qual_ret
1564        if len(v) > 3:
1565            self.p(v)
1566            v = ''
1567        elif qual_ret[-1] != '*':
1568            v += ' '
1569        v += name + '('
1570        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1571        delta_ind = len(v) - len(ind)
1572        v += args[0]
1573        i = 1
1574        while i < len(args):
1575            next_len = len(v) + len(args[i])
1576            if v[0] == '\t':
1577                next_len += delta_ind
1578            if next_len > 76:
1579                self.p(v + ',')
1580                v = ind
1581            else:
1582                v += ', '
1583            v += args[i]
1584            i += 1
1585        self.p(v + ')' + suffix)
1586
1587    def write_func_lvar(self, local_vars):
1588        if not local_vars:
1589            return
1590
1591        if type(local_vars) is str:
1592            local_vars = [local_vars]
1593
1594        local_vars.sort(key=len, reverse=True)
1595        for var in local_vars:
1596            self.p(var)
1597        self.nl()
1598
1599    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1600        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1601        self.block_start()
1602        self.write_func_lvar(local_vars=local_vars)
1603
1604        for line in body:
1605            self.p(line)
1606        self.block_end()
1607
1608    def writes_defines(self, defines):
1609        longest = 0
1610        for define in defines:
1611            if len(define[0]) > longest:
1612                longest = len(define[0])
1613        longest = ((longest + 8) // 8) * 8
1614        for define in defines:
1615            line = '#define ' + define[0]
1616            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1617            if type(define[1]) is int:
1618                line += str(define[1])
1619            elif type(define[1]) is str:
1620                line += '"' + define[1] + '"'
1621            self.p(line)
1622
1623    def write_struct_init(self, members):
1624        longest = max([len(x[0]) for x in members])
1625        longest += 1  # because we prepend a .
1626        longest = ((longest + 8) // 8) * 8
1627        for one in members:
1628            line = '.' + one[0]
1629            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1630            line += '= ' + str(one[1]) + ','
1631            self.p(line)
1632
1633    def ifdef_block(self, config):
1634        config_option = None
1635        if config:
1636            config_option = 'CONFIG_' + c_upper(config)
1637        if self._ifdef_block == config_option:
1638            return
1639
1640        if self._ifdef_block:
1641            self.p('#endif /* ' + self._ifdef_block + ' */')
1642        if config_option:
1643            self.p('#ifdef ' + config_option)
1644        self._ifdef_block = config_option
1645
1646
1647scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1648
1649direction_to_suffix = {
1650    'reply': '_rsp',
1651    'request': '_req',
1652    '': ''
1653}
1654
1655op_mode_to_wrapper = {
1656    'do': '',
1657    'dump': '_list',
1658    'notify': '_ntf',
1659    'event': '',
1660}
1661
1662_C_KW = {
1663    'auto',
1664    'bool',
1665    'break',
1666    'case',
1667    'char',
1668    'const',
1669    'continue',
1670    'default',
1671    'do',
1672    'double',
1673    'else',
1674    'enum',
1675    'extern',
1676    'float',
1677    'for',
1678    'goto',
1679    'if',
1680    'inline',
1681    'int',
1682    'long',
1683    'register',
1684    'return',
1685    'short',
1686    'signed',
1687    'sizeof',
1688    'static',
1689    'struct',
1690    'switch',
1691    'typedef',
1692    'union',
1693    'unsigned',
1694    'void',
1695    'volatile',
1696    'while'
1697}
1698
1699
1700def rdir(direction):
1701    if direction == 'reply':
1702        return 'request'
1703    if direction == 'request':
1704        return 'reply'
1705    return direction
1706
1707
1708def op_prefix(ri, direction, deref=False):
1709    suffix = f"_{ri.type_name}"
1710
1711    if not ri.op_mode or ri.op_mode == 'do':
1712        suffix += f"{direction_to_suffix[direction]}"
1713    else:
1714        if direction == 'request':
1715            suffix += '_req'
1716            if not ri.type_oneside:
1717                suffix += '_dump'
1718        else:
1719            if ri.type_consistent:
1720                if deref:
1721                    suffix += f"{direction_to_suffix[direction]}"
1722                else:
1723                    suffix += op_mode_to_wrapper[ri.op_mode]
1724            else:
1725                suffix += '_rsp'
1726                suffix += '_dump' if deref else '_list'
1727
1728    return f"{ri.family.c_name}{suffix}"
1729
1730
1731def type_name(ri, direction, deref=False):
1732    return f"struct {op_prefix(ri, direction, deref=deref)}"
1733
1734
1735def print_prototype(ri, direction, terminate=True, doc=None):
1736    suffix = ';' if terminate else ''
1737
1738    fname = ri.op.render_name
1739    if ri.op_mode == 'dump':
1740        fname += '_dump'
1741
1742    args = ['struct ynl_sock *ys']
1743    if 'request' in ri.op[ri.op_mode]:
1744        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1745
1746    ret = 'int'
1747    if 'reply' in ri.op[ri.op_mode]:
1748        ret = f"{type_name(ri, rdir(direction))} *"
1749
1750    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1751
1752
1753def print_req_prototype(ri):
1754    print_prototype(ri, "request", doc=ri.op['doc'])
1755
1756
1757def print_dump_prototype(ri):
1758    print_prototype(ri, "request")
1759
1760
1761def put_typol_fwd(cw, struct):
1762    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1763
1764
1765def put_typol(cw, struct):
1766    type_max = struct.attr_set.max_name
1767    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1768
1769    for _, arg in struct.member_list():
1770        arg.attr_typol(cw)
1771
1772    cw.block_end(line=';')
1773    cw.nl()
1774
1775    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1776    cw.p(f'.max_attr = {type_max},')
1777    cw.p(f'.table = {struct.render_name}_policy,')
1778    cw.block_end(line=';')
1779    cw.nl()
1780
1781
1782def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1783    args = [f'int {arg_name}']
1784    if enum:
1785        args = [enum.user_type + ' ' + arg_name]
1786    cw.write_func_prot('const char *', f'{render_name}_str', args)
1787    cw.block_start()
1788    if enum and enum.type == 'flags':
1789        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1790    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1791    cw.p('return NULL;')
1792    cw.p(f'return {map_name}[{arg_name}];')
1793    cw.block_end()
1794    cw.nl()
1795
1796
1797def put_op_name_fwd(family, cw):
1798    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1799
1800
1801def put_op_name(family, cw):
1802    map_name = f'{family.c_name}_op_strmap'
1803    cw.block_start(line=f"static const char * const {map_name}[] =")
1804    for op_name, op in family.msgs.items():
1805        if op.rsp_value:
1806            # Make sure we don't add duplicated entries, if multiple commands
1807            # produce the same response in legacy families.
1808            if family.rsp_by_value[op.rsp_value] != op:
1809                cw.p(f'// skip "{op_name}", duplicate reply value')
1810                continue
1811
1812            if op.req_value == op.rsp_value:
1813                cw.p(f'[{op.enum_name}] = "{op_name}",')
1814            else:
1815                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1816    cw.block_end(line=';')
1817    cw.nl()
1818
1819    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1820
1821
1822def put_enum_to_str_fwd(family, cw, enum):
1823    args = [enum.user_type + ' value']
1824    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1825
1826
1827def put_enum_to_str(family, cw, enum):
1828    map_name = f'{enum.render_name}_strmap'
1829    cw.block_start(line=f"static const char * const {map_name}[] =")
1830    for entry in enum.entries.values():
1831        cw.p(f'[{entry.value}] = "{entry.name}",')
1832    cw.block_end(line=';')
1833    cw.nl()
1834
1835    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1836
1837
1838def put_req_nested_prototype(ri, struct, suffix=';'):
1839    func_args = ['struct nlmsghdr *nlh',
1840                 'unsigned int attr_type',
1841                 f'{struct.ptr_name}obj']
1842
1843    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1844                          suffix=suffix)
1845
1846
1847def put_req_nested(ri, struct):
1848    local_vars = []
1849    init_lines = []
1850
1851    local_vars.append('struct nlattr *nest;')
1852    init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
1853
1854    has_anest = False
1855    has_count = False
1856    for _, arg in struct.member_list():
1857        has_anest |= arg.type == 'indexed-array'
1858        has_count |= arg.presence_type() == 'count'
1859    if has_anest:
1860        local_vars.append('struct nlattr *array;')
1861    if has_count:
1862        local_vars.append('unsigned int i;')
1863
1864    put_req_nested_prototype(ri, struct, suffix='')
1865    ri.cw.block_start()
1866    ri.cw.write_func_lvar(local_vars)
1867
1868    for line in init_lines:
1869        ri.cw.p(line)
1870
1871    for _, arg in struct.member_list():
1872        arg.attr_put(ri, "obj")
1873
1874    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1875
1876    ri.cw.nl()
1877    ri.cw.p('return 0;')
1878    ri.cw.block_end()
1879    ri.cw.nl()
1880
1881
1882def _multi_parse(ri, struct, init_lines, local_vars):
1883    if struct.nested:
1884        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1885    else:
1886        if ri.fixed_hdr:
1887            local_vars += ['void *hdr;']
1888        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1889        if ri.op.fixed_header != ri.family.fixed_header:
1890            if ri.family.is_classic():
1891                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({ri.fixed_hdr}))"
1892            else:
1893                raise Exception(f"Per-op fixed header not supported, yet")
1894
1895    array_nests = set()
1896    multi_attrs = set()
1897    needs_parg = False
1898    for arg, aspec in struct.member_list():
1899        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
1900            if aspec["sub-type"] in {'binary', 'nest'}:
1901                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1902                array_nests.add(arg)
1903            elif aspec['sub-type'] in scalars:
1904                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1905                array_nests.add(arg)
1906            else:
1907                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
1908        if 'multi-attr' in aspec:
1909            multi_attrs.add(arg)
1910        needs_parg |= 'nested-attributes' in aspec
1911    if array_nests or multi_attrs:
1912        local_vars.append('int i;')
1913    if needs_parg:
1914        local_vars.append('struct ynl_parse_arg parg;')
1915        init_lines.append('parg.ys = yarg->ys;')
1916
1917    all_multi = array_nests | multi_attrs
1918
1919    for anest in sorted(all_multi):
1920        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1921
1922    ri.cw.block_start()
1923    ri.cw.write_func_lvar(local_vars)
1924
1925    for line in init_lines:
1926        ri.cw.p(line)
1927    ri.cw.nl()
1928
1929    for arg in struct.inherited:
1930        ri.cw.p(f'dst->{arg} = {arg};')
1931
1932    if ri.fixed_hdr:
1933        if ri.family.is_classic():
1934            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
1935        else:
1936            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1937        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1938    for anest in sorted(all_multi):
1939        aspec = struct[anest]
1940        ri.cw.p(f"if (dst->{aspec.c_name})")
1941        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1942
1943    ri.cw.nl()
1944    ri.cw.block_start(line=iter_line)
1945    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1946    ri.cw.nl()
1947
1948    first = True
1949    for _, arg in struct.member_list():
1950        good = arg.attr_get(ri, 'dst', first=first)
1951        # First may be 'unused' or 'pad', ignore those
1952        first &= not good
1953
1954    ri.cw.block_end()
1955    ri.cw.nl()
1956
1957    for anest in sorted(array_nests):
1958        aspec = struct[anest]
1959
1960        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1961        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1962        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
1963        ri.cw.p('i = 0;')
1964        if 'nested-attributes' in aspec:
1965            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1966        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1967        if 'nested-attributes' in aspec:
1968            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1969            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1970            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1971        elif aspec.sub_type in scalars:
1972            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
1973        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
1974            # Length is validated by typol
1975            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
1976        else:
1977            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
1978        ri.cw.p('i++;')
1979        ri.cw.block_end()
1980        ri.cw.block_end()
1981    ri.cw.nl()
1982
1983    for anest in sorted(multi_attrs):
1984        aspec = struct[anest]
1985        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1986        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1987        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
1988        ri.cw.p('i = 0;')
1989        if 'nested-attributes' in aspec:
1990            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1991        ri.cw.block_start(line=iter_line)
1992        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1993        if 'nested-attributes' in aspec:
1994            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1995            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1996            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1997        elif aspec.type in scalars:
1998            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1999        elif aspec.type == 'binary' and 'struct' in aspec:
2000            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2001            ri.cw.nl()
2002            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2003            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2004            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2005        elif aspec.type == 'string':
2006            ri.cw.p('unsigned int len;')
2007            ri.cw.nl()
2008            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2009            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2010            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2011            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2012            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2013        else:
2014            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2015        ri.cw.p('i++;')
2016        ri.cw.block_end()
2017        ri.cw.block_end()
2018        ri.cw.block_end()
2019    ri.cw.nl()
2020
2021    if struct.nested:
2022        ri.cw.p('return 0;')
2023    else:
2024        ri.cw.p('return YNL_PARSE_CB_OK;')
2025    ri.cw.block_end()
2026    ri.cw.nl()
2027
2028
2029def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2030    func_args = ['struct ynl_parse_arg *yarg',
2031                 'const struct nlattr *nested']
2032    for arg in struct.inherited:
2033        func_args.append('__u32 ' + arg)
2034
2035    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2036                          suffix=suffix)
2037
2038
2039def parse_rsp_nested(ri, struct):
2040    parse_rsp_nested_prototype(ri, struct, suffix='')
2041
2042    local_vars = ['const struct nlattr *attr;',
2043                  f'{struct.ptr_name}dst = yarg->data;']
2044    init_lines = []
2045
2046    if struct.member_list():
2047        _multi_parse(ri, struct, init_lines, local_vars)
2048    else:
2049        # Empty nest
2050        ri.cw.block_start()
2051        ri.cw.p('return 0;')
2052        ri.cw.block_end()
2053        ri.cw.nl()
2054
2055
2056def parse_rsp_msg(ri, deref=False):
2057    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2058        return
2059
2060    func_args = ['const struct nlmsghdr *nlh',
2061                 'struct ynl_parse_arg *yarg']
2062
2063    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2064                  'const struct nlattr *attr;']
2065    init_lines = ['dst = yarg->data;']
2066
2067    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2068
2069    if ri.struct["reply"].member_list():
2070        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2071    else:
2072        # Empty reply
2073        ri.cw.block_start()
2074        ri.cw.p('return YNL_PARSE_CB_OK;')
2075        ri.cw.block_end()
2076        ri.cw.nl()
2077
2078
2079def print_req(ri):
2080    ret_ok = '0'
2081    ret_err = '-1'
2082    direction = "request"
2083    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2084                  'struct nlmsghdr *nlh;',
2085                  'int err;']
2086
2087    if 'reply' in ri.op[ri.op_mode]:
2088        ret_ok = 'rsp'
2089        ret_err = 'NULL'
2090        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2091
2092    if ri.fixed_hdr:
2093        local_vars += ['size_t hdr_len;',
2094                       'void *hdr;']
2095
2096    for _, attr in ri.struct["request"].member_list():
2097        if attr.presence_type() == 'count':
2098            local_vars += ['unsigned int i;']
2099            break
2100
2101    print_prototype(ri, direction, terminate=False)
2102    ri.cw.block_start()
2103    ri.cw.write_func_lvar(local_vars)
2104
2105    if ri.family.is_classic():
2106        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2107    else:
2108        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2109
2110    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2111    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2112    if 'reply' in ri.op[ri.op_mode]:
2113        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2114    ri.cw.nl()
2115
2116    if ri.fixed_hdr:
2117        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2118        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2119        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2120        ri.cw.nl()
2121
2122    for _, attr in ri.struct["request"].member_list():
2123        attr.attr_put(ri, "req")
2124    ri.cw.nl()
2125
2126    if 'reply' in ri.op[ri.op_mode]:
2127        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2128        ri.cw.p('yrs.yarg.data = rsp;')
2129        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2130        if ri.op.value is not None:
2131            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2132        else:
2133            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2134        ri.cw.nl()
2135    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2136    ri.cw.p('if (err < 0)')
2137    if 'reply' in ri.op[ri.op_mode]:
2138        ri.cw.p('goto err_free;')
2139    else:
2140        ri.cw.p('return -1;')
2141    ri.cw.nl()
2142
2143    ri.cw.p(f"return {ret_ok};")
2144    ri.cw.nl()
2145
2146    if 'reply' in ri.op[ri.op_mode]:
2147        ri.cw.p('err_free:')
2148        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2149        ri.cw.p(f"return {ret_err};")
2150
2151    ri.cw.block_end()
2152
2153
2154def print_dump(ri):
2155    direction = "request"
2156    print_prototype(ri, direction, terminate=False)
2157    ri.cw.block_start()
2158    local_vars = ['struct ynl_dump_state yds = {};',
2159                  'struct nlmsghdr *nlh;',
2160                  'int err;']
2161
2162    if ri.fixed_hdr:
2163        local_vars += ['size_t hdr_len;',
2164                       'void *hdr;']
2165
2166    ri.cw.write_func_lvar(local_vars)
2167
2168    ri.cw.p('yds.yarg.ys = ys;')
2169    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2170    ri.cw.p("yds.yarg.data = NULL;")
2171    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2172    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2173    if ri.op.value is not None:
2174        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2175    else:
2176        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2177    ri.cw.nl()
2178    if ri.family.is_classic():
2179        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2180    else:
2181        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2182
2183    if ri.fixed_hdr:
2184        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2185        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2186        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2187        ri.cw.nl()
2188
2189    if "request" in ri.op[ri.op_mode]:
2190        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2191        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2192        ri.cw.nl()
2193        for _, attr in ri.struct["request"].member_list():
2194            attr.attr_put(ri, "req")
2195    ri.cw.nl()
2196
2197    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2198    ri.cw.p('if (err < 0)')
2199    ri.cw.p('goto free_list;')
2200    ri.cw.nl()
2201
2202    ri.cw.p('return yds.first;')
2203    ri.cw.nl()
2204    ri.cw.p('free_list:')
2205    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2206    ri.cw.p('return NULL;')
2207    ri.cw.block_end()
2208
2209
2210def call_free(ri, direction, var):
2211    return f"{op_prefix(ri, direction)}_free({var});"
2212
2213
2214def free_arg_name(direction):
2215    if direction:
2216        return direction_to_suffix[direction][1:]
2217    return 'obj'
2218
2219
2220def print_alloc_wrapper(ri, direction):
2221    name = op_prefix(ri, direction)
2222    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
2223    ri.cw.block_start()
2224    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
2225    ri.cw.block_end()
2226
2227
2228def print_free_prototype(ri, direction, suffix=';'):
2229    name = op_prefix(ri, direction)
2230    struct_name = name
2231    if ri.type_name_conflict:
2232        struct_name += '_'
2233    arg = free_arg_name(direction)
2234    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2235
2236
2237def print_nlflags_set(ri, direction):
2238    name = op_prefix(ri, direction)
2239    ri.cw.write_func_prot(f'static inline void', f"{name}_set_nlflags",
2240                          [f"struct {name} *req", "__u16 nl_flags"])
2241    ri.cw.block_start()
2242    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2243    ri.cw.block_end()
2244    ri.cw.nl()
2245
2246
2247def _print_type(ri, direction, struct):
2248    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2249    if not direction and ri.type_name_conflict:
2250        suffix += '_'
2251
2252    if ri.op_mode == 'dump' and not ri.type_oneside:
2253        suffix += '_dump'
2254
2255    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2256
2257    if ri.needs_nlflags(direction):
2258        ri.cw.p('__u16 _nlmsg_flags;')
2259        ri.cw.nl()
2260    if ri.fixed_hdr:
2261        ri.cw.p(ri.fixed_hdr + ' _hdr;')
2262        ri.cw.nl()
2263
2264    for type_filter in ['present', 'len', 'count']:
2265        meta_started = False
2266        for _, attr in struct.member_list():
2267            line = attr.presence_member(ri.ku_space, type_filter)
2268            if line:
2269                if not meta_started:
2270                    ri.cw.block_start(line=f"struct")
2271                    meta_started = True
2272                ri.cw.p(line)
2273        if meta_started:
2274            ri.cw.block_end(line=f'_{type_filter};')
2275    ri.cw.nl()
2276
2277    for arg in struct.inherited:
2278        ri.cw.p(f"__u32 {arg};")
2279
2280    for _, attr in struct.member_list():
2281        attr.struct_member(ri)
2282
2283    ri.cw.block_end(line=';')
2284    ri.cw.nl()
2285
2286
2287def print_type(ri, direction):
2288    _print_type(ri, direction, ri.struct[direction])
2289
2290
2291def print_type_full(ri, struct):
2292    _print_type(ri, "", struct)
2293
2294
2295def print_type_helpers(ri, direction, deref=False):
2296    print_free_prototype(ri, direction)
2297    ri.cw.nl()
2298
2299    if ri.needs_nlflags(direction):
2300        print_nlflags_set(ri, direction)
2301
2302    if ri.ku_space == 'user' and direction == 'request':
2303        for _, attr in ri.struct[direction].member_list():
2304            attr.setter(ri, ri.attr_set, direction, deref=deref)
2305    ri.cw.nl()
2306
2307
2308def print_req_type_helpers(ri):
2309    if ri.type_empty("request"):
2310        return
2311    print_alloc_wrapper(ri, "request")
2312    print_type_helpers(ri, "request")
2313
2314
2315def print_rsp_type_helpers(ri):
2316    if 'reply' not in ri.op[ri.op_mode]:
2317        return
2318    print_type_helpers(ri, "reply")
2319
2320
2321def print_parse_prototype(ri, direction, terminate=True):
2322    suffix = "_rsp" if direction == "reply" else "_req"
2323    term = ';' if terminate else ''
2324
2325    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2326                          ['const struct nlattr **tb',
2327                           f"struct {ri.op.render_name}{suffix} *req"],
2328                          suffix=term)
2329
2330
2331def print_req_type(ri):
2332    if ri.type_empty("request"):
2333        return
2334    print_type(ri, "request")
2335
2336
2337def print_req_free(ri):
2338    if 'request' not in ri.op[ri.op_mode]:
2339        return
2340    _free_type(ri, 'request', ri.struct['request'])
2341
2342
2343def print_rsp_type(ri):
2344    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2345        direction = 'reply'
2346    elif ri.op_mode == 'event':
2347        direction = 'reply'
2348    else:
2349        return
2350    print_type(ri, direction)
2351
2352
2353def print_wrapped_type(ri):
2354    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2355    if ri.op_mode == 'dump':
2356        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2357    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2358        ri.cw.p('__u16 family;')
2359        ri.cw.p('__u8 cmd;')
2360        ri.cw.p('struct ynl_ntf_base_type *next;')
2361        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2362    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2363    ri.cw.block_end(line=';')
2364    ri.cw.nl()
2365    print_free_prototype(ri, 'reply')
2366    ri.cw.nl()
2367
2368
2369def _free_type_members_iter(ri, struct):
2370    if struct.free_needs_iter():
2371        ri.cw.p('unsigned int i;')
2372        ri.cw.nl()
2373
2374
2375def _free_type_members(ri, var, struct, ref=''):
2376    for _, attr in struct.member_list():
2377        attr.free(ri, var, ref)
2378
2379
2380def _free_type(ri, direction, struct):
2381    var = free_arg_name(direction)
2382
2383    print_free_prototype(ri, direction, suffix='')
2384    ri.cw.block_start()
2385    _free_type_members_iter(ri, struct)
2386    _free_type_members(ri, var, struct)
2387    if direction:
2388        ri.cw.p(f'free({var});')
2389    ri.cw.block_end()
2390    ri.cw.nl()
2391
2392
2393def free_rsp_nested_prototype(ri):
2394        print_free_prototype(ri, "")
2395
2396
2397def free_rsp_nested(ri, struct):
2398    _free_type(ri, "", struct)
2399
2400
2401def print_rsp_free(ri):
2402    if 'reply' not in ri.op[ri.op_mode]:
2403        return
2404    _free_type(ri, 'reply', ri.struct['reply'])
2405
2406
2407def print_dump_type_free(ri):
2408    sub_type = type_name(ri, 'reply')
2409
2410    print_free_prototype(ri, 'reply', suffix='')
2411    ri.cw.block_start()
2412    ri.cw.p(f"{sub_type} *next = rsp;")
2413    ri.cw.nl()
2414    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2415    _free_type_members_iter(ri, ri.struct['reply'])
2416    ri.cw.p('rsp = next;')
2417    ri.cw.p('next = rsp->next;')
2418    ri.cw.nl()
2419
2420    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2421    ri.cw.p(f'free(rsp);')
2422    ri.cw.block_end()
2423    ri.cw.block_end()
2424    ri.cw.nl()
2425
2426
2427def print_ntf_type_free(ri):
2428    print_free_prototype(ri, 'reply', suffix='')
2429    ri.cw.block_start()
2430    _free_type_members_iter(ri, ri.struct['reply'])
2431    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2432    ri.cw.p(f'free(rsp);')
2433    ri.cw.block_end()
2434    ri.cw.nl()
2435
2436
2437def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2438    if terminate and ri and policy_should_be_static(struct.family):
2439        return
2440
2441    if terminate:
2442        prefix = 'extern '
2443    else:
2444        if ri and policy_should_be_static(struct.family):
2445            prefix = 'static '
2446        else:
2447            prefix = ''
2448
2449    suffix = ';' if terminate else ' = {'
2450
2451    max_attr = struct.attr_max_val
2452    if ri:
2453        name = ri.op.render_name
2454        if ri.op.dual_policy:
2455            name += '_' + ri.op_mode
2456    else:
2457        name = struct.render_name
2458    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2459
2460
2461def print_req_policy(cw, struct, ri=None):
2462    if ri and ri.op:
2463        cw.ifdef_block(ri.op.get('config-cond', None))
2464    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2465    for _, arg in struct.member_list():
2466        arg.attr_policy(cw)
2467    cw.p("};")
2468    cw.ifdef_block(None)
2469    cw.nl()
2470
2471
2472def kernel_can_gen_family_struct(family):
2473    return family.proto == 'genetlink'
2474
2475
2476def policy_should_be_static(family):
2477    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2478
2479
2480def print_kernel_policy_ranges(family, cw):
2481    first = True
2482    for _, attr_set in family.attr_sets.items():
2483        if attr_set.subset_of:
2484            continue
2485
2486        for _, attr in attr_set.items():
2487            if not attr.request:
2488                continue
2489            if 'full-range' not in attr.checks:
2490                continue
2491
2492            if first:
2493                cw.p('/* Integer value ranges */')
2494                first = False
2495
2496            sign = '' if attr.type[0] == 'u' else '_signed'
2497            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2498            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2499            members = []
2500            if 'min' in attr.checks:
2501                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2502            if 'max' in attr.checks:
2503                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2504            cw.write_struct_init(members)
2505            cw.block_end(line=';')
2506            cw.nl()
2507
2508
2509def print_kernel_policy_sparse_enum_validates(family, cw):
2510    first = True
2511    for _, attr_set in family.attr_sets.items():
2512        if attr_set.subset_of:
2513            continue
2514
2515        for _, attr in attr_set.items():
2516            if not attr.request:
2517                continue
2518            if not attr.enum_name:
2519                continue
2520            if 'sparse' not in attr.checks:
2521                continue
2522
2523            if first:
2524                cw.p('/* Sparse enums validation callbacks */')
2525                first = False
2526
2527            sign = '' if attr.type[0] == 'u' else '_signed'
2528            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2529            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2530                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2531            cw.block_start()
2532            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2533            enum = family.consts[attr['enum']]
2534            first_entry = True
2535            for entry in enum.entries.values():
2536                if first_entry:
2537                    first_entry = False
2538                else:
2539                    cw.p('fallthrough;')
2540                cw.p(f'case {entry.c_name}:')
2541            cw.p('return 0;')
2542            cw.block_end()
2543            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2544            cw.p('return -EINVAL;')
2545            cw.block_end()
2546            cw.nl()
2547
2548
2549def print_kernel_op_table_fwd(family, cw, terminate):
2550    exported = not kernel_can_gen_family_struct(family)
2551
2552    if not terminate or exported:
2553        cw.p(f"/* Ops table for {family.ident_name} */")
2554
2555        pol_to_struct = {'global': 'genl_small_ops',
2556                         'per-op': 'genl_ops',
2557                         'split': 'genl_split_ops'}
2558        struct_type = pol_to_struct[family.kernel_policy]
2559
2560        if not exported:
2561            cnt = ""
2562        elif family.kernel_policy == 'split':
2563            cnt = 0
2564            for op in family.ops.values():
2565                if 'do' in op:
2566                    cnt += 1
2567                if 'dump' in op:
2568                    cnt += 1
2569        else:
2570            cnt = len(family.ops)
2571
2572        qual = 'static const' if not exported else 'const'
2573        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2574        if terminate:
2575            cw.p(f"extern {line};")
2576        else:
2577            cw.block_start(line=line + ' =')
2578
2579    if not terminate:
2580        return
2581
2582    cw.nl()
2583    for name in family.hooks['pre']['do']['list']:
2584        cw.write_func_prot('int', c_lower(name),
2585                           ['const struct genl_split_ops *ops',
2586                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2587    for name in family.hooks['post']['do']['list']:
2588        cw.write_func_prot('void', c_lower(name),
2589                           ['const struct genl_split_ops *ops',
2590                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2591    for name in family.hooks['pre']['dump']['list']:
2592        cw.write_func_prot('int', c_lower(name),
2593                           ['struct netlink_callback *cb'], suffix=';')
2594    for name in family.hooks['post']['dump']['list']:
2595        cw.write_func_prot('int', c_lower(name),
2596                           ['struct netlink_callback *cb'], suffix=';')
2597
2598    cw.nl()
2599
2600    for op_name, op in family.ops.items():
2601        if op.is_async:
2602            continue
2603
2604        if 'do' in op:
2605            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2606            cw.write_func_prot('int', name,
2607                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2608
2609        if 'dump' in op:
2610            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2611            cw.write_func_prot('int', name,
2612                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2613    cw.nl()
2614
2615
2616def print_kernel_op_table_hdr(family, cw):
2617    print_kernel_op_table_fwd(family, cw, terminate=True)
2618
2619
2620def print_kernel_op_table(family, cw):
2621    print_kernel_op_table_fwd(family, cw, terminate=False)
2622    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2623        for op_name, op in family.ops.items():
2624            if op.is_async:
2625                continue
2626
2627            cw.ifdef_block(op.get('config-cond', None))
2628            cw.block_start()
2629            members = [('cmd', op.enum_name)]
2630            if 'dont-validate' in op:
2631                members.append(('validate',
2632                                ' | '.join([c_upper('genl-dont-validate-' + x)
2633                                            for x in op['dont-validate']])), )
2634            for op_mode in ['do', 'dump']:
2635                if op_mode in op:
2636                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2637                    members.append((op_mode + 'it', name))
2638            if family.kernel_policy == 'per-op':
2639                struct = Struct(family, op['attribute-set'],
2640                                type_list=op['do']['request']['attributes'])
2641
2642                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2643                members.append(('policy', name))
2644                members.append(('maxattr', struct.attr_max_val.enum_name))
2645            if 'flags' in op:
2646                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2647            cw.write_struct_init(members)
2648            cw.block_end(line=',')
2649    elif family.kernel_policy == 'split':
2650        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2651                    'dump': {'pre': 'start', 'post': 'done'}}
2652
2653        for op_name, op in family.ops.items():
2654            for op_mode in ['do', 'dump']:
2655                if op.is_async or op_mode not in op:
2656                    continue
2657
2658                cw.ifdef_block(op.get('config-cond', None))
2659                cw.block_start()
2660                members = [('cmd', op.enum_name)]
2661                if 'dont-validate' in op:
2662                    dont_validate = []
2663                    for x in op['dont-validate']:
2664                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2665                            continue
2666                        if op_mode == "dump" and x == 'strict':
2667                            continue
2668                        dont_validate.append(x)
2669
2670                    if dont_validate:
2671                        members.append(('validate',
2672                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2673                                                    for x in dont_validate])), )
2674                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2675                if 'pre' in op[op_mode]:
2676                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2677                members.append((op_mode + 'it', name))
2678                if 'post' in op[op_mode]:
2679                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2680                if 'request' in op[op_mode]:
2681                    struct = Struct(family, op['attribute-set'],
2682                                    type_list=op[op_mode]['request']['attributes'])
2683
2684                    if op.dual_policy:
2685                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2686                    else:
2687                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2688                    members.append(('policy', name))
2689                    members.append(('maxattr', struct.attr_max_val.enum_name))
2690                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2691                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2692                cw.write_struct_init(members)
2693                cw.block_end(line=',')
2694    cw.ifdef_block(None)
2695
2696    cw.block_end(line=';')
2697    cw.nl()
2698
2699
2700def print_kernel_mcgrp_hdr(family, cw):
2701    if not family.mcgrps['list']:
2702        return
2703
2704    cw.block_start('enum')
2705    for grp in family.mcgrps['list']:
2706        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2707        cw.p(grp_id)
2708    cw.block_end(';')
2709    cw.nl()
2710
2711
2712def print_kernel_mcgrp_src(family, cw):
2713    if not family.mcgrps['list']:
2714        return
2715
2716    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2717    for grp in family.mcgrps['list']:
2718        name = grp['name']
2719        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2720        cw.p('[' + grp_id + '] = { "' + name + '", },')
2721    cw.block_end(';')
2722    cw.nl()
2723
2724
2725def print_kernel_family_struct_hdr(family, cw):
2726    if not kernel_can_gen_family_struct(family):
2727        return
2728
2729    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2730    cw.nl()
2731    if 'sock-priv' in family.kernel_family:
2732        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2733        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2734        cw.nl()
2735
2736
2737def print_kernel_family_struct_src(family, cw):
2738    if not kernel_can_gen_family_struct(family):
2739        return
2740
2741    if 'sock-priv' in family.kernel_family:
2742        # Generate "trampolines" to make CFI happy
2743        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2744                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2745                      ["void *priv"])
2746        cw.nl()
2747        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2748                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2749                      ["void *priv"])
2750        cw.nl()
2751
2752    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2753    cw.p('.name\t\t= ' + family.fam_key + ',')
2754    cw.p('.version\t= ' + family.ver_key + ',')
2755    cw.p('.netnsok\t= true,')
2756    cw.p('.parallel_ops\t= true,')
2757    cw.p('.module\t\t= THIS_MODULE,')
2758    if family.kernel_policy == 'per-op':
2759        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2760        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2761    elif family.kernel_policy == 'split':
2762        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2763        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2764    if family.mcgrps['list']:
2765        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2766        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2767    if 'sock-priv' in family.kernel_family:
2768        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2769        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2770        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2771    cw.block_end(';')
2772
2773
2774def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2775    start_line = 'enum'
2776    if enum_name in obj:
2777        if obj[enum_name]:
2778            start_line = 'enum ' + c_lower(obj[enum_name])
2779    elif ckey and ckey in obj:
2780        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2781    cw.block_start(line=start_line)
2782
2783
2784def render_uapi_unified(family, cw, max_by_define, separate_ntf):
2785    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2786    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2787    max_value = f"({cnt_name} - 1)"
2788
2789    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2790    val = 0
2791    for op in family.msgs.values():
2792        if separate_ntf and ('notify' in op or 'event' in op):
2793            continue
2794
2795        suffix = ','
2796        if op.value != val:
2797            suffix = f" = {op.value},"
2798            val = op.value
2799        cw.p(op.enum_name + suffix)
2800        val += 1
2801    cw.nl()
2802    cw.p(cnt_name + ('' if max_by_define else ','))
2803    if not max_by_define:
2804        cw.p(f"{max_name} = {max_value}")
2805    cw.block_end(line=';')
2806    if max_by_define:
2807        cw.p(f"#define {max_name} {max_value}")
2808    cw.nl()
2809
2810
2811def render_uapi_directional(family, cw, max_by_define):
2812    max_name = f"{family.op_prefix}USER_MAX"
2813    cnt_name = f"__{family.op_prefix}USER_CNT"
2814    max_value = f"({cnt_name} - 1)"
2815
2816    cw.block_start(line='enum')
2817    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
2818    val = 0
2819    for op in family.msgs.values():
2820        if 'do' in op and 'event' not in op:
2821            suffix = ','
2822            if op.value and op.value != val:
2823                suffix = f" = {op.value},"
2824                val = op.value
2825            cw.p(op.enum_name + suffix)
2826            val += 1
2827    cw.nl()
2828    cw.p(cnt_name + ('' if max_by_define else ','))
2829    if not max_by_define:
2830        cw.p(f"{max_name} = {max_value}")
2831    cw.block_end(line=';')
2832    if max_by_define:
2833        cw.p(f"#define {max_name} {max_value}")
2834    cw.nl()
2835
2836    max_name = f"{family.op_prefix}KERNEL_MAX"
2837    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
2838    max_value = f"({cnt_name} - 1)"
2839
2840    cw.block_start(line='enum')
2841    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
2842    val = 0
2843    for op in family.msgs.values():
2844        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
2845            enum_name = op.enum_name
2846            if 'event' not in op and 'notify' not in op:
2847                enum_name = f'{enum_name}_REPLY'
2848
2849            suffix = ','
2850            if op.value and op.value != val:
2851                suffix = f" = {op.value},"
2852                val = op.value
2853            cw.p(enum_name + suffix)
2854            val += 1
2855    cw.nl()
2856    cw.p(cnt_name + ('' if max_by_define else ','))
2857    if not max_by_define:
2858        cw.p(f"{max_name} = {max_value}")
2859    cw.block_end(line=';')
2860    if max_by_define:
2861        cw.p(f"#define {max_name} {max_value}")
2862    cw.nl()
2863
2864
2865def render_uapi(family, cw):
2866    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2867    hdr_prot = hdr_prot.replace('/', '_')
2868    cw.p('#ifndef ' + hdr_prot)
2869    cw.p('#define ' + hdr_prot)
2870    cw.nl()
2871
2872    defines = [(family.fam_key, family["name"]),
2873               (family.ver_key, family.get('version', 1))]
2874    cw.writes_defines(defines)
2875    cw.nl()
2876
2877    defines = []
2878    for const in family['definitions']:
2879        if const.get('header'):
2880            continue
2881
2882        if const['type'] != 'const':
2883            cw.writes_defines(defines)
2884            defines = []
2885            cw.nl()
2886
2887        # Write kdoc for enum and flags (one day maybe also structs)
2888        if const['type'] == 'enum' or const['type'] == 'flags':
2889            enum = family.consts[const['name']]
2890
2891            if enum.header:
2892                continue
2893
2894            if enum.has_doc():
2895                if enum.has_entry_doc():
2896                    cw.p('/**')
2897                    doc = ''
2898                    if 'doc' in enum:
2899                        doc = ' - ' + enum['doc']
2900                    cw.write_doc_line(enum.enum_name + doc)
2901                else:
2902                    cw.p('/*')
2903                    cw.write_doc_line(enum['doc'], indent=False)
2904                for entry in enum.entries.values():
2905                    if entry.has_doc():
2906                        doc = '@' + entry.c_name + ': ' + entry['doc']
2907                        cw.write_doc_line(doc)
2908                cw.p(' */')
2909
2910            uapi_enum_start(family, cw, const, 'name')
2911            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
2912            for entry in enum.entries.values():
2913                suffix = ','
2914                if entry.value_change:
2915                    suffix = f" = {entry.user_value()}" + suffix
2916                cw.p(entry.c_name + suffix)
2917
2918            if const.get('render-max', False):
2919                cw.nl()
2920                cw.p('/* private: */')
2921                if const['type'] == 'flags':
2922                    max_name = c_upper(name_pfx + 'mask')
2923                    max_val = f' = {enum.get_mask()},'
2924                    cw.p(max_name + max_val)
2925                else:
2926                    cnt_name = enum.enum_cnt_name
2927                    max_name = c_upper(name_pfx + 'max')
2928                    if not cnt_name:
2929                        cnt_name = '__' + name_pfx + 'max'
2930                    cw.p(c_upper(cnt_name) + ',')
2931                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
2932            cw.block_end(line=';')
2933            cw.nl()
2934        elif const['type'] == 'const':
2935            defines.append([c_upper(family.get('c-define-name',
2936                                               f"{family.ident_name}-{const['name']}")),
2937                            const['value']])
2938
2939    if defines:
2940        cw.writes_defines(defines)
2941        cw.nl()
2942
2943    max_by_define = family.get('max-by-define', False)
2944
2945    for _, attr_set in family.attr_sets.items():
2946        if attr_set.subset_of:
2947            continue
2948
2949        max_value = f"({attr_set.cnt_name} - 1)"
2950
2951        val = 0
2952        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2953        for _, attr in attr_set.items():
2954            suffix = ','
2955            if attr.value != val:
2956                suffix = f" = {attr.value},"
2957                val = attr.value
2958            val += 1
2959            cw.p(attr.enum_name + suffix)
2960        if attr_set.items():
2961            cw.nl()
2962        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2963        if not max_by_define:
2964            cw.p(f"{attr_set.max_name} = {max_value}")
2965        cw.block_end(line=';')
2966        if max_by_define:
2967            cw.p(f"#define {attr_set.max_name} {max_value}")
2968        cw.nl()
2969
2970    # Commands
2971    separate_ntf = 'async-prefix' in family['operations']
2972
2973    if family.msg_id_model == 'unified':
2974        render_uapi_unified(family, cw, max_by_define, separate_ntf)
2975    elif family.msg_id_model == 'directional':
2976        render_uapi_directional(family, cw, max_by_define)
2977    else:
2978        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
2979
2980    if separate_ntf:
2981        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2982        for op in family.msgs.values():
2983            if separate_ntf and not ('notify' in op or 'event' in op):
2984                continue
2985
2986            suffix = ','
2987            if 'value' in op:
2988                suffix = f" = {op['value']},"
2989            cw.p(op.enum_name + suffix)
2990        cw.block_end(line=';')
2991        cw.nl()
2992
2993    # Multicast
2994    defines = []
2995    for grp in family.mcgrps['list']:
2996        name = grp['name']
2997        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
2998                        f'{name}'])
2999    cw.nl()
3000    if defines:
3001        cw.writes_defines(defines)
3002        cw.nl()
3003
3004    cw.p(f'#endif /* {hdr_prot} */')
3005
3006
3007def _render_user_ntf_entry(ri, op):
3008    if not ri.family.is_classic():
3009        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3010    else:
3011        crud_op = ri.family.req_by_value[op.rsp_value]
3012        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3013    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3014    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3015    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3016    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3017    ri.cw.block_end(line=',')
3018
3019
3020def render_user_family(family, cw, prototype):
3021    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3022    if prototype:
3023        cw.p(f'extern {symbol};')
3024        return
3025
3026    if family.ntfs:
3027        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3028        for ntf_op_name, ntf_op in family.ntfs.items():
3029            if 'notify' in ntf_op:
3030                op = family.ops[ntf_op['notify']]
3031                ri = RenderInfo(cw, family, "user", op, "notify")
3032            elif 'event' in ntf_op:
3033                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3034            else:
3035                raise Exception('Invalid notification ' + ntf_op_name)
3036            _render_user_ntf_entry(ri, ntf_op)
3037        for op_name, op in family.ops.items():
3038            if 'event' not in op:
3039                continue
3040            ri = RenderInfo(cw, family, "user", op, "event")
3041            _render_user_ntf_entry(ri, op)
3042        cw.block_end(line=";")
3043        cw.nl()
3044
3045    cw.block_start(f'{symbol} = ')
3046    cw.p(f'.name\t\t= "{family.c_name}",')
3047    if family.is_classic():
3048        cw.p(f'.is_classic\t= true,')
3049        cw.p(f'.classic_id\t= {family.get("protonum")},')
3050    if family.is_classic():
3051        if family.fixed_header:
3052            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3053    elif family.fixed_header:
3054        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3055    else:
3056        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3057    if family.ntfs:
3058        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3059        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3060    cw.block_end(line=';')
3061
3062
3063def family_contains_bitfield32(family):
3064    for _, attr_set in family.attr_sets.items():
3065        if attr_set.subset_of:
3066            continue
3067        for _, attr in attr_set.items():
3068            if attr.type == "bitfield32":
3069                return True
3070    return False
3071
3072
3073def find_kernel_root(full_path):
3074    sub_path = ''
3075    while True:
3076        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3077        full_path = os.path.dirname(full_path)
3078        maintainers = os.path.join(full_path, "MAINTAINERS")
3079        if os.path.exists(maintainers):
3080            return full_path, sub_path[:-1]
3081
3082
3083def main():
3084    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3085    parser.add_argument('--mode', dest='mode', type=str, required=True,
3086                        choices=('user', 'kernel', 'uapi'))
3087    parser.add_argument('--spec', dest='spec', type=str, required=True)
3088    parser.add_argument('--header', dest='header', action='store_true', default=None)
3089    parser.add_argument('--source', dest='header', action='store_false')
3090    parser.add_argument('--user-header', nargs='+', default=[])
3091    parser.add_argument('--cmp-out', action='store_true', default=None,
3092                        help='Do not overwrite the output file if the new output is identical to the old')
3093    parser.add_argument('--exclude-op', action='append', default=[])
3094    parser.add_argument('-o', dest='out_file', type=str, default=None)
3095    args = parser.parse_args()
3096
3097    if args.header is None:
3098        parser.error("--header or --source is required")
3099
3100    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3101
3102    try:
3103        parsed = Family(args.spec, exclude_ops)
3104        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3105            print('Spec license:', parsed.license)
3106            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3107            os.sys.exit(1)
3108    except yaml.YAMLError as exc:
3109        print(exc)
3110        os.sys.exit(1)
3111        return
3112
3113    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3114
3115    _, spec_kernel = find_kernel_root(args.spec)
3116    if args.mode == 'uapi' or args.header:
3117        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3118    else:
3119        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3120    cw.p("/* Do not edit directly, auto-generated from: */")
3121    cw.p(f"/*\t{spec_kernel} */")
3122    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3123    if args.exclude_op or args.user_header:
3124        line = ''
3125        line += ' --user-header '.join([''] + args.user_header)
3126        line += ' --exclude-op '.join([''] + args.exclude_op)
3127        cw.p(f'/* YNL-ARG{line} */')
3128    cw.nl()
3129
3130    if args.mode == 'uapi':
3131        render_uapi(parsed, cw)
3132        return
3133
3134    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3135    if args.header:
3136        cw.p('#ifndef ' + hdr_prot)
3137        cw.p('#define ' + hdr_prot)
3138        cw.nl()
3139
3140    if args.out_file:
3141        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3142    else:
3143        hdr_file = "generated_header_file.h"
3144
3145    if args.mode == 'kernel':
3146        cw.p('#include <net/netlink.h>')
3147        cw.p('#include <net/genetlink.h>')
3148        cw.nl()
3149        if not args.header:
3150            if args.out_file:
3151                cw.p(f'#include "{hdr_file}"')
3152            cw.nl()
3153        headers = ['uapi/' + parsed.uapi_header]
3154        headers += parsed.kernel_family.get('headers', [])
3155    else:
3156        cw.p('#include <stdlib.h>')
3157        cw.p('#include <string.h>')
3158        if args.header:
3159            cw.p('#include <linux/types.h>')
3160            if family_contains_bitfield32(parsed):
3161                cw.p('#include <linux/netlink.h>')
3162        else:
3163            cw.p(f'#include "{hdr_file}"')
3164            cw.p('#include "ynl.h"')
3165        headers = []
3166    for definition in parsed['definitions'] + parsed['attribute-sets']:
3167        if 'header' in definition:
3168            headers.append(definition['header'])
3169    if args.mode == 'user':
3170        headers.append(parsed.uapi_header)
3171    seen_header = []
3172    for one in headers:
3173        if one not in seen_header:
3174            cw.p(f"#include <{one}>")
3175            seen_header.append(one)
3176    cw.nl()
3177
3178    if args.mode == "user":
3179        if not args.header:
3180            cw.p("#include <linux/genetlink.h>")
3181            cw.nl()
3182            for one in args.user_header:
3183                cw.p(f'#include "{one}"')
3184        else:
3185            cw.p('struct ynl_sock;')
3186            cw.nl()
3187            render_user_family(parsed, cw, True)
3188        cw.nl()
3189
3190    if args.mode == "kernel":
3191        if args.header:
3192            for _, struct in sorted(parsed.pure_nested_structs.items()):
3193                if struct.request:
3194                    cw.p('/* Common nested types */')
3195                    break
3196            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3197                if struct.request:
3198                    print_req_policy_fwd(cw, struct)
3199            cw.nl()
3200
3201            if parsed.kernel_policy == 'global':
3202                cw.p(f"/* Global operation policy for {parsed.name} */")
3203
3204                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3205                print_req_policy_fwd(cw, struct)
3206                cw.nl()
3207
3208            if parsed.kernel_policy in {'per-op', 'split'}:
3209                for op_name, op in parsed.ops.items():
3210                    if 'do' in op and 'event' not in op:
3211                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3212                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3213                        cw.nl()
3214
3215            print_kernel_op_table_hdr(parsed, cw)
3216            print_kernel_mcgrp_hdr(parsed, cw)
3217            print_kernel_family_struct_hdr(parsed, cw)
3218        else:
3219            print_kernel_policy_ranges(parsed, cw)
3220            print_kernel_policy_sparse_enum_validates(parsed, cw)
3221
3222            for _, struct in sorted(parsed.pure_nested_structs.items()):
3223                if struct.request:
3224                    cw.p('/* Common nested types */')
3225                    break
3226            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3227                if struct.request:
3228                    print_req_policy(cw, struct)
3229            cw.nl()
3230
3231            if parsed.kernel_policy == 'global':
3232                cw.p(f"/* Global operation policy for {parsed.name} */")
3233
3234                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3235                print_req_policy(cw, struct)
3236                cw.nl()
3237
3238            for op_name, op in parsed.ops.items():
3239                if parsed.kernel_policy in {'per-op', 'split'}:
3240                    for op_mode in ['do', 'dump']:
3241                        if op_mode in op and 'request' in op[op_mode]:
3242                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3243                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3244                            print_req_policy(cw, ri.struct['request'], ri=ri)
3245                            cw.nl()
3246
3247            print_kernel_op_table(parsed, cw)
3248            print_kernel_mcgrp_src(parsed, cw)
3249            print_kernel_family_struct_src(parsed, cw)
3250
3251    if args.mode == "user":
3252        if args.header:
3253            cw.p('/* Enums */')
3254            put_op_name_fwd(parsed, cw)
3255
3256            for name, const in parsed.consts.items():
3257                if isinstance(const, EnumSet):
3258                    put_enum_to_str_fwd(parsed, cw, const)
3259            cw.nl()
3260
3261            cw.p('/* Common nested types */')
3262            for attr_set, struct in parsed.pure_nested_structs.items():
3263                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3264                print_type_full(ri, struct)
3265                if struct.request and struct.in_multi_val:
3266                    free_rsp_nested_prototype(ri)
3267                    cw.nl()
3268
3269            for op_name, op in parsed.ops.items():
3270                cw.p(f"/* ============== {op.enum_name} ============== */")
3271
3272                if 'do' in op and 'event' not in op:
3273                    cw.p(f"/* {op.enum_name} - do */")
3274                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3275                    print_req_type(ri)
3276                    print_req_type_helpers(ri)
3277                    cw.nl()
3278                    print_rsp_type(ri)
3279                    print_rsp_type_helpers(ri)
3280                    cw.nl()
3281                    print_req_prototype(ri)
3282                    cw.nl()
3283
3284                if 'dump' in op:
3285                    cw.p(f"/* {op.enum_name} - dump */")
3286                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3287                    print_req_type(ri)
3288                    print_req_type_helpers(ri)
3289                    if not ri.type_consistent or ri.type_oneside:
3290                        print_rsp_type(ri)
3291                    print_wrapped_type(ri)
3292                    print_dump_prototype(ri)
3293                    cw.nl()
3294
3295                if op.has_ntf:
3296                    cw.p(f"/* {op.enum_name} - notify */")
3297                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3298                    if not ri.type_consistent:
3299                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3300                    print_wrapped_type(ri)
3301
3302            for op_name, op in parsed.ntfs.items():
3303                if 'event' in op:
3304                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3305                    cw.p(f"/* {op.enum_name} - event */")
3306                    print_rsp_type(ri)
3307                    cw.nl()
3308                    print_wrapped_type(ri)
3309            cw.nl()
3310        else:
3311            cw.p('/* Enums */')
3312            put_op_name(parsed, cw)
3313
3314            for name, const in parsed.consts.items():
3315                if isinstance(const, EnumSet):
3316                    put_enum_to_str(parsed, cw, const)
3317            cw.nl()
3318
3319            has_recursive_nests = False
3320            cw.p('/* Policies */')
3321            for struct in parsed.pure_nested_structs.values():
3322                if struct.recursive:
3323                    put_typol_fwd(cw, struct)
3324                    has_recursive_nests = True
3325            if has_recursive_nests:
3326                cw.nl()
3327            for struct in parsed.pure_nested_structs.values():
3328                put_typol(cw, struct)
3329            for name in parsed.root_sets:
3330                struct = Struct(parsed, name)
3331                put_typol(cw, struct)
3332
3333            cw.p('/* Common nested types */')
3334            if has_recursive_nests:
3335                for attr_set, struct in parsed.pure_nested_structs.items():
3336                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3337                    free_rsp_nested_prototype(ri)
3338                    if struct.request:
3339                        put_req_nested_prototype(ri, struct)
3340                    if struct.reply:
3341                        parse_rsp_nested_prototype(ri, struct)
3342                cw.nl()
3343            for attr_set, struct in parsed.pure_nested_structs.items():
3344                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3345
3346                free_rsp_nested(ri, struct)
3347                if struct.request:
3348                    put_req_nested(ri, struct)
3349                if struct.reply:
3350                    parse_rsp_nested(ri, struct)
3351
3352            for op_name, op in parsed.ops.items():
3353                cw.p(f"/* ============== {op.enum_name} ============== */")
3354                if 'do' in op and 'event' not in op:
3355                    cw.p(f"/* {op.enum_name} - do */")
3356                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3357                    print_req_free(ri)
3358                    print_rsp_free(ri)
3359                    parse_rsp_msg(ri)
3360                    print_req(ri)
3361                    cw.nl()
3362
3363                if 'dump' in op:
3364                    cw.p(f"/* {op.enum_name} - dump */")
3365                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3366                    if not ri.type_consistent or ri.type_oneside:
3367                        parse_rsp_msg(ri, deref=True)
3368                    print_req_free(ri)
3369                    print_dump_type_free(ri)
3370                    print_dump(ri)
3371                    cw.nl()
3372
3373                if op.has_ntf:
3374                    cw.p(f"/* {op.enum_name} - notify */")
3375                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3376                    if not ri.type_consistent:
3377                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3378                    print_ntf_type_free(ri)
3379
3380            for op_name, op in parsed.ntfs.items():
3381                if 'event' in op:
3382                    cw.p(f"/* {op.enum_name} - event */")
3383
3384                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3385                    parse_rsp_msg(ri)
3386
3387                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3388                    print_ntf_type_free(ri)
3389            cw.nl()
3390            render_user_family(parsed, cw, False)
3391
3392    if args.header:
3393        cw.p(f'#endif /* {hdr_prot} */')
3394
3395
3396if __name__ == "__main__":
3397    main()
3398