xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 092b34b937353eaa6164ceb1f520b1f640aee54f)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17from lib import SpecSubMessage, SpecSubMessageFormat
18
19
20def c_upper(name):
21    return name.upper().replace('-', '_')
22
23
24def c_lower(name):
25    return name.lower().replace('-', '_')
26
27
28def limit_to_number(name):
29    """
30    Turn a string limit like u32-max or s64-min into its numerical value
31    """
32    if name[0] == 'u' and name.endswith('-min'):
33        return 0
34    width = int(name[1:-4])
35    if name[0] == 's':
36        width -= 1
37    value = (1 << width) - 1
38    if name[0] == 's' and name.endswith('-min'):
39        value = -value - 1
40    return value
41
42
43class BaseNlLib:
44    def get_family_id(self):
45        return 'ys->family_id'
46
47
48class Type(SpecAttr):
49    def __init__(self, family, attr_set, attr, value):
50        super().__init__(family, attr_set, attr, value)
51
52        self.attr = attr
53        self.attr_set = attr_set
54        self.type = attr['type']
55        self.checks = attr.get('checks', {})
56
57        self.request = False
58        self.reply = False
59
60        self.is_selector = False
61
62        if 'len' in attr:
63            self.len = attr['len']
64
65        if 'nested-attributes' in attr:
66            nested = attr['nested-attributes']
67        elif 'sub-message' in attr:
68            nested = attr['sub-message']
69        else:
70            nested = None
71
72        if nested:
73            self.nested_attrs = nested
74            if self.nested_attrs == family.name:
75                self.nested_render_name = c_lower(f"{family.ident_name}")
76            else:
77                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
78
79            if self.nested_attrs in self.family.consts:
80                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
81            else:
82                self.nested_struct_type = 'struct ' + self.nested_render_name
83
84        self.c_name = c_lower(self.name)
85        if self.c_name in _C_KW:
86            self.c_name += '_'
87        if self.c_name[0].isdigit():
88            self.c_name = '_' + self.c_name
89
90        # Added by resolve():
91        self.enum_name = None
92        delattr(self, "enum_name")
93
94    def _get_real_attr(self):
95        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
96        return self.family.attr_sets[self.attr_set.subset_of][self.name]
97
98    def set_request(self):
99        self.request = True
100        if self.attr_set.subset_of:
101            self._get_real_attr().set_request()
102
103    def set_reply(self):
104        self.reply = True
105        if self.attr_set.subset_of:
106            self._get_real_attr().set_reply()
107
108    def get_limit(self, limit, default=None):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return value
112        if isinstance(value, int):
113            return value
114        if value in self.family.consts:
115            return self.family.consts[value]["value"]
116        return limit_to_number(value)
117
118    def get_limit_str(self, limit, default=None, suffix=''):
119        value = self.checks.get(limit, default)
120        if value is None:
121            return ''
122        if isinstance(value, int):
123            return str(value) + suffix
124        if value in self.family.consts:
125            const = self.family.consts[value]
126            if const.get('header'):
127                return c_upper(value)
128            return c_upper(f"{self.family['name']}-{value}")
129        return c_upper(value)
130
131    def resolve(self):
132        if 'parent-sub-message' in self.attr:
133            enum_name = self.attr['parent-sub-message'].enum_name
134        elif 'name-prefix' in self.attr:
135            enum_name = f"{self.attr['name-prefix']}{self.name}"
136        else:
137            enum_name = f"{self.attr_set.name_prefix}{self.name}"
138        self.enum_name = c_upper(enum_name)
139
140        if self.attr_set.subset_of:
141            if self.checks != self._get_real_attr().checks:
142                raise Exception("Overriding checks not supported by codegen, yet")
143
144    def is_multi_val(self):
145        return None
146
147    def is_scalar(self):
148        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
149
150    def is_recursive(self):
151        return False
152
153    def is_recursive_for_op(self, ri):
154        return self.is_recursive() and not ri.op
155
156    def presence_type(self):
157        return 'present'
158
159    def presence_member(self, space, type_filter):
160        if self.presence_type() != type_filter:
161            return
162
163        if self.presence_type() == 'present':
164            pfx = '__' if space == 'user' else ''
165            return f"{pfx}u32 {self.c_name}:1;"
166
167        if self.presence_type() in {'len', 'count'}:
168            pfx = '__' if space == 'user' else ''
169            return f"{pfx}u32 {self.c_name};"
170
171    def _complex_member_type(self, ri):
172        return None
173
174    def free_needs_iter(self):
175        return False
176
177    def _free_lines(self, ri, var, ref):
178        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
179            return [f'free({var}->{ref}{self.c_name});']
180        return []
181
182    def free(self, ri, var, ref):
183        lines = self._free_lines(ri, var, ref)
184        for line in lines:
185            ri.cw.p(line)
186
187    def arg_member(self, ri):
188        member = self._complex_member_type(ri)
189        if member:
190            spc = ' ' if member[-1] != '*' else ''
191            arg = [member + spc + '*' + self.c_name]
192            if self.presence_type() == 'count':
193                arg += ['unsigned int n_' + self.c_name]
194            return arg
195        raise Exception(f"Struct member not implemented for class type {self.type}")
196
197    def struct_member(self, ri):
198        member = self._complex_member_type(ri)
199        if member:
200            ptr = '*' if self.is_multi_val() else ''
201            if self.is_recursive_for_op(ri):
202                ptr = '*'
203            spc = ' ' if member[-1] != '*' else ''
204            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
205            return
206        members = self.arg_member(ri)
207        for one in members:
208            ri.cw.p(one + ';')
209
210    def _attr_policy(self, policy):
211        return '{ .type = ' + policy + ', }'
212
213    def attr_policy(self, cw):
214        policy = f'NLA_{c_upper(self.type)}'
215        if self.attr.get('byte-order') == 'big-endian':
216            if self.type in {'u16', 'u32'}:
217                policy = f'NLA_BE{self.type[1:]}'
218
219        spec = self._attr_policy(policy)
220        cw.p(f"\t[{self.enum_name}] = {spec},")
221
222    def _attr_typol(self):
223        raise Exception(f"Type policy not implemented for class type {self.type}")
224
225    def attr_typol(self, cw):
226        typol = self._attr_typol()
227        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
228
229    def _attr_put_line(self, ri, var, line):
230        presence = self.presence_type()
231        if presence in {'present', 'len'}:
232            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
233        ri.cw.p(f"{line};")
234
235    def _attr_put_simple(self, ri, var, put_type):
236        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
237        self._attr_put_line(ri, var, line)
238
239    def attr_put(self, ri, var):
240        raise Exception(f"Put not implemented for class type {self.type}")
241
242    def _attr_get(self, ri, var):
243        raise Exception(f"Attr get not implemented for class type {self.type}")
244
245    def attr_get(self, ri, var, first):
246        lines, init_lines, local_vars = self._attr_get(ri, var)
247        if type(lines) is str:
248            lines = [lines]
249        if type(init_lines) is str:
250            init_lines = [init_lines]
251
252        kw = 'if' if first else 'else if'
253        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
254        if local_vars:
255            for local in local_vars:
256                ri.cw.p(local)
257            ri.cw.nl()
258
259        if not self.is_multi_val():
260            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
261            ri.cw.p("return YNL_PARSE_CB_ERROR;")
262            if self.presence_type() == 'present':
263                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
264
265        if init_lines:
266            ri.cw.nl()
267            for line in init_lines:
268                ri.cw.p(line)
269
270        for line in lines:
271            ri.cw.p(line)
272        ri.cw.block_end()
273        return True
274
275    def _setter_lines(self, ri, member, presence):
276        raise Exception(f"Setter not implemented for class type {self.type}")
277
278    def setter(self, ri, space, direction, deref=False, ref=None):
279        ref = (ref if ref else []) + [self.c_name]
280        var = "req"
281        member = f"{var}->{'.'.join(ref)}"
282
283        local_vars = []
284        if self.free_needs_iter():
285            local_vars += ['unsigned int i;']
286
287        code = []
288        presence = ''
289        for i in range(0, len(ref)):
290            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
291            # Every layer below last is a nest, so we know it uses bit presence
292            # last layer is "self" and may be a complex type
293            if i == len(ref) - 1 and self.presence_type() != 'present':
294                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
295                continue
296            code.append(presence + ' = 1;')
297        ref_path = '.'.join(ref[:-1])
298        if ref_path:
299            ref_path += '.'
300        code += self._free_lines(ri, var, ref_path)
301        code += self._setter_lines(ri, member, presence)
302
303        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
304        free = bool([x for x in code if 'free(' in x])
305        alloc = bool([x for x in code if 'alloc(' in x])
306        if free and not alloc:
307            func_name = '__' + func_name
308        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
309                         body=code,
310                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
311
312
313class TypeUnused(Type):
314    def presence_type(self):
315        return ''
316
317    def arg_member(self, ri):
318        return []
319
320    def _attr_get(self, ri, var):
321        return ['return YNL_PARSE_CB_ERROR;'], None, None
322
323    def _attr_typol(self):
324        return '.type = YNL_PT_REJECT, '
325
326    def attr_policy(self, cw):
327        pass
328
329    def attr_put(self, ri, var):
330        pass
331
332    def attr_get(self, ri, var, first):
333        pass
334
335    def setter(self, ri, space, direction, deref=False, ref=None):
336        pass
337
338
339class TypePad(Type):
340    def presence_type(self):
341        return ''
342
343    def arg_member(self, ri):
344        return []
345
346    def _attr_typol(self):
347        return '.type = YNL_PT_IGNORE, '
348
349    def attr_put(self, ri, var):
350        pass
351
352    def attr_get(self, ri, var, first):
353        pass
354
355    def attr_policy(self, cw):
356        pass
357
358    def setter(self, ri, space, direction, deref=False, ref=None):
359        pass
360
361
362class TypeScalar(Type):
363    def __init__(self, family, attr_set, attr, value):
364        super().__init__(family, attr_set, attr, value)
365
366        self.byte_order_comment = ''
367        if 'byte-order' in attr:
368            self.byte_order_comment = f" /* {attr['byte-order']} */"
369
370        # Classic families have some funny enums, don't bother
371        # computing checks, since we only need them for kernel policies
372        if not family.is_classic():
373            self._init_checks()
374
375        # Added by resolve():
376        self.is_bitfield = None
377        delattr(self, "is_bitfield")
378        self.type_name = None
379        delattr(self, "type_name")
380
381    def resolve(self):
382        self.resolve_up(super())
383
384        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
385            self.is_bitfield = True
386        elif 'enum' in self.attr:
387            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
388        else:
389            self.is_bitfield = False
390
391        if not self.is_bitfield and 'enum' in self.attr:
392            self.type_name = self.family.consts[self.attr['enum']].user_type
393        elif self.is_auto_scalar:
394            self.type_name = '__' + self.type[0] + '64'
395        else:
396            self.type_name = '__' + self.type
397
398    def _init_checks(self):
399        if 'enum' in self.attr:
400            enum = self.family.consts[self.attr['enum']]
401            low, high = enum.value_range()
402            if low == None and high == None:
403                self.checks['sparse'] = True
404            else:
405                if 'min' not in self.checks:
406                    if low != 0 or self.type[0] == 's':
407                        self.checks['min'] = low
408                if 'max' not in self.checks:
409                    self.checks['max'] = high
410
411        if 'min' in self.checks and 'max' in self.checks:
412            if self.get_limit('min') > self.get_limit('max'):
413                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
414            self.checks['range'] = True
415
416        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
417        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
418        if low < 0 and self.type[0] == 'u':
419            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
420        if low < -32768 or high > 32767:
421            self.checks['full-range'] = True
422
423    def _attr_policy(self, policy):
424        if 'flags-mask' in self.checks or self.is_bitfield:
425            if self.is_bitfield:
426                enum = self.family.consts[self.attr['enum']]
427                mask = enum.get_mask(as_flags=True)
428            else:
429                flags = self.family.consts[self.checks['flags-mask']]
430                flag_cnt = len(flags['entries'])
431                mask = (1 << flag_cnt) - 1
432            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
433        elif 'full-range' in self.checks:
434            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
435        elif 'range' in self.checks:
436            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
437        elif 'min' in self.checks:
438            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
439        elif 'max' in self.checks:
440            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
441        elif 'sparse' in self.checks:
442            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
443        return super()._attr_policy(policy)
444
445    def _attr_typol(self):
446        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
447
448    def arg_member(self, ri):
449        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
450
451    def attr_put(self, ri, var):
452        self._attr_put_simple(ri, var, self.type)
453
454    def _attr_get(self, ri, var):
455        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
456
457    def _setter_lines(self, ri, member, presence):
458        return [f"{member} = {self.c_name};"]
459
460
461class TypeFlag(Type):
462    def arg_member(self, ri):
463        return []
464
465    def _attr_typol(self):
466        return '.type = YNL_PT_FLAG, '
467
468    def attr_put(self, ri, var):
469        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
470
471    def _attr_get(self, ri, var):
472        return [], None, None
473
474    def _setter_lines(self, ri, member, presence):
475        return []
476
477
478class TypeString(Type):
479    def arg_member(self, ri):
480        return [f"const char *{self.c_name}"]
481
482    def presence_type(self):
483        return 'len'
484
485    def struct_member(self, ri):
486        ri.cw.p(f"char *{self.c_name};")
487
488    def _attr_typol(self):
489        typol = f'.type = YNL_PT_NUL_STR, '
490        if self.is_selector:
491            typol += '.is_selector = 1, '
492        return typol
493
494    def _attr_policy(self, policy):
495        if 'exact-len' in self.checks:
496            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
497        else:
498            mem = '{ .type = ' + policy
499            if 'max-len' in self.checks:
500                mem += ', .len = ' + self.get_limit_str('max-len')
501            mem += ', }'
502        return mem
503
504    def attr_policy(self, cw):
505        if self.checks.get('unterminated-ok', False):
506            policy = 'NLA_STRING'
507        else:
508            policy = 'NLA_NUL_STRING'
509
510        spec = self._attr_policy(policy)
511        cw.p(f"\t[{self.enum_name}] = {spec},")
512
513    def attr_put(self, ri, var):
514        self._attr_put_simple(ri, var, 'str')
515
516    def _attr_get(self, ri, var):
517        len_mem = var + '->_len.' + self.c_name
518        return [f"{len_mem} = len;",
519                f"{var}->{self.c_name} = malloc(len + 1);",
520                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
521                f"{var}->{self.c_name}[len] = 0;"], \
522               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
523               ['unsigned int len;']
524
525    def _setter_lines(self, ri, member, presence):
526        return [f"{presence} = strlen({self.c_name});",
527                f"{member} = malloc({presence} + 1);",
528                f'memcpy({member}, {self.c_name}, {presence});',
529                f'{member}[{presence}] = 0;']
530
531
532class TypeBinary(Type):
533    def arg_member(self, ri):
534        return [f"const void *{self.c_name}", 'size_t len']
535
536    def presence_type(self):
537        return 'len'
538
539    def struct_member(self, ri):
540        ri.cw.p(f"void *{self.c_name};")
541
542    def _attr_typol(self):
543        return f'.type = YNL_PT_BINARY,'
544
545    def _attr_policy(self, policy):
546        if len(self.checks) == 0:
547            pass
548        elif len(self.checks) == 1:
549            check_name = list(self.checks)[0]
550            if check_name not in {'exact-len', 'min-len', 'max-len'}:
551                raise Exception('Unsupported check for binary type: ' + check_name)
552        else:
553            raise Exception('More than one check for binary type not implemented, yet')
554
555        if len(self.checks) == 0:
556            mem = '{ .type = NLA_BINARY, }'
557        elif 'exact-len' in self.checks:
558            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
559        elif 'min-len' in self.checks:
560            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
561        elif 'max-len' in self.checks:
562            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
563
564        return mem
565
566    def attr_put(self, ri, var):
567        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
568                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
569
570    def _attr_get(self, ri, var):
571        len_mem = var + '->_len.' + self.c_name
572        return [f"{len_mem} = len;",
573                f"{var}->{self.c_name} = malloc(len);",
574                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
575               ['len = ynl_attr_data_len(attr);'], \
576               ['unsigned int len;']
577
578    def _setter_lines(self, ri, member, presence):
579        return [f"{presence} = len;",
580                f"{member} = malloc({presence});",
581                f'memcpy({member}, {self.c_name}, {presence});']
582
583
584class TypeBinaryStruct(TypeBinary):
585    def struct_member(self, ri):
586        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
587
588    def _attr_get(self, ri, var):
589        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
590        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
591        return [f"{len_mem} = len;",
592                f"if (len < {struct_sz})",
593                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
594                "else",
595                f"{var}->{self.c_name} = malloc(len);",
596                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
597               ['len = ynl_attr_data_len(attr);'], \
598               ['unsigned int len;']
599
600
601class TypeBinaryScalarArray(TypeBinary):
602    def arg_member(self, ri):
603        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
604
605    def presence_type(self):
606        return 'count'
607
608    def struct_member(self, ri):
609        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
610
611    def attr_put(self, ri, var):
612        presence = self.presence_type()
613        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
614        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
615        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
616                f"{var}->{self.c_name}, i);")
617        ri.cw.block_end()
618
619    def _attr_get(self, ri, var):
620        len_mem = var + '->_count.' + self.c_name
621        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
622                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
623                f"{var}->{self.c_name} = malloc(len);",
624                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
625               ['len = ynl_attr_data_len(attr);'], \
626               ['unsigned int len;']
627
628    def _setter_lines(self, ri, member, presence):
629        return [f"{presence} = count;",
630                f"count *= sizeof(__{self.get('sub-type')});",
631                f"{member} = malloc(count);",
632                f'memcpy({member}, {self.c_name}, count);']
633
634
635class TypeBitfield32(Type):
636    def _complex_member_type(self, ri):
637        return "struct nla_bitfield32"
638
639    def _attr_typol(self):
640        return f'.type = YNL_PT_BITFIELD32, '
641
642    def _attr_policy(self, policy):
643        if not 'enum' in self.attr:
644            raise Exception('Enum required for bitfield32 attr')
645        enum = self.family.consts[self.attr['enum']]
646        mask = enum.get_mask(as_flags=True)
647        return f"NLA_POLICY_BITFIELD32({mask})"
648
649    def attr_put(self, ri, var):
650        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
651        self._attr_put_line(ri, var, line)
652
653    def _attr_get(self, ri, var):
654        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
655
656    def _setter_lines(self, ri, member, presence):
657        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
658
659
660class TypeNest(Type):
661    def is_recursive(self):
662        return self.family.pure_nested_structs[self.nested_attrs].recursive
663
664    def _complex_member_type(self, ri):
665        return self.nested_struct_type
666
667    def _free_lines(self, ri, var, ref):
668        lines = []
669        at = '&'
670        if self.is_recursive_for_op(ri):
671            at = ''
672            lines += [f'if ({var}->{ref}{self.c_name})']
673        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
674        return lines
675
676    def _attr_typol(self):
677        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
678
679    def _attr_policy(self, policy):
680        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
681
682    def attr_put(self, ri, var):
683        at = '' if self.is_recursive_for_op(ri) else '&'
684        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
685                            f"{self.enum_name}, {at}{var}->{self.c_name})")
686
687    def _attr_get(self, ri, var):
688        pns = self.family.pure_nested_structs[self.nested_attrs]
689        args = ["&parg", "attr"]
690        for sel in pns.external_selectors():
691            args.append(f'{var}->{sel.name}')
692        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
693                     "return YNL_PARSE_CB_ERROR;"]
694        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
695                      f"parg.data = &{var}->{self.c_name};"]
696        return get_lines, init_lines, None
697
698    def setter(self, ri, space, direction, deref=False, ref=None):
699        ref = (ref if ref else []) + [self.c_name]
700
701        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
702            if attr.is_recursive():
703                continue
704            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
705
706
707class TypeMultiAttr(Type):
708    def __init__(self, family, attr_set, attr, value, base_type):
709        super().__init__(family, attr_set, attr, value)
710
711        self.base_type = base_type
712
713    def is_multi_val(self):
714        return True
715
716    def presence_type(self):
717        return 'count'
718
719    def _complex_member_type(self, ri):
720        if 'type' not in self.attr or self.attr['type'] == 'nest':
721            return self.nested_struct_type
722        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
723            return None  # use arg_member()
724        elif self.attr['type'] == 'string':
725            return 'struct ynl_string *'
726        elif self.attr['type'] in scalars:
727            scalar_pfx = '__' if ri.ku_space == 'user' else ''
728            return scalar_pfx + self.attr['type']
729        else:
730            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
731
732    def arg_member(self, ri):
733        if self.type == 'binary' and 'struct' in self.attr:
734            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
735                    f'unsigned int n_{self.c_name}']
736        return super().arg_member(ri)
737
738    def free_needs_iter(self):
739        return self.attr['type'] in {'nest', 'string'}
740
741    def _free_lines(self, ri, var, ref):
742        lines = []
743        if self.attr['type'] in scalars:
744            lines += [f"free({var}->{ref}{self.c_name});"]
745        elif self.attr['type'] == 'binary':
746            lines += [f"free({var}->{ref}{self.c_name});"]
747        elif self.attr['type'] == 'string':
748            lines += [
749                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
750                f"free({var}->{ref}{self.c_name}[i]);",
751                f"free({var}->{ref}{self.c_name});",
752            ]
753        elif 'type' not in self.attr or self.attr['type'] == 'nest':
754            lines += [
755                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
756                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
757                f"free({var}->{ref}{self.c_name});",
758            ]
759        else:
760            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
761        return lines
762
763    def _attr_policy(self, policy):
764        return self.base_type._attr_policy(policy)
765
766    def _attr_typol(self):
767        return self.base_type._attr_typol()
768
769    def _attr_get(self, ri, var):
770        return f'n_{self.c_name}++;', None, None
771
772    def attr_put(self, ri, var):
773        if self.attr['type'] in scalars:
774            put_type = self.type
775            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
776            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
777        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
778            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
779            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
780        elif self.attr['type'] == 'string':
781            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
782            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
783        elif 'type' not in self.attr or self.attr['type'] == 'nest':
784            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
785            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
786                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
787        else:
788            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
789
790    def _setter_lines(self, ri, member, presence):
791        return [f"{member} = {self.c_name};",
792                f"{presence} = n_{self.c_name};"]
793
794
795class TypeArrayNest(Type):
796    def is_multi_val(self):
797        return True
798
799    def presence_type(self):
800        return 'count'
801
802    def _complex_member_type(self, ri):
803        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
804            return self.nested_struct_type
805        elif self.attr['sub-type'] in scalars:
806            scalar_pfx = '__' if ri.ku_space == 'user' else ''
807            return scalar_pfx + self.attr['sub-type']
808        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
809            return None  # use arg_member()
810        else:
811            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
812
813    def arg_member(self, ri):
814        if self.sub_type == 'binary' and 'exact-len' in self.checks:
815            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
816                    f'unsigned int n_{self.c_name}']
817        return super().arg_member(ri)
818
819    def _attr_typol(self):
820        if self.attr['sub-type'] in scalars:
821            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
822        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
823            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
824        else:
825            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
826
827    def _attr_get(self, ri, var):
828        local_vars = ['const struct nlattr *attr2;']
829        get_lines = [f'attr_{self.c_name} = attr;',
830                     'ynl_attr_for_each_nested(attr2, attr) {',
831                     '\tif (ynl_attr_validate(yarg, attr2))',
832                     '\t\treturn YNL_PARSE_CB_ERROR;',
833                     f'\t{var}->_count.{self.c_name}++;',
834                     '}']
835        return get_lines, None, local_vars
836
837    def attr_put(self, ri, var):
838        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
839        if self.sub_type in scalars:
840            put_type = self.sub_type
841            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
842            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
843            ri.cw.block_end()
844        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
845            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
846            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
847        elif self.sub_type == 'nest':
848            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
849            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
850        else:
851            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
852        ri.cw.p('ynl_attr_nest_end(nlh, array);')
853
854    def _setter_lines(self, ri, member, presence):
855        return [f"{member} = {self.c_name};",
856                f"{presence} = n_{self.c_name};"]
857
858
859class TypeNestTypeValue(Type):
860    def _complex_member_type(self, ri):
861        return self.nested_struct_type
862
863    def _attr_typol(self):
864        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
865
866    def _attr_get(self, ri, var):
867        prev = 'attr'
868        tv_args = ''
869        get_lines = []
870        local_vars = []
871        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
872                      f"parg.data = &{var}->{self.c_name};"]
873        if 'type-value' in self.attr:
874            tv_names = [c_lower(x) for x in self.attr["type-value"]]
875            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
876            local_vars += [f'__u32 {", ".join(tv_names)};']
877            for level in self.attr["type-value"]:
878                level = c_lower(level)
879                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
880                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
881                prev = 'attr_' + level
882
883            tv_args = f", {', '.join(tv_names)}"
884
885        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
886        return get_lines, init_lines, local_vars
887
888
889class TypeSubMessage(TypeNest):
890    def __init__(self, family, attr_set, attr, value):
891        super().__init__(family, attr_set, attr, value)
892
893        self.selector = Selector(attr, attr_set)
894
895    def _attr_typol(self):
896        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
897        typol += '.is_submsg = 1, '
898        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
899        # support external selectors. No family uses sub-messages with external
900        # selector for requests so this is fine for now.
901        if not self.selector.is_external():
902            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
903        return typol
904
905    def _attr_get(self, ri, var):
906        sel = c_lower(self['selector'])
907        if self.selector.is_external():
908            sel_var = f"_sel_{sel}"
909        else:
910            sel_var = f"{var}->{sel}"
911        get_lines = [f'if (!{sel_var})',
912                     f'return ynl_submsg_failed(yarg, "%s", "%s");' %
913                        (self.name, self['selector']),
914                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
915                     "return YNL_PARSE_CB_ERROR;"]
916        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
917                      f"parg.data = &{var}->{self.c_name};"]
918        return get_lines, init_lines, None
919
920
921class Selector:
922    def __init__(self, msg_attr, attr_set):
923        self.name = msg_attr["selector"]
924
925        if self.name in attr_set:
926            self.attr = attr_set[self.name]
927            self.attr.is_selector = True
928            self._external = False
929        else:
930            # The selector will need to get passed down thru the structs
931            self.attr = None
932            self._external = True
933
934    def set_attr(self, attr):
935        self.attr = attr
936
937    def is_external(self):
938        return self._external
939
940
941class Struct:
942    def __init__(self, family, space_name, type_list=None, fixed_header=None,
943                 inherited=None, submsg=None):
944        self.family = family
945        self.space_name = space_name
946        self.attr_set = family.attr_sets[space_name]
947        # Use list to catch comparisons with empty sets
948        self._inherited = inherited if inherited is not None else []
949        self.inherited = []
950        self.fixed_header = None
951        if fixed_header:
952            self.fixed_header = 'struct ' + c_lower(fixed_header)
953        self.submsg = submsg
954
955        self.nested = type_list is None
956        if family.name == c_lower(space_name):
957            self.render_name = c_lower(family.ident_name)
958        else:
959            self.render_name = c_lower(family.ident_name + '-' + space_name)
960        self.struct_name = 'struct ' + self.render_name
961        if self.nested and space_name in family.consts:
962            self.struct_name += '_'
963        self.ptr_name = self.struct_name + ' *'
964        # All attr sets this one contains, directly or multiple levels down
965        self.child_nests = set()
966
967        self.request = False
968        self.reply = False
969        self.recursive = False
970        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
971
972        self.attr_list = []
973        self.attrs = dict()
974        if type_list is not None:
975            for t in type_list:
976                self.attr_list.append((t, self.attr_set[t]),)
977        else:
978            for t in self.attr_set:
979                self.attr_list.append((t, self.attr_set[t]),)
980
981        max_val = 0
982        self.attr_max_val = None
983        for name, attr in self.attr_list:
984            if attr.value >= max_val:
985                max_val = attr.value
986                self.attr_max_val = attr
987            self.attrs[name] = attr
988
989    def __iter__(self):
990        yield from self.attrs
991
992    def __getitem__(self, key):
993        return self.attrs[key]
994
995    def member_list(self):
996        return self.attr_list
997
998    def set_inherited(self, new_inherited):
999        if self._inherited != new_inherited:
1000            raise Exception("Inheriting different members not supported")
1001        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1002
1003    def external_selectors(self):
1004        sels = []
1005        for name, attr in self.attr_list:
1006            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1007                sels.append(attr.selector)
1008        return sels
1009
1010    def free_needs_iter(self):
1011        for _, attr in self.attr_list:
1012            if attr.free_needs_iter():
1013                return True
1014        return False
1015
1016
1017class EnumEntry(SpecEnumEntry):
1018    def __init__(self, enum_set, yaml, prev, value_start):
1019        super().__init__(enum_set, yaml, prev, value_start)
1020
1021        if prev:
1022            self.value_change = (self.value != prev.value + 1)
1023        else:
1024            self.value_change = (self.value != 0)
1025        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1026
1027        # Added by resolve:
1028        self.c_name = None
1029        delattr(self, "c_name")
1030
1031    def resolve(self):
1032        self.resolve_up(super())
1033
1034        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1035
1036
1037class EnumSet(SpecEnumSet):
1038    def __init__(self, family, yaml):
1039        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1040
1041        if 'enum-name' in yaml:
1042            if yaml['enum-name']:
1043                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1044                self.user_type = self.enum_name
1045            else:
1046                self.enum_name = None
1047        else:
1048            self.enum_name = 'enum ' + self.render_name
1049
1050        if self.enum_name:
1051            self.user_type = self.enum_name
1052        else:
1053            self.user_type = 'int'
1054
1055        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1056        self.header = yaml.get('header', None)
1057        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1058
1059        super().__init__(family, yaml)
1060
1061    def new_entry(self, entry, prev_entry, value_start):
1062        return EnumEntry(self, entry, prev_entry, value_start)
1063
1064    def value_range(self):
1065        low = min([x.value for x in self.entries.values()])
1066        high = max([x.value for x in self.entries.values()])
1067
1068        if high - low + 1 != len(self.entries):
1069            return None, None
1070
1071        return low, high
1072
1073
1074class AttrSet(SpecAttrSet):
1075    def __init__(self, family, yaml):
1076        super().__init__(family, yaml)
1077
1078        if self.subset_of is None:
1079            if 'name-prefix' in yaml:
1080                pfx = yaml['name-prefix']
1081            elif self.name == family.name:
1082                pfx = family.ident_name + '-a-'
1083            else:
1084                pfx = f"{family.ident_name}-a-{self.name}-"
1085            self.name_prefix = c_upper(pfx)
1086            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1087            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1088        else:
1089            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1090            self.max_name = family.attr_sets[self.subset_of].max_name
1091            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1092
1093        # Added by resolve:
1094        self.c_name = None
1095        delattr(self, "c_name")
1096
1097    def resolve(self):
1098        self.c_name = c_lower(self.name)
1099        if self.c_name in _C_KW:
1100            self.c_name += '_'
1101        if self.c_name == self.family.c_name:
1102            self.c_name = ''
1103
1104    def new_attr(self, elem, value):
1105        if elem['type'] in scalars:
1106            t = TypeScalar(self.family, self, elem, value)
1107        elif elem['type'] == 'unused':
1108            t = TypeUnused(self.family, self, elem, value)
1109        elif elem['type'] == 'pad':
1110            t = TypePad(self.family, self, elem, value)
1111        elif elem['type'] == 'flag':
1112            t = TypeFlag(self.family, self, elem, value)
1113        elif elem['type'] == 'string':
1114            t = TypeString(self.family, self, elem, value)
1115        elif elem['type'] == 'binary':
1116            if 'struct' in elem:
1117                t = TypeBinaryStruct(self.family, self, elem, value)
1118            elif elem.get('sub-type') in scalars:
1119                t = TypeBinaryScalarArray(self.family, self, elem, value)
1120            else:
1121                t = TypeBinary(self.family, self, elem, value)
1122        elif elem['type'] == 'bitfield32':
1123            t = TypeBitfield32(self.family, self, elem, value)
1124        elif elem['type'] == 'nest':
1125            t = TypeNest(self.family, self, elem, value)
1126        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1127            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1128                t = TypeArrayNest(self.family, self, elem, value)
1129            else:
1130                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1131        elif elem['type'] == 'nest-type-value':
1132            t = TypeNestTypeValue(self.family, self, elem, value)
1133        elif elem['type'] == 'sub-message':
1134            t = TypeSubMessage(self.family, self, elem, value)
1135        else:
1136            raise Exception(f"No typed class for type {elem['type']}")
1137
1138        if 'multi-attr' in elem and elem['multi-attr']:
1139            t = TypeMultiAttr(self.family, self, elem, value, t)
1140
1141        return t
1142
1143
1144class Operation(SpecOperation):
1145    def __init__(self, family, yaml, req_value, rsp_value):
1146        # Fill in missing operation properties (for fixed hdr-only msgs)
1147        for mode in ['do', 'dump', 'event']:
1148            for direction in ['request', 'reply']:
1149                try:
1150                    yaml[mode][direction].setdefault('attributes', [])
1151                except KeyError:
1152                    pass
1153
1154        super().__init__(family, yaml, req_value, rsp_value)
1155
1156        self.render_name = c_lower(family.ident_name + '_' + self.name)
1157
1158        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1159                         ('dump' in yaml and 'request' in yaml['dump'])
1160
1161        self.has_ntf = False
1162
1163        # Added by resolve:
1164        self.enum_name = None
1165        delattr(self, "enum_name")
1166
1167    def resolve(self):
1168        self.resolve_up(super())
1169
1170        if not self.is_async:
1171            self.enum_name = self.family.op_prefix + c_upper(self.name)
1172        else:
1173            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1174
1175    def mark_has_ntf(self):
1176        self.has_ntf = True
1177
1178
1179class SubMessage(SpecSubMessage):
1180    def __init__(self, family, yaml):
1181        super().__init__(family, yaml)
1182
1183        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1184
1185    def resolve(self):
1186        self.resolve_up(super())
1187
1188
1189class Family(SpecFamily):
1190    def __init__(self, file_name, exclude_ops):
1191        # Added by resolve:
1192        self.c_name = None
1193        delattr(self, "c_name")
1194        self.op_prefix = None
1195        delattr(self, "op_prefix")
1196        self.async_op_prefix = None
1197        delattr(self, "async_op_prefix")
1198        self.mcgrps = None
1199        delattr(self, "mcgrps")
1200        self.consts = None
1201        delattr(self, "consts")
1202        self.hooks = None
1203        delattr(self, "hooks")
1204
1205        super().__init__(file_name, exclude_ops=exclude_ops)
1206
1207        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1208        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1209
1210        if 'definitions' not in self.yaml:
1211            self.yaml['definitions'] = []
1212
1213        if 'uapi-header' in self.yaml:
1214            self.uapi_header = self.yaml['uapi-header']
1215        else:
1216            self.uapi_header = f"linux/{self.ident_name}.h"
1217        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1218            self.uapi_header_name = self.uapi_header[6:-2]
1219        else:
1220            self.uapi_header_name = self.ident_name
1221
1222    def resolve(self):
1223        self.resolve_up(super())
1224
1225        self.c_name = c_lower(self.ident_name)
1226        if 'name-prefix' in self.yaml['operations']:
1227            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1228        else:
1229            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1230        if 'async-prefix' in self.yaml['operations']:
1231            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1232        else:
1233            self.async_op_prefix = self.op_prefix
1234
1235        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1236
1237        self.hooks = dict()
1238        for when in ['pre', 'post']:
1239            self.hooks[when] = dict()
1240            for op_mode in ['do', 'dump']:
1241                self.hooks[when][op_mode] = dict()
1242                self.hooks[when][op_mode]['set'] = set()
1243                self.hooks[when][op_mode]['list'] = []
1244
1245        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1246        self.root_sets = dict()
1247        # dict space-name -> Struct
1248        self.pure_nested_structs = dict()
1249
1250        self._mark_notify()
1251        self._mock_up_events()
1252
1253        self._load_root_sets()
1254        self._load_nested_sets()
1255        self._load_attr_use()
1256        self._load_selector_passing()
1257        self._load_hooks()
1258
1259        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1260        if self.kernel_policy == 'global':
1261            self._load_global_policy()
1262
1263    def new_enum(self, elem):
1264        return EnumSet(self, elem)
1265
1266    def new_attr_set(self, elem):
1267        return AttrSet(self, elem)
1268
1269    def new_operation(self, elem, req_value, rsp_value):
1270        return Operation(self, elem, req_value, rsp_value)
1271
1272    def new_sub_message(self, elem):
1273        return SubMessage(self, elem)
1274
1275    def is_classic(self):
1276        return self.proto == 'netlink-raw'
1277
1278    def _mark_notify(self):
1279        for op in self.msgs.values():
1280            if 'notify' in op:
1281                self.ops[op['notify']].mark_has_ntf()
1282
1283    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1284    def _mock_up_events(self):
1285        for op in self.yaml['operations']['list']:
1286            if 'event' in op:
1287                op['do'] = {
1288                    'reply': {
1289                        'attributes': op['event']['attributes']
1290                    }
1291                }
1292
1293    def _load_root_sets(self):
1294        for op_name, op in self.msgs.items():
1295            if 'attribute-set' not in op:
1296                continue
1297
1298            req_attrs = set()
1299            rsp_attrs = set()
1300            for op_mode in ['do', 'dump']:
1301                if op_mode in op and 'request' in op[op_mode]:
1302                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1303                if op_mode in op and 'reply' in op[op_mode]:
1304                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1305            if 'event' in op:
1306                rsp_attrs.update(set(op['event']['attributes']))
1307
1308            if op['attribute-set'] not in self.root_sets:
1309                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1310            else:
1311                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1312                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1313
1314    def _sort_pure_types(self):
1315        # Try to reorder according to dependencies
1316        pns_key_list = list(self.pure_nested_structs.keys())
1317        pns_key_seen = set()
1318        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1319        for _ in range(rounds):
1320            if len(pns_key_list) == 0:
1321                break
1322            name = pns_key_list.pop(0)
1323            finished = True
1324            for _, spec in self.attr_sets[name].items():
1325                if 'nested-attributes' in spec:
1326                    nested = spec['nested-attributes']
1327                elif 'sub-message' in spec:
1328                    nested = spec.sub_message
1329                else:
1330                    continue
1331
1332                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1333                if self.pure_nested_structs[nested].recursive:
1334                    continue
1335                if nested not in pns_key_seen:
1336                    # Dicts are sorted, this will make struct last
1337                    struct = self.pure_nested_structs.pop(name)
1338                    self.pure_nested_structs[name] = struct
1339                    finished = False
1340                    break
1341            if finished:
1342                pns_key_seen.add(name)
1343            else:
1344                pns_key_list.append(name)
1345
1346    def _load_nested_set_nest(self, spec):
1347        inherit = set()
1348        nested = spec['nested-attributes']
1349        if nested not in self.root_sets:
1350            if nested not in self.pure_nested_structs:
1351                self.pure_nested_structs[nested] = \
1352                    Struct(self, nested, inherited=inherit,
1353                           fixed_header=spec.get('fixed-header'))
1354        else:
1355            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1356
1357        if 'type-value' in spec:
1358            if nested in self.root_sets:
1359                raise Exception("Inheriting members to a space used as root not supported")
1360            inherit.update(set(spec['type-value']))
1361        elif spec['type'] == 'indexed-array':
1362            inherit.add('idx')
1363        self.pure_nested_structs[nested].set_inherited(inherit)
1364
1365        return nested
1366
1367    def _load_nested_set_submsg(self, spec):
1368        # Fake the struct type for the sub-message itself
1369        # its not a attr_set but codegen wants attr_sets.
1370        submsg = self.sub_msgs[spec["sub-message"]]
1371        nested = submsg.name
1372
1373        attrs = []
1374        for name, fmt in submsg.formats.items():
1375            attrs.append({
1376                "name": name,
1377                "type": "nest",
1378                "parent-sub-message": spec,
1379                "nested-attributes": fmt['attribute-set']
1380            })
1381
1382        self.attr_sets[nested] = AttrSet(self, {
1383            "name": nested,
1384            "name-pfx": self.name + '-' + spec.name + '-',
1385            "attributes": attrs
1386        })
1387
1388        if nested not in self.pure_nested_structs:
1389            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1390
1391        return nested
1392
1393    def _load_nested_sets(self):
1394        attr_set_queue = list(self.root_sets.keys())
1395        attr_set_seen = set(self.root_sets.keys())
1396
1397        while len(attr_set_queue):
1398            a_set = attr_set_queue.pop(0)
1399            for attr, spec in self.attr_sets[a_set].items():
1400                if 'nested-attributes' in spec:
1401                    nested = self._load_nested_set_nest(spec)
1402                elif 'sub-message' in spec:
1403                    nested = self._load_nested_set_submsg(spec)
1404                else:
1405                    continue
1406
1407                if nested not in attr_set_seen:
1408                    attr_set_queue.append(nested)
1409                    attr_set_seen.add(nested)
1410
1411        for root_set, rs_members in self.root_sets.items():
1412            for attr, spec in self.attr_sets[root_set].items():
1413                if 'nested-attributes' in spec:
1414                    nested = spec['nested-attributes']
1415                elif 'sub-message' in spec:
1416                    nested = spec.sub_message
1417                else:
1418                    nested = None
1419
1420                if nested:
1421                    if attr in rs_members['request']:
1422                        self.pure_nested_structs[nested].request = True
1423                    if attr in rs_members['reply']:
1424                        self.pure_nested_structs[nested].reply = True
1425
1426                    if spec.is_multi_val():
1427                        child = self.pure_nested_structs.get(nested)
1428                        child.in_multi_val = True
1429
1430        self._sort_pure_types()
1431
1432        # Propagate the request / reply / recursive
1433        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1434            for _, spec in self.attr_sets[attr_set].items():
1435                if attr_set in struct.child_nests:
1436                    struct.recursive = True
1437
1438                if 'nested-attributes' in spec:
1439                    child_name = spec['nested-attributes']
1440                elif 'sub-message' in spec:
1441                    child_name = spec.sub_message
1442                else:
1443                    continue
1444
1445                struct.child_nests.add(child_name)
1446                child = self.pure_nested_structs.get(child_name)
1447                if child:
1448                    if not child.recursive:
1449                        struct.child_nests.update(child.child_nests)
1450                    child.request |= struct.request
1451                    child.reply |= struct.reply
1452                    if spec.is_multi_val():
1453                        child.in_multi_val = True
1454
1455        self._sort_pure_types()
1456
1457    def _load_attr_use(self):
1458        for _, struct in self.pure_nested_structs.items():
1459            if struct.request:
1460                for _, arg in struct.member_list():
1461                    arg.set_request()
1462            if struct.reply:
1463                for _, arg in struct.member_list():
1464                    arg.set_reply()
1465
1466        for root_set, rs_members in self.root_sets.items():
1467            for attr, spec in self.attr_sets[root_set].items():
1468                if attr in rs_members['request']:
1469                    spec.set_request()
1470                if attr in rs_members['reply']:
1471                    spec.set_reply()
1472
1473    def _load_selector_passing(self):
1474        def all_structs():
1475            for k, v in reversed(self.pure_nested_structs.items()):
1476                yield k, v
1477            for k, _ in self.root_sets.items():
1478                yield k, None  # we don't have a struct, but it must be terminal
1479
1480        for attr_set, struct in all_structs():
1481            for _, spec in self.attr_sets[attr_set].items():
1482                if 'nested-attributes' in spec:
1483                    child_name = spec['nested-attributes']
1484                elif 'sub-message' in spec:
1485                    child_name = spec.sub_message
1486                else:
1487                    continue
1488
1489                child = self.pure_nested_structs.get(child_name)
1490                for selector in child.external_selectors():
1491                    if selector.name in self.attr_sets[attr_set]:
1492                        sel_attr = self.attr_sets[attr_set][selector.name]
1493                        selector.set_attr(sel_attr)
1494                    else:
1495                        raise Exception("Passing selector thru more than one layer not supported")
1496
1497    def _load_global_policy(self):
1498        global_set = set()
1499        attr_set_name = None
1500        for op_name, op in self.ops.items():
1501            if not op:
1502                continue
1503            if 'attribute-set' not in op:
1504                continue
1505
1506            if attr_set_name is None:
1507                attr_set_name = op['attribute-set']
1508            if attr_set_name != op['attribute-set']:
1509                raise Exception('For a global policy all ops must use the same set')
1510
1511            for op_mode in ['do', 'dump']:
1512                if op_mode in op:
1513                    req = op[op_mode].get('request')
1514                    if req:
1515                        global_set.update(req.get('attributes', []))
1516
1517        self.global_policy = []
1518        self.global_policy_set = attr_set_name
1519        for attr in self.attr_sets[attr_set_name]:
1520            if attr in global_set:
1521                self.global_policy.append(attr)
1522
1523    def _load_hooks(self):
1524        for op in self.ops.values():
1525            for op_mode in ['do', 'dump']:
1526                if op_mode not in op:
1527                    continue
1528                for when in ['pre', 'post']:
1529                    if when not in op[op_mode]:
1530                        continue
1531                    name = op[op_mode][when]
1532                    if name in self.hooks[when][op_mode]['set']:
1533                        continue
1534                    self.hooks[when][op_mode]['set'].add(name)
1535                    self.hooks[when][op_mode]['list'].append(name)
1536
1537
1538class RenderInfo:
1539    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1540        self.family = family
1541        self.nl = cw.nlib
1542        self.ku_space = ku_space
1543        self.op_mode = op_mode
1544        self.op = op
1545
1546        fixed_hdr = op.fixed_header if op else None
1547        self.fixed_hdr_len = 'ys->family->hdr_len'
1548        if op and op.fixed_header:
1549            if op.fixed_header != family.fixed_header:
1550                if family.is_classic():
1551                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1552                else:
1553                    raise Exception(f"Per-op fixed header not supported, yet")
1554
1555
1556        # 'do' and 'dump' response parsing is identical
1557        self.type_consistent = True
1558        self.type_oneside = False
1559        if op_mode != 'do' and 'dump' in op:
1560            if 'do' in op:
1561                if ('reply' in op['do']) != ('reply' in op["dump"]):
1562                    self.type_consistent = False
1563                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1564                    self.type_consistent = False
1565            else:
1566                self.type_consistent = True
1567                self.type_oneside = True
1568
1569        self.attr_set = attr_set
1570        if not self.attr_set:
1571            self.attr_set = op['attribute-set']
1572
1573        self.type_name_conflict = False
1574        if op:
1575            self.type_name = c_lower(op.name)
1576        else:
1577            self.type_name = c_lower(attr_set)
1578            if attr_set in family.consts:
1579                self.type_name_conflict = True
1580
1581        self.cw = cw
1582
1583        self.struct = dict()
1584        if op_mode == 'notify':
1585            op_mode = 'do' if 'do' in op else 'dump'
1586        for op_dir in ['request', 'reply']:
1587            if op:
1588                type_list = []
1589                if op_dir in op[op_mode]:
1590                    type_list = op[op_mode][op_dir]['attributes']
1591                self.struct[op_dir] = Struct(family, self.attr_set,
1592                                             fixed_header=fixed_hdr,
1593                                             type_list=type_list)
1594        if op_mode == 'event':
1595            self.struct['reply'] = Struct(family, self.attr_set,
1596                                          fixed_header=fixed_hdr,
1597                                          type_list=op['event']['attributes'])
1598
1599    def type_empty(self, key):
1600        return len(self.struct[key].attr_list) == 0 and \
1601            self.struct['request'].fixed_header is None
1602
1603    def needs_nlflags(self, direction):
1604        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1605
1606
1607class CodeWriter:
1608    def __init__(self, nlib, out_file=None, overwrite=True):
1609        self.nlib = nlib
1610        self._overwrite = overwrite
1611
1612        self._nl = False
1613        self._block_end = False
1614        self._silent_block = False
1615        self._ind = 0
1616        self._ifdef_block = None
1617        if out_file is None:
1618            self._out = os.sys.stdout
1619        else:
1620            self._out = tempfile.NamedTemporaryFile('w+')
1621            self._out_file = out_file
1622
1623    def __del__(self):
1624        self.close_out_file()
1625
1626    def close_out_file(self):
1627        if self._out == os.sys.stdout:
1628            return
1629        # Avoid modifying the file if contents didn't change
1630        self._out.flush()
1631        if not self._overwrite and os.path.isfile(self._out_file):
1632            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1633                return
1634        with open(self._out_file, 'w+') as out_file:
1635            self._out.seek(0)
1636            shutil.copyfileobj(self._out, out_file)
1637            self._out.close()
1638        self._out = os.sys.stdout
1639
1640    @classmethod
1641    def _is_cond(cls, line):
1642        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1643
1644    def p(self, line, add_ind=0):
1645        if self._block_end:
1646            self._block_end = False
1647            if line.startswith('else'):
1648                line = '} ' + line
1649            else:
1650                self._out.write('\t' * self._ind + '}\n')
1651
1652        if self._nl:
1653            self._out.write('\n')
1654            self._nl = False
1655
1656        ind = self._ind
1657        if line[-1] == ':':
1658            ind -= 1
1659        if self._silent_block:
1660            ind += 1
1661        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1662        self._silent_block |= line.strip() == 'else'
1663        if line[0] == '#':
1664            ind = 0
1665        if add_ind:
1666            ind += add_ind
1667        self._out.write('\t' * ind + line + '\n')
1668
1669    def nl(self):
1670        self._nl = True
1671
1672    def block_start(self, line=''):
1673        if line:
1674            line = line + ' '
1675        self.p(line + '{')
1676        self._ind += 1
1677
1678    def block_end(self, line=''):
1679        if line and line[0] not in {';', ','}:
1680            line = ' ' + line
1681        self._ind -= 1
1682        self._nl = False
1683        if not line:
1684            # Delay printing closing bracket in case "else" comes next
1685            if self._block_end:
1686                self._out.write('\t' * (self._ind + 1) + '}\n')
1687            self._block_end = True
1688        else:
1689            self.p('}' + line)
1690
1691    def write_doc_line(self, doc, indent=True):
1692        words = doc.split()
1693        line = ' *'
1694        for word in words:
1695            if len(line) + len(word) >= 79:
1696                self.p(line)
1697                line = ' *'
1698                if indent:
1699                    line += '  '
1700            line += ' ' + word
1701        self.p(line)
1702
1703    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1704        if not args:
1705            args = ['void']
1706
1707        if doc:
1708            self.p('/*')
1709            self.p(' * ' + doc)
1710            self.p(' */')
1711
1712        oneline = qual_ret
1713        if qual_ret[-1] != '*':
1714            oneline += ' '
1715        oneline += f"{name}({', '.join(args)}){suffix}"
1716
1717        if len(oneline) < 80:
1718            self.p(oneline)
1719            return
1720
1721        v = qual_ret
1722        if len(v) > 3:
1723            self.p(v)
1724            v = ''
1725        elif qual_ret[-1] != '*':
1726            v += ' '
1727        v += name + '('
1728        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1729        delta_ind = len(v) - len(ind)
1730        v += args[0]
1731        i = 1
1732        while i < len(args):
1733            next_len = len(v) + len(args[i])
1734            if v[0] == '\t':
1735                next_len += delta_ind
1736            if next_len > 76:
1737                self.p(v + ',')
1738                v = ind
1739            else:
1740                v += ', '
1741            v += args[i]
1742            i += 1
1743        self.p(v + ')' + suffix)
1744
1745    def write_func_lvar(self, local_vars):
1746        if not local_vars:
1747            return
1748
1749        if type(local_vars) is str:
1750            local_vars = [local_vars]
1751
1752        local_vars.sort(key=len, reverse=True)
1753        for var in local_vars:
1754            self.p(var)
1755        self.nl()
1756
1757    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1758        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1759        self.block_start()
1760        self.write_func_lvar(local_vars=local_vars)
1761
1762        for line in body:
1763            self.p(line)
1764        self.block_end()
1765
1766    def writes_defines(self, defines):
1767        longest = 0
1768        for define in defines:
1769            if len(define[0]) > longest:
1770                longest = len(define[0])
1771        longest = ((longest + 8) // 8) * 8
1772        for define in defines:
1773            line = '#define ' + define[0]
1774            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1775            if type(define[1]) is int:
1776                line += str(define[1])
1777            elif type(define[1]) is str:
1778                line += '"' + define[1] + '"'
1779            self.p(line)
1780
1781    def write_struct_init(self, members):
1782        longest = max([len(x[0]) for x in members])
1783        longest += 1  # because we prepend a .
1784        longest = ((longest + 8) // 8) * 8
1785        for one in members:
1786            line = '.' + one[0]
1787            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1788            line += '= ' + str(one[1]) + ','
1789            self.p(line)
1790
1791    def ifdef_block(self, config):
1792        config_option = None
1793        if config:
1794            config_option = 'CONFIG_' + c_upper(config)
1795        if self._ifdef_block == config_option:
1796            return
1797
1798        if self._ifdef_block:
1799            self.p('#endif /* ' + self._ifdef_block + ' */')
1800        if config_option:
1801            self.p('#ifdef ' + config_option)
1802        self._ifdef_block = config_option
1803
1804
1805scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1806
1807direction_to_suffix = {
1808    'reply': '_rsp',
1809    'request': '_req',
1810    '': ''
1811}
1812
1813op_mode_to_wrapper = {
1814    'do': '',
1815    'dump': '_list',
1816    'notify': '_ntf',
1817    'event': '',
1818}
1819
1820_C_KW = {
1821    'auto',
1822    'bool',
1823    'break',
1824    'case',
1825    'char',
1826    'const',
1827    'continue',
1828    'default',
1829    'do',
1830    'double',
1831    'else',
1832    'enum',
1833    'extern',
1834    'float',
1835    'for',
1836    'goto',
1837    'if',
1838    'inline',
1839    'int',
1840    'long',
1841    'register',
1842    'return',
1843    'short',
1844    'signed',
1845    'sizeof',
1846    'static',
1847    'struct',
1848    'switch',
1849    'typedef',
1850    'union',
1851    'unsigned',
1852    'void',
1853    'volatile',
1854    'while'
1855}
1856
1857
1858def rdir(direction):
1859    if direction == 'reply':
1860        return 'request'
1861    if direction == 'request':
1862        return 'reply'
1863    return direction
1864
1865
1866def op_prefix(ri, direction, deref=False):
1867    suffix = f"_{ri.type_name}"
1868
1869    if not ri.op_mode or ri.op_mode == 'do':
1870        suffix += f"{direction_to_suffix[direction]}"
1871    else:
1872        if direction == 'request':
1873            suffix += '_req'
1874            if not ri.type_oneside:
1875                suffix += '_dump'
1876        else:
1877            if ri.type_consistent:
1878                if deref:
1879                    suffix += f"{direction_to_suffix[direction]}"
1880                else:
1881                    suffix += op_mode_to_wrapper[ri.op_mode]
1882            else:
1883                suffix += '_rsp'
1884                suffix += '_dump' if deref else '_list'
1885
1886    return f"{ri.family.c_name}{suffix}"
1887
1888
1889def type_name(ri, direction, deref=False):
1890    return f"struct {op_prefix(ri, direction, deref=deref)}"
1891
1892
1893def print_prototype(ri, direction, terminate=True, doc=None):
1894    suffix = ';' if terminate else ''
1895
1896    fname = ri.op.render_name
1897    if ri.op_mode == 'dump':
1898        fname += '_dump'
1899
1900    args = ['struct ynl_sock *ys']
1901    if 'request' in ri.op[ri.op_mode]:
1902        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1903
1904    ret = 'int'
1905    if 'reply' in ri.op[ri.op_mode]:
1906        ret = f"{type_name(ri, rdir(direction))} *"
1907
1908    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1909
1910
1911def print_req_prototype(ri):
1912    print_prototype(ri, "request", doc=ri.op['doc'])
1913
1914
1915def print_dump_prototype(ri):
1916    print_prototype(ri, "request")
1917
1918
1919def put_typol_submsg(cw, struct):
1920    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1921
1922    i = 0
1923    for name, arg in struct.member_list():
1924        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s", .nest = &%s_nest, },' %
1925             (i, name, arg.nested_render_name))
1926        i += 1
1927
1928    cw.block_end(line=';')
1929    cw.nl()
1930
1931    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1932    cw.p(f'.max_attr = {i - 1},')
1933    cw.p(f'.table = {struct.render_name}_policy,')
1934    cw.block_end(line=';')
1935    cw.nl()
1936
1937
1938def put_typol_fwd(cw, struct):
1939    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1940
1941
1942def put_typol(cw, struct):
1943    if struct.submsg:
1944        put_typol_submsg(cw, struct)
1945        return
1946
1947    type_max = struct.attr_set.max_name
1948    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1949
1950    for _, arg in struct.member_list():
1951        arg.attr_typol(cw)
1952
1953    cw.block_end(line=';')
1954    cw.nl()
1955
1956    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1957    cw.p(f'.max_attr = {type_max},')
1958    cw.p(f'.table = {struct.render_name}_policy,')
1959    cw.block_end(line=';')
1960    cw.nl()
1961
1962
1963def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1964    args = [f'int {arg_name}']
1965    if enum:
1966        args = [enum.user_type + ' ' + arg_name]
1967    cw.write_func_prot('const char *', f'{render_name}_str', args)
1968    cw.block_start()
1969    if enum and enum.type == 'flags':
1970        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1971    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1972    cw.p('return NULL;')
1973    cw.p(f'return {map_name}[{arg_name}];')
1974    cw.block_end()
1975    cw.nl()
1976
1977
1978def put_op_name_fwd(family, cw):
1979    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1980
1981
1982def put_op_name(family, cw):
1983    map_name = f'{family.c_name}_op_strmap'
1984    cw.block_start(line=f"static const char * const {map_name}[] =")
1985    for op_name, op in family.msgs.items():
1986        if op.rsp_value:
1987            # Make sure we don't add duplicated entries, if multiple commands
1988            # produce the same response in legacy families.
1989            if family.rsp_by_value[op.rsp_value] != op:
1990                cw.p(f'// skip "{op_name}", duplicate reply value')
1991                continue
1992
1993            if op.req_value == op.rsp_value:
1994                cw.p(f'[{op.enum_name}] = "{op_name}",')
1995            else:
1996                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1997    cw.block_end(line=';')
1998    cw.nl()
1999
2000    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2001
2002
2003def put_enum_to_str_fwd(family, cw, enum):
2004    args = [enum.user_type + ' value']
2005    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2006
2007
2008def put_enum_to_str(family, cw, enum):
2009    map_name = f'{enum.render_name}_strmap'
2010    cw.block_start(line=f"static const char * const {map_name}[] =")
2011    for entry in enum.entries.values():
2012        cw.p(f'[{entry.value}] = "{entry.name}",')
2013    cw.block_end(line=';')
2014    cw.nl()
2015
2016    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2017
2018
2019def put_req_nested_prototype(ri, struct, suffix=';'):
2020    func_args = ['struct nlmsghdr *nlh',
2021                 'unsigned int attr_type',
2022                 f'{struct.ptr_name}obj']
2023
2024    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2025                          suffix=suffix)
2026
2027
2028def put_req_nested(ri, struct):
2029    local_vars = []
2030    init_lines = []
2031
2032    if struct.submsg is None:
2033        local_vars.append('struct nlattr *nest;')
2034        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2035
2036    has_anest = False
2037    has_count = False
2038    for _, arg in struct.member_list():
2039        has_anest |= arg.type == 'indexed-array'
2040        has_count |= arg.presence_type() == 'count'
2041    if has_anest:
2042        local_vars.append('struct nlattr *array;')
2043    if has_count:
2044        local_vars.append('unsigned int i;')
2045
2046    put_req_nested_prototype(ri, struct, suffix='')
2047    ri.cw.block_start()
2048    ri.cw.write_func_lvar(local_vars)
2049
2050    for line in init_lines:
2051        ri.cw.p(line)
2052
2053    for _, arg in struct.member_list():
2054        arg.attr_put(ri, "obj")
2055
2056    if struct.submsg is None:
2057        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2058
2059    ri.cw.nl()
2060    ri.cw.p('return 0;')
2061    ri.cw.block_end()
2062    ri.cw.nl()
2063
2064
2065def _multi_parse(ri, struct, init_lines, local_vars):
2066    if struct.nested:
2067        iter_line = "ynl_attr_for_each_nested(attr, nested)"
2068    else:
2069        if struct.fixed_header:
2070            local_vars += ['void *hdr;']
2071        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2072        if ri.op.fixed_header != ri.family.fixed_header:
2073            if ri.family.is_classic():
2074                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2075            else:
2076                raise Exception(f"Per-op fixed header not supported, yet")
2077
2078    array_nests = set()
2079    multi_attrs = set()
2080    needs_parg = False
2081    for arg, aspec in struct.member_list():
2082        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2083            if aspec["sub-type"] in {'binary', 'nest'}:
2084                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2085                array_nests.add(arg)
2086            elif aspec['sub-type'] in scalars:
2087                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2088                array_nests.add(arg)
2089            else:
2090                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2091        if 'multi-attr' in aspec:
2092            multi_attrs.add(arg)
2093        needs_parg |= 'nested-attributes' in aspec
2094        needs_parg |= 'sub-message' in aspec
2095    if array_nests or multi_attrs:
2096        local_vars.append('int i;')
2097    if needs_parg:
2098        local_vars.append('struct ynl_parse_arg parg;')
2099        init_lines.append('parg.ys = yarg->ys;')
2100
2101    all_multi = array_nests | multi_attrs
2102
2103    for anest in sorted(all_multi):
2104        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
2105
2106    ri.cw.block_start()
2107    ri.cw.write_func_lvar(local_vars)
2108
2109    for line in init_lines:
2110        ri.cw.p(line)
2111    ri.cw.nl()
2112
2113    for arg in struct.inherited:
2114        ri.cw.p(f'dst->{arg} = {arg};')
2115
2116    if struct.fixed_header:
2117        if ri.family.is_classic():
2118            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2119        else:
2120            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2121        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2122    for anest in sorted(all_multi):
2123        aspec = struct[anest]
2124        ri.cw.p(f"if (dst->{aspec.c_name})")
2125        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2126
2127    ri.cw.nl()
2128    ri.cw.block_start(line=iter_line)
2129    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2130    ri.cw.nl()
2131
2132    first = True
2133    for _, arg in struct.member_list():
2134        good = arg.attr_get(ri, 'dst', first=first)
2135        # First may be 'unused' or 'pad', ignore those
2136        first &= not good
2137
2138    ri.cw.block_end()
2139    ri.cw.nl()
2140
2141    for anest in sorted(array_nests):
2142        aspec = struct[anest]
2143
2144        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2145        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2146        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2147        ri.cw.p('i = 0;')
2148        if 'nested-attributes' in aspec:
2149            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2150        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2151        if 'nested-attributes' in aspec:
2152            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2153            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2154            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2155        elif aspec.sub_type in scalars:
2156            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2157        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2158            # Length is validated by typol
2159            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2160        else:
2161            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2162        ri.cw.p('i++;')
2163        ri.cw.block_end()
2164        ri.cw.block_end()
2165    ri.cw.nl()
2166
2167    for anest in sorted(multi_attrs):
2168        aspec = struct[anest]
2169        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2170        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2171        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2172        ri.cw.p('i = 0;')
2173        if 'nested-attributes' in aspec:
2174            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2175        ri.cw.block_start(line=iter_line)
2176        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2177        if 'nested-attributes' in aspec:
2178            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2179            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2180            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2181        elif aspec.type in scalars:
2182            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2183        elif aspec.type == 'binary' and 'struct' in aspec:
2184            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2185            ri.cw.nl()
2186            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2187            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2188            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2189        elif aspec.type == 'string':
2190            ri.cw.p('unsigned int len;')
2191            ri.cw.nl()
2192            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2193            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2194            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2195            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2196            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2197        else:
2198            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2199        ri.cw.p('i++;')
2200        ri.cw.block_end()
2201        ri.cw.block_end()
2202        ri.cw.block_end()
2203    ri.cw.nl()
2204
2205    if struct.nested:
2206        ri.cw.p('return 0;')
2207    else:
2208        ri.cw.p('return YNL_PARSE_CB_OK;')
2209    ri.cw.block_end()
2210    ri.cw.nl()
2211
2212
2213def parse_rsp_submsg(ri, struct):
2214    parse_rsp_nested_prototype(ri, struct, suffix='')
2215
2216    var = 'dst'
2217    local_vars = {'const struct nlattr *attr = nested;',
2218                  f'{struct.ptr_name}{var} = yarg->data;',
2219                  'struct ynl_parse_arg parg;'}
2220
2221    for _, arg in struct.member_list():
2222        _, _, l_vars = arg._attr_get(ri, var)
2223        local_vars |= set(l_vars) if l_vars else set()
2224
2225    ri.cw.block_start()
2226    ri.cw.write_func_lvar(list(local_vars))
2227    ri.cw.p('parg.ys = yarg->ys;')
2228    ri.cw.nl()
2229
2230    first = True
2231    for name, arg in struct.member_list():
2232        kw = 'if' if first else 'else if'
2233        first = False
2234
2235        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2236        get_lines, init_lines, _ = arg._attr_get(ri, var)
2237        for line in init_lines:
2238            ri.cw.p(line)
2239        for line in get_lines:
2240            ri.cw.p(line)
2241        if arg.presence_type() == 'present':
2242            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2243        ri.cw.block_end()
2244    ri.cw.p('return 0;')
2245    ri.cw.block_end()
2246    ri.cw.nl()
2247
2248
2249def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2250    func_args = ['struct ynl_parse_arg *yarg',
2251                 'const struct nlattr *nested']
2252    for sel in struct.external_selectors():
2253        func_args.append('const char *_sel_' + sel.name)
2254    if struct.submsg:
2255        func_args.insert(1, 'const char *sel')
2256    for arg in struct.inherited:
2257        func_args.append('__u32 ' + arg)
2258
2259    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2260                          suffix=suffix)
2261
2262
2263def parse_rsp_nested(ri, struct):
2264    if struct.submsg:
2265        return parse_rsp_submsg(ri, struct)
2266
2267    parse_rsp_nested_prototype(ri, struct, suffix='')
2268
2269    local_vars = ['const struct nlattr *attr;',
2270                  f'{struct.ptr_name}dst = yarg->data;']
2271    init_lines = []
2272
2273    if struct.member_list():
2274        _multi_parse(ri, struct, init_lines, local_vars)
2275    else:
2276        # Empty nest
2277        ri.cw.block_start()
2278        ri.cw.p('return 0;')
2279        ri.cw.block_end()
2280        ri.cw.nl()
2281
2282
2283def parse_rsp_msg(ri, deref=False):
2284    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2285        return
2286
2287    func_args = ['const struct nlmsghdr *nlh',
2288                 'struct ynl_parse_arg *yarg']
2289
2290    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2291                  'const struct nlattr *attr;']
2292    init_lines = ['dst = yarg->data;']
2293
2294    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2295
2296    if ri.struct["reply"].member_list():
2297        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2298    else:
2299        # Empty reply
2300        ri.cw.block_start()
2301        ri.cw.p('return YNL_PARSE_CB_OK;')
2302        ri.cw.block_end()
2303        ri.cw.nl()
2304
2305
2306def print_req(ri):
2307    ret_ok = '0'
2308    ret_err = '-1'
2309    direction = "request"
2310    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2311                  'struct nlmsghdr *nlh;',
2312                  'int err;']
2313
2314    if 'reply' in ri.op[ri.op_mode]:
2315        ret_ok = 'rsp'
2316        ret_err = 'NULL'
2317        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2318
2319    if ri.struct["request"].fixed_header:
2320        local_vars += ['size_t hdr_len;',
2321                       'void *hdr;']
2322
2323    for _, attr in ri.struct["request"].member_list():
2324        if attr.presence_type() == 'count':
2325            local_vars += ['unsigned int i;']
2326            break
2327
2328    print_prototype(ri, direction, terminate=False)
2329    ri.cw.block_start()
2330    ri.cw.write_func_lvar(local_vars)
2331
2332    if ri.family.is_classic():
2333        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2334    else:
2335        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2336
2337    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2338    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2339    if 'reply' in ri.op[ri.op_mode]:
2340        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2341    ri.cw.nl()
2342
2343    if ri.struct['request'].fixed_header:
2344        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2345        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2346        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2347        ri.cw.nl()
2348
2349    for _, attr in ri.struct["request"].member_list():
2350        attr.attr_put(ri, "req")
2351    ri.cw.nl()
2352
2353    if 'reply' in ri.op[ri.op_mode]:
2354        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2355        ri.cw.p('yrs.yarg.data = rsp;')
2356        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2357        if ri.op.value is not None:
2358            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2359        else:
2360            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2361        ri.cw.nl()
2362    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2363    ri.cw.p('if (err < 0)')
2364    if 'reply' in ri.op[ri.op_mode]:
2365        ri.cw.p('goto err_free;')
2366    else:
2367        ri.cw.p('return -1;')
2368    ri.cw.nl()
2369
2370    ri.cw.p(f"return {ret_ok};")
2371    ri.cw.nl()
2372
2373    if 'reply' in ri.op[ri.op_mode]:
2374        ri.cw.p('err_free:')
2375        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2376        ri.cw.p(f"return {ret_err};")
2377
2378    ri.cw.block_end()
2379
2380
2381def print_dump(ri):
2382    direction = "request"
2383    print_prototype(ri, direction, terminate=False)
2384    ri.cw.block_start()
2385    local_vars = ['struct ynl_dump_state yds = {};',
2386                  'struct nlmsghdr *nlh;',
2387                  'int err;']
2388
2389    if ri.struct['request'].fixed_header:
2390        local_vars += ['size_t hdr_len;',
2391                       'void *hdr;']
2392
2393    ri.cw.write_func_lvar(local_vars)
2394
2395    ri.cw.p('yds.yarg.ys = ys;')
2396    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2397    ri.cw.p("yds.yarg.data = NULL;")
2398    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2399    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2400    if ri.op.value is not None:
2401        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2402    else:
2403        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2404    ri.cw.nl()
2405    if ri.family.is_classic():
2406        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2407    else:
2408        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2409
2410    if ri.struct['request'].fixed_header:
2411        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2412        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2413        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2414        ri.cw.nl()
2415
2416    if "request" in ri.op[ri.op_mode]:
2417        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2418        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2419        ri.cw.nl()
2420        for _, attr in ri.struct["request"].member_list():
2421            attr.attr_put(ri, "req")
2422    ri.cw.nl()
2423
2424    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2425    ri.cw.p('if (err < 0)')
2426    ri.cw.p('goto free_list;')
2427    ri.cw.nl()
2428
2429    ri.cw.p('return yds.first;')
2430    ri.cw.nl()
2431    ri.cw.p('free_list:')
2432    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2433    ri.cw.p('return NULL;')
2434    ri.cw.block_end()
2435
2436
2437def call_free(ri, direction, var):
2438    return f"{op_prefix(ri, direction)}_free({var});"
2439
2440
2441def free_arg_name(direction):
2442    if direction:
2443        return direction_to_suffix[direction][1:]
2444    return 'obj'
2445
2446
2447def print_alloc_wrapper(ri, direction):
2448    name = op_prefix(ri, direction)
2449    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
2450    ri.cw.block_start()
2451    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
2452    ri.cw.block_end()
2453
2454
2455def print_free_prototype(ri, direction, suffix=';'):
2456    name = op_prefix(ri, direction)
2457    struct_name = name
2458    if ri.type_name_conflict:
2459        struct_name += '_'
2460    arg = free_arg_name(direction)
2461    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2462
2463
2464def print_nlflags_set(ri, direction):
2465    name = op_prefix(ri, direction)
2466    ri.cw.write_func_prot(f'static inline void', f"{name}_set_nlflags",
2467                          [f"struct {name} *req", "__u16 nl_flags"])
2468    ri.cw.block_start()
2469    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2470    ri.cw.block_end()
2471    ri.cw.nl()
2472
2473
2474def _print_type(ri, direction, struct):
2475    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2476    if not direction and ri.type_name_conflict:
2477        suffix += '_'
2478
2479    if ri.op_mode == 'dump' and not ri.type_oneside:
2480        suffix += '_dump'
2481
2482    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2483
2484    if ri.needs_nlflags(direction):
2485        ri.cw.p('__u16 _nlmsg_flags;')
2486        ri.cw.nl()
2487    if struct.fixed_header:
2488        ri.cw.p(struct.fixed_header + ' _hdr;')
2489        ri.cw.nl()
2490
2491    for type_filter in ['present', 'len', 'count']:
2492        meta_started = False
2493        for _, attr in struct.member_list():
2494            line = attr.presence_member(ri.ku_space, type_filter)
2495            if line:
2496                if not meta_started:
2497                    ri.cw.block_start(line=f"struct")
2498                    meta_started = True
2499                ri.cw.p(line)
2500        if meta_started:
2501            ri.cw.block_end(line=f'_{type_filter};')
2502    ri.cw.nl()
2503
2504    for arg in struct.inherited:
2505        ri.cw.p(f"__u32 {arg};")
2506
2507    for _, attr in struct.member_list():
2508        attr.struct_member(ri)
2509
2510    ri.cw.block_end(line=';')
2511    ri.cw.nl()
2512
2513
2514def print_type(ri, direction):
2515    _print_type(ri, direction, ri.struct[direction])
2516
2517
2518def print_type_full(ri, struct):
2519    _print_type(ri, "", struct)
2520
2521
2522def print_type_helpers(ri, direction, deref=False):
2523    print_free_prototype(ri, direction)
2524    ri.cw.nl()
2525
2526    if ri.needs_nlflags(direction):
2527        print_nlflags_set(ri, direction)
2528
2529    if ri.ku_space == 'user' and direction == 'request':
2530        for _, attr in ri.struct[direction].member_list():
2531            attr.setter(ri, ri.attr_set, direction, deref=deref)
2532    ri.cw.nl()
2533
2534
2535def print_req_type_helpers(ri):
2536    if ri.type_empty("request"):
2537        return
2538    print_alloc_wrapper(ri, "request")
2539    print_type_helpers(ri, "request")
2540
2541
2542def print_rsp_type_helpers(ri):
2543    if 'reply' not in ri.op[ri.op_mode]:
2544        return
2545    print_type_helpers(ri, "reply")
2546
2547
2548def print_parse_prototype(ri, direction, terminate=True):
2549    suffix = "_rsp" if direction == "reply" else "_req"
2550    term = ';' if terminate else ''
2551
2552    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2553                          ['const struct nlattr **tb',
2554                           f"struct {ri.op.render_name}{suffix} *req"],
2555                          suffix=term)
2556
2557
2558def print_req_type(ri):
2559    if ri.type_empty("request"):
2560        return
2561    print_type(ri, "request")
2562
2563
2564def print_req_free(ri):
2565    if 'request' not in ri.op[ri.op_mode]:
2566        return
2567    _free_type(ri, 'request', ri.struct['request'])
2568
2569
2570def print_rsp_type(ri):
2571    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2572        direction = 'reply'
2573    elif ri.op_mode == 'event':
2574        direction = 'reply'
2575    else:
2576        return
2577    print_type(ri, direction)
2578
2579
2580def print_wrapped_type(ri):
2581    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2582    if ri.op_mode == 'dump':
2583        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2584    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2585        ri.cw.p('__u16 family;')
2586        ri.cw.p('__u8 cmd;')
2587        ri.cw.p('struct ynl_ntf_base_type *next;')
2588        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2589    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2590    ri.cw.block_end(line=';')
2591    ri.cw.nl()
2592    print_free_prototype(ri, 'reply')
2593    ri.cw.nl()
2594
2595
2596def _free_type_members_iter(ri, struct):
2597    if struct.free_needs_iter():
2598        ri.cw.p('unsigned int i;')
2599        ri.cw.nl()
2600
2601
2602def _free_type_members(ri, var, struct, ref=''):
2603    for _, attr in struct.member_list():
2604        attr.free(ri, var, ref)
2605
2606
2607def _free_type(ri, direction, struct):
2608    var = free_arg_name(direction)
2609
2610    print_free_prototype(ri, direction, suffix='')
2611    ri.cw.block_start()
2612    _free_type_members_iter(ri, struct)
2613    _free_type_members(ri, var, struct)
2614    if direction:
2615        ri.cw.p(f'free({var});')
2616    ri.cw.block_end()
2617    ri.cw.nl()
2618
2619
2620def free_rsp_nested_prototype(ri):
2621        print_free_prototype(ri, "")
2622
2623
2624def free_rsp_nested(ri, struct):
2625    _free_type(ri, "", struct)
2626
2627
2628def print_rsp_free(ri):
2629    if 'reply' not in ri.op[ri.op_mode]:
2630        return
2631    _free_type(ri, 'reply', ri.struct['reply'])
2632
2633
2634def print_dump_type_free(ri):
2635    sub_type = type_name(ri, 'reply')
2636
2637    print_free_prototype(ri, 'reply', suffix='')
2638    ri.cw.block_start()
2639    ri.cw.p(f"{sub_type} *next = rsp;")
2640    ri.cw.nl()
2641    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2642    _free_type_members_iter(ri, ri.struct['reply'])
2643    ri.cw.p('rsp = next;')
2644    ri.cw.p('next = rsp->next;')
2645    ri.cw.nl()
2646
2647    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2648    ri.cw.p(f'free(rsp);')
2649    ri.cw.block_end()
2650    ri.cw.block_end()
2651    ri.cw.nl()
2652
2653
2654def print_ntf_type_free(ri):
2655    print_free_prototype(ri, 'reply', suffix='')
2656    ri.cw.block_start()
2657    _free_type_members_iter(ri, ri.struct['reply'])
2658    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2659    ri.cw.p(f'free(rsp);')
2660    ri.cw.block_end()
2661    ri.cw.nl()
2662
2663
2664def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2665    if terminate and ri and policy_should_be_static(struct.family):
2666        return
2667
2668    if terminate:
2669        prefix = 'extern '
2670    else:
2671        if ri and policy_should_be_static(struct.family):
2672            prefix = 'static '
2673        else:
2674            prefix = ''
2675
2676    suffix = ';' if terminate else ' = {'
2677
2678    max_attr = struct.attr_max_val
2679    if ri:
2680        name = ri.op.render_name
2681        if ri.op.dual_policy:
2682            name += '_' + ri.op_mode
2683    else:
2684        name = struct.render_name
2685    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2686
2687
2688def print_req_policy(cw, struct, ri=None):
2689    if ri and ri.op:
2690        cw.ifdef_block(ri.op.get('config-cond', None))
2691    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2692    for _, arg in struct.member_list():
2693        arg.attr_policy(cw)
2694    cw.p("};")
2695    cw.ifdef_block(None)
2696    cw.nl()
2697
2698
2699def kernel_can_gen_family_struct(family):
2700    return family.proto == 'genetlink'
2701
2702
2703def policy_should_be_static(family):
2704    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2705
2706
2707def print_kernel_policy_ranges(family, cw):
2708    first = True
2709    for _, attr_set in family.attr_sets.items():
2710        if attr_set.subset_of:
2711            continue
2712
2713        for _, attr in attr_set.items():
2714            if not attr.request:
2715                continue
2716            if 'full-range' not in attr.checks:
2717                continue
2718
2719            if first:
2720                cw.p('/* Integer value ranges */')
2721                first = False
2722
2723            sign = '' if attr.type[0] == 'u' else '_signed'
2724            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2725            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2726            members = []
2727            if 'min' in attr.checks:
2728                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2729            if 'max' in attr.checks:
2730                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2731            cw.write_struct_init(members)
2732            cw.block_end(line=';')
2733            cw.nl()
2734
2735
2736def print_kernel_policy_sparse_enum_validates(family, cw):
2737    first = True
2738    for _, attr_set in family.attr_sets.items():
2739        if attr_set.subset_of:
2740            continue
2741
2742        for _, attr in attr_set.items():
2743            if not attr.request:
2744                continue
2745            if not attr.enum_name:
2746                continue
2747            if 'sparse' not in attr.checks:
2748                continue
2749
2750            if first:
2751                cw.p('/* Sparse enums validation callbacks */')
2752                first = False
2753
2754            sign = '' if attr.type[0] == 'u' else '_signed'
2755            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2756            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2757                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2758            cw.block_start()
2759            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2760            enum = family.consts[attr['enum']]
2761            first_entry = True
2762            for entry in enum.entries.values():
2763                if first_entry:
2764                    first_entry = False
2765                else:
2766                    cw.p('fallthrough;')
2767                cw.p(f'case {entry.c_name}:')
2768            cw.p('return 0;')
2769            cw.block_end()
2770            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2771            cw.p('return -EINVAL;')
2772            cw.block_end()
2773            cw.nl()
2774
2775
2776def print_kernel_op_table_fwd(family, cw, terminate):
2777    exported = not kernel_can_gen_family_struct(family)
2778
2779    if not terminate or exported:
2780        cw.p(f"/* Ops table for {family.ident_name} */")
2781
2782        pol_to_struct = {'global': 'genl_small_ops',
2783                         'per-op': 'genl_ops',
2784                         'split': 'genl_split_ops'}
2785        struct_type = pol_to_struct[family.kernel_policy]
2786
2787        if not exported:
2788            cnt = ""
2789        elif family.kernel_policy == 'split':
2790            cnt = 0
2791            for op in family.ops.values():
2792                if 'do' in op:
2793                    cnt += 1
2794                if 'dump' in op:
2795                    cnt += 1
2796        else:
2797            cnt = len(family.ops)
2798
2799        qual = 'static const' if not exported else 'const'
2800        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2801        if terminate:
2802            cw.p(f"extern {line};")
2803        else:
2804            cw.block_start(line=line + ' =')
2805
2806    if not terminate:
2807        return
2808
2809    cw.nl()
2810    for name in family.hooks['pre']['do']['list']:
2811        cw.write_func_prot('int', c_lower(name),
2812                           ['const struct genl_split_ops *ops',
2813                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2814    for name in family.hooks['post']['do']['list']:
2815        cw.write_func_prot('void', c_lower(name),
2816                           ['const struct genl_split_ops *ops',
2817                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2818    for name in family.hooks['pre']['dump']['list']:
2819        cw.write_func_prot('int', c_lower(name),
2820                           ['struct netlink_callback *cb'], suffix=';')
2821    for name in family.hooks['post']['dump']['list']:
2822        cw.write_func_prot('int', c_lower(name),
2823                           ['struct netlink_callback *cb'], suffix=';')
2824
2825    cw.nl()
2826
2827    for op_name, op in family.ops.items():
2828        if op.is_async:
2829            continue
2830
2831        if 'do' in op:
2832            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2833            cw.write_func_prot('int', name,
2834                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2835
2836        if 'dump' in op:
2837            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2838            cw.write_func_prot('int', name,
2839                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2840    cw.nl()
2841
2842
2843def print_kernel_op_table_hdr(family, cw):
2844    print_kernel_op_table_fwd(family, cw, terminate=True)
2845
2846
2847def print_kernel_op_table(family, cw):
2848    print_kernel_op_table_fwd(family, cw, terminate=False)
2849    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2850        for op_name, op in family.ops.items():
2851            if op.is_async:
2852                continue
2853
2854            cw.ifdef_block(op.get('config-cond', None))
2855            cw.block_start()
2856            members = [('cmd', op.enum_name)]
2857            if 'dont-validate' in op:
2858                members.append(('validate',
2859                                ' | '.join([c_upper('genl-dont-validate-' + x)
2860                                            for x in op['dont-validate']])), )
2861            for op_mode in ['do', 'dump']:
2862                if op_mode in op:
2863                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2864                    members.append((op_mode + 'it', name))
2865            if family.kernel_policy == 'per-op':
2866                struct = Struct(family, op['attribute-set'],
2867                                type_list=op['do']['request']['attributes'])
2868
2869                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2870                members.append(('policy', name))
2871                members.append(('maxattr', struct.attr_max_val.enum_name))
2872            if 'flags' in op:
2873                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2874            cw.write_struct_init(members)
2875            cw.block_end(line=',')
2876    elif family.kernel_policy == 'split':
2877        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2878                    'dump': {'pre': 'start', 'post': 'done'}}
2879
2880        for op_name, op in family.ops.items():
2881            for op_mode in ['do', 'dump']:
2882                if op.is_async or op_mode not in op:
2883                    continue
2884
2885                cw.ifdef_block(op.get('config-cond', None))
2886                cw.block_start()
2887                members = [('cmd', op.enum_name)]
2888                if 'dont-validate' in op:
2889                    dont_validate = []
2890                    for x in op['dont-validate']:
2891                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2892                            continue
2893                        if op_mode == "dump" and x == 'strict':
2894                            continue
2895                        dont_validate.append(x)
2896
2897                    if dont_validate:
2898                        members.append(('validate',
2899                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2900                                                    for x in dont_validate])), )
2901                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2902                if 'pre' in op[op_mode]:
2903                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2904                members.append((op_mode + 'it', name))
2905                if 'post' in op[op_mode]:
2906                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2907                if 'request' in op[op_mode]:
2908                    struct = Struct(family, op['attribute-set'],
2909                                    type_list=op[op_mode]['request']['attributes'])
2910
2911                    if op.dual_policy:
2912                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2913                    else:
2914                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2915                    members.append(('policy', name))
2916                    members.append(('maxattr', struct.attr_max_val.enum_name))
2917                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2918                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2919                cw.write_struct_init(members)
2920                cw.block_end(line=',')
2921    cw.ifdef_block(None)
2922
2923    cw.block_end(line=';')
2924    cw.nl()
2925
2926
2927def print_kernel_mcgrp_hdr(family, cw):
2928    if not family.mcgrps['list']:
2929        return
2930
2931    cw.block_start('enum')
2932    for grp in family.mcgrps['list']:
2933        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2934        cw.p(grp_id)
2935    cw.block_end(';')
2936    cw.nl()
2937
2938
2939def print_kernel_mcgrp_src(family, cw):
2940    if not family.mcgrps['list']:
2941        return
2942
2943    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2944    for grp in family.mcgrps['list']:
2945        name = grp['name']
2946        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2947        cw.p('[' + grp_id + '] = { "' + name + '", },')
2948    cw.block_end(';')
2949    cw.nl()
2950
2951
2952def print_kernel_family_struct_hdr(family, cw):
2953    if not kernel_can_gen_family_struct(family):
2954        return
2955
2956    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2957    cw.nl()
2958    if 'sock-priv' in family.kernel_family:
2959        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2960        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2961        cw.nl()
2962
2963
2964def print_kernel_family_struct_src(family, cw):
2965    if not kernel_can_gen_family_struct(family):
2966        return
2967
2968    if 'sock-priv' in family.kernel_family:
2969        # Generate "trampolines" to make CFI happy
2970        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2971                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2972                      ["void *priv"])
2973        cw.nl()
2974        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2975                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2976                      ["void *priv"])
2977        cw.nl()
2978
2979    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2980    cw.p('.name\t\t= ' + family.fam_key + ',')
2981    cw.p('.version\t= ' + family.ver_key + ',')
2982    cw.p('.netnsok\t= true,')
2983    cw.p('.parallel_ops\t= true,')
2984    cw.p('.module\t\t= THIS_MODULE,')
2985    if family.kernel_policy == 'per-op':
2986        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2987        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2988    elif family.kernel_policy == 'split':
2989        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2990        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2991    if family.mcgrps['list']:
2992        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2993        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2994    if 'sock-priv' in family.kernel_family:
2995        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2996        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2997        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2998    cw.block_end(';')
2999
3000
3001def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
3002    start_line = 'enum'
3003    if enum_name in obj:
3004        if obj[enum_name]:
3005            start_line = 'enum ' + c_lower(obj[enum_name])
3006    elif ckey and ckey in obj:
3007        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3008    cw.block_start(line=start_line)
3009
3010
3011def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3012    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3013    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3014    max_value = f"({cnt_name} - 1)"
3015
3016    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3017    val = 0
3018    for op in family.msgs.values():
3019        if separate_ntf and ('notify' in op or 'event' in op):
3020            continue
3021
3022        suffix = ','
3023        if op.value != val:
3024            suffix = f" = {op.value},"
3025            val = op.value
3026        cw.p(op.enum_name + suffix)
3027        val += 1
3028    cw.nl()
3029    cw.p(cnt_name + ('' if max_by_define else ','))
3030    if not max_by_define:
3031        cw.p(f"{max_name} = {max_value}")
3032    cw.block_end(line=';')
3033    if max_by_define:
3034        cw.p(f"#define {max_name} {max_value}")
3035    cw.nl()
3036
3037
3038def render_uapi_directional(family, cw, max_by_define):
3039    max_name = f"{family.op_prefix}USER_MAX"
3040    cnt_name = f"__{family.op_prefix}USER_CNT"
3041    max_value = f"({cnt_name} - 1)"
3042
3043    cw.block_start(line='enum')
3044    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3045    val = 0
3046    for op in family.msgs.values():
3047        if 'do' in op and 'event' not in op:
3048            suffix = ','
3049            if op.value and op.value != val:
3050                suffix = f" = {op.value},"
3051                val = op.value
3052            cw.p(op.enum_name + suffix)
3053            val += 1
3054    cw.nl()
3055    cw.p(cnt_name + ('' if max_by_define else ','))
3056    if not max_by_define:
3057        cw.p(f"{max_name} = {max_value}")
3058    cw.block_end(line=';')
3059    if max_by_define:
3060        cw.p(f"#define {max_name} {max_value}")
3061    cw.nl()
3062
3063    max_name = f"{family.op_prefix}KERNEL_MAX"
3064    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3065    max_value = f"({cnt_name} - 1)"
3066
3067    cw.block_start(line='enum')
3068    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3069    val = 0
3070    for op in family.msgs.values():
3071        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3072            enum_name = op.enum_name
3073            if 'event' not in op and 'notify' not in op:
3074                enum_name = f'{enum_name}_REPLY'
3075
3076            suffix = ','
3077            if op.value and op.value != val:
3078                suffix = f" = {op.value},"
3079                val = op.value
3080            cw.p(enum_name + suffix)
3081            val += 1
3082    cw.nl()
3083    cw.p(cnt_name + ('' if max_by_define else ','))
3084    if not max_by_define:
3085        cw.p(f"{max_name} = {max_value}")
3086    cw.block_end(line=';')
3087    if max_by_define:
3088        cw.p(f"#define {max_name} {max_value}")
3089    cw.nl()
3090
3091
3092def render_uapi(family, cw):
3093    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3094    hdr_prot = hdr_prot.replace('/', '_')
3095    cw.p('#ifndef ' + hdr_prot)
3096    cw.p('#define ' + hdr_prot)
3097    cw.nl()
3098
3099    defines = [(family.fam_key, family["name"]),
3100               (family.ver_key, family.get('version', 1))]
3101    cw.writes_defines(defines)
3102    cw.nl()
3103
3104    defines = []
3105    for const in family['definitions']:
3106        if const.get('header'):
3107            continue
3108
3109        if const['type'] != 'const':
3110            cw.writes_defines(defines)
3111            defines = []
3112            cw.nl()
3113
3114        # Write kdoc for enum and flags (one day maybe also structs)
3115        if const['type'] == 'enum' or const['type'] == 'flags':
3116            enum = family.consts[const['name']]
3117
3118            if enum.header:
3119                continue
3120
3121            if enum.has_doc():
3122                if enum.has_entry_doc():
3123                    cw.p('/**')
3124                    doc = ''
3125                    if 'doc' in enum:
3126                        doc = ' - ' + enum['doc']
3127                    cw.write_doc_line(enum.enum_name + doc)
3128                else:
3129                    cw.p('/*')
3130                    cw.write_doc_line(enum['doc'], indent=False)
3131                for entry in enum.entries.values():
3132                    if entry.has_doc():
3133                        doc = '@' + entry.c_name + ': ' + entry['doc']
3134                        cw.write_doc_line(doc)
3135                cw.p(' */')
3136
3137            uapi_enum_start(family, cw, const, 'name')
3138            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3139            for entry in enum.entries.values():
3140                suffix = ','
3141                if entry.value_change:
3142                    suffix = f" = {entry.user_value()}" + suffix
3143                cw.p(entry.c_name + suffix)
3144
3145            if const.get('render-max', False):
3146                cw.nl()
3147                cw.p('/* private: */')
3148                if const['type'] == 'flags':
3149                    max_name = c_upper(name_pfx + 'mask')
3150                    max_val = f' = {enum.get_mask()},'
3151                    cw.p(max_name + max_val)
3152                else:
3153                    cnt_name = enum.enum_cnt_name
3154                    max_name = c_upper(name_pfx + 'max')
3155                    if not cnt_name:
3156                        cnt_name = '__' + name_pfx + 'max'
3157                    cw.p(c_upper(cnt_name) + ',')
3158                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3159            cw.block_end(line=';')
3160            cw.nl()
3161        elif const['type'] == 'const':
3162            defines.append([c_upper(family.get('c-define-name',
3163                                               f"{family.ident_name}-{const['name']}")),
3164                            const['value']])
3165
3166    if defines:
3167        cw.writes_defines(defines)
3168        cw.nl()
3169
3170    max_by_define = family.get('max-by-define', False)
3171
3172    for _, attr_set in family.attr_sets.items():
3173        if attr_set.subset_of:
3174            continue
3175
3176        max_value = f"({attr_set.cnt_name} - 1)"
3177
3178        val = 0
3179        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3180        for _, attr in attr_set.items():
3181            suffix = ','
3182            if attr.value != val:
3183                suffix = f" = {attr.value},"
3184                val = attr.value
3185            val += 1
3186            cw.p(attr.enum_name + suffix)
3187        if attr_set.items():
3188            cw.nl()
3189        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3190        if not max_by_define:
3191            cw.p(f"{attr_set.max_name} = {max_value}")
3192        cw.block_end(line=';')
3193        if max_by_define:
3194            cw.p(f"#define {attr_set.max_name} {max_value}")
3195        cw.nl()
3196
3197    # Commands
3198    separate_ntf = 'async-prefix' in family['operations']
3199
3200    if family.msg_id_model == 'unified':
3201        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3202    elif family.msg_id_model == 'directional':
3203        render_uapi_directional(family, cw, max_by_define)
3204    else:
3205        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3206
3207    if separate_ntf:
3208        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3209        for op in family.msgs.values():
3210            if separate_ntf and not ('notify' in op or 'event' in op):
3211                continue
3212
3213            suffix = ','
3214            if 'value' in op:
3215                suffix = f" = {op['value']},"
3216            cw.p(op.enum_name + suffix)
3217        cw.block_end(line=';')
3218        cw.nl()
3219
3220    # Multicast
3221    defines = []
3222    for grp in family.mcgrps['list']:
3223        name = grp['name']
3224        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3225                        f'{name}'])
3226    cw.nl()
3227    if defines:
3228        cw.writes_defines(defines)
3229        cw.nl()
3230
3231    cw.p(f'#endif /* {hdr_prot} */')
3232
3233
3234def _render_user_ntf_entry(ri, op):
3235    if not ri.family.is_classic():
3236        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3237    else:
3238        crud_op = ri.family.req_by_value[op.rsp_value]
3239        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3240    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3241    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3242    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3243    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3244    ri.cw.block_end(line=',')
3245
3246
3247def render_user_family(family, cw, prototype):
3248    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3249    if prototype:
3250        cw.p(f'extern {symbol};')
3251        return
3252
3253    if family.ntfs:
3254        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3255        for ntf_op_name, ntf_op in family.ntfs.items():
3256            if 'notify' in ntf_op:
3257                op = family.ops[ntf_op['notify']]
3258                ri = RenderInfo(cw, family, "user", op, "notify")
3259            elif 'event' in ntf_op:
3260                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3261            else:
3262                raise Exception('Invalid notification ' + ntf_op_name)
3263            _render_user_ntf_entry(ri, ntf_op)
3264        for op_name, op in family.ops.items():
3265            if 'event' not in op:
3266                continue
3267            ri = RenderInfo(cw, family, "user", op, "event")
3268            _render_user_ntf_entry(ri, op)
3269        cw.block_end(line=";")
3270        cw.nl()
3271
3272    cw.block_start(f'{symbol} = ')
3273    cw.p(f'.name\t\t= "{family.c_name}",')
3274    if family.is_classic():
3275        cw.p(f'.is_classic\t= true,')
3276        cw.p(f'.classic_id\t= {family.get("protonum")},')
3277    if family.is_classic():
3278        if family.fixed_header:
3279            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3280    elif family.fixed_header:
3281        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3282    else:
3283        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3284    if family.ntfs:
3285        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3286        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3287    cw.block_end(line=';')
3288
3289
3290def family_contains_bitfield32(family):
3291    for _, attr_set in family.attr_sets.items():
3292        if attr_set.subset_of:
3293            continue
3294        for _, attr in attr_set.items():
3295            if attr.type == "bitfield32":
3296                return True
3297    return False
3298
3299
3300def find_kernel_root(full_path):
3301    sub_path = ''
3302    while True:
3303        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3304        full_path = os.path.dirname(full_path)
3305        maintainers = os.path.join(full_path, "MAINTAINERS")
3306        if os.path.exists(maintainers):
3307            return full_path, sub_path[:-1]
3308
3309
3310def main():
3311    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3312    parser.add_argument('--mode', dest='mode', type=str, required=True,
3313                        choices=('user', 'kernel', 'uapi'))
3314    parser.add_argument('--spec', dest='spec', type=str, required=True)
3315    parser.add_argument('--header', dest='header', action='store_true', default=None)
3316    parser.add_argument('--source', dest='header', action='store_false')
3317    parser.add_argument('--user-header', nargs='+', default=[])
3318    parser.add_argument('--cmp-out', action='store_true', default=None,
3319                        help='Do not overwrite the output file if the new output is identical to the old')
3320    parser.add_argument('--exclude-op', action='append', default=[])
3321    parser.add_argument('-o', dest='out_file', type=str, default=None)
3322    args = parser.parse_args()
3323
3324    if args.header is None:
3325        parser.error("--header or --source is required")
3326
3327    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3328
3329    try:
3330        parsed = Family(args.spec, exclude_ops)
3331        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3332            print('Spec license:', parsed.license)
3333            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3334            os.sys.exit(1)
3335    except yaml.YAMLError as exc:
3336        print(exc)
3337        os.sys.exit(1)
3338        return
3339
3340    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3341
3342    _, spec_kernel = find_kernel_root(args.spec)
3343    if args.mode == 'uapi' or args.header:
3344        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3345    else:
3346        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3347    cw.p("/* Do not edit directly, auto-generated from: */")
3348    cw.p(f"/*\t{spec_kernel} */")
3349    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3350    if args.exclude_op or args.user_header:
3351        line = ''
3352        line += ' --user-header '.join([''] + args.user_header)
3353        line += ' --exclude-op '.join([''] + args.exclude_op)
3354        cw.p(f'/* YNL-ARG{line} */')
3355    cw.nl()
3356
3357    if args.mode == 'uapi':
3358        render_uapi(parsed, cw)
3359        return
3360
3361    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3362    if args.header:
3363        cw.p('#ifndef ' + hdr_prot)
3364        cw.p('#define ' + hdr_prot)
3365        cw.nl()
3366
3367    if args.out_file:
3368        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3369    else:
3370        hdr_file = "generated_header_file.h"
3371
3372    if args.mode == 'kernel':
3373        cw.p('#include <net/netlink.h>')
3374        cw.p('#include <net/genetlink.h>')
3375        cw.nl()
3376        if not args.header:
3377            if args.out_file:
3378                cw.p(f'#include "{hdr_file}"')
3379            cw.nl()
3380        headers = ['uapi/' + parsed.uapi_header]
3381        headers += parsed.kernel_family.get('headers', [])
3382    else:
3383        cw.p('#include <stdlib.h>')
3384        cw.p('#include <string.h>')
3385        if args.header:
3386            cw.p('#include <linux/types.h>')
3387            if family_contains_bitfield32(parsed):
3388                cw.p('#include <linux/netlink.h>')
3389        else:
3390            cw.p(f'#include "{hdr_file}"')
3391            cw.p('#include "ynl.h"')
3392        headers = []
3393    for definition in parsed['definitions'] + parsed['attribute-sets']:
3394        if 'header' in definition:
3395            headers.append(definition['header'])
3396    if args.mode == 'user':
3397        headers.append(parsed.uapi_header)
3398    seen_header = []
3399    for one in headers:
3400        if one not in seen_header:
3401            cw.p(f"#include <{one}>")
3402            seen_header.append(one)
3403    cw.nl()
3404
3405    if args.mode == "user":
3406        if not args.header:
3407            cw.p("#include <linux/genetlink.h>")
3408            cw.nl()
3409            for one in args.user_header:
3410                cw.p(f'#include "{one}"')
3411        else:
3412            cw.p('struct ynl_sock;')
3413            cw.nl()
3414            render_user_family(parsed, cw, True)
3415        cw.nl()
3416
3417    if args.mode == "kernel":
3418        if args.header:
3419            for _, struct in sorted(parsed.pure_nested_structs.items()):
3420                if struct.request:
3421                    cw.p('/* Common nested types */')
3422                    break
3423            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3424                if struct.request:
3425                    print_req_policy_fwd(cw, struct)
3426            cw.nl()
3427
3428            if parsed.kernel_policy == 'global':
3429                cw.p(f"/* Global operation policy for {parsed.name} */")
3430
3431                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3432                print_req_policy_fwd(cw, struct)
3433                cw.nl()
3434
3435            if parsed.kernel_policy in {'per-op', 'split'}:
3436                for op_name, op in parsed.ops.items():
3437                    if 'do' in op and 'event' not in op:
3438                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3439                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3440                        cw.nl()
3441
3442            print_kernel_op_table_hdr(parsed, cw)
3443            print_kernel_mcgrp_hdr(parsed, cw)
3444            print_kernel_family_struct_hdr(parsed, cw)
3445        else:
3446            print_kernel_policy_ranges(parsed, cw)
3447            print_kernel_policy_sparse_enum_validates(parsed, cw)
3448
3449            for _, struct in sorted(parsed.pure_nested_structs.items()):
3450                if struct.request:
3451                    cw.p('/* Common nested types */')
3452                    break
3453            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3454                if struct.request:
3455                    print_req_policy(cw, struct)
3456            cw.nl()
3457
3458            if parsed.kernel_policy == 'global':
3459                cw.p(f"/* Global operation policy for {parsed.name} */")
3460
3461                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3462                print_req_policy(cw, struct)
3463                cw.nl()
3464
3465            for op_name, op in parsed.ops.items():
3466                if parsed.kernel_policy in {'per-op', 'split'}:
3467                    for op_mode in ['do', 'dump']:
3468                        if op_mode in op and 'request' in op[op_mode]:
3469                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3470                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3471                            print_req_policy(cw, ri.struct['request'], ri=ri)
3472                            cw.nl()
3473
3474            print_kernel_op_table(parsed, cw)
3475            print_kernel_mcgrp_src(parsed, cw)
3476            print_kernel_family_struct_src(parsed, cw)
3477
3478    if args.mode == "user":
3479        if args.header:
3480            cw.p('/* Enums */')
3481            put_op_name_fwd(parsed, cw)
3482
3483            for name, const in parsed.consts.items():
3484                if isinstance(const, EnumSet):
3485                    put_enum_to_str_fwd(parsed, cw, const)
3486            cw.nl()
3487
3488            cw.p('/* Common nested types */')
3489            for attr_set, struct in parsed.pure_nested_structs.items():
3490                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3491                print_type_full(ri, struct)
3492                if struct.request and struct.in_multi_val:
3493                    free_rsp_nested_prototype(ri)
3494                    cw.nl()
3495
3496            for op_name, op in parsed.ops.items():
3497                cw.p(f"/* ============== {op.enum_name} ============== */")
3498
3499                if 'do' in op and 'event' not in op:
3500                    cw.p(f"/* {op.enum_name} - do */")
3501                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3502                    print_req_type(ri)
3503                    print_req_type_helpers(ri)
3504                    cw.nl()
3505                    print_rsp_type(ri)
3506                    print_rsp_type_helpers(ri)
3507                    cw.nl()
3508                    print_req_prototype(ri)
3509                    cw.nl()
3510
3511                if 'dump' in op:
3512                    cw.p(f"/* {op.enum_name} - dump */")
3513                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3514                    print_req_type(ri)
3515                    print_req_type_helpers(ri)
3516                    if not ri.type_consistent or ri.type_oneside:
3517                        print_rsp_type(ri)
3518                    print_wrapped_type(ri)
3519                    print_dump_prototype(ri)
3520                    cw.nl()
3521
3522                if op.has_ntf:
3523                    cw.p(f"/* {op.enum_name} - notify */")
3524                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3525                    if not ri.type_consistent:
3526                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3527                    print_wrapped_type(ri)
3528
3529            for op_name, op in parsed.ntfs.items():
3530                if 'event' in op:
3531                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3532                    cw.p(f"/* {op.enum_name} - event */")
3533                    print_rsp_type(ri)
3534                    cw.nl()
3535                    print_wrapped_type(ri)
3536            cw.nl()
3537        else:
3538            cw.p('/* Enums */')
3539            put_op_name(parsed, cw)
3540
3541            for name, const in parsed.consts.items():
3542                if isinstance(const, EnumSet):
3543                    put_enum_to_str(parsed, cw, const)
3544            cw.nl()
3545
3546            has_recursive_nests = False
3547            cw.p('/* Policies */')
3548            for struct in parsed.pure_nested_structs.values():
3549                if struct.recursive:
3550                    put_typol_fwd(cw, struct)
3551                    has_recursive_nests = True
3552            if has_recursive_nests:
3553                cw.nl()
3554            for struct in parsed.pure_nested_structs.values():
3555                put_typol(cw, struct)
3556            for name in parsed.root_sets:
3557                struct = Struct(parsed, name)
3558                put_typol(cw, struct)
3559
3560            cw.p('/* Common nested types */')
3561            if has_recursive_nests:
3562                for attr_set, struct in parsed.pure_nested_structs.items():
3563                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3564                    free_rsp_nested_prototype(ri)
3565                    if struct.request:
3566                        put_req_nested_prototype(ri, struct)
3567                    if struct.reply:
3568                        parse_rsp_nested_prototype(ri, struct)
3569                cw.nl()
3570            for attr_set, struct in parsed.pure_nested_structs.items():
3571                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3572
3573                free_rsp_nested(ri, struct)
3574                if struct.request:
3575                    put_req_nested(ri, struct)
3576                if struct.reply:
3577                    parse_rsp_nested(ri, struct)
3578
3579            for op_name, op in parsed.ops.items():
3580                cw.p(f"/* ============== {op.enum_name} ============== */")
3581                if 'do' in op and 'event' not in op:
3582                    cw.p(f"/* {op.enum_name} - do */")
3583                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3584                    print_req_free(ri)
3585                    print_rsp_free(ri)
3586                    parse_rsp_msg(ri)
3587                    print_req(ri)
3588                    cw.nl()
3589
3590                if 'dump' in op:
3591                    cw.p(f"/* {op.enum_name} - dump */")
3592                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3593                    if not ri.type_consistent or ri.type_oneside:
3594                        parse_rsp_msg(ri, deref=True)
3595                    print_req_free(ri)
3596                    print_dump_type_free(ri)
3597                    print_dump(ri)
3598                    cw.nl()
3599
3600                if op.has_ntf:
3601                    cw.p(f"/* {op.enum_name} - notify */")
3602                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3603                    if not ri.type_consistent:
3604                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3605                    print_ntf_type_free(ri)
3606
3607            for op_name, op in parsed.ntfs.items():
3608                if 'event' in op:
3609                    cw.p(f"/* {op.enum_name} - event */")
3610
3611                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3612                    parse_rsp_msg(ri)
3613
3614                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3615                    print_ntf_type_free(ri)
3616            cw.nl()
3617            render_user_family(parsed, cw, False)
3618
3619    if args.header:
3620        cw.p(f'#endif /* {hdr_prot} */')
3621
3622
3623if __name__ == "__main__":
3624    main()
3625