xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision a66a170b68af0d588f4eadb9e5813c4edefe24f5)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17from lib import SpecSubMessage, SpecSubMessageFormat
18
19
20def c_upper(name):
21    return name.upper().replace('-', '_')
22
23
24def c_lower(name):
25    return name.lower().replace('-', '_')
26
27
28def limit_to_number(name):
29    """
30    Turn a string limit like u32-max or s64-min into its numerical value
31    """
32    if name[0] == 'u' and name.endswith('-min'):
33        return 0
34    width = int(name[1:-4])
35    if name[0] == 's':
36        width -= 1
37    value = (1 << width) - 1
38    if name[0] == 's' and name.endswith('-min'):
39        value = -value - 1
40    return value
41
42
43class BaseNlLib:
44    def get_family_id(self):
45        return 'ys->family_id'
46
47
48class Type(SpecAttr):
49    def __init__(self, family, attr_set, attr, value):
50        super().__init__(family, attr_set, attr, value)
51
52        self.attr = attr
53        self.attr_set = attr_set
54        self.type = attr['type']
55        self.checks = attr.get('checks', {})
56
57        self.request = False
58        self.reply = False
59
60        self.is_selector = False
61
62        if 'len' in attr:
63            self.len = attr['len']
64
65        if 'nested-attributes' in attr:
66            nested = attr['nested-attributes']
67        elif 'sub-message' in attr:
68            nested = attr['sub-message']
69        else:
70            nested = None
71
72        if nested:
73            self.nested_attrs = nested
74            if self.nested_attrs == family.name:
75                self.nested_render_name = c_lower(f"{family.ident_name}")
76            else:
77                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
78
79            if self.nested_attrs in self.family.consts:
80                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
81            else:
82                self.nested_struct_type = 'struct ' + self.nested_render_name
83
84        self.c_name = c_lower(self.name)
85        if self.c_name in _C_KW:
86            self.c_name += '_'
87        if self.c_name[0].isdigit():
88            self.c_name = '_' + self.c_name
89
90        # Added by resolve():
91        self.enum_name = None
92        delattr(self, "enum_name")
93
94    def _get_real_attr(self):
95        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
96        return self.family.attr_sets[self.attr_set.subset_of][self.name]
97
98    def set_request(self):
99        self.request = True
100        if self.attr_set.subset_of:
101            self._get_real_attr().set_request()
102
103    def set_reply(self):
104        self.reply = True
105        if self.attr_set.subset_of:
106            self._get_real_attr().set_reply()
107
108    def get_limit(self, limit, default=None):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return value
112        if isinstance(value, int):
113            return value
114        if value in self.family.consts:
115            return self.family.consts[value]["value"]
116        return limit_to_number(value)
117
118    def get_limit_str(self, limit, default=None, suffix=''):
119        value = self.checks.get(limit, default)
120        if value is None:
121            return ''
122        if isinstance(value, int):
123            return str(value) + suffix
124        if value in self.family.consts:
125            const = self.family.consts[value]
126            if const.get('header'):
127                return c_upper(value)
128            return c_upper(f"{self.family['name']}-{value}")
129        return c_upper(value)
130
131    def resolve(self):
132        if 'parent-sub-message' in self.attr:
133            enum_name = self.attr['parent-sub-message'].enum_name
134        elif 'name-prefix' in self.attr:
135            enum_name = f"{self.attr['name-prefix']}{self.name}"
136        else:
137            enum_name = f"{self.attr_set.name_prefix}{self.name}"
138        self.enum_name = c_upper(enum_name)
139
140        if self.attr_set.subset_of:
141            if self.checks != self._get_real_attr().checks:
142                raise Exception("Overriding checks not supported by codegen, yet")
143
144    def is_multi_val(self):
145        return None
146
147    def is_scalar(self):
148        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
149
150    def is_recursive(self):
151        return False
152
153    def is_recursive_for_op(self, ri):
154        return self.is_recursive() and not ri.op
155
156    def presence_type(self):
157        return 'present'
158
159    def presence_member(self, space, type_filter):
160        if self.presence_type() != type_filter:
161            return
162
163        if self.presence_type() == 'present':
164            pfx = '__' if space == 'user' else ''
165            return f"{pfx}u32 {self.c_name}:1;"
166
167        if self.presence_type() in {'len', 'count'}:
168            pfx = '__' if space == 'user' else ''
169            return f"{pfx}u32 {self.c_name};"
170
171    def _complex_member_type(self, ri):
172        return None
173
174    def free_needs_iter(self):
175        return False
176
177    def _free_lines(self, ri, var, ref):
178        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
179            return [f'free({var}->{ref}{self.c_name});']
180        return []
181
182    def free(self, ri, var, ref):
183        lines = self._free_lines(ri, var, ref)
184        for line in lines:
185            ri.cw.p(line)
186
187    def arg_member(self, ri):
188        member = self._complex_member_type(ri)
189        if member:
190            spc = ' ' if member[-1] != '*' else ''
191            arg = [member + spc + '*' + self.c_name]
192            if self.presence_type() == 'count':
193                arg += ['unsigned int n_' + self.c_name]
194            return arg
195        raise Exception(f"Struct member not implemented for class type {self.type}")
196
197    def struct_member(self, ri):
198        member = self._complex_member_type(ri)
199        if member:
200            ptr = '*' if self.is_multi_val() else ''
201            if self.is_recursive_for_op(ri):
202                ptr = '*'
203            spc = ' ' if member[-1] != '*' else ''
204            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
205            return
206        members = self.arg_member(ri)
207        for one in members:
208            ri.cw.p(one + ';')
209
210    def _attr_policy(self, policy):
211        return '{ .type = ' + policy + ', }'
212
213    def attr_policy(self, cw):
214        policy = f'NLA_{c_upper(self.type)}'
215        if self.attr.get('byte-order') == 'big-endian':
216            if self.type in {'u16', 'u32'}:
217                policy = f'NLA_BE{self.type[1:]}'
218
219        spec = self._attr_policy(policy)
220        cw.p(f"\t[{self.enum_name}] = {spec},")
221
222    def _attr_typol(self):
223        raise Exception(f"Type policy not implemented for class type {self.type}")
224
225    def attr_typol(self, cw):
226        typol = self._attr_typol()
227        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
228
229    def _attr_put_line(self, ri, var, line):
230        presence = self.presence_type()
231        if presence in {'present', 'len'}:
232            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
233        ri.cw.p(f"{line};")
234
235    def _attr_put_simple(self, ri, var, put_type):
236        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
237        self._attr_put_line(ri, var, line)
238
239    def attr_put(self, ri, var):
240        raise Exception(f"Put not implemented for class type {self.type}")
241
242    def _attr_get(self, ri, var):
243        raise Exception(f"Attr get not implemented for class type {self.type}")
244
245    def attr_get(self, ri, var, first):
246        lines, init_lines, local_vars = self._attr_get(ri, var)
247        if type(lines) is str:
248            lines = [lines]
249        if type(init_lines) is str:
250            init_lines = [init_lines]
251
252        kw = 'if' if first else 'else if'
253        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
254        if local_vars:
255            for local in local_vars:
256                ri.cw.p(local)
257            ri.cw.nl()
258
259        if not self.is_multi_val():
260            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
261            ri.cw.p("return YNL_PARSE_CB_ERROR;")
262            if self.presence_type() == 'present':
263                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
264
265        if init_lines:
266            ri.cw.nl()
267            for line in init_lines:
268                ri.cw.p(line)
269
270        for line in lines:
271            ri.cw.p(line)
272        ri.cw.block_end()
273        return True
274
275    def _setter_lines(self, ri, member, presence):
276        raise Exception(f"Setter not implemented for class type {self.type}")
277
278    def setter(self, ri, space, direction, deref=False, ref=None):
279        ref = (ref if ref else []) + [self.c_name]
280        var = "req"
281        member = f"{var}->{'.'.join(ref)}"
282
283        local_vars = []
284        if self.free_needs_iter():
285            local_vars += ['unsigned int i;']
286
287        code = []
288        presence = ''
289        for i in range(0, len(ref)):
290            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
291            # Every layer below last is a nest, so we know it uses bit presence
292            # last layer is "self" and may be a complex type
293            if i == len(ref) - 1 and self.presence_type() != 'present':
294                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
295                continue
296            code.append(presence + ' = 1;')
297        ref_path = '.'.join(ref[:-1])
298        if ref_path:
299            ref_path += '.'
300        code += self._free_lines(ri, var, ref_path)
301        code += self._setter_lines(ri, member, presence)
302
303        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
304        free = bool([x for x in code if 'free(' in x])
305        alloc = bool([x for x in code if 'alloc(' in x])
306        if free and not alloc:
307            func_name = '__' + func_name
308        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
309                         body=code,
310                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
311
312
313class TypeUnused(Type):
314    def presence_type(self):
315        return ''
316
317    def arg_member(self, ri):
318        return []
319
320    def _attr_get(self, ri, var):
321        return ['return YNL_PARSE_CB_ERROR;'], None, None
322
323    def _attr_typol(self):
324        return '.type = YNL_PT_REJECT, '
325
326    def attr_policy(self, cw):
327        pass
328
329    def attr_put(self, ri, var):
330        pass
331
332    def attr_get(self, ri, var, first):
333        pass
334
335    def setter(self, ri, space, direction, deref=False, ref=None):
336        pass
337
338
339class TypePad(Type):
340    def presence_type(self):
341        return ''
342
343    def arg_member(self, ri):
344        return []
345
346    def _attr_typol(self):
347        return '.type = YNL_PT_IGNORE, '
348
349    def attr_put(self, ri, var):
350        pass
351
352    def attr_get(self, ri, var, first):
353        pass
354
355    def attr_policy(self, cw):
356        pass
357
358    def setter(self, ri, space, direction, deref=False, ref=None):
359        pass
360
361
362class TypeScalar(Type):
363    def __init__(self, family, attr_set, attr, value):
364        super().__init__(family, attr_set, attr, value)
365
366        self.byte_order_comment = ''
367        if 'byte-order' in attr:
368            self.byte_order_comment = f" /* {attr['byte-order']} */"
369
370        # Classic families have some funny enums, don't bother
371        # computing checks, since we only need them for kernel policies
372        if not family.is_classic():
373            self._init_checks()
374
375        # Added by resolve():
376        self.is_bitfield = None
377        delattr(self, "is_bitfield")
378        self.type_name = None
379        delattr(self, "type_name")
380
381    def resolve(self):
382        self.resolve_up(super())
383
384        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
385            self.is_bitfield = True
386        elif 'enum' in self.attr:
387            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
388        else:
389            self.is_bitfield = False
390
391        if not self.is_bitfield and 'enum' in self.attr:
392            self.type_name = self.family.consts[self.attr['enum']].user_type
393        elif self.is_auto_scalar:
394            self.type_name = '__' + self.type[0] + '64'
395        else:
396            self.type_name = '__' + self.type
397
398    def _init_checks(self):
399        if 'enum' in self.attr:
400            enum = self.family.consts[self.attr['enum']]
401            low, high = enum.value_range()
402            if low == None and high == None:
403                self.checks['sparse'] = True
404            else:
405                if 'min' not in self.checks:
406                    if low != 0 or self.type[0] == 's':
407                        self.checks['min'] = low
408                if 'max' not in self.checks:
409                    self.checks['max'] = high
410
411        if 'min' in self.checks and 'max' in self.checks:
412            if self.get_limit('min') > self.get_limit('max'):
413                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
414            self.checks['range'] = True
415
416        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
417        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
418        if low < 0 and self.type[0] == 'u':
419            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
420        if low < -32768 or high > 32767:
421            self.checks['full-range'] = True
422
423    def _attr_policy(self, policy):
424        if 'flags-mask' in self.checks or self.is_bitfield:
425            if self.is_bitfield:
426                enum = self.family.consts[self.attr['enum']]
427                mask = enum.get_mask(as_flags=True)
428            else:
429                flags = self.family.consts[self.checks['flags-mask']]
430                flag_cnt = len(flags['entries'])
431                mask = (1 << flag_cnt) - 1
432            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
433        elif 'full-range' in self.checks:
434            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
435        elif 'range' in self.checks:
436            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
437        elif 'min' in self.checks:
438            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
439        elif 'max' in self.checks:
440            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
441        elif 'sparse' in self.checks:
442            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
443        return super()._attr_policy(policy)
444
445    def _attr_typol(self):
446        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
447
448    def arg_member(self, ri):
449        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
450
451    def attr_put(self, ri, var):
452        self._attr_put_simple(ri, var, self.type)
453
454    def _attr_get(self, ri, var):
455        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
456
457    def _setter_lines(self, ri, member, presence):
458        return [f"{member} = {self.c_name};"]
459
460
461class TypeFlag(Type):
462    def arg_member(self, ri):
463        return []
464
465    def _attr_typol(self):
466        return '.type = YNL_PT_FLAG, '
467
468    def attr_put(self, ri, var):
469        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
470
471    def _attr_get(self, ri, var):
472        return [], None, None
473
474    def _setter_lines(self, ri, member, presence):
475        return []
476
477
478class TypeString(Type):
479    def arg_member(self, ri):
480        return [f"const char *{self.c_name}"]
481
482    def presence_type(self):
483        return 'len'
484
485    def struct_member(self, ri):
486        ri.cw.p(f"char *{self.c_name};")
487
488    def _attr_typol(self):
489        typol = f'.type = YNL_PT_NUL_STR, '
490        if self.is_selector:
491            typol += '.is_selector = 1, '
492        return typol
493
494    def _attr_policy(self, policy):
495        if 'exact-len' in self.checks:
496            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
497        else:
498            mem = '{ .type = ' + policy
499            if 'max-len' in self.checks:
500                mem += ', .len = ' + self.get_limit_str('max-len')
501            mem += ', }'
502        return mem
503
504    def attr_policy(self, cw):
505        if self.checks.get('unterminated-ok', False):
506            policy = 'NLA_STRING'
507        else:
508            policy = 'NLA_NUL_STRING'
509
510        spec = self._attr_policy(policy)
511        cw.p(f"\t[{self.enum_name}] = {spec},")
512
513    def attr_put(self, ri, var):
514        self._attr_put_simple(ri, var, 'str')
515
516    def _attr_get(self, ri, var):
517        len_mem = var + '->_len.' + self.c_name
518        return [f"{len_mem} = len;",
519                f"{var}->{self.c_name} = malloc(len + 1);",
520                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
521                f"{var}->{self.c_name}[len] = 0;"], \
522               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
523               ['unsigned int len;']
524
525    def _setter_lines(self, ri, member, presence):
526        return [f"{presence} = strlen({self.c_name});",
527                f"{member} = malloc({presence} + 1);",
528                f'memcpy({member}, {self.c_name}, {presence});',
529                f'{member}[{presence}] = 0;']
530
531
532class TypeBinary(Type):
533    def arg_member(self, ri):
534        return [f"const void *{self.c_name}", 'size_t len']
535
536    def presence_type(self):
537        return 'len'
538
539    def struct_member(self, ri):
540        ri.cw.p(f"void *{self.c_name};")
541
542    def _attr_typol(self):
543        return f'.type = YNL_PT_BINARY,'
544
545    def _attr_policy(self, policy):
546        if len(self.checks) == 0:
547            pass
548        elif len(self.checks) == 1:
549            check_name = list(self.checks)[0]
550            if check_name not in {'exact-len', 'min-len', 'max-len'}:
551                raise Exception('Unsupported check for binary type: ' + check_name)
552        else:
553            raise Exception('More than one check for binary type not implemented, yet')
554
555        if len(self.checks) == 0:
556            mem = '{ .type = NLA_BINARY, }'
557        elif 'exact-len' in self.checks:
558            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
559        elif 'min-len' in self.checks:
560            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
561        elif 'max-len' in self.checks:
562            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
563
564        return mem
565
566    def attr_put(self, ri, var):
567        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
568                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
569
570    def _attr_get(self, ri, var):
571        len_mem = var + '->_len.' + self.c_name
572        return [f"{len_mem} = len;",
573                f"{var}->{self.c_name} = malloc(len);",
574                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
575               ['len = ynl_attr_data_len(attr);'], \
576               ['unsigned int len;']
577
578    def _setter_lines(self, ri, member, presence):
579        return [f"{presence} = len;",
580                f"{member} = malloc({presence});",
581                f'memcpy({member}, {self.c_name}, {presence});']
582
583
584class TypeBinaryStruct(TypeBinary):
585    def struct_member(self, ri):
586        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
587
588    def _attr_get(self, ri, var):
589        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
590        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
591        return [f"{len_mem} = len;",
592                f"if (len < {struct_sz})",
593                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
594                "else",
595                f"{var}->{self.c_name} = malloc(len);",
596                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
597               ['len = ynl_attr_data_len(attr);'], \
598               ['unsigned int len;']
599
600
601class TypeBinaryScalarArray(TypeBinary):
602    def arg_member(self, ri):
603        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
604
605    def presence_type(self):
606        return 'count'
607
608    def struct_member(self, ri):
609        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
610
611    def attr_put(self, ri, var):
612        presence = self.presence_type()
613        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
614        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
615        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
616                f"{var}->{self.c_name}, i);")
617        ri.cw.block_end()
618
619    def _attr_get(self, ri, var):
620        len_mem = var + '->_count.' + self.c_name
621        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
622                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
623                f"{var}->{self.c_name} = malloc(len);",
624                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
625               ['len = ynl_attr_data_len(attr);'], \
626               ['unsigned int len;']
627
628    def _setter_lines(self, ri, member, presence):
629        return [f"{presence} = count;",
630                f"count *= sizeof(__{self.get('sub-type')});",
631                f"{member} = malloc(count);",
632                f'memcpy({member}, {self.c_name}, count);']
633
634
635class TypeBitfield32(Type):
636    def _complex_member_type(self, ri):
637        return "struct nla_bitfield32"
638
639    def _attr_typol(self):
640        return f'.type = YNL_PT_BITFIELD32, '
641
642    def _attr_policy(self, policy):
643        if not 'enum' in self.attr:
644            raise Exception('Enum required for bitfield32 attr')
645        enum = self.family.consts[self.attr['enum']]
646        mask = enum.get_mask(as_flags=True)
647        return f"NLA_POLICY_BITFIELD32({mask})"
648
649    def attr_put(self, ri, var):
650        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
651        self._attr_put_line(ri, var, line)
652
653    def _attr_get(self, ri, var):
654        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
655
656    def _setter_lines(self, ri, member, presence):
657        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
658
659
660class TypeNest(Type):
661    def is_recursive(self):
662        return self.family.pure_nested_structs[self.nested_attrs].recursive
663
664    def _complex_member_type(self, ri):
665        return self.nested_struct_type
666
667    def _free_lines(self, ri, var, ref):
668        lines = []
669        at = '&'
670        if self.is_recursive_for_op(ri):
671            at = ''
672            lines += [f'if ({var}->{ref}{self.c_name})']
673        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
674        return lines
675
676    def _attr_typol(self):
677        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
678
679    def _attr_policy(self, policy):
680        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
681
682    def attr_put(self, ri, var):
683        at = '' if self.is_recursive_for_op(ri) else '&'
684        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
685                            f"{self.enum_name}, {at}{var}->{self.c_name})")
686
687    def _attr_get(self, ri, var):
688        pns = self.family.pure_nested_structs[self.nested_attrs]
689        args = ["&parg", "attr"]
690        for sel in pns.external_selectors():
691            args.append(f'{var}->{sel.name}')
692        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
693                     "return YNL_PARSE_CB_ERROR;"]
694        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
695                      f"parg.data = &{var}->{self.c_name};"]
696        return get_lines, init_lines, None
697
698    def setter(self, ri, space, direction, deref=False, ref=None):
699        ref = (ref if ref else []) + [self.c_name]
700
701        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
702            if attr.is_recursive():
703                continue
704            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
705
706
707class TypeMultiAttr(Type):
708    def __init__(self, family, attr_set, attr, value, base_type):
709        super().__init__(family, attr_set, attr, value)
710
711        self.base_type = base_type
712
713    def is_multi_val(self):
714        return True
715
716    def presence_type(self):
717        return 'count'
718
719    def _complex_member_type(self, ri):
720        if 'type' not in self.attr or self.attr['type'] == 'nest':
721            return self.nested_struct_type
722        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
723            return None  # use arg_member()
724        elif self.attr['type'] == 'string':
725            return 'struct ynl_string *'
726        elif self.attr['type'] in scalars:
727            scalar_pfx = '__' if ri.ku_space == 'user' else ''
728            return scalar_pfx + self.attr['type']
729        else:
730            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
731
732    def arg_member(self, ri):
733        if self.type == 'binary' and 'struct' in self.attr:
734            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
735                    f'unsigned int n_{self.c_name}']
736        return super().arg_member(ri)
737
738    def free_needs_iter(self):
739        return self.attr['type'] in {'nest', 'string'}
740
741    def _free_lines(self, ri, var, ref):
742        lines = []
743        if self.attr['type'] in scalars:
744            lines += [f"free({var}->{ref}{self.c_name});"]
745        elif self.attr['type'] == 'binary':
746            lines += [f"free({var}->{ref}{self.c_name});"]
747        elif self.attr['type'] == 'string':
748            lines += [
749                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
750                f"free({var}->{ref}{self.c_name}[i]);",
751                f"free({var}->{ref}{self.c_name});",
752            ]
753        elif 'type' not in self.attr or self.attr['type'] == 'nest':
754            lines += [
755                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
756                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
757                f"free({var}->{ref}{self.c_name});",
758            ]
759        else:
760            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
761        return lines
762
763    def _attr_policy(self, policy):
764        return self.base_type._attr_policy(policy)
765
766    def _attr_typol(self):
767        return self.base_type._attr_typol()
768
769    def _attr_get(self, ri, var):
770        return f'n_{self.c_name}++;', None, None
771
772    def attr_put(self, ri, var):
773        if self.attr['type'] in scalars:
774            put_type = self.type
775            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
776            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
777        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
778            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
779            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
780        elif self.attr['type'] == 'string':
781            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
782            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
783        elif 'type' not in self.attr or self.attr['type'] == 'nest':
784            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
785            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
786                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
787        else:
788            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
789
790    def _setter_lines(self, ri, member, presence):
791        return [f"{member} = {self.c_name};",
792                f"{presence} = n_{self.c_name};"]
793
794
795class TypeArrayNest(Type):
796    def is_multi_val(self):
797        return True
798
799    def presence_type(self):
800        return 'count'
801
802    def _complex_member_type(self, ri):
803        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
804            return self.nested_struct_type
805        elif self.attr['sub-type'] in scalars:
806            scalar_pfx = '__' if ri.ku_space == 'user' else ''
807            return scalar_pfx + self.attr['sub-type']
808        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
809            return None  # use arg_member()
810        else:
811            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
812
813    def arg_member(self, ri):
814        if self.sub_type == 'binary' and 'exact-len' in self.checks:
815            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
816                    f'unsigned int n_{self.c_name}']
817        return super().arg_member(ri)
818
819    def _attr_typol(self):
820        if self.attr['sub-type'] in scalars:
821            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
822        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
823            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
824        else:
825            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
826
827    def _attr_get(self, ri, var):
828        local_vars = ['const struct nlattr *attr2;']
829        get_lines = [f'attr_{self.c_name} = attr;',
830                     'ynl_attr_for_each_nested(attr2, attr) {',
831                     '\tif (ynl_attr_validate(yarg, attr2))',
832                     '\t\treturn YNL_PARSE_CB_ERROR;',
833                     f'\t{var}->_count.{self.c_name}++;',
834                     '}']
835        return get_lines, None, local_vars
836
837    def attr_put(self, ri, var):
838        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
839        if self.sub_type in scalars:
840            put_type = self.sub_type
841            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
842            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
843            ri.cw.block_end()
844        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
845            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
846            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
847        elif self.sub_type == 'nest':
848            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
849            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
850        else:
851            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
852        ri.cw.p('ynl_attr_nest_end(nlh, array);')
853
854    def _setter_lines(self, ri, member, presence):
855        return [f"{member} = {self.c_name};",
856                f"{presence} = n_{self.c_name};"]
857
858
859class TypeNestTypeValue(Type):
860    def _complex_member_type(self, ri):
861        return self.nested_struct_type
862
863    def _attr_typol(self):
864        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
865
866    def _attr_get(self, ri, var):
867        prev = 'attr'
868        tv_args = ''
869        get_lines = []
870        local_vars = []
871        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
872                      f"parg.data = &{var}->{self.c_name};"]
873        if 'type-value' in self.attr:
874            tv_names = [c_lower(x) for x in self.attr["type-value"]]
875            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
876            local_vars += [f'__u32 {", ".join(tv_names)};']
877            for level in self.attr["type-value"]:
878                level = c_lower(level)
879                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
880                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
881                prev = 'attr_' + level
882
883            tv_args = f", {', '.join(tv_names)}"
884
885        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
886        return get_lines, init_lines, local_vars
887
888
889class TypeSubMessage(TypeNest):
890    def __init__(self, family, attr_set, attr, value):
891        super().__init__(family, attr_set, attr, value)
892
893        self.selector = Selector(attr, attr_set)
894
895    def _attr_typol(self):
896        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
897        typol += '.is_submsg = 1, '
898        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
899        # support external selectors. No family uses sub-messages with external
900        # selector for requests so this is fine for now.
901        if not self.selector.is_external():
902            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
903        return typol
904
905    def _attr_get(self, ri, var):
906        sel = c_lower(self['selector'])
907        if self.selector.is_external():
908            sel_var = f"_sel_{sel}"
909        else:
910            sel_var = f"{var}->{sel}"
911        get_lines = [f'if (!{sel_var})',
912                     f'return ynl_submsg_failed(yarg, "%s", "%s");' %
913                        (self.name, self['selector']),
914                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
915                     "return YNL_PARSE_CB_ERROR;"]
916        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
917                      f"parg.data = &{var}->{self.c_name};"]
918        return get_lines, init_lines, None
919
920
921class Selector:
922    def __init__(self, msg_attr, attr_set):
923        self.name = msg_attr["selector"]
924
925        if self.name in attr_set:
926            self.attr = attr_set[self.name]
927            self.attr.is_selector = True
928            self._external = False
929        else:
930            # The selector will need to get passed down thru the structs
931            self.attr = None
932            self._external = True
933
934    def set_attr(self, attr):
935        self.attr = attr
936
937    def is_external(self):
938        return self._external
939
940
941class Struct:
942    def __init__(self, family, space_name, type_list=None, fixed_header=None,
943                 inherited=None, submsg=None):
944        self.family = family
945        self.space_name = space_name
946        self.attr_set = family.attr_sets[space_name]
947        # Use list to catch comparisons with empty sets
948        self._inherited = inherited if inherited is not None else []
949        self.inherited = []
950        self.fixed_header = None
951        if fixed_header:
952            self.fixed_header = 'struct ' + c_lower(fixed_header)
953        self.submsg = submsg
954
955        self.nested = type_list is None
956        if family.name == c_lower(space_name):
957            self.render_name = c_lower(family.ident_name)
958        else:
959            self.render_name = c_lower(family.ident_name + '-' + space_name)
960        self.struct_name = 'struct ' + self.render_name
961        if self.nested and space_name in family.consts:
962            self.struct_name += '_'
963        self.ptr_name = self.struct_name + ' *'
964        # All attr sets this one contains, directly or multiple levels down
965        self.child_nests = set()
966
967        self.request = False
968        self.reply = False
969        self.recursive = False
970        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
971
972        self.attr_list = []
973        self.attrs = dict()
974        if type_list is not None:
975            for t in type_list:
976                self.attr_list.append((t, self.attr_set[t]),)
977        else:
978            for t in self.attr_set:
979                self.attr_list.append((t, self.attr_set[t]),)
980
981        max_val = 0
982        self.attr_max_val = None
983        for name, attr in self.attr_list:
984            if attr.value >= max_val:
985                max_val = attr.value
986                self.attr_max_val = attr
987            self.attrs[name] = attr
988
989    def __iter__(self):
990        yield from self.attrs
991
992    def __getitem__(self, key):
993        return self.attrs[key]
994
995    def member_list(self):
996        return self.attr_list
997
998    def set_inherited(self, new_inherited):
999        if self._inherited != new_inherited:
1000            raise Exception("Inheriting different members not supported")
1001        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1002
1003    def external_selectors(self):
1004        sels = []
1005        for name, attr in self.attr_list:
1006            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1007                sels.append(attr.selector)
1008        return sels
1009
1010    def free_needs_iter(self):
1011        for _, attr in self.attr_list:
1012            if attr.free_needs_iter():
1013                return True
1014        return False
1015
1016
1017class EnumEntry(SpecEnumEntry):
1018    def __init__(self, enum_set, yaml, prev, value_start):
1019        super().__init__(enum_set, yaml, prev, value_start)
1020
1021        if prev:
1022            self.value_change = (self.value != prev.value + 1)
1023        else:
1024            self.value_change = (self.value != 0)
1025        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1026
1027        # Added by resolve:
1028        self.c_name = None
1029        delattr(self, "c_name")
1030
1031    def resolve(self):
1032        self.resolve_up(super())
1033
1034        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1035
1036
1037class EnumSet(SpecEnumSet):
1038    def __init__(self, family, yaml):
1039        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1040
1041        if 'enum-name' in yaml:
1042            if yaml['enum-name']:
1043                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1044                self.user_type = self.enum_name
1045            else:
1046                self.enum_name = None
1047        else:
1048            self.enum_name = 'enum ' + self.render_name
1049
1050        if self.enum_name:
1051            self.user_type = self.enum_name
1052        else:
1053            self.user_type = 'int'
1054
1055        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1056        self.header = yaml.get('header', None)
1057        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1058
1059        super().__init__(family, yaml)
1060
1061    def new_entry(self, entry, prev_entry, value_start):
1062        return EnumEntry(self, entry, prev_entry, value_start)
1063
1064    def value_range(self):
1065        low = min([x.value for x in self.entries.values()])
1066        high = max([x.value for x in self.entries.values()])
1067
1068        if high - low + 1 != len(self.entries):
1069            return None, None
1070
1071        return low, high
1072
1073
1074class AttrSet(SpecAttrSet):
1075    def __init__(self, family, yaml):
1076        super().__init__(family, yaml)
1077
1078        if self.subset_of is None:
1079            if 'name-prefix' in yaml:
1080                pfx = yaml['name-prefix']
1081            elif self.name == family.name:
1082                pfx = family.ident_name + '-a-'
1083            else:
1084                pfx = f"{family.ident_name}-a-{self.name}-"
1085            self.name_prefix = c_upper(pfx)
1086            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1087            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1088        else:
1089            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1090            self.max_name = family.attr_sets[self.subset_of].max_name
1091            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1092
1093        # Added by resolve:
1094        self.c_name = None
1095        delattr(self, "c_name")
1096
1097    def resolve(self):
1098        self.c_name = c_lower(self.name)
1099        if self.c_name in _C_KW:
1100            self.c_name += '_'
1101        if self.c_name == self.family.c_name:
1102            self.c_name = ''
1103
1104    def new_attr(self, elem, value):
1105        if elem['type'] in scalars:
1106            t = TypeScalar(self.family, self, elem, value)
1107        elif elem['type'] == 'unused':
1108            t = TypeUnused(self.family, self, elem, value)
1109        elif elem['type'] == 'pad':
1110            t = TypePad(self.family, self, elem, value)
1111        elif elem['type'] == 'flag':
1112            t = TypeFlag(self.family, self, elem, value)
1113        elif elem['type'] == 'string':
1114            t = TypeString(self.family, self, elem, value)
1115        elif elem['type'] == 'binary':
1116            if 'struct' in elem:
1117                t = TypeBinaryStruct(self.family, self, elem, value)
1118            elif elem.get('sub-type') in scalars:
1119                t = TypeBinaryScalarArray(self.family, self, elem, value)
1120            else:
1121                t = TypeBinary(self.family, self, elem, value)
1122        elif elem['type'] == 'bitfield32':
1123            t = TypeBitfield32(self.family, self, elem, value)
1124        elif elem['type'] == 'nest':
1125            t = TypeNest(self.family, self, elem, value)
1126        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1127            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1128                t = TypeArrayNest(self.family, self, elem, value)
1129            else:
1130                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1131        elif elem['type'] == 'nest-type-value':
1132            t = TypeNestTypeValue(self.family, self, elem, value)
1133        elif elem['type'] == 'sub-message':
1134            t = TypeSubMessage(self.family, self, elem, value)
1135        else:
1136            raise Exception(f"No typed class for type {elem['type']}")
1137
1138        if 'multi-attr' in elem and elem['multi-attr']:
1139            t = TypeMultiAttr(self.family, self, elem, value, t)
1140
1141        return t
1142
1143
1144class Operation(SpecOperation):
1145    def __init__(self, family, yaml, req_value, rsp_value):
1146        # Fill in missing operation properties (for fixed hdr-only msgs)
1147        for mode in ['do', 'dump', 'event']:
1148            for direction in ['request', 'reply']:
1149                try:
1150                    yaml[mode][direction].setdefault('attributes', [])
1151                except KeyError:
1152                    pass
1153
1154        super().__init__(family, yaml, req_value, rsp_value)
1155
1156        self.render_name = c_lower(family.ident_name + '_' + self.name)
1157
1158        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1159                         ('dump' in yaml and 'request' in yaml['dump'])
1160
1161        self.has_ntf = False
1162
1163        # Added by resolve:
1164        self.enum_name = None
1165        delattr(self, "enum_name")
1166
1167    def resolve(self):
1168        self.resolve_up(super())
1169
1170        if not self.is_async:
1171            self.enum_name = self.family.op_prefix + c_upper(self.name)
1172        else:
1173            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1174
1175    def mark_has_ntf(self):
1176        self.has_ntf = True
1177
1178
1179class SubMessage(SpecSubMessage):
1180    def __init__(self, family, yaml):
1181        super().__init__(family, yaml)
1182
1183        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1184
1185    def resolve(self):
1186        self.resolve_up(super())
1187
1188
1189class Family(SpecFamily):
1190    def __init__(self, file_name, exclude_ops):
1191        # Added by resolve:
1192        self.c_name = None
1193        delattr(self, "c_name")
1194        self.op_prefix = None
1195        delattr(self, "op_prefix")
1196        self.async_op_prefix = None
1197        delattr(self, "async_op_prefix")
1198        self.mcgrps = None
1199        delattr(self, "mcgrps")
1200        self.consts = None
1201        delattr(self, "consts")
1202        self.hooks = None
1203        delattr(self, "hooks")
1204
1205        super().__init__(file_name, exclude_ops=exclude_ops)
1206
1207        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1208        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1209
1210        if 'definitions' not in self.yaml:
1211            self.yaml['definitions'] = []
1212
1213        if 'uapi-header' in self.yaml:
1214            self.uapi_header = self.yaml['uapi-header']
1215        else:
1216            self.uapi_header = f"linux/{self.ident_name}.h"
1217        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1218            self.uapi_header_name = self.uapi_header[6:-2]
1219        else:
1220            self.uapi_header_name = self.ident_name
1221
1222    def resolve(self):
1223        self.resolve_up(super())
1224
1225        self.c_name = c_lower(self.ident_name)
1226        if 'name-prefix' in self.yaml['operations']:
1227            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1228        else:
1229            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1230        if 'async-prefix' in self.yaml['operations']:
1231            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1232        else:
1233            self.async_op_prefix = self.op_prefix
1234
1235        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1236
1237        self.hooks = dict()
1238        for when in ['pre', 'post']:
1239            self.hooks[when] = dict()
1240            for op_mode in ['do', 'dump']:
1241                self.hooks[when][op_mode] = dict()
1242                self.hooks[when][op_mode]['set'] = set()
1243                self.hooks[when][op_mode]['list'] = []
1244
1245        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1246        self.root_sets = dict()
1247        # dict space-name -> Struct
1248        self.pure_nested_structs = dict()
1249
1250        self._mark_notify()
1251        self._mock_up_events()
1252
1253        self._load_root_sets()
1254        self._load_nested_sets()
1255        self._load_attr_use()
1256        self._load_selector_passing()
1257        self._load_hooks()
1258
1259        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1260        if self.kernel_policy == 'global':
1261            self._load_global_policy()
1262
1263    def new_enum(self, elem):
1264        return EnumSet(self, elem)
1265
1266    def new_attr_set(self, elem):
1267        return AttrSet(self, elem)
1268
1269    def new_operation(self, elem, req_value, rsp_value):
1270        return Operation(self, elem, req_value, rsp_value)
1271
1272    def new_sub_message(self, elem):
1273        return SubMessage(self, elem)
1274
1275    def is_classic(self):
1276        return self.proto == 'netlink-raw'
1277
1278    def _mark_notify(self):
1279        for op in self.msgs.values():
1280            if 'notify' in op:
1281                self.ops[op['notify']].mark_has_ntf()
1282
1283    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1284    def _mock_up_events(self):
1285        for op in self.yaml['operations']['list']:
1286            if 'event' in op:
1287                op['do'] = {
1288                    'reply': {
1289                        'attributes': op['event']['attributes']
1290                    }
1291                }
1292
1293    def _load_root_sets(self):
1294        for op_name, op in self.msgs.items():
1295            if 'attribute-set' not in op:
1296                continue
1297
1298            req_attrs = set()
1299            rsp_attrs = set()
1300            for op_mode in ['do', 'dump']:
1301                if op_mode in op and 'request' in op[op_mode]:
1302                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1303                if op_mode in op and 'reply' in op[op_mode]:
1304                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1305            if 'event' in op:
1306                rsp_attrs.update(set(op['event']['attributes']))
1307
1308            if op['attribute-set'] not in self.root_sets:
1309                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1310            else:
1311                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1312                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1313
1314    def _sort_pure_types(self):
1315        # Try to reorder according to dependencies
1316        pns_key_list = list(self.pure_nested_structs.keys())
1317        pns_key_seen = set()
1318        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1319        for _ in range(rounds):
1320            if len(pns_key_list) == 0:
1321                break
1322            name = pns_key_list.pop(0)
1323            finished = True
1324            for _, spec in self.attr_sets[name].items():
1325                if 'nested-attributes' in spec:
1326                    nested = spec['nested-attributes']
1327                elif 'sub-message' in spec:
1328                    nested = spec.sub_message
1329                else:
1330                    continue
1331
1332                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1333                if self.pure_nested_structs[nested].recursive:
1334                    continue
1335                if nested not in pns_key_seen:
1336                    # Dicts are sorted, this will make struct last
1337                    struct = self.pure_nested_structs.pop(name)
1338                    self.pure_nested_structs[name] = struct
1339                    finished = False
1340                    break
1341            if finished:
1342                pns_key_seen.add(name)
1343            else:
1344                pns_key_list.append(name)
1345
1346    def _load_nested_set_nest(self, spec):
1347        inherit = set()
1348        nested = spec['nested-attributes']
1349        if nested not in self.root_sets:
1350            if nested not in self.pure_nested_structs:
1351                self.pure_nested_structs[nested] = \
1352                    Struct(self, nested, inherited=inherit,
1353                           fixed_header=spec.get('fixed-header'))
1354        else:
1355            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1356
1357        if 'type-value' in spec:
1358            if nested in self.root_sets:
1359                raise Exception("Inheriting members to a space used as root not supported")
1360            inherit.update(set(spec['type-value']))
1361        elif spec['type'] == 'indexed-array':
1362            inherit.add('idx')
1363        self.pure_nested_structs[nested].set_inherited(inherit)
1364
1365        return nested
1366
1367    def _load_nested_set_submsg(self, spec):
1368        # Fake the struct type for the sub-message itself
1369        # its not a attr_set but codegen wants attr_sets.
1370        submsg = self.sub_msgs[spec["sub-message"]]
1371        nested = submsg.name
1372
1373        attrs = []
1374        for name, fmt in submsg.formats.items():
1375            attrs.append({
1376                "name": name,
1377                "type": "nest",
1378                "parent-sub-message": spec,
1379                "nested-attributes": fmt['attribute-set']
1380            })
1381
1382        self.attr_sets[nested] = AttrSet(self, {
1383            "name": nested,
1384            "name-pfx": self.name + '-' + spec.name + '-',
1385            "attributes": attrs
1386        })
1387
1388        if nested not in self.pure_nested_structs:
1389            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1390
1391        return nested
1392
1393    def _load_nested_sets(self):
1394        attr_set_queue = list(self.root_sets.keys())
1395        attr_set_seen = set(self.root_sets.keys())
1396
1397        while len(attr_set_queue):
1398            a_set = attr_set_queue.pop(0)
1399            for attr, spec in self.attr_sets[a_set].items():
1400                if 'nested-attributes' in spec:
1401                    nested = self._load_nested_set_nest(spec)
1402                elif 'sub-message' in spec:
1403                    nested = self._load_nested_set_submsg(spec)
1404                else:
1405                    continue
1406
1407                if nested not in attr_set_seen:
1408                    attr_set_queue.append(nested)
1409                    attr_set_seen.add(nested)
1410
1411        for root_set, rs_members in self.root_sets.items():
1412            for attr, spec in self.attr_sets[root_set].items():
1413                if 'nested-attributes' in spec:
1414                    nested = spec['nested-attributes']
1415                elif 'sub-message' in spec:
1416                    nested = spec.sub_message
1417                else:
1418                    nested = None
1419
1420                if nested:
1421                    if attr in rs_members['request']:
1422                        self.pure_nested_structs[nested].request = True
1423                    if attr in rs_members['reply']:
1424                        self.pure_nested_structs[nested].reply = True
1425
1426                    if spec.is_multi_val():
1427                        child = self.pure_nested_structs.get(nested)
1428                        child.in_multi_val = True
1429
1430        self._sort_pure_types()
1431
1432        # Propagate the request / reply / recursive
1433        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1434            for _, spec in self.attr_sets[attr_set].items():
1435                if attr_set in struct.child_nests:
1436                    struct.recursive = True
1437
1438                if 'nested-attributes' in spec:
1439                    child_name = spec['nested-attributes']
1440                elif 'sub-message' in spec:
1441                    child_name = spec.sub_message
1442                else:
1443                    continue
1444
1445                struct.child_nests.add(child_name)
1446                child = self.pure_nested_structs.get(child_name)
1447                if child:
1448                    if not child.recursive:
1449                        struct.child_nests.update(child.child_nests)
1450                    child.request |= struct.request
1451                    child.reply |= struct.reply
1452                    if spec.is_multi_val():
1453                        child.in_multi_val = True
1454
1455        self._sort_pure_types()
1456
1457    def _load_attr_use(self):
1458        for _, struct in self.pure_nested_structs.items():
1459            if struct.request:
1460                for _, arg in struct.member_list():
1461                    arg.set_request()
1462            if struct.reply:
1463                for _, arg in struct.member_list():
1464                    arg.set_reply()
1465
1466        for root_set, rs_members in self.root_sets.items():
1467            for attr, spec in self.attr_sets[root_set].items():
1468                if attr in rs_members['request']:
1469                    spec.set_request()
1470                if attr in rs_members['reply']:
1471                    spec.set_reply()
1472
1473    def _load_selector_passing(self):
1474        def all_structs():
1475            for k, v in reversed(self.pure_nested_structs.items()):
1476                yield k, v
1477            for k, _ in self.root_sets.items():
1478                yield k, None  # we don't have a struct, but it must be terminal
1479
1480        for attr_set, struct in all_structs():
1481            for _, spec in self.attr_sets[attr_set].items():
1482                if 'nested-attributes' in spec:
1483                    child_name = spec['nested-attributes']
1484                elif 'sub-message' in spec:
1485                    child_name = spec.sub_message
1486                else:
1487                    continue
1488
1489                child = self.pure_nested_structs.get(child_name)
1490                for selector in child.external_selectors():
1491                    if selector.name in self.attr_sets[attr_set]:
1492                        sel_attr = self.attr_sets[attr_set][selector.name]
1493                        selector.set_attr(sel_attr)
1494                    else:
1495                        raise Exception("Passing selector thru more than one layer not supported")
1496
1497    def _load_global_policy(self):
1498        global_set = set()
1499        attr_set_name = None
1500        for op_name, op in self.ops.items():
1501            if not op:
1502                continue
1503            if 'attribute-set' not in op:
1504                continue
1505
1506            if attr_set_name is None:
1507                attr_set_name = op['attribute-set']
1508            if attr_set_name != op['attribute-set']:
1509                raise Exception('For a global policy all ops must use the same set')
1510
1511            for op_mode in ['do', 'dump']:
1512                if op_mode in op:
1513                    req = op[op_mode].get('request')
1514                    if req:
1515                        global_set.update(req.get('attributes', []))
1516
1517        self.global_policy = []
1518        self.global_policy_set = attr_set_name
1519        for attr in self.attr_sets[attr_set_name]:
1520            if attr in global_set:
1521                self.global_policy.append(attr)
1522
1523    def _load_hooks(self):
1524        for op in self.ops.values():
1525            for op_mode in ['do', 'dump']:
1526                if op_mode not in op:
1527                    continue
1528                for when in ['pre', 'post']:
1529                    if when not in op[op_mode]:
1530                        continue
1531                    name = op[op_mode][when]
1532                    if name in self.hooks[when][op_mode]['set']:
1533                        continue
1534                    self.hooks[when][op_mode]['set'].add(name)
1535                    self.hooks[when][op_mode]['list'].append(name)
1536
1537
1538class RenderInfo:
1539    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1540        self.family = family
1541        self.nl = cw.nlib
1542        self.ku_space = ku_space
1543        self.op_mode = op_mode
1544        self.op = op
1545
1546        fixed_hdr = op.fixed_header if op else None
1547        self.fixed_hdr_len = 'ys->family->hdr_len'
1548        if op and op.fixed_header:
1549            if op.fixed_header != family.fixed_header:
1550                if family.is_classic():
1551                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1552                else:
1553                    raise Exception(f"Per-op fixed header not supported, yet")
1554
1555
1556        # 'do' and 'dump' response parsing is identical
1557        self.type_consistent = True
1558        self.type_oneside = False
1559        if op_mode != 'do' and 'dump' in op:
1560            if 'do' in op:
1561                if ('reply' in op['do']) != ('reply' in op["dump"]):
1562                    self.type_consistent = False
1563                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1564                    self.type_consistent = False
1565            else:
1566                self.type_consistent = True
1567                self.type_oneside = True
1568
1569        self.attr_set = attr_set
1570        if not self.attr_set:
1571            self.attr_set = op['attribute-set']
1572
1573        self.type_name_conflict = False
1574        if op:
1575            self.type_name = c_lower(op.name)
1576        else:
1577            self.type_name = c_lower(attr_set)
1578            if attr_set in family.consts:
1579                self.type_name_conflict = True
1580
1581        self.cw = cw
1582
1583        self.struct = dict()
1584        if op_mode == 'notify':
1585            op_mode = 'do' if 'do' in op else 'dump'
1586        for op_dir in ['request', 'reply']:
1587            if op:
1588                type_list = []
1589                if op_dir in op[op_mode]:
1590                    type_list = op[op_mode][op_dir]['attributes']
1591                self.struct[op_dir] = Struct(family, self.attr_set,
1592                                             fixed_header=fixed_hdr,
1593                                             type_list=type_list)
1594        if op_mode == 'event':
1595            self.struct['reply'] = Struct(family, self.attr_set,
1596                                          fixed_header=fixed_hdr,
1597                                          type_list=op['event']['attributes'])
1598
1599    def type_empty(self, key):
1600        return len(self.struct[key].attr_list) == 0 and \
1601            self.struct['request'].fixed_header is None
1602
1603    def needs_nlflags(self, direction):
1604        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1605
1606
1607class CodeWriter:
1608    def __init__(self, nlib, out_file=None, overwrite=True):
1609        self.nlib = nlib
1610        self._overwrite = overwrite
1611
1612        self._nl = False
1613        self._block_end = False
1614        self._silent_block = False
1615        self._ind = 0
1616        self._ifdef_block = None
1617        if out_file is None:
1618            self._out = os.sys.stdout
1619        else:
1620            self._out = tempfile.NamedTemporaryFile('w+')
1621            self._out_file = out_file
1622
1623    def __del__(self):
1624        self.close_out_file()
1625
1626    def close_out_file(self):
1627        if self._out == os.sys.stdout:
1628            return
1629        # Avoid modifying the file if contents didn't change
1630        self._out.flush()
1631        if not self._overwrite and os.path.isfile(self._out_file):
1632            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1633                return
1634        with open(self._out_file, 'w+') as out_file:
1635            self._out.seek(0)
1636            shutil.copyfileobj(self._out, out_file)
1637            self._out.close()
1638        self._out = os.sys.stdout
1639
1640    @classmethod
1641    def _is_cond(cls, line):
1642        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1643
1644    def p(self, line, add_ind=0):
1645        if self._block_end:
1646            self._block_end = False
1647            if line.startswith('else'):
1648                line = '} ' + line
1649            else:
1650                self._out.write('\t' * self._ind + '}\n')
1651
1652        if self._nl:
1653            self._out.write('\n')
1654            self._nl = False
1655
1656        ind = self._ind
1657        if line[-1] == ':':
1658            ind -= 1
1659        if self._silent_block:
1660            ind += 1
1661        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1662        self._silent_block |= line.strip() == 'else'
1663        if line[0] == '#':
1664            ind = 0
1665        if add_ind:
1666            ind += add_ind
1667        self._out.write('\t' * ind + line + '\n')
1668
1669    def nl(self):
1670        self._nl = True
1671
1672    def block_start(self, line=''):
1673        if line:
1674            line = line + ' '
1675        self.p(line + '{')
1676        self._ind += 1
1677
1678    def block_end(self, line=''):
1679        if line and line[0] not in {';', ','}:
1680            line = ' ' + line
1681        self._ind -= 1
1682        self._nl = False
1683        if not line:
1684            # Delay printing closing bracket in case "else" comes next
1685            if self._block_end:
1686                self._out.write('\t' * (self._ind + 1) + '}\n')
1687            self._block_end = True
1688        else:
1689            self.p('}' + line)
1690
1691    def write_doc_line(self, doc, indent=True):
1692        words = doc.split()
1693        line = ' *'
1694        for word in words:
1695            if len(line) + len(word) >= 79:
1696                self.p(line)
1697                line = ' *'
1698                if indent:
1699                    line += '  '
1700            line += ' ' + word
1701        self.p(line)
1702
1703    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1704        if not args:
1705            args = ['void']
1706
1707        if doc:
1708            self.p('/*')
1709            self.p(' * ' + doc)
1710            self.p(' */')
1711
1712        oneline = qual_ret
1713        if qual_ret[-1] != '*':
1714            oneline += ' '
1715        oneline += f"{name}({', '.join(args)}){suffix}"
1716
1717        if len(oneline) < 80:
1718            self.p(oneline)
1719            return
1720
1721        v = qual_ret
1722        if len(v) > 3:
1723            self.p(v)
1724            v = ''
1725        elif qual_ret[-1] != '*':
1726            v += ' '
1727        v += name + '('
1728        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1729        delta_ind = len(v) - len(ind)
1730        v += args[0]
1731        i = 1
1732        while i < len(args):
1733            next_len = len(v) + len(args[i])
1734            if v[0] == '\t':
1735                next_len += delta_ind
1736            if next_len > 76:
1737                self.p(v + ',')
1738                v = ind
1739            else:
1740                v += ', '
1741            v += args[i]
1742            i += 1
1743        self.p(v + ')' + suffix)
1744
1745    def write_func_lvar(self, local_vars):
1746        if not local_vars:
1747            return
1748
1749        if type(local_vars) is str:
1750            local_vars = [local_vars]
1751
1752        local_vars.sort(key=len, reverse=True)
1753        for var in local_vars:
1754            self.p(var)
1755        self.nl()
1756
1757    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1758        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1759        self.block_start()
1760        self.write_func_lvar(local_vars=local_vars)
1761
1762        for line in body:
1763            self.p(line)
1764        self.block_end()
1765
1766    def writes_defines(self, defines):
1767        longest = 0
1768        for define in defines:
1769            if len(define[0]) > longest:
1770                longest = len(define[0])
1771        longest = ((longest + 8) // 8) * 8
1772        for define in defines:
1773            line = '#define ' + define[0]
1774            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1775            if type(define[1]) is int:
1776                line += str(define[1])
1777            elif type(define[1]) is str:
1778                line += '"' + define[1] + '"'
1779            self.p(line)
1780
1781    def write_struct_init(self, members):
1782        longest = max([len(x[0]) for x in members])
1783        longest += 1  # because we prepend a .
1784        longest = ((longest + 8) // 8) * 8
1785        for one in members:
1786            line = '.' + one[0]
1787            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1788            line += '= ' + str(one[1]) + ','
1789            self.p(line)
1790
1791    def ifdef_block(self, config):
1792        config_option = None
1793        if config:
1794            config_option = 'CONFIG_' + c_upper(config)
1795        if self._ifdef_block == config_option:
1796            return
1797
1798        if self._ifdef_block:
1799            self.p('#endif /* ' + self._ifdef_block + ' */')
1800        if config_option:
1801            self.p('#ifdef ' + config_option)
1802        self._ifdef_block = config_option
1803
1804
1805scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1806
1807direction_to_suffix = {
1808    'reply': '_rsp',
1809    'request': '_req',
1810    '': ''
1811}
1812
1813op_mode_to_wrapper = {
1814    'do': '',
1815    'dump': '_list',
1816    'notify': '_ntf',
1817    'event': '',
1818}
1819
1820_C_KW = {
1821    'auto',
1822    'bool',
1823    'break',
1824    'case',
1825    'char',
1826    'const',
1827    'continue',
1828    'default',
1829    'do',
1830    'double',
1831    'else',
1832    'enum',
1833    'extern',
1834    'float',
1835    'for',
1836    'goto',
1837    'if',
1838    'inline',
1839    'int',
1840    'long',
1841    'register',
1842    'return',
1843    'short',
1844    'signed',
1845    'sizeof',
1846    'static',
1847    'struct',
1848    'switch',
1849    'typedef',
1850    'union',
1851    'unsigned',
1852    'void',
1853    'volatile',
1854    'while'
1855}
1856
1857
1858def rdir(direction):
1859    if direction == 'reply':
1860        return 'request'
1861    if direction == 'request':
1862        return 'reply'
1863    return direction
1864
1865
1866def op_prefix(ri, direction, deref=False):
1867    suffix = f"_{ri.type_name}"
1868
1869    if not ri.op_mode or ri.op_mode == 'do':
1870        suffix += f"{direction_to_suffix[direction]}"
1871    else:
1872        if direction == 'request':
1873            suffix += '_req'
1874            if not ri.type_oneside:
1875                suffix += '_dump'
1876        else:
1877            if ri.type_consistent:
1878                if deref:
1879                    suffix += f"{direction_to_suffix[direction]}"
1880                else:
1881                    suffix += op_mode_to_wrapper[ri.op_mode]
1882            else:
1883                suffix += '_rsp'
1884                suffix += '_dump' if deref else '_list'
1885
1886    return f"{ri.family.c_name}{suffix}"
1887
1888
1889def type_name(ri, direction, deref=False):
1890    return f"struct {op_prefix(ri, direction, deref=deref)}"
1891
1892
1893def print_prototype(ri, direction, terminate=True, doc=None):
1894    suffix = ';' if terminate else ''
1895
1896    fname = ri.op.render_name
1897    if ri.op_mode == 'dump':
1898        fname += '_dump'
1899
1900    args = ['struct ynl_sock *ys']
1901    if 'request' in ri.op[ri.op_mode]:
1902        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1903
1904    ret = 'int'
1905    if 'reply' in ri.op[ri.op_mode]:
1906        ret = f"{type_name(ri, rdir(direction))} *"
1907
1908    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1909
1910
1911def print_req_prototype(ri):
1912    print_prototype(ri, "request", doc=ri.op['doc'])
1913
1914
1915def print_dump_prototype(ri):
1916    print_prototype(ri, "request")
1917
1918
1919def put_typol_submsg(cw, struct):
1920    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1921
1922    i = 0
1923    for name, arg in struct.member_list():
1924        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s", .nest = &%s_nest, },' %
1925             (i, name, arg.nested_render_name))
1926        i += 1
1927
1928    cw.block_end(line=';')
1929    cw.nl()
1930
1931    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1932    cw.p(f'.max_attr = {i - 1},')
1933    cw.p(f'.table = {struct.render_name}_policy,')
1934    cw.block_end(line=';')
1935    cw.nl()
1936
1937
1938def put_typol_fwd(cw, struct):
1939    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1940
1941
1942def put_typol(cw, struct):
1943    if struct.submsg:
1944        put_typol_submsg(cw, struct)
1945        return
1946
1947    type_max = struct.attr_set.max_name
1948    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1949
1950    for _, arg in struct.member_list():
1951        arg.attr_typol(cw)
1952
1953    cw.block_end(line=';')
1954    cw.nl()
1955
1956    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1957    cw.p(f'.max_attr = {type_max},')
1958    cw.p(f'.table = {struct.render_name}_policy,')
1959    cw.block_end(line=';')
1960    cw.nl()
1961
1962
1963def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1964    args = [f'int {arg_name}']
1965    if enum:
1966        args = [enum.user_type + ' ' + arg_name]
1967    cw.write_func_prot('const char *', f'{render_name}_str', args)
1968    cw.block_start()
1969    if enum and enum.type == 'flags':
1970        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1971    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1972    cw.p('return NULL;')
1973    cw.p(f'return {map_name}[{arg_name}];')
1974    cw.block_end()
1975    cw.nl()
1976
1977
1978def put_op_name_fwd(family, cw):
1979    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1980
1981
1982def put_op_name(family, cw):
1983    map_name = f'{family.c_name}_op_strmap'
1984    cw.block_start(line=f"static const char * const {map_name}[] =")
1985    for op_name, op in family.msgs.items():
1986        if op.rsp_value:
1987            # Make sure we don't add duplicated entries, if multiple commands
1988            # produce the same response in legacy families.
1989            if family.rsp_by_value[op.rsp_value] != op:
1990                cw.p(f'// skip "{op_name}", duplicate reply value')
1991                continue
1992
1993            if op.req_value == op.rsp_value:
1994                cw.p(f'[{op.enum_name}] = "{op_name}",')
1995            else:
1996                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1997    cw.block_end(line=';')
1998    cw.nl()
1999
2000    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2001
2002
2003def put_enum_to_str_fwd(family, cw, enum):
2004    args = [enum.user_type + ' value']
2005    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2006
2007
2008def put_enum_to_str(family, cw, enum):
2009    map_name = f'{enum.render_name}_strmap'
2010    cw.block_start(line=f"static const char * const {map_name}[] =")
2011    for entry in enum.entries.values():
2012        cw.p(f'[{entry.value}] = "{entry.name}",')
2013    cw.block_end(line=';')
2014    cw.nl()
2015
2016    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2017
2018
2019def put_req_nested_prototype(ri, struct, suffix=';'):
2020    func_args = ['struct nlmsghdr *nlh',
2021                 'unsigned int attr_type',
2022                 f'{struct.ptr_name}obj']
2023
2024    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2025                          suffix=suffix)
2026
2027
2028def put_req_nested(ri, struct):
2029    local_vars = []
2030    init_lines = []
2031
2032    if struct.submsg is None:
2033        local_vars.append('struct nlattr *nest;')
2034        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2035
2036    has_anest = False
2037    has_count = False
2038    for _, arg in struct.member_list():
2039        has_anest |= arg.type == 'indexed-array'
2040        has_count |= arg.presence_type() == 'count'
2041    if has_anest:
2042        local_vars.append('struct nlattr *array;')
2043    if has_count:
2044        local_vars.append('unsigned int i;')
2045
2046    put_req_nested_prototype(ri, struct, suffix='')
2047    ri.cw.block_start()
2048    ri.cw.write_func_lvar(local_vars)
2049
2050    for line in init_lines:
2051        ri.cw.p(line)
2052
2053    for _, arg in struct.member_list():
2054        arg.attr_put(ri, "obj")
2055
2056    if struct.submsg is None:
2057        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2058
2059    ri.cw.nl()
2060    ri.cw.p('return 0;')
2061    ri.cw.block_end()
2062    ri.cw.nl()
2063
2064
2065def _multi_parse(ri, struct, init_lines, local_vars):
2066    if struct.nested:
2067        iter_line = "ynl_attr_for_each_nested(attr, nested)"
2068    else:
2069        if struct.fixed_header:
2070            local_vars += ['void *hdr;']
2071        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2072        if ri.op.fixed_header != ri.family.fixed_header:
2073            if ri.family.is_classic():
2074                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2075            else:
2076                raise Exception(f"Per-op fixed header not supported, yet")
2077
2078    array_nests = set()
2079    multi_attrs = set()
2080    needs_parg = False
2081    for arg, aspec in struct.member_list():
2082        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2083            if aspec["sub-type"] in {'binary', 'nest'}:
2084                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2085                array_nests.add(arg)
2086            elif aspec['sub-type'] in scalars:
2087                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2088                array_nests.add(arg)
2089            else:
2090                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2091        if 'multi-attr' in aspec:
2092            multi_attrs.add(arg)
2093        needs_parg |= 'nested-attributes' in aspec
2094        needs_parg |= 'sub-message' in aspec
2095    if array_nests or multi_attrs:
2096        local_vars.append('int i;')
2097    if needs_parg:
2098        local_vars.append('struct ynl_parse_arg parg;')
2099        init_lines.append('parg.ys = yarg->ys;')
2100
2101    all_multi = array_nests | multi_attrs
2102
2103    for anest in sorted(all_multi):
2104        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
2105
2106    ri.cw.block_start()
2107    ri.cw.write_func_lvar(local_vars)
2108
2109    for line in init_lines:
2110        ri.cw.p(line)
2111    ri.cw.nl()
2112
2113    for arg in struct.inherited:
2114        ri.cw.p(f'dst->{arg} = {arg};')
2115
2116    if struct.fixed_header:
2117        if ri.family.is_classic():
2118            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2119        else:
2120            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2121        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2122    for anest in sorted(all_multi):
2123        aspec = struct[anest]
2124        ri.cw.p(f"if (dst->{aspec.c_name})")
2125        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2126
2127    ri.cw.nl()
2128    ri.cw.block_start(line=iter_line)
2129    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2130    ri.cw.nl()
2131
2132    first = True
2133    for _, arg in struct.member_list():
2134        good = arg.attr_get(ri, 'dst', first=first)
2135        # First may be 'unused' or 'pad', ignore those
2136        first &= not good
2137
2138    ri.cw.block_end()
2139    ri.cw.nl()
2140
2141    for anest in sorted(array_nests):
2142        aspec = struct[anest]
2143
2144        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2145        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2146        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2147        ri.cw.p('i = 0;')
2148        if 'nested-attributes' in aspec:
2149            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2150        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2151        if 'nested-attributes' in aspec:
2152            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2153            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2154            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2155        elif aspec.sub_type in scalars:
2156            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2157        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2158            # Length is validated by typol
2159            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2160        else:
2161            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2162        ri.cw.p('i++;')
2163        ri.cw.block_end()
2164        ri.cw.block_end()
2165    ri.cw.nl()
2166
2167    for anest in sorted(multi_attrs):
2168        aspec = struct[anest]
2169        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2170        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2171        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2172        ri.cw.p('i = 0;')
2173        if 'nested-attributes' in aspec:
2174            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2175        ri.cw.block_start(line=iter_line)
2176        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2177        if 'nested-attributes' in aspec:
2178            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2179            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2180            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2181        elif aspec.type in scalars:
2182            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2183        elif aspec.type == 'binary' and 'struct' in aspec:
2184            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2185            ri.cw.nl()
2186            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2187            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2188            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2189        elif aspec.type == 'string':
2190            ri.cw.p('unsigned int len;')
2191            ri.cw.nl()
2192            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2193            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2194            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2195            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2196            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2197        else:
2198            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2199        ri.cw.p('i++;')
2200        ri.cw.block_end()
2201        ri.cw.block_end()
2202        ri.cw.block_end()
2203    ri.cw.nl()
2204
2205    if struct.nested:
2206        ri.cw.p('return 0;')
2207    else:
2208        ri.cw.p('return YNL_PARSE_CB_OK;')
2209    ri.cw.block_end()
2210    ri.cw.nl()
2211
2212
2213def parse_rsp_submsg(ri, struct):
2214    parse_rsp_nested_prototype(ri, struct, suffix='')
2215
2216    var = 'dst'
2217
2218    ri.cw.block_start()
2219    ri.cw.write_func_lvar(['const struct nlattr *attr = nested;',
2220                          f'{struct.ptr_name}{var} = yarg->data;',
2221                          'struct ynl_parse_arg parg;'])
2222
2223    ri.cw.p('parg.ys = yarg->ys;')
2224    ri.cw.nl()
2225
2226    first = True
2227    for name, arg in struct.member_list():
2228        kw = 'if' if first else 'else if'
2229        first = False
2230
2231        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2232        get_lines, init_lines, _ = arg._attr_get(ri, var)
2233        for line in init_lines:
2234            ri.cw.p(line)
2235        for line in get_lines:
2236            ri.cw.p(line)
2237        if arg.presence_type() == 'present':
2238            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2239        ri.cw.block_end()
2240    ri.cw.p('return 0;')
2241    ri.cw.block_end()
2242    ri.cw.nl()
2243
2244
2245def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2246    func_args = ['struct ynl_parse_arg *yarg',
2247                 'const struct nlattr *nested']
2248    for sel in struct.external_selectors():
2249        func_args.append('const char *_sel_' + sel.name)
2250    if struct.submsg:
2251        func_args.insert(1, 'const char *sel')
2252    for arg in struct.inherited:
2253        func_args.append('__u32 ' + arg)
2254
2255    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2256                          suffix=suffix)
2257
2258
2259def parse_rsp_nested(ri, struct):
2260    if struct.submsg:
2261        return parse_rsp_submsg(ri, struct)
2262
2263    parse_rsp_nested_prototype(ri, struct, suffix='')
2264
2265    local_vars = ['const struct nlattr *attr;',
2266                  f'{struct.ptr_name}dst = yarg->data;']
2267    init_lines = []
2268
2269    if struct.member_list():
2270        _multi_parse(ri, struct, init_lines, local_vars)
2271    else:
2272        # Empty nest
2273        ri.cw.block_start()
2274        ri.cw.p('return 0;')
2275        ri.cw.block_end()
2276        ri.cw.nl()
2277
2278
2279def parse_rsp_msg(ri, deref=False):
2280    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2281        return
2282
2283    func_args = ['const struct nlmsghdr *nlh',
2284                 'struct ynl_parse_arg *yarg']
2285
2286    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2287                  'const struct nlattr *attr;']
2288    init_lines = ['dst = yarg->data;']
2289
2290    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2291
2292    if ri.struct["reply"].member_list():
2293        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2294    else:
2295        # Empty reply
2296        ri.cw.block_start()
2297        ri.cw.p('return YNL_PARSE_CB_OK;')
2298        ri.cw.block_end()
2299        ri.cw.nl()
2300
2301
2302def print_req(ri):
2303    ret_ok = '0'
2304    ret_err = '-1'
2305    direction = "request"
2306    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2307                  'struct nlmsghdr *nlh;',
2308                  'int err;']
2309
2310    if 'reply' in ri.op[ri.op_mode]:
2311        ret_ok = 'rsp'
2312        ret_err = 'NULL'
2313        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2314
2315    if ri.struct["request"].fixed_header:
2316        local_vars += ['size_t hdr_len;',
2317                       'void *hdr;']
2318
2319    for _, attr in ri.struct["request"].member_list():
2320        if attr.presence_type() == 'count':
2321            local_vars += ['unsigned int i;']
2322            break
2323
2324    print_prototype(ri, direction, terminate=False)
2325    ri.cw.block_start()
2326    ri.cw.write_func_lvar(local_vars)
2327
2328    if ri.family.is_classic():
2329        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2330    else:
2331        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2332
2333    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2334    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2335    if 'reply' in ri.op[ri.op_mode]:
2336        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2337    ri.cw.nl()
2338
2339    if ri.struct['request'].fixed_header:
2340        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2341        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2342        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2343        ri.cw.nl()
2344
2345    for _, attr in ri.struct["request"].member_list():
2346        attr.attr_put(ri, "req")
2347    ri.cw.nl()
2348
2349    if 'reply' in ri.op[ri.op_mode]:
2350        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2351        ri.cw.p('yrs.yarg.data = rsp;')
2352        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2353        if ri.op.value is not None:
2354            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2355        else:
2356            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2357        ri.cw.nl()
2358    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2359    ri.cw.p('if (err < 0)')
2360    if 'reply' in ri.op[ri.op_mode]:
2361        ri.cw.p('goto err_free;')
2362    else:
2363        ri.cw.p('return -1;')
2364    ri.cw.nl()
2365
2366    ri.cw.p(f"return {ret_ok};")
2367    ri.cw.nl()
2368
2369    if 'reply' in ri.op[ri.op_mode]:
2370        ri.cw.p('err_free:')
2371        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2372        ri.cw.p(f"return {ret_err};")
2373
2374    ri.cw.block_end()
2375
2376
2377def print_dump(ri):
2378    direction = "request"
2379    print_prototype(ri, direction, terminate=False)
2380    ri.cw.block_start()
2381    local_vars = ['struct ynl_dump_state yds = {};',
2382                  'struct nlmsghdr *nlh;',
2383                  'int err;']
2384
2385    if ri.struct['request'].fixed_header:
2386        local_vars += ['size_t hdr_len;',
2387                       'void *hdr;']
2388
2389    ri.cw.write_func_lvar(local_vars)
2390
2391    ri.cw.p('yds.yarg.ys = ys;')
2392    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2393    ri.cw.p("yds.yarg.data = NULL;")
2394    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2395    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2396    if ri.op.value is not None:
2397        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2398    else:
2399        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2400    ri.cw.nl()
2401    if ri.family.is_classic():
2402        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2403    else:
2404        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2405
2406    if ri.struct['request'].fixed_header:
2407        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2408        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2409        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2410        ri.cw.nl()
2411
2412    if "request" in ri.op[ri.op_mode]:
2413        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2414        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2415        ri.cw.nl()
2416        for _, attr in ri.struct["request"].member_list():
2417            attr.attr_put(ri, "req")
2418    ri.cw.nl()
2419
2420    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2421    ri.cw.p('if (err < 0)')
2422    ri.cw.p('goto free_list;')
2423    ri.cw.nl()
2424
2425    ri.cw.p('return yds.first;')
2426    ri.cw.nl()
2427    ri.cw.p('free_list:')
2428    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2429    ri.cw.p('return NULL;')
2430    ri.cw.block_end()
2431
2432
2433def call_free(ri, direction, var):
2434    return f"{op_prefix(ri, direction)}_free({var});"
2435
2436
2437def free_arg_name(direction):
2438    if direction:
2439        return direction_to_suffix[direction][1:]
2440    return 'obj'
2441
2442
2443def print_alloc_wrapper(ri, direction):
2444    name = op_prefix(ri, direction)
2445    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
2446    ri.cw.block_start()
2447    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
2448    ri.cw.block_end()
2449
2450
2451def print_free_prototype(ri, direction, suffix=';'):
2452    name = op_prefix(ri, direction)
2453    struct_name = name
2454    if ri.type_name_conflict:
2455        struct_name += '_'
2456    arg = free_arg_name(direction)
2457    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2458
2459
2460def print_nlflags_set(ri, direction):
2461    name = op_prefix(ri, direction)
2462    ri.cw.write_func_prot(f'static inline void', f"{name}_set_nlflags",
2463                          [f"struct {name} *req", "__u16 nl_flags"])
2464    ri.cw.block_start()
2465    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2466    ri.cw.block_end()
2467    ri.cw.nl()
2468
2469
2470def _print_type(ri, direction, struct):
2471    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2472    if not direction and ri.type_name_conflict:
2473        suffix += '_'
2474
2475    if ri.op_mode == 'dump' and not ri.type_oneside:
2476        suffix += '_dump'
2477
2478    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2479
2480    if ri.needs_nlflags(direction):
2481        ri.cw.p('__u16 _nlmsg_flags;')
2482        ri.cw.nl()
2483    if struct.fixed_header:
2484        ri.cw.p(struct.fixed_header + ' _hdr;')
2485        ri.cw.nl()
2486
2487    for type_filter in ['present', 'len', 'count']:
2488        meta_started = False
2489        for _, attr in struct.member_list():
2490            line = attr.presence_member(ri.ku_space, type_filter)
2491            if line:
2492                if not meta_started:
2493                    ri.cw.block_start(line=f"struct")
2494                    meta_started = True
2495                ri.cw.p(line)
2496        if meta_started:
2497            ri.cw.block_end(line=f'_{type_filter};')
2498    ri.cw.nl()
2499
2500    for arg in struct.inherited:
2501        ri.cw.p(f"__u32 {arg};")
2502
2503    for _, attr in struct.member_list():
2504        attr.struct_member(ri)
2505
2506    ri.cw.block_end(line=';')
2507    ri.cw.nl()
2508
2509
2510def print_type(ri, direction):
2511    _print_type(ri, direction, ri.struct[direction])
2512
2513
2514def print_type_full(ri, struct):
2515    _print_type(ri, "", struct)
2516
2517
2518def print_type_helpers(ri, direction, deref=False):
2519    print_free_prototype(ri, direction)
2520    ri.cw.nl()
2521
2522    if ri.needs_nlflags(direction):
2523        print_nlflags_set(ri, direction)
2524
2525    if ri.ku_space == 'user' and direction == 'request':
2526        for _, attr in ri.struct[direction].member_list():
2527            attr.setter(ri, ri.attr_set, direction, deref=deref)
2528    ri.cw.nl()
2529
2530
2531def print_req_type_helpers(ri):
2532    if ri.type_empty("request"):
2533        return
2534    print_alloc_wrapper(ri, "request")
2535    print_type_helpers(ri, "request")
2536
2537
2538def print_rsp_type_helpers(ri):
2539    if 'reply' not in ri.op[ri.op_mode]:
2540        return
2541    print_type_helpers(ri, "reply")
2542
2543
2544def print_parse_prototype(ri, direction, terminate=True):
2545    suffix = "_rsp" if direction == "reply" else "_req"
2546    term = ';' if terminate else ''
2547
2548    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2549                          ['const struct nlattr **tb',
2550                           f"struct {ri.op.render_name}{suffix} *req"],
2551                          suffix=term)
2552
2553
2554def print_req_type(ri):
2555    if ri.type_empty("request"):
2556        return
2557    print_type(ri, "request")
2558
2559
2560def print_req_free(ri):
2561    if 'request' not in ri.op[ri.op_mode]:
2562        return
2563    _free_type(ri, 'request', ri.struct['request'])
2564
2565
2566def print_rsp_type(ri):
2567    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2568        direction = 'reply'
2569    elif ri.op_mode == 'event':
2570        direction = 'reply'
2571    else:
2572        return
2573    print_type(ri, direction)
2574
2575
2576def print_wrapped_type(ri):
2577    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2578    if ri.op_mode == 'dump':
2579        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2580    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2581        ri.cw.p('__u16 family;')
2582        ri.cw.p('__u8 cmd;')
2583        ri.cw.p('struct ynl_ntf_base_type *next;')
2584        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2585    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2586    ri.cw.block_end(line=';')
2587    ri.cw.nl()
2588    print_free_prototype(ri, 'reply')
2589    ri.cw.nl()
2590
2591
2592def _free_type_members_iter(ri, struct):
2593    if struct.free_needs_iter():
2594        ri.cw.p('unsigned int i;')
2595        ri.cw.nl()
2596
2597
2598def _free_type_members(ri, var, struct, ref=''):
2599    for _, attr in struct.member_list():
2600        attr.free(ri, var, ref)
2601
2602
2603def _free_type(ri, direction, struct):
2604    var = free_arg_name(direction)
2605
2606    print_free_prototype(ri, direction, suffix='')
2607    ri.cw.block_start()
2608    _free_type_members_iter(ri, struct)
2609    _free_type_members(ri, var, struct)
2610    if direction:
2611        ri.cw.p(f'free({var});')
2612    ri.cw.block_end()
2613    ri.cw.nl()
2614
2615
2616def free_rsp_nested_prototype(ri):
2617        print_free_prototype(ri, "")
2618
2619
2620def free_rsp_nested(ri, struct):
2621    _free_type(ri, "", struct)
2622
2623
2624def print_rsp_free(ri):
2625    if 'reply' not in ri.op[ri.op_mode]:
2626        return
2627    _free_type(ri, 'reply', ri.struct['reply'])
2628
2629
2630def print_dump_type_free(ri):
2631    sub_type = type_name(ri, 'reply')
2632
2633    print_free_prototype(ri, 'reply', suffix='')
2634    ri.cw.block_start()
2635    ri.cw.p(f"{sub_type} *next = rsp;")
2636    ri.cw.nl()
2637    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2638    _free_type_members_iter(ri, ri.struct['reply'])
2639    ri.cw.p('rsp = next;')
2640    ri.cw.p('next = rsp->next;')
2641    ri.cw.nl()
2642
2643    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2644    ri.cw.p(f'free(rsp);')
2645    ri.cw.block_end()
2646    ri.cw.block_end()
2647    ri.cw.nl()
2648
2649
2650def print_ntf_type_free(ri):
2651    print_free_prototype(ri, 'reply', suffix='')
2652    ri.cw.block_start()
2653    _free_type_members_iter(ri, ri.struct['reply'])
2654    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2655    ri.cw.p(f'free(rsp);')
2656    ri.cw.block_end()
2657    ri.cw.nl()
2658
2659
2660def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2661    if terminate and ri and policy_should_be_static(struct.family):
2662        return
2663
2664    if terminate:
2665        prefix = 'extern '
2666    else:
2667        if ri and policy_should_be_static(struct.family):
2668            prefix = 'static '
2669        else:
2670            prefix = ''
2671
2672    suffix = ';' if terminate else ' = {'
2673
2674    max_attr = struct.attr_max_val
2675    if ri:
2676        name = ri.op.render_name
2677        if ri.op.dual_policy:
2678            name += '_' + ri.op_mode
2679    else:
2680        name = struct.render_name
2681    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2682
2683
2684def print_req_policy(cw, struct, ri=None):
2685    if ri and ri.op:
2686        cw.ifdef_block(ri.op.get('config-cond', None))
2687    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2688    for _, arg in struct.member_list():
2689        arg.attr_policy(cw)
2690    cw.p("};")
2691    cw.ifdef_block(None)
2692    cw.nl()
2693
2694
2695def kernel_can_gen_family_struct(family):
2696    return family.proto == 'genetlink'
2697
2698
2699def policy_should_be_static(family):
2700    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2701
2702
2703def print_kernel_policy_ranges(family, cw):
2704    first = True
2705    for _, attr_set in family.attr_sets.items():
2706        if attr_set.subset_of:
2707            continue
2708
2709        for _, attr in attr_set.items():
2710            if not attr.request:
2711                continue
2712            if 'full-range' not in attr.checks:
2713                continue
2714
2715            if first:
2716                cw.p('/* Integer value ranges */')
2717                first = False
2718
2719            sign = '' if attr.type[0] == 'u' else '_signed'
2720            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2721            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2722            members = []
2723            if 'min' in attr.checks:
2724                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2725            if 'max' in attr.checks:
2726                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2727            cw.write_struct_init(members)
2728            cw.block_end(line=';')
2729            cw.nl()
2730
2731
2732def print_kernel_policy_sparse_enum_validates(family, cw):
2733    first = True
2734    for _, attr_set in family.attr_sets.items():
2735        if attr_set.subset_of:
2736            continue
2737
2738        for _, attr in attr_set.items():
2739            if not attr.request:
2740                continue
2741            if not attr.enum_name:
2742                continue
2743            if 'sparse' not in attr.checks:
2744                continue
2745
2746            if first:
2747                cw.p('/* Sparse enums validation callbacks */')
2748                first = False
2749
2750            sign = '' if attr.type[0] == 'u' else '_signed'
2751            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2752            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2753                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2754            cw.block_start()
2755            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2756            enum = family.consts[attr['enum']]
2757            first_entry = True
2758            for entry in enum.entries.values():
2759                if first_entry:
2760                    first_entry = False
2761                else:
2762                    cw.p('fallthrough;')
2763                cw.p(f'case {entry.c_name}:')
2764            cw.p('return 0;')
2765            cw.block_end()
2766            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2767            cw.p('return -EINVAL;')
2768            cw.block_end()
2769            cw.nl()
2770
2771
2772def print_kernel_op_table_fwd(family, cw, terminate):
2773    exported = not kernel_can_gen_family_struct(family)
2774
2775    if not terminate or exported:
2776        cw.p(f"/* Ops table for {family.ident_name} */")
2777
2778        pol_to_struct = {'global': 'genl_small_ops',
2779                         'per-op': 'genl_ops',
2780                         'split': 'genl_split_ops'}
2781        struct_type = pol_to_struct[family.kernel_policy]
2782
2783        if not exported:
2784            cnt = ""
2785        elif family.kernel_policy == 'split':
2786            cnt = 0
2787            for op in family.ops.values():
2788                if 'do' in op:
2789                    cnt += 1
2790                if 'dump' in op:
2791                    cnt += 1
2792        else:
2793            cnt = len(family.ops)
2794
2795        qual = 'static const' if not exported else 'const'
2796        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2797        if terminate:
2798            cw.p(f"extern {line};")
2799        else:
2800            cw.block_start(line=line + ' =')
2801
2802    if not terminate:
2803        return
2804
2805    cw.nl()
2806    for name in family.hooks['pre']['do']['list']:
2807        cw.write_func_prot('int', c_lower(name),
2808                           ['const struct genl_split_ops *ops',
2809                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2810    for name in family.hooks['post']['do']['list']:
2811        cw.write_func_prot('void', c_lower(name),
2812                           ['const struct genl_split_ops *ops',
2813                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2814    for name in family.hooks['pre']['dump']['list']:
2815        cw.write_func_prot('int', c_lower(name),
2816                           ['struct netlink_callback *cb'], suffix=';')
2817    for name in family.hooks['post']['dump']['list']:
2818        cw.write_func_prot('int', c_lower(name),
2819                           ['struct netlink_callback *cb'], suffix=';')
2820
2821    cw.nl()
2822
2823    for op_name, op in family.ops.items():
2824        if op.is_async:
2825            continue
2826
2827        if 'do' in op:
2828            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2829            cw.write_func_prot('int', name,
2830                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2831
2832        if 'dump' in op:
2833            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2834            cw.write_func_prot('int', name,
2835                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2836    cw.nl()
2837
2838
2839def print_kernel_op_table_hdr(family, cw):
2840    print_kernel_op_table_fwd(family, cw, terminate=True)
2841
2842
2843def print_kernel_op_table(family, cw):
2844    print_kernel_op_table_fwd(family, cw, terminate=False)
2845    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2846        for op_name, op in family.ops.items():
2847            if op.is_async:
2848                continue
2849
2850            cw.ifdef_block(op.get('config-cond', None))
2851            cw.block_start()
2852            members = [('cmd', op.enum_name)]
2853            if 'dont-validate' in op:
2854                members.append(('validate',
2855                                ' | '.join([c_upper('genl-dont-validate-' + x)
2856                                            for x in op['dont-validate']])), )
2857            for op_mode in ['do', 'dump']:
2858                if op_mode in op:
2859                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2860                    members.append((op_mode + 'it', name))
2861            if family.kernel_policy == 'per-op':
2862                struct = Struct(family, op['attribute-set'],
2863                                type_list=op['do']['request']['attributes'])
2864
2865                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2866                members.append(('policy', name))
2867                members.append(('maxattr', struct.attr_max_val.enum_name))
2868            if 'flags' in op:
2869                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2870            cw.write_struct_init(members)
2871            cw.block_end(line=',')
2872    elif family.kernel_policy == 'split':
2873        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2874                    'dump': {'pre': 'start', 'post': 'done'}}
2875
2876        for op_name, op in family.ops.items():
2877            for op_mode in ['do', 'dump']:
2878                if op.is_async or op_mode not in op:
2879                    continue
2880
2881                cw.ifdef_block(op.get('config-cond', None))
2882                cw.block_start()
2883                members = [('cmd', op.enum_name)]
2884                if 'dont-validate' in op:
2885                    dont_validate = []
2886                    for x in op['dont-validate']:
2887                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2888                            continue
2889                        if op_mode == "dump" and x == 'strict':
2890                            continue
2891                        dont_validate.append(x)
2892
2893                    if dont_validate:
2894                        members.append(('validate',
2895                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2896                                                    for x in dont_validate])), )
2897                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2898                if 'pre' in op[op_mode]:
2899                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2900                members.append((op_mode + 'it', name))
2901                if 'post' in op[op_mode]:
2902                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2903                if 'request' in op[op_mode]:
2904                    struct = Struct(family, op['attribute-set'],
2905                                    type_list=op[op_mode]['request']['attributes'])
2906
2907                    if op.dual_policy:
2908                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2909                    else:
2910                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2911                    members.append(('policy', name))
2912                    members.append(('maxattr', struct.attr_max_val.enum_name))
2913                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2914                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2915                cw.write_struct_init(members)
2916                cw.block_end(line=',')
2917    cw.ifdef_block(None)
2918
2919    cw.block_end(line=';')
2920    cw.nl()
2921
2922
2923def print_kernel_mcgrp_hdr(family, cw):
2924    if not family.mcgrps['list']:
2925        return
2926
2927    cw.block_start('enum')
2928    for grp in family.mcgrps['list']:
2929        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2930        cw.p(grp_id)
2931    cw.block_end(';')
2932    cw.nl()
2933
2934
2935def print_kernel_mcgrp_src(family, cw):
2936    if not family.mcgrps['list']:
2937        return
2938
2939    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2940    for grp in family.mcgrps['list']:
2941        name = grp['name']
2942        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2943        cw.p('[' + grp_id + '] = { "' + name + '", },')
2944    cw.block_end(';')
2945    cw.nl()
2946
2947
2948def print_kernel_family_struct_hdr(family, cw):
2949    if not kernel_can_gen_family_struct(family):
2950        return
2951
2952    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2953    cw.nl()
2954    if 'sock-priv' in family.kernel_family:
2955        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2956        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2957        cw.nl()
2958
2959
2960def print_kernel_family_struct_src(family, cw):
2961    if not kernel_can_gen_family_struct(family):
2962        return
2963
2964    if 'sock-priv' in family.kernel_family:
2965        # Generate "trampolines" to make CFI happy
2966        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2967                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2968                      ["void *priv"])
2969        cw.nl()
2970        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2971                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2972                      ["void *priv"])
2973        cw.nl()
2974
2975    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2976    cw.p('.name\t\t= ' + family.fam_key + ',')
2977    cw.p('.version\t= ' + family.ver_key + ',')
2978    cw.p('.netnsok\t= true,')
2979    cw.p('.parallel_ops\t= true,')
2980    cw.p('.module\t\t= THIS_MODULE,')
2981    if family.kernel_policy == 'per-op':
2982        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2983        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2984    elif family.kernel_policy == 'split':
2985        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2986        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2987    if family.mcgrps['list']:
2988        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2989        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2990    if 'sock-priv' in family.kernel_family:
2991        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2992        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2993        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2994    cw.block_end(';')
2995
2996
2997def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2998    start_line = 'enum'
2999    if enum_name in obj:
3000        if obj[enum_name]:
3001            start_line = 'enum ' + c_lower(obj[enum_name])
3002    elif ckey and ckey in obj:
3003        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3004    cw.block_start(line=start_line)
3005
3006
3007def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3008    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3009    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3010    max_value = f"({cnt_name} - 1)"
3011
3012    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3013    val = 0
3014    for op in family.msgs.values():
3015        if separate_ntf and ('notify' in op or 'event' in op):
3016            continue
3017
3018        suffix = ','
3019        if op.value != val:
3020            suffix = f" = {op.value},"
3021            val = op.value
3022        cw.p(op.enum_name + suffix)
3023        val += 1
3024    cw.nl()
3025    cw.p(cnt_name + ('' if max_by_define else ','))
3026    if not max_by_define:
3027        cw.p(f"{max_name} = {max_value}")
3028    cw.block_end(line=';')
3029    if max_by_define:
3030        cw.p(f"#define {max_name} {max_value}")
3031    cw.nl()
3032
3033
3034def render_uapi_directional(family, cw, max_by_define):
3035    max_name = f"{family.op_prefix}USER_MAX"
3036    cnt_name = f"__{family.op_prefix}USER_CNT"
3037    max_value = f"({cnt_name} - 1)"
3038
3039    cw.block_start(line='enum')
3040    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3041    val = 0
3042    for op in family.msgs.values():
3043        if 'do' in op and 'event' not in op:
3044            suffix = ','
3045            if op.value and op.value != val:
3046                suffix = f" = {op.value},"
3047                val = op.value
3048            cw.p(op.enum_name + suffix)
3049            val += 1
3050    cw.nl()
3051    cw.p(cnt_name + ('' if max_by_define else ','))
3052    if not max_by_define:
3053        cw.p(f"{max_name} = {max_value}")
3054    cw.block_end(line=';')
3055    if max_by_define:
3056        cw.p(f"#define {max_name} {max_value}")
3057    cw.nl()
3058
3059    max_name = f"{family.op_prefix}KERNEL_MAX"
3060    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3061    max_value = f"({cnt_name} - 1)"
3062
3063    cw.block_start(line='enum')
3064    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3065    val = 0
3066    for op in family.msgs.values():
3067        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3068            enum_name = op.enum_name
3069            if 'event' not in op and 'notify' not in op:
3070                enum_name = f'{enum_name}_REPLY'
3071
3072            suffix = ','
3073            if op.value and op.value != val:
3074                suffix = f" = {op.value},"
3075                val = op.value
3076            cw.p(enum_name + suffix)
3077            val += 1
3078    cw.nl()
3079    cw.p(cnt_name + ('' if max_by_define else ','))
3080    if not max_by_define:
3081        cw.p(f"{max_name} = {max_value}")
3082    cw.block_end(line=';')
3083    if max_by_define:
3084        cw.p(f"#define {max_name} {max_value}")
3085    cw.nl()
3086
3087
3088def render_uapi(family, cw):
3089    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3090    hdr_prot = hdr_prot.replace('/', '_')
3091    cw.p('#ifndef ' + hdr_prot)
3092    cw.p('#define ' + hdr_prot)
3093    cw.nl()
3094
3095    defines = [(family.fam_key, family["name"]),
3096               (family.ver_key, family.get('version', 1))]
3097    cw.writes_defines(defines)
3098    cw.nl()
3099
3100    defines = []
3101    for const in family['definitions']:
3102        if const.get('header'):
3103            continue
3104
3105        if const['type'] != 'const':
3106            cw.writes_defines(defines)
3107            defines = []
3108            cw.nl()
3109
3110        # Write kdoc for enum and flags (one day maybe also structs)
3111        if const['type'] == 'enum' or const['type'] == 'flags':
3112            enum = family.consts[const['name']]
3113
3114            if enum.header:
3115                continue
3116
3117            if enum.has_doc():
3118                if enum.has_entry_doc():
3119                    cw.p('/**')
3120                    doc = ''
3121                    if 'doc' in enum:
3122                        doc = ' - ' + enum['doc']
3123                    cw.write_doc_line(enum.enum_name + doc)
3124                else:
3125                    cw.p('/*')
3126                    cw.write_doc_line(enum['doc'], indent=False)
3127                for entry in enum.entries.values():
3128                    if entry.has_doc():
3129                        doc = '@' + entry.c_name + ': ' + entry['doc']
3130                        cw.write_doc_line(doc)
3131                cw.p(' */')
3132
3133            uapi_enum_start(family, cw, const, 'name')
3134            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3135            for entry in enum.entries.values():
3136                suffix = ','
3137                if entry.value_change:
3138                    suffix = f" = {entry.user_value()}" + suffix
3139                cw.p(entry.c_name + suffix)
3140
3141            if const.get('render-max', False):
3142                cw.nl()
3143                cw.p('/* private: */')
3144                if const['type'] == 'flags':
3145                    max_name = c_upper(name_pfx + 'mask')
3146                    max_val = f' = {enum.get_mask()},'
3147                    cw.p(max_name + max_val)
3148                else:
3149                    cnt_name = enum.enum_cnt_name
3150                    max_name = c_upper(name_pfx + 'max')
3151                    if not cnt_name:
3152                        cnt_name = '__' + name_pfx + 'max'
3153                    cw.p(c_upper(cnt_name) + ',')
3154                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3155            cw.block_end(line=';')
3156            cw.nl()
3157        elif const['type'] == 'const':
3158            defines.append([c_upper(family.get('c-define-name',
3159                                               f"{family.ident_name}-{const['name']}")),
3160                            const['value']])
3161
3162    if defines:
3163        cw.writes_defines(defines)
3164        cw.nl()
3165
3166    max_by_define = family.get('max-by-define', False)
3167
3168    for _, attr_set in family.attr_sets.items():
3169        if attr_set.subset_of:
3170            continue
3171
3172        max_value = f"({attr_set.cnt_name} - 1)"
3173
3174        val = 0
3175        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3176        for _, attr in attr_set.items():
3177            suffix = ','
3178            if attr.value != val:
3179                suffix = f" = {attr.value},"
3180                val = attr.value
3181            val += 1
3182            cw.p(attr.enum_name + suffix)
3183        if attr_set.items():
3184            cw.nl()
3185        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3186        if not max_by_define:
3187            cw.p(f"{attr_set.max_name} = {max_value}")
3188        cw.block_end(line=';')
3189        if max_by_define:
3190            cw.p(f"#define {attr_set.max_name} {max_value}")
3191        cw.nl()
3192
3193    # Commands
3194    separate_ntf = 'async-prefix' in family['operations']
3195
3196    if family.msg_id_model == 'unified':
3197        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3198    elif family.msg_id_model == 'directional':
3199        render_uapi_directional(family, cw, max_by_define)
3200    else:
3201        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3202
3203    if separate_ntf:
3204        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3205        for op in family.msgs.values():
3206            if separate_ntf and not ('notify' in op or 'event' in op):
3207                continue
3208
3209            suffix = ','
3210            if 'value' in op:
3211                suffix = f" = {op['value']},"
3212            cw.p(op.enum_name + suffix)
3213        cw.block_end(line=';')
3214        cw.nl()
3215
3216    # Multicast
3217    defines = []
3218    for grp in family.mcgrps['list']:
3219        name = grp['name']
3220        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3221                        f'{name}'])
3222    cw.nl()
3223    if defines:
3224        cw.writes_defines(defines)
3225        cw.nl()
3226
3227    cw.p(f'#endif /* {hdr_prot} */')
3228
3229
3230def _render_user_ntf_entry(ri, op):
3231    if not ri.family.is_classic():
3232        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3233    else:
3234        crud_op = ri.family.req_by_value[op.rsp_value]
3235        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3236    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3237    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3238    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3239    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3240    ri.cw.block_end(line=',')
3241
3242
3243def render_user_family(family, cw, prototype):
3244    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3245    if prototype:
3246        cw.p(f'extern {symbol};')
3247        return
3248
3249    if family.ntfs:
3250        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3251        for ntf_op_name, ntf_op in family.ntfs.items():
3252            if 'notify' in ntf_op:
3253                op = family.ops[ntf_op['notify']]
3254                ri = RenderInfo(cw, family, "user", op, "notify")
3255            elif 'event' in ntf_op:
3256                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3257            else:
3258                raise Exception('Invalid notification ' + ntf_op_name)
3259            _render_user_ntf_entry(ri, ntf_op)
3260        for op_name, op in family.ops.items():
3261            if 'event' not in op:
3262                continue
3263            ri = RenderInfo(cw, family, "user", op, "event")
3264            _render_user_ntf_entry(ri, op)
3265        cw.block_end(line=";")
3266        cw.nl()
3267
3268    cw.block_start(f'{symbol} = ')
3269    cw.p(f'.name\t\t= "{family.c_name}",')
3270    if family.is_classic():
3271        cw.p(f'.is_classic\t= true,')
3272        cw.p(f'.classic_id\t= {family.get("protonum")},')
3273    if family.is_classic():
3274        if family.fixed_header:
3275            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3276    elif family.fixed_header:
3277        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3278    else:
3279        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3280    if family.ntfs:
3281        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3282        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3283    cw.block_end(line=';')
3284
3285
3286def family_contains_bitfield32(family):
3287    for _, attr_set in family.attr_sets.items():
3288        if attr_set.subset_of:
3289            continue
3290        for _, attr in attr_set.items():
3291            if attr.type == "bitfield32":
3292                return True
3293    return False
3294
3295
3296def find_kernel_root(full_path):
3297    sub_path = ''
3298    while True:
3299        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3300        full_path = os.path.dirname(full_path)
3301        maintainers = os.path.join(full_path, "MAINTAINERS")
3302        if os.path.exists(maintainers):
3303            return full_path, sub_path[:-1]
3304
3305
3306def main():
3307    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3308    parser.add_argument('--mode', dest='mode', type=str, required=True,
3309                        choices=('user', 'kernel', 'uapi'))
3310    parser.add_argument('--spec', dest='spec', type=str, required=True)
3311    parser.add_argument('--header', dest='header', action='store_true', default=None)
3312    parser.add_argument('--source', dest='header', action='store_false')
3313    parser.add_argument('--user-header', nargs='+', default=[])
3314    parser.add_argument('--cmp-out', action='store_true', default=None,
3315                        help='Do not overwrite the output file if the new output is identical to the old')
3316    parser.add_argument('--exclude-op', action='append', default=[])
3317    parser.add_argument('-o', dest='out_file', type=str, default=None)
3318    args = parser.parse_args()
3319
3320    if args.header is None:
3321        parser.error("--header or --source is required")
3322
3323    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3324
3325    try:
3326        parsed = Family(args.spec, exclude_ops)
3327        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3328            print('Spec license:', parsed.license)
3329            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3330            os.sys.exit(1)
3331    except yaml.YAMLError as exc:
3332        print(exc)
3333        os.sys.exit(1)
3334        return
3335
3336    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3337
3338    _, spec_kernel = find_kernel_root(args.spec)
3339    if args.mode == 'uapi' or args.header:
3340        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3341    else:
3342        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3343    cw.p("/* Do not edit directly, auto-generated from: */")
3344    cw.p(f"/*\t{spec_kernel} */")
3345    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3346    if args.exclude_op or args.user_header:
3347        line = ''
3348        line += ' --user-header '.join([''] + args.user_header)
3349        line += ' --exclude-op '.join([''] + args.exclude_op)
3350        cw.p(f'/* YNL-ARG{line} */')
3351    cw.nl()
3352
3353    if args.mode == 'uapi':
3354        render_uapi(parsed, cw)
3355        return
3356
3357    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3358    if args.header:
3359        cw.p('#ifndef ' + hdr_prot)
3360        cw.p('#define ' + hdr_prot)
3361        cw.nl()
3362
3363    if args.out_file:
3364        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3365    else:
3366        hdr_file = "generated_header_file.h"
3367
3368    if args.mode == 'kernel':
3369        cw.p('#include <net/netlink.h>')
3370        cw.p('#include <net/genetlink.h>')
3371        cw.nl()
3372        if not args.header:
3373            if args.out_file:
3374                cw.p(f'#include "{hdr_file}"')
3375            cw.nl()
3376        headers = ['uapi/' + parsed.uapi_header]
3377        headers += parsed.kernel_family.get('headers', [])
3378    else:
3379        cw.p('#include <stdlib.h>')
3380        cw.p('#include <string.h>')
3381        if args.header:
3382            cw.p('#include <linux/types.h>')
3383            if family_contains_bitfield32(parsed):
3384                cw.p('#include <linux/netlink.h>')
3385        else:
3386            cw.p(f'#include "{hdr_file}"')
3387            cw.p('#include "ynl.h"')
3388        headers = []
3389    for definition in parsed['definitions'] + parsed['attribute-sets']:
3390        if 'header' in definition:
3391            headers.append(definition['header'])
3392    if args.mode == 'user':
3393        headers.append(parsed.uapi_header)
3394    seen_header = []
3395    for one in headers:
3396        if one not in seen_header:
3397            cw.p(f"#include <{one}>")
3398            seen_header.append(one)
3399    cw.nl()
3400
3401    if args.mode == "user":
3402        if not args.header:
3403            cw.p("#include <linux/genetlink.h>")
3404            cw.nl()
3405            for one in args.user_header:
3406                cw.p(f'#include "{one}"')
3407        else:
3408            cw.p('struct ynl_sock;')
3409            cw.nl()
3410            render_user_family(parsed, cw, True)
3411        cw.nl()
3412
3413    if args.mode == "kernel":
3414        if args.header:
3415            for _, struct in sorted(parsed.pure_nested_structs.items()):
3416                if struct.request:
3417                    cw.p('/* Common nested types */')
3418                    break
3419            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3420                if struct.request:
3421                    print_req_policy_fwd(cw, struct)
3422            cw.nl()
3423
3424            if parsed.kernel_policy == 'global':
3425                cw.p(f"/* Global operation policy for {parsed.name} */")
3426
3427                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3428                print_req_policy_fwd(cw, struct)
3429                cw.nl()
3430
3431            if parsed.kernel_policy in {'per-op', 'split'}:
3432                for op_name, op in parsed.ops.items():
3433                    if 'do' in op and 'event' not in op:
3434                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3435                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3436                        cw.nl()
3437
3438            print_kernel_op_table_hdr(parsed, cw)
3439            print_kernel_mcgrp_hdr(parsed, cw)
3440            print_kernel_family_struct_hdr(parsed, cw)
3441        else:
3442            print_kernel_policy_ranges(parsed, cw)
3443            print_kernel_policy_sparse_enum_validates(parsed, cw)
3444
3445            for _, struct in sorted(parsed.pure_nested_structs.items()):
3446                if struct.request:
3447                    cw.p('/* Common nested types */')
3448                    break
3449            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3450                if struct.request:
3451                    print_req_policy(cw, struct)
3452            cw.nl()
3453
3454            if parsed.kernel_policy == 'global':
3455                cw.p(f"/* Global operation policy for {parsed.name} */")
3456
3457                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3458                print_req_policy(cw, struct)
3459                cw.nl()
3460
3461            for op_name, op in parsed.ops.items():
3462                if parsed.kernel_policy in {'per-op', 'split'}:
3463                    for op_mode in ['do', 'dump']:
3464                        if op_mode in op and 'request' in op[op_mode]:
3465                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3466                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3467                            print_req_policy(cw, ri.struct['request'], ri=ri)
3468                            cw.nl()
3469
3470            print_kernel_op_table(parsed, cw)
3471            print_kernel_mcgrp_src(parsed, cw)
3472            print_kernel_family_struct_src(parsed, cw)
3473
3474    if args.mode == "user":
3475        if args.header:
3476            cw.p('/* Enums */')
3477            put_op_name_fwd(parsed, cw)
3478
3479            for name, const in parsed.consts.items():
3480                if isinstance(const, EnumSet):
3481                    put_enum_to_str_fwd(parsed, cw, const)
3482            cw.nl()
3483
3484            cw.p('/* Common nested types */')
3485            for attr_set, struct in parsed.pure_nested_structs.items():
3486                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3487                print_type_full(ri, struct)
3488                if struct.request and struct.in_multi_val:
3489                    free_rsp_nested_prototype(ri)
3490                    cw.nl()
3491
3492            for op_name, op in parsed.ops.items():
3493                cw.p(f"/* ============== {op.enum_name} ============== */")
3494
3495                if 'do' in op and 'event' not in op:
3496                    cw.p(f"/* {op.enum_name} - do */")
3497                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3498                    print_req_type(ri)
3499                    print_req_type_helpers(ri)
3500                    cw.nl()
3501                    print_rsp_type(ri)
3502                    print_rsp_type_helpers(ri)
3503                    cw.nl()
3504                    print_req_prototype(ri)
3505                    cw.nl()
3506
3507                if 'dump' in op:
3508                    cw.p(f"/* {op.enum_name} - dump */")
3509                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3510                    print_req_type(ri)
3511                    print_req_type_helpers(ri)
3512                    if not ri.type_consistent or ri.type_oneside:
3513                        print_rsp_type(ri)
3514                    print_wrapped_type(ri)
3515                    print_dump_prototype(ri)
3516                    cw.nl()
3517
3518                if op.has_ntf:
3519                    cw.p(f"/* {op.enum_name} - notify */")
3520                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3521                    if not ri.type_consistent:
3522                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3523                    print_wrapped_type(ri)
3524
3525            for op_name, op in parsed.ntfs.items():
3526                if 'event' in op:
3527                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3528                    cw.p(f"/* {op.enum_name} - event */")
3529                    print_rsp_type(ri)
3530                    cw.nl()
3531                    print_wrapped_type(ri)
3532            cw.nl()
3533        else:
3534            cw.p('/* Enums */')
3535            put_op_name(parsed, cw)
3536
3537            for name, const in parsed.consts.items():
3538                if isinstance(const, EnumSet):
3539                    put_enum_to_str(parsed, cw, const)
3540            cw.nl()
3541
3542            has_recursive_nests = False
3543            cw.p('/* Policies */')
3544            for struct in parsed.pure_nested_structs.values():
3545                if struct.recursive:
3546                    put_typol_fwd(cw, struct)
3547                    has_recursive_nests = True
3548            if has_recursive_nests:
3549                cw.nl()
3550            for struct in parsed.pure_nested_structs.values():
3551                put_typol(cw, struct)
3552            for name in parsed.root_sets:
3553                struct = Struct(parsed, name)
3554                put_typol(cw, struct)
3555
3556            cw.p('/* Common nested types */')
3557            if has_recursive_nests:
3558                for attr_set, struct in parsed.pure_nested_structs.items():
3559                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3560                    free_rsp_nested_prototype(ri)
3561                    if struct.request:
3562                        put_req_nested_prototype(ri, struct)
3563                    if struct.reply:
3564                        parse_rsp_nested_prototype(ri, struct)
3565                cw.nl()
3566            for attr_set, struct in parsed.pure_nested_structs.items():
3567                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3568
3569                free_rsp_nested(ri, struct)
3570                if struct.request:
3571                    put_req_nested(ri, struct)
3572                if struct.reply:
3573                    parse_rsp_nested(ri, struct)
3574
3575            for op_name, op in parsed.ops.items():
3576                cw.p(f"/* ============== {op.enum_name} ============== */")
3577                if 'do' in op and 'event' not in op:
3578                    cw.p(f"/* {op.enum_name} - do */")
3579                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3580                    print_req_free(ri)
3581                    print_rsp_free(ri)
3582                    parse_rsp_msg(ri)
3583                    print_req(ri)
3584                    cw.nl()
3585
3586                if 'dump' in op:
3587                    cw.p(f"/* {op.enum_name} - dump */")
3588                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3589                    if not ri.type_consistent or ri.type_oneside:
3590                        parse_rsp_msg(ri, deref=True)
3591                    print_req_free(ri)
3592                    print_dump_type_free(ri)
3593                    print_dump(ri)
3594                    cw.nl()
3595
3596                if op.has_ntf:
3597                    cw.p(f"/* {op.enum_name} - notify */")
3598                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3599                    if not ri.type_consistent:
3600                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3601                    print_ntf_type_free(ri)
3602
3603            for op_name, op in parsed.ntfs.items():
3604                if 'event' in op:
3605                    cw.p(f"/* {op.enum_name} - event */")
3606
3607                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3608                    parse_rsp_msg(ri)
3609
3610                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3611                    print_ntf_type_free(ri)
3612            cw.nl()
3613            render_user_family(parsed, cw, False)
3614
3615    if args.header:
3616        cw.p(f'#endif /* {hdr_prot} */')
3617
3618
3619if __name__ == "__main__":
3620    main()
3621