xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 2c222dde61c4fcb8693d31acf5ef8e342fda4c26)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17from lib import SpecSubMessage, SpecSubMessageFormat
18
19
20def c_upper(name):
21    return name.upper().replace('-', '_')
22
23
24def c_lower(name):
25    return name.lower().replace('-', '_')
26
27
28def limit_to_number(name):
29    """
30    Turn a string limit like u32-max or s64-min into its numerical value
31    """
32    if name[0] == 'u' and name.endswith('-min'):
33        return 0
34    width = int(name[1:-4])
35    if name[0] == 's':
36        width -= 1
37    value = (1 << width) - 1
38    if name[0] == 's' and name.endswith('-min'):
39        value = -value - 1
40    return value
41
42
43class BaseNlLib:
44    def get_family_id(self):
45        return 'ys->family_id'
46
47
48class Type(SpecAttr):
49    def __init__(self, family, attr_set, attr, value):
50        super().__init__(family, attr_set, attr, value)
51
52        self.attr = attr
53        self.attr_set = attr_set
54        self.type = attr['type']
55        self.checks = attr.get('checks', {})
56
57        self.request = False
58        self.reply = False
59
60        self.is_selector = False
61
62        if 'len' in attr:
63            self.len = attr['len']
64
65        if 'nested-attributes' in attr:
66            nested = attr['nested-attributes']
67        elif 'sub-message' in attr:
68            nested = attr['sub-message']
69        else:
70            nested = None
71
72        if nested:
73            self.nested_attrs = nested
74            if self.nested_attrs == family.name:
75                self.nested_render_name = c_lower(f"{family.ident_name}")
76            else:
77                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
78
79            if self.nested_attrs in self.family.consts:
80                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
81            else:
82                self.nested_struct_type = 'struct ' + self.nested_render_name
83
84        self.c_name = c_lower(self.name)
85        if self.c_name in _C_KW:
86            self.c_name += '_'
87        if self.c_name[0].isdigit():
88            self.c_name = '_' + self.c_name
89
90        # Added by resolve():
91        self.enum_name = None
92        delattr(self, "enum_name")
93
94    def _get_real_attr(self):
95        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
96        return self.family.attr_sets[self.attr_set.subset_of][self.name]
97
98    def set_request(self):
99        self.request = True
100        if self.attr_set.subset_of:
101            self._get_real_attr().set_request()
102
103    def set_reply(self):
104        self.reply = True
105        if self.attr_set.subset_of:
106            self._get_real_attr().set_reply()
107
108    def get_limit(self, limit, default=None):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return value
112        if isinstance(value, int):
113            return value
114        if value in self.family.consts:
115            return self.family.consts[value]["value"]
116        return limit_to_number(value)
117
118    def get_limit_str(self, limit, default=None, suffix=''):
119        value = self.checks.get(limit, default)
120        if value is None:
121            return ''
122        if isinstance(value, int):
123            return str(value) + suffix
124        if value in self.family.consts:
125            const = self.family.consts[value]
126            if const.get('header'):
127                return c_upper(value)
128            return c_upper(f"{self.family['name']}-{value}")
129        return c_upper(value)
130
131    def resolve(self):
132        if 'parent-sub-message' in self.attr:
133            enum_name = self.attr['parent-sub-message'].enum_name
134        elif 'name-prefix' in self.attr:
135            enum_name = f"{self.attr['name-prefix']}{self.name}"
136        else:
137            enum_name = f"{self.attr_set.name_prefix}{self.name}"
138        self.enum_name = c_upper(enum_name)
139
140        if self.attr_set.subset_of:
141            if self.checks != self._get_real_attr().checks:
142                raise Exception("Overriding checks not supported by codegen, yet")
143
144    def is_multi_val(self):
145        return None
146
147    def is_scalar(self):
148        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
149
150    def is_recursive(self):
151        return False
152
153    def is_recursive_for_op(self, ri):
154        return self.is_recursive() and not ri.op
155
156    def presence_type(self):
157        return 'present'
158
159    def presence_member(self, space, type_filter):
160        if self.presence_type() != type_filter:
161            return
162
163        if self.presence_type() == 'present':
164            pfx = '__' if space == 'user' else ''
165            return f"{pfx}u32 {self.c_name}:1;"
166
167        if self.presence_type() in {'len', 'count'}:
168            pfx = '__' if space == 'user' else ''
169            return f"{pfx}u32 {self.c_name};"
170
171    def _complex_member_type(self, ri):
172        return None
173
174    def free_needs_iter(self):
175        return False
176
177    def _free_lines(self, ri, var, ref):
178        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
179            return [f'free({var}->{ref}{self.c_name});']
180        return []
181
182    def free(self, ri, var, ref):
183        lines = self._free_lines(ri, var, ref)
184        for line in lines:
185            ri.cw.p(line)
186
187    def arg_member(self, ri):
188        member = self._complex_member_type(ri)
189        if member:
190            spc = ' ' if member[-1] != '*' else ''
191            arg = [member + spc + '*' + self.c_name]
192            if self.presence_type() == 'count':
193                arg += ['unsigned int n_' + self.c_name]
194            return arg
195        raise Exception(f"Struct member not implemented for class type {self.type}")
196
197    def struct_member(self, ri):
198        member = self._complex_member_type(ri)
199        if member:
200            ptr = '*' if self.is_multi_val() else ''
201            if self.is_recursive_for_op(ri):
202                ptr = '*'
203            spc = ' ' if member[-1] != '*' else ''
204            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
205            return
206        members = self.arg_member(ri)
207        for one in members:
208            ri.cw.p(one + ';')
209
210    def _attr_policy(self, policy):
211        return '{ .type = ' + policy + ', }'
212
213    def attr_policy(self, cw):
214        policy = f'NLA_{c_upper(self.type)}'
215        if self.attr.get('byte-order') == 'big-endian':
216            if self.type in {'u16', 'u32'}:
217                policy = f'NLA_BE{self.type[1:]}'
218
219        spec = self._attr_policy(policy)
220        cw.p(f"\t[{self.enum_name}] = {spec},")
221
222    def _attr_typol(self):
223        raise Exception(f"Type policy not implemented for class type {self.type}")
224
225    def attr_typol(self, cw):
226        typol = self._attr_typol()
227        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
228
229    def _attr_put_line(self, ri, var, line):
230        presence = self.presence_type()
231        if presence in {'present', 'len'}:
232            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
233        ri.cw.p(f"{line};")
234
235    def _attr_put_simple(self, ri, var, put_type):
236        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
237        self._attr_put_line(ri, var, line)
238
239    def attr_put(self, ri, var):
240        raise Exception(f"Put not implemented for class type {self.type}")
241
242    def _attr_get(self, ri, var):
243        raise Exception(f"Attr get not implemented for class type {self.type}")
244
245    def attr_get(self, ri, var, first):
246        lines, init_lines, local_vars = self._attr_get(ri, var)
247        if type(lines) is str:
248            lines = [lines]
249        if type(init_lines) is str:
250            init_lines = [init_lines]
251
252        kw = 'if' if first else 'else if'
253        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
254        if local_vars:
255            for local in local_vars:
256                ri.cw.p(local)
257            ri.cw.nl()
258
259        if not self.is_multi_val():
260            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
261            ri.cw.p("return YNL_PARSE_CB_ERROR;")
262            if self.presence_type() == 'present':
263                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
264
265        if init_lines:
266            ri.cw.nl()
267            for line in init_lines:
268                ri.cw.p(line)
269
270        for line in lines:
271            ri.cw.p(line)
272        ri.cw.block_end()
273        return True
274
275    def _setter_lines(self, ri, member, presence):
276        raise Exception(f"Setter not implemented for class type {self.type}")
277
278    def setter(self, ri, space, direction, deref=False, ref=None):
279        ref = (ref if ref else []) + [self.c_name]
280        var = "req"
281        member = f"{var}->{'.'.join(ref)}"
282
283        local_vars = []
284        if self.free_needs_iter():
285            local_vars += ['unsigned int i;']
286
287        code = []
288        presence = ''
289        for i in range(0, len(ref)):
290            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
291            # Every layer below last is a nest, so we know it uses bit presence
292            # last layer is "self" and may be a complex type
293            if i == len(ref) - 1 and self.presence_type() != 'present':
294                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
295                continue
296            code.append(presence + ' = 1;')
297        ref_path = '.'.join(ref[:-1])
298        if ref_path:
299            ref_path += '.'
300        code += self._free_lines(ri, var, ref_path)
301        code += self._setter_lines(ri, member, presence)
302
303        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
304        free = bool([x for x in code if 'free(' in x])
305        alloc = bool([x for x in code if 'alloc(' in x])
306        if free and not alloc:
307            func_name = '__' + func_name
308        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
309                         body=code,
310                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
311
312
313class TypeUnused(Type):
314    def presence_type(self):
315        return ''
316
317    def arg_member(self, ri):
318        return []
319
320    def _attr_get(self, ri, var):
321        return ['return YNL_PARSE_CB_ERROR;'], None, None
322
323    def _attr_typol(self):
324        return '.type = YNL_PT_REJECT, '
325
326    def attr_policy(self, cw):
327        pass
328
329    def attr_put(self, ri, var):
330        pass
331
332    def attr_get(self, ri, var, first):
333        pass
334
335    def setter(self, ri, space, direction, deref=False, ref=None):
336        pass
337
338
339class TypePad(Type):
340    def presence_type(self):
341        return ''
342
343    def arg_member(self, ri):
344        return []
345
346    def _attr_typol(self):
347        return '.type = YNL_PT_IGNORE, '
348
349    def attr_put(self, ri, var):
350        pass
351
352    def attr_get(self, ri, var, first):
353        pass
354
355    def attr_policy(self, cw):
356        pass
357
358    def setter(self, ri, space, direction, deref=False, ref=None):
359        pass
360
361
362class TypeScalar(Type):
363    def __init__(self, family, attr_set, attr, value):
364        super().__init__(family, attr_set, attr, value)
365
366        self.byte_order_comment = ''
367        if 'byte-order' in attr:
368            self.byte_order_comment = f" /* {attr['byte-order']} */"
369
370        # Classic families have some funny enums, don't bother
371        # computing checks, since we only need them for kernel policies
372        if not family.is_classic():
373            self._init_checks()
374
375        # Added by resolve():
376        self.is_bitfield = None
377        delattr(self, "is_bitfield")
378        self.type_name = None
379        delattr(self, "type_name")
380
381    def resolve(self):
382        self.resolve_up(super())
383
384        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
385            self.is_bitfield = True
386        elif 'enum' in self.attr:
387            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
388        else:
389            self.is_bitfield = False
390
391        if not self.is_bitfield and 'enum' in self.attr:
392            self.type_name = self.family.consts[self.attr['enum']].user_type
393        elif self.is_auto_scalar:
394            self.type_name = '__' + self.type[0] + '64'
395        else:
396            self.type_name = '__' + self.type
397
398    def _init_checks(self):
399        if 'enum' in self.attr:
400            enum = self.family.consts[self.attr['enum']]
401            low, high = enum.value_range()
402            if low == None and high == None:
403                self.checks['sparse'] = True
404            else:
405                if 'min' not in self.checks:
406                    if low != 0 or self.type[0] == 's':
407                        self.checks['min'] = low
408                if 'max' not in self.checks:
409                    self.checks['max'] = high
410
411        if 'min' in self.checks and 'max' in self.checks:
412            if self.get_limit('min') > self.get_limit('max'):
413                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
414            self.checks['range'] = True
415
416        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
417        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
418        if low < 0 and self.type[0] == 'u':
419            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
420        if low < -32768 or high > 32767:
421            self.checks['full-range'] = True
422
423    def _attr_policy(self, policy):
424        if 'flags-mask' in self.checks or self.is_bitfield:
425            if self.is_bitfield:
426                enum = self.family.consts[self.attr['enum']]
427                mask = enum.get_mask(as_flags=True)
428            else:
429                flags = self.family.consts[self.checks['flags-mask']]
430                flag_cnt = len(flags['entries'])
431                mask = (1 << flag_cnt) - 1
432            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
433        elif 'full-range' in self.checks:
434            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
435        elif 'range' in self.checks:
436            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
437        elif 'min' in self.checks:
438            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
439        elif 'max' in self.checks:
440            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
441        elif 'sparse' in self.checks:
442            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
443        return super()._attr_policy(policy)
444
445    def _attr_typol(self):
446        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
447
448    def arg_member(self, ri):
449        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
450
451    def attr_put(self, ri, var):
452        self._attr_put_simple(ri, var, self.type)
453
454    def _attr_get(self, ri, var):
455        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
456
457    def _setter_lines(self, ri, member, presence):
458        return [f"{member} = {self.c_name};"]
459
460
461class TypeFlag(Type):
462    def arg_member(self, ri):
463        return []
464
465    def _attr_typol(self):
466        return '.type = YNL_PT_FLAG, '
467
468    def attr_put(self, ri, var):
469        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
470
471    def _attr_get(self, ri, var):
472        return [], None, None
473
474    def _setter_lines(self, ri, member, presence):
475        return []
476
477
478class TypeString(Type):
479    def arg_member(self, ri):
480        return [f"const char *{self.c_name}"]
481
482    def presence_type(self):
483        return 'len'
484
485    def struct_member(self, ri):
486        ri.cw.p(f"char *{self.c_name};")
487
488    def _attr_typol(self):
489        typol = f'.type = YNL_PT_NUL_STR, '
490        if self.is_selector:
491            typol += '.is_selector = 1, '
492        return typol
493
494    def _attr_policy(self, policy):
495        if 'exact-len' in self.checks:
496            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
497        else:
498            mem = '{ .type = ' + policy
499            if 'max-len' in self.checks:
500                mem += ', .len = ' + self.get_limit_str('max-len')
501            mem += ', }'
502        return mem
503
504    def attr_policy(self, cw):
505        if self.checks.get('unterminated-ok', False):
506            policy = 'NLA_STRING'
507        else:
508            policy = 'NLA_NUL_STRING'
509
510        spec = self._attr_policy(policy)
511        cw.p(f"\t[{self.enum_name}] = {spec},")
512
513    def attr_put(self, ri, var):
514        self._attr_put_simple(ri, var, 'str')
515
516    def _attr_get(self, ri, var):
517        len_mem = var + '->_len.' + self.c_name
518        return [f"{len_mem} = len;",
519                f"{var}->{self.c_name} = malloc(len + 1);",
520                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
521                f"{var}->{self.c_name}[len] = 0;"], \
522               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
523               ['unsigned int len;']
524
525    def _setter_lines(self, ri, member, presence):
526        return [f"{presence} = strlen({self.c_name});",
527                f"{member} = malloc({presence} + 1);",
528                f'memcpy({member}, {self.c_name}, {presence});',
529                f'{member}[{presence}] = 0;']
530
531
532class TypeBinary(Type):
533    def arg_member(self, ri):
534        return [f"const void *{self.c_name}", 'size_t len']
535
536    def presence_type(self):
537        return 'len'
538
539    def struct_member(self, ri):
540        ri.cw.p(f"void *{self.c_name};")
541
542    def _attr_typol(self):
543        return f'.type = YNL_PT_BINARY,'
544
545    def _attr_policy(self, policy):
546        if len(self.checks) == 0:
547            pass
548        elif len(self.checks) == 1:
549            check_name = list(self.checks)[0]
550            if check_name not in {'exact-len', 'min-len', 'max-len'}:
551                raise Exception('Unsupported check for binary type: ' + check_name)
552        else:
553            raise Exception('More than one check for binary type not implemented, yet')
554
555        if len(self.checks) == 0:
556            mem = '{ .type = NLA_BINARY, }'
557        elif 'exact-len' in self.checks:
558            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
559        elif 'min-len' in self.checks:
560            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
561        elif 'max-len' in self.checks:
562            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
563
564        return mem
565
566    def attr_put(self, ri, var):
567        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
568                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
569
570    def _attr_get(self, ri, var):
571        len_mem = var + '->_len.' + self.c_name
572        return [f"{len_mem} = len;",
573                f"{var}->{self.c_name} = malloc(len);",
574                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
575               ['len = ynl_attr_data_len(attr);'], \
576               ['unsigned int len;']
577
578    def _setter_lines(self, ri, member, presence):
579        return [f"{presence} = len;",
580                f"{member} = malloc({presence});",
581                f'memcpy({member}, {self.c_name}, {presence});']
582
583
584class TypeBinaryStruct(TypeBinary):
585    def struct_member(self, ri):
586        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
587
588    def _attr_get(self, ri, var):
589        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
590        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
591        return [f"{len_mem} = len;",
592                f"if (len < {struct_sz})",
593                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
594                "else",
595                f"{var}->{self.c_name} = malloc(len);",
596                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
597               ['len = ynl_attr_data_len(attr);'], \
598               ['unsigned int len;']
599
600
601class TypeBinaryScalarArray(TypeBinary):
602    def arg_member(self, ri):
603        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
604
605    def presence_type(self):
606        return 'count'
607
608    def struct_member(self, ri):
609        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
610
611    def attr_put(self, ri, var):
612        presence = self.presence_type()
613        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
614        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
615        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
616                f"{var}->{self.c_name}, i);")
617        ri.cw.block_end()
618
619    def _attr_get(self, ri, var):
620        len_mem = var + '->_count.' + self.c_name
621        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
622                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
623                f"{var}->{self.c_name} = malloc(len);",
624                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
625               ['len = ynl_attr_data_len(attr);'], \
626               ['unsigned int len;']
627
628    def _setter_lines(self, ri, member, presence):
629        return [f"{presence} = count;",
630                f"count *= sizeof(__{self.get('sub-type')});",
631                f"{member} = malloc(count);",
632                f'memcpy({member}, {self.c_name}, count);']
633
634
635class TypeBitfield32(Type):
636    def _complex_member_type(self, ri):
637        return "struct nla_bitfield32"
638
639    def _attr_typol(self):
640        return f'.type = YNL_PT_BITFIELD32, '
641
642    def _attr_policy(self, policy):
643        if not 'enum' in self.attr:
644            raise Exception('Enum required for bitfield32 attr')
645        enum = self.family.consts[self.attr['enum']]
646        mask = enum.get_mask(as_flags=True)
647        return f"NLA_POLICY_BITFIELD32({mask})"
648
649    def attr_put(self, ri, var):
650        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
651        self._attr_put_line(ri, var, line)
652
653    def _attr_get(self, ri, var):
654        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
655
656    def _setter_lines(self, ri, member, presence):
657        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
658
659
660class TypeNest(Type):
661    def is_recursive(self):
662        return self.family.pure_nested_structs[self.nested_attrs].recursive
663
664    def _complex_member_type(self, ri):
665        return self.nested_struct_type
666
667    def _free_lines(self, ri, var, ref):
668        lines = []
669        at = '&'
670        if self.is_recursive_for_op(ri):
671            at = ''
672            lines += [f'if ({var}->{ref}{self.c_name})']
673        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
674        return lines
675
676    def _attr_typol(self):
677        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
678
679    def _attr_policy(self, policy):
680        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
681
682    def attr_put(self, ri, var):
683        at = '' if self.is_recursive_for_op(ri) else '&'
684        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
685                            f"{self.enum_name}, {at}{var}->{self.c_name})")
686
687    def _attr_get(self, ri, var):
688        pns = self.family.pure_nested_structs[self.nested_attrs]
689        args = ["&parg", "attr"]
690        for sel in pns.external_selectors():
691            args.append(f'{var}->{sel.name}')
692        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
693                     "return YNL_PARSE_CB_ERROR;"]
694        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
695                      f"parg.data = &{var}->{self.c_name};"]
696        return get_lines, init_lines, None
697
698    def setter(self, ri, space, direction, deref=False, ref=None):
699        ref = (ref if ref else []) + [self.c_name]
700
701        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
702            if attr.is_recursive():
703                continue
704            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
705
706
707class TypeMultiAttr(Type):
708    def __init__(self, family, attr_set, attr, value, base_type):
709        super().__init__(family, attr_set, attr, value)
710
711        self.base_type = base_type
712
713    def is_multi_val(self):
714        return True
715
716    def presence_type(self):
717        return 'count'
718
719    def _complex_member_type(self, ri):
720        if 'type' not in self.attr or self.attr['type'] == 'nest':
721            return self.nested_struct_type
722        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
723            return None  # use arg_member()
724        elif self.attr['type'] == 'string':
725            return 'struct ynl_string *'
726        elif self.attr['type'] in scalars:
727            scalar_pfx = '__' if ri.ku_space == 'user' else ''
728            return scalar_pfx + self.attr['type']
729        else:
730            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
731
732    def arg_member(self, ri):
733        if self.type == 'binary' and 'struct' in self.attr:
734            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
735                    f'unsigned int n_{self.c_name}']
736        return super().arg_member(ri)
737
738    def free_needs_iter(self):
739        return self.attr['type'] in {'nest', 'string'}
740
741    def _free_lines(self, ri, var, ref):
742        lines = []
743        if self.attr['type'] in scalars:
744            lines += [f"free({var}->{ref}{self.c_name});"]
745        elif self.attr['type'] == 'binary':
746            lines += [f"free({var}->{ref}{self.c_name});"]
747        elif self.attr['type'] == 'string':
748            lines += [
749                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
750                f"free({var}->{ref}{self.c_name}[i]);",
751                f"free({var}->{ref}{self.c_name});",
752            ]
753        elif 'type' not in self.attr or self.attr['type'] == 'nest':
754            lines += [
755                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
756                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
757                f"free({var}->{ref}{self.c_name});",
758            ]
759        else:
760            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
761        return lines
762
763    def _attr_policy(self, policy):
764        return self.base_type._attr_policy(policy)
765
766    def _attr_typol(self):
767        return self.base_type._attr_typol()
768
769    def _attr_get(self, ri, var):
770        return f'n_{self.c_name}++;', None, None
771
772    def attr_put(self, ri, var):
773        if self.attr['type'] in scalars:
774            put_type = self.type
775            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
776            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
777        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
778            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
779            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
780        elif self.attr['type'] == 'string':
781            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
782            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
783        elif 'type' not in self.attr or self.attr['type'] == 'nest':
784            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
785            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
786                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
787        else:
788            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
789
790    def _setter_lines(self, ri, member, presence):
791        return [f"{member} = {self.c_name};",
792                f"{presence} = n_{self.c_name};"]
793
794
795class TypeArrayNest(Type):
796    def is_multi_val(self):
797        return True
798
799    def presence_type(self):
800        return 'count'
801
802    def _complex_member_type(self, ri):
803        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
804            return self.nested_struct_type
805        elif self.attr['sub-type'] in scalars:
806            scalar_pfx = '__' if ri.ku_space == 'user' else ''
807            return scalar_pfx + self.attr['sub-type']
808        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
809            return None  # use arg_member()
810        else:
811            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
812
813    def arg_member(self, ri):
814        if self.sub_type == 'binary' and 'exact-len' in self.checks:
815            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
816                    f'unsigned int n_{self.c_name}']
817        return super().arg_member(ri)
818
819    def _attr_typol(self):
820        if self.attr['sub-type'] in scalars:
821            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
822        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
823            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
824        else:
825            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
826
827    def _attr_get(self, ri, var):
828        local_vars = ['const struct nlattr *attr2;']
829        get_lines = [f'attr_{self.c_name} = attr;',
830                     'ynl_attr_for_each_nested(attr2, attr) {',
831                     '\tif (ynl_attr_validate(yarg, attr2))',
832                     '\t\treturn YNL_PARSE_CB_ERROR;',
833                     f'\t{var}->_count.{self.c_name}++;',
834                     '}']
835        return get_lines, None, local_vars
836
837    def attr_put(self, ri, var):
838        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
839        if self.sub_type in scalars:
840            put_type = self.sub_type
841            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
842            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
843            ri.cw.block_end()
844        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
845            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
846            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
847        elif self.sub_type == 'nest':
848            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
849            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
850        else:
851            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
852        ri.cw.p('ynl_attr_nest_end(nlh, array);')
853
854    def _setter_lines(self, ri, member, presence):
855        return [f"{member} = {self.c_name};",
856                f"{presence} = n_{self.c_name};"]
857
858
859class TypeNestTypeValue(Type):
860    def _complex_member_type(self, ri):
861        return self.nested_struct_type
862
863    def _attr_typol(self):
864        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
865
866    def _attr_get(self, ri, var):
867        prev = 'attr'
868        tv_args = ''
869        get_lines = []
870        local_vars = []
871        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
872                      f"parg.data = &{var}->{self.c_name};"]
873        if 'type-value' in self.attr:
874            tv_names = [c_lower(x) for x in self.attr["type-value"]]
875            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
876            local_vars += [f'__u32 {", ".join(tv_names)};']
877            for level in self.attr["type-value"]:
878                level = c_lower(level)
879                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
880                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
881                prev = 'attr_' + level
882
883            tv_args = f", {', '.join(tv_names)}"
884
885        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
886        return get_lines, init_lines, local_vars
887
888
889class TypeSubMessage(TypeNest):
890    def __init__(self, family, attr_set, attr, value):
891        super().__init__(family, attr_set, attr, value)
892
893        self.selector = Selector(attr, attr_set)
894
895    def _attr_typol(self):
896        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
897        typol += '.is_submsg = 1, '
898        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
899        # support external selectors. No family uses sub-messages with external
900        # selector for requests so this is fine for now.
901        if not self.selector.is_external():
902            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
903        return typol
904
905    def _attr_get(self, ri, var):
906        sel = c_lower(self['selector'])
907        if self.selector.is_external():
908            sel_var = f"_sel_{sel}"
909        else:
910            sel_var = f"{var}->{sel}"
911        get_lines = [f'if (!{sel_var})',
912                     f'return ynl_submsg_failed(yarg, "%s", "%s");' %
913                        (self.name, self['selector']),
914                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
915                     "return YNL_PARSE_CB_ERROR;"]
916        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
917                      f"parg.data = &{var}->{self.c_name};"]
918        return get_lines, init_lines, None
919
920
921class Selector:
922    def __init__(self, msg_attr, attr_set):
923        self.name = msg_attr["selector"]
924
925        if self.name in attr_set:
926            self.attr = attr_set[self.name]
927            self.attr.is_selector = True
928            self._external = False
929        else:
930            # The selector will need to get passed down thru the structs
931            self.attr = None
932            self._external = True
933
934    def set_attr(self, attr):
935        self.attr = attr
936
937    def is_external(self):
938        return self._external
939
940
941class Struct:
942    def __init__(self, family, space_name, type_list=None, fixed_header=None,
943                 inherited=None, submsg=None):
944        self.family = family
945        self.space_name = space_name
946        self.attr_set = family.attr_sets[space_name]
947        # Use list to catch comparisons with empty sets
948        self._inherited = inherited if inherited is not None else []
949        self.inherited = []
950        self.fixed_header = None
951        if fixed_header:
952            self.fixed_header = 'struct ' + c_lower(fixed_header)
953        self.submsg = submsg
954
955        self.nested = type_list is None
956        if family.name == c_lower(space_name):
957            self.render_name = c_lower(family.ident_name)
958        else:
959            self.render_name = c_lower(family.ident_name + '-' + space_name)
960        self.struct_name = 'struct ' + self.render_name
961        if self.nested and space_name in family.consts:
962            self.struct_name += '_'
963        self.ptr_name = self.struct_name + ' *'
964        # All attr sets this one contains, directly or multiple levels down
965        self.child_nests = set()
966
967        self.request = False
968        self.reply = False
969        self.recursive = False
970        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
971
972        self.attr_list = []
973        self.attrs = dict()
974        if type_list is not None:
975            for t in type_list:
976                self.attr_list.append((t, self.attr_set[t]),)
977        else:
978            for t in self.attr_set:
979                self.attr_list.append((t, self.attr_set[t]),)
980
981        max_val = 0
982        self.attr_max_val = None
983        for name, attr in self.attr_list:
984            if attr.value >= max_val:
985                max_val = attr.value
986                self.attr_max_val = attr
987            self.attrs[name] = attr
988
989    def __iter__(self):
990        yield from self.attrs
991
992    def __getitem__(self, key):
993        return self.attrs[key]
994
995    def member_list(self):
996        return self.attr_list
997
998    def set_inherited(self, new_inherited):
999        if self._inherited != new_inherited:
1000            raise Exception("Inheriting different members not supported")
1001        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1002
1003    def external_selectors(self):
1004        sels = []
1005        for name, attr in self.attr_list:
1006            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1007                sels.append(attr.selector)
1008        return sels
1009
1010    def free_needs_iter(self):
1011        for _, attr in self.attr_list:
1012            if attr.free_needs_iter():
1013                return True
1014        return False
1015
1016
1017class EnumEntry(SpecEnumEntry):
1018    def __init__(self, enum_set, yaml, prev, value_start):
1019        super().__init__(enum_set, yaml, prev, value_start)
1020
1021        if prev:
1022            self.value_change = (self.value != prev.value + 1)
1023        else:
1024            self.value_change = (self.value != 0)
1025        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1026
1027        # Added by resolve:
1028        self.c_name = None
1029        delattr(self, "c_name")
1030
1031    def resolve(self):
1032        self.resolve_up(super())
1033
1034        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1035
1036
1037class EnumSet(SpecEnumSet):
1038    def __init__(self, family, yaml):
1039        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1040
1041        if 'enum-name' in yaml:
1042            if yaml['enum-name']:
1043                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1044                self.user_type = self.enum_name
1045            else:
1046                self.enum_name = None
1047        else:
1048            self.enum_name = 'enum ' + self.render_name
1049
1050        if self.enum_name:
1051            self.user_type = self.enum_name
1052        else:
1053            self.user_type = 'int'
1054
1055        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1056        self.header = yaml.get('header', None)
1057        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1058
1059        super().__init__(family, yaml)
1060
1061    def new_entry(self, entry, prev_entry, value_start):
1062        return EnumEntry(self, entry, prev_entry, value_start)
1063
1064    def value_range(self):
1065        low = min([x.value for x in self.entries.values()])
1066        high = max([x.value for x in self.entries.values()])
1067
1068        if high - low + 1 != len(self.entries):
1069            return None, None
1070
1071        return low, high
1072
1073
1074class AttrSet(SpecAttrSet):
1075    def __init__(self, family, yaml):
1076        super().__init__(family, yaml)
1077
1078        if self.subset_of is None:
1079            if 'name-prefix' in yaml:
1080                pfx = yaml['name-prefix']
1081            elif self.name == family.name:
1082                pfx = family.ident_name + '-a-'
1083            else:
1084                pfx = f"{family.ident_name}-a-{self.name}-"
1085            self.name_prefix = c_upper(pfx)
1086            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1087            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1088        else:
1089            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1090            self.max_name = family.attr_sets[self.subset_of].max_name
1091            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1092
1093        # Added by resolve:
1094        self.c_name = None
1095        delattr(self, "c_name")
1096
1097    def resolve(self):
1098        self.c_name = c_lower(self.name)
1099        if self.c_name in _C_KW:
1100            self.c_name += '_'
1101        if self.c_name == self.family.c_name:
1102            self.c_name = ''
1103
1104    def new_attr(self, elem, value):
1105        if elem['type'] in scalars:
1106            t = TypeScalar(self.family, self, elem, value)
1107        elif elem['type'] == 'unused':
1108            t = TypeUnused(self.family, self, elem, value)
1109        elif elem['type'] == 'pad':
1110            t = TypePad(self.family, self, elem, value)
1111        elif elem['type'] == 'flag':
1112            t = TypeFlag(self.family, self, elem, value)
1113        elif elem['type'] == 'string':
1114            t = TypeString(self.family, self, elem, value)
1115        elif elem['type'] == 'binary':
1116            if 'struct' in elem:
1117                t = TypeBinaryStruct(self.family, self, elem, value)
1118            elif elem.get('sub-type') in scalars:
1119                t = TypeBinaryScalarArray(self.family, self, elem, value)
1120            else:
1121                t = TypeBinary(self.family, self, elem, value)
1122        elif elem['type'] == 'bitfield32':
1123            t = TypeBitfield32(self.family, self, elem, value)
1124        elif elem['type'] == 'nest':
1125            t = TypeNest(self.family, self, elem, value)
1126        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1127            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1128                t = TypeArrayNest(self.family, self, elem, value)
1129            else:
1130                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1131        elif elem['type'] == 'nest-type-value':
1132            t = TypeNestTypeValue(self.family, self, elem, value)
1133        elif elem['type'] == 'sub-message':
1134            t = TypeSubMessage(self.family, self, elem, value)
1135        else:
1136            raise Exception(f"No typed class for type {elem['type']}")
1137
1138        if 'multi-attr' in elem and elem['multi-attr']:
1139            t = TypeMultiAttr(self.family, self, elem, value, t)
1140
1141        return t
1142
1143
1144class Operation(SpecOperation):
1145    def __init__(self, family, yaml, req_value, rsp_value):
1146        # Fill in missing operation properties (for fixed hdr-only msgs)
1147        for mode in ['do', 'dump', 'event']:
1148            for direction in ['request', 'reply']:
1149                try:
1150                    yaml[mode][direction].setdefault('attributes', [])
1151                except KeyError:
1152                    pass
1153
1154        super().__init__(family, yaml, req_value, rsp_value)
1155
1156        self.render_name = c_lower(family.ident_name + '_' + self.name)
1157
1158        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1159                         ('dump' in yaml and 'request' in yaml['dump'])
1160
1161        self.has_ntf = False
1162
1163        # Added by resolve:
1164        self.enum_name = None
1165        delattr(self, "enum_name")
1166
1167    def resolve(self):
1168        self.resolve_up(super())
1169
1170        if not self.is_async:
1171            self.enum_name = self.family.op_prefix + c_upper(self.name)
1172        else:
1173            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1174
1175    def mark_has_ntf(self):
1176        self.has_ntf = True
1177
1178
1179class SubMessage(SpecSubMessage):
1180    def __init__(self, family, yaml):
1181        super().__init__(family, yaml)
1182
1183        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1184
1185    def resolve(self):
1186        self.resolve_up(super())
1187
1188
1189class Family(SpecFamily):
1190    def __init__(self, file_name, exclude_ops):
1191        # Added by resolve:
1192        self.c_name = None
1193        delattr(self, "c_name")
1194        self.op_prefix = None
1195        delattr(self, "op_prefix")
1196        self.async_op_prefix = None
1197        delattr(self, "async_op_prefix")
1198        self.mcgrps = None
1199        delattr(self, "mcgrps")
1200        self.consts = None
1201        delattr(self, "consts")
1202        self.hooks = None
1203        delattr(self, "hooks")
1204
1205        super().__init__(file_name, exclude_ops=exclude_ops)
1206
1207        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1208        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1209
1210        if 'definitions' not in self.yaml:
1211            self.yaml['definitions'] = []
1212
1213        if 'uapi-header' in self.yaml:
1214            self.uapi_header = self.yaml['uapi-header']
1215        else:
1216            self.uapi_header = f"linux/{self.ident_name}.h"
1217        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1218            self.uapi_header_name = self.uapi_header[6:-2]
1219        else:
1220            self.uapi_header_name = self.ident_name
1221
1222    def resolve(self):
1223        self.resolve_up(super())
1224
1225        self.c_name = c_lower(self.ident_name)
1226        if 'name-prefix' in self.yaml['operations']:
1227            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1228        else:
1229            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1230        if 'async-prefix' in self.yaml['operations']:
1231            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1232        else:
1233            self.async_op_prefix = self.op_prefix
1234
1235        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1236
1237        self.hooks = dict()
1238        for when in ['pre', 'post']:
1239            self.hooks[when] = dict()
1240            for op_mode in ['do', 'dump']:
1241                self.hooks[when][op_mode] = dict()
1242                self.hooks[when][op_mode]['set'] = set()
1243                self.hooks[when][op_mode]['list'] = []
1244
1245        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1246        self.root_sets = dict()
1247        # dict space-name -> Struct
1248        self.pure_nested_structs = dict()
1249
1250        self._mark_notify()
1251        self._mock_up_events()
1252
1253        self._load_root_sets()
1254        self._load_nested_sets()
1255        self._load_attr_use()
1256        self._load_selector_passing()
1257        self._load_hooks()
1258
1259        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1260        if self.kernel_policy == 'global':
1261            self._load_global_policy()
1262
1263    def new_enum(self, elem):
1264        return EnumSet(self, elem)
1265
1266    def new_attr_set(self, elem):
1267        return AttrSet(self, elem)
1268
1269    def new_operation(self, elem, req_value, rsp_value):
1270        return Operation(self, elem, req_value, rsp_value)
1271
1272    def new_sub_message(self, elem):
1273        return SubMessage(self, elem)
1274
1275    def is_classic(self):
1276        return self.proto == 'netlink-raw'
1277
1278    def _mark_notify(self):
1279        for op in self.msgs.values():
1280            if 'notify' in op:
1281                self.ops[op['notify']].mark_has_ntf()
1282
1283    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1284    def _mock_up_events(self):
1285        for op in self.yaml['operations']['list']:
1286            if 'event' in op:
1287                op['do'] = {
1288                    'reply': {
1289                        'attributes': op['event']['attributes']
1290                    }
1291                }
1292
1293    def _load_root_sets(self):
1294        for op_name, op in self.msgs.items():
1295            if 'attribute-set' not in op:
1296                continue
1297
1298            req_attrs = set()
1299            rsp_attrs = set()
1300            for op_mode in ['do', 'dump']:
1301                if op_mode in op and 'request' in op[op_mode]:
1302                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1303                if op_mode in op and 'reply' in op[op_mode]:
1304                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1305            if 'event' in op:
1306                rsp_attrs.update(set(op['event']['attributes']))
1307
1308            if op['attribute-set'] not in self.root_sets:
1309                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1310            else:
1311                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1312                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1313
1314    def _sort_pure_types(self):
1315        # Try to reorder according to dependencies
1316        pns_key_list = list(self.pure_nested_structs.keys())
1317        pns_key_seen = set()
1318        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1319        for _ in range(rounds):
1320            if len(pns_key_list) == 0:
1321                break
1322            name = pns_key_list.pop(0)
1323            finished = True
1324            for _, spec in self.attr_sets[name].items():
1325                if 'nested-attributes' in spec:
1326                    nested = spec['nested-attributes']
1327                elif 'sub-message' in spec:
1328                    nested = spec.sub_message
1329                else:
1330                    continue
1331
1332                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1333                if self.pure_nested_structs[nested].recursive:
1334                    continue
1335                if nested not in pns_key_seen:
1336                    # Dicts are sorted, this will make struct last
1337                    struct = self.pure_nested_structs.pop(name)
1338                    self.pure_nested_structs[name] = struct
1339                    finished = False
1340                    break
1341            if finished:
1342                pns_key_seen.add(name)
1343            else:
1344                pns_key_list.append(name)
1345
1346    def _load_nested_set_nest(self, spec):
1347        inherit = set()
1348        nested = spec['nested-attributes']
1349        if nested not in self.root_sets:
1350            if nested not in self.pure_nested_structs:
1351                self.pure_nested_structs[nested] = \
1352                    Struct(self, nested, inherited=inherit,
1353                           fixed_header=spec.get('fixed-header'))
1354        else:
1355            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1356
1357        if 'type-value' in spec:
1358            if nested in self.root_sets:
1359                raise Exception("Inheriting members to a space used as root not supported")
1360            inherit.update(set(spec['type-value']))
1361        elif spec['type'] == 'indexed-array':
1362            inherit.add('idx')
1363        self.pure_nested_structs[nested].set_inherited(inherit)
1364
1365        return nested
1366
1367    def _load_nested_set_submsg(self, spec):
1368        # Fake the struct type for the sub-message itself
1369        # its not a attr_set but codegen wants attr_sets.
1370        submsg = self.sub_msgs[spec["sub-message"]]
1371        nested = submsg.name
1372
1373        attrs = []
1374        for name, fmt in submsg.formats.items():
1375            attr = {
1376                "name": name,
1377                "parent-sub-message": spec,
1378            }
1379            if 'attribute-set' in fmt:
1380                attr |= {
1381                    "type": "nest",
1382                    "nested-attributes": fmt['attribute-set'],
1383                }
1384                if 'fixed-header' in fmt:
1385                    attr |= { "fixed-header": fmt["fixed-header"] }
1386            elif 'fixed-header' in fmt:
1387                attr |= {
1388                    "type": "binary",
1389                    "struct": fmt["fixed-header"],
1390                }
1391            else:
1392                attr["type"] = "flag"
1393            attrs.append(attr)
1394
1395        self.attr_sets[nested] = AttrSet(self, {
1396            "name": nested,
1397            "name-pfx": self.name + '-' + spec.name + '-',
1398            "attributes": attrs
1399        })
1400
1401        if nested not in self.pure_nested_structs:
1402            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1403
1404        return nested
1405
1406    def _load_nested_sets(self):
1407        attr_set_queue = list(self.root_sets.keys())
1408        attr_set_seen = set(self.root_sets.keys())
1409
1410        while len(attr_set_queue):
1411            a_set = attr_set_queue.pop(0)
1412            for attr, spec in self.attr_sets[a_set].items():
1413                if 'nested-attributes' in spec:
1414                    nested = self._load_nested_set_nest(spec)
1415                elif 'sub-message' in spec:
1416                    nested = self._load_nested_set_submsg(spec)
1417                else:
1418                    continue
1419
1420                if nested not in attr_set_seen:
1421                    attr_set_queue.append(nested)
1422                    attr_set_seen.add(nested)
1423
1424        for root_set, rs_members in self.root_sets.items():
1425            for attr, spec in self.attr_sets[root_set].items():
1426                if 'nested-attributes' in spec:
1427                    nested = spec['nested-attributes']
1428                elif 'sub-message' in spec:
1429                    nested = spec.sub_message
1430                else:
1431                    nested = None
1432
1433                if nested:
1434                    if attr in rs_members['request']:
1435                        self.pure_nested_structs[nested].request = True
1436                    if attr in rs_members['reply']:
1437                        self.pure_nested_structs[nested].reply = True
1438
1439                    if spec.is_multi_val():
1440                        child = self.pure_nested_structs.get(nested)
1441                        child.in_multi_val = True
1442
1443        self._sort_pure_types()
1444
1445        # Propagate the request / reply / recursive
1446        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1447            for _, spec in self.attr_sets[attr_set].items():
1448                if attr_set in struct.child_nests:
1449                    struct.recursive = True
1450
1451                if 'nested-attributes' in spec:
1452                    child_name = spec['nested-attributes']
1453                elif 'sub-message' in spec:
1454                    child_name = spec.sub_message
1455                else:
1456                    continue
1457
1458                struct.child_nests.add(child_name)
1459                child = self.pure_nested_structs.get(child_name)
1460                if child:
1461                    if not child.recursive:
1462                        struct.child_nests.update(child.child_nests)
1463                    child.request |= struct.request
1464                    child.reply |= struct.reply
1465                    if spec.is_multi_val():
1466                        child.in_multi_val = True
1467
1468        self._sort_pure_types()
1469
1470    def _load_attr_use(self):
1471        for _, struct in self.pure_nested_structs.items():
1472            if struct.request:
1473                for _, arg in struct.member_list():
1474                    arg.set_request()
1475            if struct.reply:
1476                for _, arg in struct.member_list():
1477                    arg.set_reply()
1478
1479        for root_set, rs_members in self.root_sets.items():
1480            for attr, spec in self.attr_sets[root_set].items():
1481                if attr in rs_members['request']:
1482                    spec.set_request()
1483                if attr in rs_members['reply']:
1484                    spec.set_reply()
1485
1486    def _load_selector_passing(self):
1487        def all_structs():
1488            for k, v in reversed(self.pure_nested_structs.items()):
1489                yield k, v
1490            for k, _ in self.root_sets.items():
1491                yield k, None  # we don't have a struct, but it must be terminal
1492
1493        for attr_set, struct in all_structs():
1494            for _, spec in self.attr_sets[attr_set].items():
1495                if 'nested-attributes' in spec:
1496                    child_name = spec['nested-attributes']
1497                elif 'sub-message' in spec:
1498                    child_name = spec.sub_message
1499                else:
1500                    continue
1501
1502                child = self.pure_nested_structs.get(child_name)
1503                for selector in child.external_selectors():
1504                    if selector.name in self.attr_sets[attr_set]:
1505                        sel_attr = self.attr_sets[attr_set][selector.name]
1506                        selector.set_attr(sel_attr)
1507                    else:
1508                        raise Exception("Passing selector thru more than one layer not supported")
1509
1510    def _load_global_policy(self):
1511        global_set = set()
1512        attr_set_name = None
1513        for op_name, op in self.ops.items():
1514            if not op:
1515                continue
1516            if 'attribute-set' not in op:
1517                continue
1518
1519            if attr_set_name is None:
1520                attr_set_name = op['attribute-set']
1521            if attr_set_name != op['attribute-set']:
1522                raise Exception('For a global policy all ops must use the same set')
1523
1524            for op_mode in ['do', 'dump']:
1525                if op_mode in op:
1526                    req = op[op_mode].get('request')
1527                    if req:
1528                        global_set.update(req.get('attributes', []))
1529
1530        self.global_policy = []
1531        self.global_policy_set = attr_set_name
1532        for attr in self.attr_sets[attr_set_name]:
1533            if attr in global_set:
1534                self.global_policy.append(attr)
1535
1536    def _load_hooks(self):
1537        for op in self.ops.values():
1538            for op_mode in ['do', 'dump']:
1539                if op_mode not in op:
1540                    continue
1541                for when in ['pre', 'post']:
1542                    if when not in op[op_mode]:
1543                        continue
1544                    name = op[op_mode][when]
1545                    if name in self.hooks[when][op_mode]['set']:
1546                        continue
1547                    self.hooks[when][op_mode]['set'].add(name)
1548                    self.hooks[when][op_mode]['list'].append(name)
1549
1550
1551class RenderInfo:
1552    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1553        self.family = family
1554        self.nl = cw.nlib
1555        self.ku_space = ku_space
1556        self.op_mode = op_mode
1557        self.op = op
1558
1559        fixed_hdr = op.fixed_header if op else None
1560        self.fixed_hdr_len = 'ys->family->hdr_len'
1561        if op and op.fixed_header:
1562            if op.fixed_header != family.fixed_header:
1563                if family.is_classic():
1564                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1565                else:
1566                    raise Exception(f"Per-op fixed header not supported, yet")
1567
1568
1569        # 'do' and 'dump' response parsing is identical
1570        self.type_consistent = True
1571        self.type_oneside = False
1572        if op_mode != 'do' and 'dump' in op:
1573            if 'do' in op:
1574                if ('reply' in op['do']) != ('reply' in op["dump"]):
1575                    self.type_consistent = False
1576                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1577                    self.type_consistent = False
1578            else:
1579                self.type_consistent = True
1580                self.type_oneside = True
1581
1582        self.attr_set = attr_set
1583        if not self.attr_set:
1584            self.attr_set = op['attribute-set']
1585
1586        self.type_name_conflict = False
1587        if op:
1588            self.type_name = c_lower(op.name)
1589        else:
1590            self.type_name = c_lower(attr_set)
1591            if attr_set in family.consts:
1592                self.type_name_conflict = True
1593
1594        self.cw = cw
1595
1596        self.struct = dict()
1597        if op_mode == 'notify':
1598            op_mode = 'do' if 'do' in op else 'dump'
1599        for op_dir in ['request', 'reply']:
1600            if op:
1601                type_list = []
1602                if op_dir in op[op_mode]:
1603                    type_list = op[op_mode][op_dir]['attributes']
1604                self.struct[op_dir] = Struct(family, self.attr_set,
1605                                             fixed_header=fixed_hdr,
1606                                             type_list=type_list)
1607        if op_mode == 'event':
1608            self.struct['reply'] = Struct(family, self.attr_set,
1609                                          fixed_header=fixed_hdr,
1610                                          type_list=op['event']['attributes'])
1611
1612    def type_empty(self, key):
1613        return len(self.struct[key].attr_list) == 0 and \
1614            self.struct['request'].fixed_header is None
1615
1616    def needs_nlflags(self, direction):
1617        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1618
1619
1620class CodeWriter:
1621    def __init__(self, nlib, out_file=None, overwrite=True):
1622        self.nlib = nlib
1623        self._overwrite = overwrite
1624
1625        self._nl = False
1626        self._block_end = False
1627        self._silent_block = False
1628        self._ind = 0
1629        self._ifdef_block = None
1630        if out_file is None:
1631            self._out = os.sys.stdout
1632        else:
1633            self._out = tempfile.NamedTemporaryFile('w+')
1634            self._out_file = out_file
1635
1636    def __del__(self):
1637        self.close_out_file()
1638
1639    def close_out_file(self):
1640        if self._out == os.sys.stdout:
1641            return
1642        # Avoid modifying the file if contents didn't change
1643        self._out.flush()
1644        if not self._overwrite and os.path.isfile(self._out_file):
1645            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1646                return
1647        with open(self._out_file, 'w+') as out_file:
1648            self._out.seek(0)
1649            shutil.copyfileobj(self._out, out_file)
1650            self._out.close()
1651        self._out = os.sys.stdout
1652
1653    @classmethod
1654    def _is_cond(cls, line):
1655        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1656
1657    def p(self, line, add_ind=0):
1658        if self._block_end:
1659            self._block_end = False
1660            if line.startswith('else'):
1661                line = '} ' + line
1662            else:
1663                self._out.write('\t' * self._ind + '}\n')
1664
1665        if self._nl:
1666            self._out.write('\n')
1667            self._nl = False
1668
1669        ind = self._ind
1670        if line[-1] == ':':
1671            ind -= 1
1672        if self._silent_block:
1673            ind += 1
1674        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1675        self._silent_block |= line.strip() == 'else'
1676        if line[0] == '#':
1677            ind = 0
1678        if add_ind:
1679            ind += add_ind
1680        self._out.write('\t' * ind + line + '\n')
1681
1682    def nl(self):
1683        self._nl = True
1684
1685    def block_start(self, line=''):
1686        if line:
1687            line = line + ' '
1688        self.p(line + '{')
1689        self._ind += 1
1690
1691    def block_end(self, line=''):
1692        if line and line[0] not in {';', ','}:
1693            line = ' ' + line
1694        self._ind -= 1
1695        self._nl = False
1696        if not line:
1697            # Delay printing closing bracket in case "else" comes next
1698            if self._block_end:
1699                self._out.write('\t' * (self._ind + 1) + '}\n')
1700            self._block_end = True
1701        else:
1702            self.p('}' + line)
1703
1704    def write_doc_line(self, doc, indent=True):
1705        words = doc.split()
1706        line = ' *'
1707        for word in words:
1708            if len(line) + len(word) >= 79:
1709                self.p(line)
1710                line = ' *'
1711                if indent:
1712                    line += '  '
1713            line += ' ' + word
1714        self.p(line)
1715
1716    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1717        if not args:
1718            args = ['void']
1719
1720        if doc:
1721            self.p('/*')
1722            self.p(' * ' + doc)
1723            self.p(' */')
1724
1725        oneline = qual_ret
1726        if qual_ret[-1] != '*':
1727            oneline += ' '
1728        oneline += f"{name}({', '.join(args)}){suffix}"
1729
1730        if len(oneline) < 80:
1731            self.p(oneline)
1732            return
1733
1734        v = qual_ret
1735        if len(v) > 3:
1736            self.p(v)
1737            v = ''
1738        elif qual_ret[-1] != '*':
1739            v += ' '
1740        v += name + '('
1741        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1742        delta_ind = len(v) - len(ind)
1743        v += args[0]
1744        i = 1
1745        while i < len(args):
1746            next_len = len(v) + len(args[i])
1747            if v[0] == '\t':
1748                next_len += delta_ind
1749            if next_len > 76:
1750                self.p(v + ',')
1751                v = ind
1752            else:
1753                v += ', '
1754            v += args[i]
1755            i += 1
1756        self.p(v + ')' + suffix)
1757
1758    def write_func_lvar(self, local_vars):
1759        if not local_vars:
1760            return
1761
1762        if type(local_vars) is str:
1763            local_vars = [local_vars]
1764
1765        local_vars.sort(key=len, reverse=True)
1766        for var in local_vars:
1767            self.p(var)
1768        self.nl()
1769
1770    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1771        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1772        self.block_start()
1773        self.write_func_lvar(local_vars=local_vars)
1774
1775        for line in body:
1776            self.p(line)
1777        self.block_end()
1778
1779    def writes_defines(self, defines):
1780        longest = 0
1781        for define in defines:
1782            if len(define[0]) > longest:
1783                longest = len(define[0])
1784        longest = ((longest + 8) // 8) * 8
1785        for define in defines:
1786            line = '#define ' + define[0]
1787            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1788            if type(define[1]) is int:
1789                line += str(define[1])
1790            elif type(define[1]) is str:
1791                line += '"' + define[1] + '"'
1792            self.p(line)
1793
1794    def write_struct_init(self, members):
1795        longest = max([len(x[0]) for x in members])
1796        longest += 1  # because we prepend a .
1797        longest = ((longest + 8) // 8) * 8
1798        for one in members:
1799            line = '.' + one[0]
1800            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1801            line += '= ' + str(one[1]) + ','
1802            self.p(line)
1803
1804    def ifdef_block(self, config):
1805        config_option = None
1806        if config:
1807            config_option = 'CONFIG_' + c_upper(config)
1808        if self._ifdef_block == config_option:
1809            return
1810
1811        if self._ifdef_block:
1812            self.p('#endif /* ' + self._ifdef_block + ' */')
1813        if config_option:
1814            self.p('#ifdef ' + config_option)
1815        self._ifdef_block = config_option
1816
1817
1818scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1819
1820direction_to_suffix = {
1821    'reply': '_rsp',
1822    'request': '_req',
1823    '': ''
1824}
1825
1826op_mode_to_wrapper = {
1827    'do': '',
1828    'dump': '_list',
1829    'notify': '_ntf',
1830    'event': '',
1831}
1832
1833_C_KW = {
1834    'auto',
1835    'bool',
1836    'break',
1837    'case',
1838    'char',
1839    'const',
1840    'continue',
1841    'default',
1842    'do',
1843    'double',
1844    'else',
1845    'enum',
1846    'extern',
1847    'float',
1848    'for',
1849    'goto',
1850    'if',
1851    'inline',
1852    'int',
1853    'long',
1854    'register',
1855    'return',
1856    'short',
1857    'signed',
1858    'sizeof',
1859    'static',
1860    'struct',
1861    'switch',
1862    'typedef',
1863    'union',
1864    'unsigned',
1865    'void',
1866    'volatile',
1867    'while'
1868}
1869
1870
1871def rdir(direction):
1872    if direction == 'reply':
1873        return 'request'
1874    if direction == 'request':
1875        return 'reply'
1876    return direction
1877
1878
1879def op_prefix(ri, direction, deref=False):
1880    suffix = f"_{ri.type_name}"
1881
1882    if not ri.op_mode:
1883        pass
1884    elif ri.op_mode == 'do':
1885        suffix += f"{direction_to_suffix[direction]}"
1886    else:
1887        if direction == 'request':
1888            suffix += '_req'
1889            if not ri.type_oneside:
1890                suffix += '_dump'
1891        else:
1892            if ri.type_consistent:
1893                if deref:
1894                    suffix += f"{direction_to_suffix[direction]}"
1895                else:
1896                    suffix += op_mode_to_wrapper[ri.op_mode]
1897            else:
1898                suffix += '_rsp'
1899                suffix += '_dump' if deref else '_list'
1900
1901    return f"{ri.family.c_name}{suffix}"
1902
1903
1904def type_name(ri, direction, deref=False):
1905    return f"struct {op_prefix(ri, direction, deref=deref)}"
1906
1907
1908def print_prototype(ri, direction, terminate=True, doc=None):
1909    suffix = ';' if terminate else ''
1910
1911    fname = ri.op.render_name
1912    if ri.op_mode == 'dump':
1913        fname += '_dump'
1914
1915    args = ['struct ynl_sock *ys']
1916    if 'request' in ri.op[ri.op_mode]:
1917        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1918
1919    ret = 'int'
1920    if 'reply' in ri.op[ri.op_mode]:
1921        ret = f"{type_name(ri, rdir(direction))} *"
1922
1923    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1924
1925
1926def print_req_prototype(ri):
1927    print_prototype(ri, "request", doc=ri.op['doc'])
1928
1929
1930def print_dump_prototype(ri):
1931    print_prototype(ri, "request")
1932
1933
1934def put_typol_submsg(cw, struct):
1935    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1936
1937    i = 0
1938    for name, arg in struct.member_list():
1939        nest = ""
1940        if arg.type == 'nest':
1941            nest = f" .nest = &{arg.nested_render_name}_nest,"
1942        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s",%s },' %
1943             (i, name, nest))
1944        i += 1
1945
1946    cw.block_end(line=';')
1947    cw.nl()
1948
1949    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1950    cw.p(f'.max_attr = {i - 1},')
1951    cw.p(f'.table = {struct.render_name}_policy,')
1952    cw.block_end(line=';')
1953    cw.nl()
1954
1955
1956def put_typol_fwd(cw, struct):
1957    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1958
1959
1960def put_typol(cw, struct):
1961    if struct.submsg:
1962        put_typol_submsg(cw, struct)
1963        return
1964
1965    type_max = struct.attr_set.max_name
1966    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1967
1968    for _, arg in struct.member_list():
1969        arg.attr_typol(cw)
1970
1971    cw.block_end(line=';')
1972    cw.nl()
1973
1974    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1975    cw.p(f'.max_attr = {type_max},')
1976    cw.p(f'.table = {struct.render_name}_policy,')
1977    cw.block_end(line=';')
1978    cw.nl()
1979
1980
1981def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1982    args = [f'int {arg_name}']
1983    if enum:
1984        args = [enum.user_type + ' ' + arg_name]
1985    cw.write_func_prot('const char *', f'{render_name}_str', args)
1986    cw.block_start()
1987    if enum and enum.type == 'flags':
1988        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1989    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1990    cw.p('return NULL;')
1991    cw.p(f'return {map_name}[{arg_name}];')
1992    cw.block_end()
1993    cw.nl()
1994
1995
1996def put_op_name_fwd(family, cw):
1997    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1998
1999
2000def put_op_name(family, cw):
2001    map_name = f'{family.c_name}_op_strmap'
2002    cw.block_start(line=f"static const char * const {map_name}[] =")
2003    for op_name, op in family.msgs.items():
2004        if op.rsp_value:
2005            # Make sure we don't add duplicated entries, if multiple commands
2006            # produce the same response in legacy families.
2007            if family.rsp_by_value[op.rsp_value] != op:
2008                cw.p(f'// skip "{op_name}", duplicate reply value')
2009                continue
2010
2011            if op.req_value == op.rsp_value:
2012                cw.p(f'[{op.enum_name}] = "{op_name}",')
2013            else:
2014                cw.p(f'[{op.rsp_value}] = "{op_name}",')
2015    cw.block_end(line=';')
2016    cw.nl()
2017
2018    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2019
2020
2021def put_enum_to_str_fwd(family, cw, enum):
2022    args = [enum.user_type + ' value']
2023    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2024
2025
2026def put_enum_to_str(family, cw, enum):
2027    map_name = f'{enum.render_name}_strmap'
2028    cw.block_start(line=f"static const char * const {map_name}[] =")
2029    for entry in enum.entries.values():
2030        cw.p(f'[{entry.value}] = "{entry.name}",')
2031    cw.block_end(line=';')
2032    cw.nl()
2033
2034    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2035
2036
2037def put_req_nested_prototype(ri, struct, suffix=';'):
2038    func_args = ['struct nlmsghdr *nlh',
2039                 'unsigned int attr_type',
2040                 f'{struct.ptr_name}obj']
2041
2042    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2043                          suffix=suffix)
2044
2045
2046def put_req_nested(ri, struct):
2047    local_vars = []
2048    init_lines = []
2049
2050    if struct.submsg is None:
2051        local_vars.append('struct nlattr *nest;')
2052        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2053    if struct.fixed_header:
2054        local_vars.append('void *hdr;')
2055        struct_sz = f'sizeof({struct.fixed_header})'
2056        init_lines.append(f"hdr = ynl_nlmsg_put_extra_header(nlh, {struct_sz});")
2057        init_lines.append(f"memcpy(hdr, &obj->_hdr, {struct_sz});")
2058
2059    has_anest = False
2060    has_count = False
2061    for _, arg in struct.member_list():
2062        has_anest |= arg.type == 'indexed-array'
2063        has_count |= arg.presence_type() == 'count'
2064    if has_anest:
2065        local_vars.append('struct nlattr *array;')
2066    if has_count:
2067        local_vars.append('unsigned int i;')
2068
2069    put_req_nested_prototype(ri, struct, suffix='')
2070    ri.cw.block_start()
2071    ri.cw.write_func_lvar(local_vars)
2072
2073    for line in init_lines:
2074        ri.cw.p(line)
2075
2076    for _, arg in struct.member_list():
2077        arg.attr_put(ri, "obj")
2078
2079    if struct.submsg is None:
2080        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2081
2082    ri.cw.nl()
2083    ri.cw.p('return 0;')
2084    ri.cw.block_end()
2085    ri.cw.nl()
2086
2087
2088def _multi_parse(ri, struct, init_lines, local_vars):
2089    if struct.fixed_header:
2090        local_vars += ['void *hdr;']
2091    if struct.nested:
2092        if struct.fixed_header:
2093            iter_line = f"ynl_attr_for_each_nested_off(attr, nested, sizeof({struct.fixed_header}))"
2094        else:
2095            iter_line = "ynl_attr_for_each_nested(attr, nested)"
2096    else:
2097        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2098        if ri.op.fixed_header != ri.family.fixed_header:
2099            if ri.family.is_classic():
2100                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2101            else:
2102                raise Exception(f"Per-op fixed header not supported, yet")
2103
2104    array_nests = set()
2105    multi_attrs = set()
2106    needs_parg = False
2107    for arg, aspec in struct.member_list():
2108        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2109            if aspec["sub-type"] in {'binary', 'nest'}:
2110                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2111                array_nests.add(arg)
2112            elif aspec['sub-type'] in scalars:
2113                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2114                array_nests.add(arg)
2115            else:
2116                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2117        if 'multi-attr' in aspec:
2118            multi_attrs.add(arg)
2119        needs_parg |= 'nested-attributes' in aspec
2120        needs_parg |= 'sub-message' in aspec
2121    if array_nests or multi_attrs:
2122        local_vars.append('int i;')
2123    if needs_parg:
2124        local_vars.append('struct ynl_parse_arg parg;')
2125        init_lines.append('parg.ys = yarg->ys;')
2126
2127    all_multi = array_nests | multi_attrs
2128
2129    for anest in sorted(all_multi):
2130        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
2131
2132    ri.cw.block_start()
2133    ri.cw.write_func_lvar(local_vars)
2134
2135    for line in init_lines:
2136        ri.cw.p(line)
2137    ri.cw.nl()
2138
2139    for arg in struct.inherited:
2140        ri.cw.p(f'dst->{arg} = {arg};')
2141
2142    if struct.fixed_header:
2143        if struct.nested:
2144            ri.cw.p('hdr = ynl_attr_data(nested);')
2145        elif ri.family.is_classic():
2146            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2147        else:
2148            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2149        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2150    for anest in sorted(all_multi):
2151        aspec = struct[anest]
2152        ri.cw.p(f"if (dst->{aspec.c_name})")
2153        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2154
2155    ri.cw.nl()
2156    ri.cw.block_start(line=iter_line)
2157    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2158    ri.cw.nl()
2159
2160    first = True
2161    for _, arg in struct.member_list():
2162        good = arg.attr_get(ri, 'dst', first=first)
2163        # First may be 'unused' or 'pad', ignore those
2164        first &= not good
2165
2166    ri.cw.block_end()
2167    ri.cw.nl()
2168
2169    for anest in sorted(array_nests):
2170        aspec = struct[anest]
2171
2172        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2173        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2174        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2175        ri.cw.p('i = 0;')
2176        if 'nested-attributes' in aspec:
2177            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2178        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2179        if 'nested-attributes' in aspec:
2180            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2181            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2182            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2183        elif aspec.sub_type in scalars:
2184            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2185        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2186            # Length is validated by typol
2187            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2188        else:
2189            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2190        ri.cw.p('i++;')
2191        ri.cw.block_end()
2192        ri.cw.block_end()
2193    ri.cw.nl()
2194
2195    for anest in sorted(multi_attrs):
2196        aspec = struct[anest]
2197        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2198        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2199        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2200        ri.cw.p('i = 0;')
2201        if 'nested-attributes' in aspec:
2202            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2203        ri.cw.block_start(line=iter_line)
2204        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2205        if 'nested-attributes' in aspec:
2206            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2207            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2208            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2209        elif aspec.type in scalars:
2210            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2211        elif aspec.type == 'binary' and 'struct' in aspec:
2212            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2213            ri.cw.nl()
2214            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2215            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2216            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2217        elif aspec.type == 'string':
2218            ri.cw.p('unsigned int len;')
2219            ri.cw.nl()
2220            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2221            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2222            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2223            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2224            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2225        else:
2226            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2227        ri.cw.p('i++;')
2228        ri.cw.block_end()
2229        ri.cw.block_end()
2230        ri.cw.block_end()
2231    ri.cw.nl()
2232
2233    if struct.nested:
2234        ri.cw.p('return 0;')
2235    else:
2236        ri.cw.p('return YNL_PARSE_CB_OK;')
2237    ri.cw.block_end()
2238    ri.cw.nl()
2239
2240
2241def parse_rsp_submsg(ri, struct):
2242    parse_rsp_nested_prototype(ri, struct, suffix='')
2243
2244    var = 'dst'
2245    local_vars = {'const struct nlattr *attr = nested;',
2246                  f'{struct.ptr_name}{var} = yarg->data;',
2247                  'struct ynl_parse_arg parg;'}
2248
2249    for _, arg in struct.member_list():
2250        _, _, l_vars = arg._attr_get(ri, var)
2251        local_vars |= set(l_vars) if l_vars else set()
2252
2253    ri.cw.block_start()
2254    ri.cw.write_func_lvar(list(local_vars))
2255    ri.cw.p('parg.ys = yarg->ys;')
2256    ri.cw.nl()
2257
2258    first = True
2259    for name, arg in struct.member_list():
2260        kw = 'if' if first else 'else if'
2261        first = False
2262
2263        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2264        get_lines, init_lines, _ = arg._attr_get(ri, var)
2265        for line in init_lines or []:
2266            ri.cw.p(line)
2267        for line in get_lines:
2268            ri.cw.p(line)
2269        if arg.presence_type() == 'present':
2270            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2271        ri.cw.block_end()
2272    ri.cw.p('return 0;')
2273    ri.cw.block_end()
2274    ri.cw.nl()
2275
2276
2277def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2278    func_args = ['struct ynl_parse_arg *yarg',
2279                 'const struct nlattr *nested']
2280    for sel in struct.external_selectors():
2281        func_args.append('const char *_sel_' + sel.name)
2282    if struct.submsg:
2283        func_args.insert(1, 'const char *sel')
2284    for arg in struct.inherited:
2285        func_args.append('__u32 ' + arg)
2286
2287    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2288                          suffix=suffix)
2289
2290
2291def parse_rsp_nested(ri, struct):
2292    if struct.submsg:
2293        return parse_rsp_submsg(ri, struct)
2294
2295    parse_rsp_nested_prototype(ri, struct, suffix='')
2296
2297    local_vars = ['const struct nlattr *attr;',
2298                  f'{struct.ptr_name}dst = yarg->data;']
2299    init_lines = []
2300
2301    if struct.member_list():
2302        _multi_parse(ri, struct, init_lines, local_vars)
2303    else:
2304        # Empty nest
2305        ri.cw.block_start()
2306        ri.cw.p('return 0;')
2307        ri.cw.block_end()
2308        ri.cw.nl()
2309
2310
2311def parse_rsp_msg(ri, deref=False):
2312    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2313        return
2314
2315    func_args = ['const struct nlmsghdr *nlh',
2316                 'struct ynl_parse_arg *yarg']
2317
2318    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2319                  'const struct nlattr *attr;']
2320    init_lines = ['dst = yarg->data;']
2321
2322    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2323
2324    if ri.struct["reply"].member_list():
2325        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2326    else:
2327        # Empty reply
2328        ri.cw.block_start()
2329        ri.cw.p('return YNL_PARSE_CB_OK;')
2330        ri.cw.block_end()
2331        ri.cw.nl()
2332
2333
2334def print_req(ri):
2335    ret_ok = '0'
2336    ret_err = '-1'
2337    direction = "request"
2338    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2339                  'struct nlmsghdr *nlh;',
2340                  'int err;']
2341
2342    if 'reply' in ri.op[ri.op_mode]:
2343        ret_ok = 'rsp'
2344        ret_err = 'NULL'
2345        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2346
2347    if ri.struct["request"].fixed_header:
2348        local_vars += ['size_t hdr_len;',
2349                       'void *hdr;']
2350
2351    for _, attr in ri.struct["request"].member_list():
2352        if attr.presence_type() == 'count':
2353            local_vars += ['unsigned int i;']
2354            break
2355
2356    print_prototype(ri, direction, terminate=False)
2357    ri.cw.block_start()
2358    ri.cw.write_func_lvar(local_vars)
2359
2360    if ri.family.is_classic():
2361        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2362    else:
2363        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2364
2365    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2366    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2367    if 'reply' in ri.op[ri.op_mode]:
2368        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2369    ri.cw.nl()
2370
2371    if ri.struct['request'].fixed_header:
2372        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2373        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2374        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2375        ri.cw.nl()
2376
2377    for _, attr in ri.struct["request"].member_list():
2378        attr.attr_put(ri, "req")
2379    ri.cw.nl()
2380
2381    if 'reply' in ri.op[ri.op_mode]:
2382        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2383        ri.cw.p('yrs.yarg.data = rsp;')
2384        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2385        if ri.op.value is not None:
2386            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2387        else:
2388            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2389        ri.cw.nl()
2390    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2391    ri.cw.p('if (err < 0)')
2392    if 'reply' in ri.op[ri.op_mode]:
2393        ri.cw.p('goto err_free;')
2394    else:
2395        ri.cw.p('return -1;')
2396    ri.cw.nl()
2397
2398    ri.cw.p(f"return {ret_ok};")
2399    ri.cw.nl()
2400
2401    if 'reply' in ri.op[ri.op_mode]:
2402        ri.cw.p('err_free:')
2403        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2404        ri.cw.p(f"return {ret_err};")
2405
2406    ri.cw.block_end()
2407
2408
2409def print_dump(ri):
2410    direction = "request"
2411    print_prototype(ri, direction, terminate=False)
2412    ri.cw.block_start()
2413    local_vars = ['struct ynl_dump_state yds = {};',
2414                  'struct nlmsghdr *nlh;',
2415                  'int err;']
2416
2417    if ri.struct['request'].fixed_header:
2418        local_vars += ['size_t hdr_len;',
2419                       'void *hdr;']
2420
2421    ri.cw.write_func_lvar(local_vars)
2422
2423    ri.cw.p('yds.yarg.ys = ys;')
2424    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2425    ri.cw.p("yds.yarg.data = NULL;")
2426    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2427    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2428    if ri.op.value is not None:
2429        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2430    else:
2431        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2432    ri.cw.nl()
2433    if ri.family.is_classic():
2434        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2435    else:
2436        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2437
2438    if ri.struct['request'].fixed_header:
2439        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2440        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2441        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2442        ri.cw.nl()
2443
2444    if "request" in ri.op[ri.op_mode]:
2445        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2446        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2447        ri.cw.nl()
2448        for _, attr in ri.struct["request"].member_list():
2449            attr.attr_put(ri, "req")
2450    ri.cw.nl()
2451
2452    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2453    ri.cw.p('if (err < 0)')
2454    ri.cw.p('goto free_list;')
2455    ri.cw.nl()
2456
2457    ri.cw.p('return yds.first;')
2458    ri.cw.nl()
2459    ri.cw.p('free_list:')
2460    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2461    ri.cw.p('return NULL;')
2462    ri.cw.block_end()
2463
2464
2465def call_free(ri, direction, var):
2466    return f"{op_prefix(ri, direction)}_free({var});"
2467
2468
2469def free_arg_name(direction):
2470    if direction:
2471        return direction_to_suffix[direction][1:]
2472    return 'obj'
2473
2474
2475def print_alloc_wrapper(ri, direction, struct=None):
2476    name = op_prefix(ri, direction)
2477    struct_name = name
2478    if ri.type_name_conflict:
2479        struct_name += '_'
2480
2481    args = ["void"]
2482    cnt = "1"
2483    if struct and struct.in_multi_val:
2484        args = ["unsigned int n"]
2485        cnt = "n"
2486
2487    ri.cw.write_func_prot(f'static inline struct {struct_name} *',
2488                          f"{name}_alloc", args)
2489    ri.cw.block_start()
2490    ri.cw.p(f'return calloc({cnt}, sizeof(struct {struct_name}));')
2491    ri.cw.block_end()
2492
2493
2494def print_free_prototype(ri, direction, suffix=';'):
2495    name = op_prefix(ri, direction)
2496    struct_name = name
2497    if ri.type_name_conflict:
2498        struct_name += '_'
2499    arg = free_arg_name(direction)
2500    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2501
2502
2503def print_nlflags_set(ri, direction):
2504    name = op_prefix(ri, direction)
2505    ri.cw.write_func_prot(f'static inline void', f"{name}_set_nlflags",
2506                          [f"struct {name} *req", "__u16 nl_flags"])
2507    ri.cw.block_start()
2508    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2509    ri.cw.block_end()
2510    ri.cw.nl()
2511
2512
2513def _print_type(ri, direction, struct):
2514    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2515    if not direction and ri.type_name_conflict:
2516        suffix += '_'
2517
2518    if ri.op_mode == 'dump' and not ri.type_oneside:
2519        suffix += '_dump'
2520
2521    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2522
2523    if ri.needs_nlflags(direction):
2524        ri.cw.p('__u16 _nlmsg_flags;')
2525        ri.cw.nl()
2526    if struct.fixed_header:
2527        ri.cw.p(struct.fixed_header + ' _hdr;')
2528        ri.cw.nl()
2529
2530    for type_filter in ['present', 'len', 'count']:
2531        meta_started = False
2532        for _, attr in struct.member_list():
2533            line = attr.presence_member(ri.ku_space, type_filter)
2534            if line:
2535                if not meta_started:
2536                    ri.cw.block_start(line=f"struct")
2537                    meta_started = True
2538                ri.cw.p(line)
2539        if meta_started:
2540            ri.cw.block_end(line=f'_{type_filter};')
2541    ri.cw.nl()
2542
2543    for arg in struct.inherited:
2544        ri.cw.p(f"__u32 {arg};")
2545
2546    for _, attr in struct.member_list():
2547        attr.struct_member(ri)
2548
2549    ri.cw.block_end(line=';')
2550    ri.cw.nl()
2551
2552
2553def print_type(ri, direction):
2554    _print_type(ri, direction, ri.struct[direction])
2555
2556
2557def print_type_full(ri, struct):
2558    _print_type(ri, "", struct)
2559
2560    if struct.request and struct.in_multi_val:
2561        print_alloc_wrapper(ri, "", struct)
2562        ri.cw.nl()
2563        free_rsp_nested_prototype(ri)
2564        ri.cw.nl()
2565
2566
2567def print_type_helpers(ri, direction, deref=False):
2568    print_free_prototype(ri, direction)
2569    ri.cw.nl()
2570
2571    if ri.needs_nlflags(direction):
2572        print_nlflags_set(ri, direction)
2573
2574    if ri.ku_space == 'user' and direction == 'request':
2575        for _, attr in ri.struct[direction].member_list():
2576            attr.setter(ri, ri.attr_set, direction, deref=deref)
2577    ri.cw.nl()
2578
2579
2580def print_req_type_helpers(ri):
2581    if ri.type_empty("request"):
2582        return
2583    print_alloc_wrapper(ri, "request")
2584    print_type_helpers(ri, "request")
2585
2586
2587def print_rsp_type_helpers(ri):
2588    if 'reply' not in ri.op[ri.op_mode]:
2589        return
2590    print_type_helpers(ri, "reply")
2591
2592
2593def print_parse_prototype(ri, direction, terminate=True):
2594    suffix = "_rsp" if direction == "reply" else "_req"
2595    term = ';' if terminate else ''
2596
2597    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2598                          ['const struct nlattr **tb',
2599                           f"struct {ri.op.render_name}{suffix} *req"],
2600                          suffix=term)
2601
2602
2603def print_req_type(ri):
2604    if ri.type_empty("request"):
2605        return
2606    print_type(ri, "request")
2607
2608
2609def print_req_free(ri):
2610    if 'request' not in ri.op[ri.op_mode]:
2611        return
2612    _free_type(ri, 'request', ri.struct['request'])
2613
2614
2615def print_rsp_type(ri):
2616    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2617        direction = 'reply'
2618    elif ri.op_mode == 'event':
2619        direction = 'reply'
2620    else:
2621        return
2622    print_type(ri, direction)
2623
2624
2625def print_wrapped_type(ri):
2626    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2627    if ri.op_mode == 'dump':
2628        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2629    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2630        ri.cw.p('__u16 family;')
2631        ri.cw.p('__u8 cmd;')
2632        ri.cw.p('struct ynl_ntf_base_type *next;')
2633        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2634    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2635    ri.cw.block_end(line=';')
2636    ri.cw.nl()
2637    print_free_prototype(ri, 'reply')
2638    ri.cw.nl()
2639
2640
2641def _free_type_members_iter(ri, struct):
2642    if struct.free_needs_iter():
2643        ri.cw.p('unsigned int i;')
2644        ri.cw.nl()
2645
2646
2647def _free_type_members(ri, var, struct, ref=''):
2648    for _, attr in struct.member_list():
2649        attr.free(ri, var, ref)
2650
2651
2652def _free_type(ri, direction, struct):
2653    var = free_arg_name(direction)
2654
2655    print_free_prototype(ri, direction, suffix='')
2656    ri.cw.block_start()
2657    _free_type_members_iter(ri, struct)
2658    _free_type_members(ri, var, struct)
2659    if direction:
2660        ri.cw.p(f'free({var});')
2661    ri.cw.block_end()
2662    ri.cw.nl()
2663
2664
2665def free_rsp_nested_prototype(ri):
2666        print_free_prototype(ri, "")
2667
2668
2669def free_rsp_nested(ri, struct):
2670    _free_type(ri, "", struct)
2671
2672
2673def print_rsp_free(ri):
2674    if 'reply' not in ri.op[ri.op_mode]:
2675        return
2676    _free_type(ri, 'reply', ri.struct['reply'])
2677
2678
2679def print_dump_type_free(ri):
2680    sub_type = type_name(ri, 'reply')
2681
2682    print_free_prototype(ri, 'reply', suffix='')
2683    ri.cw.block_start()
2684    ri.cw.p(f"{sub_type} *next = rsp;")
2685    ri.cw.nl()
2686    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2687    _free_type_members_iter(ri, ri.struct['reply'])
2688    ri.cw.p('rsp = next;')
2689    ri.cw.p('next = rsp->next;')
2690    ri.cw.nl()
2691
2692    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2693    ri.cw.p(f'free(rsp);')
2694    ri.cw.block_end()
2695    ri.cw.block_end()
2696    ri.cw.nl()
2697
2698
2699def print_ntf_type_free(ri):
2700    print_free_prototype(ri, 'reply', suffix='')
2701    ri.cw.block_start()
2702    _free_type_members_iter(ri, ri.struct['reply'])
2703    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2704    ri.cw.p(f'free(rsp);')
2705    ri.cw.block_end()
2706    ri.cw.nl()
2707
2708
2709def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2710    if terminate and ri and policy_should_be_static(struct.family):
2711        return
2712
2713    if terminate:
2714        prefix = 'extern '
2715    else:
2716        if ri and policy_should_be_static(struct.family):
2717            prefix = 'static '
2718        else:
2719            prefix = ''
2720
2721    suffix = ';' if terminate else ' = {'
2722
2723    max_attr = struct.attr_max_val
2724    if ri:
2725        name = ri.op.render_name
2726        if ri.op.dual_policy:
2727            name += '_' + ri.op_mode
2728    else:
2729        name = struct.render_name
2730    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2731
2732
2733def print_req_policy(cw, struct, ri=None):
2734    if ri and ri.op:
2735        cw.ifdef_block(ri.op.get('config-cond', None))
2736    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2737    for _, arg in struct.member_list():
2738        arg.attr_policy(cw)
2739    cw.p("};")
2740    cw.ifdef_block(None)
2741    cw.nl()
2742
2743
2744def kernel_can_gen_family_struct(family):
2745    return family.proto == 'genetlink'
2746
2747
2748def policy_should_be_static(family):
2749    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2750
2751
2752def print_kernel_policy_ranges(family, cw):
2753    first = True
2754    for _, attr_set in family.attr_sets.items():
2755        if attr_set.subset_of:
2756            continue
2757
2758        for _, attr in attr_set.items():
2759            if not attr.request:
2760                continue
2761            if 'full-range' not in attr.checks:
2762                continue
2763
2764            if first:
2765                cw.p('/* Integer value ranges */')
2766                first = False
2767
2768            sign = '' if attr.type[0] == 'u' else '_signed'
2769            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2770            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2771            members = []
2772            if 'min' in attr.checks:
2773                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2774            if 'max' in attr.checks:
2775                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2776            cw.write_struct_init(members)
2777            cw.block_end(line=';')
2778            cw.nl()
2779
2780
2781def print_kernel_policy_sparse_enum_validates(family, cw):
2782    first = True
2783    for _, attr_set in family.attr_sets.items():
2784        if attr_set.subset_of:
2785            continue
2786
2787        for _, attr in attr_set.items():
2788            if not attr.request:
2789                continue
2790            if not attr.enum_name:
2791                continue
2792            if 'sparse' not in attr.checks:
2793                continue
2794
2795            if first:
2796                cw.p('/* Sparse enums validation callbacks */')
2797                first = False
2798
2799            sign = '' if attr.type[0] == 'u' else '_signed'
2800            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2801            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2802                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2803            cw.block_start()
2804            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2805            enum = family.consts[attr['enum']]
2806            first_entry = True
2807            for entry in enum.entries.values():
2808                if first_entry:
2809                    first_entry = False
2810                else:
2811                    cw.p('fallthrough;')
2812                cw.p(f'case {entry.c_name}:')
2813            cw.p('return 0;')
2814            cw.block_end()
2815            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2816            cw.p('return -EINVAL;')
2817            cw.block_end()
2818            cw.nl()
2819
2820
2821def print_kernel_op_table_fwd(family, cw, terminate):
2822    exported = not kernel_can_gen_family_struct(family)
2823
2824    if not terminate or exported:
2825        cw.p(f"/* Ops table for {family.ident_name} */")
2826
2827        pol_to_struct = {'global': 'genl_small_ops',
2828                         'per-op': 'genl_ops',
2829                         'split': 'genl_split_ops'}
2830        struct_type = pol_to_struct[family.kernel_policy]
2831
2832        if not exported:
2833            cnt = ""
2834        elif family.kernel_policy == 'split':
2835            cnt = 0
2836            for op in family.ops.values():
2837                if 'do' in op:
2838                    cnt += 1
2839                if 'dump' in op:
2840                    cnt += 1
2841        else:
2842            cnt = len(family.ops)
2843
2844        qual = 'static const' if not exported else 'const'
2845        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2846        if terminate:
2847            cw.p(f"extern {line};")
2848        else:
2849            cw.block_start(line=line + ' =')
2850
2851    if not terminate:
2852        return
2853
2854    cw.nl()
2855    for name in family.hooks['pre']['do']['list']:
2856        cw.write_func_prot('int', c_lower(name),
2857                           ['const struct genl_split_ops *ops',
2858                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2859    for name in family.hooks['post']['do']['list']:
2860        cw.write_func_prot('void', c_lower(name),
2861                           ['const struct genl_split_ops *ops',
2862                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2863    for name in family.hooks['pre']['dump']['list']:
2864        cw.write_func_prot('int', c_lower(name),
2865                           ['struct netlink_callback *cb'], suffix=';')
2866    for name in family.hooks['post']['dump']['list']:
2867        cw.write_func_prot('int', c_lower(name),
2868                           ['struct netlink_callback *cb'], suffix=';')
2869
2870    cw.nl()
2871
2872    for op_name, op in family.ops.items():
2873        if op.is_async:
2874            continue
2875
2876        if 'do' in op:
2877            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2878            cw.write_func_prot('int', name,
2879                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2880
2881        if 'dump' in op:
2882            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2883            cw.write_func_prot('int', name,
2884                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2885    cw.nl()
2886
2887
2888def print_kernel_op_table_hdr(family, cw):
2889    print_kernel_op_table_fwd(family, cw, terminate=True)
2890
2891
2892def print_kernel_op_table(family, cw):
2893    print_kernel_op_table_fwd(family, cw, terminate=False)
2894    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2895        for op_name, op in family.ops.items():
2896            if op.is_async:
2897                continue
2898
2899            cw.ifdef_block(op.get('config-cond', None))
2900            cw.block_start()
2901            members = [('cmd', op.enum_name)]
2902            if 'dont-validate' in op:
2903                members.append(('validate',
2904                                ' | '.join([c_upper('genl-dont-validate-' + x)
2905                                            for x in op['dont-validate']])), )
2906            for op_mode in ['do', 'dump']:
2907                if op_mode in op:
2908                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2909                    members.append((op_mode + 'it', name))
2910            if family.kernel_policy == 'per-op':
2911                struct = Struct(family, op['attribute-set'],
2912                                type_list=op['do']['request']['attributes'])
2913
2914                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2915                members.append(('policy', name))
2916                members.append(('maxattr', struct.attr_max_val.enum_name))
2917            if 'flags' in op:
2918                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2919            cw.write_struct_init(members)
2920            cw.block_end(line=',')
2921    elif family.kernel_policy == 'split':
2922        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2923                    'dump': {'pre': 'start', 'post': 'done'}}
2924
2925        for op_name, op in family.ops.items():
2926            for op_mode in ['do', 'dump']:
2927                if op.is_async or op_mode not in op:
2928                    continue
2929
2930                cw.ifdef_block(op.get('config-cond', None))
2931                cw.block_start()
2932                members = [('cmd', op.enum_name)]
2933                if 'dont-validate' in op:
2934                    dont_validate = []
2935                    for x in op['dont-validate']:
2936                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2937                            continue
2938                        if op_mode == "dump" and x == 'strict':
2939                            continue
2940                        dont_validate.append(x)
2941
2942                    if dont_validate:
2943                        members.append(('validate',
2944                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2945                                                    for x in dont_validate])), )
2946                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2947                if 'pre' in op[op_mode]:
2948                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2949                members.append((op_mode + 'it', name))
2950                if 'post' in op[op_mode]:
2951                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2952                if 'request' in op[op_mode]:
2953                    struct = Struct(family, op['attribute-set'],
2954                                    type_list=op[op_mode]['request']['attributes'])
2955
2956                    if op.dual_policy:
2957                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2958                    else:
2959                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2960                    members.append(('policy', name))
2961                    members.append(('maxattr', struct.attr_max_val.enum_name))
2962                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2963                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2964                cw.write_struct_init(members)
2965                cw.block_end(line=',')
2966    cw.ifdef_block(None)
2967
2968    cw.block_end(line=';')
2969    cw.nl()
2970
2971
2972def print_kernel_mcgrp_hdr(family, cw):
2973    if not family.mcgrps['list']:
2974        return
2975
2976    cw.block_start('enum')
2977    for grp in family.mcgrps['list']:
2978        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2979        cw.p(grp_id)
2980    cw.block_end(';')
2981    cw.nl()
2982
2983
2984def print_kernel_mcgrp_src(family, cw):
2985    if not family.mcgrps['list']:
2986        return
2987
2988    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2989    for grp in family.mcgrps['list']:
2990        name = grp['name']
2991        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2992        cw.p('[' + grp_id + '] = { "' + name + '", },')
2993    cw.block_end(';')
2994    cw.nl()
2995
2996
2997def print_kernel_family_struct_hdr(family, cw):
2998    if not kernel_can_gen_family_struct(family):
2999        return
3000
3001    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
3002    cw.nl()
3003    if 'sock-priv' in family.kernel_family:
3004        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
3005        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
3006        cw.nl()
3007
3008
3009def print_kernel_family_struct_src(family, cw):
3010    if not kernel_can_gen_family_struct(family):
3011        return
3012
3013    if 'sock-priv' in family.kernel_family:
3014        # Generate "trampolines" to make CFI happy
3015        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
3016                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
3017                      ["void *priv"])
3018        cw.nl()
3019        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
3020                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
3021                      ["void *priv"])
3022        cw.nl()
3023
3024    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
3025    cw.p('.name\t\t= ' + family.fam_key + ',')
3026    cw.p('.version\t= ' + family.ver_key + ',')
3027    cw.p('.netnsok\t= true,')
3028    cw.p('.parallel_ops\t= true,')
3029    cw.p('.module\t\t= THIS_MODULE,')
3030    if family.kernel_policy == 'per-op':
3031        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
3032        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3033    elif family.kernel_policy == 'split':
3034        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
3035        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3036    if family.mcgrps['list']:
3037        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
3038        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
3039    if 'sock-priv' in family.kernel_family:
3040        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
3041        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
3042        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
3043    cw.block_end(';')
3044
3045
3046def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
3047    start_line = 'enum'
3048    if enum_name in obj:
3049        if obj[enum_name]:
3050            start_line = 'enum ' + c_lower(obj[enum_name])
3051    elif ckey and ckey in obj:
3052        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3053    cw.block_start(line=start_line)
3054
3055
3056def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3057    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3058    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3059    max_value = f"({cnt_name} - 1)"
3060
3061    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3062    val = 0
3063    for op in family.msgs.values():
3064        if separate_ntf and ('notify' in op or 'event' in op):
3065            continue
3066
3067        suffix = ','
3068        if op.value != val:
3069            suffix = f" = {op.value},"
3070            val = op.value
3071        cw.p(op.enum_name + suffix)
3072        val += 1
3073    cw.nl()
3074    cw.p(cnt_name + ('' if max_by_define else ','))
3075    if not max_by_define:
3076        cw.p(f"{max_name} = {max_value}")
3077    cw.block_end(line=';')
3078    if max_by_define:
3079        cw.p(f"#define {max_name} {max_value}")
3080    cw.nl()
3081
3082
3083def render_uapi_directional(family, cw, max_by_define):
3084    max_name = f"{family.op_prefix}USER_MAX"
3085    cnt_name = f"__{family.op_prefix}USER_CNT"
3086    max_value = f"({cnt_name} - 1)"
3087
3088    cw.block_start(line='enum')
3089    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3090    val = 0
3091    for op in family.msgs.values():
3092        if 'do' in op and 'event' not in op:
3093            suffix = ','
3094            if op.value and op.value != val:
3095                suffix = f" = {op.value},"
3096                val = op.value
3097            cw.p(op.enum_name + suffix)
3098            val += 1
3099    cw.nl()
3100    cw.p(cnt_name + ('' if max_by_define else ','))
3101    if not max_by_define:
3102        cw.p(f"{max_name} = {max_value}")
3103    cw.block_end(line=';')
3104    if max_by_define:
3105        cw.p(f"#define {max_name} {max_value}")
3106    cw.nl()
3107
3108    max_name = f"{family.op_prefix}KERNEL_MAX"
3109    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3110    max_value = f"({cnt_name} - 1)"
3111
3112    cw.block_start(line='enum')
3113    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3114    val = 0
3115    for op in family.msgs.values():
3116        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3117            enum_name = op.enum_name
3118            if 'event' not in op and 'notify' not in op:
3119                enum_name = f'{enum_name}_REPLY'
3120
3121            suffix = ','
3122            if op.value and op.value != val:
3123                suffix = f" = {op.value},"
3124                val = op.value
3125            cw.p(enum_name + suffix)
3126            val += 1
3127    cw.nl()
3128    cw.p(cnt_name + ('' if max_by_define else ','))
3129    if not max_by_define:
3130        cw.p(f"{max_name} = {max_value}")
3131    cw.block_end(line=';')
3132    if max_by_define:
3133        cw.p(f"#define {max_name} {max_value}")
3134    cw.nl()
3135
3136
3137def render_uapi(family, cw):
3138    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3139    hdr_prot = hdr_prot.replace('/', '_')
3140    cw.p('#ifndef ' + hdr_prot)
3141    cw.p('#define ' + hdr_prot)
3142    cw.nl()
3143
3144    defines = [(family.fam_key, family["name"]),
3145               (family.ver_key, family.get('version', 1))]
3146    cw.writes_defines(defines)
3147    cw.nl()
3148
3149    defines = []
3150    for const in family['definitions']:
3151        if const.get('header'):
3152            continue
3153
3154        if const['type'] != 'const':
3155            cw.writes_defines(defines)
3156            defines = []
3157            cw.nl()
3158
3159        # Write kdoc for enum and flags (one day maybe also structs)
3160        if const['type'] == 'enum' or const['type'] == 'flags':
3161            enum = family.consts[const['name']]
3162
3163            if enum.header:
3164                continue
3165
3166            if enum.has_doc():
3167                if enum.has_entry_doc():
3168                    cw.p('/**')
3169                    doc = ''
3170                    if 'doc' in enum:
3171                        doc = ' - ' + enum['doc']
3172                    cw.write_doc_line(enum.enum_name + doc)
3173                else:
3174                    cw.p('/*')
3175                    cw.write_doc_line(enum['doc'], indent=False)
3176                for entry in enum.entries.values():
3177                    if entry.has_doc():
3178                        doc = '@' + entry.c_name + ': ' + entry['doc']
3179                        cw.write_doc_line(doc)
3180                cw.p(' */')
3181
3182            uapi_enum_start(family, cw, const, 'name')
3183            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3184            for entry in enum.entries.values():
3185                suffix = ','
3186                if entry.value_change:
3187                    suffix = f" = {entry.user_value()}" + suffix
3188                cw.p(entry.c_name + suffix)
3189
3190            if const.get('render-max', False):
3191                cw.nl()
3192                cw.p('/* private: */')
3193                if const['type'] == 'flags':
3194                    max_name = c_upper(name_pfx + 'mask')
3195                    max_val = f' = {enum.get_mask()},'
3196                    cw.p(max_name + max_val)
3197                else:
3198                    cnt_name = enum.enum_cnt_name
3199                    max_name = c_upper(name_pfx + 'max')
3200                    if not cnt_name:
3201                        cnt_name = '__' + name_pfx + 'max'
3202                    cw.p(c_upper(cnt_name) + ',')
3203                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3204            cw.block_end(line=';')
3205            cw.nl()
3206        elif const['type'] == 'const':
3207            defines.append([c_upper(family.get('c-define-name',
3208                                               f"{family.ident_name}-{const['name']}")),
3209                            const['value']])
3210
3211    if defines:
3212        cw.writes_defines(defines)
3213        cw.nl()
3214
3215    max_by_define = family.get('max-by-define', False)
3216
3217    for _, attr_set in family.attr_sets.items():
3218        if attr_set.subset_of:
3219            continue
3220
3221        max_value = f"({attr_set.cnt_name} - 1)"
3222
3223        val = 0
3224        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3225        for _, attr in attr_set.items():
3226            suffix = ','
3227            if attr.value != val:
3228                suffix = f" = {attr.value},"
3229                val = attr.value
3230            val += 1
3231            cw.p(attr.enum_name + suffix)
3232        if attr_set.items():
3233            cw.nl()
3234        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3235        if not max_by_define:
3236            cw.p(f"{attr_set.max_name} = {max_value}")
3237        cw.block_end(line=';')
3238        if max_by_define:
3239            cw.p(f"#define {attr_set.max_name} {max_value}")
3240        cw.nl()
3241
3242    # Commands
3243    separate_ntf = 'async-prefix' in family['operations']
3244
3245    if family.msg_id_model == 'unified':
3246        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3247    elif family.msg_id_model == 'directional':
3248        render_uapi_directional(family, cw, max_by_define)
3249    else:
3250        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3251
3252    if separate_ntf:
3253        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3254        for op in family.msgs.values():
3255            if separate_ntf and not ('notify' in op or 'event' in op):
3256                continue
3257
3258            suffix = ','
3259            if 'value' in op:
3260                suffix = f" = {op['value']},"
3261            cw.p(op.enum_name + suffix)
3262        cw.block_end(line=';')
3263        cw.nl()
3264
3265    # Multicast
3266    defines = []
3267    for grp in family.mcgrps['list']:
3268        name = grp['name']
3269        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3270                        f'{name}'])
3271    cw.nl()
3272    if defines:
3273        cw.writes_defines(defines)
3274        cw.nl()
3275
3276    cw.p(f'#endif /* {hdr_prot} */')
3277
3278
3279def _render_user_ntf_entry(ri, op):
3280    if not ri.family.is_classic():
3281        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3282    else:
3283        crud_op = ri.family.req_by_value[op.rsp_value]
3284        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3285    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3286    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3287    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3288    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3289    ri.cw.block_end(line=',')
3290
3291
3292def render_user_family(family, cw, prototype):
3293    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3294    if prototype:
3295        cw.p(f'extern {symbol};')
3296        return
3297
3298    if family.ntfs:
3299        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3300        for ntf_op_name, ntf_op in family.ntfs.items():
3301            if 'notify' in ntf_op:
3302                op = family.ops[ntf_op['notify']]
3303                ri = RenderInfo(cw, family, "user", op, "notify")
3304            elif 'event' in ntf_op:
3305                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3306            else:
3307                raise Exception('Invalid notification ' + ntf_op_name)
3308            _render_user_ntf_entry(ri, ntf_op)
3309        for op_name, op in family.ops.items():
3310            if 'event' not in op:
3311                continue
3312            ri = RenderInfo(cw, family, "user", op, "event")
3313            _render_user_ntf_entry(ri, op)
3314        cw.block_end(line=";")
3315        cw.nl()
3316
3317    cw.block_start(f'{symbol} = ')
3318    cw.p(f'.name\t\t= "{family.c_name}",')
3319    if family.is_classic():
3320        cw.p(f'.is_classic\t= true,')
3321        cw.p(f'.classic_id\t= {family.get("protonum")},')
3322    if family.is_classic():
3323        if family.fixed_header:
3324            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3325    elif family.fixed_header:
3326        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3327    else:
3328        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3329    if family.ntfs:
3330        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3331        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3332    cw.block_end(line=';')
3333
3334
3335def family_contains_bitfield32(family):
3336    for _, attr_set in family.attr_sets.items():
3337        if attr_set.subset_of:
3338            continue
3339        for _, attr in attr_set.items():
3340            if attr.type == "bitfield32":
3341                return True
3342    return False
3343
3344
3345def find_kernel_root(full_path):
3346    sub_path = ''
3347    while True:
3348        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3349        full_path = os.path.dirname(full_path)
3350        maintainers = os.path.join(full_path, "MAINTAINERS")
3351        if os.path.exists(maintainers):
3352            return full_path, sub_path[:-1]
3353
3354
3355def main():
3356    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3357    parser.add_argument('--mode', dest='mode', type=str, required=True,
3358                        choices=('user', 'kernel', 'uapi'))
3359    parser.add_argument('--spec', dest='spec', type=str, required=True)
3360    parser.add_argument('--header', dest='header', action='store_true', default=None)
3361    parser.add_argument('--source', dest='header', action='store_false')
3362    parser.add_argument('--user-header', nargs='+', default=[])
3363    parser.add_argument('--cmp-out', action='store_true', default=None,
3364                        help='Do not overwrite the output file if the new output is identical to the old')
3365    parser.add_argument('--exclude-op', action='append', default=[])
3366    parser.add_argument('-o', dest='out_file', type=str, default=None)
3367    args = parser.parse_args()
3368
3369    if args.header is None:
3370        parser.error("--header or --source is required")
3371
3372    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3373
3374    try:
3375        parsed = Family(args.spec, exclude_ops)
3376        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3377            print('Spec license:', parsed.license)
3378            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3379            os.sys.exit(1)
3380    except yaml.YAMLError as exc:
3381        print(exc)
3382        os.sys.exit(1)
3383        return
3384
3385    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3386
3387    _, spec_kernel = find_kernel_root(args.spec)
3388    if args.mode == 'uapi' or args.header:
3389        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3390    else:
3391        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3392    cw.p("/* Do not edit directly, auto-generated from: */")
3393    cw.p(f"/*\t{spec_kernel} */")
3394    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3395    if args.exclude_op or args.user_header:
3396        line = ''
3397        line += ' --user-header '.join([''] + args.user_header)
3398        line += ' --exclude-op '.join([''] + args.exclude_op)
3399        cw.p(f'/* YNL-ARG{line} */')
3400    cw.nl()
3401
3402    if args.mode == 'uapi':
3403        render_uapi(parsed, cw)
3404        return
3405
3406    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3407    if args.header:
3408        cw.p('#ifndef ' + hdr_prot)
3409        cw.p('#define ' + hdr_prot)
3410        cw.nl()
3411
3412    if args.out_file:
3413        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3414    else:
3415        hdr_file = "generated_header_file.h"
3416
3417    if args.mode == 'kernel':
3418        cw.p('#include <net/netlink.h>')
3419        cw.p('#include <net/genetlink.h>')
3420        cw.nl()
3421        if not args.header:
3422            if args.out_file:
3423                cw.p(f'#include "{hdr_file}"')
3424            cw.nl()
3425        headers = ['uapi/' + parsed.uapi_header]
3426        headers += parsed.kernel_family.get('headers', [])
3427    else:
3428        cw.p('#include <stdlib.h>')
3429        cw.p('#include <string.h>')
3430        if args.header:
3431            cw.p('#include <linux/types.h>')
3432            if family_contains_bitfield32(parsed):
3433                cw.p('#include <linux/netlink.h>')
3434        else:
3435            cw.p(f'#include "{hdr_file}"')
3436            cw.p('#include "ynl.h"')
3437        headers = []
3438    for definition in parsed['definitions'] + parsed['attribute-sets']:
3439        if 'header' in definition:
3440            headers.append(definition['header'])
3441    if args.mode == 'user':
3442        headers.append(parsed.uapi_header)
3443    seen_header = []
3444    for one in headers:
3445        if one not in seen_header:
3446            cw.p(f"#include <{one}>")
3447            seen_header.append(one)
3448    cw.nl()
3449
3450    if args.mode == "user":
3451        if not args.header:
3452            cw.p("#include <linux/genetlink.h>")
3453            cw.nl()
3454            for one in args.user_header:
3455                cw.p(f'#include "{one}"')
3456        else:
3457            cw.p('struct ynl_sock;')
3458            cw.nl()
3459            render_user_family(parsed, cw, True)
3460        cw.nl()
3461
3462    if args.mode == "kernel":
3463        if args.header:
3464            for _, struct in sorted(parsed.pure_nested_structs.items()):
3465                if struct.request:
3466                    cw.p('/* Common nested types */')
3467                    break
3468            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3469                if struct.request:
3470                    print_req_policy_fwd(cw, struct)
3471            cw.nl()
3472
3473            if parsed.kernel_policy == 'global':
3474                cw.p(f"/* Global operation policy for {parsed.name} */")
3475
3476                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3477                print_req_policy_fwd(cw, struct)
3478                cw.nl()
3479
3480            if parsed.kernel_policy in {'per-op', 'split'}:
3481                for op_name, op in parsed.ops.items():
3482                    if 'do' in op and 'event' not in op:
3483                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3484                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3485                        cw.nl()
3486
3487            print_kernel_op_table_hdr(parsed, cw)
3488            print_kernel_mcgrp_hdr(parsed, cw)
3489            print_kernel_family_struct_hdr(parsed, cw)
3490        else:
3491            print_kernel_policy_ranges(parsed, cw)
3492            print_kernel_policy_sparse_enum_validates(parsed, cw)
3493
3494            for _, struct in sorted(parsed.pure_nested_structs.items()):
3495                if struct.request:
3496                    cw.p('/* Common nested types */')
3497                    break
3498            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3499                if struct.request:
3500                    print_req_policy(cw, struct)
3501            cw.nl()
3502
3503            if parsed.kernel_policy == 'global':
3504                cw.p(f"/* Global operation policy for {parsed.name} */")
3505
3506                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3507                print_req_policy(cw, struct)
3508                cw.nl()
3509
3510            for op_name, op in parsed.ops.items():
3511                if parsed.kernel_policy in {'per-op', 'split'}:
3512                    for op_mode in ['do', 'dump']:
3513                        if op_mode in op and 'request' in op[op_mode]:
3514                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3515                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3516                            print_req_policy(cw, ri.struct['request'], ri=ri)
3517                            cw.nl()
3518
3519            print_kernel_op_table(parsed, cw)
3520            print_kernel_mcgrp_src(parsed, cw)
3521            print_kernel_family_struct_src(parsed, cw)
3522
3523    if args.mode == "user":
3524        if args.header:
3525            cw.p('/* Enums */')
3526            put_op_name_fwd(parsed, cw)
3527
3528            for name, const in parsed.consts.items():
3529                if isinstance(const, EnumSet):
3530                    put_enum_to_str_fwd(parsed, cw, const)
3531            cw.nl()
3532
3533            cw.p('/* Common nested types */')
3534            for attr_set, struct in parsed.pure_nested_structs.items():
3535                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3536                print_type_full(ri, struct)
3537
3538            for op_name, op in parsed.ops.items():
3539                cw.p(f"/* ============== {op.enum_name} ============== */")
3540
3541                if 'do' in op and 'event' not in op:
3542                    cw.p(f"/* {op.enum_name} - do */")
3543                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3544                    print_req_type(ri)
3545                    print_req_type_helpers(ri)
3546                    cw.nl()
3547                    print_rsp_type(ri)
3548                    print_rsp_type_helpers(ri)
3549                    cw.nl()
3550                    print_req_prototype(ri)
3551                    cw.nl()
3552
3553                if 'dump' in op:
3554                    cw.p(f"/* {op.enum_name} - dump */")
3555                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3556                    print_req_type(ri)
3557                    print_req_type_helpers(ri)
3558                    if not ri.type_consistent or ri.type_oneside:
3559                        print_rsp_type(ri)
3560                    print_wrapped_type(ri)
3561                    print_dump_prototype(ri)
3562                    cw.nl()
3563
3564                if op.has_ntf:
3565                    cw.p(f"/* {op.enum_name} - notify */")
3566                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3567                    if not ri.type_consistent:
3568                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3569                    print_wrapped_type(ri)
3570
3571            for op_name, op in parsed.ntfs.items():
3572                if 'event' in op:
3573                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3574                    cw.p(f"/* {op.enum_name} - event */")
3575                    print_rsp_type(ri)
3576                    cw.nl()
3577                    print_wrapped_type(ri)
3578            cw.nl()
3579        else:
3580            cw.p('/* Enums */')
3581            put_op_name(parsed, cw)
3582
3583            for name, const in parsed.consts.items():
3584                if isinstance(const, EnumSet):
3585                    put_enum_to_str(parsed, cw, const)
3586            cw.nl()
3587
3588            has_recursive_nests = False
3589            cw.p('/* Policies */')
3590            for struct in parsed.pure_nested_structs.values():
3591                if struct.recursive:
3592                    put_typol_fwd(cw, struct)
3593                    has_recursive_nests = True
3594            if has_recursive_nests:
3595                cw.nl()
3596            for struct in parsed.pure_nested_structs.values():
3597                put_typol(cw, struct)
3598            for name in parsed.root_sets:
3599                struct = Struct(parsed, name)
3600                put_typol(cw, struct)
3601
3602            cw.p('/* Common nested types */')
3603            if has_recursive_nests:
3604                for attr_set, struct in parsed.pure_nested_structs.items():
3605                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3606                    free_rsp_nested_prototype(ri)
3607                    if struct.request:
3608                        put_req_nested_prototype(ri, struct)
3609                    if struct.reply:
3610                        parse_rsp_nested_prototype(ri, struct)
3611                cw.nl()
3612            for attr_set, struct in parsed.pure_nested_structs.items():
3613                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3614
3615                free_rsp_nested(ri, struct)
3616                if struct.request:
3617                    put_req_nested(ri, struct)
3618                if struct.reply:
3619                    parse_rsp_nested(ri, struct)
3620
3621            for op_name, op in parsed.ops.items():
3622                cw.p(f"/* ============== {op.enum_name} ============== */")
3623                if 'do' in op and 'event' not in op:
3624                    cw.p(f"/* {op.enum_name} - do */")
3625                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3626                    print_req_free(ri)
3627                    print_rsp_free(ri)
3628                    parse_rsp_msg(ri)
3629                    print_req(ri)
3630                    cw.nl()
3631
3632                if 'dump' in op:
3633                    cw.p(f"/* {op.enum_name} - dump */")
3634                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3635                    if not ri.type_consistent or ri.type_oneside:
3636                        parse_rsp_msg(ri, deref=True)
3637                    print_req_free(ri)
3638                    print_dump_type_free(ri)
3639                    print_dump(ri)
3640                    cw.nl()
3641
3642                if op.has_ntf:
3643                    cw.p(f"/* {op.enum_name} - notify */")
3644                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3645                    if not ri.type_consistent:
3646                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3647                    print_ntf_type_free(ri)
3648
3649            for op_name, op in parsed.ntfs.items():
3650                if 'event' in op:
3651                    cw.p(f"/* {op.enum_name} - event */")
3652
3653                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3654                    parse_rsp_msg(ri)
3655
3656                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3657                    print_ntf_type_free(ri)
3658            cw.nl()
3659            render_user_family(parsed, cw, False)
3660
3661    if args.header:
3662        cw.p(f'#endif /* {hdr_prot} */')
3663
3664
3665if __name__ == "__main__":
3666    main()
3667