xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision cf5869977702b1d51e3b4d58b6c559a98a366114)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17from lib import SpecSubMessage, SpecSubMessageFormat
18
19
20def c_upper(name):
21    return name.upper().replace('-', '_')
22
23
24def c_lower(name):
25    return name.lower().replace('-', '_')
26
27
28def limit_to_number(name):
29    """
30    Turn a string limit like u32-max or s64-min into its numerical value
31    """
32    if name[0] == 'u' and name.endswith('-min'):
33        return 0
34    width = int(name[1:-4])
35    if name[0] == 's':
36        width -= 1
37    value = (1 << width) - 1
38    if name[0] == 's' and name.endswith('-min'):
39        value = -value - 1
40    return value
41
42
43class BaseNlLib:
44    def get_family_id(self):
45        return 'ys->family_id'
46
47
48class Type(SpecAttr):
49    def __init__(self, family, attr_set, attr, value):
50        super().__init__(family, attr_set, attr, value)
51
52        self.attr = attr
53        self.attr_set = attr_set
54        self.type = attr['type']
55        self.checks = attr.get('checks', {})
56
57        self.request = False
58        self.reply = False
59
60        self.is_selector = False
61
62        if 'len' in attr:
63            self.len = attr['len']
64
65        if 'nested-attributes' in attr:
66            nested = attr['nested-attributes']
67        elif 'sub-message' in attr:
68            nested = attr['sub-message']
69        else:
70            nested = None
71
72        if nested:
73            self.nested_attrs = nested
74            if self.nested_attrs == family.name:
75                self.nested_render_name = c_lower(f"{family.ident_name}")
76            else:
77                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
78
79            if self.nested_attrs in self.family.consts:
80                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
81            else:
82                self.nested_struct_type = 'struct ' + self.nested_render_name
83
84        self.c_name = c_lower(self.name)
85        if self.c_name in _C_KW:
86            self.c_name += '_'
87        if self.c_name[0].isdigit():
88            self.c_name = '_' + self.c_name
89
90        # Added by resolve():
91        self.enum_name = None
92        delattr(self, "enum_name")
93
94    def _get_real_attr(self):
95        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
96        return self.family.attr_sets[self.attr_set.subset_of][self.name]
97
98    def set_request(self):
99        self.request = True
100        if self.attr_set.subset_of:
101            self._get_real_attr().set_request()
102
103    def set_reply(self):
104        self.reply = True
105        if self.attr_set.subset_of:
106            self._get_real_attr().set_reply()
107
108    def get_limit(self, limit, default=None):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return value
112        if isinstance(value, int):
113            return value
114        if value in self.family.consts:
115            return self.family.consts[value]["value"]
116        return limit_to_number(value)
117
118    def get_limit_str(self, limit, default=None, suffix=''):
119        value = self.checks.get(limit, default)
120        if value is None:
121            return ''
122        if isinstance(value, int):
123            return str(value) + suffix
124        if value in self.family.consts:
125            const = self.family.consts[value]
126            if const.get('header'):
127                return c_upper(value)
128            return c_upper(f"{self.family['name']}-{value}")
129        return c_upper(value)
130
131    def resolve(self):
132        if 'parent-sub-message' in self.attr:
133            enum_name = self.attr['parent-sub-message'].enum_name
134        elif 'name-prefix' in self.attr:
135            enum_name = f"{self.attr['name-prefix']}{self.name}"
136        else:
137            enum_name = f"{self.attr_set.name_prefix}{self.name}"
138        self.enum_name = c_upper(enum_name)
139
140        if self.attr_set.subset_of:
141            if self.checks != self._get_real_attr().checks:
142                raise Exception("Overriding checks not supported by codegen, yet")
143
144    def is_multi_val(self):
145        return None
146
147    def is_scalar(self):
148        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
149
150    def is_recursive(self):
151        return False
152
153    def is_recursive_for_op(self, ri):
154        return self.is_recursive() and not ri.op
155
156    def presence_type(self):
157        return 'present'
158
159    def presence_member(self, space, type_filter):
160        if self.presence_type() != type_filter:
161            return
162
163        if self.presence_type() == 'present':
164            pfx = '__' if space == 'user' else ''
165            return f"{pfx}u32 {self.c_name}:1;"
166
167        if self.presence_type() in {'len', 'count'}:
168            pfx = '__' if space == 'user' else ''
169            return f"{pfx}u32 {self.c_name};"
170
171    def _complex_member_type(self, ri):
172        return None
173
174    def free_needs_iter(self):
175        return False
176
177    def _free_lines(self, ri, var, ref):
178        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
179            return [f'free({var}->{ref}{self.c_name});']
180        return []
181
182    def free(self, ri, var, ref):
183        lines = self._free_lines(ri, var, ref)
184        for line in lines:
185            ri.cw.p(line)
186
187    def arg_member(self, ri):
188        member = self._complex_member_type(ri)
189        if member:
190            spc = ' ' if member[-1] != '*' else ''
191            arg = [member + spc + '*' + self.c_name]
192            if self.presence_type() == 'count':
193                arg += ['unsigned int n_' + self.c_name]
194            return arg
195        raise Exception(f"Struct member not implemented for class type {self.type}")
196
197    def struct_member(self, ri):
198        member = self._complex_member_type(ri)
199        if member:
200            ptr = '*' if self.is_multi_val() else ''
201            if self.is_recursive_for_op(ri):
202                ptr = '*'
203            spc = ' ' if member[-1] != '*' else ''
204            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
205            return
206        members = self.arg_member(ri)
207        for one in members:
208            ri.cw.p(one + ';')
209
210    def _attr_policy(self, policy):
211        return '{ .type = ' + policy + ', }'
212
213    def attr_policy(self, cw):
214        policy = f'NLA_{c_upper(self.type)}'
215        if self.attr.get('byte-order') == 'big-endian':
216            if self.type in {'u16', 'u32'}:
217                policy = f'NLA_BE{self.type[1:]}'
218
219        spec = self._attr_policy(policy)
220        cw.p(f"\t[{self.enum_name}] = {spec},")
221
222    def _attr_typol(self):
223        raise Exception(f"Type policy not implemented for class type {self.type}")
224
225    def attr_typol(self, cw):
226        typol = self._attr_typol()
227        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
228
229    def _attr_put_line(self, ri, var, line):
230        presence = self.presence_type()
231        if presence in {'present', 'len'}:
232            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
233        ri.cw.p(f"{line};")
234
235    def _attr_put_simple(self, ri, var, put_type):
236        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
237        self._attr_put_line(ri, var, line)
238
239    def attr_put(self, ri, var):
240        raise Exception(f"Put not implemented for class type {self.type}")
241
242    def _attr_get(self, ri, var):
243        raise Exception(f"Attr get not implemented for class type {self.type}")
244
245    def attr_get(self, ri, var, first):
246        lines, init_lines, local_vars = self._attr_get(ri, var)
247        if type(lines) is str:
248            lines = [lines]
249        if type(init_lines) is str:
250            init_lines = [init_lines]
251
252        kw = 'if' if first else 'else if'
253        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
254        if local_vars:
255            for local in local_vars:
256                ri.cw.p(local)
257            ri.cw.nl()
258
259        if not self.is_multi_val():
260            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
261            ri.cw.p("return YNL_PARSE_CB_ERROR;")
262            if self.presence_type() == 'present':
263                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
264
265        if init_lines:
266            ri.cw.nl()
267            for line in init_lines:
268                ri.cw.p(line)
269
270        for line in lines:
271            ri.cw.p(line)
272        ri.cw.block_end()
273        return True
274
275    def _setter_lines(self, ri, member, presence):
276        raise Exception(f"Setter not implemented for class type {self.type}")
277
278    def setter(self, ri, space, direction, deref=False, ref=None):
279        ref = (ref if ref else []) + [self.c_name]
280        var = "req"
281        member = f"{var}->{'.'.join(ref)}"
282
283        local_vars = []
284        if self.free_needs_iter():
285            local_vars += ['unsigned int i;']
286
287        code = []
288        presence = ''
289        for i in range(0, len(ref)):
290            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
291            # Every layer below last is a nest, so we know it uses bit presence
292            # last layer is "self" and may be a complex type
293            if i == len(ref) - 1 and self.presence_type() != 'present':
294                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
295                continue
296            code.append(presence + ' = 1;')
297        ref_path = '.'.join(ref[:-1])
298        if ref_path:
299            ref_path += '.'
300        code += self._free_lines(ri, var, ref_path)
301        code += self._setter_lines(ri, member, presence)
302
303        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
304        free = bool([x for x in code if 'free(' in x])
305        alloc = bool([x for x in code if 'alloc(' in x])
306        if free and not alloc:
307            func_name = '__' + func_name
308        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
309                         body=code,
310                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
311
312
313class TypeUnused(Type):
314    def presence_type(self):
315        return ''
316
317    def arg_member(self, ri):
318        return []
319
320    def _attr_get(self, ri, var):
321        return ['return YNL_PARSE_CB_ERROR;'], None, None
322
323    def _attr_typol(self):
324        return '.type = YNL_PT_REJECT, '
325
326    def attr_policy(self, cw):
327        pass
328
329    def attr_put(self, ri, var):
330        pass
331
332    def attr_get(self, ri, var, first):
333        pass
334
335    def setter(self, ri, space, direction, deref=False, ref=None):
336        pass
337
338
339class TypePad(Type):
340    def presence_type(self):
341        return ''
342
343    def arg_member(self, ri):
344        return []
345
346    def _attr_typol(self):
347        return '.type = YNL_PT_IGNORE, '
348
349    def attr_put(self, ri, var):
350        pass
351
352    def attr_get(self, ri, var, first):
353        pass
354
355    def attr_policy(self, cw):
356        pass
357
358    def setter(self, ri, space, direction, deref=False, ref=None):
359        pass
360
361
362class TypeScalar(Type):
363    def __init__(self, family, attr_set, attr, value):
364        super().__init__(family, attr_set, attr, value)
365
366        self.byte_order_comment = ''
367        if 'byte-order' in attr:
368            self.byte_order_comment = f" /* {attr['byte-order']} */"
369
370        # Classic families have some funny enums, don't bother
371        # computing checks, since we only need them for kernel policies
372        if not family.is_classic():
373            self._init_checks()
374
375        # Added by resolve():
376        self.is_bitfield = None
377        delattr(self, "is_bitfield")
378        self.type_name = None
379        delattr(self, "type_name")
380
381    def resolve(self):
382        self.resolve_up(super())
383
384        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
385            self.is_bitfield = True
386        elif 'enum' in self.attr:
387            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
388        else:
389            self.is_bitfield = False
390
391        if not self.is_bitfield and 'enum' in self.attr:
392            self.type_name = self.family.consts[self.attr['enum']].user_type
393        elif self.is_auto_scalar:
394            self.type_name = '__' + self.type[0] + '64'
395        else:
396            self.type_name = '__' + self.type
397
398    def _init_checks(self):
399        if 'enum' in self.attr:
400            enum = self.family.consts[self.attr['enum']]
401            low, high = enum.value_range()
402            if low == None and high == None:
403                self.checks['sparse'] = True
404            else:
405                if 'min' not in self.checks:
406                    if low != 0 or self.type[0] == 's':
407                        self.checks['min'] = low
408                if 'max' not in self.checks:
409                    self.checks['max'] = high
410
411        if 'min' in self.checks and 'max' in self.checks:
412            if self.get_limit('min') > self.get_limit('max'):
413                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
414            self.checks['range'] = True
415
416        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
417        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
418        if low < 0 and self.type[0] == 'u':
419            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
420        if low < -32768 or high > 32767:
421            self.checks['full-range'] = True
422
423    def _attr_policy(self, policy):
424        if 'flags-mask' in self.checks or self.is_bitfield:
425            if self.is_bitfield:
426                enum = self.family.consts[self.attr['enum']]
427                mask = enum.get_mask(as_flags=True)
428            else:
429                flags = self.family.consts[self.checks['flags-mask']]
430                flag_cnt = len(flags['entries'])
431                mask = (1 << flag_cnt) - 1
432            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
433        elif 'full-range' in self.checks:
434            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
435        elif 'range' in self.checks:
436            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
437        elif 'min' in self.checks:
438            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
439        elif 'max' in self.checks:
440            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
441        elif 'sparse' in self.checks:
442            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
443        return super()._attr_policy(policy)
444
445    def _attr_typol(self):
446        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
447
448    def arg_member(self, ri):
449        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
450
451    def attr_put(self, ri, var):
452        self._attr_put_simple(ri, var, self.type)
453
454    def _attr_get(self, ri, var):
455        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
456
457    def _setter_lines(self, ri, member, presence):
458        return [f"{member} = {self.c_name};"]
459
460
461class TypeFlag(Type):
462    def arg_member(self, ri):
463        return []
464
465    def _attr_typol(self):
466        return '.type = YNL_PT_FLAG, '
467
468    def attr_put(self, ri, var):
469        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
470
471    def _attr_get(self, ri, var):
472        return [], None, None
473
474    def _setter_lines(self, ri, member, presence):
475        return []
476
477
478class TypeString(Type):
479    def arg_member(self, ri):
480        return [f"const char *{self.c_name}"]
481
482    def presence_type(self):
483        return 'len'
484
485    def struct_member(self, ri):
486        ri.cw.p(f"char *{self.c_name};")
487
488    def _attr_typol(self):
489        typol = f'.type = YNL_PT_NUL_STR, '
490        if self.is_selector:
491            typol += '.is_selector = 1, '
492        return typol
493
494    def _attr_policy(self, policy):
495        if 'exact-len' in self.checks:
496            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
497        else:
498            mem = '{ .type = ' + policy
499            if 'max-len' in self.checks:
500                mem += ', .len = ' + self.get_limit_str('max-len')
501            mem += ', }'
502        return mem
503
504    def attr_policy(self, cw):
505        if self.checks.get('unterminated-ok', False):
506            policy = 'NLA_STRING'
507        else:
508            policy = 'NLA_NUL_STRING'
509
510        spec = self._attr_policy(policy)
511        cw.p(f"\t[{self.enum_name}] = {spec},")
512
513    def attr_put(self, ri, var):
514        self._attr_put_simple(ri, var, 'str')
515
516    def _attr_get(self, ri, var):
517        len_mem = var + '->_len.' + self.c_name
518        return [f"{len_mem} = len;",
519                f"{var}->{self.c_name} = malloc(len + 1);",
520                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
521                f"{var}->{self.c_name}[len] = 0;"], \
522               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
523               ['unsigned int len;']
524
525    def _setter_lines(self, ri, member, presence):
526        return [f"{presence} = strlen({self.c_name});",
527                f"{member} = malloc({presence} + 1);",
528                f'memcpy({member}, {self.c_name}, {presence});',
529                f'{member}[{presence}] = 0;']
530
531
532class TypeBinary(Type):
533    def arg_member(self, ri):
534        return [f"const void *{self.c_name}", 'size_t len']
535
536    def presence_type(self):
537        return 'len'
538
539    def struct_member(self, ri):
540        ri.cw.p(f"void *{self.c_name};")
541
542    def _attr_typol(self):
543        return f'.type = YNL_PT_BINARY,'
544
545    def _attr_policy(self, policy):
546        if len(self.checks) == 0:
547            pass
548        elif len(self.checks) == 1:
549            check_name = list(self.checks)[0]
550            if check_name not in {'exact-len', 'min-len', 'max-len'}:
551                raise Exception('Unsupported check for binary type: ' + check_name)
552        else:
553            raise Exception('More than one check for binary type not implemented, yet')
554
555        if len(self.checks) == 0:
556            mem = '{ .type = NLA_BINARY, }'
557        elif 'exact-len' in self.checks:
558            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
559        elif 'min-len' in self.checks:
560            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
561        elif 'max-len' in self.checks:
562            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
563
564        return mem
565
566    def attr_put(self, ri, var):
567        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
568                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
569
570    def _attr_get(self, ri, var):
571        len_mem = var + '->_len.' + self.c_name
572        return [f"{len_mem} = len;",
573                f"{var}->{self.c_name} = malloc(len);",
574                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
575               ['len = ynl_attr_data_len(attr);'], \
576               ['unsigned int len;']
577
578    def _setter_lines(self, ri, member, presence):
579        return [f"{presence} = len;",
580                f"{member} = malloc({presence});",
581                f'memcpy({member}, {self.c_name}, {presence});']
582
583
584class TypeBinaryStruct(TypeBinary):
585    def struct_member(self, ri):
586        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
587
588    def _attr_get(self, ri, var):
589        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
590        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
591        return [f"{len_mem} = len;",
592                f"if (len < {struct_sz})",
593                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
594                "else",
595                f"{var}->{self.c_name} = malloc(len);",
596                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
597               ['len = ynl_attr_data_len(attr);'], \
598               ['unsigned int len;']
599
600
601class TypeBinaryScalarArray(TypeBinary):
602    def arg_member(self, ri):
603        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
604
605    def presence_type(self):
606        return 'count'
607
608    def struct_member(self, ri):
609        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
610
611    def attr_put(self, ri, var):
612        presence = self.presence_type()
613        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
614        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
615        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
616                f"{var}->{self.c_name}, i);")
617        ri.cw.block_end()
618
619    def _attr_get(self, ri, var):
620        len_mem = var + '->_count.' + self.c_name
621        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
622                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
623                f"{var}->{self.c_name} = malloc(len);",
624                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
625               ['len = ynl_attr_data_len(attr);'], \
626               ['unsigned int len;']
627
628    def _setter_lines(self, ri, member, presence):
629        return [f"{presence} = count;",
630                f"count *= sizeof(__{self.get('sub-type')});",
631                f"{member} = malloc(count);",
632                f'memcpy({member}, {self.c_name}, count);']
633
634
635class TypeBitfield32(Type):
636    def _complex_member_type(self, ri):
637        return "struct nla_bitfield32"
638
639    def _attr_typol(self):
640        return f'.type = YNL_PT_BITFIELD32, '
641
642    def _attr_policy(self, policy):
643        if not 'enum' in self.attr:
644            raise Exception('Enum required for bitfield32 attr')
645        enum = self.family.consts[self.attr['enum']]
646        mask = enum.get_mask(as_flags=True)
647        return f"NLA_POLICY_BITFIELD32({mask})"
648
649    def attr_put(self, ri, var):
650        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
651        self._attr_put_line(ri, var, line)
652
653    def _attr_get(self, ri, var):
654        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
655
656    def _setter_lines(self, ri, member, presence):
657        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
658
659
660class TypeNest(Type):
661    def is_recursive(self):
662        return self.family.pure_nested_structs[self.nested_attrs].recursive
663
664    def _complex_member_type(self, ri):
665        return self.nested_struct_type
666
667    def _free_lines(self, ri, var, ref):
668        lines = []
669        at = '&'
670        if self.is_recursive_for_op(ri):
671            at = ''
672            lines += [f'if ({var}->{ref}{self.c_name})']
673        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
674        return lines
675
676    def _attr_typol(self):
677        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
678
679    def _attr_policy(self, policy):
680        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
681
682    def attr_put(self, ri, var):
683        at = '' if self.is_recursive_for_op(ri) else '&'
684        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
685                            f"{self.enum_name}, {at}{var}->{self.c_name})")
686
687    def _attr_get(self, ri, var):
688        pns = self.family.pure_nested_structs[self.nested_attrs]
689        args = ["&parg", "attr"]
690        for sel in pns.external_selectors():
691            args.append(f'{var}->{sel.name}')
692        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
693                     "return YNL_PARSE_CB_ERROR;"]
694        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
695                      f"parg.data = &{var}->{self.c_name};"]
696        return get_lines, init_lines, None
697
698    def setter(self, ri, space, direction, deref=False, ref=None):
699        ref = (ref if ref else []) + [self.c_name]
700
701        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
702            if attr.is_recursive():
703                continue
704            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
705
706
707class TypeMultiAttr(Type):
708    def __init__(self, family, attr_set, attr, value, base_type):
709        super().__init__(family, attr_set, attr, value)
710
711        self.base_type = base_type
712
713    def is_multi_val(self):
714        return True
715
716    def presence_type(self):
717        return 'count'
718
719    def _complex_member_type(self, ri):
720        if 'type' not in self.attr or self.attr['type'] == 'nest':
721            return self.nested_struct_type
722        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
723            return None  # use arg_member()
724        elif self.attr['type'] == 'string':
725            return 'struct ynl_string *'
726        elif self.attr['type'] in scalars:
727            scalar_pfx = '__' if ri.ku_space == 'user' else ''
728            return scalar_pfx + self.attr['type']
729        else:
730            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
731
732    def arg_member(self, ri):
733        if self.type == 'binary' and 'struct' in self.attr:
734            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
735                    f'unsigned int n_{self.c_name}']
736        return super().arg_member(ri)
737
738    def free_needs_iter(self):
739        return self.attr['type'] in {'nest', 'string'}
740
741    def _free_lines(self, ri, var, ref):
742        lines = []
743        if self.attr['type'] in scalars:
744            lines += [f"free({var}->{ref}{self.c_name});"]
745        elif self.attr['type'] == 'binary':
746            lines += [f"free({var}->{ref}{self.c_name});"]
747        elif self.attr['type'] == 'string':
748            lines += [
749                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
750                f"free({var}->{ref}{self.c_name}[i]);",
751                f"free({var}->{ref}{self.c_name});",
752            ]
753        elif 'type' not in self.attr or self.attr['type'] == 'nest':
754            lines += [
755                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
756                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
757                f"free({var}->{ref}{self.c_name});",
758            ]
759        else:
760            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
761        return lines
762
763    def _attr_policy(self, policy):
764        return self.base_type._attr_policy(policy)
765
766    def _attr_typol(self):
767        return self.base_type._attr_typol()
768
769    def _attr_get(self, ri, var):
770        return f'n_{self.c_name}++;', None, None
771
772    def attr_put(self, ri, var):
773        if self.attr['type'] in scalars:
774            put_type = self.type
775            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
776            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
777        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
778            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
779            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
780        elif self.attr['type'] == 'string':
781            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
782            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
783        elif 'type' not in self.attr or self.attr['type'] == 'nest':
784            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
785            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
786                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
787        else:
788            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
789
790    def _setter_lines(self, ri, member, presence):
791        return [f"{member} = {self.c_name};",
792                f"{presence} = n_{self.c_name};"]
793
794
795class TypeArrayNest(Type):
796    def is_multi_val(self):
797        return True
798
799    def presence_type(self):
800        return 'count'
801
802    def _complex_member_type(self, ri):
803        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
804            return self.nested_struct_type
805        elif self.attr['sub-type'] in scalars:
806            scalar_pfx = '__' if ri.ku_space == 'user' else ''
807            return scalar_pfx + self.attr['sub-type']
808        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
809            return None  # use arg_member()
810        else:
811            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
812
813    def arg_member(self, ri):
814        if self.sub_type == 'binary' and 'exact-len' in self.checks:
815            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
816                    f'unsigned int n_{self.c_name}']
817        return super().arg_member(ri)
818
819    def _attr_typol(self):
820        if self.attr['sub-type'] in scalars:
821            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
822        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
823            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
824        else:
825            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
826
827    def _attr_get(self, ri, var):
828        local_vars = ['const struct nlattr *attr2;']
829        get_lines = [f'attr_{self.c_name} = attr;',
830                     'ynl_attr_for_each_nested(attr2, attr) {',
831                     '\tif (ynl_attr_validate(yarg, attr2))',
832                     '\t\treturn YNL_PARSE_CB_ERROR;',
833                     f'\t{var}->_count.{self.c_name}++;',
834                     '}']
835        return get_lines, None, local_vars
836
837    def attr_put(self, ri, var):
838        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
839        if self.sub_type in scalars:
840            put_type = self.sub_type
841            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
842            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
843            ri.cw.block_end()
844        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
845            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
846            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
847        elif self.sub_type == 'nest':
848            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
849            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
850        else:
851            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
852        ri.cw.p('ynl_attr_nest_end(nlh, array);')
853
854    def _setter_lines(self, ri, member, presence):
855        return [f"{member} = {self.c_name};",
856                f"{presence} = n_{self.c_name};"]
857
858
859class TypeNestTypeValue(Type):
860    def _complex_member_type(self, ri):
861        return self.nested_struct_type
862
863    def _attr_typol(self):
864        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
865
866    def _attr_get(self, ri, var):
867        prev = 'attr'
868        tv_args = ''
869        get_lines = []
870        local_vars = []
871        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
872                      f"parg.data = &{var}->{self.c_name};"]
873        if 'type-value' in self.attr:
874            tv_names = [c_lower(x) for x in self.attr["type-value"]]
875            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
876            local_vars += [f'__u32 {", ".join(tv_names)};']
877            for level in self.attr["type-value"]:
878                level = c_lower(level)
879                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
880                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
881                prev = 'attr_' + level
882
883            tv_args = f", {', '.join(tv_names)}"
884
885        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
886        return get_lines, init_lines, local_vars
887
888
889class TypeSubMessage(TypeNest):
890    def __init__(self, family, attr_set, attr, value):
891        super().__init__(family, attr_set, attr, value)
892
893        self.selector = Selector(attr, attr_set)
894
895    def _attr_typol(self):
896        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
897        typol += '.is_submsg = 1, '
898        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
899        # support external selectors. No family uses sub-messages with external
900        # selector for requests so this is fine for now.
901        if not self.selector.is_external():
902            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
903        return typol
904
905    def _attr_get(self, ri, var):
906        sel = c_lower(self['selector'])
907        if self.selector.is_external():
908            sel_var = f"_sel_{sel}"
909        else:
910            sel_var = f"{var}->{sel}"
911        get_lines = [f'if (!{sel_var})',
912                     f'return ynl_submsg_failed(yarg, "%s", "%s");' %
913                        (self.name, self['selector']),
914                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
915                     "return YNL_PARSE_CB_ERROR;"]
916        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
917                      f"parg.data = &{var}->{self.c_name};"]
918        return get_lines, init_lines, None
919
920
921class Selector:
922    def __init__(self, msg_attr, attr_set):
923        self.name = msg_attr["selector"]
924
925        if self.name in attr_set:
926            self.attr = attr_set[self.name]
927            self.attr.is_selector = True
928            self._external = False
929        else:
930            # The selector will need to get passed down thru the structs
931            self.attr = None
932            self._external = True
933
934    def set_attr(self, attr):
935        self.attr = attr
936
937    def is_external(self):
938        return self._external
939
940
941class Struct:
942    def __init__(self, family, space_name, type_list=None, fixed_header=None,
943                 inherited=None, submsg=None):
944        self.family = family
945        self.space_name = space_name
946        self.attr_set = family.attr_sets[space_name]
947        # Use list to catch comparisons with empty sets
948        self._inherited = inherited if inherited is not None else []
949        self.inherited = []
950        self.fixed_header = None
951        if fixed_header:
952            self.fixed_header = 'struct ' + c_lower(fixed_header)
953        self.submsg = submsg
954
955        self.nested = type_list is None
956        if family.name == c_lower(space_name):
957            self.render_name = c_lower(family.ident_name)
958        else:
959            self.render_name = c_lower(family.ident_name + '-' + space_name)
960        self.struct_name = 'struct ' + self.render_name
961        if self.nested and space_name in family.consts:
962            self.struct_name += '_'
963        self.ptr_name = self.struct_name + ' *'
964        # All attr sets this one contains, directly or multiple levels down
965        self.child_nests = set()
966
967        self.request = False
968        self.reply = False
969        self.recursive = False
970        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
971
972        self.attr_list = []
973        self.attrs = dict()
974        if type_list is not None:
975            for t in type_list:
976                self.attr_list.append((t, self.attr_set[t]),)
977        else:
978            for t in self.attr_set:
979                self.attr_list.append((t, self.attr_set[t]),)
980
981        max_val = 0
982        self.attr_max_val = None
983        for name, attr in self.attr_list:
984            if attr.value >= max_val:
985                max_val = attr.value
986                self.attr_max_val = attr
987            self.attrs[name] = attr
988
989    def __iter__(self):
990        yield from self.attrs
991
992    def __getitem__(self, key):
993        return self.attrs[key]
994
995    def member_list(self):
996        return self.attr_list
997
998    def set_inherited(self, new_inherited):
999        if self._inherited != new_inherited:
1000            raise Exception("Inheriting different members not supported")
1001        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1002
1003    def external_selectors(self):
1004        sels = []
1005        for name, attr in self.attr_list:
1006            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1007                sels.append(attr.selector)
1008        return sels
1009
1010    def free_needs_iter(self):
1011        for _, attr in self.attr_list:
1012            if attr.free_needs_iter():
1013                return True
1014        return False
1015
1016
1017class EnumEntry(SpecEnumEntry):
1018    def __init__(self, enum_set, yaml, prev, value_start):
1019        super().__init__(enum_set, yaml, prev, value_start)
1020
1021        if prev:
1022            self.value_change = (self.value != prev.value + 1)
1023        else:
1024            self.value_change = (self.value != 0)
1025        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1026
1027        # Added by resolve:
1028        self.c_name = None
1029        delattr(self, "c_name")
1030
1031    def resolve(self):
1032        self.resolve_up(super())
1033
1034        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1035
1036
1037class EnumSet(SpecEnumSet):
1038    def __init__(self, family, yaml):
1039        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1040
1041        if 'enum-name' in yaml:
1042            if yaml['enum-name']:
1043                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1044                self.user_type = self.enum_name
1045            else:
1046                self.enum_name = None
1047        else:
1048            self.enum_name = 'enum ' + self.render_name
1049
1050        if self.enum_name:
1051            self.user_type = self.enum_name
1052        else:
1053            self.user_type = 'int'
1054
1055        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1056        self.header = yaml.get('header', None)
1057        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1058
1059        super().__init__(family, yaml)
1060
1061    def new_entry(self, entry, prev_entry, value_start):
1062        return EnumEntry(self, entry, prev_entry, value_start)
1063
1064    def value_range(self):
1065        low = min([x.value for x in self.entries.values()])
1066        high = max([x.value for x in self.entries.values()])
1067
1068        if high - low + 1 != len(self.entries):
1069            return None, None
1070
1071        return low, high
1072
1073
1074class AttrSet(SpecAttrSet):
1075    def __init__(self, family, yaml):
1076        super().__init__(family, yaml)
1077
1078        if self.subset_of is None:
1079            if 'name-prefix' in yaml:
1080                pfx = yaml['name-prefix']
1081            elif self.name == family.name:
1082                pfx = family.ident_name + '-a-'
1083            else:
1084                pfx = f"{family.ident_name}-a-{self.name}-"
1085            self.name_prefix = c_upper(pfx)
1086            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1087            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1088        else:
1089            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1090            self.max_name = family.attr_sets[self.subset_of].max_name
1091            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1092
1093        # Added by resolve:
1094        self.c_name = None
1095        delattr(self, "c_name")
1096
1097    def resolve(self):
1098        self.c_name = c_lower(self.name)
1099        if self.c_name in _C_KW:
1100            self.c_name += '_'
1101        if self.c_name == self.family.c_name:
1102            self.c_name = ''
1103
1104    def new_attr(self, elem, value):
1105        if elem['type'] in scalars:
1106            t = TypeScalar(self.family, self, elem, value)
1107        elif elem['type'] == 'unused':
1108            t = TypeUnused(self.family, self, elem, value)
1109        elif elem['type'] == 'pad':
1110            t = TypePad(self.family, self, elem, value)
1111        elif elem['type'] == 'flag':
1112            t = TypeFlag(self.family, self, elem, value)
1113        elif elem['type'] == 'string':
1114            t = TypeString(self.family, self, elem, value)
1115        elif elem['type'] == 'binary':
1116            if 'struct' in elem:
1117                t = TypeBinaryStruct(self.family, self, elem, value)
1118            elif elem.get('sub-type') in scalars:
1119                t = TypeBinaryScalarArray(self.family, self, elem, value)
1120            else:
1121                t = TypeBinary(self.family, self, elem, value)
1122        elif elem['type'] == 'bitfield32':
1123            t = TypeBitfield32(self.family, self, elem, value)
1124        elif elem['type'] == 'nest':
1125            t = TypeNest(self.family, self, elem, value)
1126        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1127            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1128                t = TypeArrayNest(self.family, self, elem, value)
1129            else:
1130                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1131        elif elem['type'] == 'nest-type-value':
1132            t = TypeNestTypeValue(self.family, self, elem, value)
1133        elif elem['type'] == 'sub-message':
1134            t = TypeSubMessage(self.family, self, elem, value)
1135        else:
1136            raise Exception(f"No typed class for type {elem['type']}")
1137
1138        if 'multi-attr' in elem and elem['multi-attr']:
1139            t = TypeMultiAttr(self.family, self, elem, value, t)
1140
1141        return t
1142
1143
1144class Operation(SpecOperation):
1145    def __init__(self, family, yaml, req_value, rsp_value):
1146        # Fill in missing operation properties (for fixed hdr-only msgs)
1147        for mode in ['do', 'dump', 'event']:
1148            for direction in ['request', 'reply']:
1149                try:
1150                    yaml[mode][direction].setdefault('attributes', [])
1151                except KeyError:
1152                    pass
1153
1154        super().__init__(family, yaml, req_value, rsp_value)
1155
1156        self.render_name = c_lower(family.ident_name + '_' + self.name)
1157
1158        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1159                         ('dump' in yaml and 'request' in yaml['dump'])
1160
1161        self.has_ntf = False
1162
1163        # Added by resolve:
1164        self.enum_name = None
1165        delattr(self, "enum_name")
1166
1167    def resolve(self):
1168        self.resolve_up(super())
1169
1170        if not self.is_async:
1171            self.enum_name = self.family.op_prefix + c_upper(self.name)
1172        else:
1173            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1174
1175    def mark_has_ntf(self):
1176        self.has_ntf = True
1177
1178
1179class SubMessage(SpecSubMessage):
1180    def __init__(self, family, yaml):
1181        super().__init__(family, yaml)
1182
1183        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1184
1185    def resolve(self):
1186        self.resolve_up(super())
1187
1188
1189class Family(SpecFamily):
1190    def __init__(self, file_name, exclude_ops):
1191        # Added by resolve:
1192        self.c_name = None
1193        delattr(self, "c_name")
1194        self.op_prefix = None
1195        delattr(self, "op_prefix")
1196        self.async_op_prefix = None
1197        delattr(self, "async_op_prefix")
1198        self.mcgrps = None
1199        delattr(self, "mcgrps")
1200        self.consts = None
1201        delattr(self, "consts")
1202        self.hooks = None
1203        delattr(self, "hooks")
1204
1205        super().__init__(file_name, exclude_ops=exclude_ops)
1206
1207        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1208        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1209
1210        if 'definitions' not in self.yaml:
1211            self.yaml['definitions'] = []
1212
1213        if 'uapi-header' in self.yaml:
1214            self.uapi_header = self.yaml['uapi-header']
1215        else:
1216            self.uapi_header = f"linux/{self.ident_name}.h"
1217        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1218            self.uapi_header_name = self.uapi_header[6:-2]
1219        else:
1220            self.uapi_header_name = self.ident_name
1221
1222    def resolve(self):
1223        self.resolve_up(super())
1224
1225        self.c_name = c_lower(self.ident_name)
1226        if 'name-prefix' in self.yaml['operations']:
1227            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1228        else:
1229            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1230        if 'async-prefix' in self.yaml['operations']:
1231            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1232        else:
1233            self.async_op_prefix = self.op_prefix
1234
1235        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1236
1237        self.hooks = dict()
1238        for when in ['pre', 'post']:
1239            self.hooks[when] = dict()
1240            for op_mode in ['do', 'dump']:
1241                self.hooks[when][op_mode] = dict()
1242                self.hooks[when][op_mode]['set'] = set()
1243                self.hooks[when][op_mode]['list'] = []
1244
1245        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1246        self.root_sets = dict()
1247        # dict space-name -> Struct
1248        self.pure_nested_structs = dict()
1249
1250        self._mark_notify()
1251        self._mock_up_events()
1252
1253        self._load_root_sets()
1254        self._load_nested_sets()
1255        self._load_attr_use()
1256        self._load_selector_passing()
1257        self._load_hooks()
1258
1259        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1260        if self.kernel_policy == 'global':
1261            self._load_global_policy()
1262
1263    def new_enum(self, elem):
1264        return EnumSet(self, elem)
1265
1266    def new_attr_set(self, elem):
1267        return AttrSet(self, elem)
1268
1269    def new_operation(self, elem, req_value, rsp_value):
1270        return Operation(self, elem, req_value, rsp_value)
1271
1272    def new_sub_message(self, elem):
1273        return SubMessage(self, elem)
1274
1275    def is_classic(self):
1276        return self.proto == 'netlink-raw'
1277
1278    def _mark_notify(self):
1279        for op in self.msgs.values():
1280            if 'notify' in op:
1281                self.ops[op['notify']].mark_has_ntf()
1282
1283    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1284    def _mock_up_events(self):
1285        for op in self.yaml['operations']['list']:
1286            if 'event' in op:
1287                op['do'] = {
1288                    'reply': {
1289                        'attributes': op['event']['attributes']
1290                    }
1291                }
1292
1293    def _load_root_sets(self):
1294        for op_name, op in self.msgs.items():
1295            if 'attribute-set' not in op:
1296                continue
1297
1298            req_attrs = set()
1299            rsp_attrs = set()
1300            for op_mode in ['do', 'dump']:
1301                if op_mode in op and 'request' in op[op_mode]:
1302                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1303                if op_mode in op and 'reply' in op[op_mode]:
1304                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1305            if 'event' in op:
1306                rsp_attrs.update(set(op['event']['attributes']))
1307
1308            if op['attribute-set'] not in self.root_sets:
1309                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1310            else:
1311                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1312                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1313
1314    def _sort_pure_types(self):
1315        # Try to reorder according to dependencies
1316        pns_key_list = list(self.pure_nested_structs.keys())
1317        pns_key_seen = set()
1318        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1319        for _ in range(rounds):
1320            if len(pns_key_list) == 0:
1321                break
1322            name = pns_key_list.pop(0)
1323            finished = True
1324            for _, spec in self.attr_sets[name].items():
1325                if 'nested-attributes' in spec:
1326                    nested = spec['nested-attributes']
1327                elif 'sub-message' in spec:
1328                    nested = spec.sub_message
1329                else:
1330                    continue
1331
1332                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1333                if self.pure_nested_structs[nested].recursive:
1334                    continue
1335                if nested not in pns_key_seen:
1336                    # Dicts are sorted, this will make struct last
1337                    struct = self.pure_nested_structs.pop(name)
1338                    self.pure_nested_structs[name] = struct
1339                    finished = False
1340                    break
1341            if finished:
1342                pns_key_seen.add(name)
1343            else:
1344                pns_key_list.append(name)
1345
1346    def _load_nested_set_nest(self, spec):
1347        inherit = set()
1348        nested = spec['nested-attributes']
1349        if nested not in self.root_sets:
1350            if nested not in self.pure_nested_structs:
1351                self.pure_nested_structs[nested] = \
1352                    Struct(self, nested, inherited=inherit,
1353                           fixed_header=spec.get('fixed-header'))
1354        else:
1355            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1356
1357        if 'type-value' in spec:
1358            if nested in self.root_sets:
1359                raise Exception("Inheriting members to a space used as root not supported")
1360            inherit.update(set(spec['type-value']))
1361        elif spec['type'] == 'indexed-array':
1362            inherit.add('idx')
1363        self.pure_nested_structs[nested].set_inherited(inherit)
1364
1365        return nested
1366
1367    def _load_nested_set_submsg(self, spec):
1368        # Fake the struct type for the sub-message itself
1369        # its not a attr_set but codegen wants attr_sets.
1370        submsg = self.sub_msgs[spec["sub-message"]]
1371        nested = submsg.name
1372
1373        attrs = []
1374        for name, fmt in submsg.formats.items():
1375            attr = {
1376                "name": name,
1377                "parent-sub-message": spec,
1378            }
1379            if 'attribute-set' in fmt:
1380                attr |= {
1381                    "type": "nest",
1382                    "nested-attributes": fmt['attribute-set'],
1383                }
1384                if 'fixed-header' in fmt:
1385                    attr |= { "fixed-header": fmt["fixed-header"] }
1386            elif 'fixed-header' in fmt:
1387                attr |= {
1388                    "type": "binary",
1389                    "struct": fmt["fixed-header"],
1390                }
1391            else:
1392                attr["type"] = "flag"
1393            attrs.append(attr)
1394
1395        self.attr_sets[nested] = AttrSet(self, {
1396            "name": nested,
1397            "name-pfx": self.name + '-' + spec.name + '-',
1398            "attributes": attrs
1399        })
1400
1401        if nested not in self.pure_nested_structs:
1402            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1403
1404        return nested
1405
1406    def _load_nested_sets(self):
1407        attr_set_queue = list(self.root_sets.keys())
1408        attr_set_seen = set(self.root_sets.keys())
1409
1410        while len(attr_set_queue):
1411            a_set = attr_set_queue.pop(0)
1412            for attr, spec in self.attr_sets[a_set].items():
1413                if 'nested-attributes' in spec:
1414                    nested = self._load_nested_set_nest(spec)
1415                elif 'sub-message' in spec:
1416                    nested = self._load_nested_set_submsg(spec)
1417                else:
1418                    continue
1419
1420                if nested not in attr_set_seen:
1421                    attr_set_queue.append(nested)
1422                    attr_set_seen.add(nested)
1423
1424        for root_set, rs_members in self.root_sets.items():
1425            for attr, spec in self.attr_sets[root_set].items():
1426                if 'nested-attributes' in spec:
1427                    nested = spec['nested-attributes']
1428                elif 'sub-message' in spec:
1429                    nested = spec.sub_message
1430                else:
1431                    nested = None
1432
1433                if nested:
1434                    if attr in rs_members['request']:
1435                        self.pure_nested_structs[nested].request = True
1436                    if attr in rs_members['reply']:
1437                        self.pure_nested_structs[nested].reply = True
1438
1439                    if spec.is_multi_val():
1440                        child = self.pure_nested_structs.get(nested)
1441                        child.in_multi_val = True
1442
1443        self._sort_pure_types()
1444
1445        # Propagate the request / reply / recursive
1446        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1447            for _, spec in self.attr_sets[attr_set].items():
1448                if attr_set in struct.child_nests:
1449                    struct.recursive = True
1450
1451                if 'nested-attributes' in spec:
1452                    child_name = spec['nested-attributes']
1453                elif 'sub-message' in spec:
1454                    child_name = spec.sub_message
1455                else:
1456                    continue
1457
1458                struct.child_nests.add(child_name)
1459                child = self.pure_nested_structs.get(child_name)
1460                if child:
1461                    if not child.recursive:
1462                        struct.child_nests.update(child.child_nests)
1463                    child.request |= struct.request
1464                    child.reply |= struct.reply
1465                    if spec.is_multi_val():
1466                        child.in_multi_val = True
1467
1468        self._sort_pure_types()
1469
1470    def _load_attr_use(self):
1471        for _, struct in self.pure_nested_structs.items():
1472            if struct.request:
1473                for _, arg in struct.member_list():
1474                    arg.set_request()
1475            if struct.reply:
1476                for _, arg in struct.member_list():
1477                    arg.set_reply()
1478
1479        for root_set, rs_members in self.root_sets.items():
1480            for attr, spec in self.attr_sets[root_set].items():
1481                if attr in rs_members['request']:
1482                    spec.set_request()
1483                if attr in rs_members['reply']:
1484                    spec.set_reply()
1485
1486    def _load_selector_passing(self):
1487        def all_structs():
1488            for k, v in reversed(self.pure_nested_structs.items()):
1489                yield k, v
1490            for k, _ in self.root_sets.items():
1491                yield k, None  # we don't have a struct, but it must be terminal
1492
1493        for attr_set, struct in all_structs():
1494            for _, spec in self.attr_sets[attr_set].items():
1495                if 'nested-attributes' in spec:
1496                    child_name = spec['nested-attributes']
1497                elif 'sub-message' in spec:
1498                    child_name = spec.sub_message
1499                else:
1500                    continue
1501
1502                child = self.pure_nested_structs.get(child_name)
1503                for selector in child.external_selectors():
1504                    if selector.name in self.attr_sets[attr_set]:
1505                        sel_attr = self.attr_sets[attr_set][selector.name]
1506                        selector.set_attr(sel_attr)
1507                    else:
1508                        raise Exception("Passing selector thru more than one layer not supported")
1509
1510    def _load_global_policy(self):
1511        global_set = set()
1512        attr_set_name = None
1513        for op_name, op in self.ops.items():
1514            if not op:
1515                continue
1516            if 'attribute-set' not in op:
1517                continue
1518
1519            if attr_set_name is None:
1520                attr_set_name = op['attribute-set']
1521            if attr_set_name != op['attribute-set']:
1522                raise Exception('For a global policy all ops must use the same set')
1523
1524            for op_mode in ['do', 'dump']:
1525                if op_mode in op:
1526                    req = op[op_mode].get('request')
1527                    if req:
1528                        global_set.update(req.get('attributes', []))
1529
1530        self.global_policy = []
1531        self.global_policy_set = attr_set_name
1532        for attr in self.attr_sets[attr_set_name]:
1533            if attr in global_set:
1534                self.global_policy.append(attr)
1535
1536    def _load_hooks(self):
1537        for op in self.ops.values():
1538            for op_mode in ['do', 'dump']:
1539                if op_mode not in op:
1540                    continue
1541                for when in ['pre', 'post']:
1542                    if when not in op[op_mode]:
1543                        continue
1544                    name = op[op_mode][when]
1545                    if name in self.hooks[when][op_mode]['set']:
1546                        continue
1547                    self.hooks[when][op_mode]['set'].add(name)
1548                    self.hooks[when][op_mode]['list'].append(name)
1549
1550
1551class RenderInfo:
1552    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1553        self.family = family
1554        self.nl = cw.nlib
1555        self.ku_space = ku_space
1556        self.op_mode = op_mode
1557        self.op = op
1558
1559        fixed_hdr = op.fixed_header if op else None
1560        self.fixed_hdr_len = 'ys->family->hdr_len'
1561        if op and op.fixed_header:
1562            if op.fixed_header != family.fixed_header:
1563                if family.is_classic():
1564                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1565                else:
1566                    raise Exception(f"Per-op fixed header not supported, yet")
1567
1568
1569        # 'do' and 'dump' response parsing is identical
1570        self.type_consistent = True
1571        self.type_oneside = False
1572        if op_mode != 'do' and 'dump' in op:
1573            if 'do' in op:
1574                if ('reply' in op['do']) != ('reply' in op["dump"]):
1575                    self.type_consistent = False
1576                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1577                    self.type_consistent = False
1578            else:
1579                self.type_consistent = True
1580                self.type_oneside = True
1581
1582        self.attr_set = attr_set
1583        if not self.attr_set:
1584            self.attr_set = op['attribute-set']
1585
1586        self.type_name_conflict = False
1587        if op:
1588            self.type_name = c_lower(op.name)
1589        else:
1590            self.type_name = c_lower(attr_set)
1591            if attr_set in family.consts:
1592                self.type_name_conflict = True
1593
1594        self.cw = cw
1595
1596        self.struct = dict()
1597        if op_mode == 'notify':
1598            op_mode = 'do' if 'do' in op else 'dump'
1599        for op_dir in ['request', 'reply']:
1600            if op:
1601                type_list = []
1602                if op_dir in op[op_mode]:
1603                    type_list = op[op_mode][op_dir]['attributes']
1604                self.struct[op_dir] = Struct(family, self.attr_set,
1605                                             fixed_header=fixed_hdr,
1606                                             type_list=type_list)
1607        if op_mode == 'event':
1608            self.struct['reply'] = Struct(family, self.attr_set,
1609                                          fixed_header=fixed_hdr,
1610                                          type_list=op['event']['attributes'])
1611
1612    def type_empty(self, key):
1613        return len(self.struct[key].attr_list) == 0 and \
1614            self.struct['request'].fixed_header is None
1615
1616    def needs_nlflags(self, direction):
1617        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1618
1619
1620class CodeWriter:
1621    def __init__(self, nlib, out_file=None, overwrite=True):
1622        self.nlib = nlib
1623        self._overwrite = overwrite
1624
1625        self._nl = False
1626        self._block_end = False
1627        self._silent_block = False
1628        self._ind = 0
1629        self._ifdef_block = None
1630        if out_file is None:
1631            self._out = os.sys.stdout
1632        else:
1633            self._out = tempfile.NamedTemporaryFile('w+')
1634            self._out_file = out_file
1635
1636    def __del__(self):
1637        self.close_out_file()
1638
1639    def close_out_file(self):
1640        if self._out == os.sys.stdout:
1641            return
1642        # Avoid modifying the file if contents didn't change
1643        self._out.flush()
1644        if not self._overwrite and os.path.isfile(self._out_file):
1645            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1646                return
1647        with open(self._out_file, 'w+') as out_file:
1648            self._out.seek(0)
1649            shutil.copyfileobj(self._out, out_file)
1650            self._out.close()
1651        self._out = os.sys.stdout
1652
1653    @classmethod
1654    def _is_cond(cls, line):
1655        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1656
1657    def p(self, line, add_ind=0):
1658        if self._block_end:
1659            self._block_end = False
1660            if line.startswith('else'):
1661                line = '} ' + line
1662            else:
1663                self._out.write('\t' * self._ind + '}\n')
1664
1665        if self._nl:
1666            self._out.write('\n')
1667            self._nl = False
1668
1669        ind = self._ind
1670        if line[-1] == ':':
1671            ind -= 1
1672        if self._silent_block:
1673            ind += 1
1674        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1675        self._silent_block |= line.strip() == 'else'
1676        if line[0] == '#':
1677            ind = 0
1678        if add_ind:
1679            ind += add_ind
1680        self._out.write('\t' * ind + line + '\n')
1681
1682    def nl(self):
1683        self._nl = True
1684
1685    def block_start(self, line=''):
1686        if line:
1687            line = line + ' '
1688        self.p(line + '{')
1689        self._ind += 1
1690
1691    def block_end(self, line=''):
1692        if line and line[0] not in {';', ','}:
1693            line = ' ' + line
1694        self._ind -= 1
1695        self._nl = False
1696        if not line:
1697            # Delay printing closing bracket in case "else" comes next
1698            if self._block_end:
1699                self._out.write('\t' * (self._ind + 1) + '}\n')
1700            self._block_end = True
1701        else:
1702            self.p('}' + line)
1703
1704    def write_doc_line(self, doc, indent=True):
1705        words = doc.split()
1706        line = ' *'
1707        for word in words:
1708            if len(line) + len(word) >= 79:
1709                self.p(line)
1710                line = ' *'
1711                if indent:
1712                    line += '  '
1713            line += ' ' + word
1714        self.p(line)
1715
1716    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1717        if not args:
1718            args = ['void']
1719
1720        if doc:
1721            self.p('/*')
1722            self.p(' * ' + doc)
1723            self.p(' */')
1724
1725        oneline = qual_ret
1726        if qual_ret[-1] != '*':
1727            oneline += ' '
1728        oneline += f"{name}({', '.join(args)}){suffix}"
1729
1730        if len(oneline) < 80:
1731            self.p(oneline)
1732            return
1733
1734        v = qual_ret
1735        if len(v) > 3:
1736            self.p(v)
1737            v = ''
1738        elif qual_ret[-1] != '*':
1739            v += ' '
1740        v += name + '('
1741        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1742        delta_ind = len(v) - len(ind)
1743        v += args[0]
1744        i = 1
1745        while i < len(args):
1746            next_len = len(v) + len(args[i])
1747            if v[0] == '\t':
1748                next_len += delta_ind
1749            if next_len > 76:
1750                self.p(v + ',')
1751                v = ind
1752            else:
1753                v += ', '
1754            v += args[i]
1755            i += 1
1756        self.p(v + ')' + suffix)
1757
1758    def write_func_lvar(self, local_vars):
1759        if not local_vars:
1760            return
1761
1762        if type(local_vars) is str:
1763            local_vars = [local_vars]
1764
1765        local_vars.sort(key=len, reverse=True)
1766        for var in local_vars:
1767            self.p(var)
1768        self.nl()
1769
1770    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1771        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1772        self.block_start()
1773        self.write_func_lvar(local_vars=local_vars)
1774
1775        for line in body:
1776            self.p(line)
1777        self.block_end()
1778
1779    def writes_defines(self, defines):
1780        longest = 0
1781        for define in defines:
1782            if len(define[0]) > longest:
1783                longest = len(define[0])
1784        longest = ((longest + 8) // 8) * 8
1785        for define in defines:
1786            line = '#define ' + define[0]
1787            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1788            if type(define[1]) is int:
1789                line += str(define[1])
1790            elif type(define[1]) is str:
1791                line += '"' + define[1] + '"'
1792            self.p(line)
1793
1794    def write_struct_init(self, members):
1795        longest = max([len(x[0]) for x in members])
1796        longest += 1  # because we prepend a .
1797        longest = ((longest + 8) // 8) * 8
1798        for one in members:
1799            line = '.' + one[0]
1800            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1801            line += '= ' + str(one[1]) + ','
1802            self.p(line)
1803
1804    def ifdef_block(self, config):
1805        config_option = None
1806        if config:
1807            config_option = 'CONFIG_' + c_upper(config)
1808        if self._ifdef_block == config_option:
1809            return
1810
1811        if self._ifdef_block:
1812            self.p('#endif /* ' + self._ifdef_block + ' */')
1813        if config_option:
1814            self.p('#ifdef ' + config_option)
1815        self._ifdef_block = config_option
1816
1817
1818scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1819
1820direction_to_suffix = {
1821    'reply': '_rsp',
1822    'request': '_req',
1823    '': ''
1824}
1825
1826op_mode_to_wrapper = {
1827    'do': '',
1828    'dump': '_list',
1829    'notify': '_ntf',
1830    'event': '',
1831}
1832
1833_C_KW = {
1834    'auto',
1835    'bool',
1836    'break',
1837    'case',
1838    'char',
1839    'const',
1840    'continue',
1841    'default',
1842    'do',
1843    'double',
1844    'else',
1845    'enum',
1846    'extern',
1847    'float',
1848    'for',
1849    'goto',
1850    'if',
1851    'inline',
1852    'int',
1853    'long',
1854    'register',
1855    'return',
1856    'short',
1857    'signed',
1858    'sizeof',
1859    'static',
1860    'struct',
1861    'switch',
1862    'typedef',
1863    'union',
1864    'unsigned',
1865    'void',
1866    'volatile',
1867    'while'
1868}
1869
1870
1871def rdir(direction):
1872    if direction == 'reply':
1873        return 'request'
1874    if direction == 'request':
1875        return 'reply'
1876    return direction
1877
1878
1879def op_prefix(ri, direction, deref=False):
1880    suffix = f"_{ri.type_name}"
1881
1882    if not ri.op_mode:
1883        pass
1884    elif ri.op_mode == 'do':
1885        suffix += f"{direction_to_suffix[direction]}"
1886    else:
1887        if direction == 'request':
1888            suffix += '_req'
1889            if not ri.type_oneside:
1890                suffix += '_dump'
1891        else:
1892            if ri.type_consistent:
1893                if deref:
1894                    suffix += f"{direction_to_suffix[direction]}"
1895                else:
1896                    suffix += op_mode_to_wrapper[ri.op_mode]
1897            else:
1898                suffix += '_rsp'
1899                suffix += '_dump' if deref else '_list'
1900
1901    return f"{ri.family.c_name}{suffix}"
1902
1903
1904def type_name(ri, direction, deref=False):
1905    return f"struct {op_prefix(ri, direction, deref=deref)}"
1906
1907
1908def print_prototype(ri, direction, terminate=True, doc=None):
1909    suffix = ';' if terminate else ''
1910
1911    fname = ri.op.render_name
1912    if ri.op_mode == 'dump':
1913        fname += '_dump'
1914
1915    args = ['struct ynl_sock *ys']
1916    if 'request' in ri.op[ri.op_mode]:
1917        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1918
1919    ret = 'int'
1920    if 'reply' in ri.op[ri.op_mode]:
1921        ret = f"{type_name(ri, rdir(direction))} *"
1922
1923    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1924
1925
1926def print_req_prototype(ri):
1927    print_prototype(ri, "request", doc=ri.op['doc'])
1928
1929
1930def print_dump_prototype(ri):
1931    print_prototype(ri, "request")
1932
1933
1934def put_typol_submsg(cw, struct):
1935    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1936
1937    i = 0
1938    for name, arg in struct.member_list():
1939        nest = ""
1940        if arg.type == 'nest':
1941            nest = f" .nest = &{arg.nested_render_name}_nest,"
1942        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s",%s },' %
1943             (i, name, nest))
1944        i += 1
1945
1946    cw.block_end(line=';')
1947    cw.nl()
1948
1949    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1950    cw.p(f'.max_attr = {i - 1},')
1951    cw.p(f'.table = {struct.render_name}_policy,')
1952    cw.block_end(line=';')
1953    cw.nl()
1954
1955
1956def put_typol_fwd(cw, struct):
1957    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1958
1959
1960def put_typol(cw, struct):
1961    if struct.submsg:
1962        put_typol_submsg(cw, struct)
1963        return
1964
1965    type_max = struct.attr_set.max_name
1966    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1967
1968    for _, arg in struct.member_list():
1969        arg.attr_typol(cw)
1970
1971    cw.block_end(line=';')
1972    cw.nl()
1973
1974    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1975    cw.p(f'.max_attr = {type_max},')
1976    cw.p(f'.table = {struct.render_name}_policy,')
1977    cw.block_end(line=';')
1978    cw.nl()
1979
1980
1981def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1982    args = [f'int {arg_name}']
1983    if enum:
1984        args = [enum.user_type + ' ' + arg_name]
1985    cw.write_func_prot('const char *', f'{render_name}_str', args)
1986    cw.block_start()
1987    if enum and enum.type == 'flags':
1988        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1989    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1990    cw.p('return NULL;')
1991    cw.p(f'return {map_name}[{arg_name}];')
1992    cw.block_end()
1993    cw.nl()
1994
1995
1996def put_op_name_fwd(family, cw):
1997    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1998
1999
2000def put_op_name(family, cw):
2001    map_name = f'{family.c_name}_op_strmap'
2002    cw.block_start(line=f"static const char * const {map_name}[] =")
2003    for op_name, op in family.msgs.items():
2004        if op.rsp_value:
2005            # Make sure we don't add duplicated entries, if multiple commands
2006            # produce the same response in legacy families.
2007            if family.rsp_by_value[op.rsp_value] != op:
2008                cw.p(f'// skip "{op_name}", duplicate reply value')
2009                continue
2010
2011            if op.req_value == op.rsp_value:
2012                cw.p(f'[{op.enum_name}] = "{op_name}",')
2013            else:
2014                cw.p(f'[{op.rsp_value}] = "{op_name}",')
2015    cw.block_end(line=';')
2016    cw.nl()
2017
2018    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2019
2020
2021def put_enum_to_str_fwd(family, cw, enum):
2022    args = [enum.user_type + ' value']
2023    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2024
2025
2026def put_enum_to_str(family, cw, enum):
2027    map_name = f'{enum.render_name}_strmap'
2028    cw.block_start(line=f"static const char * const {map_name}[] =")
2029    for entry in enum.entries.values():
2030        cw.p(f'[{entry.value}] = "{entry.name}",')
2031    cw.block_end(line=';')
2032    cw.nl()
2033
2034    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2035
2036
2037def put_req_nested_prototype(ri, struct, suffix=';'):
2038    func_args = ['struct nlmsghdr *nlh',
2039                 'unsigned int attr_type',
2040                 f'{struct.ptr_name}obj']
2041
2042    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2043                          suffix=suffix)
2044
2045
2046def put_req_nested(ri, struct):
2047    local_vars = []
2048    init_lines = []
2049
2050    if struct.submsg is None:
2051        local_vars.append('struct nlattr *nest;')
2052        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2053    if struct.fixed_header:
2054        local_vars.append('void *hdr;')
2055        struct_sz = f'sizeof({struct.fixed_header})'
2056        init_lines.append(f"hdr = ynl_nlmsg_put_extra_header(nlh, {struct_sz});")
2057        init_lines.append(f"memcpy(hdr, &obj->_hdr, {struct_sz});")
2058
2059    has_anest = False
2060    has_count = False
2061    for _, arg in struct.member_list():
2062        has_anest |= arg.type == 'indexed-array'
2063        has_count |= arg.presence_type() == 'count'
2064    if has_anest:
2065        local_vars.append('struct nlattr *array;')
2066    if has_count:
2067        local_vars.append('unsigned int i;')
2068
2069    put_req_nested_prototype(ri, struct, suffix='')
2070    ri.cw.block_start()
2071    ri.cw.write_func_lvar(local_vars)
2072
2073    for line in init_lines:
2074        ri.cw.p(line)
2075
2076    for _, arg in struct.member_list():
2077        arg.attr_put(ri, "obj")
2078
2079    if struct.submsg is None:
2080        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2081
2082    ri.cw.nl()
2083    ri.cw.p('return 0;')
2084    ri.cw.block_end()
2085    ri.cw.nl()
2086
2087
2088def _multi_parse(ri, struct, init_lines, local_vars):
2089    if struct.fixed_header:
2090        local_vars += ['void *hdr;']
2091    if struct.nested:
2092        if struct.fixed_header:
2093            iter_line = f"ynl_attr_for_each_nested_off(attr, nested, sizeof({struct.fixed_header}))"
2094        else:
2095            iter_line = "ynl_attr_for_each_nested(attr, nested)"
2096    else:
2097        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2098        if ri.op.fixed_header != ri.family.fixed_header:
2099            if ri.family.is_classic():
2100                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2101            else:
2102                raise Exception(f"Per-op fixed header not supported, yet")
2103
2104    array_nests = set()
2105    multi_attrs = set()
2106    needs_parg = False
2107    for arg, aspec in struct.member_list():
2108        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2109            if aspec["sub-type"] in {'binary', 'nest'}:
2110                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2111                array_nests.add(arg)
2112            elif aspec['sub-type'] in scalars:
2113                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2114                array_nests.add(arg)
2115            else:
2116                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2117        if 'multi-attr' in aspec:
2118            multi_attrs.add(arg)
2119        needs_parg |= 'nested-attributes' in aspec
2120        needs_parg |= 'sub-message' in aspec
2121    if array_nests or multi_attrs:
2122        local_vars.append('int i;')
2123    if needs_parg:
2124        local_vars.append('struct ynl_parse_arg parg;')
2125        init_lines.append('parg.ys = yarg->ys;')
2126
2127    all_multi = array_nests | multi_attrs
2128
2129    for anest in sorted(all_multi):
2130        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
2131
2132    ri.cw.block_start()
2133    ri.cw.write_func_lvar(local_vars)
2134
2135    for line in init_lines:
2136        ri.cw.p(line)
2137    ri.cw.nl()
2138
2139    for arg in struct.inherited:
2140        ri.cw.p(f'dst->{arg} = {arg};')
2141
2142    if struct.fixed_header:
2143        if struct.nested:
2144            ri.cw.p('hdr = ynl_attr_data(nested);')
2145        elif ri.family.is_classic():
2146            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2147        else:
2148            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2149        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2150    for anest in sorted(all_multi):
2151        aspec = struct[anest]
2152        ri.cw.p(f"if (dst->{aspec.c_name})")
2153        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2154
2155    ri.cw.nl()
2156    ri.cw.block_start(line=iter_line)
2157    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2158    ri.cw.nl()
2159
2160    first = True
2161    for _, arg in struct.member_list():
2162        good = arg.attr_get(ri, 'dst', first=first)
2163        # First may be 'unused' or 'pad', ignore those
2164        first &= not good
2165
2166    ri.cw.block_end()
2167    ri.cw.nl()
2168
2169    for anest in sorted(array_nests):
2170        aspec = struct[anest]
2171
2172        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2173        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2174        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2175        ri.cw.p('i = 0;')
2176        if 'nested-attributes' in aspec:
2177            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2178        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2179        if 'nested-attributes' in aspec:
2180            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2181            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2182            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2183        elif aspec.sub_type in scalars:
2184            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2185        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2186            # Length is validated by typol
2187            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2188        else:
2189            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2190        ri.cw.p('i++;')
2191        ri.cw.block_end()
2192        ri.cw.block_end()
2193    ri.cw.nl()
2194
2195    for anest in sorted(multi_attrs):
2196        aspec = struct[anest]
2197        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2198        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2199        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2200        ri.cw.p('i = 0;')
2201        if 'nested-attributes' in aspec:
2202            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2203        ri.cw.block_start(line=iter_line)
2204        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2205        if 'nested-attributes' in aspec:
2206            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2207            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2208            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2209        elif aspec.type in scalars:
2210            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2211        elif aspec.type == 'binary' and 'struct' in aspec:
2212            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2213            ri.cw.nl()
2214            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2215            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2216            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2217        elif aspec.type == 'string':
2218            ri.cw.p('unsigned int len;')
2219            ri.cw.nl()
2220            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2221            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2222            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2223            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2224            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2225        else:
2226            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2227        ri.cw.p('i++;')
2228        ri.cw.block_end()
2229        ri.cw.block_end()
2230        ri.cw.block_end()
2231    ri.cw.nl()
2232
2233    if struct.nested:
2234        ri.cw.p('return 0;')
2235    else:
2236        ri.cw.p('return YNL_PARSE_CB_OK;')
2237    ri.cw.block_end()
2238    ri.cw.nl()
2239
2240
2241def parse_rsp_submsg(ri, struct):
2242    parse_rsp_nested_prototype(ri, struct, suffix='')
2243
2244    var = 'dst'
2245    local_vars = {'const struct nlattr *attr = nested;',
2246                  f'{struct.ptr_name}{var} = yarg->data;',
2247                  'struct ynl_parse_arg parg;'}
2248
2249    for _, arg in struct.member_list():
2250        _, _, l_vars = arg._attr_get(ri, var)
2251        local_vars |= set(l_vars) if l_vars else set()
2252
2253    ri.cw.block_start()
2254    ri.cw.write_func_lvar(list(local_vars))
2255    ri.cw.p('parg.ys = yarg->ys;')
2256    ri.cw.nl()
2257
2258    first = True
2259    for name, arg in struct.member_list():
2260        kw = 'if' if first else 'else if'
2261        first = False
2262
2263        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2264        get_lines, init_lines, _ = arg._attr_get(ri, var)
2265        for line in init_lines or []:
2266            ri.cw.p(line)
2267        for line in get_lines:
2268            ri.cw.p(line)
2269        if arg.presence_type() == 'present':
2270            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2271        ri.cw.block_end()
2272    ri.cw.p('return 0;')
2273    ri.cw.block_end()
2274    ri.cw.nl()
2275
2276
2277def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2278    func_args = ['struct ynl_parse_arg *yarg',
2279                 'const struct nlattr *nested']
2280    for sel in struct.external_selectors():
2281        func_args.append('const char *_sel_' + sel.name)
2282    if struct.submsg:
2283        func_args.insert(1, 'const char *sel')
2284    for arg in struct.inherited:
2285        func_args.append('__u32 ' + arg)
2286
2287    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2288                          suffix=suffix)
2289
2290
2291def parse_rsp_nested(ri, struct):
2292    if struct.submsg:
2293        return parse_rsp_submsg(ri, struct)
2294
2295    parse_rsp_nested_prototype(ri, struct, suffix='')
2296
2297    local_vars = ['const struct nlattr *attr;',
2298                  f'{struct.ptr_name}dst = yarg->data;']
2299    init_lines = []
2300
2301    if struct.member_list():
2302        _multi_parse(ri, struct, init_lines, local_vars)
2303    else:
2304        # Empty nest
2305        ri.cw.block_start()
2306        ri.cw.p('return 0;')
2307        ri.cw.block_end()
2308        ri.cw.nl()
2309
2310
2311def parse_rsp_msg(ri, deref=False):
2312    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2313        return
2314
2315    func_args = ['const struct nlmsghdr *nlh',
2316                 'struct ynl_parse_arg *yarg']
2317
2318    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2319                  'const struct nlattr *attr;']
2320    init_lines = ['dst = yarg->data;']
2321
2322    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2323
2324    if ri.struct["reply"].member_list():
2325        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2326    else:
2327        # Empty reply
2328        ri.cw.block_start()
2329        ri.cw.p('return YNL_PARSE_CB_OK;')
2330        ri.cw.block_end()
2331        ri.cw.nl()
2332
2333
2334def print_req(ri):
2335    ret_ok = '0'
2336    ret_err = '-1'
2337    direction = "request"
2338    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2339                  'struct nlmsghdr *nlh;',
2340                  'int err;']
2341
2342    if 'reply' in ri.op[ri.op_mode]:
2343        ret_ok = 'rsp'
2344        ret_err = 'NULL'
2345        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2346
2347    if ri.struct["request"].fixed_header:
2348        local_vars += ['size_t hdr_len;',
2349                       'void *hdr;']
2350
2351    for _, attr in ri.struct["request"].member_list():
2352        if attr.presence_type() == 'count':
2353            local_vars += ['unsigned int i;']
2354            break
2355
2356    print_prototype(ri, direction, terminate=False)
2357    ri.cw.block_start()
2358    ri.cw.write_func_lvar(local_vars)
2359
2360    if ri.family.is_classic():
2361        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2362    else:
2363        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2364
2365    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2366    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2367    if 'reply' in ri.op[ri.op_mode]:
2368        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2369    ri.cw.nl()
2370
2371    if ri.struct['request'].fixed_header:
2372        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2373        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2374        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2375        ri.cw.nl()
2376
2377    for _, attr in ri.struct["request"].member_list():
2378        attr.attr_put(ri, "req")
2379    ri.cw.nl()
2380
2381    if 'reply' in ri.op[ri.op_mode]:
2382        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2383        ri.cw.p('yrs.yarg.data = rsp;')
2384        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2385        if ri.op.value is not None:
2386            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2387        else:
2388            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2389        ri.cw.nl()
2390    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2391    ri.cw.p('if (err < 0)')
2392    if 'reply' in ri.op[ri.op_mode]:
2393        ri.cw.p('goto err_free;')
2394    else:
2395        ri.cw.p('return -1;')
2396    ri.cw.nl()
2397
2398    ri.cw.p(f"return {ret_ok};")
2399    ri.cw.nl()
2400
2401    if 'reply' in ri.op[ri.op_mode]:
2402        ri.cw.p('err_free:')
2403        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2404        ri.cw.p(f"return {ret_err};")
2405
2406    ri.cw.block_end()
2407
2408
2409def print_dump(ri):
2410    direction = "request"
2411    print_prototype(ri, direction, terminate=False)
2412    ri.cw.block_start()
2413    local_vars = ['struct ynl_dump_state yds = {};',
2414                  'struct nlmsghdr *nlh;',
2415                  'int err;']
2416
2417    if ri.struct['request'].fixed_header:
2418        local_vars += ['size_t hdr_len;',
2419                       'void *hdr;']
2420
2421    ri.cw.write_func_lvar(local_vars)
2422
2423    ri.cw.p('yds.yarg.ys = ys;')
2424    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2425    ri.cw.p("yds.yarg.data = NULL;")
2426    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2427    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2428    if ri.op.value is not None:
2429        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2430    else:
2431        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2432    ri.cw.nl()
2433    if ri.family.is_classic():
2434        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2435    else:
2436        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2437
2438    if ri.struct['request'].fixed_header:
2439        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2440        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2441        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2442        ri.cw.nl()
2443
2444    if "request" in ri.op[ri.op_mode]:
2445        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2446        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2447        ri.cw.nl()
2448        for _, attr in ri.struct["request"].member_list():
2449            attr.attr_put(ri, "req")
2450    ri.cw.nl()
2451
2452    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2453    ri.cw.p('if (err < 0)')
2454    ri.cw.p('goto free_list;')
2455    ri.cw.nl()
2456
2457    ri.cw.p('return yds.first;')
2458    ri.cw.nl()
2459    ri.cw.p('free_list:')
2460    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2461    ri.cw.p('return NULL;')
2462    ri.cw.block_end()
2463
2464
2465def call_free(ri, direction, var):
2466    return f"{op_prefix(ri, direction)}_free({var});"
2467
2468
2469def free_arg_name(direction):
2470    if direction:
2471        return direction_to_suffix[direction][1:]
2472    return 'obj'
2473
2474
2475def print_alloc_wrapper(ri, direction):
2476    name = op_prefix(ri, direction)
2477    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
2478    ri.cw.block_start()
2479    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
2480    ri.cw.block_end()
2481
2482
2483def print_free_prototype(ri, direction, suffix=';'):
2484    name = op_prefix(ri, direction)
2485    struct_name = name
2486    if ri.type_name_conflict:
2487        struct_name += '_'
2488    arg = free_arg_name(direction)
2489    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2490
2491
2492def print_nlflags_set(ri, direction):
2493    name = op_prefix(ri, direction)
2494    ri.cw.write_func_prot(f'static inline void', f"{name}_set_nlflags",
2495                          [f"struct {name} *req", "__u16 nl_flags"])
2496    ri.cw.block_start()
2497    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2498    ri.cw.block_end()
2499    ri.cw.nl()
2500
2501
2502def _print_type(ri, direction, struct):
2503    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2504    if not direction and ri.type_name_conflict:
2505        suffix += '_'
2506
2507    if ri.op_mode == 'dump' and not ri.type_oneside:
2508        suffix += '_dump'
2509
2510    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2511
2512    if ri.needs_nlflags(direction):
2513        ri.cw.p('__u16 _nlmsg_flags;')
2514        ri.cw.nl()
2515    if struct.fixed_header:
2516        ri.cw.p(struct.fixed_header + ' _hdr;')
2517        ri.cw.nl()
2518
2519    for type_filter in ['present', 'len', 'count']:
2520        meta_started = False
2521        for _, attr in struct.member_list():
2522            line = attr.presence_member(ri.ku_space, type_filter)
2523            if line:
2524                if not meta_started:
2525                    ri.cw.block_start(line=f"struct")
2526                    meta_started = True
2527                ri.cw.p(line)
2528        if meta_started:
2529            ri.cw.block_end(line=f'_{type_filter};')
2530    ri.cw.nl()
2531
2532    for arg in struct.inherited:
2533        ri.cw.p(f"__u32 {arg};")
2534
2535    for _, attr in struct.member_list():
2536        attr.struct_member(ri)
2537
2538    ri.cw.block_end(line=';')
2539    ri.cw.nl()
2540
2541
2542def print_type(ri, direction):
2543    _print_type(ri, direction, ri.struct[direction])
2544
2545
2546def print_type_full(ri, struct):
2547    _print_type(ri, "", struct)
2548
2549    if struct.request and struct.in_multi_val:
2550        free_rsp_nested_prototype(ri)
2551        ri.cw.nl()
2552
2553
2554def print_type_helpers(ri, direction, deref=False):
2555    print_free_prototype(ri, direction)
2556    ri.cw.nl()
2557
2558    if ri.needs_nlflags(direction):
2559        print_nlflags_set(ri, direction)
2560
2561    if ri.ku_space == 'user' and direction == 'request':
2562        for _, attr in ri.struct[direction].member_list():
2563            attr.setter(ri, ri.attr_set, direction, deref=deref)
2564    ri.cw.nl()
2565
2566
2567def print_req_type_helpers(ri):
2568    if ri.type_empty("request"):
2569        return
2570    print_alloc_wrapper(ri, "request")
2571    print_type_helpers(ri, "request")
2572
2573
2574def print_rsp_type_helpers(ri):
2575    if 'reply' not in ri.op[ri.op_mode]:
2576        return
2577    print_type_helpers(ri, "reply")
2578
2579
2580def print_parse_prototype(ri, direction, terminate=True):
2581    suffix = "_rsp" if direction == "reply" else "_req"
2582    term = ';' if terminate else ''
2583
2584    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2585                          ['const struct nlattr **tb',
2586                           f"struct {ri.op.render_name}{suffix} *req"],
2587                          suffix=term)
2588
2589
2590def print_req_type(ri):
2591    if ri.type_empty("request"):
2592        return
2593    print_type(ri, "request")
2594
2595
2596def print_req_free(ri):
2597    if 'request' not in ri.op[ri.op_mode]:
2598        return
2599    _free_type(ri, 'request', ri.struct['request'])
2600
2601
2602def print_rsp_type(ri):
2603    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2604        direction = 'reply'
2605    elif ri.op_mode == 'event':
2606        direction = 'reply'
2607    else:
2608        return
2609    print_type(ri, direction)
2610
2611
2612def print_wrapped_type(ri):
2613    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2614    if ri.op_mode == 'dump':
2615        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2616    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2617        ri.cw.p('__u16 family;')
2618        ri.cw.p('__u8 cmd;')
2619        ri.cw.p('struct ynl_ntf_base_type *next;')
2620        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2621    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2622    ri.cw.block_end(line=';')
2623    ri.cw.nl()
2624    print_free_prototype(ri, 'reply')
2625    ri.cw.nl()
2626
2627
2628def _free_type_members_iter(ri, struct):
2629    if struct.free_needs_iter():
2630        ri.cw.p('unsigned int i;')
2631        ri.cw.nl()
2632
2633
2634def _free_type_members(ri, var, struct, ref=''):
2635    for _, attr in struct.member_list():
2636        attr.free(ri, var, ref)
2637
2638
2639def _free_type(ri, direction, struct):
2640    var = free_arg_name(direction)
2641
2642    print_free_prototype(ri, direction, suffix='')
2643    ri.cw.block_start()
2644    _free_type_members_iter(ri, struct)
2645    _free_type_members(ri, var, struct)
2646    if direction:
2647        ri.cw.p(f'free({var});')
2648    ri.cw.block_end()
2649    ri.cw.nl()
2650
2651
2652def free_rsp_nested_prototype(ri):
2653        print_free_prototype(ri, "")
2654
2655
2656def free_rsp_nested(ri, struct):
2657    _free_type(ri, "", struct)
2658
2659
2660def print_rsp_free(ri):
2661    if 'reply' not in ri.op[ri.op_mode]:
2662        return
2663    _free_type(ri, 'reply', ri.struct['reply'])
2664
2665
2666def print_dump_type_free(ri):
2667    sub_type = type_name(ri, 'reply')
2668
2669    print_free_prototype(ri, 'reply', suffix='')
2670    ri.cw.block_start()
2671    ri.cw.p(f"{sub_type} *next = rsp;")
2672    ri.cw.nl()
2673    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2674    _free_type_members_iter(ri, ri.struct['reply'])
2675    ri.cw.p('rsp = next;')
2676    ri.cw.p('next = rsp->next;')
2677    ri.cw.nl()
2678
2679    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2680    ri.cw.p(f'free(rsp);')
2681    ri.cw.block_end()
2682    ri.cw.block_end()
2683    ri.cw.nl()
2684
2685
2686def print_ntf_type_free(ri):
2687    print_free_prototype(ri, 'reply', suffix='')
2688    ri.cw.block_start()
2689    _free_type_members_iter(ri, ri.struct['reply'])
2690    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2691    ri.cw.p(f'free(rsp);')
2692    ri.cw.block_end()
2693    ri.cw.nl()
2694
2695
2696def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2697    if terminate and ri and policy_should_be_static(struct.family):
2698        return
2699
2700    if terminate:
2701        prefix = 'extern '
2702    else:
2703        if ri and policy_should_be_static(struct.family):
2704            prefix = 'static '
2705        else:
2706            prefix = ''
2707
2708    suffix = ';' if terminate else ' = {'
2709
2710    max_attr = struct.attr_max_val
2711    if ri:
2712        name = ri.op.render_name
2713        if ri.op.dual_policy:
2714            name += '_' + ri.op_mode
2715    else:
2716        name = struct.render_name
2717    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2718
2719
2720def print_req_policy(cw, struct, ri=None):
2721    if ri and ri.op:
2722        cw.ifdef_block(ri.op.get('config-cond', None))
2723    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2724    for _, arg in struct.member_list():
2725        arg.attr_policy(cw)
2726    cw.p("};")
2727    cw.ifdef_block(None)
2728    cw.nl()
2729
2730
2731def kernel_can_gen_family_struct(family):
2732    return family.proto == 'genetlink'
2733
2734
2735def policy_should_be_static(family):
2736    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2737
2738
2739def print_kernel_policy_ranges(family, cw):
2740    first = True
2741    for _, attr_set in family.attr_sets.items():
2742        if attr_set.subset_of:
2743            continue
2744
2745        for _, attr in attr_set.items():
2746            if not attr.request:
2747                continue
2748            if 'full-range' not in attr.checks:
2749                continue
2750
2751            if first:
2752                cw.p('/* Integer value ranges */')
2753                first = False
2754
2755            sign = '' if attr.type[0] == 'u' else '_signed'
2756            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2757            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2758            members = []
2759            if 'min' in attr.checks:
2760                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2761            if 'max' in attr.checks:
2762                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2763            cw.write_struct_init(members)
2764            cw.block_end(line=';')
2765            cw.nl()
2766
2767
2768def print_kernel_policy_sparse_enum_validates(family, cw):
2769    first = True
2770    for _, attr_set in family.attr_sets.items():
2771        if attr_set.subset_of:
2772            continue
2773
2774        for _, attr in attr_set.items():
2775            if not attr.request:
2776                continue
2777            if not attr.enum_name:
2778                continue
2779            if 'sparse' not in attr.checks:
2780                continue
2781
2782            if first:
2783                cw.p('/* Sparse enums validation callbacks */')
2784                first = False
2785
2786            sign = '' if attr.type[0] == 'u' else '_signed'
2787            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2788            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2789                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2790            cw.block_start()
2791            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2792            enum = family.consts[attr['enum']]
2793            first_entry = True
2794            for entry in enum.entries.values():
2795                if first_entry:
2796                    first_entry = False
2797                else:
2798                    cw.p('fallthrough;')
2799                cw.p(f'case {entry.c_name}:')
2800            cw.p('return 0;')
2801            cw.block_end()
2802            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2803            cw.p('return -EINVAL;')
2804            cw.block_end()
2805            cw.nl()
2806
2807
2808def print_kernel_op_table_fwd(family, cw, terminate):
2809    exported = not kernel_can_gen_family_struct(family)
2810
2811    if not terminate or exported:
2812        cw.p(f"/* Ops table for {family.ident_name} */")
2813
2814        pol_to_struct = {'global': 'genl_small_ops',
2815                         'per-op': 'genl_ops',
2816                         'split': 'genl_split_ops'}
2817        struct_type = pol_to_struct[family.kernel_policy]
2818
2819        if not exported:
2820            cnt = ""
2821        elif family.kernel_policy == 'split':
2822            cnt = 0
2823            for op in family.ops.values():
2824                if 'do' in op:
2825                    cnt += 1
2826                if 'dump' in op:
2827                    cnt += 1
2828        else:
2829            cnt = len(family.ops)
2830
2831        qual = 'static const' if not exported else 'const'
2832        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2833        if terminate:
2834            cw.p(f"extern {line};")
2835        else:
2836            cw.block_start(line=line + ' =')
2837
2838    if not terminate:
2839        return
2840
2841    cw.nl()
2842    for name in family.hooks['pre']['do']['list']:
2843        cw.write_func_prot('int', c_lower(name),
2844                           ['const struct genl_split_ops *ops',
2845                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2846    for name in family.hooks['post']['do']['list']:
2847        cw.write_func_prot('void', c_lower(name),
2848                           ['const struct genl_split_ops *ops',
2849                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2850    for name in family.hooks['pre']['dump']['list']:
2851        cw.write_func_prot('int', c_lower(name),
2852                           ['struct netlink_callback *cb'], suffix=';')
2853    for name in family.hooks['post']['dump']['list']:
2854        cw.write_func_prot('int', c_lower(name),
2855                           ['struct netlink_callback *cb'], suffix=';')
2856
2857    cw.nl()
2858
2859    for op_name, op in family.ops.items():
2860        if op.is_async:
2861            continue
2862
2863        if 'do' in op:
2864            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2865            cw.write_func_prot('int', name,
2866                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2867
2868        if 'dump' in op:
2869            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2870            cw.write_func_prot('int', name,
2871                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2872    cw.nl()
2873
2874
2875def print_kernel_op_table_hdr(family, cw):
2876    print_kernel_op_table_fwd(family, cw, terminate=True)
2877
2878
2879def print_kernel_op_table(family, cw):
2880    print_kernel_op_table_fwd(family, cw, terminate=False)
2881    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2882        for op_name, op in family.ops.items():
2883            if op.is_async:
2884                continue
2885
2886            cw.ifdef_block(op.get('config-cond', None))
2887            cw.block_start()
2888            members = [('cmd', op.enum_name)]
2889            if 'dont-validate' in op:
2890                members.append(('validate',
2891                                ' | '.join([c_upper('genl-dont-validate-' + x)
2892                                            for x in op['dont-validate']])), )
2893            for op_mode in ['do', 'dump']:
2894                if op_mode in op:
2895                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2896                    members.append((op_mode + 'it', name))
2897            if family.kernel_policy == 'per-op':
2898                struct = Struct(family, op['attribute-set'],
2899                                type_list=op['do']['request']['attributes'])
2900
2901                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2902                members.append(('policy', name))
2903                members.append(('maxattr', struct.attr_max_val.enum_name))
2904            if 'flags' in op:
2905                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2906            cw.write_struct_init(members)
2907            cw.block_end(line=',')
2908    elif family.kernel_policy == 'split':
2909        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2910                    'dump': {'pre': 'start', 'post': 'done'}}
2911
2912        for op_name, op in family.ops.items():
2913            for op_mode in ['do', 'dump']:
2914                if op.is_async or op_mode not in op:
2915                    continue
2916
2917                cw.ifdef_block(op.get('config-cond', None))
2918                cw.block_start()
2919                members = [('cmd', op.enum_name)]
2920                if 'dont-validate' in op:
2921                    dont_validate = []
2922                    for x in op['dont-validate']:
2923                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2924                            continue
2925                        if op_mode == "dump" and x == 'strict':
2926                            continue
2927                        dont_validate.append(x)
2928
2929                    if dont_validate:
2930                        members.append(('validate',
2931                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2932                                                    for x in dont_validate])), )
2933                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2934                if 'pre' in op[op_mode]:
2935                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2936                members.append((op_mode + 'it', name))
2937                if 'post' in op[op_mode]:
2938                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2939                if 'request' in op[op_mode]:
2940                    struct = Struct(family, op['attribute-set'],
2941                                    type_list=op[op_mode]['request']['attributes'])
2942
2943                    if op.dual_policy:
2944                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2945                    else:
2946                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2947                    members.append(('policy', name))
2948                    members.append(('maxattr', struct.attr_max_val.enum_name))
2949                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2950                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2951                cw.write_struct_init(members)
2952                cw.block_end(line=',')
2953    cw.ifdef_block(None)
2954
2955    cw.block_end(line=';')
2956    cw.nl()
2957
2958
2959def print_kernel_mcgrp_hdr(family, cw):
2960    if not family.mcgrps['list']:
2961        return
2962
2963    cw.block_start('enum')
2964    for grp in family.mcgrps['list']:
2965        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2966        cw.p(grp_id)
2967    cw.block_end(';')
2968    cw.nl()
2969
2970
2971def print_kernel_mcgrp_src(family, cw):
2972    if not family.mcgrps['list']:
2973        return
2974
2975    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2976    for grp in family.mcgrps['list']:
2977        name = grp['name']
2978        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2979        cw.p('[' + grp_id + '] = { "' + name + '", },')
2980    cw.block_end(';')
2981    cw.nl()
2982
2983
2984def print_kernel_family_struct_hdr(family, cw):
2985    if not kernel_can_gen_family_struct(family):
2986        return
2987
2988    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2989    cw.nl()
2990    if 'sock-priv' in family.kernel_family:
2991        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2992        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2993        cw.nl()
2994
2995
2996def print_kernel_family_struct_src(family, cw):
2997    if not kernel_can_gen_family_struct(family):
2998        return
2999
3000    if 'sock-priv' in family.kernel_family:
3001        # Generate "trampolines" to make CFI happy
3002        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
3003                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
3004                      ["void *priv"])
3005        cw.nl()
3006        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
3007                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
3008                      ["void *priv"])
3009        cw.nl()
3010
3011    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
3012    cw.p('.name\t\t= ' + family.fam_key + ',')
3013    cw.p('.version\t= ' + family.ver_key + ',')
3014    cw.p('.netnsok\t= true,')
3015    cw.p('.parallel_ops\t= true,')
3016    cw.p('.module\t\t= THIS_MODULE,')
3017    if family.kernel_policy == 'per-op':
3018        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
3019        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3020    elif family.kernel_policy == 'split':
3021        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
3022        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3023    if family.mcgrps['list']:
3024        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
3025        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
3026    if 'sock-priv' in family.kernel_family:
3027        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
3028        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
3029        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
3030    cw.block_end(';')
3031
3032
3033def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
3034    start_line = 'enum'
3035    if enum_name in obj:
3036        if obj[enum_name]:
3037            start_line = 'enum ' + c_lower(obj[enum_name])
3038    elif ckey and ckey in obj:
3039        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3040    cw.block_start(line=start_line)
3041
3042
3043def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3044    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3045    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3046    max_value = f"({cnt_name} - 1)"
3047
3048    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3049    val = 0
3050    for op in family.msgs.values():
3051        if separate_ntf and ('notify' in op or 'event' in op):
3052            continue
3053
3054        suffix = ','
3055        if op.value != val:
3056            suffix = f" = {op.value},"
3057            val = op.value
3058        cw.p(op.enum_name + suffix)
3059        val += 1
3060    cw.nl()
3061    cw.p(cnt_name + ('' if max_by_define else ','))
3062    if not max_by_define:
3063        cw.p(f"{max_name} = {max_value}")
3064    cw.block_end(line=';')
3065    if max_by_define:
3066        cw.p(f"#define {max_name} {max_value}")
3067    cw.nl()
3068
3069
3070def render_uapi_directional(family, cw, max_by_define):
3071    max_name = f"{family.op_prefix}USER_MAX"
3072    cnt_name = f"__{family.op_prefix}USER_CNT"
3073    max_value = f"({cnt_name} - 1)"
3074
3075    cw.block_start(line='enum')
3076    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3077    val = 0
3078    for op in family.msgs.values():
3079        if 'do' in op and 'event' not in op:
3080            suffix = ','
3081            if op.value and op.value != val:
3082                suffix = f" = {op.value},"
3083                val = op.value
3084            cw.p(op.enum_name + suffix)
3085            val += 1
3086    cw.nl()
3087    cw.p(cnt_name + ('' if max_by_define else ','))
3088    if not max_by_define:
3089        cw.p(f"{max_name} = {max_value}")
3090    cw.block_end(line=';')
3091    if max_by_define:
3092        cw.p(f"#define {max_name} {max_value}")
3093    cw.nl()
3094
3095    max_name = f"{family.op_prefix}KERNEL_MAX"
3096    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3097    max_value = f"({cnt_name} - 1)"
3098
3099    cw.block_start(line='enum')
3100    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3101    val = 0
3102    for op in family.msgs.values():
3103        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3104            enum_name = op.enum_name
3105            if 'event' not in op and 'notify' not in op:
3106                enum_name = f'{enum_name}_REPLY'
3107
3108            suffix = ','
3109            if op.value and op.value != val:
3110                suffix = f" = {op.value},"
3111                val = op.value
3112            cw.p(enum_name + suffix)
3113            val += 1
3114    cw.nl()
3115    cw.p(cnt_name + ('' if max_by_define else ','))
3116    if not max_by_define:
3117        cw.p(f"{max_name} = {max_value}")
3118    cw.block_end(line=';')
3119    if max_by_define:
3120        cw.p(f"#define {max_name} {max_value}")
3121    cw.nl()
3122
3123
3124def render_uapi(family, cw):
3125    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3126    hdr_prot = hdr_prot.replace('/', '_')
3127    cw.p('#ifndef ' + hdr_prot)
3128    cw.p('#define ' + hdr_prot)
3129    cw.nl()
3130
3131    defines = [(family.fam_key, family["name"]),
3132               (family.ver_key, family.get('version', 1))]
3133    cw.writes_defines(defines)
3134    cw.nl()
3135
3136    defines = []
3137    for const in family['definitions']:
3138        if const.get('header'):
3139            continue
3140
3141        if const['type'] != 'const':
3142            cw.writes_defines(defines)
3143            defines = []
3144            cw.nl()
3145
3146        # Write kdoc for enum and flags (one day maybe also structs)
3147        if const['type'] == 'enum' or const['type'] == 'flags':
3148            enum = family.consts[const['name']]
3149
3150            if enum.header:
3151                continue
3152
3153            if enum.has_doc():
3154                if enum.has_entry_doc():
3155                    cw.p('/**')
3156                    doc = ''
3157                    if 'doc' in enum:
3158                        doc = ' - ' + enum['doc']
3159                    cw.write_doc_line(enum.enum_name + doc)
3160                else:
3161                    cw.p('/*')
3162                    cw.write_doc_line(enum['doc'], indent=False)
3163                for entry in enum.entries.values():
3164                    if entry.has_doc():
3165                        doc = '@' + entry.c_name + ': ' + entry['doc']
3166                        cw.write_doc_line(doc)
3167                cw.p(' */')
3168
3169            uapi_enum_start(family, cw, const, 'name')
3170            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3171            for entry in enum.entries.values():
3172                suffix = ','
3173                if entry.value_change:
3174                    suffix = f" = {entry.user_value()}" + suffix
3175                cw.p(entry.c_name + suffix)
3176
3177            if const.get('render-max', False):
3178                cw.nl()
3179                cw.p('/* private: */')
3180                if const['type'] == 'flags':
3181                    max_name = c_upper(name_pfx + 'mask')
3182                    max_val = f' = {enum.get_mask()},'
3183                    cw.p(max_name + max_val)
3184                else:
3185                    cnt_name = enum.enum_cnt_name
3186                    max_name = c_upper(name_pfx + 'max')
3187                    if not cnt_name:
3188                        cnt_name = '__' + name_pfx + 'max'
3189                    cw.p(c_upper(cnt_name) + ',')
3190                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3191            cw.block_end(line=';')
3192            cw.nl()
3193        elif const['type'] == 'const':
3194            defines.append([c_upper(family.get('c-define-name',
3195                                               f"{family.ident_name}-{const['name']}")),
3196                            const['value']])
3197
3198    if defines:
3199        cw.writes_defines(defines)
3200        cw.nl()
3201
3202    max_by_define = family.get('max-by-define', False)
3203
3204    for _, attr_set in family.attr_sets.items():
3205        if attr_set.subset_of:
3206            continue
3207
3208        max_value = f"({attr_set.cnt_name} - 1)"
3209
3210        val = 0
3211        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3212        for _, attr in attr_set.items():
3213            suffix = ','
3214            if attr.value != val:
3215                suffix = f" = {attr.value},"
3216                val = attr.value
3217            val += 1
3218            cw.p(attr.enum_name + suffix)
3219        if attr_set.items():
3220            cw.nl()
3221        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3222        if not max_by_define:
3223            cw.p(f"{attr_set.max_name} = {max_value}")
3224        cw.block_end(line=';')
3225        if max_by_define:
3226            cw.p(f"#define {attr_set.max_name} {max_value}")
3227        cw.nl()
3228
3229    # Commands
3230    separate_ntf = 'async-prefix' in family['operations']
3231
3232    if family.msg_id_model == 'unified':
3233        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3234    elif family.msg_id_model == 'directional':
3235        render_uapi_directional(family, cw, max_by_define)
3236    else:
3237        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3238
3239    if separate_ntf:
3240        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3241        for op in family.msgs.values():
3242            if separate_ntf and not ('notify' in op or 'event' in op):
3243                continue
3244
3245            suffix = ','
3246            if 'value' in op:
3247                suffix = f" = {op['value']},"
3248            cw.p(op.enum_name + suffix)
3249        cw.block_end(line=';')
3250        cw.nl()
3251
3252    # Multicast
3253    defines = []
3254    for grp in family.mcgrps['list']:
3255        name = grp['name']
3256        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3257                        f'{name}'])
3258    cw.nl()
3259    if defines:
3260        cw.writes_defines(defines)
3261        cw.nl()
3262
3263    cw.p(f'#endif /* {hdr_prot} */')
3264
3265
3266def _render_user_ntf_entry(ri, op):
3267    if not ri.family.is_classic():
3268        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3269    else:
3270        crud_op = ri.family.req_by_value[op.rsp_value]
3271        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3272    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3273    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3274    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3275    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3276    ri.cw.block_end(line=',')
3277
3278
3279def render_user_family(family, cw, prototype):
3280    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3281    if prototype:
3282        cw.p(f'extern {symbol};')
3283        return
3284
3285    if family.ntfs:
3286        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3287        for ntf_op_name, ntf_op in family.ntfs.items():
3288            if 'notify' in ntf_op:
3289                op = family.ops[ntf_op['notify']]
3290                ri = RenderInfo(cw, family, "user", op, "notify")
3291            elif 'event' in ntf_op:
3292                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3293            else:
3294                raise Exception('Invalid notification ' + ntf_op_name)
3295            _render_user_ntf_entry(ri, ntf_op)
3296        for op_name, op in family.ops.items():
3297            if 'event' not in op:
3298                continue
3299            ri = RenderInfo(cw, family, "user", op, "event")
3300            _render_user_ntf_entry(ri, op)
3301        cw.block_end(line=";")
3302        cw.nl()
3303
3304    cw.block_start(f'{symbol} = ')
3305    cw.p(f'.name\t\t= "{family.c_name}",')
3306    if family.is_classic():
3307        cw.p(f'.is_classic\t= true,')
3308        cw.p(f'.classic_id\t= {family.get("protonum")},')
3309    if family.is_classic():
3310        if family.fixed_header:
3311            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3312    elif family.fixed_header:
3313        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3314    else:
3315        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3316    if family.ntfs:
3317        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3318        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3319    cw.block_end(line=';')
3320
3321
3322def family_contains_bitfield32(family):
3323    for _, attr_set in family.attr_sets.items():
3324        if attr_set.subset_of:
3325            continue
3326        for _, attr in attr_set.items():
3327            if attr.type == "bitfield32":
3328                return True
3329    return False
3330
3331
3332def find_kernel_root(full_path):
3333    sub_path = ''
3334    while True:
3335        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3336        full_path = os.path.dirname(full_path)
3337        maintainers = os.path.join(full_path, "MAINTAINERS")
3338        if os.path.exists(maintainers):
3339            return full_path, sub_path[:-1]
3340
3341
3342def main():
3343    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3344    parser.add_argument('--mode', dest='mode', type=str, required=True,
3345                        choices=('user', 'kernel', 'uapi'))
3346    parser.add_argument('--spec', dest='spec', type=str, required=True)
3347    parser.add_argument('--header', dest='header', action='store_true', default=None)
3348    parser.add_argument('--source', dest='header', action='store_false')
3349    parser.add_argument('--user-header', nargs='+', default=[])
3350    parser.add_argument('--cmp-out', action='store_true', default=None,
3351                        help='Do not overwrite the output file if the new output is identical to the old')
3352    parser.add_argument('--exclude-op', action='append', default=[])
3353    parser.add_argument('-o', dest='out_file', type=str, default=None)
3354    args = parser.parse_args()
3355
3356    if args.header is None:
3357        parser.error("--header or --source is required")
3358
3359    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3360
3361    try:
3362        parsed = Family(args.spec, exclude_ops)
3363        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3364            print('Spec license:', parsed.license)
3365            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3366            os.sys.exit(1)
3367    except yaml.YAMLError as exc:
3368        print(exc)
3369        os.sys.exit(1)
3370        return
3371
3372    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3373
3374    _, spec_kernel = find_kernel_root(args.spec)
3375    if args.mode == 'uapi' or args.header:
3376        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3377    else:
3378        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3379    cw.p("/* Do not edit directly, auto-generated from: */")
3380    cw.p(f"/*\t{spec_kernel} */")
3381    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3382    if args.exclude_op or args.user_header:
3383        line = ''
3384        line += ' --user-header '.join([''] + args.user_header)
3385        line += ' --exclude-op '.join([''] + args.exclude_op)
3386        cw.p(f'/* YNL-ARG{line} */')
3387    cw.nl()
3388
3389    if args.mode == 'uapi':
3390        render_uapi(parsed, cw)
3391        return
3392
3393    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3394    if args.header:
3395        cw.p('#ifndef ' + hdr_prot)
3396        cw.p('#define ' + hdr_prot)
3397        cw.nl()
3398
3399    if args.out_file:
3400        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3401    else:
3402        hdr_file = "generated_header_file.h"
3403
3404    if args.mode == 'kernel':
3405        cw.p('#include <net/netlink.h>')
3406        cw.p('#include <net/genetlink.h>')
3407        cw.nl()
3408        if not args.header:
3409            if args.out_file:
3410                cw.p(f'#include "{hdr_file}"')
3411            cw.nl()
3412        headers = ['uapi/' + parsed.uapi_header]
3413        headers += parsed.kernel_family.get('headers', [])
3414    else:
3415        cw.p('#include <stdlib.h>')
3416        cw.p('#include <string.h>')
3417        if args.header:
3418            cw.p('#include <linux/types.h>')
3419            if family_contains_bitfield32(parsed):
3420                cw.p('#include <linux/netlink.h>')
3421        else:
3422            cw.p(f'#include "{hdr_file}"')
3423            cw.p('#include "ynl.h"')
3424        headers = []
3425    for definition in parsed['definitions'] + parsed['attribute-sets']:
3426        if 'header' in definition:
3427            headers.append(definition['header'])
3428    if args.mode == 'user':
3429        headers.append(parsed.uapi_header)
3430    seen_header = []
3431    for one in headers:
3432        if one not in seen_header:
3433            cw.p(f"#include <{one}>")
3434            seen_header.append(one)
3435    cw.nl()
3436
3437    if args.mode == "user":
3438        if not args.header:
3439            cw.p("#include <linux/genetlink.h>")
3440            cw.nl()
3441            for one in args.user_header:
3442                cw.p(f'#include "{one}"')
3443        else:
3444            cw.p('struct ynl_sock;')
3445            cw.nl()
3446            render_user_family(parsed, cw, True)
3447        cw.nl()
3448
3449    if args.mode == "kernel":
3450        if args.header:
3451            for _, struct in sorted(parsed.pure_nested_structs.items()):
3452                if struct.request:
3453                    cw.p('/* Common nested types */')
3454                    break
3455            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3456                if struct.request:
3457                    print_req_policy_fwd(cw, struct)
3458            cw.nl()
3459
3460            if parsed.kernel_policy == 'global':
3461                cw.p(f"/* Global operation policy for {parsed.name} */")
3462
3463                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3464                print_req_policy_fwd(cw, struct)
3465                cw.nl()
3466
3467            if parsed.kernel_policy in {'per-op', 'split'}:
3468                for op_name, op in parsed.ops.items():
3469                    if 'do' in op and 'event' not in op:
3470                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3471                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3472                        cw.nl()
3473
3474            print_kernel_op_table_hdr(parsed, cw)
3475            print_kernel_mcgrp_hdr(parsed, cw)
3476            print_kernel_family_struct_hdr(parsed, cw)
3477        else:
3478            print_kernel_policy_ranges(parsed, cw)
3479            print_kernel_policy_sparse_enum_validates(parsed, cw)
3480
3481            for _, struct in sorted(parsed.pure_nested_structs.items()):
3482                if struct.request:
3483                    cw.p('/* Common nested types */')
3484                    break
3485            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3486                if struct.request:
3487                    print_req_policy(cw, struct)
3488            cw.nl()
3489
3490            if parsed.kernel_policy == 'global':
3491                cw.p(f"/* Global operation policy for {parsed.name} */")
3492
3493                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3494                print_req_policy(cw, struct)
3495                cw.nl()
3496
3497            for op_name, op in parsed.ops.items():
3498                if parsed.kernel_policy in {'per-op', 'split'}:
3499                    for op_mode in ['do', 'dump']:
3500                        if op_mode in op and 'request' in op[op_mode]:
3501                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3502                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3503                            print_req_policy(cw, ri.struct['request'], ri=ri)
3504                            cw.nl()
3505
3506            print_kernel_op_table(parsed, cw)
3507            print_kernel_mcgrp_src(parsed, cw)
3508            print_kernel_family_struct_src(parsed, cw)
3509
3510    if args.mode == "user":
3511        if args.header:
3512            cw.p('/* Enums */')
3513            put_op_name_fwd(parsed, cw)
3514
3515            for name, const in parsed.consts.items():
3516                if isinstance(const, EnumSet):
3517                    put_enum_to_str_fwd(parsed, cw, const)
3518            cw.nl()
3519
3520            cw.p('/* Common nested types */')
3521            for attr_set, struct in parsed.pure_nested_structs.items():
3522                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3523                print_type_full(ri, struct)
3524
3525            for op_name, op in parsed.ops.items():
3526                cw.p(f"/* ============== {op.enum_name} ============== */")
3527
3528                if 'do' in op and 'event' not in op:
3529                    cw.p(f"/* {op.enum_name} - do */")
3530                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3531                    print_req_type(ri)
3532                    print_req_type_helpers(ri)
3533                    cw.nl()
3534                    print_rsp_type(ri)
3535                    print_rsp_type_helpers(ri)
3536                    cw.nl()
3537                    print_req_prototype(ri)
3538                    cw.nl()
3539
3540                if 'dump' in op:
3541                    cw.p(f"/* {op.enum_name} - dump */")
3542                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3543                    print_req_type(ri)
3544                    print_req_type_helpers(ri)
3545                    if not ri.type_consistent or ri.type_oneside:
3546                        print_rsp_type(ri)
3547                    print_wrapped_type(ri)
3548                    print_dump_prototype(ri)
3549                    cw.nl()
3550
3551                if op.has_ntf:
3552                    cw.p(f"/* {op.enum_name} - notify */")
3553                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3554                    if not ri.type_consistent:
3555                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3556                    print_wrapped_type(ri)
3557
3558            for op_name, op in parsed.ntfs.items():
3559                if 'event' in op:
3560                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3561                    cw.p(f"/* {op.enum_name} - event */")
3562                    print_rsp_type(ri)
3563                    cw.nl()
3564                    print_wrapped_type(ri)
3565            cw.nl()
3566        else:
3567            cw.p('/* Enums */')
3568            put_op_name(parsed, cw)
3569
3570            for name, const in parsed.consts.items():
3571                if isinstance(const, EnumSet):
3572                    put_enum_to_str(parsed, cw, const)
3573            cw.nl()
3574
3575            has_recursive_nests = False
3576            cw.p('/* Policies */')
3577            for struct in parsed.pure_nested_structs.values():
3578                if struct.recursive:
3579                    put_typol_fwd(cw, struct)
3580                    has_recursive_nests = True
3581            if has_recursive_nests:
3582                cw.nl()
3583            for struct in parsed.pure_nested_structs.values():
3584                put_typol(cw, struct)
3585            for name in parsed.root_sets:
3586                struct = Struct(parsed, name)
3587                put_typol(cw, struct)
3588
3589            cw.p('/* Common nested types */')
3590            if has_recursive_nests:
3591                for attr_set, struct in parsed.pure_nested_structs.items():
3592                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3593                    free_rsp_nested_prototype(ri)
3594                    if struct.request:
3595                        put_req_nested_prototype(ri, struct)
3596                    if struct.reply:
3597                        parse_rsp_nested_prototype(ri, struct)
3598                cw.nl()
3599            for attr_set, struct in parsed.pure_nested_structs.items():
3600                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3601
3602                free_rsp_nested(ri, struct)
3603                if struct.request:
3604                    put_req_nested(ri, struct)
3605                if struct.reply:
3606                    parse_rsp_nested(ri, struct)
3607
3608            for op_name, op in parsed.ops.items():
3609                cw.p(f"/* ============== {op.enum_name} ============== */")
3610                if 'do' in op and 'event' not in op:
3611                    cw.p(f"/* {op.enum_name} - do */")
3612                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3613                    print_req_free(ri)
3614                    print_rsp_free(ri)
3615                    parse_rsp_msg(ri)
3616                    print_req(ri)
3617                    cw.nl()
3618
3619                if 'dump' in op:
3620                    cw.p(f"/* {op.enum_name} - dump */")
3621                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3622                    if not ri.type_consistent or ri.type_oneside:
3623                        parse_rsp_msg(ri, deref=True)
3624                    print_req_free(ri)
3625                    print_dump_type_free(ri)
3626                    print_dump(ri)
3627                    cw.nl()
3628
3629                if op.has_ntf:
3630                    cw.p(f"/* {op.enum_name} - notify */")
3631                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3632                    if not ri.type_consistent:
3633                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3634                    print_ntf_type_free(ri)
3635
3636            for op_name, op in parsed.ntfs.items():
3637                if 'event' in op:
3638                    cw.p(f"/* {op.enum_name} - event */")
3639
3640                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3641                    parse_rsp_msg(ri)
3642
3643                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3644                    print_ntf_type_free(ri)
3645            cw.nl()
3646            render_user_family(parsed, cw, False)
3647
3648    if args.header:
3649        cw.p(f'#endif /* {hdr_prot} */')
3650
3651
3652if __name__ == "__main__":
3653    main()
3654