xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 3186a8e55ae3428ec1e06af09075e20885376e4e)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17from lib import SpecSubMessage, SpecSubMessageFormat
18
19
20def c_upper(name):
21    return name.upper().replace('-', '_')
22
23
24def c_lower(name):
25    return name.lower().replace('-', '_')
26
27
28def limit_to_number(name):
29    """
30    Turn a string limit like u32-max or s64-min into its numerical value
31    """
32    if name[0] == 'u' and name.endswith('-min'):
33        return 0
34    width = int(name[1:-4])
35    if name[0] == 's':
36        width -= 1
37    value = (1 << width) - 1
38    if name[0] == 's' and name.endswith('-min'):
39        value = -value - 1
40    return value
41
42
43class BaseNlLib:
44    def get_family_id(self):
45        return 'ys->family_id'
46
47
48class Type(SpecAttr):
49    def __init__(self, family, attr_set, attr, value):
50        super().__init__(family, attr_set, attr, value)
51
52        self.attr = attr
53        self.attr_set = attr_set
54        self.type = attr['type']
55        self.checks = attr.get('checks', {})
56
57        self.request = False
58        self.reply = False
59
60        if 'len' in attr:
61            self.len = attr['len']
62
63        if 'nested-attributes' in attr:
64            nested = attr['nested-attributes']
65        else:
66            nested = None
67
68        if nested:
69            self.nested_attrs = nested
70            if self.nested_attrs == family.name:
71                self.nested_render_name = c_lower(f"{family.ident_name}")
72            else:
73                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
74
75            if self.nested_attrs in self.family.consts:
76                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
77            else:
78                self.nested_struct_type = 'struct ' + self.nested_render_name
79
80        self.c_name = c_lower(self.name)
81        if self.c_name in _C_KW:
82            self.c_name += '_'
83        if self.c_name[0].isdigit():
84            self.c_name = '_' + self.c_name
85
86        # Added by resolve():
87        self.enum_name = None
88        delattr(self, "enum_name")
89
90    def _get_real_attr(self):
91        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
92        return self.family.attr_sets[self.attr_set.subset_of][self.name]
93
94    def set_request(self):
95        self.request = True
96        if self.attr_set.subset_of:
97            self._get_real_attr().set_request()
98
99    def set_reply(self):
100        self.reply = True
101        if self.attr_set.subset_of:
102            self._get_real_attr().set_reply()
103
104    def get_limit(self, limit, default=None):
105        value = self.checks.get(limit, default)
106        if value is None:
107            return value
108        if isinstance(value, int):
109            return value
110        if value in self.family.consts:
111            return self.family.consts[value]["value"]
112        return limit_to_number(value)
113
114    def get_limit_str(self, limit, default=None, suffix=''):
115        value = self.checks.get(limit, default)
116        if value is None:
117            return ''
118        if isinstance(value, int):
119            return str(value) + suffix
120        if value in self.family.consts:
121            const = self.family.consts[value]
122            if const.get('header'):
123                return c_upper(value)
124            return c_upper(f"{self.family['name']}-{value}")
125        return c_upper(value)
126
127    def resolve(self):
128        if 'name-prefix' in self.attr:
129            enum_name = f"{self.attr['name-prefix']}{self.name}"
130        else:
131            enum_name = f"{self.attr_set.name_prefix}{self.name}"
132        self.enum_name = c_upper(enum_name)
133
134        if self.attr_set.subset_of:
135            if self.checks != self._get_real_attr().checks:
136                raise Exception("Overriding checks not supported by codegen, yet")
137
138    def is_multi_val(self):
139        return None
140
141    def is_scalar(self):
142        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
143
144    def is_recursive(self):
145        return False
146
147    def is_recursive_for_op(self, ri):
148        return self.is_recursive() and not ri.op
149
150    def presence_type(self):
151        return 'present'
152
153    def presence_member(self, space, type_filter):
154        if self.presence_type() != type_filter:
155            return
156
157        if self.presence_type() == 'present':
158            pfx = '__' if space == 'user' else ''
159            return f"{pfx}u32 {self.c_name}:1;"
160
161        if self.presence_type() in {'len', 'count'}:
162            pfx = '__' if space == 'user' else ''
163            return f"{pfx}u32 {self.c_name};"
164
165    def _complex_member_type(self, ri):
166        return None
167
168    def free_needs_iter(self):
169        return False
170
171    def _free_lines(self, ri, var, ref):
172        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
173            return [f'free({var}->{ref}{self.c_name});']
174        return []
175
176    def free(self, ri, var, ref):
177        lines = self._free_lines(ri, var, ref)
178        for line in lines:
179            ri.cw.p(line)
180
181    def arg_member(self, ri):
182        member = self._complex_member_type(ri)
183        if member:
184            spc = ' ' if member[-1] != '*' else ''
185            arg = [member + spc + '*' + self.c_name]
186            if self.presence_type() == 'count':
187                arg += ['unsigned int n_' + self.c_name]
188            return arg
189        raise Exception(f"Struct member not implemented for class type {self.type}")
190
191    def struct_member(self, ri):
192        member = self._complex_member_type(ri)
193        if member:
194            ptr = '*' if self.is_multi_val() else ''
195            if self.is_recursive_for_op(ri):
196                ptr = '*'
197            spc = ' ' if member[-1] != '*' else ''
198            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
199            return
200        members = self.arg_member(ri)
201        for one in members:
202            ri.cw.p(one + ';')
203
204    def _attr_policy(self, policy):
205        return '{ .type = ' + policy + ', }'
206
207    def attr_policy(self, cw):
208        policy = f'NLA_{c_upper(self.type)}'
209        if self.attr.get('byte-order') == 'big-endian':
210            if self.type in {'u16', 'u32'}:
211                policy = f'NLA_BE{self.type[1:]}'
212
213        spec = self._attr_policy(policy)
214        cw.p(f"\t[{self.enum_name}] = {spec},")
215
216    def _attr_typol(self):
217        raise Exception(f"Type policy not implemented for class type {self.type}")
218
219    def attr_typol(self, cw):
220        typol = self._attr_typol()
221        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
222
223    def _attr_put_line(self, ri, var, line):
224        presence = self.presence_type()
225        if presence in {'present', 'len'}:
226            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
227        ri.cw.p(f"{line};")
228
229    def _attr_put_simple(self, ri, var, put_type):
230        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
231        self._attr_put_line(ri, var, line)
232
233    def attr_put(self, ri, var):
234        raise Exception(f"Put not implemented for class type {self.type}")
235
236    def _attr_get(self, ri, var):
237        raise Exception(f"Attr get not implemented for class type {self.type}")
238
239    def attr_get(self, ri, var, first):
240        lines, init_lines, local_vars = self._attr_get(ri, var)
241        if type(lines) is str:
242            lines = [lines]
243        if type(init_lines) is str:
244            init_lines = [init_lines]
245
246        kw = 'if' if first else 'else if'
247        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
248        if local_vars:
249            for local in local_vars:
250                ri.cw.p(local)
251            ri.cw.nl()
252
253        if not self.is_multi_val():
254            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
255            ri.cw.p("return YNL_PARSE_CB_ERROR;")
256            if self.presence_type() == 'present':
257                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
258
259        if init_lines:
260            ri.cw.nl()
261            for line in init_lines:
262                ri.cw.p(line)
263
264        for line in lines:
265            ri.cw.p(line)
266        ri.cw.block_end()
267        return True
268
269    def _setter_lines(self, ri, member, presence):
270        raise Exception(f"Setter not implemented for class type {self.type}")
271
272    def setter(self, ri, space, direction, deref=False, ref=None):
273        ref = (ref if ref else []) + [self.c_name]
274        var = "req"
275        member = f"{var}->{'.'.join(ref)}"
276
277        local_vars = []
278        if self.free_needs_iter():
279            local_vars += ['unsigned int i;']
280
281        code = []
282        presence = ''
283        for i in range(0, len(ref)):
284            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
285            # Every layer below last is a nest, so we know it uses bit presence
286            # last layer is "self" and may be a complex type
287            if i == len(ref) - 1 and self.presence_type() != 'present':
288                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
289                continue
290            code.append(presence + ' = 1;')
291        ref_path = '.'.join(ref[:-1])
292        if ref_path:
293            ref_path += '.'
294        code += self._free_lines(ri, var, ref_path)
295        code += self._setter_lines(ri, member, presence)
296
297        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
298        free = bool([x for x in code if 'free(' in x])
299        alloc = bool([x for x in code if 'alloc(' in x])
300        if free and not alloc:
301            func_name = '__' + func_name
302        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
303                         body=code,
304                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
305
306
307class TypeUnused(Type):
308    def presence_type(self):
309        return ''
310
311    def arg_member(self, ri):
312        return []
313
314    def _attr_get(self, ri, var):
315        return ['return YNL_PARSE_CB_ERROR;'], None, None
316
317    def _attr_typol(self):
318        return '.type = YNL_PT_REJECT, '
319
320    def attr_policy(self, cw):
321        pass
322
323    def attr_put(self, ri, var):
324        pass
325
326    def attr_get(self, ri, var, first):
327        pass
328
329    def setter(self, ri, space, direction, deref=False, ref=None):
330        pass
331
332
333class TypePad(Type):
334    def presence_type(self):
335        return ''
336
337    def arg_member(self, ri):
338        return []
339
340    def _attr_typol(self):
341        return '.type = YNL_PT_IGNORE, '
342
343    def attr_put(self, ri, var):
344        pass
345
346    def attr_get(self, ri, var, first):
347        pass
348
349    def attr_policy(self, cw):
350        pass
351
352    def setter(self, ri, space, direction, deref=False, ref=None):
353        pass
354
355
356class TypeScalar(Type):
357    def __init__(self, family, attr_set, attr, value):
358        super().__init__(family, attr_set, attr, value)
359
360        self.byte_order_comment = ''
361        if 'byte-order' in attr:
362            self.byte_order_comment = f" /* {attr['byte-order']} */"
363
364        # Classic families have some funny enums, don't bother
365        # computing checks, since we only need them for kernel policies
366        if not family.is_classic():
367            self._init_checks()
368
369        # Added by resolve():
370        self.is_bitfield = None
371        delattr(self, "is_bitfield")
372        self.type_name = None
373        delattr(self, "type_name")
374
375    def resolve(self):
376        self.resolve_up(super())
377
378        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
379            self.is_bitfield = True
380        elif 'enum' in self.attr:
381            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
382        else:
383            self.is_bitfield = False
384
385        if not self.is_bitfield and 'enum' in self.attr:
386            self.type_name = self.family.consts[self.attr['enum']].user_type
387        elif self.is_auto_scalar:
388            self.type_name = '__' + self.type[0] + '64'
389        else:
390            self.type_name = '__' + self.type
391
392    def _init_checks(self):
393        if 'enum' in self.attr:
394            enum = self.family.consts[self.attr['enum']]
395            low, high = enum.value_range()
396            if low == None and high == None:
397                self.checks['sparse'] = True
398            else:
399                if 'min' not in self.checks:
400                    if low != 0 or self.type[0] == 's':
401                        self.checks['min'] = low
402                if 'max' not in self.checks:
403                    self.checks['max'] = high
404
405        if 'min' in self.checks and 'max' in self.checks:
406            if self.get_limit('min') > self.get_limit('max'):
407                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
408            self.checks['range'] = True
409
410        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
411        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
412        if low < 0 and self.type[0] == 'u':
413            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
414        if low < -32768 or high > 32767:
415            self.checks['full-range'] = True
416
417    def _attr_policy(self, policy):
418        if 'flags-mask' in self.checks or self.is_bitfield:
419            if self.is_bitfield:
420                enum = self.family.consts[self.attr['enum']]
421                mask = enum.get_mask(as_flags=True)
422            else:
423                flags = self.family.consts[self.checks['flags-mask']]
424                flag_cnt = len(flags['entries'])
425                mask = (1 << flag_cnt) - 1
426            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
427        elif 'full-range' in self.checks:
428            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
429        elif 'range' in self.checks:
430            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
431        elif 'min' in self.checks:
432            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
433        elif 'max' in self.checks:
434            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
435        elif 'sparse' in self.checks:
436            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
437        return super()._attr_policy(policy)
438
439    def _attr_typol(self):
440        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
441
442    def arg_member(self, ri):
443        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
444
445    def attr_put(self, ri, var):
446        self._attr_put_simple(ri, var, self.type)
447
448    def _attr_get(self, ri, var):
449        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
450
451    def _setter_lines(self, ri, member, presence):
452        return [f"{member} = {self.c_name};"]
453
454
455class TypeFlag(Type):
456    def arg_member(self, ri):
457        return []
458
459    def _attr_typol(self):
460        return '.type = YNL_PT_FLAG, '
461
462    def attr_put(self, ri, var):
463        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
464
465    def _attr_get(self, ri, var):
466        return [], None, None
467
468    def _setter_lines(self, ri, member, presence):
469        return []
470
471
472class TypeString(Type):
473    def arg_member(self, ri):
474        return [f"const char *{self.c_name}"]
475
476    def presence_type(self):
477        return 'len'
478
479    def struct_member(self, ri):
480        ri.cw.p(f"char *{self.c_name};")
481
482    def _attr_typol(self):
483        return f'.type = YNL_PT_NUL_STR, '
484
485    def _attr_policy(self, policy):
486        if 'exact-len' in self.checks:
487            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
488        else:
489            mem = '{ .type = ' + policy
490            if 'max-len' in self.checks:
491                mem += ', .len = ' + self.get_limit_str('max-len')
492            mem += ', }'
493        return mem
494
495    def attr_policy(self, cw):
496        if self.checks.get('unterminated-ok', False):
497            policy = 'NLA_STRING'
498        else:
499            policy = 'NLA_NUL_STRING'
500
501        spec = self._attr_policy(policy)
502        cw.p(f"\t[{self.enum_name}] = {spec},")
503
504    def attr_put(self, ri, var):
505        self._attr_put_simple(ri, var, 'str')
506
507    def _attr_get(self, ri, var):
508        len_mem = var + '->_len.' + self.c_name
509        return [f"{len_mem} = len;",
510                f"{var}->{self.c_name} = malloc(len + 1);",
511                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
512                f"{var}->{self.c_name}[len] = 0;"], \
513               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
514               ['unsigned int len;']
515
516    def _setter_lines(self, ri, member, presence):
517        return [f"{presence} = strlen({self.c_name});",
518                f"{member} = malloc({presence} + 1);",
519                f'memcpy({member}, {self.c_name}, {presence});',
520                f'{member}[{presence}] = 0;']
521
522
523class TypeBinary(Type):
524    def arg_member(self, ri):
525        return [f"const void *{self.c_name}", 'size_t len']
526
527    def presence_type(self):
528        return 'len'
529
530    def struct_member(self, ri):
531        ri.cw.p(f"void *{self.c_name};")
532
533    def _attr_typol(self):
534        return f'.type = YNL_PT_BINARY,'
535
536    def _attr_policy(self, policy):
537        if len(self.checks) == 0:
538            pass
539        elif len(self.checks) == 1:
540            check_name = list(self.checks)[0]
541            if check_name not in {'exact-len', 'min-len', 'max-len'}:
542                raise Exception('Unsupported check for binary type: ' + check_name)
543        else:
544            raise Exception('More than one check for binary type not implemented, yet')
545
546        if len(self.checks) == 0:
547            mem = '{ .type = NLA_BINARY, }'
548        elif 'exact-len' in self.checks:
549            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
550        elif 'min-len' in self.checks:
551            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
552        elif 'max-len' in self.checks:
553            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
554
555        return mem
556
557    def attr_put(self, ri, var):
558        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
559                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
560
561    def _attr_get(self, ri, var):
562        len_mem = var + '->_len.' + self.c_name
563        return [f"{len_mem} = len;",
564                f"{var}->{self.c_name} = malloc(len);",
565                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
566               ['len = ynl_attr_data_len(attr);'], \
567               ['unsigned int len;']
568
569    def _setter_lines(self, ri, member, presence):
570        return [f"{presence} = len;",
571                f"{member} = malloc({presence});",
572                f'memcpy({member}, {self.c_name}, {presence});']
573
574
575class TypeBinaryStruct(TypeBinary):
576    def struct_member(self, ri):
577        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
578
579    def _attr_get(self, ri, var):
580        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
581        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
582        return [f"{len_mem} = len;",
583                f"if (len < {struct_sz})",
584                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
585                "else",
586                f"{var}->{self.c_name} = malloc(len);",
587                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
588               ['len = ynl_attr_data_len(attr);'], \
589               ['unsigned int len;']
590
591
592class TypeBinaryScalarArray(TypeBinary):
593    def arg_member(self, ri):
594        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
595
596    def presence_type(self):
597        return 'count'
598
599    def struct_member(self, ri):
600        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
601
602    def attr_put(self, ri, var):
603        presence = self.presence_type()
604        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
605        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
606        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
607                f"{var}->{self.c_name}, i);")
608        ri.cw.block_end()
609
610    def _attr_get(self, ri, var):
611        len_mem = var + '->_count.' + self.c_name
612        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
613                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
614                f"{var}->{self.c_name} = malloc(len);",
615                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
616               ['len = ynl_attr_data_len(attr);'], \
617               ['unsigned int len;']
618
619    def _setter_lines(self, ri, member, presence):
620        return [f"{presence} = count;",
621                f"count *= sizeof(__{self.get('sub-type')});",
622                f"{member} = malloc(count);",
623                f'memcpy({member}, {self.c_name}, count);']
624
625
626class TypeBitfield32(Type):
627    def _complex_member_type(self, ri):
628        return "struct nla_bitfield32"
629
630    def _attr_typol(self):
631        return f'.type = YNL_PT_BITFIELD32, '
632
633    def _attr_policy(self, policy):
634        if not 'enum' in self.attr:
635            raise Exception('Enum required for bitfield32 attr')
636        enum = self.family.consts[self.attr['enum']]
637        mask = enum.get_mask(as_flags=True)
638        return f"NLA_POLICY_BITFIELD32({mask})"
639
640    def attr_put(self, ri, var):
641        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
642        self._attr_put_line(ri, var, line)
643
644    def _attr_get(self, ri, var):
645        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
646
647    def _setter_lines(self, ri, member, presence):
648        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
649
650
651class TypeNest(Type):
652    def is_recursive(self):
653        return self.family.pure_nested_structs[self.nested_attrs].recursive
654
655    def _complex_member_type(self, ri):
656        return self.nested_struct_type
657
658    def _free_lines(self, ri, var, ref):
659        lines = []
660        at = '&'
661        if self.is_recursive_for_op(ri):
662            at = ''
663            lines += [f'if ({var}->{ref}{self.c_name})']
664        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
665        return lines
666
667    def _attr_typol(self):
668        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
669
670    def _attr_policy(self, policy):
671        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
672
673    def attr_put(self, ri, var):
674        at = '' if self.is_recursive_for_op(ri) else '&'
675        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
676                            f"{self.enum_name}, {at}{var}->{self.c_name})")
677
678    def _attr_get(self, ri, var):
679        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
680                     "return YNL_PARSE_CB_ERROR;"]
681        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
682                      f"parg.data = &{var}->{self.c_name};"]
683        return get_lines, init_lines, None
684
685    def setter(self, ri, space, direction, deref=False, ref=None):
686        ref = (ref if ref else []) + [self.c_name]
687
688        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
689            if attr.is_recursive():
690                continue
691            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
692
693
694class TypeMultiAttr(Type):
695    def __init__(self, family, attr_set, attr, value, base_type):
696        super().__init__(family, attr_set, attr, value)
697
698        self.base_type = base_type
699
700    def is_multi_val(self):
701        return True
702
703    def presence_type(self):
704        return 'count'
705
706    def _complex_member_type(self, ri):
707        if 'type' not in self.attr or self.attr['type'] == 'nest':
708            return self.nested_struct_type
709        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
710            return None  # use arg_member()
711        elif self.attr['type'] == 'string':
712            return 'struct ynl_string *'
713        elif self.attr['type'] in scalars:
714            scalar_pfx = '__' if ri.ku_space == 'user' else ''
715            return scalar_pfx + self.attr['type']
716        else:
717            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
718
719    def arg_member(self, ri):
720        if self.type == 'binary' and 'struct' in self.attr:
721            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
722                    f'unsigned int n_{self.c_name}']
723        return super().arg_member(ri)
724
725    def free_needs_iter(self):
726        return self.attr['type'] in {'nest', 'string'}
727
728    def _free_lines(self, ri, var, ref):
729        lines = []
730        if self.attr['type'] in scalars:
731            lines += [f"free({var}->{ref}{self.c_name});"]
732        elif self.attr['type'] == 'binary':
733            lines += [f"free({var}->{ref}{self.c_name});"]
734        elif self.attr['type'] == 'string':
735            lines += [
736                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
737                f"free({var}->{ref}{self.c_name}[i]);",
738                f"free({var}->{ref}{self.c_name});",
739            ]
740        elif 'type' not in self.attr or self.attr['type'] == 'nest':
741            lines += [
742                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
743                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
744                f"free({var}->{ref}{self.c_name});",
745            ]
746        else:
747            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
748        return lines
749
750    def _attr_policy(self, policy):
751        return self.base_type._attr_policy(policy)
752
753    def _attr_typol(self):
754        return self.base_type._attr_typol()
755
756    def _attr_get(self, ri, var):
757        return f'n_{self.c_name}++;', None, None
758
759    def attr_put(self, ri, var):
760        if self.attr['type'] in scalars:
761            put_type = self.type
762            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
763            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
764        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
765            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
766            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
767        elif self.attr['type'] == 'string':
768            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
769            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
770        elif 'type' not in self.attr or self.attr['type'] == 'nest':
771            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
772            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
773                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
774        else:
775            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
776
777    def _setter_lines(self, ri, member, presence):
778        return [f"{member} = {self.c_name};",
779                f"{presence} = n_{self.c_name};"]
780
781
782class TypeArrayNest(Type):
783    def is_multi_val(self):
784        return True
785
786    def presence_type(self):
787        return 'count'
788
789    def _complex_member_type(self, ri):
790        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
791            return self.nested_struct_type
792        elif self.attr['sub-type'] in scalars:
793            scalar_pfx = '__' if ri.ku_space == 'user' else ''
794            return scalar_pfx + self.attr['sub-type']
795        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
796            return None  # use arg_member()
797        else:
798            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
799
800    def arg_member(self, ri):
801        if self.sub_type == 'binary' and 'exact-len' in self.checks:
802            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
803                    f'unsigned int n_{self.c_name}']
804        return super().arg_member(ri)
805
806    def _attr_typol(self):
807        if self.attr['sub-type'] in scalars:
808            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
809        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
810            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
811        else:
812            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
813
814    def _attr_get(self, ri, var):
815        local_vars = ['const struct nlattr *attr2;']
816        get_lines = [f'attr_{self.c_name} = attr;',
817                     'ynl_attr_for_each_nested(attr2, attr) {',
818                     '\tif (ynl_attr_validate(yarg, attr2))',
819                     '\t\treturn YNL_PARSE_CB_ERROR;',
820                     f'\t{var}->_count.{self.c_name}++;',
821                     '}']
822        return get_lines, None, local_vars
823
824    def attr_put(self, ri, var):
825        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
826        if self.sub_type in scalars:
827            put_type = self.sub_type
828            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
829            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
830            ri.cw.block_end()
831        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
832            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
833            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
834        elif self.sub_type == 'nest':
835            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
836            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
837        else:
838            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
839        ri.cw.p('ynl_attr_nest_end(nlh, array);')
840
841    def _setter_lines(self, ri, member, presence):
842        return [f"{member} = {self.c_name};",
843                f"{presence} = n_{self.c_name};"]
844
845
846class TypeNestTypeValue(Type):
847    def _complex_member_type(self, ri):
848        return self.nested_struct_type
849
850    def _attr_typol(self):
851        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
852
853    def _attr_get(self, ri, var):
854        prev = 'attr'
855        tv_args = ''
856        get_lines = []
857        local_vars = []
858        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
859                      f"parg.data = &{var}->{self.c_name};"]
860        if 'type-value' in self.attr:
861            tv_names = [c_lower(x) for x in self.attr["type-value"]]
862            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
863            local_vars += [f'__u32 {", ".join(tv_names)};']
864            for level in self.attr["type-value"]:
865                level = c_lower(level)
866                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
867                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
868                prev = 'attr_' + level
869
870            tv_args = f", {', '.join(tv_names)}"
871
872        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
873        return get_lines, init_lines, local_vars
874
875
876class TypeSubMessage(TypeUnused):
877    pass
878
879
880class Struct:
881    def __init__(self, family, space_name, type_list=None, inherited=None):
882        self.family = family
883        self.space_name = space_name
884        self.attr_set = family.attr_sets[space_name]
885        # Use list to catch comparisons with empty sets
886        self._inherited = inherited if inherited is not None else []
887        self.inherited = []
888
889        self.nested = type_list is None
890        if family.name == c_lower(space_name):
891            self.render_name = c_lower(family.ident_name)
892        else:
893            self.render_name = c_lower(family.ident_name + '-' + space_name)
894        self.struct_name = 'struct ' + self.render_name
895        if self.nested and space_name in family.consts:
896            self.struct_name += '_'
897        self.ptr_name = self.struct_name + ' *'
898        # All attr sets this one contains, directly or multiple levels down
899        self.child_nests = set()
900
901        self.request = False
902        self.reply = False
903        self.recursive = False
904        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
905
906        self.attr_list = []
907        self.attrs = dict()
908        if type_list is not None:
909            for t in type_list:
910                self.attr_list.append((t, self.attr_set[t]),)
911        else:
912            for t in self.attr_set:
913                self.attr_list.append((t, self.attr_set[t]),)
914
915        max_val = 0
916        self.attr_max_val = None
917        for name, attr in self.attr_list:
918            if attr.value >= max_val:
919                max_val = attr.value
920                self.attr_max_val = attr
921            self.attrs[name] = attr
922
923    def __iter__(self):
924        yield from self.attrs
925
926    def __getitem__(self, key):
927        return self.attrs[key]
928
929    def member_list(self):
930        return self.attr_list
931
932    def set_inherited(self, new_inherited):
933        if self._inherited != new_inherited:
934            raise Exception("Inheriting different members not supported")
935        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
936
937    def free_needs_iter(self):
938        for _, attr in self.attr_list:
939            if attr.free_needs_iter():
940                return True
941        return False
942
943
944class EnumEntry(SpecEnumEntry):
945    def __init__(self, enum_set, yaml, prev, value_start):
946        super().__init__(enum_set, yaml, prev, value_start)
947
948        if prev:
949            self.value_change = (self.value != prev.value + 1)
950        else:
951            self.value_change = (self.value != 0)
952        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
953
954        # Added by resolve:
955        self.c_name = None
956        delattr(self, "c_name")
957
958    def resolve(self):
959        self.resolve_up(super())
960
961        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
962
963
964class EnumSet(SpecEnumSet):
965    def __init__(self, family, yaml):
966        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
967
968        if 'enum-name' in yaml:
969            if yaml['enum-name']:
970                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
971                self.user_type = self.enum_name
972            else:
973                self.enum_name = None
974        else:
975            self.enum_name = 'enum ' + self.render_name
976
977        if self.enum_name:
978            self.user_type = self.enum_name
979        else:
980            self.user_type = 'int'
981
982        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
983        self.header = yaml.get('header', None)
984        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
985
986        super().__init__(family, yaml)
987
988    def new_entry(self, entry, prev_entry, value_start):
989        return EnumEntry(self, entry, prev_entry, value_start)
990
991    def value_range(self):
992        low = min([x.value for x in self.entries.values()])
993        high = max([x.value for x in self.entries.values()])
994
995        if high - low + 1 != len(self.entries):
996            return None, None
997
998        return low, high
999
1000
1001class AttrSet(SpecAttrSet):
1002    def __init__(self, family, yaml):
1003        super().__init__(family, yaml)
1004
1005        if self.subset_of is None:
1006            if 'name-prefix' in yaml:
1007                pfx = yaml['name-prefix']
1008            elif self.name == family.name:
1009                pfx = family.ident_name + '-a-'
1010            else:
1011                pfx = f"{family.ident_name}-a-{self.name}-"
1012            self.name_prefix = c_upper(pfx)
1013            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1014            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1015        else:
1016            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1017            self.max_name = family.attr_sets[self.subset_of].max_name
1018            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1019
1020        # Added by resolve:
1021        self.c_name = None
1022        delattr(self, "c_name")
1023
1024    def resolve(self):
1025        self.c_name = c_lower(self.name)
1026        if self.c_name in _C_KW:
1027            self.c_name += '_'
1028        if self.c_name == self.family.c_name:
1029            self.c_name = ''
1030
1031    def new_attr(self, elem, value):
1032        if elem['type'] in scalars:
1033            t = TypeScalar(self.family, self, elem, value)
1034        elif elem['type'] == 'unused':
1035            t = TypeUnused(self.family, self, elem, value)
1036        elif elem['type'] == 'pad':
1037            t = TypePad(self.family, self, elem, value)
1038        elif elem['type'] == 'flag':
1039            t = TypeFlag(self.family, self, elem, value)
1040        elif elem['type'] == 'string':
1041            t = TypeString(self.family, self, elem, value)
1042        elif elem['type'] == 'binary':
1043            if 'struct' in elem:
1044                t = TypeBinaryStruct(self.family, self, elem, value)
1045            elif elem.get('sub-type') in scalars:
1046                t = TypeBinaryScalarArray(self.family, self, elem, value)
1047            else:
1048                t = TypeBinary(self.family, self, elem, value)
1049        elif elem['type'] == 'bitfield32':
1050            t = TypeBitfield32(self.family, self, elem, value)
1051        elif elem['type'] == 'nest':
1052            t = TypeNest(self.family, self, elem, value)
1053        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1054            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1055                t = TypeArrayNest(self.family, self, elem, value)
1056            else:
1057                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1058        elif elem['type'] == 'nest-type-value':
1059            t = TypeNestTypeValue(self.family, self, elem, value)
1060        elif elem['type'] == 'sub-message':
1061            t = TypeSubMessage(self.family, self, elem, value)
1062        else:
1063            raise Exception(f"No typed class for type {elem['type']}")
1064
1065        if 'multi-attr' in elem and elem['multi-attr']:
1066            t = TypeMultiAttr(self.family, self, elem, value, t)
1067
1068        return t
1069
1070
1071class Operation(SpecOperation):
1072    def __init__(self, family, yaml, req_value, rsp_value):
1073        # Fill in missing operation properties (for fixed hdr-only msgs)
1074        for mode in ['do', 'dump', 'event']:
1075            for direction in ['request', 'reply']:
1076                try:
1077                    yaml[mode][direction].setdefault('attributes', [])
1078                except KeyError:
1079                    pass
1080
1081        super().__init__(family, yaml, req_value, rsp_value)
1082
1083        self.render_name = c_lower(family.ident_name + '_' + self.name)
1084
1085        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1086                         ('dump' in yaml and 'request' in yaml['dump'])
1087
1088        self.has_ntf = False
1089
1090        # Added by resolve:
1091        self.enum_name = None
1092        delattr(self, "enum_name")
1093
1094    def resolve(self):
1095        self.resolve_up(super())
1096
1097        if not self.is_async:
1098            self.enum_name = self.family.op_prefix + c_upper(self.name)
1099        else:
1100            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1101
1102    def mark_has_ntf(self):
1103        self.has_ntf = True
1104
1105
1106class SubMessage(SpecSubMessage):
1107    def __init__(self, family, yaml):
1108        super().__init__(family, yaml)
1109
1110        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1111
1112    def resolve(self):
1113        self.resolve_up(super())
1114
1115
1116class Family(SpecFamily):
1117    def __init__(self, file_name, exclude_ops):
1118        # Added by resolve:
1119        self.c_name = None
1120        delattr(self, "c_name")
1121        self.op_prefix = None
1122        delattr(self, "op_prefix")
1123        self.async_op_prefix = None
1124        delattr(self, "async_op_prefix")
1125        self.mcgrps = None
1126        delattr(self, "mcgrps")
1127        self.consts = None
1128        delattr(self, "consts")
1129        self.hooks = None
1130        delattr(self, "hooks")
1131
1132        super().__init__(file_name, exclude_ops=exclude_ops)
1133
1134        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1135        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1136
1137        if 'definitions' not in self.yaml:
1138            self.yaml['definitions'] = []
1139
1140        if 'uapi-header' in self.yaml:
1141            self.uapi_header = self.yaml['uapi-header']
1142        else:
1143            self.uapi_header = f"linux/{self.ident_name}.h"
1144        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1145            self.uapi_header_name = self.uapi_header[6:-2]
1146        else:
1147            self.uapi_header_name = self.ident_name
1148
1149    def resolve(self):
1150        self.resolve_up(super())
1151
1152        self.c_name = c_lower(self.ident_name)
1153        if 'name-prefix' in self.yaml['operations']:
1154            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1155        else:
1156            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1157        if 'async-prefix' in self.yaml['operations']:
1158            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1159        else:
1160            self.async_op_prefix = self.op_prefix
1161
1162        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1163
1164        self.hooks = dict()
1165        for when in ['pre', 'post']:
1166            self.hooks[when] = dict()
1167            for op_mode in ['do', 'dump']:
1168                self.hooks[when][op_mode] = dict()
1169                self.hooks[when][op_mode]['set'] = set()
1170                self.hooks[when][op_mode]['list'] = []
1171
1172        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1173        self.root_sets = dict()
1174        # dict space-name -> Struct
1175        self.pure_nested_structs = dict()
1176
1177        self._mark_notify()
1178        self._mock_up_events()
1179
1180        self._load_root_sets()
1181        self._load_nested_sets()
1182        self._load_attr_use()
1183        self._load_hooks()
1184
1185        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1186        if self.kernel_policy == 'global':
1187            self._load_global_policy()
1188
1189    def new_enum(self, elem):
1190        return EnumSet(self, elem)
1191
1192    def new_attr_set(self, elem):
1193        return AttrSet(self, elem)
1194
1195    def new_operation(self, elem, req_value, rsp_value):
1196        return Operation(self, elem, req_value, rsp_value)
1197
1198    def new_sub_message(self, elem):
1199        return SubMessage(self, elem)
1200
1201    def is_classic(self):
1202        return self.proto == 'netlink-raw'
1203
1204    def _mark_notify(self):
1205        for op in self.msgs.values():
1206            if 'notify' in op:
1207                self.ops[op['notify']].mark_has_ntf()
1208
1209    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1210    def _mock_up_events(self):
1211        for op in self.yaml['operations']['list']:
1212            if 'event' in op:
1213                op['do'] = {
1214                    'reply': {
1215                        'attributes': op['event']['attributes']
1216                    }
1217                }
1218
1219    def _load_root_sets(self):
1220        for op_name, op in self.msgs.items():
1221            if 'attribute-set' not in op:
1222                continue
1223
1224            req_attrs = set()
1225            rsp_attrs = set()
1226            for op_mode in ['do', 'dump']:
1227                if op_mode in op and 'request' in op[op_mode]:
1228                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1229                if op_mode in op and 'reply' in op[op_mode]:
1230                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1231            if 'event' in op:
1232                rsp_attrs.update(set(op['event']['attributes']))
1233
1234            if op['attribute-set'] not in self.root_sets:
1235                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1236            else:
1237                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1238                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1239
1240    def _sort_pure_types(self):
1241        # Try to reorder according to dependencies
1242        pns_key_list = list(self.pure_nested_structs.keys())
1243        pns_key_seen = set()
1244        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1245        for _ in range(rounds):
1246            if len(pns_key_list) == 0:
1247                break
1248            name = pns_key_list.pop(0)
1249            finished = True
1250            for _, spec in self.attr_sets[name].items():
1251                if 'nested-attributes' in spec:
1252                    nested = spec['nested-attributes']
1253                else:
1254                    continue
1255
1256                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1257                if self.pure_nested_structs[nested].recursive:
1258                    continue
1259                if nested not in pns_key_seen:
1260                    # Dicts are sorted, this will make struct last
1261                    struct = self.pure_nested_structs.pop(name)
1262                    self.pure_nested_structs[name] = struct
1263                    finished = False
1264                    break
1265            if finished:
1266                pns_key_seen.add(name)
1267            else:
1268                pns_key_list.append(name)
1269
1270    def _load_nested_set_nest(self, spec):
1271        inherit = set()
1272        nested = spec['nested-attributes']
1273        if nested not in self.root_sets:
1274            if nested not in self.pure_nested_structs:
1275                self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1276        else:
1277            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1278
1279        if 'type-value' in spec:
1280            if nested in self.root_sets:
1281                raise Exception("Inheriting members to a space used as root not supported")
1282            inherit.update(set(spec['type-value']))
1283        elif spec['type'] == 'indexed-array':
1284            inherit.add('idx')
1285        self.pure_nested_structs[nested].set_inherited(inherit)
1286
1287        return nested
1288
1289    def _load_nested_sets(self):
1290        attr_set_queue = list(self.root_sets.keys())
1291        attr_set_seen = set(self.root_sets.keys())
1292
1293        while len(attr_set_queue):
1294            a_set = attr_set_queue.pop(0)
1295            for attr, spec in self.attr_sets[a_set].items():
1296                if 'nested-attributes' in spec:
1297                    nested = self._load_nested_set_nest(spec)
1298                else:
1299                    continue
1300
1301                if nested not in attr_set_seen:
1302                    attr_set_queue.append(nested)
1303                    attr_set_seen.add(nested)
1304
1305        for root_set, rs_members in self.root_sets.items():
1306            for attr, spec in self.attr_sets[root_set].items():
1307                if 'nested-attributes' in spec:
1308                    nested = spec['nested-attributes']
1309                else:
1310                    nested = None
1311
1312                if nested:
1313                    if attr in rs_members['request']:
1314                        self.pure_nested_structs[nested].request = True
1315                    if attr in rs_members['reply']:
1316                        self.pure_nested_structs[nested].reply = True
1317
1318                    if spec.is_multi_val():
1319                        child = self.pure_nested_structs.get(nested)
1320                        child.in_multi_val = True
1321
1322        self._sort_pure_types()
1323
1324        # Propagate the request / reply / recursive
1325        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1326            for _, spec in self.attr_sets[attr_set].items():
1327                if attr_set in struct.child_nests:
1328                    struct.recursive = True
1329
1330                if 'nested-attributes' in spec:
1331                    child_name = spec['nested-attributes']
1332                else:
1333                    continue
1334
1335                struct.child_nests.add(child_name)
1336                child = self.pure_nested_structs.get(child_name)
1337                if child:
1338                    if not child.recursive:
1339                        struct.child_nests.update(child.child_nests)
1340                    child.request |= struct.request
1341                    child.reply |= struct.reply
1342                    if spec.is_multi_val():
1343                        child.in_multi_val = True
1344
1345        self._sort_pure_types()
1346
1347    def _load_attr_use(self):
1348        for _, struct in self.pure_nested_structs.items():
1349            if struct.request:
1350                for _, arg in struct.member_list():
1351                    arg.set_request()
1352            if struct.reply:
1353                for _, arg in struct.member_list():
1354                    arg.set_reply()
1355
1356        for root_set, rs_members in self.root_sets.items():
1357            for attr, spec in self.attr_sets[root_set].items():
1358                if attr in rs_members['request']:
1359                    spec.set_request()
1360                if attr in rs_members['reply']:
1361                    spec.set_reply()
1362
1363    def _load_global_policy(self):
1364        global_set = set()
1365        attr_set_name = None
1366        for op_name, op in self.ops.items():
1367            if not op:
1368                continue
1369            if 'attribute-set' not in op:
1370                continue
1371
1372            if attr_set_name is None:
1373                attr_set_name = op['attribute-set']
1374            if attr_set_name != op['attribute-set']:
1375                raise Exception('For a global policy all ops must use the same set')
1376
1377            for op_mode in ['do', 'dump']:
1378                if op_mode in op:
1379                    req = op[op_mode].get('request')
1380                    if req:
1381                        global_set.update(req.get('attributes', []))
1382
1383        self.global_policy = []
1384        self.global_policy_set = attr_set_name
1385        for attr in self.attr_sets[attr_set_name]:
1386            if attr in global_set:
1387                self.global_policy.append(attr)
1388
1389    def _load_hooks(self):
1390        for op in self.ops.values():
1391            for op_mode in ['do', 'dump']:
1392                if op_mode not in op:
1393                    continue
1394                for when in ['pre', 'post']:
1395                    if when not in op[op_mode]:
1396                        continue
1397                    name = op[op_mode][when]
1398                    if name in self.hooks[when][op_mode]['set']:
1399                        continue
1400                    self.hooks[when][op_mode]['set'].add(name)
1401                    self.hooks[when][op_mode]['list'].append(name)
1402
1403
1404class RenderInfo:
1405    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1406        self.family = family
1407        self.nl = cw.nlib
1408        self.ku_space = ku_space
1409        self.op_mode = op_mode
1410        self.op = op
1411
1412        self.fixed_hdr = None
1413        self.fixed_hdr_len = 'ys->family->hdr_len'
1414        if op and op.fixed_header:
1415            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1416            if op.fixed_header != family.fixed_header:
1417                if family.is_classic():
1418                    self.fixed_hdr_len = f"sizeof({self.fixed_hdr})"
1419                else:
1420                    raise Exception(f"Per-op fixed header not supported, yet")
1421
1422
1423        # 'do' and 'dump' response parsing is identical
1424        self.type_consistent = True
1425        self.type_oneside = False
1426        if op_mode != 'do' and 'dump' in op:
1427            if 'do' in op:
1428                if ('reply' in op['do']) != ('reply' in op["dump"]):
1429                    self.type_consistent = False
1430                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1431                    self.type_consistent = False
1432            else:
1433                self.type_consistent = True
1434                self.type_oneside = True
1435
1436        self.attr_set = attr_set
1437        if not self.attr_set:
1438            self.attr_set = op['attribute-set']
1439
1440        self.type_name_conflict = False
1441        if op:
1442            self.type_name = c_lower(op.name)
1443        else:
1444            self.type_name = c_lower(attr_set)
1445            if attr_set in family.consts:
1446                self.type_name_conflict = True
1447
1448        self.cw = cw
1449
1450        self.struct = dict()
1451        if op_mode == 'notify':
1452            op_mode = 'do' if 'do' in op else 'dump'
1453        for op_dir in ['request', 'reply']:
1454            if op:
1455                type_list = []
1456                if op_dir in op[op_mode]:
1457                    type_list = op[op_mode][op_dir]['attributes']
1458                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1459        if op_mode == 'event':
1460            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1461
1462    def type_empty(self, key):
1463        return len(self.struct[key].attr_list) == 0 and self.fixed_hdr is None
1464
1465    def needs_nlflags(self, direction):
1466        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1467
1468
1469class CodeWriter:
1470    def __init__(self, nlib, out_file=None, overwrite=True):
1471        self.nlib = nlib
1472        self._overwrite = overwrite
1473
1474        self._nl = False
1475        self._block_end = False
1476        self._silent_block = False
1477        self._ind = 0
1478        self._ifdef_block = None
1479        if out_file is None:
1480            self._out = os.sys.stdout
1481        else:
1482            self._out = tempfile.NamedTemporaryFile('w+')
1483            self._out_file = out_file
1484
1485    def __del__(self):
1486        self.close_out_file()
1487
1488    def close_out_file(self):
1489        if self._out == os.sys.stdout:
1490            return
1491        # Avoid modifying the file if contents didn't change
1492        self._out.flush()
1493        if not self._overwrite and os.path.isfile(self._out_file):
1494            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1495                return
1496        with open(self._out_file, 'w+') as out_file:
1497            self._out.seek(0)
1498            shutil.copyfileobj(self._out, out_file)
1499            self._out.close()
1500        self._out = os.sys.stdout
1501
1502    @classmethod
1503    def _is_cond(cls, line):
1504        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1505
1506    def p(self, line, add_ind=0):
1507        if self._block_end:
1508            self._block_end = False
1509            if line.startswith('else'):
1510                line = '} ' + line
1511            else:
1512                self._out.write('\t' * self._ind + '}\n')
1513
1514        if self._nl:
1515            self._out.write('\n')
1516            self._nl = False
1517
1518        ind = self._ind
1519        if line[-1] == ':':
1520            ind -= 1
1521        if self._silent_block:
1522            ind += 1
1523        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1524        self._silent_block |= line.strip() == 'else'
1525        if line[0] == '#':
1526            ind = 0
1527        if add_ind:
1528            ind += add_ind
1529        self._out.write('\t' * ind + line + '\n')
1530
1531    def nl(self):
1532        self._nl = True
1533
1534    def block_start(self, line=''):
1535        if line:
1536            line = line + ' '
1537        self.p(line + '{')
1538        self._ind += 1
1539
1540    def block_end(self, line=''):
1541        if line and line[0] not in {';', ','}:
1542            line = ' ' + line
1543        self._ind -= 1
1544        self._nl = False
1545        if not line:
1546            # Delay printing closing bracket in case "else" comes next
1547            if self._block_end:
1548                self._out.write('\t' * (self._ind + 1) + '}\n')
1549            self._block_end = True
1550        else:
1551            self.p('}' + line)
1552
1553    def write_doc_line(self, doc, indent=True):
1554        words = doc.split()
1555        line = ' *'
1556        for word in words:
1557            if len(line) + len(word) >= 79:
1558                self.p(line)
1559                line = ' *'
1560                if indent:
1561                    line += '  '
1562            line += ' ' + word
1563        self.p(line)
1564
1565    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1566        if not args:
1567            args = ['void']
1568
1569        if doc:
1570            self.p('/*')
1571            self.p(' * ' + doc)
1572            self.p(' */')
1573
1574        oneline = qual_ret
1575        if qual_ret[-1] != '*':
1576            oneline += ' '
1577        oneline += f"{name}({', '.join(args)}){suffix}"
1578
1579        if len(oneline) < 80:
1580            self.p(oneline)
1581            return
1582
1583        v = qual_ret
1584        if len(v) > 3:
1585            self.p(v)
1586            v = ''
1587        elif qual_ret[-1] != '*':
1588            v += ' '
1589        v += name + '('
1590        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1591        delta_ind = len(v) - len(ind)
1592        v += args[0]
1593        i = 1
1594        while i < len(args):
1595            next_len = len(v) + len(args[i])
1596            if v[0] == '\t':
1597                next_len += delta_ind
1598            if next_len > 76:
1599                self.p(v + ',')
1600                v = ind
1601            else:
1602                v += ', '
1603            v += args[i]
1604            i += 1
1605        self.p(v + ')' + suffix)
1606
1607    def write_func_lvar(self, local_vars):
1608        if not local_vars:
1609            return
1610
1611        if type(local_vars) is str:
1612            local_vars = [local_vars]
1613
1614        local_vars.sort(key=len, reverse=True)
1615        for var in local_vars:
1616            self.p(var)
1617        self.nl()
1618
1619    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1620        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1621        self.block_start()
1622        self.write_func_lvar(local_vars=local_vars)
1623
1624        for line in body:
1625            self.p(line)
1626        self.block_end()
1627
1628    def writes_defines(self, defines):
1629        longest = 0
1630        for define in defines:
1631            if len(define[0]) > longest:
1632                longest = len(define[0])
1633        longest = ((longest + 8) // 8) * 8
1634        for define in defines:
1635            line = '#define ' + define[0]
1636            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1637            if type(define[1]) is int:
1638                line += str(define[1])
1639            elif type(define[1]) is str:
1640                line += '"' + define[1] + '"'
1641            self.p(line)
1642
1643    def write_struct_init(self, members):
1644        longest = max([len(x[0]) for x in members])
1645        longest += 1  # because we prepend a .
1646        longest = ((longest + 8) // 8) * 8
1647        for one in members:
1648            line = '.' + one[0]
1649            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1650            line += '= ' + str(one[1]) + ','
1651            self.p(line)
1652
1653    def ifdef_block(self, config):
1654        config_option = None
1655        if config:
1656            config_option = 'CONFIG_' + c_upper(config)
1657        if self._ifdef_block == config_option:
1658            return
1659
1660        if self._ifdef_block:
1661            self.p('#endif /* ' + self._ifdef_block + ' */')
1662        if config_option:
1663            self.p('#ifdef ' + config_option)
1664        self._ifdef_block = config_option
1665
1666
1667scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1668
1669direction_to_suffix = {
1670    'reply': '_rsp',
1671    'request': '_req',
1672    '': ''
1673}
1674
1675op_mode_to_wrapper = {
1676    'do': '',
1677    'dump': '_list',
1678    'notify': '_ntf',
1679    'event': '',
1680}
1681
1682_C_KW = {
1683    'auto',
1684    'bool',
1685    'break',
1686    'case',
1687    'char',
1688    'const',
1689    'continue',
1690    'default',
1691    'do',
1692    'double',
1693    'else',
1694    'enum',
1695    'extern',
1696    'float',
1697    'for',
1698    'goto',
1699    'if',
1700    'inline',
1701    'int',
1702    'long',
1703    'register',
1704    'return',
1705    'short',
1706    'signed',
1707    'sizeof',
1708    'static',
1709    'struct',
1710    'switch',
1711    'typedef',
1712    'union',
1713    'unsigned',
1714    'void',
1715    'volatile',
1716    'while'
1717}
1718
1719
1720def rdir(direction):
1721    if direction == 'reply':
1722        return 'request'
1723    if direction == 'request':
1724        return 'reply'
1725    return direction
1726
1727
1728def op_prefix(ri, direction, deref=False):
1729    suffix = f"_{ri.type_name}"
1730
1731    if not ri.op_mode or ri.op_mode == 'do':
1732        suffix += f"{direction_to_suffix[direction]}"
1733    else:
1734        if direction == 'request':
1735            suffix += '_req'
1736            if not ri.type_oneside:
1737                suffix += '_dump'
1738        else:
1739            if ri.type_consistent:
1740                if deref:
1741                    suffix += f"{direction_to_suffix[direction]}"
1742                else:
1743                    suffix += op_mode_to_wrapper[ri.op_mode]
1744            else:
1745                suffix += '_rsp'
1746                suffix += '_dump' if deref else '_list'
1747
1748    return f"{ri.family.c_name}{suffix}"
1749
1750
1751def type_name(ri, direction, deref=False):
1752    return f"struct {op_prefix(ri, direction, deref=deref)}"
1753
1754
1755def print_prototype(ri, direction, terminate=True, doc=None):
1756    suffix = ';' if terminate else ''
1757
1758    fname = ri.op.render_name
1759    if ri.op_mode == 'dump':
1760        fname += '_dump'
1761
1762    args = ['struct ynl_sock *ys']
1763    if 'request' in ri.op[ri.op_mode]:
1764        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1765
1766    ret = 'int'
1767    if 'reply' in ri.op[ri.op_mode]:
1768        ret = f"{type_name(ri, rdir(direction))} *"
1769
1770    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1771
1772
1773def print_req_prototype(ri):
1774    print_prototype(ri, "request", doc=ri.op['doc'])
1775
1776
1777def print_dump_prototype(ri):
1778    print_prototype(ri, "request")
1779
1780
1781def put_typol_fwd(cw, struct):
1782    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1783
1784
1785def put_typol(cw, struct):
1786    type_max = struct.attr_set.max_name
1787    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1788
1789    for _, arg in struct.member_list():
1790        arg.attr_typol(cw)
1791
1792    cw.block_end(line=';')
1793    cw.nl()
1794
1795    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1796    cw.p(f'.max_attr = {type_max},')
1797    cw.p(f'.table = {struct.render_name}_policy,')
1798    cw.block_end(line=';')
1799    cw.nl()
1800
1801
1802def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1803    args = [f'int {arg_name}']
1804    if enum:
1805        args = [enum.user_type + ' ' + arg_name]
1806    cw.write_func_prot('const char *', f'{render_name}_str', args)
1807    cw.block_start()
1808    if enum and enum.type == 'flags':
1809        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1810    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1811    cw.p('return NULL;')
1812    cw.p(f'return {map_name}[{arg_name}];')
1813    cw.block_end()
1814    cw.nl()
1815
1816
1817def put_op_name_fwd(family, cw):
1818    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1819
1820
1821def put_op_name(family, cw):
1822    map_name = f'{family.c_name}_op_strmap'
1823    cw.block_start(line=f"static const char * const {map_name}[] =")
1824    for op_name, op in family.msgs.items():
1825        if op.rsp_value:
1826            # Make sure we don't add duplicated entries, if multiple commands
1827            # produce the same response in legacy families.
1828            if family.rsp_by_value[op.rsp_value] != op:
1829                cw.p(f'// skip "{op_name}", duplicate reply value')
1830                continue
1831
1832            if op.req_value == op.rsp_value:
1833                cw.p(f'[{op.enum_name}] = "{op_name}",')
1834            else:
1835                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1836    cw.block_end(line=';')
1837    cw.nl()
1838
1839    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1840
1841
1842def put_enum_to_str_fwd(family, cw, enum):
1843    args = [enum.user_type + ' value']
1844    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1845
1846
1847def put_enum_to_str(family, cw, enum):
1848    map_name = f'{enum.render_name}_strmap'
1849    cw.block_start(line=f"static const char * const {map_name}[] =")
1850    for entry in enum.entries.values():
1851        cw.p(f'[{entry.value}] = "{entry.name}",')
1852    cw.block_end(line=';')
1853    cw.nl()
1854
1855    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1856
1857
1858def put_req_nested_prototype(ri, struct, suffix=';'):
1859    func_args = ['struct nlmsghdr *nlh',
1860                 'unsigned int attr_type',
1861                 f'{struct.ptr_name}obj']
1862
1863    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1864                          suffix=suffix)
1865
1866
1867def put_req_nested(ri, struct):
1868    local_vars = []
1869    init_lines = []
1870
1871    local_vars.append('struct nlattr *nest;')
1872    init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
1873
1874    has_anest = False
1875    has_count = False
1876    for _, arg in struct.member_list():
1877        has_anest |= arg.type == 'indexed-array'
1878        has_count |= arg.presence_type() == 'count'
1879    if has_anest:
1880        local_vars.append('struct nlattr *array;')
1881    if has_count:
1882        local_vars.append('unsigned int i;')
1883
1884    put_req_nested_prototype(ri, struct, suffix='')
1885    ri.cw.block_start()
1886    ri.cw.write_func_lvar(local_vars)
1887
1888    for line in init_lines:
1889        ri.cw.p(line)
1890
1891    for _, arg in struct.member_list():
1892        arg.attr_put(ri, "obj")
1893
1894    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1895
1896    ri.cw.nl()
1897    ri.cw.p('return 0;')
1898    ri.cw.block_end()
1899    ri.cw.nl()
1900
1901
1902def _multi_parse(ri, struct, init_lines, local_vars):
1903    if struct.nested:
1904        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1905    else:
1906        if ri.fixed_hdr:
1907            local_vars += ['void *hdr;']
1908        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1909        if ri.op.fixed_header != ri.family.fixed_header:
1910            if ri.family.is_classic():
1911                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({ri.fixed_hdr}))"
1912            else:
1913                raise Exception(f"Per-op fixed header not supported, yet")
1914
1915    array_nests = set()
1916    multi_attrs = set()
1917    needs_parg = False
1918    for arg, aspec in struct.member_list():
1919        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
1920            if aspec["sub-type"] in {'binary', 'nest'}:
1921                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1922                array_nests.add(arg)
1923            elif aspec['sub-type'] in scalars:
1924                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1925                array_nests.add(arg)
1926            else:
1927                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
1928        if 'multi-attr' in aspec:
1929            multi_attrs.add(arg)
1930        needs_parg |= 'nested-attributes' in aspec
1931    if array_nests or multi_attrs:
1932        local_vars.append('int i;')
1933    if needs_parg:
1934        local_vars.append('struct ynl_parse_arg parg;')
1935        init_lines.append('parg.ys = yarg->ys;')
1936
1937    all_multi = array_nests | multi_attrs
1938
1939    for anest in sorted(all_multi):
1940        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1941
1942    ri.cw.block_start()
1943    ri.cw.write_func_lvar(local_vars)
1944
1945    for line in init_lines:
1946        ri.cw.p(line)
1947    ri.cw.nl()
1948
1949    for arg in struct.inherited:
1950        ri.cw.p(f'dst->{arg} = {arg};')
1951
1952    if ri.fixed_hdr:
1953        if ri.family.is_classic():
1954            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
1955        else:
1956            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1957        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1958    for anest in sorted(all_multi):
1959        aspec = struct[anest]
1960        ri.cw.p(f"if (dst->{aspec.c_name})")
1961        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1962
1963    ri.cw.nl()
1964    ri.cw.block_start(line=iter_line)
1965    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1966    ri.cw.nl()
1967
1968    first = True
1969    for _, arg in struct.member_list():
1970        good = arg.attr_get(ri, 'dst', first=first)
1971        # First may be 'unused' or 'pad', ignore those
1972        first &= not good
1973
1974    ri.cw.block_end()
1975    ri.cw.nl()
1976
1977    for anest in sorted(array_nests):
1978        aspec = struct[anest]
1979
1980        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1981        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1982        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
1983        ri.cw.p('i = 0;')
1984        if 'nested-attributes' in aspec:
1985            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1986        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1987        if 'nested-attributes' in aspec:
1988            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1989            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1990            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1991        elif aspec.sub_type in scalars:
1992            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
1993        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
1994            # Length is validated by typol
1995            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
1996        else:
1997            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
1998        ri.cw.p('i++;')
1999        ri.cw.block_end()
2000        ri.cw.block_end()
2001    ri.cw.nl()
2002
2003    for anest in sorted(multi_attrs):
2004        aspec = struct[anest]
2005        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2006        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2007        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2008        ri.cw.p('i = 0;')
2009        if 'nested-attributes' in aspec:
2010            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2011        ri.cw.block_start(line=iter_line)
2012        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2013        if 'nested-attributes' in aspec:
2014            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2015            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2016            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2017        elif aspec.type in scalars:
2018            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2019        elif aspec.type == 'binary' and 'struct' in aspec:
2020            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2021            ri.cw.nl()
2022            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2023            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2024            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2025        elif aspec.type == 'string':
2026            ri.cw.p('unsigned int len;')
2027            ri.cw.nl()
2028            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2029            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2030            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2031            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2032            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2033        else:
2034            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2035        ri.cw.p('i++;')
2036        ri.cw.block_end()
2037        ri.cw.block_end()
2038        ri.cw.block_end()
2039    ri.cw.nl()
2040
2041    if struct.nested:
2042        ri.cw.p('return 0;')
2043    else:
2044        ri.cw.p('return YNL_PARSE_CB_OK;')
2045    ri.cw.block_end()
2046    ri.cw.nl()
2047
2048
2049def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2050    func_args = ['struct ynl_parse_arg *yarg',
2051                 'const struct nlattr *nested']
2052    for arg in struct.inherited:
2053        func_args.append('__u32 ' + arg)
2054
2055    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2056                          suffix=suffix)
2057
2058
2059def parse_rsp_nested(ri, struct):
2060    parse_rsp_nested_prototype(ri, struct, suffix='')
2061
2062    local_vars = ['const struct nlattr *attr;',
2063                  f'{struct.ptr_name}dst = yarg->data;']
2064    init_lines = []
2065
2066    if struct.member_list():
2067        _multi_parse(ri, struct, init_lines, local_vars)
2068    else:
2069        # Empty nest
2070        ri.cw.block_start()
2071        ri.cw.p('return 0;')
2072        ri.cw.block_end()
2073        ri.cw.nl()
2074
2075
2076def parse_rsp_msg(ri, deref=False):
2077    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2078        return
2079
2080    func_args = ['const struct nlmsghdr *nlh',
2081                 'struct ynl_parse_arg *yarg']
2082
2083    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2084                  'const struct nlattr *attr;']
2085    init_lines = ['dst = yarg->data;']
2086
2087    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2088
2089    if ri.struct["reply"].member_list():
2090        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2091    else:
2092        # Empty reply
2093        ri.cw.block_start()
2094        ri.cw.p('return YNL_PARSE_CB_OK;')
2095        ri.cw.block_end()
2096        ri.cw.nl()
2097
2098
2099def print_req(ri):
2100    ret_ok = '0'
2101    ret_err = '-1'
2102    direction = "request"
2103    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2104                  'struct nlmsghdr *nlh;',
2105                  'int err;']
2106
2107    if 'reply' in ri.op[ri.op_mode]:
2108        ret_ok = 'rsp'
2109        ret_err = 'NULL'
2110        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2111
2112    if ri.fixed_hdr:
2113        local_vars += ['size_t hdr_len;',
2114                       'void *hdr;']
2115
2116    for _, attr in ri.struct["request"].member_list():
2117        if attr.presence_type() == 'count':
2118            local_vars += ['unsigned int i;']
2119            break
2120
2121    print_prototype(ri, direction, terminate=False)
2122    ri.cw.block_start()
2123    ri.cw.write_func_lvar(local_vars)
2124
2125    if ri.family.is_classic():
2126        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2127    else:
2128        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2129
2130    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2131    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2132    if 'reply' in ri.op[ri.op_mode]:
2133        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2134    ri.cw.nl()
2135
2136    if ri.fixed_hdr:
2137        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2138        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2139        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2140        ri.cw.nl()
2141
2142    for _, attr in ri.struct["request"].member_list():
2143        attr.attr_put(ri, "req")
2144    ri.cw.nl()
2145
2146    if 'reply' in ri.op[ri.op_mode]:
2147        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2148        ri.cw.p('yrs.yarg.data = rsp;')
2149        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2150        if ri.op.value is not None:
2151            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2152        else:
2153            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2154        ri.cw.nl()
2155    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2156    ri.cw.p('if (err < 0)')
2157    if 'reply' in ri.op[ri.op_mode]:
2158        ri.cw.p('goto err_free;')
2159    else:
2160        ri.cw.p('return -1;')
2161    ri.cw.nl()
2162
2163    ri.cw.p(f"return {ret_ok};")
2164    ri.cw.nl()
2165
2166    if 'reply' in ri.op[ri.op_mode]:
2167        ri.cw.p('err_free:')
2168        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2169        ri.cw.p(f"return {ret_err};")
2170
2171    ri.cw.block_end()
2172
2173
2174def print_dump(ri):
2175    direction = "request"
2176    print_prototype(ri, direction, terminate=False)
2177    ri.cw.block_start()
2178    local_vars = ['struct ynl_dump_state yds = {};',
2179                  'struct nlmsghdr *nlh;',
2180                  'int err;']
2181
2182    if ri.fixed_hdr:
2183        local_vars += ['size_t hdr_len;',
2184                       'void *hdr;']
2185
2186    ri.cw.write_func_lvar(local_vars)
2187
2188    ri.cw.p('yds.yarg.ys = ys;')
2189    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2190    ri.cw.p("yds.yarg.data = NULL;")
2191    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2192    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2193    if ri.op.value is not None:
2194        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2195    else:
2196        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2197    ri.cw.nl()
2198    if ri.family.is_classic():
2199        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2200    else:
2201        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2202
2203    if ri.fixed_hdr:
2204        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2205        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2206        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2207        ri.cw.nl()
2208
2209    if "request" in ri.op[ri.op_mode]:
2210        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2211        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2212        ri.cw.nl()
2213        for _, attr in ri.struct["request"].member_list():
2214            attr.attr_put(ri, "req")
2215    ri.cw.nl()
2216
2217    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2218    ri.cw.p('if (err < 0)')
2219    ri.cw.p('goto free_list;')
2220    ri.cw.nl()
2221
2222    ri.cw.p('return yds.first;')
2223    ri.cw.nl()
2224    ri.cw.p('free_list:')
2225    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2226    ri.cw.p('return NULL;')
2227    ri.cw.block_end()
2228
2229
2230def call_free(ri, direction, var):
2231    return f"{op_prefix(ri, direction)}_free({var});"
2232
2233
2234def free_arg_name(direction):
2235    if direction:
2236        return direction_to_suffix[direction][1:]
2237    return 'obj'
2238
2239
2240def print_alloc_wrapper(ri, direction):
2241    name = op_prefix(ri, direction)
2242    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
2243    ri.cw.block_start()
2244    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
2245    ri.cw.block_end()
2246
2247
2248def print_free_prototype(ri, direction, suffix=';'):
2249    name = op_prefix(ri, direction)
2250    struct_name = name
2251    if ri.type_name_conflict:
2252        struct_name += '_'
2253    arg = free_arg_name(direction)
2254    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2255
2256
2257def print_nlflags_set(ri, direction):
2258    name = op_prefix(ri, direction)
2259    ri.cw.write_func_prot(f'static inline void', f"{name}_set_nlflags",
2260                          [f"struct {name} *req", "__u16 nl_flags"])
2261    ri.cw.block_start()
2262    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2263    ri.cw.block_end()
2264    ri.cw.nl()
2265
2266
2267def _print_type(ri, direction, struct):
2268    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2269    if not direction and ri.type_name_conflict:
2270        suffix += '_'
2271
2272    if ri.op_mode == 'dump' and not ri.type_oneside:
2273        suffix += '_dump'
2274
2275    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2276
2277    if ri.needs_nlflags(direction):
2278        ri.cw.p('__u16 _nlmsg_flags;')
2279        ri.cw.nl()
2280    if ri.fixed_hdr:
2281        ri.cw.p(ri.fixed_hdr + ' _hdr;')
2282        ri.cw.nl()
2283
2284    for type_filter in ['present', 'len', 'count']:
2285        meta_started = False
2286        for _, attr in struct.member_list():
2287            line = attr.presence_member(ri.ku_space, type_filter)
2288            if line:
2289                if not meta_started:
2290                    ri.cw.block_start(line=f"struct")
2291                    meta_started = True
2292                ri.cw.p(line)
2293        if meta_started:
2294            ri.cw.block_end(line=f'_{type_filter};')
2295    ri.cw.nl()
2296
2297    for arg in struct.inherited:
2298        ri.cw.p(f"__u32 {arg};")
2299
2300    for _, attr in struct.member_list():
2301        attr.struct_member(ri)
2302
2303    ri.cw.block_end(line=';')
2304    ri.cw.nl()
2305
2306
2307def print_type(ri, direction):
2308    _print_type(ri, direction, ri.struct[direction])
2309
2310
2311def print_type_full(ri, struct):
2312    _print_type(ri, "", struct)
2313
2314
2315def print_type_helpers(ri, direction, deref=False):
2316    print_free_prototype(ri, direction)
2317    ri.cw.nl()
2318
2319    if ri.needs_nlflags(direction):
2320        print_nlflags_set(ri, direction)
2321
2322    if ri.ku_space == 'user' and direction == 'request':
2323        for _, attr in ri.struct[direction].member_list():
2324            attr.setter(ri, ri.attr_set, direction, deref=deref)
2325    ri.cw.nl()
2326
2327
2328def print_req_type_helpers(ri):
2329    if ri.type_empty("request"):
2330        return
2331    print_alloc_wrapper(ri, "request")
2332    print_type_helpers(ri, "request")
2333
2334
2335def print_rsp_type_helpers(ri):
2336    if 'reply' not in ri.op[ri.op_mode]:
2337        return
2338    print_type_helpers(ri, "reply")
2339
2340
2341def print_parse_prototype(ri, direction, terminate=True):
2342    suffix = "_rsp" if direction == "reply" else "_req"
2343    term = ';' if terminate else ''
2344
2345    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2346                          ['const struct nlattr **tb',
2347                           f"struct {ri.op.render_name}{suffix} *req"],
2348                          suffix=term)
2349
2350
2351def print_req_type(ri):
2352    if ri.type_empty("request"):
2353        return
2354    print_type(ri, "request")
2355
2356
2357def print_req_free(ri):
2358    if 'request' not in ri.op[ri.op_mode]:
2359        return
2360    _free_type(ri, 'request', ri.struct['request'])
2361
2362
2363def print_rsp_type(ri):
2364    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2365        direction = 'reply'
2366    elif ri.op_mode == 'event':
2367        direction = 'reply'
2368    else:
2369        return
2370    print_type(ri, direction)
2371
2372
2373def print_wrapped_type(ri):
2374    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2375    if ri.op_mode == 'dump':
2376        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2377    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2378        ri.cw.p('__u16 family;')
2379        ri.cw.p('__u8 cmd;')
2380        ri.cw.p('struct ynl_ntf_base_type *next;')
2381        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2382    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2383    ri.cw.block_end(line=';')
2384    ri.cw.nl()
2385    print_free_prototype(ri, 'reply')
2386    ri.cw.nl()
2387
2388
2389def _free_type_members_iter(ri, struct):
2390    if struct.free_needs_iter():
2391        ri.cw.p('unsigned int i;')
2392        ri.cw.nl()
2393
2394
2395def _free_type_members(ri, var, struct, ref=''):
2396    for _, attr in struct.member_list():
2397        attr.free(ri, var, ref)
2398
2399
2400def _free_type(ri, direction, struct):
2401    var = free_arg_name(direction)
2402
2403    print_free_prototype(ri, direction, suffix='')
2404    ri.cw.block_start()
2405    _free_type_members_iter(ri, struct)
2406    _free_type_members(ri, var, struct)
2407    if direction:
2408        ri.cw.p(f'free({var});')
2409    ri.cw.block_end()
2410    ri.cw.nl()
2411
2412
2413def free_rsp_nested_prototype(ri):
2414        print_free_prototype(ri, "")
2415
2416
2417def free_rsp_nested(ri, struct):
2418    _free_type(ri, "", struct)
2419
2420
2421def print_rsp_free(ri):
2422    if 'reply' not in ri.op[ri.op_mode]:
2423        return
2424    _free_type(ri, 'reply', ri.struct['reply'])
2425
2426
2427def print_dump_type_free(ri):
2428    sub_type = type_name(ri, 'reply')
2429
2430    print_free_prototype(ri, 'reply', suffix='')
2431    ri.cw.block_start()
2432    ri.cw.p(f"{sub_type} *next = rsp;")
2433    ri.cw.nl()
2434    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2435    _free_type_members_iter(ri, ri.struct['reply'])
2436    ri.cw.p('rsp = next;')
2437    ri.cw.p('next = rsp->next;')
2438    ri.cw.nl()
2439
2440    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2441    ri.cw.p(f'free(rsp);')
2442    ri.cw.block_end()
2443    ri.cw.block_end()
2444    ri.cw.nl()
2445
2446
2447def print_ntf_type_free(ri):
2448    print_free_prototype(ri, 'reply', suffix='')
2449    ri.cw.block_start()
2450    _free_type_members_iter(ri, ri.struct['reply'])
2451    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2452    ri.cw.p(f'free(rsp);')
2453    ri.cw.block_end()
2454    ri.cw.nl()
2455
2456
2457def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2458    if terminate and ri and policy_should_be_static(struct.family):
2459        return
2460
2461    if terminate:
2462        prefix = 'extern '
2463    else:
2464        if ri and policy_should_be_static(struct.family):
2465            prefix = 'static '
2466        else:
2467            prefix = ''
2468
2469    suffix = ';' if terminate else ' = {'
2470
2471    max_attr = struct.attr_max_val
2472    if ri:
2473        name = ri.op.render_name
2474        if ri.op.dual_policy:
2475            name += '_' + ri.op_mode
2476    else:
2477        name = struct.render_name
2478    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2479
2480
2481def print_req_policy(cw, struct, ri=None):
2482    if ri and ri.op:
2483        cw.ifdef_block(ri.op.get('config-cond', None))
2484    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2485    for _, arg in struct.member_list():
2486        arg.attr_policy(cw)
2487    cw.p("};")
2488    cw.ifdef_block(None)
2489    cw.nl()
2490
2491
2492def kernel_can_gen_family_struct(family):
2493    return family.proto == 'genetlink'
2494
2495
2496def policy_should_be_static(family):
2497    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2498
2499
2500def print_kernel_policy_ranges(family, cw):
2501    first = True
2502    for _, attr_set in family.attr_sets.items():
2503        if attr_set.subset_of:
2504            continue
2505
2506        for _, attr in attr_set.items():
2507            if not attr.request:
2508                continue
2509            if 'full-range' not in attr.checks:
2510                continue
2511
2512            if first:
2513                cw.p('/* Integer value ranges */')
2514                first = False
2515
2516            sign = '' if attr.type[0] == 'u' else '_signed'
2517            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2518            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2519            members = []
2520            if 'min' in attr.checks:
2521                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2522            if 'max' in attr.checks:
2523                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2524            cw.write_struct_init(members)
2525            cw.block_end(line=';')
2526            cw.nl()
2527
2528
2529def print_kernel_policy_sparse_enum_validates(family, cw):
2530    first = True
2531    for _, attr_set in family.attr_sets.items():
2532        if attr_set.subset_of:
2533            continue
2534
2535        for _, attr in attr_set.items():
2536            if not attr.request:
2537                continue
2538            if not attr.enum_name:
2539                continue
2540            if 'sparse' not in attr.checks:
2541                continue
2542
2543            if first:
2544                cw.p('/* Sparse enums validation callbacks */')
2545                first = False
2546
2547            sign = '' if attr.type[0] == 'u' else '_signed'
2548            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2549            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2550                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2551            cw.block_start()
2552            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2553            enum = family.consts[attr['enum']]
2554            first_entry = True
2555            for entry in enum.entries.values():
2556                if first_entry:
2557                    first_entry = False
2558                else:
2559                    cw.p('fallthrough;')
2560                cw.p(f'case {entry.c_name}:')
2561            cw.p('return 0;')
2562            cw.block_end()
2563            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2564            cw.p('return -EINVAL;')
2565            cw.block_end()
2566            cw.nl()
2567
2568
2569def print_kernel_op_table_fwd(family, cw, terminate):
2570    exported = not kernel_can_gen_family_struct(family)
2571
2572    if not terminate or exported:
2573        cw.p(f"/* Ops table for {family.ident_name} */")
2574
2575        pol_to_struct = {'global': 'genl_small_ops',
2576                         'per-op': 'genl_ops',
2577                         'split': 'genl_split_ops'}
2578        struct_type = pol_to_struct[family.kernel_policy]
2579
2580        if not exported:
2581            cnt = ""
2582        elif family.kernel_policy == 'split':
2583            cnt = 0
2584            for op in family.ops.values():
2585                if 'do' in op:
2586                    cnt += 1
2587                if 'dump' in op:
2588                    cnt += 1
2589        else:
2590            cnt = len(family.ops)
2591
2592        qual = 'static const' if not exported else 'const'
2593        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2594        if terminate:
2595            cw.p(f"extern {line};")
2596        else:
2597            cw.block_start(line=line + ' =')
2598
2599    if not terminate:
2600        return
2601
2602    cw.nl()
2603    for name in family.hooks['pre']['do']['list']:
2604        cw.write_func_prot('int', c_lower(name),
2605                           ['const struct genl_split_ops *ops',
2606                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2607    for name in family.hooks['post']['do']['list']:
2608        cw.write_func_prot('void', c_lower(name),
2609                           ['const struct genl_split_ops *ops',
2610                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2611    for name in family.hooks['pre']['dump']['list']:
2612        cw.write_func_prot('int', c_lower(name),
2613                           ['struct netlink_callback *cb'], suffix=';')
2614    for name in family.hooks['post']['dump']['list']:
2615        cw.write_func_prot('int', c_lower(name),
2616                           ['struct netlink_callback *cb'], suffix=';')
2617
2618    cw.nl()
2619
2620    for op_name, op in family.ops.items():
2621        if op.is_async:
2622            continue
2623
2624        if 'do' in op:
2625            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2626            cw.write_func_prot('int', name,
2627                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2628
2629        if 'dump' in op:
2630            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2631            cw.write_func_prot('int', name,
2632                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2633    cw.nl()
2634
2635
2636def print_kernel_op_table_hdr(family, cw):
2637    print_kernel_op_table_fwd(family, cw, terminate=True)
2638
2639
2640def print_kernel_op_table(family, cw):
2641    print_kernel_op_table_fwd(family, cw, terminate=False)
2642    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2643        for op_name, op in family.ops.items():
2644            if op.is_async:
2645                continue
2646
2647            cw.ifdef_block(op.get('config-cond', None))
2648            cw.block_start()
2649            members = [('cmd', op.enum_name)]
2650            if 'dont-validate' in op:
2651                members.append(('validate',
2652                                ' | '.join([c_upper('genl-dont-validate-' + x)
2653                                            for x in op['dont-validate']])), )
2654            for op_mode in ['do', 'dump']:
2655                if op_mode in op:
2656                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2657                    members.append((op_mode + 'it', name))
2658            if family.kernel_policy == 'per-op':
2659                struct = Struct(family, op['attribute-set'],
2660                                type_list=op['do']['request']['attributes'])
2661
2662                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2663                members.append(('policy', name))
2664                members.append(('maxattr', struct.attr_max_val.enum_name))
2665            if 'flags' in op:
2666                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2667            cw.write_struct_init(members)
2668            cw.block_end(line=',')
2669    elif family.kernel_policy == 'split':
2670        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2671                    'dump': {'pre': 'start', 'post': 'done'}}
2672
2673        for op_name, op in family.ops.items():
2674            for op_mode in ['do', 'dump']:
2675                if op.is_async or op_mode not in op:
2676                    continue
2677
2678                cw.ifdef_block(op.get('config-cond', None))
2679                cw.block_start()
2680                members = [('cmd', op.enum_name)]
2681                if 'dont-validate' in op:
2682                    dont_validate = []
2683                    for x in op['dont-validate']:
2684                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2685                            continue
2686                        if op_mode == "dump" and x == 'strict':
2687                            continue
2688                        dont_validate.append(x)
2689
2690                    if dont_validate:
2691                        members.append(('validate',
2692                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2693                                                    for x in dont_validate])), )
2694                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2695                if 'pre' in op[op_mode]:
2696                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2697                members.append((op_mode + 'it', name))
2698                if 'post' in op[op_mode]:
2699                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2700                if 'request' in op[op_mode]:
2701                    struct = Struct(family, op['attribute-set'],
2702                                    type_list=op[op_mode]['request']['attributes'])
2703
2704                    if op.dual_policy:
2705                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2706                    else:
2707                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2708                    members.append(('policy', name))
2709                    members.append(('maxattr', struct.attr_max_val.enum_name))
2710                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2711                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2712                cw.write_struct_init(members)
2713                cw.block_end(line=',')
2714    cw.ifdef_block(None)
2715
2716    cw.block_end(line=';')
2717    cw.nl()
2718
2719
2720def print_kernel_mcgrp_hdr(family, cw):
2721    if not family.mcgrps['list']:
2722        return
2723
2724    cw.block_start('enum')
2725    for grp in family.mcgrps['list']:
2726        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2727        cw.p(grp_id)
2728    cw.block_end(';')
2729    cw.nl()
2730
2731
2732def print_kernel_mcgrp_src(family, cw):
2733    if not family.mcgrps['list']:
2734        return
2735
2736    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2737    for grp in family.mcgrps['list']:
2738        name = grp['name']
2739        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2740        cw.p('[' + grp_id + '] = { "' + name + '", },')
2741    cw.block_end(';')
2742    cw.nl()
2743
2744
2745def print_kernel_family_struct_hdr(family, cw):
2746    if not kernel_can_gen_family_struct(family):
2747        return
2748
2749    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2750    cw.nl()
2751    if 'sock-priv' in family.kernel_family:
2752        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2753        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2754        cw.nl()
2755
2756
2757def print_kernel_family_struct_src(family, cw):
2758    if not kernel_can_gen_family_struct(family):
2759        return
2760
2761    if 'sock-priv' in family.kernel_family:
2762        # Generate "trampolines" to make CFI happy
2763        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2764                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2765                      ["void *priv"])
2766        cw.nl()
2767        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2768                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2769                      ["void *priv"])
2770        cw.nl()
2771
2772    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2773    cw.p('.name\t\t= ' + family.fam_key + ',')
2774    cw.p('.version\t= ' + family.ver_key + ',')
2775    cw.p('.netnsok\t= true,')
2776    cw.p('.parallel_ops\t= true,')
2777    cw.p('.module\t\t= THIS_MODULE,')
2778    if family.kernel_policy == 'per-op':
2779        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2780        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2781    elif family.kernel_policy == 'split':
2782        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2783        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2784    if family.mcgrps['list']:
2785        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2786        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2787    if 'sock-priv' in family.kernel_family:
2788        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2789        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2790        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2791    cw.block_end(';')
2792
2793
2794def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2795    start_line = 'enum'
2796    if enum_name in obj:
2797        if obj[enum_name]:
2798            start_line = 'enum ' + c_lower(obj[enum_name])
2799    elif ckey and ckey in obj:
2800        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2801    cw.block_start(line=start_line)
2802
2803
2804def render_uapi_unified(family, cw, max_by_define, separate_ntf):
2805    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2806    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2807    max_value = f"({cnt_name} - 1)"
2808
2809    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2810    val = 0
2811    for op in family.msgs.values():
2812        if separate_ntf and ('notify' in op or 'event' in op):
2813            continue
2814
2815        suffix = ','
2816        if op.value != val:
2817            suffix = f" = {op.value},"
2818            val = op.value
2819        cw.p(op.enum_name + suffix)
2820        val += 1
2821    cw.nl()
2822    cw.p(cnt_name + ('' if max_by_define else ','))
2823    if not max_by_define:
2824        cw.p(f"{max_name} = {max_value}")
2825    cw.block_end(line=';')
2826    if max_by_define:
2827        cw.p(f"#define {max_name} {max_value}")
2828    cw.nl()
2829
2830
2831def render_uapi_directional(family, cw, max_by_define):
2832    max_name = f"{family.op_prefix}USER_MAX"
2833    cnt_name = f"__{family.op_prefix}USER_CNT"
2834    max_value = f"({cnt_name} - 1)"
2835
2836    cw.block_start(line='enum')
2837    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
2838    val = 0
2839    for op in family.msgs.values():
2840        if 'do' in op and 'event' not in op:
2841            suffix = ','
2842            if op.value and op.value != val:
2843                suffix = f" = {op.value},"
2844                val = op.value
2845            cw.p(op.enum_name + suffix)
2846            val += 1
2847    cw.nl()
2848    cw.p(cnt_name + ('' if max_by_define else ','))
2849    if not max_by_define:
2850        cw.p(f"{max_name} = {max_value}")
2851    cw.block_end(line=';')
2852    if max_by_define:
2853        cw.p(f"#define {max_name} {max_value}")
2854    cw.nl()
2855
2856    max_name = f"{family.op_prefix}KERNEL_MAX"
2857    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
2858    max_value = f"({cnt_name} - 1)"
2859
2860    cw.block_start(line='enum')
2861    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
2862    val = 0
2863    for op in family.msgs.values():
2864        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
2865            enum_name = op.enum_name
2866            if 'event' not in op and 'notify' not in op:
2867                enum_name = f'{enum_name}_REPLY'
2868
2869            suffix = ','
2870            if op.value and op.value != val:
2871                suffix = f" = {op.value},"
2872                val = op.value
2873            cw.p(enum_name + suffix)
2874            val += 1
2875    cw.nl()
2876    cw.p(cnt_name + ('' if max_by_define else ','))
2877    if not max_by_define:
2878        cw.p(f"{max_name} = {max_value}")
2879    cw.block_end(line=';')
2880    if max_by_define:
2881        cw.p(f"#define {max_name} {max_value}")
2882    cw.nl()
2883
2884
2885def render_uapi(family, cw):
2886    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2887    hdr_prot = hdr_prot.replace('/', '_')
2888    cw.p('#ifndef ' + hdr_prot)
2889    cw.p('#define ' + hdr_prot)
2890    cw.nl()
2891
2892    defines = [(family.fam_key, family["name"]),
2893               (family.ver_key, family.get('version', 1))]
2894    cw.writes_defines(defines)
2895    cw.nl()
2896
2897    defines = []
2898    for const in family['definitions']:
2899        if const.get('header'):
2900            continue
2901
2902        if const['type'] != 'const':
2903            cw.writes_defines(defines)
2904            defines = []
2905            cw.nl()
2906
2907        # Write kdoc for enum and flags (one day maybe also structs)
2908        if const['type'] == 'enum' or const['type'] == 'flags':
2909            enum = family.consts[const['name']]
2910
2911            if enum.header:
2912                continue
2913
2914            if enum.has_doc():
2915                if enum.has_entry_doc():
2916                    cw.p('/**')
2917                    doc = ''
2918                    if 'doc' in enum:
2919                        doc = ' - ' + enum['doc']
2920                    cw.write_doc_line(enum.enum_name + doc)
2921                else:
2922                    cw.p('/*')
2923                    cw.write_doc_line(enum['doc'], indent=False)
2924                for entry in enum.entries.values():
2925                    if entry.has_doc():
2926                        doc = '@' + entry.c_name + ': ' + entry['doc']
2927                        cw.write_doc_line(doc)
2928                cw.p(' */')
2929
2930            uapi_enum_start(family, cw, const, 'name')
2931            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
2932            for entry in enum.entries.values():
2933                suffix = ','
2934                if entry.value_change:
2935                    suffix = f" = {entry.user_value()}" + suffix
2936                cw.p(entry.c_name + suffix)
2937
2938            if const.get('render-max', False):
2939                cw.nl()
2940                cw.p('/* private: */')
2941                if const['type'] == 'flags':
2942                    max_name = c_upper(name_pfx + 'mask')
2943                    max_val = f' = {enum.get_mask()},'
2944                    cw.p(max_name + max_val)
2945                else:
2946                    cnt_name = enum.enum_cnt_name
2947                    max_name = c_upper(name_pfx + 'max')
2948                    if not cnt_name:
2949                        cnt_name = '__' + name_pfx + 'max'
2950                    cw.p(c_upper(cnt_name) + ',')
2951                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
2952            cw.block_end(line=';')
2953            cw.nl()
2954        elif const['type'] == 'const':
2955            defines.append([c_upper(family.get('c-define-name',
2956                                               f"{family.ident_name}-{const['name']}")),
2957                            const['value']])
2958
2959    if defines:
2960        cw.writes_defines(defines)
2961        cw.nl()
2962
2963    max_by_define = family.get('max-by-define', False)
2964
2965    for _, attr_set in family.attr_sets.items():
2966        if attr_set.subset_of:
2967            continue
2968
2969        max_value = f"({attr_set.cnt_name} - 1)"
2970
2971        val = 0
2972        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2973        for _, attr in attr_set.items():
2974            suffix = ','
2975            if attr.value != val:
2976                suffix = f" = {attr.value},"
2977                val = attr.value
2978            val += 1
2979            cw.p(attr.enum_name + suffix)
2980        if attr_set.items():
2981            cw.nl()
2982        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2983        if not max_by_define:
2984            cw.p(f"{attr_set.max_name} = {max_value}")
2985        cw.block_end(line=';')
2986        if max_by_define:
2987            cw.p(f"#define {attr_set.max_name} {max_value}")
2988        cw.nl()
2989
2990    # Commands
2991    separate_ntf = 'async-prefix' in family['operations']
2992
2993    if family.msg_id_model == 'unified':
2994        render_uapi_unified(family, cw, max_by_define, separate_ntf)
2995    elif family.msg_id_model == 'directional':
2996        render_uapi_directional(family, cw, max_by_define)
2997    else:
2998        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
2999
3000    if separate_ntf:
3001        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3002        for op in family.msgs.values():
3003            if separate_ntf and not ('notify' in op or 'event' in op):
3004                continue
3005
3006            suffix = ','
3007            if 'value' in op:
3008                suffix = f" = {op['value']},"
3009            cw.p(op.enum_name + suffix)
3010        cw.block_end(line=';')
3011        cw.nl()
3012
3013    # Multicast
3014    defines = []
3015    for grp in family.mcgrps['list']:
3016        name = grp['name']
3017        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3018                        f'{name}'])
3019    cw.nl()
3020    if defines:
3021        cw.writes_defines(defines)
3022        cw.nl()
3023
3024    cw.p(f'#endif /* {hdr_prot} */')
3025
3026
3027def _render_user_ntf_entry(ri, op):
3028    if not ri.family.is_classic():
3029        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3030    else:
3031        crud_op = ri.family.req_by_value[op.rsp_value]
3032        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3033    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3034    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3035    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3036    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3037    ri.cw.block_end(line=',')
3038
3039
3040def render_user_family(family, cw, prototype):
3041    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3042    if prototype:
3043        cw.p(f'extern {symbol};')
3044        return
3045
3046    if family.ntfs:
3047        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3048        for ntf_op_name, ntf_op in family.ntfs.items():
3049            if 'notify' in ntf_op:
3050                op = family.ops[ntf_op['notify']]
3051                ri = RenderInfo(cw, family, "user", op, "notify")
3052            elif 'event' in ntf_op:
3053                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3054            else:
3055                raise Exception('Invalid notification ' + ntf_op_name)
3056            _render_user_ntf_entry(ri, ntf_op)
3057        for op_name, op in family.ops.items():
3058            if 'event' not in op:
3059                continue
3060            ri = RenderInfo(cw, family, "user", op, "event")
3061            _render_user_ntf_entry(ri, op)
3062        cw.block_end(line=";")
3063        cw.nl()
3064
3065    cw.block_start(f'{symbol} = ')
3066    cw.p(f'.name\t\t= "{family.c_name}",')
3067    if family.is_classic():
3068        cw.p(f'.is_classic\t= true,')
3069        cw.p(f'.classic_id\t= {family.get("protonum")},')
3070    if family.is_classic():
3071        if family.fixed_header:
3072            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3073    elif family.fixed_header:
3074        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3075    else:
3076        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3077    if family.ntfs:
3078        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3079        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3080    cw.block_end(line=';')
3081
3082
3083def family_contains_bitfield32(family):
3084    for _, attr_set in family.attr_sets.items():
3085        if attr_set.subset_of:
3086            continue
3087        for _, attr in attr_set.items():
3088            if attr.type == "bitfield32":
3089                return True
3090    return False
3091
3092
3093def find_kernel_root(full_path):
3094    sub_path = ''
3095    while True:
3096        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3097        full_path = os.path.dirname(full_path)
3098        maintainers = os.path.join(full_path, "MAINTAINERS")
3099        if os.path.exists(maintainers):
3100            return full_path, sub_path[:-1]
3101
3102
3103def main():
3104    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3105    parser.add_argument('--mode', dest='mode', type=str, required=True,
3106                        choices=('user', 'kernel', 'uapi'))
3107    parser.add_argument('--spec', dest='spec', type=str, required=True)
3108    parser.add_argument('--header', dest='header', action='store_true', default=None)
3109    parser.add_argument('--source', dest='header', action='store_false')
3110    parser.add_argument('--user-header', nargs='+', default=[])
3111    parser.add_argument('--cmp-out', action='store_true', default=None,
3112                        help='Do not overwrite the output file if the new output is identical to the old')
3113    parser.add_argument('--exclude-op', action='append', default=[])
3114    parser.add_argument('-o', dest='out_file', type=str, default=None)
3115    args = parser.parse_args()
3116
3117    if args.header is None:
3118        parser.error("--header or --source is required")
3119
3120    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3121
3122    try:
3123        parsed = Family(args.spec, exclude_ops)
3124        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3125            print('Spec license:', parsed.license)
3126            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3127            os.sys.exit(1)
3128    except yaml.YAMLError as exc:
3129        print(exc)
3130        os.sys.exit(1)
3131        return
3132
3133    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3134
3135    _, spec_kernel = find_kernel_root(args.spec)
3136    if args.mode == 'uapi' or args.header:
3137        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3138    else:
3139        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3140    cw.p("/* Do not edit directly, auto-generated from: */")
3141    cw.p(f"/*\t{spec_kernel} */")
3142    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3143    if args.exclude_op or args.user_header:
3144        line = ''
3145        line += ' --user-header '.join([''] + args.user_header)
3146        line += ' --exclude-op '.join([''] + args.exclude_op)
3147        cw.p(f'/* YNL-ARG{line} */')
3148    cw.nl()
3149
3150    if args.mode == 'uapi':
3151        render_uapi(parsed, cw)
3152        return
3153
3154    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3155    if args.header:
3156        cw.p('#ifndef ' + hdr_prot)
3157        cw.p('#define ' + hdr_prot)
3158        cw.nl()
3159
3160    if args.out_file:
3161        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3162    else:
3163        hdr_file = "generated_header_file.h"
3164
3165    if args.mode == 'kernel':
3166        cw.p('#include <net/netlink.h>')
3167        cw.p('#include <net/genetlink.h>')
3168        cw.nl()
3169        if not args.header:
3170            if args.out_file:
3171                cw.p(f'#include "{hdr_file}"')
3172            cw.nl()
3173        headers = ['uapi/' + parsed.uapi_header]
3174        headers += parsed.kernel_family.get('headers', [])
3175    else:
3176        cw.p('#include <stdlib.h>')
3177        cw.p('#include <string.h>')
3178        if args.header:
3179            cw.p('#include <linux/types.h>')
3180            if family_contains_bitfield32(parsed):
3181                cw.p('#include <linux/netlink.h>')
3182        else:
3183            cw.p(f'#include "{hdr_file}"')
3184            cw.p('#include "ynl.h"')
3185        headers = []
3186    for definition in parsed['definitions'] + parsed['attribute-sets']:
3187        if 'header' in definition:
3188            headers.append(definition['header'])
3189    if args.mode == 'user':
3190        headers.append(parsed.uapi_header)
3191    seen_header = []
3192    for one in headers:
3193        if one not in seen_header:
3194            cw.p(f"#include <{one}>")
3195            seen_header.append(one)
3196    cw.nl()
3197
3198    if args.mode == "user":
3199        if not args.header:
3200            cw.p("#include <linux/genetlink.h>")
3201            cw.nl()
3202            for one in args.user_header:
3203                cw.p(f'#include "{one}"')
3204        else:
3205            cw.p('struct ynl_sock;')
3206            cw.nl()
3207            render_user_family(parsed, cw, True)
3208        cw.nl()
3209
3210    if args.mode == "kernel":
3211        if args.header:
3212            for _, struct in sorted(parsed.pure_nested_structs.items()):
3213                if struct.request:
3214                    cw.p('/* Common nested types */')
3215                    break
3216            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3217                if struct.request:
3218                    print_req_policy_fwd(cw, struct)
3219            cw.nl()
3220
3221            if parsed.kernel_policy == 'global':
3222                cw.p(f"/* Global operation policy for {parsed.name} */")
3223
3224                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3225                print_req_policy_fwd(cw, struct)
3226                cw.nl()
3227
3228            if parsed.kernel_policy in {'per-op', 'split'}:
3229                for op_name, op in parsed.ops.items():
3230                    if 'do' in op and 'event' not in op:
3231                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3232                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3233                        cw.nl()
3234
3235            print_kernel_op_table_hdr(parsed, cw)
3236            print_kernel_mcgrp_hdr(parsed, cw)
3237            print_kernel_family_struct_hdr(parsed, cw)
3238        else:
3239            print_kernel_policy_ranges(parsed, cw)
3240            print_kernel_policy_sparse_enum_validates(parsed, cw)
3241
3242            for _, struct in sorted(parsed.pure_nested_structs.items()):
3243                if struct.request:
3244                    cw.p('/* Common nested types */')
3245                    break
3246            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3247                if struct.request:
3248                    print_req_policy(cw, struct)
3249            cw.nl()
3250
3251            if parsed.kernel_policy == 'global':
3252                cw.p(f"/* Global operation policy for {parsed.name} */")
3253
3254                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3255                print_req_policy(cw, struct)
3256                cw.nl()
3257
3258            for op_name, op in parsed.ops.items():
3259                if parsed.kernel_policy in {'per-op', 'split'}:
3260                    for op_mode in ['do', 'dump']:
3261                        if op_mode in op and 'request' in op[op_mode]:
3262                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3263                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3264                            print_req_policy(cw, ri.struct['request'], ri=ri)
3265                            cw.nl()
3266
3267            print_kernel_op_table(parsed, cw)
3268            print_kernel_mcgrp_src(parsed, cw)
3269            print_kernel_family_struct_src(parsed, cw)
3270
3271    if args.mode == "user":
3272        if args.header:
3273            cw.p('/* Enums */')
3274            put_op_name_fwd(parsed, cw)
3275
3276            for name, const in parsed.consts.items():
3277                if isinstance(const, EnumSet):
3278                    put_enum_to_str_fwd(parsed, cw, const)
3279            cw.nl()
3280
3281            cw.p('/* Common nested types */')
3282            for attr_set, struct in parsed.pure_nested_structs.items():
3283                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3284                print_type_full(ri, struct)
3285                if struct.request and struct.in_multi_val:
3286                    free_rsp_nested_prototype(ri)
3287                    cw.nl()
3288
3289            for op_name, op in parsed.ops.items():
3290                cw.p(f"/* ============== {op.enum_name} ============== */")
3291
3292                if 'do' in op and 'event' not in op:
3293                    cw.p(f"/* {op.enum_name} - do */")
3294                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3295                    print_req_type(ri)
3296                    print_req_type_helpers(ri)
3297                    cw.nl()
3298                    print_rsp_type(ri)
3299                    print_rsp_type_helpers(ri)
3300                    cw.nl()
3301                    print_req_prototype(ri)
3302                    cw.nl()
3303
3304                if 'dump' in op:
3305                    cw.p(f"/* {op.enum_name} - dump */")
3306                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3307                    print_req_type(ri)
3308                    print_req_type_helpers(ri)
3309                    if not ri.type_consistent or ri.type_oneside:
3310                        print_rsp_type(ri)
3311                    print_wrapped_type(ri)
3312                    print_dump_prototype(ri)
3313                    cw.nl()
3314
3315                if op.has_ntf:
3316                    cw.p(f"/* {op.enum_name} - notify */")
3317                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3318                    if not ri.type_consistent:
3319                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3320                    print_wrapped_type(ri)
3321
3322            for op_name, op in parsed.ntfs.items():
3323                if 'event' in op:
3324                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3325                    cw.p(f"/* {op.enum_name} - event */")
3326                    print_rsp_type(ri)
3327                    cw.nl()
3328                    print_wrapped_type(ri)
3329            cw.nl()
3330        else:
3331            cw.p('/* Enums */')
3332            put_op_name(parsed, cw)
3333
3334            for name, const in parsed.consts.items():
3335                if isinstance(const, EnumSet):
3336                    put_enum_to_str(parsed, cw, const)
3337            cw.nl()
3338
3339            has_recursive_nests = False
3340            cw.p('/* Policies */')
3341            for struct in parsed.pure_nested_structs.values():
3342                if struct.recursive:
3343                    put_typol_fwd(cw, struct)
3344                    has_recursive_nests = True
3345            if has_recursive_nests:
3346                cw.nl()
3347            for struct in parsed.pure_nested_structs.values():
3348                put_typol(cw, struct)
3349            for name in parsed.root_sets:
3350                struct = Struct(parsed, name)
3351                put_typol(cw, struct)
3352
3353            cw.p('/* Common nested types */')
3354            if has_recursive_nests:
3355                for attr_set, struct in parsed.pure_nested_structs.items():
3356                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3357                    free_rsp_nested_prototype(ri)
3358                    if struct.request:
3359                        put_req_nested_prototype(ri, struct)
3360                    if struct.reply:
3361                        parse_rsp_nested_prototype(ri, struct)
3362                cw.nl()
3363            for attr_set, struct in parsed.pure_nested_structs.items():
3364                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3365
3366                free_rsp_nested(ri, struct)
3367                if struct.request:
3368                    put_req_nested(ri, struct)
3369                if struct.reply:
3370                    parse_rsp_nested(ri, struct)
3371
3372            for op_name, op in parsed.ops.items():
3373                cw.p(f"/* ============== {op.enum_name} ============== */")
3374                if 'do' in op and 'event' not in op:
3375                    cw.p(f"/* {op.enum_name} - do */")
3376                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3377                    print_req_free(ri)
3378                    print_rsp_free(ri)
3379                    parse_rsp_msg(ri)
3380                    print_req(ri)
3381                    cw.nl()
3382
3383                if 'dump' in op:
3384                    cw.p(f"/* {op.enum_name} - dump */")
3385                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3386                    if not ri.type_consistent or ri.type_oneside:
3387                        parse_rsp_msg(ri, deref=True)
3388                    print_req_free(ri)
3389                    print_dump_type_free(ri)
3390                    print_dump(ri)
3391                    cw.nl()
3392
3393                if op.has_ntf:
3394                    cw.p(f"/* {op.enum_name} - notify */")
3395                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3396                    if not ri.type_consistent:
3397                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3398                    print_ntf_type_free(ri)
3399
3400            for op_name, op in parsed.ntfs.items():
3401                if 'event' in op:
3402                    cw.p(f"/* {op.enum_name} - event */")
3403
3404                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3405                    parse_rsp_msg(ri)
3406
3407                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3408                    print_ntf_type_free(ri)
3409            cw.nl()
3410            render_user_family(parsed, cw, False)
3411
3412    if args.header:
3413        cw.p(f'#endif /* {hdr_prot} */')
3414
3415
3416if __name__ == "__main__":
3417    main()
3418