xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision c9c048993d4c0b8a188ae717ca4a1dadd02d29b5)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        if 'len' in attr:
60            self.len = attr['len']
61
62        if 'nested-attributes' in attr:
63            self.nested_attrs = attr['nested-attributes']
64            if self.nested_attrs == family.name:
65                self.nested_render_name = c_lower(f"{family.ident_name}")
66            else:
67                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
68
69            if self.nested_attrs in self.family.consts:
70                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
71            else:
72                self.nested_struct_type = 'struct ' + self.nested_render_name
73
74        self.c_name = c_lower(self.name)
75        if self.c_name in _C_KW:
76            self.c_name += '_'
77        if self.c_name[0].isdigit():
78            self.c_name = '_' + self.c_name
79
80        # Added by resolve():
81        self.enum_name = None
82        delattr(self, "enum_name")
83
84    def _get_real_attr(self):
85        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
86        return self.family.attr_sets[self.attr_set.subset_of][self.name]
87
88    def set_request(self):
89        self.request = True
90        if self.attr_set.subset_of:
91            self._get_real_attr().set_request()
92
93    def set_reply(self):
94        self.reply = True
95        if self.attr_set.subset_of:
96            self._get_real_attr().set_reply()
97
98    def get_limit(self, limit, default=None):
99        value = self.checks.get(limit, default)
100        if value is None:
101            return value
102        if isinstance(value, int):
103            return value
104        if value in self.family.consts:
105            return self.family.consts[value]["value"]
106        return limit_to_number(value)
107
108    def get_limit_str(self, limit, default=None, suffix=''):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return ''
112        if isinstance(value, int):
113            return str(value) + suffix
114        if value in self.family.consts:
115            const = self.family.consts[value]
116            if const.get('header'):
117                return c_upper(value)
118            return c_upper(f"{self.family['name']}-{value}")
119        return c_upper(value)
120
121    def resolve(self):
122        if 'name-prefix' in self.attr:
123            enum_name = f"{self.attr['name-prefix']}{self.name}"
124        else:
125            enum_name = f"{self.attr_set.name_prefix}{self.name}"
126        self.enum_name = c_upper(enum_name)
127
128        if self.attr_set.subset_of:
129            if self.checks != self._get_real_attr().checks:
130                raise Exception("Overriding checks not supported by codegen, yet")
131
132    def is_multi_val(self):
133        return None
134
135    def is_scalar(self):
136        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
137
138    def is_recursive(self):
139        return False
140
141    def is_recursive_for_op(self, ri):
142        return self.is_recursive() and not ri.op
143
144    def presence_type(self):
145        return 'present'
146
147    def presence_member(self, space, type_filter):
148        if self.presence_type() != type_filter:
149            return
150
151        if self.presence_type() == 'present':
152            pfx = '__' if space == 'user' else ''
153            return f"{pfx}u32 {self.c_name}:1;"
154
155        if self.presence_type() in {'len', 'count'}:
156            pfx = '__' if space == 'user' else ''
157            return f"{pfx}u32 {self.c_name};"
158
159    def _complex_member_type(self, ri):
160        return None
161
162    def free_needs_iter(self):
163        return False
164
165    def _free_lines(self, ri, var, ref):
166        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
167            return [f'free({var}->{ref}{self.c_name});']
168        return []
169
170    def free(self, ri, var, ref):
171        lines = self._free_lines(ri, var, ref)
172        for line in lines:
173            ri.cw.p(line)
174
175    def arg_member(self, ri):
176        member = self._complex_member_type(ri)
177        if member:
178            spc = ' ' if member[-1] != '*' else ''
179            arg = [member + spc + '*' + self.c_name]
180            if self.presence_type() == 'count':
181                arg += ['unsigned int n_' + self.c_name]
182            return arg
183        raise Exception(f"Struct member not implemented for class type {self.type}")
184
185    def struct_member(self, ri):
186        member = self._complex_member_type(ri)
187        if member:
188            ptr = '*' if self.is_multi_val() else ''
189            if self.is_recursive_for_op(ri):
190                ptr = '*'
191            spc = ' ' if member[-1] != '*' else ''
192            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
193            return
194        members = self.arg_member(ri)
195        for one in members:
196            ri.cw.p(one + ';')
197
198    def _attr_policy(self, policy):
199        return '{ .type = ' + policy + ', }'
200
201    def attr_policy(self, cw):
202        policy = f'NLA_{c_upper(self.type)}'
203        if self.attr.get('byte-order') == 'big-endian':
204            if self.type in {'u16', 'u32'}:
205                policy = f'NLA_BE{self.type[1:]}'
206
207        spec = self._attr_policy(policy)
208        cw.p(f"\t[{self.enum_name}] = {spec},")
209
210    def _attr_typol(self):
211        raise Exception(f"Type policy not implemented for class type {self.type}")
212
213    def attr_typol(self, cw):
214        typol = self._attr_typol()
215        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
216
217    def _attr_put_line(self, ri, var, line):
218        presence = self.presence_type()
219        if presence in {'present', 'len'}:
220            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
221        ri.cw.p(f"{line};")
222
223    def _attr_put_simple(self, ri, var, put_type):
224        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
225        self._attr_put_line(ri, var, line)
226
227    def attr_put(self, ri, var):
228        raise Exception(f"Put not implemented for class type {self.type}")
229
230    def _attr_get(self, ri, var):
231        raise Exception(f"Attr get not implemented for class type {self.type}")
232
233    def attr_get(self, ri, var, first):
234        lines, init_lines, local_vars = self._attr_get(ri, var)
235        if type(lines) is str:
236            lines = [lines]
237        if type(init_lines) is str:
238            init_lines = [init_lines]
239
240        kw = 'if' if first else 'else if'
241        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
242        if local_vars:
243            for local in local_vars:
244                ri.cw.p(local)
245            ri.cw.nl()
246
247        if not self.is_multi_val():
248            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
249            ri.cw.p("return YNL_PARSE_CB_ERROR;")
250            if self.presence_type() == 'present':
251                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
252
253        if init_lines:
254            ri.cw.nl()
255            for line in init_lines:
256                ri.cw.p(line)
257
258        for line in lines:
259            ri.cw.p(line)
260        ri.cw.block_end()
261        return True
262
263    def _setter_lines(self, ri, member, presence):
264        raise Exception(f"Setter not implemented for class type {self.type}")
265
266    def setter(self, ri, space, direction, deref=False, ref=None):
267        ref = (ref if ref else []) + [self.c_name]
268        var = "req"
269        member = f"{var}->{'.'.join(ref)}"
270
271        local_vars = []
272        if self.free_needs_iter():
273            local_vars += ['unsigned int i;']
274
275        code = []
276        presence = ''
277        for i in range(0, len(ref)):
278            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
279            # Every layer below last is a nest, so we know it uses bit presence
280            # last layer is "self" and may be a complex type
281            if i == len(ref) - 1 and self.presence_type() != 'present':
282                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
283                continue
284            code.append(presence + ' = 1;')
285        ref_path = '.'.join(ref[:-1])
286        if ref_path:
287            ref_path += '.'
288        code += self._free_lines(ri, var, ref_path)
289        code += self._setter_lines(ri, member, presence)
290
291        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
292        free = bool([x for x in code if 'free(' in x])
293        alloc = bool([x for x in code if 'alloc(' in x])
294        if free and not alloc:
295            func_name = '__' + func_name
296        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
297                         body=code,
298                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
299
300
301class TypeUnused(Type):
302    def presence_type(self):
303        return ''
304
305    def arg_member(self, ri):
306        return []
307
308    def _attr_get(self, ri, var):
309        return ['return YNL_PARSE_CB_ERROR;'], None, None
310
311    def _attr_typol(self):
312        return '.type = YNL_PT_REJECT, '
313
314    def attr_policy(self, cw):
315        pass
316
317    def attr_put(self, ri, var):
318        pass
319
320    def attr_get(self, ri, var, first):
321        pass
322
323    def setter(self, ri, space, direction, deref=False, ref=None):
324        pass
325
326
327class TypePad(Type):
328    def presence_type(self):
329        return ''
330
331    def arg_member(self, ri):
332        return []
333
334    def _attr_typol(self):
335        return '.type = YNL_PT_IGNORE, '
336
337    def attr_put(self, ri, var):
338        pass
339
340    def attr_get(self, ri, var, first):
341        pass
342
343    def attr_policy(self, cw):
344        pass
345
346    def setter(self, ri, space, direction, deref=False, ref=None):
347        pass
348
349
350class TypeScalar(Type):
351    def __init__(self, family, attr_set, attr, value):
352        super().__init__(family, attr_set, attr, value)
353
354        self.byte_order_comment = ''
355        if 'byte-order' in attr:
356            self.byte_order_comment = f" /* {attr['byte-order']} */"
357
358        # Classic families have some funny enums, don't bother
359        # computing checks, since we only need them for kernel policies
360        if not family.is_classic():
361            self._init_checks()
362
363        # Added by resolve():
364        self.is_bitfield = None
365        delattr(self, "is_bitfield")
366        self.type_name = None
367        delattr(self, "type_name")
368
369    def resolve(self):
370        self.resolve_up(super())
371
372        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
373            self.is_bitfield = True
374        elif 'enum' in self.attr:
375            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
376        else:
377            self.is_bitfield = False
378
379        if not self.is_bitfield and 'enum' in self.attr:
380            self.type_name = self.family.consts[self.attr['enum']].user_type
381        elif self.is_auto_scalar:
382            self.type_name = '__' + self.type[0] + '64'
383        else:
384            self.type_name = '__' + self.type
385
386    def _init_checks(self):
387        if 'enum' in self.attr:
388            enum = self.family.consts[self.attr['enum']]
389            low, high = enum.value_range()
390            if low == None and high == None:
391                self.checks['sparse'] = True
392            else:
393                if 'min' not in self.checks:
394                    if low != 0 or self.type[0] == 's':
395                        self.checks['min'] = low
396                if 'max' not in self.checks:
397                    self.checks['max'] = high
398
399        if 'min' in self.checks and 'max' in self.checks:
400            if self.get_limit('min') > self.get_limit('max'):
401                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
402            self.checks['range'] = True
403
404        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
405        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
406        if low < 0 and self.type[0] == 'u':
407            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
408        if low < -32768 or high > 32767:
409            self.checks['full-range'] = True
410
411    def _attr_policy(self, policy):
412        if 'flags-mask' in self.checks or self.is_bitfield:
413            if self.is_bitfield:
414                enum = self.family.consts[self.attr['enum']]
415                mask = enum.get_mask(as_flags=True)
416            else:
417                flags = self.family.consts[self.checks['flags-mask']]
418                flag_cnt = len(flags['entries'])
419                mask = (1 << flag_cnt) - 1
420            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
421        elif 'full-range' in self.checks:
422            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
423        elif 'range' in self.checks:
424            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
425        elif 'min' in self.checks:
426            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
427        elif 'max' in self.checks:
428            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
429        elif 'sparse' in self.checks:
430            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
431        return super()._attr_policy(policy)
432
433    def _attr_typol(self):
434        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
435
436    def arg_member(self, ri):
437        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
438
439    def attr_put(self, ri, var):
440        self._attr_put_simple(ri, var, self.type)
441
442    def _attr_get(self, ri, var):
443        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
444
445    def _setter_lines(self, ri, member, presence):
446        return [f"{member} = {self.c_name};"]
447
448
449class TypeFlag(Type):
450    def arg_member(self, ri):
451        return []
452
453    def _attr_typol(self):
454        return '.type = YNL_PT_FLAG, '
455
456    def attr_put(self, ri, var):
457        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
458
459    def _attr_get(self, ri, var):
460        return [], None, None
461
462    def _setter_lines(self, ri, member, presence):
463        return []
464
465
466class TypeString(Type):
467    def arg_member(self, ri):
468        return [f"const char *{self.c_name}"]
469
470    def presence_type(self):
471        return 'len'
472
473    def struct_member(self, ri):
474        ri.cw.p(f"char *{self.c_name};")
475
476    def _attr_typol(self):
477        return f'.type = YNL_PT_NUL_STR, '
478
479    def _attr_policy(self, policy):
480        if 'exact-len' in self.checks:
481            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
482        else:
483            mem = '{ .type = ' + policy
484            if 'max-len' in self.checks:
485                mem += ', .len = ' + self.get_limit_str('max-len')
486            mem += ', }'
487        return mem
488
489    def attr_policy(self, cw):
490        if self.checks.get('unterminated-ok', False):
491            policy = 'NLA_STRING'
492        else:
493            policy = 'NLA_NUL_STRING'
494
495        spec = self._attr_policy(policy)
496        cw.p(f"\t[{self.enum_name}] = {spec},")
497
498    def attr_put(self, ri, var):
499        self._attr_put_simple(ri, var, 'str')
500
501    def _attr_get(self, ri, var):
502        len_mem = var + '->_len.' + self.c_name
503        return [f"{len_mem} = len;",
504                f"{var}->{self.c_name} = malloc(len + 1);",
505                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
506                f"{var}->{self.c_name}[len] = 0;"], \
507               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
508               ['unsigned int len;']
509
510    def _setter_lines(self, ri, member, presence):
511        return [f"{presence} = strlen({self.c_name});",
512                f"{member} = malloc({presence} + 1);",
513                f'memcpy({member}, {self.c_name}, {presence});',
514                f'{member}[{presence}] = 0;']
515
516
517class TypeBinary(Type):
518    def arg_member(self, ri):
519        return [f"const void *{self.c_name}", 'size_t len']
520
521    def presence_type(self):
522        return 'len'
523
524    def struct_member(self, ri):
525        ri.cw.p(f"void *{self.c_name};")
526
527    def _attr_typol(self):
528        return f'.type = YNL_PT_BINARY,'
529
530    def _attr_policy(self, policy):
531        if len(self.checks) == 0:
532            pass
533        elif len(self.checks) == 1:
534            check_name = list(self.checks)[0]
535            if check_name not in {'exact-len', 'min-len', 'max-len'}:
536                raise Exception('Unsupported check for binary type: ' + check_name)
537        else:
538            raise Exception('More than one check for binary type not implemented, yet')
539
540        if len(self.checks) == 0:
541            mem = '{ .type = NLA_BINARY, }'
542        elif 'exact-len' in self.checks:
543            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
544        elif 'min-len' in self.checks:
545            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
546        elif 'max-len' in self.checks:
547            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
548
549        return mem
550
551    def attr_put(self, ri, var):
552        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
553                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
554
555    def _attr_get(self, ri, var):
556        len_mem = var + '->_len.' + self.c_name
557        return [f"{len_mem} = len;",
558                f"{var}->{self.c_name} = malloc(len);",
559                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
560               ['len = ynl_attr_data_len(attr);'], \
561               ['unsigned int len;']
562
563    def _setter_lines(self, ri, member, presence):
564        return [f"{presence} = len;",
565                f"{member} = malloc({presence});",
566                f'memcpy({member}, {self.c_name}, {presence});']
567
568
569class TypeBinaryStruct(TypeBinary):
570    def struct_member(self, ri):
571        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
572
573    def _attr_get(self, ri, var):
574        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
575        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
576        return [f"{len_mem} = len;",
577                f"if (len < {struct_sz})",
578                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
579                "else",
580                f"{var}->{self.c_name} = malloc(len);",
581                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
582               ['len = ynl_attr_data_len(attr);'], \
583               ['unsigned int len;']
584
585
586class TypeBinaryScalarArray(TypeBinary):
587    def arg_member(self, ri):
588        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
589
590    def presence_type(self):
591        return 'count'
592
593    def struct_member(self, ri):
594        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
595
596    def attr_put(self, ri, var):
597        presence = self.presence_type()
598        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
599        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
600        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
601                f"{var}->{self.c_name}, i);")
602        ri.cw.block_end()
603
604    def _attr_get(self, ri, var):
605        len_mem = var + '->_count.' + self.c_name
606        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
607                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
608                f"{var}->{self.c_name} = malloc(len);",
609                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
610               ['len = ynl_attr_data_len(attr);'], \
611               ['unsigned int len;']
612
613    def _setter_lines(self, ri, member, presence):
614        return [f"{presence} = count;",
615                f"count *= sizeof(__{self.get('sub-type')});",
616                f"{member} = malloc(count);",
617                f'memcpy({member}, {self.c_name}, count);']
618
619
620class TypeBitfield32(Type):
621    def _complex_member_type(self, ri):
622        return "struct nla_bitfield32"
623
624    def _attr_typol(self):
625        return f'.type = YNL_PT_BITFIELD32, '
626
627    def _attr_policy(self, policy):
628        if not 'enum' in self.attr:
629            raise Exception('Enum required for bitfield32 attr')
630        enum = self.family.consts[self.attr['enum']]
631        mask = enum.get_mask(as_flags=True)
632        return f"NLA_POLICY_BITFIELD32({mask})"
633
634    def attr_put(self, ri, var):
635        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
636        self._attr_put_line(ri, var, line)
637
638    def _attr_get(self, ri, var):
639        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
640
641    def _setter_lines(self, ri, member, presence):
642        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
643
644
645class TypeNest(Type):
646    def is_recursive(self):
647        return self.family.pure_nested_structs[self.nested_attrs].recursive
648
649    def _complex_member_type(self, ri):
650        return self.nested_struct_type
651
652    def _free_lines(self, ri, var, ref):
653        lines = []
654        at = '&'
655        if self.is_recursive_for_op(ri):
656            at = ''
657            lines += [f'if ({var}->{ref}{self.c_name})']
658        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
659        return lines
660
661    def _attr_typol(self):
662        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
663
664    def _attr_policy(self, policy):
665        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
666
667    def attr_put(self, ri, var):
668        at = '' if self.is_recursive_for_op(ri) else '&'
669        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
670                            f"{self.enum_name}, {at}{var}->{self.c_name})")
671
672    def _attr_get(self, ri, var):
673        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
674                     "return YNL_PARSE_CB_ERROR;"]
675        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
676                      f"parg.data = &{var}->{self.c_name};"]
677        return get_lines, init_lines, None
678
679    def setter(self, ri, space, direction, deref=False, ref=None):
680        ref = (ref if ref else []) + [self.c_name]
681
682        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
683            if attr.is_recursive():
684                continue
685            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
686
687
688class TypeMultiAttr(Type):
689    def __init__(self, family, attr_set, attr, value, base_type):
690        super().__init__(family, attr_set, attr, value)
691
692        self.base_type = base_type
693
694    def is_multi_val(self):
695        return True
696
697    def presence_type(self):
698        return 'count'
699
700    def _complex_member_type(self, ri):
701        if 'type' not in self.attr or self.attr['type'] == 'nest':
702            return self.nested_struct_type
703        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
704            return None  # use arg_member()
705        elif self.attr['type'] == 'string':
706            return 'struct ynl_string *'
707        elif self.attr['type'] in scalars:
708            scalar_pfx = '__' if ri.ku_space == 'user' else ''
709            return scalar_pfx + self.attr['type']
710        else:
711            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
712
713    def arg_member(self, ri):
714        if self.type == 'binary' and 'struct' in self.attr:
715            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
716                    f'unsigned int n_{self.c_name}']
717        return super().arg_member(ri)
718
719    def free_needs_iter(self):
720        return self.attr['type'] in {'nest', 'string'}
721
722    def _free_lines(self, ri, var, ref):
723        lines = []
724        if self.attr['type'] in scalars:
725            lines += [f"free({var}->{ref}{self.c_name});"]
726        elif self.attr['type'] == 'binary':
727            lines += [f"free({var}->{ref}{self.c_name});"]
728        elif self.attr['type'] == 'string':
729            lines += [
730                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
731                f"free({var}->{ref}{self.c_name}[i]);",
732                f"free({var}->{ref}{self.c_name});",
733            ]
734        elif 'type' not in self.attr or self.attr['type'] == 'nest':
735            lines += [
736                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
737                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
738                f"free({var}->{ref}{self.c_name});",
739            ]
740        else:
741            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
742        return lines
743
744    def _attr_policy(self, policy):
745        return self.base_type._attr_policy(policy)
746
747    def _attr_typol(self):
748        return self.base_type._attr_typol()
749
750    def _attr_get(self, ri, var):
751        return f'n_{self.c_name}++;', None, None
752
753    def attr_put(self, ri, var):
754        if self.attr['type'] in scalars:
755            put_type = self.type
756            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
757            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
758        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
759            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
760            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
761        elif self.attr['type'] == 'string':
762            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
763            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
764        elif 'type' not in self.attr or self.attr['type'] == 'nest':
765            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
766            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
767                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
768        else:
769            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
770
771    def _setter_lines(self, ri, member, presence):
772        return [f"{member} = {self.c_name};",
773                f"{presence} = n_{self.c_name};"]
774
775
776class TypeArrayNest(Type):
777    def is_multi_val(self):
778        return True
779
780    def presence_type(self):
781        return 'count'
782
783    def _complex_member_type(self, ri):
784        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
785            return self.nested_struct_type
786        elif self.attr['sub-type'] in scalars:
787            scalar_pfx = '__' if ri.ku_space == 'user' else ''
788            return scalar_pfx + self.attr['sub-type']
789        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
790            return None  # use arg_member()
791        else:
792            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
793
794    def arg_member(self, ri):
795        if self.sub_type == 'binary' and 'exact-len' in self.checks:
796            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
797                    f'unsigned int n_{self.c_name}']
798        return super().arg_member(ri)
799
800    def _attr_typol(self):
801        if self.attr['sub-type'] in scalars:
802            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
803        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
804            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
805        else:
806            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
807
808    def _attr_get(self, ri, var):
809        local_vars = ['const struct nlattr *attr2;']
810        get_lines = [f'attr_{self.c_name} = attr;',
811                     'ynl_attr_for_each_nested(attr2, attr) {',
812                     '\tif (ynl_attr_validate(yarg, attr2))',
813                     '\t\treturn YNL_PARSE_CB_ERROR;',
814                     f'\t{var}->_count.{self.c_name}++;',
815                     '}']
816        return get_lines, None, local_vars
817
818    def attr_put(self, ri, var):
819        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
820        if self.sub_type in scalars:
821            put_type = self.sub_type
822            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
823            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
824            ri.cw.block_end()
825        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
826            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
827            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
828        elif self.sub_type == 'nest':
829            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
830            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
831        else:
832            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
833        ri.cw.p('ynl_attr_nest_end(nlh, array);')
834
835    def _setter_lines(self, ri, member, presence):
836        return [f"{member} = {self.c_name};",
837                f"{presence} = n_{self.c_name};"]
838
839
840class TypeNestTypeValue(Type):
841    def _complex_member_type(self, ri):
842        return self.nested_struct_type
843
844    def _attr_typol(self):
845        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
846
847    def _attr_get(self, ri, var):
848        prev = 'attr'
849        tv_args = ''
850        get_lines = []
851        local_vars = []
852        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
853                      f"parg.data = &{var}->{self.c_name};"]
854        if 'type-value' in self.attr:
855            tv_names = [c_lower(x) for x in self.attr["type-value"]]
856            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
857            local_vars += [f'__u32 {", ".join(tv_names)};']
858            for level in self.attr["type-value"]:
859                level = c_lower(level)
860                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
861                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
862                prev = 'attr_' + level
863
864            tv_args = f", {', '.join(tv_names)}"
865
866        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
867        return get_lines, init_lines, local_vars
868
869
870class Struct:
871    def __init__(self, family, space_name, type_list=None, inherited=None):
872        self.family = family
873        self.space_name = space_name
874        self.attr_set = family.attr_sets[space_name]
875        # Use list to catch comparisons with empty sets
876        self._inherited = inherited if inherited is not None else []
877        self.inherited = []
878
879        self.nested = type_list is None
880        if family.name == c_lower(space_name):
881            self.render_name = c_lower(family.ident_name)
882        else:
883            self.render_name = c_lower(family.ident_name + '-' + space_name)
884        self.struct_name = 'struct ' + self.render_name
885        if self.nested and space_name in family.consts:
886            self.struct_name += '_'
887        self.ptr_name = self.struct_name + ' *'
888        # All attr sets this one contains, directly or multiple levels down
889        self.child_nests = set()
890
891        self.request = False
892        self.reply = False
893        self.recursive = False
894        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
895
896        self.attr_list = []
897        self.attrs = dict()
898        if type_list is not None:
899            for t in type_list:
900                self.attr_list.append((t, self.attr_set[t]),)
901        else:
902            for t in self.attr_set:
903                self.attr_list.append((t, self.attr_set[t]),)
904
905        max_val = 0
906        self.attr_max_val = None
907        for name, attr in self.attr_list:
908            if attr.value >= max_val:
909                max_val = attr.value
910                self.attr_max_val = attr
911            self.attrs[name] = attr
912
913    def __iter__(self):
914        yield from self.attrs
915
916    def __getitem__(self, key):
917        return self.attrs[key]
918
919    def member_list(self):
920        return self.attr_list
921
922    def set_inherited(self, new_inherited):
923        if self._inherited != new_inherited:
924            raise Exception("Inheriting different members not supported")
925        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
926
927    def free_needs_iter(self):
928        for _, attr in self.attr_list:
929            if attr.free_needs_iter():
930                return True
931        return False
932
933
934class EnumEntry(SpecEnumEntry):
935    def __init__(self, enum_set, yaml, prev, value_start):
936        super().__init__(enum_set, yaml, prev, value_start)
937
938        if prev:
939            self.value_change = (self.value != prev.value + 1)
940        else:
941            self.value_change = (self.value != 0)
942        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
943
944        # Added by resolve:
945        self.c_name = None
946        delattr(self, "c_name")
947
948    def resolve(self):
949        self.resolve_up(super())
950
951        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
952
953
954class EnumSet(SpecEnumSet):
955    def __init__(self, family, yaml):
956        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
957
958        if 'enum-name' in yaml:
959            if yaml['enum-name']:
960                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
961                self.user_type = self.enum_name
962            else:
963                self.enum_name = None
964        else:
965            self.enum_name = 'enum ' + self.render_name
966
967        if self.enum_name:
968            self.user_type = self.enum_name
969        else:
970            self.user_type = 'int'
971
972        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
973        self.header = yaml.get('header', None)
974        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
975
976        super().__init__(family, yaml)
977
978    def new_entry(self, entry, prev_entry, value_start):
979        return EnumEntry(self, entry, prev_entry, value_start)
980
981    def value_range(self):
982        low = min([x.value for x in self.entries.values()])
983        high = max([x.value for x in self.entries.values()])
984
985        if high - low + 1 != len(self.entries):
986            return None, None
987
988        return low, high
989
990
991class AttrSet(SpecAttrSet):
992    def __init__(self, family, yaml):
993        super().__init__(family, yaml)
994
995        if self.subset_of is None:
996            if 'name-prefix' in yaml:
997                pfx = yaml['name-prefix']
998            elif self.name == family.name:
999                pfx = family.ident_name + '-a-'
1000            else:
1001                pfx = f"{family.ident_name}-a-{self.name}-"
1002            self.name_prefix = c_upper(pfx)
1003            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1004            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1005        else:
1006            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1007            self.max_name = family.attr_sets[self.subset_of].max_name
1008            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1009
1010        # Added by resolve:
1011        self.c_name = None
1012        delattr(self, "c_name")
1013
1014    def resolve(self):
1015        self.c_name = c_lower(self.name)
1016        if self.c_name in _C_KW:
1017            self.c_name += '_'
1018        if self.c_name == self.family.c_name:
1019            self.c_name = ''
1020
1021    def new_attr(self, elem, value):
1022        if elem['type'] in scalars:
1023            t = TypeScalar(self.family, self, elem, value)
1024        elif elem['type'] == 'unused':
1025            t = TypeUnused(self.family, self, elem, value)
1026        elif elem['type'] == 'pad':
1027            t = TypePad(self.family, self, elem, value)
1028        elif elem['type'] == 'flag':
1029            t = TypeFlag(self.family, self, elem, value)
1030        elif elem['type'] == 'string':
1031            t = TypeString(self.family, self, elem, value)
1032        elif elem['type'] == 'binary':
1033            if 'struct' in elem:
1034                t = TypeBinaryStruct(self.family, self, elem, value)
1035            elif elem.get('sub-type') in scalars:
1036                t = TypeBinaryScalarArray(self.family, self, elem, value)
1037            else:
1038                t = TypeBinary(self.family, self, elem, value)
1039        elif elem['type'] == 'bitfield32':
1040            t = TypeBitfield32(self.family, self, elem, value)
1041        elif elem['type'] == 'nest':
1042            t = TypeNest(self.family, self, elem, value)
1043        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1044            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1045                t = TypeArrayNest(self.family, self, elem, value)
1046            else:
1047                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1048        elif elem['type'] == 'nest-type-value':
1049            t = TypeNestTypeValue(self.family, self, elem, value)
1050        else:
1051            raise Exception(f"No typed class for type {elem['type']}")
1052
1053        if 'multi-attr' in elem and elem['multi-attr']:
1054            t = TypeMultiAttr(self.family, self, elem, value, t)
1055
1056        return t
1057
1058
1059class Operation(SpecOperation):
1060    def __init__(self, family, yaml, req_value, rsp_value):
1061        # Fill in missing operation properties (for fixed hdr-only msgs)
1062        for mode in ['do', 'dump', 'event']:
1063            for direction in ['request', 'reply']:
1064                try:
1065                    yaml[mode][direction].setdefault('attributes', [])
1066                except KeyError:
1067                    pass
1068
1069        super().__init__(family, yaml, req_value, rsp_value)
1070
1071        self.render_name = c_lower(family.ident_name + '_' + self.name)
1072
1073        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1074                         ('dump' in yaml and 'request' in yaml['dump'])
1075
1076        self.has_ntf = False
1077
1078        # Added by resolve:
1079        self.enum_name = None
1080        delattr(self, "enum_name")
1081
1082    def resolve(self):
1083        self.resolve_up(super())
1084
1085        if not self.is_async:
1086            self.enum_name = self.family.op_prefix + c_upper(self.name)
1087        else:
1088            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1089
1090    def mark_has_ntf(self):
1091        self.has_ntf = True
1092
1093
1094class Family(SpecFamily):
1095    def __init__(self, file_name, exclude_ops):
1096        # Added by resolve:
1097        self.c_name = None
1098        delattr(self, "c_name")
1099        self.op_prefix = None
1100        delattr(self, "op_prefix")
1101        self.async_op_prefix = None
1102        delattr(self, "async_op_prefix")
1103        self.mcgrps = None
1104        delattr(self, "mcgrps")
1105        self.consts = None
1106        delattr(self, "consts")
1107        self.hooks = None
1108        delattr(self, "hooks")
1109
1110        super().__init__(file_name, exclude_ops=exclude_ops)
1111
1112        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1113        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1114
1115        if 'definitions' not in self.yaml:
1116            self.yaml['definitions'] = []
1117
1118        if 'uapi-header' in self.yaml:
1119            self.uapi_header = self.yaml['uapi-header']
1120        else:
1121            self.uapi_header = f"linux/{self.ident_name}.h"
1122        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1123            self.uapi_header_name = self.uapi_header[6:-2]
1124        else:
1125            self.uapi_header_name = self.ident_name
1126
1127    def resolve(self):
1128        self.resolve_up(super())
1129
1130        self.c_name = c_lower(self.ident_name)
1131        if 'name-prefix' in self.yaml['operations']:
1132            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1133        else:
1134            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1135        if 'async-prefix' in self.yaml['operations']:
1136            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1137        else:
1138            self.async_op_prefix = self.op_prefix
1139
1140        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1141
1142        self.hooks = dict()
1143        for when in ['pre', 'post']:
1144            self.hooks[when] = dict()
1145            for op_mode in ['do', 'dump']:
1146                self.hooks[when][op_mode] = dict()
1147                self.hooks[when][op_mode]['set'] = set()
1148                self.hooks[when][op_mode]['list'] = []
1149
1150        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1151        self.root_sets = dict()
1152        # dict space-name -> Struct
1153        self.pure_nested_structs = dict()
1154
1155        self._mark_notify()
1156        self._mock_up_events()
1157
1158        self._load_root_sets()
1159        self._load_nested_sets()
1160        self._load_attr_use()
1161        self._load_hooks()
1162
1163        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1164        if self.kernel_policy == 'global':
1165            self._load_global_policy()
1166
1167    def new_enum(self, elem):
1168        return EnumSet(self, elem)
1169
1170    def new_attr_set(self, elem):
1171        return AttrSet(self, elem)
1172
1173    def new_operation(self, elem, req_value, rsp_value):
1174        return Operation(self, elem, req_value, rsp_value)
1175
1176    def is_classic(self):
1177        return self.proto == 'netlink-raw'
1178
1179    def _mark_notify(self):
1180        for op in self.msgs.values():
1181            if 'notify' in op:
1182                self.ops[op['notify']].mark_has_ntf()
1183
1184    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1185    def _mock_up_events(self):
1186        for op in self.yaml['operations']['list']:
1187            if 'event' in op:
1188                op['do'] = {
1189                    'reply': {
1190                        'attributes': op['event']['attributes']
1191                    }
1192                }
1193
1194    def _load_root_sets(self):
1195        for op_name, op in self.msgs.items():
1196            if 'attribute-set' not in op:
1197                continue
1198
1199            req_attrs = set()
1200            rsp_attrs = set()
1201            for op_mode in ['do', 'dump']:
1202                if op_mode in op and 'request' in op[op_mode]:
1203                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1204                if op_mode in op and 'reply' in op[op_mode]:
1205                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1206            if 'event' in op:
1207                rsp_attrs.update(set(op['event']['attributes']))
1208
1209            if op['attribute-set'] not in self.root_sets:
1210                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1211            else:
1212                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1213                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1214
1215    def _sort_pure_types(self):
1216        # Try to reorder according to dependencies
1217        pns_key_list = list(self.pure_nested_structs.keys())
1218        pns_key_seen = set()
1219        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1220        for _ in range(rounds):
1221            if len(pns_key_list) == 0:
1222                break
1223            name = pns_key_list.pop(0)
1224            finished = True
1225            for _, spec in self.attr_sets[name].items():
1226                if 'nested-attributes' in spec:
1227                    nested = spec['nested-attributes']
1228                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1229                    if self.pure_nested_structs[nested].recursive:
1230                        continue
1231                    if nested not in pns_key_seen:
1232                        # Dicts are sorted, this will make struct last
1233                        struct = self.pure_nested_structs.pop(name)
1234                        self.pure_nested_structs[name] = struct
1235                        finished = False
1236                        break
1237            if finished:
1238                pns_key_seen.add(name)
1239            else:
1240                pns_key_list.append(name)
1241
1242    def _load_nested_set_nest(self, spec):
1243        inherit = set()
1244        nested = spec['nested-attributes']
1245        if nested not in self.root_sets:
1246            if nested not in self.pure_nested_structs:
1247                self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1248        else:
1249            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1250
1251        if 'type-value' in spec:
1252            if nested in self.root_sets:
1253                raise Exception("Inheriting members to a space used as root not supported")
1254            inherit.update(set(spec['type-value']))
1255        elif spec['type'] == 'indexed-array':
1256            inherit.add('idx')
1257        self.pure_nested_structs[nested].set_inherited(inherit)
1258
1259        return nested
1260
1261    def _load_nested_sets(self):
1262        attr_set_queue = list(self.root_sets.keys())
1263        attr_set_seen = set(self.root_sets.keys())
1264
1265        while len(attr_set_queue):
1266            a_set = attr_set_queue.pop(0)
1267            for attr, spec in self.attr_sets[a_set].items():
1268                if 'nested-attributes' in spec:
1269                    nested = self._load_nested_set_nest(spec)
1270                else:
1271                    continue
1272
1273                if nested not in attr_set_seen:
1274                    attr_set_queue.append(nested)
1275                    attr_set_seen.add(nested)
1276
1277        for root_set, rs_members in self.root_sets.items():
1278            for attr, spec in self.attr_sets[root_set].items():
1279                if 'nested-attributes' in spec:
1280                    nested = spec['nested-attributes']
1281                    if attr in rs_members['request']:
1282                        self.pure_nested_structs[nested].request = True
1283                    if attr in rs_members['reply']:
1284                        self.pure_nested_structs[nested].reply = True
1285                    if spec.is_multi_val():
1286                        child = self.pure_nested_structs.get(nested)
1287                        child.in_multi_val = True
1288
1289        self._sort_pure_types()
1290
1291        # Propagate the request / reply / recursive
1292        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1293            for _, spec in self.attr_sets[attr_set].items():
1294                if 'nested-attributes' in spec:
1295                    child_name = spec['nested-attributes']
1296                    struct.child_nests.add(child_name)
1297                    child = self.pure_nested_structs.get(child_name)
1298                    if child:
1299                        if not child.recursive:
1300                            struct.child_nests.update(child.child_nests)
1301                        child.request |= struct.request
1302                        child.reply |= struct.reply
1303                        if spec.is_multi_val():
1304                            child.in_multi_val = True
1305                if attr_set in struct.child_nests:
1306                    struct.recursive = True
1307
1308        self._sort_pure_types()
1309
1310    def _load_attr_use(self):
1311        for _, struct in self.pure_nested_structs.items():
1312            if struct.request:
1313                for _, arg in struct.member_list():
1314                    arg.set_request()
1315            if struct.reply:
1316                for _, arg in struct.member_list():
1317                    arg.set_reply()
1318
1319        for root_set, rs_members in self.root_sets.items():
1320            for attr, spec in self.attr_sets[root_set].items():
1321                if attr in rs_members['request']:
1322                    spec.set_request()
1323                if attr in rs_members['reply']:
1324                    spec.set_reply()
1325
1326    def _load_global_policy(self):
1327        global_set = set()
1328        attr_set_name = None
1329        for op_name, op in self.ops.items():
1330            if not op:
1331                continue
1332            if 'attribute-set' not in op:
1333                continue
1334
1335            if attr_set_name is None:
1336                attr_set_name = op['attribute-set']
1337            if attr_set_name != op['attribute-set']:
1338                raise Exception('For a global policy all ops must use the same set')
1339
1340            for op_mode in ['do', 'dump']:
1341                if op_mode in op:
1342                    req = op[op_mode].get('request')
1343                    if req:
1344                        global_set.update(req.get('attributes', []))
1345
1346        self.global_policy = []
1347        self.global_policy_set = attr_set_name
1348        for attr in self.attr_sets[attr_set_name]:
1349            if attr in global_set:
1350                self.global_policy.append(attr)
1351
1352    def _load_hooks(self):
1353        for op in self.ops.values():
1354            for op_mode in ['do', 'dump']:
1355                if op_mode not in op:
1356                    continue
1357                for when in ['pre', 'post']:
1358                    if when not in op[op_mode]:
1359                        continue
1360                    name = op[op_mode][when]
1361                    if name in self.hooks[when][op_mode]['set']:
1362                        continue
1363                    self.hooks[when][op_mode]['set'].add(name)
1364                    self.hooks[when][op_mode]['list'].append(name)
1365
1366
1367class RenderInfo:
1368    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1369        self.family = family
1370        self.nl = cw.nlib
1371        self.ku_space = ku_space
1372        self.op_mode = op_mode
1373        self.op = op
1374
1375        self.fixed_hdr = None
1376        self.fixed_hdr_len = 'ys->family->hdr_len'
1377        if op and op.fixed_header:
1378            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1379            if op.fixed_header != family.fixed_header:
1380                if family.is_classic():
1381                    self.fixed_hdr_len = f"sizeof({self.fixed_hdr})"
1382                else:
1383                    raise Exception(f"Per-op fixed header not supported, yet")
1384
1385
1386        # 'do' and 'dump' response parsing is identical
1387        self.type_consistent = True
1388        self.type_oneside = False
1389        if op_mode != 'do' and 'dump' in op:
1390            if 'do' in op:
1391                if ('reply' in op['do']) != ('reply' in op["dump"]):
1392                    self.type_consistent = False
1393                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1394                    self.type_consistent = False
1395            else:
1396                self.type_consistent = True
1397                self.type_oneside = True
1398
1399        self.attr_set = attr_set
1400        if not self.attr_set:
1401            self.attr_set = op['attribute-set']
1402
1403        self.type_name_conflict = False
1404        if op:
1405            self.type_name = c_lower(op.name)
1406        else:
1407            self.type_name = c_lower(attr_set)
1408            if attr_set in family.consts:
1409                self.type_name_conflict = True
1410
1411        self.cw = cw
1412
1413        self.struct = dict()
1414        if op_mode == 'notify':
1415            op_mode = 'do' if 'do' in op else 'dump'
1416        for op_dir in ['request', 'reply']:
1417            if op:
1418                type_list = []
1419                if op_dir in op[op_mode]:
1420                    type_list = op[op_mode][op_dir]['attributes']
1421                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1422        if op_mode == 'event':
1423            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1424
1425    def type_empty(self, key):
1426        return len(self.struct[key].attr_list) == 0 and self.fixed_hdr is None
1427
1428    def needs_nlflags(self, direction):
1429        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1430
1431
1432class CodeWriter:
1433    def __init__(self, nlib, out_file=None, overwrite=True):
1434        self.nlib = nlib
1435        self._overwrite = overwrite
1436
1437        self._nl = False
1438        self._block_end = False
1439        self._silent_block = False
1440        self._ind = 0
1441        self._ifdef_block = None
1442        if out_file is None:
1443            self._out = os.sys.stdout
1444        else:
1445            self._out = tempfile.NamedTemporaryFile('w+')
1446            self._out_file = out_file
1447
1448    def __del__(self):
1449        self.close_out_file()
1450
1451    def close_out_file(self):
1452        if self._out == os.sys.stdout:
1453            return
1454        # Avoid modifying the file if contents didn't change
1455        self._out.flush()
1456        if not self._overwrite and os.path.isfile(self._out_file):
1457            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1458                return
1459        with open(self._out_file, 'w+') as out_file:
1460            self._out.seek(0)
1461            shutil.copyfileobj(self._out, out_file)
1462            self._out.close()
1463        self._out = os.sys.stdout
1464
1465    @classmethod
1466    def _is_cond(cls, line):
1467        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1468
1469    def p(self, line, add_ind=0):
1470        if self._block_end:
1471            self._block_end = False
1472            if line.startswith('else'):
1473                line = '} ' + line
1474            else:
1475                self._out.write('\t' * self._ind + '}\n')
1476
1477        if self._nl:
1478            self._out.write('\n')
1479            self._nl = False
1480
1481        ind = self._ind
1482        if line[-1] == ':':
1483            ind -= 1
1484        if self._silent_block:
1485            ind += 1
1486        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1487        self._silent_block |= line.strip() == 'else'
1488        if line[0] == '#':
1489            ind = 0
1490        if add_ind:
1491            ind += add_ind
1492        self._out.write('\t' * ind + line + '\n')
1493
1494    def nl(self):
1495        self._nl = True
1496
1497    def block_start(self, line=''):
1498        if line:
1499            line = line + ' '
1500        self.p(line + '{')
1501        self._ind += 1
1502
1503    def block_end(self, line=''):
1504        if line and line[0] not in {';', ','}:
1505            line = ' ' + line
1506        self._ind -= 1
1507        self._nl = False
1508        if not line:
1509            # Delay printing closing bracket in case "else" comes next
1510            if self._block_end:
1511                self._out.write('\t' * (self._ind + 1) + '}\n')
1512            self._block_end = True
1513        else:
1514            self.p('}' + line)
1515
1516    def write_doc_line(self, doc, indent=True):
1517        words = doc.split()
1518        line = ' *'
1519        for word in words:
1520            if len(line) + len(word) >= 79:
1521                self.p(line)
1522                line = ' *'
1523                if indent:
1524                    line += '  '
1525            line += ' ' + word
1526        self.p(line)
1527
1528    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1529        if not args:
1530            args = ['void']
1531
1532        if doc:
1533            self.p('/*')
1534            self.p(' * ' + doc)
1535            self.p(' */')
1536
1537        oneline = qual_ret
1538        if qual_ret[-1] != '*':
1539            oneline += ' '
1540        oneline += f"{name}({', '.join(args)}){suffix}"
1541
1542        if len(oneline) < 80:
1543            self.p(oneline)
1544            return
1545
1546        v = qual_ret
1547        if len(v) > 3:
1548            self.p(v)
1549            v = ''
1550        elif qual_ret[-1] != '*':
1551            v += ' '
1552        v += name + '('
1553        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1554        delta_ind = len(v) - len(ind)
1555        v += args[0]
1556        i = 1
1557        while i < len(args):
1558            next_len = len(v) + len(args[i])
1559            if v[0] == '\t':
1560                next_len += delta_ind
1561            if next_len > 76:
1562                self.p(v + ',')
1563                v = ind
1564            else:
1565                v += ', '
1566            v += args[i]
1567            i += 1
1568        self.p(v + ')' + suffix)
1569
1570    def write_func_lvar(self, local_vars):
1571        if not local_vars:
1572            return
1573
1574        if type(local_vars) is str:
1575            local_vars = [local_vars]
1576
1577        local_vars.sort(key=len, reverse=True)
1578        for var in local_vars:
1579            self.p(var)
1580        self.nl()
1581
1582    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1583        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1584        self.block_start()
1585        self.write_func_lvar(local_vars=local_vars)
1586
1587        for line in body:
1588            self.p(line)
1589        self.block_end()
1590
1591    def writes_defines(self, defines):
1592        longest = 0
1593        for define in defines:
1594            if len(define[0]) > longest:
1595                longest = len(define[0])
1596        longest = ((longest + 8) // 8) * 8
1597        for define in defines:
1598            line = '#define ' + define[0]
1599            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1600            if type(define[1]) is int:
1601                line += str(define[1])
1602            elif type(define[1]) is str:
1603                line += '"' + define[1] + '"'
1604            self.p(line)
1605
1606    def write_struct_init(self, members):
1607        longest = max([len(x[0]) for x in members])
1608        longest += 1  # because we prepend a .
1609        longest = ((longest + 8) // 8) * 8
1610        for one in members:
1611            line = '.' + one[0]
1612            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1613            line += '= ' + str(one[1]) + ','
1614            self.p(line)
1615
1616    def ifdef_block(self, config):
1617        config_option = None
1618        if config:
1619            config_option = 'CONFIG_' + c_upper(config)
1620        if self._ifdef_block == config_option:
1621            return
1622
1623        if self._ifdef_block:
1624            self.p('#endif /* ' + self._ifdef_block + ' */')
1625        if config_option:
1626            self.p('#ifdef ' + config_option)
1627        self._ifdef_block = config_option
1628
1629
1630scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1631
1632direction_to_suffix = {
1633    'reply': '_rsp',
1634    'request': '_req',
1635    '': ''
1636}
1637
1638op_mode_to_wrapper = {
1639    'do': '',
1640    'dump': '_list',
1641    'notify': '_ntf',
1642    'event': '',
1643}
1644
1645_C_KW = {
1646    'auto',
1647    'bool',
1648    'break',
1649    'case',
1650    'char',
1651    'const',
1652    'continue',
1653    'default',
1654    'do',
1655    'double',
1656    'else',
1657    'enum',
1658    'extern',
1659    'float',
1660    'for',
1661    'goto',
1662    'if',
1663    'inline',
1664    'int',
1665    'long',
1666    'register',
1667    'return',
1668    'short',
1669    'signed',
1670    'sizeof',
1671    'static',
1672    'struct',
1673    'switch',
1674    'typedef',
1675    'union',
1676    'unsigned',
1677    'void',
1678    'volatile',
1679    'while'
1680}
1681
1682
1683def rdir(direction):
1684    if direction == 'reply':
1685        return 'request'
1686    if direction == 'request':
1687        return 'reply'
1688    return direction
1689
1690
1691def op_prefix(ri, direction, deref=False):
1692    suffix = f"_{ri.type_name}"
1693
1694    if not ri.op_mode or ri.op_mode == 'do':
1695        suffix += f"{direction_to_suffix[direction]}"
1696    else:
1697        if direction == 'request':
1698            suffix += '_req'
1699            if not ri.type_oneside:
1700                suffix += '_dump'
1701        else:
1702            if ri.type_consistent:
1703                if deref:
1704                    suffix += f"{direction_to_suffix[direction]}"
1705                else:
1706                    suffix += op_mode_to_wrapper[ri.op_mode]
1707            else:
1708                suffix += '_rsp'
1709                suffix += '_dump' if deref else '_list'
1710
1711    return f"{ri.family.c_name}{suffix}"
1712
1713
1714def type_name(ri, direction, deref=False):
1715    return f"struct {op_prefix(ri, direction, deref=deref)}"
1716
1717
1718def print_prototype(ri, direction, terminate=True, doc=None):
1719    suffix = ';' if terminate else ''
1720
1721    fname = ri.op.render_name
1722    if ri.op_mode == 'dump':
1723        fname += '_dump'
1724
1725    args = ['struct ynl_sock *ys']
1726    if 'request' in ri.op[ri.op_mode]:
1727        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1728
1729    ret = 'int'
1730    if 'reply' in ri.op[ri.op_mode]:
1731        ret = f"{type_name(ri, rdir(direction))} *"
1732
1733    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1734
1735
1736def print_req_prototype(ri):
1737    print_prototype(ri, "request", doc=ri.op['doc'])
1738
1739
1740def print_dump_prototype(ri):
1741    print_prototype(ri, "request")
1742
1743
1744def put_typol_fwd(cw, struct):
1745    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1746
1747
1748def put_typol(cw, struct):
1749    type_max = struct.attr_set.max_name
1750    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1751
1752    for _, arg in struct.member_list():
1753        arg.attr_typol(cw)
1754
1755    cw.block_end(line=';')
1756    cw.nl()
1757
1758    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1759    cw.p(f'.max_attr = {type_max},')
1760    cw.p(f'.table = {struct.render_name}_policy,')
1761    cw.block_end(line=';')
1762    cw.nl()
1763
1764
1765def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1766    args = [f'int {arg_name}']
1767    if enum:
1768        args = [enum.user_type + ' ' + arg_name]
1769    cw.write_func_prot('const char *', f'{render_name}_str', args)
1770    cw.block_start()
1771    if enum and enum.type == 'flags':
1772        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1773    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1774    cw.p('return NULL;')
1775    cw.p(f'return {map_name}[{arg_name}];')
1776    cw.block_end()
1777    cw.nl()
1778
1779
1780def put_op_name_fwd(family, cw):
1781    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1782
1783
1784def put_op_name(family, cw):
1785    map_name = f'{family.c_name}_op_strmap'
1786    cw.block_start(line=f"static const char * const {map_name}[] =")
1787    for op_name, op in family.msgs.items():
1788        if op.rsp_value:
1789            # Make sure we don't add duplicated entries, if multiple commands
1790            # produce the same response in legacy families.
1791            if family.rsp_by_value[op.rsp_value] != op:
1792                cw.p(f'// skip "{op_name}", duplicate reply value')
1793                continue
1794
1795            if op.req_value == op.rsp_value:
1796                cw.p(f'[{op.enum_name}] = "{op_name}",')
1797            else:
1798                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1799    cw.block_end(line=';')
1800    cw.nl()
1801
1802    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1803
1804
1805def put_enum_to_str_fwd(family, cw, enum):
1806    args = [enum.user_type + ' value']
1807    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1808
1809
1810def put_enum_to_str(family, cw, enum):
1811    map_name = f'{enum.render_name}_strmap'
1812    cw.block_start(line=f"static const char * const {map_name}[] =")
1813    for entry in enum.entries.values():
1814        cw.p(f'[{entry.value}] = "{entry.name}",')
1815    cw.block_end(line=';')
1816    cw.nl()
1817
1818    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1819
1820
1821def put_req_nested_prototype(ri, struct, suffix=';'):
1822    func_args = ['struct nlmsghdr *nlh',
1823                 'unsigned int attr_type',
1824                 f'{struct.ptr_name}obj']
1825
1826    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1827                          suffix=suffix)
1828
1829
1830def put_req_nested(ri, struct):
1831    local_vars = []
1832    init_lines = []
1833
1834    local_vars.append('struct nlattr *nest;')
1835    init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
1836
1837    has_anest = False
1838    has_count = False
1839    for _, arg in struct.member_list():
1840        has_anest |= arg.type == 'indexed-array'
1841        has_count |= arg.presence_type() == 'count'
1842    if has_anest:
1843        local_vars.append('struct nlattr *array;')
1844    if has_count:
1845        local_vars.append('unsigned int i;')
1846
1847    put_req_nested_prototype(ri, struct, suffix='')
1848    ri.cw.block_start()
1849    ri.cw.write_func_lvar(local_vars)
1850
1851    for line in init_lines:
1852        ri.cw.p(line)
1853
1854    for _, arg in struct.member_list():
1855        arg.attr_put(ri, "obj")
1856
1857    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1858
1859    ri.cw.nl()
1860    ri.cw.p('return 0;')
1861    ri.cw.block_end()
1862    ri.cw.nl()
1863
1864
1865def _multi_parse(ri, struct, init_lines, local_vars):
1866    if struct.nested:
1867        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1868    else:
1869        if ri.fixed_hdr:
1870            local_vars += ['void *hdr;']
1871        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1872        if ri.op.fixed_header != ri.family.fixed_header:
1873            if ri.family.is_classic():
1874                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({ri.fixed_hdr}))"
1875            else:
1876                raise Exception(f"Per-op fixed header not supported, yet")
1877
1878    array_nests = set()
1879    multi_attrs = set()
1880    needs_parg = False
1881    for arg, aspec in struct.member_list():
1882        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
1883            if aspec["sub-type"] in {'binary', 'nest'}:
1884                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1885                array_nests.add(arg)
1886            elif aspec['sub-type'] in scalars:
1887                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1888                array_nests.add(arg)
1889            else:
1890                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
1891        if 'multi-attr' in aspec:
1892            multi_attrs.add(arg)
1893        needs_parg |= 'nested-attributes' in aspec
1894    if array_nests or multi_attrs:
1895        local_vars.append('int i;')
1896    if needs_parg:
1897        local_vars.append('struct ynl_parse_arg parg;')
1898        init_lines.append('parg.ys = yarg->ys;')
1899
1900    all_multi = array_nests | multi_attrs
1901
1902    for anest in sorted(all_multi):
1903        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1904
1905    ri.cw.block_start()
1906    ri.cw.write_func_lvar(local_vars)
1907
1908    for line in init_lines:
1909        ri.cw.p(line)
1910    ri.cw.nl()
1911
1912    for arg in struct.inherited:
1913        ri.cw.p(f'dst->{arg} = {arg};')
1914
1915    if ri.fixed_hdr:
1916        if ri.family.is_classic():
1917            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
1918        else:
1919            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1920        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1921    for anest in sorted(all_multi):
1922        aspec = struct[anest]
1923        ri.cw.p(f"if (dst->{aspec.c_name})")
1924        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1925
1926    ri.cw.nl()
1927    ri.cw.block_start(line=iter_line)
1928    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1929    ri.cw.nl()
1930
1931    first = True
1932    for _, arg in struct.member_list():
1933        good = arg.attr_get(ri, 'dst', first=first)
1934        # First may be 'unused' or 'pad', ignore those
1935        first &= not good
1936
1937    ri.cw.block_end()
1938    ri.cw.nl()
1939
1940    for anest in sorted(array_nests):
1941        aspec = struct[anest]
1942
1943        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1944        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1945        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
1946        ri.cw.p('i = 0;')
1947        if 'nested-attributes' in aspec:
1948            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1949        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1950        if 'nested-attributes' in aspec:
1951            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1952            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1953            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1954        elif aspec.sub_type in scalars:
1955            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
1956        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
1957            # Length is validated by typol
1958            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
1959        else:
1960            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
1961        ri.cw.p('i++;')
1962        ri.cw.block_end()
1963        ri.cw.block_end()
1964    ri.cw.nl()
1965
1966    for anest in sorted(multi_attrs):
1967        aspec = struct[anest]
1968        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1969        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1970        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
1971        ri.cw.p('i = 0;')
1972        if 'nested-attributes' in aspec:
1973            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1974        ri.cw.block_start(line=iter_line)
1975        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1976        if 'nested-attributes' in aspec:
1977            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1978            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1979            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1980        elif aspec.type in scalars:
1981            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1982        elif aspec.type == 'binary' and 'struct' in aspec:
1983            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
1984            ri.cw.nl()
1985            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
1986            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
1987            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
1988        elif aspec.type == 'string':
1989            ri.cw.p('unsigned int len;')
1990            ri.cw.nl()
1991            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
1992            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
1993            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
1994            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
1995            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
1996        else:
1997            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
1998        ri.cw.p('i++;')
1999        ri.cw.block_end()
2000        ri.cw.block_end()
2001        ri.cw.block_end()
2002    ri.cw.nl()
2003
2004    if struct.nested:
2005        ri.cw.p('return 0;')
2006    else:
2007        ri.cw.p('return YNL_PARSE_CB_OK;')
2008    ri.cw.block_end()
2009    ri.cw.nl()
2010
2011
2012def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2013    func_args = ['struct ynl_parse_arg *yarg',
2014                 'const struct nlattr *nested']
2015    for arg in struct.inherited:
2016        func_args.append('__u32 ' + arg)
2017
2018    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2019                          suffix=suffix)
2020
2021
2022def parse_rsp_nested(ri, struct):
2023    parse_rsp_nested_prototype(ri, struct, suffix='')
2024
2025    local_vars = ['const struct nlattr *attr;',
2026                  f'{struct.ptr_name}dst = yarg->data;']
2027    init_lines = []
2028
2029    if struct.member_list():
2030        _multi_parse(ri, struct, init_lines, local_vars)
2031    else:
2032        # Empty nest
2033        ri.cw.block_start()
2034        ri.cw.p('return 0;')
2035        ri.cw.block_end()
2036        ri.cw.nl()
2037
2038
2039def parse_rsp_msg(ri, deref=False):
2040    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2041        return
2042
2043    func_args = ['const struct nlmsghdr *nlh',
2044                 'struct ynl_parse_arg *yarg']
2045
2046    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2047                  'const struct nlattr *attr;']
2048    init_lines = ['dst = yarg->data;']
2049
2050    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2051
2052    if ri.struct["reply"].member_list():
2053        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2054    else:
2055        # Empty reply
2056        ri.cw.block_start()
2057        ri.cw.p('return YNL_PARSE_CB_OK;')
2058        ri.cw.block_end()
2059        ri.cw.nl()
2060
2061
2062def print_req(ri):
2063    ret_ok = '0'
2064    ret_err = '-1'
2065    direction = "request"
2066    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2067                  'struct nlmsghdr *nlh;',
2068                  'int err;']
2069
2070    if 'reply' in ri.op[ri.op_mode]:
2071        ret_ok = 'rsp'
2072        ret_err = 'NULL'
2073        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2074
2075    if ri.fixed_hdr:
2076        local_vars += ['size_t hdr_len;',
2077                       'void *hdr;']
2078
2079    for _, attr in ri.struct["request"].member_list():
2080        if attr.presence_type() == 'count':
2081            local_vars += ['unsigned int i;']
2082            break
2083
2084    print_prototype(ri, direction, terminate=False)
2085    ri.cw.block_start()
2086    ri.cw.write_func_lvar(local_vars)
2087
2088    if ri.family.is_classic():
2089        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2090    else:
2091        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2092
2093    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2094    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2095    if 'reply' in ri.op[ri.op_mode]:
2096        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2097    ri.cw.nl()
2098
2099    if ri.fixed_hdr:
2100        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2101        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2102        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2103        ri.cw.nl()
2104
2105    for _, attr in ri.struct["request"].member_list():
2106        attr.attr_put(ri, "req")
2107    ri.cw.nl()
2108
2109    if 'reply' in ri.op[ri.op_mode]:
2110        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2111        ri.cw.p('yrs.yarg.data = rsp;')
2112        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2113        if ri.op.value is not None:
2114            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2115        else:
2116            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2117        ri.cw.nl()
2118    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2119    ri.cw.p('if (err < 0)')
2120    if 'reply' in ri.op[ri.op_mode]:
2121        ri.cw.p('goto err_free;')
2122    else:
2123        ri.cw.p('return -1;')
2124    ri.cw.nl()
2125
2126    ri.cw.p(f"return {ret_ok};")
2127    ri.cw.nl()
2128
2129    if 'reply' in ri.op[ri.op_mode]:
2130        ri.cw.p('err_free:')
2131        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2132        ri.cw.p(f"return {ret_err};")
2133
2134    ri.cw.block_end()
2135
2136
2137def print_dump(ri):
2138    direction = "request"
2139    print_prototype(ri, direction, terminate=False)
2140    ri.cw.block_start()
2141    local_vars = ['struct ynl_dump_state yds = {};',
2142                  'struct nlmsghdr *nlh;',
2143                  'int err;']
2144
2145    if ri.fixed_hdr:
2146        local_vars += ['size_t hdr_len;',
2147                       'void *hdr;']
2148
2149    ri.cw.write_func_lvar(local_vars)
2150
2151    ri.cw.p('yds.yarg.ys = ys;')
2152    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2153    ri.cw.p("yds.yarg.data = NULL;")
2154    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2155    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2156    if ri.op.value is not None:
2157        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2158    else:
2159        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2160    ri.cw.nl()
2161    if ri.family.is_classic():
2162        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2163    else:
2164        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2165
2166    if ri.fixed_hdr:
2167        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2168        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2169        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2170        ri.cw.nl()
2171
2172    if "request" in ri.op[ri.op_mode]:
2173        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2174        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2175        ri.cw.nl()
2176        for _, attr in ri.struct["request"].member_list():
2177            attr.attr_put(ri, "req")
2178    ri.cw.nl()
2179
2180    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2181    ri.cw.p('if (err < 0)')
2182    ri.cw.p('goto free_list;')
2183    ri.cw.nl()
2184
2185    ri.cw.p('return yds.first;')
2186    ri.cw.nl()
2187    ri.cw.p('free_list:')
2188    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2189    ri.cw.p('return NULL;')
2190    ri.cw.block_end()
2191
2192
2193def call_free(ri, direction, var):
2194    return f"{op_prefix(ri, direction)}_free({var});"
2195
2196
2197def free_arg_name(direction):
2198    if direction:
2199        return direction_to_suffix[direction][1:]
2200    return 'obj'
2201
2202
2203def print_alloc_wrapper(ri, direction):
2204    name = op_prefix(ri, direction)
2205    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
2206    ri.cw.block_start()
2207    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
2208    ri.cw.block_end()
2209
2210
2211def print_free_prototype(ri, direction, suffix=';'):
2212    name = op_prefix(ri, direction)
2213    struct_name = name
2214    if ri.type_name_conflict:
2215        struct_name += '_'
2216    arg = free_arg_name(direction)
2217    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2218
2219
2220def print_nlflags_set(ri, direction):
2221    name = op_prefix(ri, direction)
2222    ri.cw.write_func_prot(f'static inline void', f"{name}_set_nlflags",
2223                          [f"struct {name} *req", "__u16 nl_flags"])
2224    ri.cw.block_start()
2225    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2226    ri.cw.block_end()
2227    ri.cw.nl()
2228
2229
2230def _print_type(ri, direction, struct):
2231    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2232    if not direction and ri.type_name_conflict:
2233        suffix += '_'
2234
2235    if ri.op_mode == 'dump' and not ri.type_oneside:
2236        suffix += '_dump'
2237
2238    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2239
2240    if ri.needs_nlflags(direction):
2241        ri.cw.p('__u16 _nlmsg_flags;')
2242        ri.cw.nl()
2243    if ri.fixed_hdr:
2244        ri.cw.p(ri.fixed_hdr + ' _hdr;')
2245        ri.cw.nl()
2246
2247    for type_filter in ['present', 'len', 'count']:
2248        meta_started = False
2249        for _, attr in struct.member_list():
2250            line = attr.presence_member(ri.ku_space, type_filter)
2251            if line:
2252                if not meta_started:
2253                    ri.cw.block_start(line=f"struct")
2254                    meta_started = True
2255                ri.cw.p(line)
2256        if meta_started:
2257            ri.cw.block_end(line=f'_{type_filter};')
2258    ri.cw.nl()
2259
2260    for arg in struct.inherited:
2261        ri.cw.p(f"__u32 {arg};")
2262
2263    for _, attr in struct.member_list():
2264        attr.struct_member(ri)
2265
2266    ri.cw.block_end(line=';')
2267    ri.cw.nl()
2268
2269
2270def print_type(ri, direction):
2271    _print_type(ri, direction, ri.struct[direction])
2272
2273
2274def print_type_full(ri, struct):
2275    _print_type(ri, "", struct)
2276
2277
2278def print_type_helpers(ri, direction, deref=False):
2279    print_free_prototype(ri, direction)
2280    ri.cw.nl()
2281
2282    if ri.needs_nlflags(direction):
2283        print_nlflags_set(ri, direction)
2284
2285    if ri.ku_space == 'user' and direction == 'request':
2286        for _, attr in ri.struct[direction].member_list():
2287            attr.setter(ri, ri.attr_set, direction, deref=deref)
2288    ri.cw.nl()
2289
2290
2291def print_req_type_helpers(ri):
2292    if ri.type_empty("request"):
2293        return
2294    print_alloc_wrapper(ri, "request")
2295    print_type_helpers(ri, "request")
2296
2297
2298def print_rsp_type_helpers(ri):
2299    if 'reply' not in ri.op[ri.op_mode]:
2300        return
2301    print_type_helpers(ri, "reply")
2302
2303
2304def print_parse_prototype(ri, direction, terminate=True):
2305    suffix = "_rsp" if direction == "reply" else "_req"
2306    term = ';' if terminate else ''
2307
2308    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2309                          ['const struct nlattr **tb',
2310                           f"struct {ri.op.render_name}{suffix} *req"],
2311                          suffix=term)
2312
2313
2314def print_req_type(ri):
2315    if ri.type_empty("request"):
2316        return
2317    print_type(ri, "request")
2318
2319
2320def print_req_free(ri):
2321    if 'request' not in ri.op[ri.op_mode]:
2322        return
2323    _free_type(ri, 'request', ri.struct['request'])
2324
2325
2326def print_rsp_type(ri):
2327    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2328        direction = 'reply'
2329    elif ri.op_mode == 'event':
2330        direction = 'reply'
2331    else:
2332        return
2333    print_type(ri, direction)
2334
2335
2336def print_wrapped_type(ri):
2337    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2338    if ri.op_mode == 'dump':
2339        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2340    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2341        ri.cw.p('__u16 family;')
2342        ri.cw.p('__u8 cmd;')
2343        ri.cw.p('struct ynl_ntf_base_type *next;')
2344        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2345    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2346    ri.cw.block_end(line=';')
2347    ri.cw.nl()
2348    print_free_prototype(ri, 'reply')
2349    ri.cw.nl()
2350
2351
2352def _free_type_members_iter(ri, struct):
2353    if struct.free_needs_iter():
2354        ri.cw.p('unsigned int i;')
2355        ri.cw.nl()
2356
2357
2358def _free_type_members(ri, var, struct, ref=''):
2359    for _, attr in struct.member_list():
2360        attr.free(ri, var, ref)
2361
2362
2363def _free_type(ri, direction, struct):
2364    var = free_arg_name(direction)
2365
2366    print_free_prototype(ri, direction, suffix='')
2367    ri.cw.block_start()
2368    _free_type_members_iter(ri, struct)
2369    _free_type_members(ri, var, struct)
2370    if direction:
2371        ri.cw.p(f'free({var});')
2372    ri.cw.block_end()
2373    ri.cw.nl()
2374
2375
2376def free_rsp_nested_prototype(ri):
2377        print_free_prototype(ri, "")
2378
2379
2380def free_rsp_nested(ri, struct):
2381    _free_type(ri, "", struct)
2382
2383
2384def print_rsp_free(ri):
2385    if 'reply' not in ri.op[ri.op_mode]:
2386        return
2387    _free_type(ri, 'reply', ri.struct['reply'])
2388
2389
2390def print_dump_type_free(ri):
2391    sub_type = type_name(ri, 'reply')
2392
2393    print_free_prototype(ri, 'reply', suffix='')
2394    ri.cw.block_start()
2395    ri.cw.p(f"{sub_type} *next = rsp;")
2396    ri.cw.nl()
2397    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2398    _free_type_members_iter(ri, ri.struct['reply'])
2399    ri.cw.p('rsp = next;')
2400    ri.cw.p('next = rsp->next;')
2401    ri.cw.nl()
2402
2403    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2404    ri.cw.p(f'free(rsp);')
2405    ri.cw.block_end()
2406    ri.cw.block_end()
2407    ri.cw.nl()
2408
2409
2410def print_ntf_type_free(ri):
2411    print_free_prototype(ri, 'reply', suffix='')
2412    ri.cw.block_start()
2413    _free_type_members_iter(ri, ri.struct['reply'])
2414    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2415    ri.cw.p(f'free(rsp);')
2416    ri.cw.block_end()
2417    ri.cw.nl()
2418
2419
2420def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2421    if terminate and ri and policy_should_be_static(struct.family):
2422        return
2423
2424    if terminate:
2425        prefix = 'extern '
2426    else:
2427        if ri and policy_should_be_static(struct.family):
2428            prefix = 'static '
2429        else:
2430            prefix = ''
2431
2432    suffix = ';' if terminate else ' = {'
2433
2434    max_attr = struct.attr_max_val
2435    if ri:
2436        name = ri.op.render_name
2437        if ri.op.dual_policy:
2438            name += '_' + ri.op_mode
2439    else:
2440        name = struct.render_name
2441    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2442
2443
2444def print_req_policy(cw, struct, ri=None):
2445    if ri and ri.op:
2446        cw.ifdef_block(ri.op.get('config-cond', None))
2447    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2448    for _, arg in struct.member_list():
2449        arg.attr_policy(cw)
2450    cw.p("};")
2451    cw.ifdef_block(None)
2452    cw.nl()
2453
2454
2455def kernel_can_gen_family_struct(family):
2456    return family.proto == 'genetlink'
2457
2458
2459def policy_should_be_static(family):
2460    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2461
2462
2463def print_kernel_policy_ranges(family, cw):
2464    first = True
2465    for _, attr_set in family.attr_sets.items():
2466        if attr_set.subset_of:
2467            continue
2468
2469        for _, attr in attr_set.items():
2470            if not attr.request:
2471                continue
2472            if 'full-range' not in attr.checks:
2473                continue
2474
2475            if first:
2476                cw.p('/* Integer value ranges */')
2477                first = False
2478
2479            sign = '' if attr.type[0] == 'u' else '_signed'
2480            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2481            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2482            members = []
2483            if 'min' in attr.checks:
2484                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2485            if 'max' in attr.checks:
2486                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2487            cw.write_struct_init(members)
2488            cw.block_end(line=';')
2489            cw.nl()
2490
2491
2492def print_kernel_policy_sparse_enum_validates(family, cw):
2493    first = True
2494    for _, attr_set in family.attr_sets.items():
2495        if attr_set.subset_of:
2496            continue
2497
2498        for _, attr in attr_set.items():
2499            if not attr.request:
2500                continue
2501            if not attr.enum_name:
2502                continue
2503            if 'sparse' not in attr.checks:
2504                continue
2505
2506            if first:
2507                cw.p('/* Sparse enums validation callbacks */')
2508                first = False
2509
2510            sign = '' if attr.type[0] == 'u' else '_signed'
2511            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2512            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2513                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2514            cw.block_start()
2515            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2516            enum = family.consts[attr['enum']]
2517            first_entry = True
2518            for entry in enum.entries.values():
2519                if first_entry:
2520                    first_entry = False
2521                else:
2522                    cw.p('fallthrough;')
2523                cw.p(f'case {entry.c_name}:')
2524            cw.p('return 0;')
2525            cw.block_end()
2526            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2527            cw.p('return -EINVAL;')
2528            cw.block_end()
2529            cw.nl()
2530
2531
2532def print_kernel_op_table_fwd(family, cw, terminate):
2533    exported = not kernel_can_gen_family_struct(family)
2534
2535    if not terminate or exported:
2536        cw.p(f"/* Ops table for {family.ident_name} */")
2537
2538        pol_to_struct = {'global': 'genl_small_ops',
2539                         'per-op': 'genl_ops',
2540                         'split': 'genl_split_ops'}
2541        struct_type = pol_to_struct[family.kernel_policy]
2542
2543        if not exported:
2544            cnt = ""
2545        elif family.kernel_policy == 'split':
2546            cnt = 0
2547            for op in family.ops.values():
2548                if 'do' in op:
2549                    cnt += 1
2550                if 'dump' in op:
2551                    cnt += 1
2552        else:
2553            cnt = len(family.ops)
2554
2555        qual = 'static const' if not exported else 'const'
2556        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2557        if terminate:
2558            cw.p(f"extern {line};")
2559        else:
2560            cw.block_start(line=line + ' =')
2561
2562    if not terminate:
2563        return
2564
2565    cw.nl()
2566    for name in family.hooks['pre']['do']['list']:
2567        cw.write_func_prot('int', c_lower(name),
2568                           ['const struct genl_split_ops *ops',
2569                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2570    for name in family.hooks['post']['do']['list']:
2571        cw.write_func_prot('void', c_lower(name),
2572                           ['const struct genl_split_ops *ops',
2573                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2574    for name in family.hooks['pre']['dump']['list']:
2575        cw.write_func_prot('int', c_lower(name),
2576                           ['struct netlink_callback *cb'], suffix=';')
2577    for name in family.hooks['post']['dump']['list']:
2578        cw.write_func_prot('int', c_lower(name),
2579                           ['struct netlink_callback *cb'], suffix=';')
2580
2581    cw.nl()
2582
2583    for op_name, op in family.ops.items():
2584        if op.is_async:
2585            continue
2586
2587        if 'do' in op:
2588            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2589            cw.write_func_prot('int', name,
2590                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2591
2592        if 'dump' in op:
2593            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2594            cw.write_func_prot('int', name,
2595                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2596    cw.nl()
2597
2598
2599def print_kernel_op_table_hdr(family, cw):
2600    print_kernel_op_table_fwd(family, cw, terminate=True)
2601
2602
2603def print_kernel_op_table(family, cw):
2604    print_kernel_op_table_fwd(family, cw, terminate=False)
2605    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2606        for op_name, op in family.ops.items():
2607            if op.is_async:
2608                continue
2609
2610            cw.ifdef_block(op.get('config-cond', None))
2611            cw.block_start()
2612            members = [('cmd', op.enum_name)]
2613            if 'dont-validate' in op:
2614                members.append(('validate',
2615                                ' | '.join([c_upper('genl-dont-validate-' + x)
2616                                            for x in op['dont-validate']])), )
2617            for op_mode in ['do', 'dump']:
2618                if op_mode in op:
2619                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2620                    members.append((op_mode + 'it', name))
2621            if family.kernel_policy == 'per-op':
2622                struct = Struct(family, op['attribute-set'],
2623                                type_list=op['do']['request']['attributes'])
2624
2625                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2626                members.append(('policy', name))
2627                members.append(('maxattr', struct.attr_max_val.enum_name))
2628            if 'flags' in op:
2629                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2630            cw.write_struct_init(members)
2631            cw.block_end(line=',')
2632    elif family.kernel_policy == 'split':
2633        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2634                    'dump': {'pre': 'start', 'post': 'done'}}
2635
2636        for op_name, op in family.ops.items():
2637            for op_mode in ['do', 'dump']:
2638                if op.is_async or op_mode not in op:
2639                    continue
2640
2641                cw.ifdef_block(op.get('config-cond', None))
2642                cw.block_start()
2643                members = [('cmd', op.enum_name)]
2644                if 'dont-validate' in op:
2645                    dont_validate = []
2646                    for x in op['dont-validate']:
2647                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2648                            continue
2649                        if op_mode == "dump" and x == 'strict':
2650                            continue
2651                        dont_validate.append(x)
2652
2653                    if dont_validate:
2654                        members.append(('validate',
2655                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2656                                                    for x in dont_validate])), )
2657                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2658                if 'pre' in op[op_mode]:
2659                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2660                members.append((op_mode + 'it', name))
2661                if 'post' in op[op_mode]:
2662                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2663                if 'request' in op[op_mode]:
2664                    struct = Struct(family, op['attribute-set'],
2665                                    type_list=op[op_mode]['request']['attributes'])
2666
2667                    if op.dual_policy:
2668                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2669                    else:
2670                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2671                    members.append(('policy', name))
2672                    members.append(('maxattr', struct.attr_max_val.enum_name))
2673                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2674                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2675                cw.write_struct_init(members)
2676                cw.block_end(line=',')
2677    cw.ifdef_block(None)
2678
2679    cw.block_end(line=';')
2680    cw.nl()
2681
2682
2683def print_kernel_mcgrp_hdr(family, cw):
2684    if not family.mcgrps['list']:
2685        return
2686
2687    cw.block_start('enum')
2688    for grp in family.mcgrps['list']:
2689        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2690        cw.p(grp_id)
2691    cw.block_end(';')
2692    cw.nl()
2693
2694
2695def print_kernel_mcgrp_src(family, cw):
2696    if not family.mcgrps['list']:
2697        return
2698
2699    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2700    for grp in family.mcgrps['list']:
2701        name = grp['name']
2702        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2703        cw.p('[' + grp_id + '] = { "' + name + '", },')
2704    cw.block_end(';')
2705    cw.nl()
2706
2707
2708def print_kernel_family_struct_hdr(family, cw):
2709    if not kernel_can_gen_family_struct(family):
2710        return
2711
2712    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2713    cw.nl()
2714    if 'sock-priv' in family.kernel_family:
2715        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2716        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2717        cw.nl()
2718
2719
2720def print_kernel_family_struct_src(family, cw):
2721    if not kernel_can_gen_family_struct(family):
2722        return
2723
2724    if 'sock-priv' in family.kernel_family:
2725        # Generate "trampolines" to make CFI happy
2726        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2727                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2728                      ["void *priv"])
2729        cw.nl()
2730        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2731                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2732                      ["void *priv"])
2733        cw.nl()
2734
2735    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2736    cw.p('.name\t\t= ' + family.fam_key + ',')
2737    cw.p('.version\t= ' + family.ver_key + ',')
2738    cw.p('.netnsok\t= true,')
2739    cw.p('.parallel_ops\t= true,')
2740    cw.p('.module\t\t= THIS_MODULE,')
2741    if family.kernel_policy == 'per-op':
2742        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2743        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2744    elif family.kernel_policy == 'split':
2745        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2746        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2747    if family.mcgrps['list']:
2748        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2749        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2750    if 'sock-priv' in family.kernel_family:
2751        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2752        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2753        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2754    cw.block_end(';')
2755
2756
2757def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2758    start_line = 'enum'
2759    if enum_name in obj:
2760        if obj[enum_name]:
2761            start_line = 'enum ' + c_lower(obj[enum_name])
2762    elif ckey and ckey in obj:
2763        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2764    cw.block_start(line=start_line)
2765
2766
2767def render_uapi_unified(family, cw, max_by_define, separate_ntf):
2768    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2769    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2770    max_value = f"({cnt_name} - 1)"
2771
2772    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2773    val = 0
2774    for op in family.msgs.values():
2775        if separate_ntf and ('notify' in op or 'event' in op):
2776            continue
2777
2778        suffix = ','
2779        if op.value != val:
2780            suffix = f" = {op.value},"
2781            val = op.value
2782        cw.p(op.enum_name + suffix)
2783        val += 1
2784    cw.nl()
2785    cw.p(cnt_name + ('' if max_by_define else ','))
2786    if not max_by_define:
2787        cw.p(f"{max_name} = {max_value}")
2788    cw.block_end(line=';')
2789    if max_by_define:
2790        cw.p(f"#define {max_name} {max_value}")
2791    cw.nl()
2792
2793
2794def render_uapi_directional(family, cw, max_by_define):
2795    max_name = f"{family.op_prefix}USER_MAX"
2796    cnt_name = f"__{family.op_prefix}USER_CNT"
2797    max_value = f"({cnt_name} - 1)"
2798
2799    cw.block_start(line='enum')
2800    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
2801    val = 0
2802    for op in family.msgs.values():
2803        if 'do' in op and 'event' not in op:
2804            suffix = ','
2805            if op.value and op.value != val:
2806                suffix = f" = {op.value},"
2807                val = op.value
2808            cw.p(op.enum_name + suffix)
2809            val += 1
2810    cw.nl()
2811    cw.p(cnt_name + ('' if max_by_define else ','))
2812    if not max_by_define:
2813        cw.p(f"{max_name} = {max_value}")
2814    cw.block_end(line=';')
2815    if max_by_define:
2816        cw.p(f"#define {max_name} {max_value}")
2817    cw.nl()
2818
2819    max_name = f"{family.op_prefix}KERNEL_MAX"
2820    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
2821    max_value = f"({cnt_name} - 1)"
2822
2823    cw.block_start(line='enum')
2824    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
2825    val = 0
2826    for op in family.msgs.values():
2827        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
2828            enum_name = op.enum_name
2829            if 'event' not in op and 'notify' not in op:
2830                enum_name = f'{enum_name}_REPLY'
2831
2832            suffix = ','
2833            if op.value and op.value != val:
2834                suffix = f" = {op.value},"
2835                val = op.value
2836            cw.p(enum_name + suffix)
2837            val += 1
2838    cw.nl()
2839    cw.p(cnt_name + ('' if max_by_define else ','))
2840    if not max_by_define:
2841        cw.p(f"{max_name} = {max_value}")
2842    cw.block_end(line=';')
2843    if max_by_define:
2844        cw.p(f"#define {max_name} {max_value}")
2845    cw.nl()
2846
2847
2848def render_uapi(family, cw):
2849    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2850    hdr_prot = hdr_prot.replace('/', '_')
2851    cw.p('#ifndef ' + hdr_prot)
2852    cw.p('#define ' + hdr_prot)
2853    cw.nl()
2854
2855    defines = [(family.fam_key, family["name"]),
2856               (family.ver_key, family.get('version', 1))]
2857    cw.writes_defines(defines)
2858    cw.nl()
2859
2860    defines = []
2861    for const in family['definitions']:
2862        if const.get('header'):
2863            continue
2864
2865        if const['type'] != 'const':
2866            cw.writes_defines(defines)
2867            defines = []
2868            cw.nl()
2869
2870        # Write kdoc for enum and flags (one day maybe also structs)
2871        if const['type'] == 'enum' or const['type'] == 'flags':
2872            enum = family.consts[const['name']]
2873
2874            if enum.header:
2875                continue
2876
2877            if enum.has_doc():
2878                if enum.has_entry_doc():
2879                    cw.p('/**')
2880                    doc = ''
2881                    if 'doc' in enum:
2882                        doc = ' - ' + enum['doc']
2883                    cw.write_doc_line(enum.enum_name + doc)
2884                else:
2885                    cw.p('/*')
2886                    cw.write_doc_line(enum['doc'], indent=False)
2887                for entry in enum.entries.values():
2888                    if entry.has_doc():
2889                        doc = '@' + entry.c_name + ': ' + entry['doc']
2890                        cw.write_doc_line(doc)
2891                cw.p(' */')
2892
2893            uapi_enum_start(family, cw, const, 'name')
2894            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
2895            for entry in enum.entries.values():
2896                suffix = ','
2897                if entry.value_change:
2898                    suffix = f" = {entry.user_value()}" + suffix
2899                cw.p(entry.c_name + suffix)
2900
2901            if const.get('render-max', False):
2902                cw.nl()
2903                cw.p('/* private: */')
2904                if const['type'] == 'flags':
2905                    max_name = c_upper(name_pfx + 'mask')
2906                    max_val = f' = {enum.get_mask()},'
2907                    cw.p(max_name + max_val)
2908                else:
2909                    cnt_name = enum.enum_cnt_name
2910                    max_name = c_upper(name_pfx + 'max')
2911                    if not cnt_name:
2912                        cnt_name = '__' + name_pfx + 'max'
2913                    cw.p(c_upper(cnt_name) + ',')
2914                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
2915            cw.block_end(line=';')
2916            cw.nl()
2917        elif const['type'] == 'const':
2918            defines.append([c_upper(family.get('c-define-name',
2919                                               f"{family.ident_name}-{const['name']}")),
2920                            const['value']])
2921
2922    if defines:
2923        cw.writes_defines(defines)
2924        cw.nl()
2925
2926    max_by_define = family.get('max-by-define', False)
2927
2928    for _, attr_set in family.attr_sets.items():
2929        if attr_set.subset_of:
2930            continue
2931
2932        max_value = f"({attr_set.cnt_name} - 1)"
2933
2934        val = 0
2935        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2936        for _, attr in attr_set.items():
2937            suffix = ','
2938            if attr.value != val:
2939                suffix = f" = {attr.value},"
2940                val = attr.value
2941            val += 1
2942            cw.p(attr.enum_name + suffix)
2943        if attr_set.items():
2944            cw.nl()
2945        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2946        if not max_by_define:
2947            cw.p(f"{attr_set.max_name} = {max_value}")
2948        cw.block_end(line=';')
2949        if max_by_define:
2950            cw.p(f"#define {attr_set.max_name} {max_value}")
2951        cw.nl()
2952
2953    # Commands
2954    separate_ntf = 'async-prefix' in family['operations']
2955
2956    if family.msg_id_model == 'unified':
2957        render_uapi_unified(family, cw, max_by_define, separate_ntf)
2958    elif family.msg_id_model == 'directional':
2959        render_uapi_directional(family, cw, max_by_define)
2960    else:
2961        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
2962
2963    if separate_ntf:
2964        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2965        for op in family.msgs.values():
2966            if separate_ntf and not ('notify' in op or 'event' in op):
2967                continue
2968
2969            suffix = ','
2970            if 'value' in op:
2971                suffix = f" = {op['value']},"
2972            cw.p(op.enum_name + suffix)
2973        cw.block_end(line=';')
2974        cw.nl()
2975
2976    # Multicast
2977    defines = []
2978    for grp in family.mcgrps['list']:
2979        name = grp['name']
2980        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
2981                        f'{name}'])
2982    cw.nl()
2983    if defines:
2984        cw.writes_defines(defines)
2985        cw.nl()
2986
2987    cw.p(f'#endif /* {hdr_prot} */')
2988
2989
2990def _render_user_ntf_entry(ri, op):
2991    if not ri.family.is_classic():
2992        ri.cw.block_start(line=f"[{op.enum_name}] = ")
2993    else:
2994        crud_op = ri.family.req_by_value[op.rsp_value]
2995        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
2996    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2997    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2998    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2999    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3000    ri.cw.block_end(line=',')
3001
3002
3003def render_user_family(family, cw, prototype):
3004    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3005    if prototype:
3006        cw.p(f'extern {symbol};')
3007        return
3008
3009    if family.ntfs:
3010        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3011        for ntf_op_name, ntf_op in family.ntfs.items():
3012            if 'notify' in ntf_op:
3013                op = family.ops[ntf_op['notify']]
3014                ri = RenderInfo(cw, family, "user", op, "notify")
3015            elif 'event' in ntf_op:
3016                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3017            else:
3018                raise Exception('Invalid notification ' + ntf_op_name)
3019            _render_user_ntf_entry(ri, ntf_op)
3020        for op_name, op in family.ops.items():
3021            if 'event' not in op:
3022                continue
3023            ri = RenderInfo(cw, family, "user", op, "event")
3024            _render_user_ntf_entry(ri, op)
3025        cw.block_end(line=";")
3026        cw.nl()
3027
3028    cw.block_start(f'{symbol} = ')
3029    cw.p(f'.name\t\t= "{family.c_name}",')
3030    if family.is_classic():
3031        cw.p(f'.is_classic\t= true,')
3032        cw.p(f'.classic_id\t= {family.get("protonum")},')
3033    if family.is_classic():
3034        if family.fixed_header:
3035            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3036    elif family.fixed_header:
3037        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3038    else:
3039        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3040    if family.ntfs:
3041        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3042        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3043    cw.block_end(line=';')
3044
3045
3046def family_contains_bitfield32(family):
3047    for _, attr_set in family.attr_sets.items():
3048        if attr_set.subset_of:
3049            continue
3050        for _, attr in attr_set.items():
3051            if attr.type == "bitfield32":
3052                return True
3053    return False
3054
3055
3056def find_kernel_root(full_path):
3057    sub_path = ''
3058    while True:
3059        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3060        full_path = os.path.dirname(full_path)
3061        maintainers = os.path.join(full_path, "MAINTAINERS")
3062        if os.path.exists(maintainers):
3063            return full_path, sub_path[:-1]
3064
3065
3066def main():
3067    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3068    parser.add_argument('--mode', dest='mode', type=str, required=True,
3069                        choices=('user', 'kernel', 'uapi'))
3070    parser.add_argument('--spec', dest='spec', type=str, required=True)
3071    parser.add_argument('--header', dest='header', action='store_true', default=None)
3072    parser.add_argument('--source', dest='header', action='store_false')
3073    parser.add_argument('--user-header', nargs='+', default=[])
3074    parser.add_argument('--cmp-out', action='store_true', default=None,
3075                        help='Do not overwrite the output file if the new output is identical to the old')
3076    parser.add_argument('--exclude-op', action='append', default=[])
3077    parser.add_argument('-o', dest='out_file', type=str, default=None)
3078    args = parser.parse_args()
3079
3080    if args.header is None:
3081        parser.error("--header or --source is required")
3082
3083    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3084
3085    try:
3086        parsed = Family(args.spec, exclude_ops)
3087        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3088            print('Spec license:', parsed.license)
3089            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3090            os.sys.exit(1)
3091    except yaml.YAMLError as exc:
3092        print(exc)
3093        os.sys.exit(1)
3094        return
3095
3096    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3097
3098    _, spec_kernel = find_kernel_root(args.spec)
3099    if args.mode == 'uapi' or args.header:
3100        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3101    else:
3102        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3103    cw.p("/* Do not edit directly, auto-generated from: */")
3104    cw.p(f"/*\t{spec_kernel} */")
3105    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3106    if args.exclude_op or args.user_header:
3107        line = ''
3108        line += ' --user-header '.join([''] + args.user_header)
3109        line += ' --exclude-op '.join([''] + args.exclude_op)
3110        cw.p(f'/* YNL-ARG{line} */')
3111    cw.nl()
3112
3113    if args.mode == 'uapi':
3114        render_uapi(parsed, cw)
3115        return
3116
3117    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3118    if args.header:
3119        cw.p('#ifndef ' + hdr_prot)
3120        cw.p('#define ' + hdr_prot)
3121        cw.nl()
3122
3123    if args.out_file:
3124        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3125    else:
3126        hdr_file = "generated_header_file.h"
3127
3128    if args.mode == 'kernel':
3129        cw.p('#include <net/netlink.h>')
3130        cw.p('#include <net/genetlink.h>')
3131        cw.nl()
3132        if not args.header:
3133            if args.out_file:
3134                cw.p(f'#include "{hdr_file}"')
3135            cw.nl()
3136        headers = ['uapi/' + parsed.uapi_header]
3137        headers += parsed.kernel_family.get('headers', [])
3138    else:
3139        cw.p('#include <stdlib.h>')
3140        cw.p('#include <string.h>')
3141        if args.header:
3142            cw.p('#include <linux/types.h>')
3143            if family_contains_bitfield32(parsed):
3144                cw.p('#include <linux/netlink.h>')
3145        else:
3146            cw.p(f'#include "{hdr_file}"')
3147            cw.p('#include "ynl.h"')
3148        headers = []
3149    for definition in parsed['definitions'] + parsed['attribute-sets']:
3150        if 'header' in definition:
3151            headers.append(definition['header'])
3152    if args.mode == 'user':
3153        headers.append(parsed.uapi_header)
3154    seen_header = []
3155    for one in headers:
3156        if one not in seen_header:
3157            cw.p(f"#include <{one}>")
3158            seen_header.append(one)
3159    cw.nl()
3160
3161    if args.mode == "user":
3162        if not args.header:
3163            cw.p("#include <linux/genetlink.h>")
3164            cw.nl()
3165            for one in args.user_header:
3166                cw.p(f'#include "{one}"')
3167        else:
3168            cw.p('struct ynl_sock;')
3169            cw.nl()
3170            render_user_family(parsed, cw, True)
3171        cw.nl()
3172
3173    if args.mode == "kernel":
3174        if args.header:
3175            for _, struct in sorted(parsed.pure_nested_structs.items()):
3176                if struct.request:
3177                    cw.p('/* Common nested types */')
3178                    break
3179            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3180                if struct.request:
3181                    print_req_policy_fwd(cw, struct)
3182            cw.nl()
3183
3184            if parsed.kernel_policy == 'global':
3185                cw.p(f"/* Global operation policy for {parsed.name} */")
3186
3187                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3188                print_req_policy_fwd(cw, struct)
3189                cw.nl()
3190
3191            if parsed.kernel_policy in {'per-op', 'split'}:
3192                for op_name, op in parsed.ops.items():
3193                    if 'do' in op and 'event' not in op:
3194                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3195                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3196                        cw.nl()
3197
3198            print_kernel_op_table_hdr(parsed, cw)
3199            print_kernel_mcgrp_hdr(parsed, cw)
3200            print_kernel_family_struct_hdr(parsed, cw)
3201        else:
3202            print_kernel_policy_ranges(parsed, cw)
3203            print_kernel_policy_sparse_enum_validates(parsed, cw)
3204
3205            for _, struct in sorted(parsed.pure_nested_structs.items()):
3206                if struct.request:
3207                    cw.p('/* Common nested types */')
3208                    break
3209            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3210                if struct.request:
3211                    print_req_policy(cw, struct)
3212            cw.nl()
3213
3214            if parsed.kernel_policy == 'global':
3215                cw.p(f"/* Global operation policy for {parsed.name} */")
3216
3217                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3218                print_req_policy(cw, struct)
3219                cw.nl()
3220
3221            for op_name, op in parsed.ops.items():
3222                if parsed.kernel_policy in {'per-op', 'split'}:
3223                    for op_mode in ['do', 'dump']:
3224                        if op_mode in op and 'request' in op[op_mode]:
3225                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3226                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3227                            print_req_policy(cw, ri.struct['request'], ri=ri)
3228                            cw.nl()
3229
3230            print_kernel_op_table(parsed, cw)
3231            print_kernel_mcgrp_src(parsed, cw)
3232            print_kernel_family_struct_src(parsed, cw)
3233
3234    if args.mode == "user":
3235        if args.header:
3236            cw.p('/* Enums */')
3237            put_op_name_fwd(parsed, cw)
3238
3239            for name, const in parsed.consts.items():
3240                if isinstance(const, EnumSet):
3241                    put_enum_to_str_fwd(parsed, cw, const)
3242            cw.nl()
3243
3244            cw.p('/* Common nested types */')
3245            for attr_set, struct in parsed.pure_nested_structs.items():
3246                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3247                print_type_full(ri, struct)
3248                if struct.request and struct.in_multi_val:
3249                    free_rsp_nested_prototype(ri)
3250                    cw.nl()
3251
3252            for op_name, op in parsed.ops.items():
3253                cw.p(f"/* ============== {op.enum_name} ============== */")
3254
3255                if 'do' in op and 'event' not in op:
3256                    cw.p(f"/* {op.enum_name} - do */")
3257                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3258                    print_req_type(ri)
3259                    print_req_type_helpers(ri)
3260                    cw.nl()
3261                    print_rsp_type(ri)
3262                    print_rsp_type_helpers(ri)
3263                    cw.nl()
3264                    print_req_prototype(ri)
3265                    cw.nl()
3266
3267                if 'dump' in op:
3268                    cw.p(f"/* {op.enum_name} - dump */")
3269                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3270                    print_req_type(ri)
3271                    print_req_type_helpers(ri)
3272                    if not ri.type_consistent or ri.type_oneside:
3273                        print_rsp_type(ri)
3274                    print_wrapped_type(ri)
3275                    print_dump_prototype(ri)
3276                    cw.nl()
3277
3278                if op.has_ntf:
3279                    cw.p(f"/* {op.enum_name} - notify */")
3280                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3281                    if not ri.type_consistent:
3282                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3283                    print_wrapped_type(ri)
3284
3285            for op_name, op in parsed.ntfs.items():
3286                if 'event' in op:
3287                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3288                    cw.p(f"/* {op.enum_name} - event */")
3289                    print_rsp_type(ri)
3290                    cw.nl()
3291                    print_wrapped_type(ri)
3292            cw.nl()
3293        else:
3294            cw.p('/* Enums */')
3295            put_op_name(parsed, cw)
3296
3297            for name, const in parsed.consts.items():
3298                if isinstance(const, EnumSet):
3299                    put_enum_to_str(parsed, cw, const)
3300            cw.nl()
3301
3302            has_recursive_nests = False
3303            cw.p('/* Policies */')
3304            for struct in parsed.pure_nested_structs.values():
3305                if struct.recursive:
3306                    put_typol_fwd(cw, struct)
3307                    has_recursive_nests = True
3308            if has_recursive_nests:
3309                cw.nl()
3310            for name in parsed.pure_nested_structs:
3311                struct = Struct(parsed, name)
3312                put_typol(cw, struct)
3313            for name in parsed.root_sets:
3314                struct = Struct(parsed, name)
3315                put_typol(cw, struct)
3316
3317            cw.p('/* Common nested types */')
3318            if has_recursive_nests:
3319                for attr_set, struct in parsed.pure_nested_structs.items():
3320                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3321                    free_rsp_nested_prototype(ri)
3322                    if struct.request:
3323                        put_req_nested_prototype(ri, struct)
3324                    if struct.reply:
3325                        parse_rsp_nested_prototype(ri, struct)
3326                cw.nl()
3327            for attr_set, struct in parsed.pure_nested_structs.items():
3328                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3329
3330                free_rsp_nested(ri, struct)
3331                if struct.request:
3332                    put_req_nested(ri, struct)
3333                if struct.reply:
3334                    parse_rsp_nested(ri, struct)
3335
3336            for op_name, op in parsed.ops.items():
3337                cw.p(f"/* ============== {op.enum_name} ============== */")
3338                if 'do' in op and 'event' not in op:
3339                    cw.p(f"/* {op.enum_name} - do */")
3340                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3341                    print_req_free(ri)
3342                    print_rsp_free(ri)
3343                    parse_rsp_msg(ri)
3344                    print_req(ri)
3345                    cw.nl()
3346
3347                if 'dump' in op:
3348                    cw.p(f"/* {op.enum_name} - dump */")
3349                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3350                    if not ri.type_consistent or ri.type_oneside:
3351                        parse_rsp_msg(ri, deref=True)
3352                    print_req_free(ri)
3353                    print_dump_type_free(ri)
3354                    print_dump(ri)
3355                    cw.nl()
3356
3357                if op.has_ntf:
3358                    cw.p(f"/* {op.enum_name} - notify */")
3359                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3360                    if not ri.type_consistent:
3361                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3362                    print_ntf_type_free(ri)
3363
3364            for op_name, op in parsed.ntfs.items():
3365                if 'event' in op:
3366                    cw.p(f"/* {op.enum_name} - event */")
3367
3368                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3369                    parse_rsp_msg(ri)
3370
3371                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3372                    print_ntf_type_free(ri)
3373            cw.nl()
3374            render_user_family(parsed, cw, False)
3375
3376    if args.header:
3377        cw.p(f'#endif /* {hdr_prot} */')
3378
3379
3380if __name__ == "__main__":
3381    main()
3382