xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 37a93dd5c49b5fda807fd204edf2547c3493319c)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3#
4# pylint: disable=line-too-long, missing-class-docstring, missing-function-docstring
5# pylint: disable=too-many-positional-arguments, too-many-arguments, too-many-statements
6# pylint: disable=too-many-branches, too-many-locals, too-many-instance-attributes
7# pylint: disable=too-many-nested-blocks, too-many-lines, too-few-public-methods
8# pylint: disable=broad-exception-raised, broad-exception-caught, protected-access
9
10"""
11ynl_gen_c
12
13A YNL to C code generator for both kernel and userspace protocol stubs.
14"""
15
16import argparse
17import filecmp
18import pathlib
19import os
20import re
21import shutil
22import sys
23import tempfile
24import yaml as pyyaml
25
26# pylint: disable=no-name-in-module,wrong-import-position
27sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
28from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
29from lib import SpecSubMessage
30
31
32def c_upper(name):
33    return name.upper().replace('-', '_')
34
35
36def c_lower(name):
37    return name.lower().replace('-', '_')
38
39
40def limit_to_number(name):
41    """
42    Turn a string limit like u32-max or s64-min into its numerical value
43    """
44    if name[0] == 'u' and name.endswith('-min'):
45        return 0
46    width = int(name[1:-4])
47    if name[0] == 's':
48        width -= 1
49    value = (1 << width) - 1
50    if name[0] == 's' and name.endswith('-min'):
51        value = -value - 1
52    return value
53
54
55class BaseNlLib:
56    def get_family_id(self):
57        return 'ys->family_id'
58
59
60class Type(SpecAttr):
61    def __init__(self, family, attr_set, attr, value):
62        super().__init__(family, attr_set, attr, value)
63
64        self.attr = attr
65        self.attr_set = attr_set
66        self.type = attr['type']
67        self.checks = attr.get('checks', {})
68
69        self.request = False
70        self.reply = False
71
72        self.is_selector = False
73
74        if 'len' in attr:
75            self.len = attr['len']
76
77        if 'nested-attributes' in attr:
78            nested = attr['nested-attributes']
79        elif 'sub-message' in attr:
80            nested = attr['sub-message']
81        else:
82            nested = None
83
84        if nested:
85            self.nested_attrs = nested
86            if self.nested_attrs == family.name:
87                self.nested_render_name = c_lower(f"{family.ident_name}")
88            else:
89                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
90
91            if self.nested_attrs in self.family.consts:
92                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
93            else:
94                self.nested_struct_type = 'struct ' + self.nested_render_name
95
96        self.c_name = c_lower(self.name)
97        if self.c_name in _C_KW:
98            self.c_name += '_'
99        if self.c_name[0].isdigit():
100            self.c_name = '_' + self.c_name
101
102        # Added by resolve():
103        self.enum_name = None
104        delattr(self, "enum_name")
105
106    def _get_real_attr(self):
107        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
108        return self.family.attr_sets[self.attr_set.subset_of][self.name]
109
110    def set_request(self):
111        self.request = True
112        if self.attr_set.subset_of:
113            self._get_real_attr().set_request()
114
115    def set_reply(self):
116        self.reply = True
117        if self.attr_set.subset_of:
118            self._get_real_attr().set_reply()
119
120    def get_limit(self, limit, default=None):
121        value = self.checks.get(limit, default)
122        if value is None:
123            return value
124        if isinstance(value, int):
125            return value
126        if value in self.family.consts:
127            return self.family.consts[value]["value"]
128        return limit_to_number(value)
129
130    def get_limit_str(self, limit, default=None, suffix=''):
131        value = self.checks.get(limit, default)
132        if value is None:
133            return ''
134        if isinstance(value, int):
135            return str(value) + suffix
136        if value in self.family.consts:
137            const = self.family.consts[value]
138            if const.get('header'):
139                return c_upper(value)
140            return c_upper(f"{self.family['name']}-{value}")
141        return c_upper(value)
142
143    def resolve(self):
144        if 'parent-sub-message' in self.attr:
145            enum_name = self.attr['parent-sub-message'].enum_name
146        elif 'name-prefix' in self.attr:
147            enum_name = f"{self.attr['name-prefix']}{self.name}"
148        else:
149            enum_name = f"{self.attr_set.name_prefix}{self.name}"
150        self.enum_name = c_upper(enum_name)
151
152        if self.attr_set.subset_of:
153            if self.checks != self._get_real_attr().checks:
154                raise Exception("Overriding checks not supported by codegen, yet")
155
156    def is_multi_val(self):
157        return None
158
159    def is_scalar(self):
160        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
161
162    def is_recursive(self):
163        return False
164
165    def is_recursive_for_op(self, ri):
166        return self.is_recursive() and not ri.op
167
168    def presence_type(self):
169        return 'present'
170
171    def presence_member(self, space, type_filter):
172        if self.presence_type() != type_filter:
173            return ''
174
175        if self.presence_type() == 'present':
176            pfx = '__' if space == 'user' else ''
177            return f"{pfx}u32 {self.c_name}:1;"
178
179        if self.presence_type() in {'len', 'count'}:
180            pfx = '__' if space == 'user' else ''
181            return f"{pfx}u32 {self.c_name};"
182        return ''
183
184    def _complex_member_type(self, _ri):
185        return None
186
187    def free_needs_iter(self):
188        return False
189
190    def _free_lines(self, _ri, var, ref):
191        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
192            return [f'free({var}->{ref}{self.c_name});']
193        return []
194
195    def free(self, ri, var, ref):
196        lines = self._free_lines(ri, var, ref)
197        for line in lines:
198            ri.cw.p(line)
199
200    # pylint: disable=assignment-from-none
201    def arg_member(self, ri):
202        member = self._complex_member_type(ri)
203        if member is not None:
204            spc = ' ' if member[-1] != '*' else ''
205            arg = [member + spc + '*' + self.c_name]
206            if self.presence_type() == 'count':
207                arg += ['unsigned int n_' + self.c_name]
208            return arg
209        raise Exception(f"Struct member not implemented for class type {self.type}")
210
211    def struct_member(self, ri):
212        member = self._complex_member_type(ri)
213        if member is not None:
214            ptr = '*' if self.is_multi_val() else ''
215            if self.is_recursive_for_op(ri):
216                ptr = '*'
217            spc = ' ' if member[-1] != '*' else ''
218            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
219            return
220        members = self.arg_member(ri)
221        for one in members:
222            ri.cw.p(one + ';')
223
224    def _attr_policy(self, policy):
225        return '{ .type = ' + policy + ', }'
226
227    def attr_policy(self, cw):
228        policy = f'NLA_{c_upper(self.type)}'
229        if self.attr.get('byte-order') == 'big-endian':
230            if self.type in {'u16', 'u32'}:
231                policy = f'NLA_BE{self.type[1:]}'
232
233        spec = self._attr_policy(policy)
234        cw.p(f"\t[{self.enum_name}] = {spec},")
235
236    def _attr_typol(self):
237        raise Exception(f"Type policy not implemented for class type {self.type}")
238
239    def attr_typol(self, cw):
240        typol = self._attr_typol()
241        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
242
243    def _attr_put_line(self, ri, var, line):
244        presence = self.presence_type()
245        if presence in {'present', 'len'}:
246            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
247        ri.cw.p(f"{line};")
248
249    def _attr_put_simple(self, ri, var, put_type):
250        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
251        self._attr_put_line(ri, var, line)
252
253    def attr_put(self, ri, var):
254        raise Exception(f"Put not implemented for class type {self.type}")
255
256    def _attr_get(self, ri, var):
257        raise Exception(f"Attr get not implemented for class type {self.type}")
258
259    def attr_get(self, ri, var, first):
260        lines, init_lines, _ = self._attr_get(ri, var)
261        if isinstance(lines, str):
262            lines = [lines]
263        if isinstance(init_lines, str):
264            init_lines = [init_lines]
265
266        kw = 'if' if first else 'else if'
267        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
268
269        if not self.is_multi_val():
270            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
271            ri.cw.p("return YNL_PARSE_CB_ERROR;")
272            if self.presence_type() == 'present':
273                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
274
275        if init_lines:
276            ri.cw.nl()
277            for line in init_lines:
278                ri.cw.p(line)
279
280        for line in lines:
281            ri.cw.p(line)
282        ri.cw.block_end()
283        return True
284
285    def _setter_lines(self, ri, member, presence):
286        raise Exception(f"Setter not implemented for class type {self.type}")
287
288    def setter(self, ri, _space, direction, deref=False, ref=None, var="req"):
289        ref = (ref if ref else []) + [self.c_name]
290        member = f"{var}->{'.'.join(ref)}"
291
292        local_vars = []
293        if self.free_needs_iter():
294            local_vars += ['unsigned int i;']
295
296        code = []
297        presence = ''
298        # pylint: disable=consider-using-enumerate
299        for i in range(0, len(ref)):
300            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
301            # Every layer below last is a nest, so we know it uses bit presence
302            # last layer is "self" and may be a complex type
303            if i == len(ref) - 1 and self.presence_type() != 'present':
304                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
305                continue
306            code.append(presence + ' = 1;')
307        ref_path = '.'.join(ref[:-1])
308        if ref_path:
309            ref_path += '.'
310        code += self._free_lines(ri, var, ref_path)
311        code += self._setter_lines(ri, member, presence)
312
313        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
314        free = bool([x for x in code if 'free(' in x])
315        alloc = bool([x for x in code if 'alloc(' in x])
316        if free and not alloc:
317            func_name = '__' + func_name
318        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
319                         body=code,
320                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
321
322
323class TypeUnused(Type):
324    def presence_type(self):
325        return ''
326
327    def arg_member(self, ri):
328        return []
329
330    def _attr_get(self, ri, var):
331        return ['return YNL_PARSE_CB_ERROR;'], None, None
332
333    def _attr_typol(self):
334        return '.type = YNL_PT_REJECT, '
335
336    def attr_policy(self, cw):
337        pass
338
339    def attr_put(self, ri, var):
340        pass
341
342    def attr_get(self, ri, var, first):
343        pass
344
345    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
346        pass
347
348
349class TypePad(Type):
350    def presence_type(self):
351        return ''
352
353    def arg_member(self, ri):
354        return []
355
356    def _attr_typol(self):
357        return '.type = YNL_PT_IGNORE, '
358
359    def attr_put(self, ri, var):
360        pass
361
362    def attr_get(self, ri, var, first):
363        pass
364
365    def attr_policy(self, cw):
366        pass
367
368    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
369        pass
370
371
372class TypeScalar(Type):
373    def __init__(self, family, attr_set, attr, value):
374        super().__init__(family, attr_set, attr, value)
375
376        self.byte_order_comment = ''
377        if 'byte-order' in attr:
378            self.byte_order_comment = f" /* {attr['byte-order']} */"
379
380        # Classic families have some funny enums, don't bother
381        # computing checks, since we only need them for kernel policies
382        if not family.is_classic():
383            self._init_checks()
384
385        # Added by resolve():
386        self.is_bitfield = None
387        delattr(self, "is_bitfield")
388        self.type_name = None
389        delattr(self, "type_name")
390
391    def resolve(self):
392        self.resolve_up(super())
393
394        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
395            self.is_bitfield = True
396        elif 'enum' in self.attr:
397            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
398        else:
399            self.is_bitfield = False
400
401        if not self.is_bitfield and 'enum' in self.attr:
402            self.type_name = self.family.consts[self.attr['enum']].user_type
403        elif self.is_auto_scalar:
404            self.type_name = '__' + self.type[0] + '64'
405        else:
406            self.type_name = '__' + self.type
407
408    def _init_checks(self):
409        if 'enum' in self.attr:
410            enum = self.family.consts[self.attr['enum']]
411            low, high = enum.value_range()
412            if low is None and high is None:
413                self.checks['sparse'] = True
414            else:
415                if 'min' not in self.checks:
416                    if low != 0 or self.type[0] == 's':
417                        self.checks['min'] = low
418                if 'max' not in self.checks:
419                    self.checks['max'] = high
420
421        if 'min' in self.checks and 'max' in self.checks:
422            if self.get_limit('min') > self.get_limit('max'):
423                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
424            self.checks['range'] = True
425
426        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
427        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
428        if low < 0 and self.type[0] == 'u':
429            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
430        if low < -32768 or high > 32767:
431            self.checks['full-range'] = True
432
433    # pylint: disable=too-many-return-statements
434    def _attr_policy(self, policy):
435        if 'flags-mask' in self.checks or self.is_bitfield:
436            if self.is_bitfield:
437                enum = self.family.consts[self.attr['enum']]
438                mask = enum.get_mask(as_flags=True)
439            else:
440                flags = self.family.consts[self.checks['flags-mask']]
441                flag_cnt = len(flags['entries'])
442                mask = (1 << flag_cnt) - 1
443            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
444        if 'full-range' in self.checks:
445            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
446        if 'range' in self.checks:
447            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
448        if 'min' in self.checks:
449            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
450        if 'max' in self.checks:
451            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
452        if 'sparse' in self.checks:
453            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
454        return super()._attr_policy(policy)
455
456    def _attr_typol(self):
457        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
458
459    def arg_member(self, ri):
460        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
461
462    def attr_put(self, ri, var):
463        self._attr_put_simple(ri, var, self.type)
464
465    def _attr_get(self, ri, var):
466        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
467
468    def _setter_lines(self, ri, member, presence):
469        return [f"{member} = {self.c_name};"]
470
471
472class TypeFlag(Type):
473    def arg_member(self, ri):
474        return []
475
476    def _attr_typol(self):
477        return '.type = YNL_PT_FLAG, '
478
479    def attr_put(self, ri, var):
480        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
481
482    def _attr_get(self, ri, var):
483        return [], None, None
484
485    def _setter_lines(self, ri, member, presence):
486        return []
487
488
489class TypeString(Type):
490    def arg_member(self, ri):
491        return [f"const char *{self.c_name}"]
492
493    def presence_type(self):
494        return 'len'
495
496    def struct_member(self, ri):
497        ri.cw.p(f"char *{self.c_name};")
498
499    def _attr_typol(self):
500        typol = '.type = YNL_PT_NUL_STR, '
501        if self.is_selector:
502            typol += '.is_selector = 1, '
503        return typol
504
505    def _attr_policy(self, policy):
506        if 'exact-len' in self.checks:
507            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
508        else:
509            mem = '{ .type = ' + policy
510            if 'max-len' in self.checks:
511                mem += ', .len = ' + self.get_limit_str('max-len')
512            mem += ', }'
513        return mem
514
515    def attr_policy(self, cw):
516        if self.checks.get('unterminated-ok', False):
517            policy = 'NLA_STRING'
518        else:
519            policy = 'NLA_NUL_STRING'
520
521        spec = self._attr_policy(policy)
522        cw.p(f"\t[{self.enum_name}] = {spec},")
523
524    def attr_put(self, ri, var):
525        self._attr_put_simple(ri, var, 'str')
526
527    def _attr_get(self, ri, var):
528        len_mem = var + '->_len.' + self.c_name
529        return [f"{len_mem} = len;",
530                f"{var}->{self.c_name} = malloc(len + 1);",
531                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
532                f"{var}->{self.c_name}[len] = 0;"], \
533               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
534               ['unsigned int len;']
535
536    def _setter_lines(self, ri, member, presence):
537        return [f"{presence} = strlen({self.c_name});",
538                f"{member} = malloc({presence} + 1);",
539                f'memcpy({member}, {self.c_name}, {presence});',
540                f'{member}[{presence}] = 0;']
541
542
543class TypeBinary(Type):
544    def arg_member(self, ri):
545        return [f"const void *{self.c_name}", 'size_t len']
546
547    def presence_type(self):
548        return 'len'
549
550    def struct_member(self, ri):
551        ri.cw.p(f"void *{self.c_name};")
552
553    def _attr_typol(self):
554        return '.type = YNL_PT_BINARY,'
555
556    def _attr_policy(self, policy):
557        if len(self.checks) == 0:
558            pass
559        elif len(self.checks) == 1:
560            check_name = list(self.checks)[0]
561            if check_name not in {'exact-len', 'min-len', 'max-len'}:
562                raise Exception('Unsupported check for binary type: ' + check_name)
563        else:
564            raise Exception('More than one check for binary type not implemented, yet')
565
566        if len(self.checks) == 0:
567            mem = '{ .type = NLA_BINARY, }'
568        elif 'exact-len' in self.checks:
569            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
570        elif 'min-len' in self.checks:
571            mem = 'NLA_POLICY_MIN_LEN(' + self.get_limit_str('min-len') + ')'
572        elif 'max-len' in self.checks:
573            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
574        else:
575            raise Exception('Failed to process policy check for binary type')
576
577        return mem
578
579    def attr_put(self, ri, var):
580        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
581                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
582
583    def _attr_get(self, ri, var):
584        len_mem = var + '->_len.' + self.c_name
585        return [f"{len_mem} = len;",
586                f"{var}->{self.c_name} = malloc(len);",
587                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
588               ['len = ynl_attr_data_len(attr);'], \
589               ['unsigned int len;']
590
591    def _setter_lines(self, ri, member, presence):
592        return [f"{presence} = len;",
593                f"{member} = malloc({presence});",
594                f'memcpy({member}, {self.c_name}, {presence});']
595
596
597class TypeBinaryStruct(TypeBinary):
598    def struct_member(self, ri):
599        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
600
601    def _attr_get(self, ri, var):
602        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
603        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
604        return [f"{len_mem} = len;",
605                f"if (len < {struct_sz})",
606                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
607                "else",
608                f"{var}->{self.c_name} = malloc(len);",
609                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
610               ['len = ynl_attr_data_len(attr);'], \
611               ['unsigned int len;']
612
613
614class TypeBinaryScalarArray(TypeBinary):
615    def arg_member(self, ri):
616        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
617
618    def presence_type(self):
619        return 'count'
620
621    def struct_member(self, ri):
622        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
623
624    def attr_put(self, ri, var):
625        presence = self.presence_type()
626        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
627        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
628        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
629                f"{var}->{self.c_name}, i);")
630        ri.cw.block_end()
631
632    def _attr_get(self, ri, var):
633        len_mem = var + '->_count.' + self.c_name
634        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
635                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
636                f"{var}->{self.c_name} = malloc(len);",
637                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
638               ['len = ynl_attr_data_len(attr);'], \
639               ['unsigned int len;']
640
641    def _setter_lines(self, ri, member, presence):
642        return [f"{presence} = count;",
643                f"count *= sizeof(__{self.get('sub-type')});",
644                f"{member} = malloc(count);",
645                f'memcpy({member}, {self.c_name}, count);']
646
647
648class TypeBitfield32(Type):
649    def _complex_member_type(self, _ri):
650        return "struct nla_bitfield32"
651
652    def _attr_typol(self):
653        return '.type = YNL_PT_BITFIELD32, '
654
655    def _attr_policy(self, policy):
656        if 'enum' not in self.attr:
657            raise Exception('Enum required for bitfield32 attr')
658        enum = self.family.consts[self.attr['enum']]
659        mask = enum.get_mask(as_flags=True)
660        return f"NLA_POLICY_BITFIELD32({mask})"
661
662    def attr_put(self, ri, var):
663        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
664        self._attr_put_line(ri, var, line)
665
666    def _attr_get(self, ri, var):
667        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
668
669    def _setter_lines(self, ri, member, presence):
670        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
671
672
673class TypeNest(Type):
674    def is_recursive(self):
675        return self.family.pure_nested_structs[self.nested_attrs].recursive
676
677    def _complex_member_type(self, _ri):
678        return self.nested_struct_type
679
680    def _free_lines(self, ri, var, ref):
681        lines = []
682        at = '&'
683        if self.is_recursive_for_op(ri):
684            at = ''
685            lines += [f'if ({var}->{ref}{self.c_name})']
686        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
687        return lines
688
689    def _attr_typol(self):
690        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
691
692    def _attr_policy(self, policy):
693        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
694
695    def attr_put(self, ri, var):
696        at = '' if self.is_recursive_for_op(ri) else '&'
697        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
698                            f"{self.enum_name}, {at}{var}->{self.c_name})")
699
700    def _attr_get(self, ri, var):
701        pns = self.family.pure_nested_structs[self.nested_attrs]
702        args = ["&parg", "attr"]
703        for sel in pns.external_selectors():
704            args.append(f'{var}->{sel.name}')
705        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
706                     "return YNL_PARSE_CB_ERROR;"]
707        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
708                      f"parg.data = &{var}->{self.c_name};"]
709        return get_lines, init_lines, None
710
711    def setter(self, ri, _space, direction, deref=False, ref=None, var="req"):
712        ref = (ref if ref else []) + [self.c_name]
713
714        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
715            if attr.is_recursive():
716                continue
717            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref,
718                        var=var)
719
720
721class TypeMultiAttr(Type):
722    def __init__(self, family, attr_set, attr, value, base_type):
723        super().__init__(family, attr_set, attr, value)
724
725        self.base_type = base_type
726
727    def is_multi_val(self):
728        return True
729
730    def presence_type(self):
731        return 'count'
732
733    def _complex_member_type(self, ri):
734        if 'type' not in self.attr or self.attr['type'] == 'nest':
735            return self.nested_struct_type
736        if self.attr['type'] == 'binary' and 'struct' in self.attr:
737            return None  # use arg_member()
738        if self.attr['type'] == 'string':
739            return 'struct ynl_string *'
740        if self.attr['type'] in scalars:
741            scalar_pfx = '__' if ri.ku_space == 'user' else ''
742            if self.is_auto_scalar:
743                name = self.type[0] + '64'
744            else:
745                name = self.attr['type']
746            return scalar_pfx + name
747        raise Exception(f"Sub-type {self.attr['type']} not supported yet")
748
749    def arg_member(self, ri):
750        if self.type == 'binary' and 'struct' in self.attr:
751            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
752                    f'unsigned int n_{self.c_name}']
753        return super().arg_member(ri)
754
755    def free_needs_iter(self):
756        return self.attr['type'] in {'nest', 'string'}
757
758    def _free_lines(self, _ri, var, ref):
759        lines = []
760        if self.attr['type'] in scalars:
761            lines += [f"free({var}->{ref}{self.c_name});"]
762        elif self.attr['type'] == 'binary':
763            lines += [f"free({var}->{ref}{self.c_name});"]
764        elif self.attr['type'] == 'string':
765            lines += [
766                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
767                f"free({var}->{ref}{self.c_name}[i]);",
768                f"free({var}->{ref}{self.c_name});",
769            ]
770        elif 'type' not in self.attr or self.attr['type'] == 'nest':
771            lines += [
772                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
773                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
774                f"free({var}->{ref}{self.c_name});",
775            ]
776        else:
777            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
778        return lines
779
780    def _attr_policy(self, policy):
781        return self.base_type._attr_policy(policy)
782
783    def _attr_typol(self):
784        return self.base_type._attr_typol()
785
786    def _attr_get(self, ri, var):
787        return f'n_{self.c_name}++;', None, None
788
789    def attr_put(self, ri, var):
790        if self.attr['type'] in scalars:
791            put_type = self.type
792            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
793            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
794        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
795            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
796            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
797        elif self.attr['type'] == 'string':
798            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
799            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
800        elif 'type' not in self.attr or self.attr['type'] == 'nest':
801            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
802            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
803                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
804        else:
805            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
806
807    def _setter_lines(self, ri, member, presence):
808        return [f"{member} = {self.c_name};",
809                f"{presence} = n_{self.c_name};"]
810
811
812class TypeIndexedArray(Type):
813    def is_multi_val(self):
814        return True
815
816    def presence_type(self):
817        return 'count'
818
819    def _complex_member_type(self, ri):
820        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
821            return self.nested_struct_type
822        if self.attr['sub-type'] in scalars:
823            scalar_pfx = '__' if ri.ku_space == 'user' else ''
824            return scalar_pfx + self.attr['sub-type']
825        if self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
826            return None  # use arg_member()
827        raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
828
829    def arg_member(self, ri):
830        if self.sub_type == 'binary' and 'exact-len' in self.checks:
831            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
832                    f'unsigned int n_{self.c_name}']
833        return super().arg_member(ri)
834
835    def _attr_policy(self, policy):
836        if self.attr['sub-type'] == 'nest':
837            return f'NLA_POLICY_NESTED_ARRAY({self.nested_render_name}_nl_policy)'
838        return super()._attr_policy(policy)
839
840    def _attr_typol(self):
841        if self.attr['sub-type'] in scalars:
842            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
843        if self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
844            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
845        if self.attr['sub-type'] == 'nest':
846            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
847        raise Exception(f"Typol for IndexedArray sub-type {self.attr['sub-type']} not supported, yet")
848
849    def _attr_get(self, ri, var):
850        local_vars = ['const struct nlattr *attr2;']
851        get_lines = [f'attr_{self.c_name} = attr;',
852                     'ynl_attr_for_each_nested(attr2, attr) {',
853                     '\tif (__ynl_attr_validate(yarg, attr2, type))',
854                     '\t\treturn YNL_PARSE_CB_ERROR;',
855                     f'\tn_{self.c_name}++;',
856                     '}']
857        return get_lines, None, local_vars
858
859    def attr_put(self, ri, var):
860        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
861        if self.sub_type in scalars:
862            put_type = self.sub_type
863            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
864            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
865            ri.cw.block_end()
866        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
867            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
868            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
869        elif self.sub_type == 'nest':
870            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
871            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
872        else:
873            raise Exception(f"Put for IndexedArray sub-type {self.attr['sub-type']} not supported, yet")
874        ri.cw.p('ynl_attr_nest_end(nlh, array);')
875
876    def _setter_lines(self, ri, member, presence):
877        return [f"{member} = {self.c_name};",
878                f"{presence} = n_{self.c_name};"]
879
880    def free_needs_iter(self):
881        return self.sub_type == 'nest'
882
883    def _free_lines(self, _ri, var, ref):
884        lines = []
885        if self.sub_type == 'nest':
886            lines += [
887                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
888                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
889            ]
890        lines += (f"free({var}->{ref}{self.c_name});",)
891        return lines
892
893class TypeNestTypeValue(Type):
894    def _complex_member_type(self, _ri):
895        return self.nested_struct_type
896
897    def _attr_typol(self):
898        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
899
900    def _attr_get(self, ri, var):
901        prev = 'attr'
902        tv_args = ''
903        get_lines = []
904        local_vars = []
905        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
906                      f"parg.data = &{var}->{self.c_name};"]
907        if 'type-value' in self.attr:
908            tv_names = [c_lower(x) for x in self.attr["type-value"]]
909            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
910            local_vars += [f'__u32 {", ".join(tv_names)};']
911            for level in self.attr["type-value"]:
912                level = c_lower(level)
913                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
914                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
915                prev = 'attr_' + level
916
917            tv_args = f", {', '.join(tv_names)}"
918
919        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
920        return get_lines, init_lines, local_vars
921
922
923class TypeSubMessage(TypeNest):
924    def __init__(self, family, attr_set, attr, value):
925        super().__init__(family, attr_set, attr, value)
926
927        self.selector = Selector(attr, attr_set)
928
929    def _attr_typol(self):
930        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
931        typol += '.is_submsg = 1, '
932        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
933        # support external selectors. No family uses sub-messages with external
934        # selector for requests so this is fine for now.
935        if not self.selector.is_external():
936            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
937        return typol
938
939    def _attr_get(self, ri, var):
940        selector = self['selector']
941        sel = c_lower(selector)
942        if self.selector.is_external():
943            sel_var = f"_sel_{sel}"
944        else:
945            sel_var = f"{var}->{sel}"
946        get_lines = [f'if (!{sel_var})',
947                     f'return ynl_submsg_failed(yarg, "{self.name}", "{selector}");',
948                     f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
949                     "return YNL_PARSE_CB_ERROR;"]
950        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
951                      f"parg.data = &{var}->{self.c_name};"]
952        return get_lines, init_lines, None
953
954
955class Selector:
956    def __init__(self, msg_attr, attr_set):
957        self.name = msg_attr["selector"]
958
959        if self.name in attr_set:
960            self.attr = attr_set[self.name]
961            self.attr.is_selector = True
962            self._external = False
963        else:
964            # The selector will need to get passed down thru the structs
965            self.attr = None
966            self._external = True
967
968    def set_attr(self, attr):
969        self.attr = attr
970
971    def is_external(self):
972        return self._external
973
974
975class Struct:
976    def __init__(self, family, space_name, type_list=None, fixed_header=None,
977                 inherited=None, submsg=None):
978        self.family = family
979        self.space_name = space_name
980        self.attr_set = family.attr_sets[space_name]
981        # Use list to catch comparisons with empty sets
982        self._inherited = inherited if inherited is not None else []
983        self.inherited = []
984        self.fixed_header = None
985        if fixed_header:
986            self.fixed_header = 'struct ' + c_lower(fixed_header)
987        self.submsg = submsg
988
989        self.nested = type_list is None
990        if family.name == c_lower(space_name):
991            self.render_name = c_lower(family.ident_name)
992        else:
993            self.render_name = c_lower(family.ident_name + '-' + space_name)
994        self.struct_name = 'struct ' + self.render_name
995        if self.nested and space_name in family.consts:
996            self.struct_name += '_'
997        self.ptr_name = self.struct_name + ' *'
998        # All attr sets this one contains, directly or multiple levels down
999        self.child_nests = set()
1000
1001        self.request = False
1002        self.reply = False
1003        self.recursive = False
1004        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
1005
1006        self.attr_list = []
1007        self.attrs = {}
1008        if type_list is not None:
1009            for t in type_list:
1010                self.attr_list.append((t, self.attr_set[t]),)
1011        else:
1012            for t in self.attr_set:
1013                self.attr_list.append((t, self.attr_set[t]),)
1014
1015        max_val = 0
1016        self.attr_max_val = None
1017        for name, attr in self.attr_list:
1018            if attr.value >= max_val:
1019                max_val = attr.value
1020                self.attr_max_val = attr
1021            self.attrs[name] = attr
1022
1023    def __iter__(self):
1024        yield from self.attrs
1025
1026    def __getitem__(self, key):
1027        return self.attrs[key]
1028
1029    def member_list(self):
1030        return self.attr_list
1031
1032    def set_inherited(self, new_inherited):
1033        if self._inherited != new_inherited:
1034            raise Exception("Inheriting different members not supported")
1035        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1036
1037    def external_selectors(self):
1038        sels = []
1039        for _name, attr in self.attr_list:
1040            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1041                sels.append(attr.selector)
1042        return sels
1043
1044    def free_needs_iter(self):
1045        for _, attr in self.attr_list:
1046            if attr.free_needs_iter():
1047                return True
1048        return False
1049
1050
1051class EnumEntry(SpecEnumEntry):
1052    def __init__(self, enum_set, yaml, prev, value_start):
1053        super().__init__(enum_set, yaml, prev, value_start)
1054
1055        if prev:
1056            self.value_change = self.value != prev.value + 1
1057        else:
1058            self.value_change = self.value != 0
1059        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1060
1061        # Added by resolve:
1062        self.c_name = None
1063        delattr(self, "c_name")
1064
1065    def resolve(self):
1066        self.resolve_up(super())
1067
1068        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1069
1070
1071class EnumSet(SpecEnumSet):
1072    def __init__(self, family, yaml):
1073        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1074
1075        if 'enum-name' in yaml:
1076            if yaml['enum-name']:
1077                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1078                self.user_type = self.enum_name
1079            else:
1080                self.enum_name = None
1081        else:
1082            self.enum_name = 'enum ' + self.render_name
1083
1084        if self.enum_name:
1085            self.user_type = self.enum_name
1086        else:
1087            self.user_type = 'int'
1088
1089        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1090        self.header = yaml.get('header', None)
1091        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1092
1093        super().__init__(family, yaml)
1094
1095    def new_entry(self, entry, prev_entry, value_start):
1096        return EnumEntry(self, entry, prev_entry, value_start)
1097
1098    def value_range(self):
1099        low = min(x.value for x in self.entries.values())
1100        high = max(x.value for x in self.entries.values())
1101
1102        if high - low + 1 != len(self.entries):
1103            return None, None
1104
1105        return low, high
1106
1107
1108class AttrSet(SpecAttrSet):
1109    def __init__(self, family, yaml):
1110        super().__init__(family, yaml)
1111
1112        if self.subset_of is None:
1113            if 'name-prefix' in yaml:
1114                pfx = yaml['name-prefix']
1115            elif self.name == family.name:
1116                pfx = family.ident_name + '-a-'
1117            else:
1118                pfx = f"{family.ident_name}-a-{self.name}-"
1119            self.name_prefix = c_upper(pfx)
1120            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1121            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1122        else:
1123            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1124            self.max_name = family.attr_sets[self.subset_of].max_name
1125            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1126
1127        # Added by resolve:
1128        self.c_name = None
1129        delattr(self, "c_name")
1130
1131    def resolve(self):
1132        self.c_name = c_lower(self.name)
1133        if self.c_name in _C_KW:
1134            self.c_name += '_'
1135        if self.c_name == self.family.c_name:
1136            self.c_name = ''
1137
1138    def new_attr(self, elem, value):
1139        if elem['type'] in scalars:
1140            t = TypeScalar(self.family, self, elem, value)
1141        elif elem['type'] == 'unused':
1142            t = TypeUnused(self.family, self, elem, value)
1143        elif elem['type'] == 'pad':
1144            t = TypePad(self.family, self, elem, value)
1145        elif elem['type'] == 'flag':
1146            t = TypeFlag(self.family, self, elem, value)
1147        elif elem['type'] == 'string':
1148            t = TypeString(self.family, self, elem, value)
1149        elif elem['type'] == 'binary':
1150            if 'struct' in elem:
1151                t = TypeBinaryStruct(self.family, self, elem, value)
1152            elif elem.get('sub-type') in scalars:
1153                t = TypeBinaryScalarArray(self.family, self, elem, value)
1154            else:
1155                t = TypeBinary(self.family, self, elem, value)
1156        elif elem['type'] == 'bitfield32':
1157            t = TypeBitfield32(self.family, self, elem, value)
1158        elif elem['type'] == 'nest':
1159            t = TypeNest(self.family, self, elem, value)
1160        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1161            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1162                t = TypeIndexedArray(self.family, self, elem, value)
1163            else:
1164                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1165        elif elem['type'] == 'nest-type-value':
1166            t = TypeNestTypeValue(self.family, self, elem, value)
1167        elif elem['type'] == 'sub-message':
1168            t = TypeSubMessage(self.family, self, elem, value)
1169        else:
1170            raise Exception(f"No typed class for type {elem['type']}")
1171
1172        if 'multi-attr' in elem and elem['multi-attr']:
1173            t = TypeMultiAttr(self.family, self, elem, value, t)
1174
1175        return t
1176
1177
1178class Operation(SpecOperation):
1179    def __init__(self, family, yaml, req_value, rsp_value):
1180        # Fill in missing operation properties (for fixed hdr-only msgs)
1181        for mode in ['do', 'dump', 'event']:
1182            for direction in ['request', 'reply']:
1183                try:
1184                    yaml[mode][direction].setdefault('attributes', [])
1185                except KeyError:
1186                    pass
1187
1188        super().__init__(family, yaml, req_value, rsp_value)
1189
1190        self.render_name = c_lower(family.ident_name + '_' + self.name)
1191
1192        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1193                         ('dump' in yaml and 'request' in yaml['dump'])
1194
1195        self.has_ntf = False
1196
1197        # Added by resolve:
1198        self.enum_name = None
1199        delattr(self, "enum_name")
1200
1201    def resolve(self):
1202        self.resolve_up(super())
1203
1204        if not self.is_async:
1205            self.enum_name = self.family.op_prefix + c_upper(self.name)
1206        else:
1207            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1208
1209    def mark_has_ntf(self):
1210        self.has_ntf = True
1211
1212
1213class SubMessage(SpecSubMessage):
1214    def __init__(self, family, yaml):
1215        super().__init__(family, yaml)
1216
1217        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1218
1219    def resolve(self):
1220        self.resolve_up(super())
1221
1222
1223class Family(SpecFamily):
1224    def __init__(self, file_name, exclude_ops, fn_prefix):
1225        # Added by resolve:
1226        self.c_name = None
1227        delattr(self, "c_name")
1228        self.op_prefix = None
1229        delattr(self, "op_prefix")
1230        self.async_op_prefix = None
1231        delattr(self, "async_op_prefix")
1232        self.mcgrps = None
1233        delattr(self, "mcgrps")
1234        self.consts = None
1235        delattr(self, "consts")
1236        self.hooks = None
1237        delattr(self, "hooks")
1238
1239        self.root_sets = {}
1240        self.pure_nested_structs = {}
1241        self.kernel_policy = None
1242        self.global_policy = None
1243        self.global_policy_set = None
1244
1245        super().__init__(file_name, exclude_ops=exclude_ops)
1246
1247        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1248        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1249
1250        if 'definitions' not in self.yaml:
1251            self.yaml['definitions'] = []
1252
1253        if 'uapi-header' in self.yaml:
1254            self.uapi_header = self.yaml['uapi-header']
1255        else:
1256            self.uapi_header = f"linux/{self.ident_name}.h"
1257        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1258            self.uapi_header_name = self.uapi_header[6:-2]
1259        else:
1260            self.uapi_header_name = self.ident_name
1261
1262        self.fn_prefix = fn_prefix if fn_prefix else f'{self.ident_name}-nl'
1263
1264    def resolve(self):
1265        self.resolve_up(super())
1266
1267        self.c_name = c_lower(self.ident_name)
1268        if 'name-prefix' in self.yaml['operations']:
1269            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1270        else:
1271            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1272        if 'async-prefix' in self.yaml['operations']:
1273            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1274        else:
1275            self.async_op_prefix = self.op_prefix
1276
1277        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1278
1279        self.hooks = {}
1280        for when in ['pre', 'post']:
1281            self.hooks[when] = {}
1282            for op_mode in ['do', 'dump']:
1283                self.hooks[when][op_mode] = {}
1284                self.hooks[when][op_mode]['set'] = set()
1285                self.hooks[when][op_mode]['list'] = []
1286
1287        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1288        self.root_sets = {}
1289        # dict space-name -> Struct
1290        self.pure_nested_structs = {}
1291
1292        self._mark_notify()
1293        self._mock_up_events()
1294
1295        self._load_root_sets()
1296        self._load_nested_sets()
1297        self._load_attr_use()
1298        self._load_selector_passing()
1299        self._load_hooks()
1300
1301        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1302        if self.kernel_policy == 'global':
1303            self._load_global_policy()
1304
1305    def new_enum(self, elem):
1306        return EnumSet(self, elem)
1307
1308    def new_attr_set(self, elem):
1309        return AttrSet(self, elem)
1310
1311    def new_operation(self, elem, req_value, rsp_value):
1312        return Operation(self, elem, req_value, rsp_value)
1313
1314    def new_sub_message(self, elem):
1315        return SubMessage(self, elem)
1316
1317    def is_classic(self):
1318        return self.proto == 'netlink-raw'
1319
1320    def _mark_notify(self):
1321        for op in self.msgs.values():
1322            if 'notify' in op:
1323                self.ops[op['notify']].mark_has_ntf()
1324
1325    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1326    def _mock_up_events(self):
1327        for op in self.yaml['operations']['list']:
1328            if 'event' in op:
1329                op['do'] = {
1330                    'reply': {
1331                        'attributes': op['event']['attributes']
1332                    }
1333                }
1334
1335    def _load_root_sets(self):
1336        for _op_name, op in self.msgs.items():
1337            if 'attribute-set' not in op:
1338                continue
1339
1340            req_attrs = set()
1341            rsp_attrs = set()
1342            for op_mode in ['do', 'dump']:
1343                if op_mode in op and 'request' in op[op_mode]:
1344                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1345                if op_mode in op and 'reply' in op[op_mode]:
1346                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1347            if 'event' in op:
1348                rsp_attrs.update(set(op['event']['attributes']))
1349
1350            if op['attribute-set'] not in self.root_sets:
1351                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1352            else:
1353                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1354                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1355
1356    def _sort_pure_types(self):
1357        # Try to reorder according to dependencies
1358        pns_key_list = list(self.pure_nested_structs.keys())
1359        pns_key_seen = set()
1360        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1361        for _ in range(rounds):
1362            if len(pns_key_list) == 0:
1363                break
1364            name = pns_key_list.pop(0)
1365            finished = True
1366            for _, spec in self.attr_sets[name].items():
1367                if 'nested-attributes' in spec:
1368                    nested = spec['nested-attributes']
1369                elif 'sub-message' in spec:
1370                    nested = spec.sub_message
1371                else:
1372                    continue
1373
1374                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1375                if self.pure_nested_structs[nested].recursive:
1376                    continue
1377                if nested not in pns_key_seen:
1378                    # Dicts are sorted, this will make struct last
1379                    struct = self.pure_nested_structs.pop(name)
1380                    self.pure_nested_structs[name] = struct
1381                    finished = False
1382                    break
1383            if finished:
1384                pns_key_seen.add(name)
1385            else:
1386                pns_key_list.append(name)
1387
1388    def _load_nested_set_nest(self, spec):
1389        inherit = set()
1390        nested = spec['nested-attributes']
1391        if nested not in self.root_sets:
1392            if nested not in self.pure_nested_structs:
1393                self.pure_nested_structs[nested] = \
1394                    Struct(self, nested, inherited=inherit,
1395                           fixed_header=spec.get('fixed-header'))
1396        else:
1397            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1398
1399        if 'type-value' in spec:
1400            if nested in self.root_sets:
1401                raise Exception("Inheriting members to a space used as root not supported")
1402            inherit.update(set(spec['type-value']))
1403        elif spec['type'] == 'indexed-array':
1404            inherit.add('idx')
1405        self.pure_nested_structs[nested].set_inherited(inherit)
1406
1407        return nested
1408
1409    def _load_nested_set_submsg(self, spec):
1410        # Fake the struct type for the sub-message itself
1411        # its not a attr_set but codegen wants attr_sets.
1412        submsg = self.sub_msgs[spec["sub-message"]]
1413        nested = submsg.name
1414
1415        attrs = []
1416        for name, fmt in submsg.formats.items():
1417            attr = {
1418                "name": name,
1419                "parent-sub-message": spec,
1420            }
1421            if 'attribute-set' in fmt:
1422                attr |= {
1423                    "type": "nest",
1424                    "nested-attributes": fmt['attribute-set'],
1425                }
1426                if 'fixed-header' in fmt:
1427                    attr |= { "fixed-header": fmt["fixed-header"] }
1428            elif 'fixed-header' in fmt:
1429                attr |= {
1430                    "type": "binary",
1431                    "struct": fmt["fixed-header"],
1432                }
1433            else:
1434                attr["type"] = "flag"
1435            attrs.append(attr)
1436
1437        self.attr_sets[nested] = AttrSet(self, {
1438            "name": nested,
1439            "name-pfx": self.name + '-' + spec.name + '-',
1440            "attributes": attrs
1441        })
1442
1443        if nested not in self.pure_nested_structs:
1444            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1445
1446        return nested
1447
1448    def _load_nested_sets(self):
1449        attr_set_queue = list(self.root_sets.keys())
1450        attr_set_seen = set(self.root_sets.keys())
1451
1452        while attr_set_queue:
1453            a_set = attr_set_queue.pop(0)
1454            for attr, spec in self.attr_sets[a_set].items():
1455                if 'nested-attributes' in spec:
1456                    nested = self._load_nested_set_nest(spec)
1457                elif 'sub-message' in spec:
1458                    nested = self._load_nested_set_submsg(spec)
1459                else:
1460                    continue
1461
1462                if nested not in attr_set_seen:
1463                    attr_set_queue.append(nested)
1464                    attr_set_seen.add(nested)
1465
1466        for root_set, rs_members in self.root_sets.items():
1467            for attr, spec in self.attr_sets[root_set].items():
1468                if 'nested-attributes' in spec:
1469                    nested = spec['nested-attributes']
1470                elif 'sub-message' in spec:
1471                    nested = spec.sub_message
1472                else:
1473                    nested = None
1474
1475                if nested:
1476                    if attr in rs_members['request']:
1477                        self.pure_nested_structs[nested].request = True
1478                    if attr in rs_members['reply']:
1479                        self.pure_nested_structs[nested].reply = True
1480
1481                    if spec.is_multi_val():
1482                        child = self.pure_nested_structs.get(nested)
1483                        child.in_multi_val = True
1484
1485        self._sort_pure_types()
1486
1487        # Propagate the request / reply / recursive
1488        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1489            for _, spec in self.attr_sets[attr_set].items():
1490                if attr_set in struct.child_nests:
1491                    struct.recursive = True
1492
1493                if 'nested-attributes' in spec:
1494                    child_name = spec['nested-attributes']
1495                elif 'sub-message' in spec:
1496                    child_name = spec.sub_message
1497                else:
1498                    continue
1499
1500                struct.child_nests.add(child_name)
1501                child = self.pure_nested_structs.get(child_name)
1502                if child:
1503                    if not child.recursive:
1504                        struct.child_nests.update(child.child_nests)
1505                    child.request |= struct.request
1506                    child.reply |= struct.reply
1507                    if spec.is_multi_val():
1508                        child.in_multi_val = True
1509
1510        self._sort_pure_types()
1511
1512    def _load_attr_use(self):
1513        for _, struct in self.pure_nested_structs.items():
1514            if struct.request:
1515                for _, arg in struct.member_list():
1516                    arg.set_request()
1517            if struct.reply:
1518                for _, arg in struct.member_list():
1519                    arg.set_reply()
1520
1521        for root_set, rs_members in self.root_sets.items():
1522            for attr, spec in self.attr_sets[root_set].items():
1523                if attr in rs_members['request']:
1524                    spec.set_request()
1525                if attr in rs_members['reply']:
1526                    spec.set_reply()
1527
1528    def _load_selector_passing(self):
1529        def all_structs():
1530            for k, v in reversed(self.pure_nested_structs.items()):
1531                yield k, v
1532            for k, _ in self.root_sets.items():
1533                yield k, None  # we don't have a struct, but it must be terminal
1534
1535        for attr_set, _struct in all_structs():
1536            for _, spec in self.attr_sets[attr_set].items():
1537                if 'nested-attributes' in spec:
1538                    child_name = spec['nested-attributes']
1539                elif 'sub-message' in spec:
1540                    child_name = spec.sub_message
1541                else:
1542                    continue
1543
1544                child = self.pure_nested_structs.get(child_name)
1545                for selector in child.external_selectors():
1546                    if selector.name in self.attr_sets[attr_set]:
1547                        sel_attr = self.attr_sets[attr_set][selector.name]
1548                        selector.set_attr(sel_attr)
1549                    else:
1550                        raise Exception("Passing selector thru more than one layer not supported")
1551
1552    def _load_global_policy(self):
1553        global_set = set()
1554        attr_set_name = None
1555        for _op_name, op in self.ops.items():
1556            if not op:
1557                continue
1558            if 'attribute-set' not in op:
1559                continue
1560
1561            if attr_set_name is None:
1562                attr_set_name = op['attribute-set']
1563            if attr_set_name != op['attribute-set']:
1564                raise Exception('For a global policy all ops must use the same set')
1565
1566            for op_mode in ['do', 'dump']:
1567                if op_mode in op:
1568                    req = op[op_mode].get('request')
1569                    if req:
1570                        global_set.update(req.get('attributes', []))
1571
1572        self.global_policy = []
1573        self.global_policy_set = attr_set_name
1574        for attr in self.attr_sets[attr_set_name]:
1575            if attr in global_set:
1576                self.global_policy.append(attr)
1577
1578    def _load_hooks(self):
1579        for op in self.ops.values():
1580            for op_mode in ['do', 'dump']:
1581                if op_mode not in op:
1582                    continue
1583                for when in ['pre', 'post']:
1584                    if when not in op[op_mode]:
1585                        continue
1586                    name = op[op_mode][when]
1587                    if name in self.hooks[when][op_mode]['set']:
1588                        continue
1589                    self.hooks[when][op_mode]['set'].add(name)
1590                    self.hooks[when][op_mode]['list'].append(name)
1591
1592
1593class RenderInfo:
1594    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1595        self.family = family
1596        self.nl = cw.nlib
1597        self.ku_space = ku_space
1598        self.op_mode = op_mode
1599        self.op = op
1600
1601        fixed_hdr = op.fixed_header if op else None
1602        self.fixed_hdr_len = 'ys->family->hdr_len'
1603        if op and op.fixed_header:
1604            if op.fixed_header != family.fixed_header:
1605                if family.is_classic():
1606                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1607                else:
1608                    raise Exception("Per-op fixed header not supported, yet")
1609
1610
1611        # 'do' and 'dump' response parsing is identical
1612        self.type_consistent = True
1613        self.type_oneside = False
1614        if op_mode != 'do' and 'dump' in op:
1615            if 'do' in op:
1616                if ('reply' in op['do']) != ('reply' in op["dump"]):
1617                    self.type_consistent = False
1618                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1619                    self.type_consistent = False
1620            else:
1621                self.type_consistent = True
1622                self.type_oneside = True
1623
1624        self.attr_set = attr_set
1625        if not self.attr_set:
1626            self.attr_set = op['attribute-set']
1627
1628        self.type_name_conflict = False
1629        if op:
1630            self.type_name = c_lower(op.name)
1631        else:
1632            self.type_name = c_lower(attr_set)
1633            if attr_set in family.consts:
1634                self.type_name_conflict = True
1635
1636        self.cw = cw
1637
1638        self.struct = {}
1639        if op_mode == 'notify':
1640            op_mode = 'do' if 'do' in op else 'dump'
1641        for op_dir in ['request', 'reply']:
1642            if op:
1643                type_list = []
1644                if op_dir in op[op_mode]:
1645                    type_list = op[op_mode][op_dir]['attributes']
1646                self.struct[op_dir] = Struct(family, self.attr_set,
1647                                             fixed_header=fixed_hdr,
1648                                             type_list=type_list)
1649        if op_mode == 'event':
1650            self.struct['reply'] = Struct(family, self.attr_set,
1651                                          fixed_header=fixed_hdr,
1652                                          type_list=op['event']['attributes'])
1653
1654    def type_empty(self, key):
1655        return len(self.struct[key].attr_list) == 0 and \
1656            self.struct['request'].fixed_header is None
1657
1658    def needs_nlflags(self, direction):
1659        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1660
1661
1662class CodeWriter:
1663    def __init__(self, nlib, out_file=None, overwrite=True):
1664        self.nlib = nlib
1665        self._overwrite = overwrite
1666
1667        self._nl = False
1668        self._block_end = False
1669        self._silent_block = False
1670        self._ind = 0
1671        self._ifdef_block = None
1672        if out_file is None:
1673            self._out = os.sys.stdout
1674        else:
1675            # pylint: disable=consider-using-with
1676            self._out = tempfile.NamedTemporaryFile('w+')
1677            self._out_file = out_file
1678
1679    def __del__(self):
1680        self.close_out_file()
1681
1682    def close_out_file(self):
1683        if self._out == os.sys.stdout:
1684            return
1685        # Avoid modifying the file if contents didn't change
1686        self._out.flush()
1687        if not self._overwrite and os.path.isfile(self._out_file):
1688            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1689                return
1690        with open(self._out_file, 'w+', encoding='utf-8') as out_file:
1691            self._out.seek(0)
1692            shutil.copyfileobj(self._out, out_file)
1693            self._out.close()
1694        self._out = os.sys.stdout
1695
1696    @classmethod
1697    def _is_cond(cls, line):
1698        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1699
1700    def p(self, line, add_ind=0):
1701        if self._block_end:
1702            self._block_end = False
1703            if line.startswith('else'):
1704                line = '} ' + line
1705            else:
1706                self._out.write('\t' * self._ind + '}\n')
1707
1708        if self._nl:
1709            self._out.write('\n')
1710            self._nl = False
1711
1712        ind = self._ind
1713        if line[-1] == ':':
1714            ind -= 1
1715        if self._silent_block:
1716            ind += 1
1717        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1718        self._silent_block |= line.strip() == 'else'
1719        if line[0] == '#':
1720            ind = 0
1721        if add_ind:
1722            ind += add_ind
1723        self._out.write('\t' * ind + line + '\n')
1724
1725    def nl(self):
1726        self._nl = True
1727
1728    def block_start(self, line=''):
1729        if line:
1730            line = line + ' '
1731        self.p(line + '{')
1732        self._ind += 1
1733
1734    def block_end(self, line=''):
1735        if line and line[0] not in {';', ','}:
1736            line = ' ' + line
1737        self._ind -= 1
1738        self._nl = False
1739        if not line:
1740            # Delay printing closing bracket in case "else" comes next
1741            if self._block_end:
1742                self._out.write('\t' * (self._ind + 1) + '}\n')
1743            self._block_end = True
1744        else:
1745            self.p('}' + line)
1746
1747    def write_doc_line(self, doc, indent=True):
1748        words = doc.split()
1749        line = ' *'
1750        for word in words:
1751            if len(line) + len(word) >= 79:
1752                self.p(line)
1753                line = ' *'
1754                if indent:
1755                    line += '  '
1756            line += ' ' + word
1757        self.p(line)
1758
1759    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1760        if not args:
1761            args = ['void']
1762
1763        if doc:
1764            self.p('/*')
1765            self.p(' * ' + doc)
1766            self.p(' */')
1767
1768        oneline = qual_ret
1769        if qual_ret[-1] != '*':
1770            oneline += ' '
1771        oneline += f"{name}({', '.join(args)}){suffix}"
1772
1773        if len(oneline) < 80:
1774            self.p(oneline)
1775            return
1776
1777        v = qual_ret
1778        if len(v) > 3:
1779            self.p(v)
1780            v = ''
1781        elif qual_ret[-1] != '*':
1782            v += ' '
1783        v += name + '('
1784        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1785        delta_ind = len(v) - len(ind)
1786        v += args[0]
1787        i = 1
1788        while i < len(args):
1789            next_len = len(v) + len(args[i])
1790            if v[0] == '\t':
1791                next_len += delta_ind
1792            if next_len > 76:
1793                self.p(v + ',')
1794                v = ind
1795            else:
1796                v += ', '
1797            v += args[i]
1798            i += 1
1799        self.p(v + ')' + suffix)
1800
1801    def write_func_lvar(self, local_vars):
1802        if not local_vars:
1803            return
1804
1805        if isinstance(local_vars, str):
1806            local_vars = [local_vars]
1807
1808        local_vars.sort(key=len, reverse=True)
1809        for var in local_vars:
1810            self.p(var)
1811        self.nl()
1812
1813    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1814        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1815        self.block_start()
1816        self.write_func_lvar(local_vars=local_vars)
1817
1818        for line in body:
1819            self.p(line)
1820        self.block_end()
1821
1822    def writes_defines(self, defines):
1823        longest = 0
1824        for define in defines:
1825            longest = max(len(define[0]), longest)
1826        longest = ((longest + 8) // 8) * 8
1827        for define in defines:
1828            line = '#define ' + define[0]
1829            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1830            if isinstance(define[1], int):
1831                line += str(define[1])
1832            elif isinstance(define[1], str):
1833                line += '"' + define[1] + '"'
1834            self.p(line)
1835
1836    def write_struct_init(self, members):
1837        longest = max(len(x[0]) for x in members)
1838        longest += 1  # because we prepend a .
1839        longest = ((longest + 8) // 8) * 8
1840        for one in members:
1841            line = '.' + one[0]
1842            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1843            line += '= ' + str(one[1]) + ','
1844            self.p(line)
1845
1846    def ifdef_block(self, config):
1847        config_option = None
1848        if config:
1849            config_option = 'CONFIG_' + c_upper(config)
1850        if self._ifdef_block == config_option:
1851            return
1852
1853        if self._ifdef_block:
1854            self.p('#endif /* ' + self._ifdef_block + ' */')
1855        if config_option:
1856            self.p('#ifdef ' + config_option)
1857        self._ifdef_block = config_option
1858
1859
1860scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1861
1862direction_to_suffix = {
1863    'reply': '_rsp',
1864    'request': '_req',
1865    '': ''
1866}
1867
1868op_mode_to_wrapper = {
1869    'do': '',
1870    'dump': '_list',
1871    'notify': '_ntf',
1872    'event': '',
1873}
1874
1875_C_KW = {
1876    'auto',
1877    'bool',
1878    'break',
1879    'case',
1880    'char',
1881    'const',
1882    'continue',
1883    'default',
1884    'do',
1885    'double',
1886    'else',
1887    'enum',
1888    'extern',
1889    'float',
1890    'for',
1891    'goto',
1892    'if',
1893    'inline',
1894    'int',
1895    'long',
1896    'register',
1897    'return',
1898    'short',
1899    'signed',
1900    'sizeof',
1901    'static',
1902    'struct',
1903    'switch',
1904    'typedef',
1905    'union',
1906    'unsigned',
1907    'void',
1908    'volatile',
1909    'while'
1910}
1911
1912
1913def rdir(direction):
1914    if direction == 'reply':
1915        return 'request'
1916    if direction == 'request':
1917        return 'reply'
1918    return direction
1919
1920
1921def op_prefix(ri, direction, deref=False):
1922    suffix = f"_{ri.type_name}"
1923
1924    if not ri.op_mode:
1925        pass
1926    elif ri.op_mode == 'do':
1927        suffix += f"{direction_to_suffix[direction]}"
1928    else:
1929        if direction == 'request':
1930            suffix += '_req'
1931            if not ri.type_oneside:
1932                suffix += '_dump'
1933        else:
1934            if ri.type_consistent:
1935                if deref:
1936                    suffix += f"{direction_to_suffix[direction]}"
1937                else:
1938                    suffix += op_mode_to_wrapper[ri.op_mode]
1939            else:
1940                suffix += '_rsp'
1941                suffix += '_dump' if deref else '_list'
1942
1943    return f"{ri.family.c_name}{suffix}"
1944
1945
1946def type_name(ri, direction, deref=False):
1947    return f"struct {op_prefix(ri, direction, deref=deref)}"
1948
1949
1950def print_prototype(ri, direction, terminate=True, doc=None):
1951    suffix = ';' if terminate else ''
1952
1953    fname = ri.op.render_name
1954    if ri.op_mode == 'dump':
1955        fname += '_dump'
1956
1957    args = ['struct ynl_sock *ys']
1958    if 'request' in ri.op[ri.op_mode]:
1959        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1960
1961    ret = 'int'
1962    if 'reply' in ri.op[ri.op_mode]:
1963        ret = f"{type_name(ri, rdir(direction))} *"
1964
1965    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1966
1967
1968def print_req_prototype(ri):
1969    print_prototype(ri, "request", doc=ri.op['doc'])
1970
1971
1972def print_dump_prototype(ri):
1973    print_prototype(ri, "request")
1974
1975
1976def put_typol_submsg(cw, struct):
1977    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1978
1979    i = 0
1980    for name, arg in struct.member_list():
1981        nest = ""
1982        if arg.type == 'nest':
1983            nest = f" .nest = &{arg.nested_render_name}_nest,"
1984        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s",%s },' %
1985             (i, name, nest))
1986        i += 1
1987
1988    cw.block_end(line=';')
1989    cw.nl()
1990
1991    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1992    cw.p(f'.max_attr = {i - 1},')
1993    cw.p(f'.table = {struct.render_name}_policy,')
1994    cw.block_end(line=';')
1995    cw.nl()
1996
1997
1998def put_typol_fwd(cw, struct):
1999    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
2000
2001
2002def put_typol(cw, struct):
2003    if struct.submsg:
2004        put_typol_submsg(cw, struct)
2005        return
2006
2007    type_max = struct.attr_set.max_name
2008    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
2009
2010    for _, arg in struct.member_list():
2011        arg.attr_typol(cw)
2012
2013    cw.block_end(line=';')
2014    cw.nl()
2015
2016    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
2017    cw.p(f'.max_attr = {type_max},')
2018    cw.p(f'.table = {struct.render_name}_policy,')
2019    cw.block_end(line=';')
2020    cw.nl()
2021
2022
2023def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
2024    args = [f'int {arg_name}']
2025    if enum:
2026        args = [enum.user_type + ' ' + arg_name]
2027    cw.write_func_prot('const char *', f'{render_name}_str', args)
2028    cw.block_start()
2029    if enum and enum.type == 'flags':
2030        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
2031    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
2032    cw.p('return NULL;')
2033    cw.p(f'return {map_name}[{arg_name}];')
2034    cw.block_end()
2035    cw.nl()
2036
2037
2038def put_op_name_fwd(family, cw):
2039    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
2040
2041
2042def put_op_name(family, cw):
2043    map_name = f'{family.c_name}_op_strmap'
2044    cw.block_start(line=f"static const char * const {map_name}[] =")
2045    for op_name, op in family.msgs.items():
2046        if op.rsp_value:
2047            # Make sure we don't add duplicated entries, if multiple commands
2048            # produce the same response in legacy families.
2049            if family.rsp_by_value[op.rsp_value] != op:
2050                cw.p(f'// skip "{op_name}", duplicate reply value')
2051                continue
2052
2053            if op.req_value == op.rsp_value:
2054                cw.p(f'[{op.enum_name}] = "{op_name}",')
2055            else:
2056                cw.p(f'[{op.rsp_value}] = "{op_name}",')
2057    cw.block_end(line=';')
2058    cw.nl()
2059
2060    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2061
2062
2063def put_enum_to_str_fwd(_family, cw, enum):
2064    args = [enum.user_type + ' value']
2065    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2066
2067
2068def put_enum_to_str(_family, cw, enum):
2069    map_name = f'{enum.render_name}_strmap'
2070    cw.block_start(line=f"static const char * const {map_name}[] =")
2071    for entry in enum.entries.values():
2072        cw.p(f'[{entry.value}] = "{entry.name}",')
2073    cw.block_end(line=';')
2074    cw.nl()
2075
2076    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2077
2078
2079def put_local_vars(struct):
2080    local_vars = []
2081    has_array = False
2082    has_count = False
2083    for _, arg in struct.member_list():
2084        has_array |= arg.type == 'indexed-array'
2085        has_count |= arg.presence_type() == 'count'
2086    if has_array:
2087        local_vars.append('struct nlattr *array;')
2088    if has_count:
2089        local_vars.append('unsigned int i;')
2090    return local_vars
2091
2092
2093def put_req_nested_prototype(ri, struct, suffix=';'):
2094    func_args = ['struct nlmsghdr *nlh',
2095                 'unsigned int attr_type',
2096                 f'{struct.ptr_name}obj']
2097
2098    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2099                          suffix=suffix)
2100
2101
2102def put_req_nested(ri, struct):
2103    local_vars = []
2104    init_lines = []
2105
2106    if struct.submsg is None:
2107        local_vars.append('struct nlattr *nest;')
2108        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2109    if struct.fixed_header:
2110        local_vars.append('void *hdr;')
2111        struct_sz = f'sizeof({struct.fixed_header})'
2112        init_lines.append(f"hdr = ynl_nlmsg_put_extra_header(nlh, {struct_sz});")
2113        init_lines.append(f"memcpy(hdr, &obj->_hdr, {struct_sz});")
2114
2115    local_vars += put_local_vars(struct)
2116
2117    put_req_nested_prototype(ri, struct, suffix='')
2118    ri.cw.block_start()
2119    ri.cw.write_func_lvar(local_vars)
2120
2121    for line in init_lines:
2122        ri.cw.p(line)
2123
2124    for _, arg in struct.member_list():
2125        arg.attr_put(ri, "obj")
2126
2127    if struct.submsg is None:
2128        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2129
2130    ri.cw.nl()
2131    ri.cw.p('return 0;')
2132    ri.cw.block_end()
2133    ri.cw.nl()
2134
2135
2136def _multi_parse(ri, struct, init_lines, local_vars):
2137    if struct.fixed_header:
2138        local_vars += ['void *hdr;']
2139    if struct.nested:
2140        if struct.fixed_header:
2141            iter_line = f"ynl_attr_for_each_nested_off(attr, nested, sizeof({struct.fixed_header}))"
2142        else:
2143            iter_line = "ynl_attr_for_each_nested(attr, nested)"
2144    else:
2145        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2146        if ri.op.fixed_header != ri.family.fixed_header:
2147            if ri.family.is_classic():
2148                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2149            else:
2150                raise Exception("Per-op fixed header not supported, yet")
2151
2152    indexed_arrays = set()
2153    multi_attrs = set()
2154    needs_parg = False
2155    var_set = set()
2156    for arg, aspec in struct.member_list():
2157        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2158            if aspec["sub-type"] in {'binary', 'nest'}:
2159                local_vars.append(f'const struct nlattr *attr_{aspec.c_name} = NULL;')
2160                indexed_arrays.add(arg)
2161            elif aspec['sub-type'] in scalars:
2162                local_vars.append(f'const struct nlattr *attr_{aspec.c_name} = NULL;')
2163                indexed_arrays.add(arg)
2164            else:
2165                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2166        if 'multi-attr' in aspec:
2167            multi_attrs.add(arg)
2168        needs_parg |= 'nested-attributes' in aspec
2169        needs_parg |= 'sub-message' in aspec
2170
2171        try:
2172            _, _, l_vars = aspec._attr_get(ri, '')
2173            var_set |= set(l_vars) if l_vars else set()
2174        except Exception:
2175            pass  # _attr_get() not implemented by simple types, ignore
2176    local_vars += list(var_set)
2177    if indexed_arrays or multi_attrs:
2178        local_vars.append('int i;')
2179    if needs_parg:
2180        local_vars.append('struct ynl_parse_arg parg;')
2181        init_lines.append('parg.ys = yarg->ys;')
2182
2183    all_multi = indexed_arrays | multi_attrs
2184
2185    for arg in sorted(all_multi):
2186        local_vars.append(f"unsigned int n_{struct[arg].c_name} = 0;")
2187
2188    ri.cw.block_start()
2189    ri.cw.write_func_lvar(local_vars)
2190
2191    for line in init_lines:
2192        ri.cw.p(line)
2193    ri.cw.nl()
2194
2195    for arg in struct.inherited:
2196        ri.cw.p(f'dst->{arg} = {arg};')
2197
2198    if struct.fixed_header:
2199        if struct.nested:
2200            ri.cw.p('hdr = ynl_attr_data(nested);')
2201        elif ri.family.is_classic():
2202            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2203        else:
2204            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2205        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2206    for arg in sorted(all_multi):
2207        aspec = struct[arg]
2208        ri.cw.p(f"if (dst->{aspec.c_name})")
2209        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2210
2211    ri.cw.nl()
2212    ri.cw.block_start(line=iter_line)
2213    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2214    ri.cw.nl()
2215
2216    first = True
2217    for _, arg in struct.member_list():
2218        good = arg.attr_get(ri, 'dst', first=first)
2219        # First may be 'unused' or 'pad', ignore those
2220        first &= not good
2221
2222    ri.cw.block_end()
2223    ri.cw.nl()
2224
2225    for arg in sorted(indexed_arrays):
2226        aspec = struct[arg]
2227
2228        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2229        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2230        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2231        ri.cw.p('i = 0;')
2232        if 'nested-attributes' in aspec:
2233            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2234        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2235        if 'nested-attributes' in aspec:
2236            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2237            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2238            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2239        elif aspec.sub_type in scalars:
2240            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2241        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2242            # Length is validated by typol
2243            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2244        else:
2245            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2246        ri.cw.p('i++;')
2247        ri.cw.block_end()
2248        ri.cw.block_end()
2249    ri.cw.nl()
2250
2251    for arg in sorted(multi_attrs):
2252        aspec = struct[arg]
2253        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2254        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2255        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2256        ri.cw.p('i = 0;')
2257        if 'nested-attributes' in aspec:
2258            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2259        ri.cw.block_start(line=iter_line)
2260        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2261        if 'nested-attributes' in aspec:
2262            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2263            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2264            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2265        elif aspec.type in scalars:
2266            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2267        elif aspec.type == 'binary' and 'struct' in aspec:
2268            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2269            ri.cw.nl()
2270            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2271            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2272            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2273        elif aspec.type == 'string':
2274            ri.cw.p('unsigned int len;')
2275            ri.cw.nl()
2276            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2277            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2278            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2279            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2280            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2281        else:
2282            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2283        ri.cw.p('i++;')
2284        ri.cw.block_end()
2285        ri.cw.block_end()
2286        ri.cw.block_end()
2287    ri.cw.nl()
2288
2289    if struct.nested:
2290        ri.cw.p('return 0;')
2291    else:
2292        ri.cw.p('return YNL_PARSE_CB_OK;')
2293    ri.cw.block_end()
2294    ri.cw.nl()
2295
2296
2297def parse_rsp_submsg(ri, struct):
2298    parse_rsp_nested_prototype(ri, struct, suffix='')
2299
2300    var = 'dst'
2301    local_vars = {'const struct nlattr *attr = nested;',
2302                  f'{struct.ptr_name}{var} = yarg->data;',
2303                  'struct ynl_parse_arg parg;'}
2304
2305    for _, arg in struct.member_list():
2306        _, _, l_vars = arg._attr_get(ri, var)
2307        local_vars |= set(l_vars) if l_vars else set()
2308
2309    ri.cw.block_start()
2310    ri.cw.write_func_lvar(list(local_vars))
2311    ri.cw.p('parg.ys = yarg->ys;')
2312    ri.cw.nl()
2313
2314    first = True
2315    for name, arg in struct.member_list():
2316        kw = 'if' if first else 'else if'
2317        first = False
2318
2319        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2320        get_lines, init_lines, _ = arg._attr_get(ri, var)
2321        for line in init_lines or []:
2322            ri.cw.p(line)
2323        for line in get_lines:
2324            ri.cw.p(line)
2325        if arg.presence_type() == 'present':
2326            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2327        ri.cw.block_end()
2328    ri.cw.p('return 0;')
2329    ri.cw.block_end()
2330    ri.cw.nl()
2331
2332
2333def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2334    func_args = ['struct ynl_parse_arg *yarg',
2335                 'const struct nlattr *nested']
2336    for sel in struct.external_selectors():
2337        func_args.append('const char *_sel_' + sel.name)
2338    if struct.submsg:
2339        func_args.insert(1, 'const char *sel')
2340    for arg in struct.inherited:
2341        func_args.append('__u32 ' + arg)
2342
2343    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2344                          suffix=suffix)
2345
2346
2347def parse_rsp_nested(ri, struct):
2348    if struct.submsg:
2349        parse_rsp_submsg(ri, struct)
2350        return
2351
2352    parse_rsp_nested_prototype(ri, struct, suffix='')
2353
2354    local_vars = ['const struct nlattr *attr;',
2355                  f'{struct.ptr_name}dst = yarg->data;']
2356    init_lines = []
2357
2358    if struct.member_list():
2359        _multi_parse(ri, struct, init_lines, local_vars)
2360    else:
2361        # Empty nest
2362        ri.cw.block_start()
2363        ri.cw.p('return 0;')
2364        ri.cw.block_end()
2365        ri.cw.nl()
2366
2367
2368def parse_rsp_msg(ri, deref=False):
2369    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2370        return
2371
2372    func_args = ['const struct nlmsghdr *nlh',
2373                 'struct ynl_parse_arg *yarg']
2374
2375    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2376                  'const struct nlattr *attr;']
2377    init_lines = ['dst = yarg->data;']
2378
2379    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2380
2381    if ri.struct["reply"].member_list():
2382        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2383    else:
2384        # Empty reply
2385        ri.cw.block_start()
2386        ri.cw.p('return YNL_PARSE_CB_OK;')
2387        ri.cw.block_end()
2388        ri.cw.nl()
2389
2390
2391def print_req(ri):
2392    ret_ok = '0'
2393    ret_err = '-1'
2394    direction = "request"
2395    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2396                  'struct nlmsghdr *nlh;',
2397                  'int err;']
2398
2399    if 'reply' in ri.op[ri.op_mode]:
2400        ret_ok = 'rsp'
2401        ret_err = 'NULL'
2402        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2403
2404    if ri.struct["request"].fixed_header:
2405        local_vars += ['size_t hdr_len;',
2406                       'void *hdr;']
2407
2408    local_vars += put_local_vars(ri.struct['request'])
2409
2410    print_prototype(ri, direction, terminate=False)
2411    ri.cw.block_start()
2412    ri.cw.write_func_lvar(local_vars)
2413
2414    if ri.family.is_classic():
2415        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2416    else:
2417        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2418
2419    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2420    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2421    if 'reply' in ri.op[ri.op_mode]:
2422        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2423    ri.cw.nl()
2424
2425    if ri.struct['request'].fixed_header:
2426        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2427        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2428        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2429        ri.cw.nl()
2430
2431    for _, attr in ri.struct["request"].member_list():
2432        attr.attr_put(ri, "req")
2433    ri.cw.nl()
2434
2435    if 'reply' in ri.op[ri.op_mode]:
2436        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2437        ri.cw.p('yrs.yarg.data = rsp;')
2438        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2439        if ri.op.value is not None:
2440            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2441        else:
2442            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2443        ri.cw.nl()
2444    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2445    ri.cw.p('if (err < 0)')
2446    if 'reply' in ri.op[ri.op_mode]:
2447        ri.cw.p('goto err_free;')
2448    else:
2449        ri.cw.p('return -1;')
2450    ri.cw.nl()
2451
2452    ri.cw.p(f"return {ret_ok};")
2453    ri.cw.nl()
2454
2455    if 'reply' in ri.op[ri.op_mode]:
2456        ri.cw.p('err_free:')
2457        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2458        ri.cw.p(f"return {ret_err};")
2459
2460    ri.cw.block_end()
2461
2462
2463def print_dump(ri):
2464    direction = "request"
2465    print_prototype(ri, direction, terminate=False)
2466    ri.cw.block_start()
2467    local_vars = ['struct ynl_dump_state yds = {};',
2468                  'struct nlmsghdr *nlh;',
2469                  'int err;']
2470
2471    if ri.struct['request'].fixed_header:
2472        local_vars += ['size_t hdr_len;',
2473                       'void *hdr;']
2474
2475    if 'request' in ri.op[ri.op_mode]:
2476        local_vars += put_local_vars(ri.struct['request'])
2477
2478    ri.cw.write_func_lvar(local_vars)
2479
2480    ri.cw.p('yds.yarg.ys = ys;')
2481    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2482    ri.cw.p("yds.yarg.data = NULL;")
2483    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2484    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2485    if ri.op.value is not None:
2486        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2487    else:
2488        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2489    ri.cw.nl()
2490    if ri.family.is_classic():
2491        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2492    else:
2493        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2494
2495    if ri.struct['request'].fixed_header:
2496        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2497        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2498        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2499        ri.cw.nl()
2500
2501    if "request" in ri.op[ri.op_mode]:
2502        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2503        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2504        ri.cw.nl()
2505        for _, attr in ri.struct["request"].member_list():
2506            attr.attr_put(ri, "req")
2507    ri.cw.nl()
2508
2509    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2510    ri.cw.p('if (err < 0)')
2511    ri.cw.p('goto free_list;')
2512    ri.cw.nl()
2513
2514    ri.cw.p('return yds.first;')
2515    ri.cw.nl()
2516    ri.cw.p('free_list:')
2517    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2518    ri.cw.p('return NULL;')
2519    ri.cw.block_end()
2520
2521
2522def call_free(ri, direction, var):
2523    return f"{op_prefix(ri, direction)}_free({var});"
2524
2525
2526def free_arg_name(direction):
2527    if direction:
2528        return direction_to_suffix[direction][1:]
2529    return 'obj'
2530
2531
2532def print_alloc_wrapper(ri, direction, struct=None):
2533    name = op_prefix(ri, direction)
2534    struct_name = name
2535    if ri.type_name_conflict:
2536        struct_name += '_'
2537
2538    args = ["void"]
2539    cnt = "1"
2540    if struct and struct.in_multi_val:
2541        args = ["unsigned int n"]
2542        cnt = "n"
2543
2544    ri.cw.write_func_prot(f'static inline struct {struct_name} *',
2545                          f"{name}_alloc", args)
2546    ri.cw.block_start()
2547    ri.cw.p(f'return calloc({cnt}, sizeof(struct {struct_name}));')
2548    ri.cw.block_end()
2549
2550
2551def print_free_prototype(ri, direction, suffix=';'):
2552    name = op_prefix(ri, direction)
2553    struct_name = name
2554    if ri.type_name_conflict:
2555        struct_name += '_'
2556    arg = free_arg_name(direction)
2557    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2558
2559
2560def print_nlflags_set(ri, direction):
2561    name = op_prefix(ri, direction)
2562    ri.cw.write_func_prot('static inline void', f"{name}_set_nlflags",
2563                          [f"struct {name} *req", "__u16 nl_flags"])
2564    ri.cw.block_start()
2565    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2566    ri.cw.block_end()
2567    ri.cw.nl()
2568
2569
2570def _print_type(ri, direction, struct):
2571    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2572    if not direction and ri.type_name_conflict:
2573        suffix += '_'
2574
2575    if ri.op_mode == 'dump' and not ri.type_oneside:
2576        suffix += '_dump'
2577
2578    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2579
2580    if ri.needs_nlflags(direction):
2581        ri.cw.p('__u16 _nlmsg_flags;')
2582        ri.cw.nl()
2583    if struct.fixed_header:
2584        ri.cw.p(struct.fixed_header + ' _hdr;')
2585        ri.cw.nl()
2586
2587    for type_filter in ['present', 'len', 'count']:
2588        meta_started = False
2589        for _, attr in struct.member_list():
2590            line = attr.presence_member(ri.ku_space, type_filter)
2591            if line:
2592                if not meta_started:
2593                    ri.cw.block_start(line="struct")
2594                    meta_started = True
2595                ri.cw.p(line)
2596        if meta_started:
2597            ri.cw.block_end(line=f'_{type_filter};')
2598    ri.cw.nl()
2599
2600    for arg in struct.inherited:
2601        ri.cw.p(f"__u32 {arg};")
2602
2603    for _, attr in struct.member_list():
2604        attr.struct_member(ri)
2605
2606    ri.cw.block_end(line=';')
2607    ri.cw.nl()
2608
2609
2610def print_type(ri, direction):
2611    _print_type(ri, direction, ri.struct[direction])
2612
2613
2614def print_type_full(ri, struct):
2615    _print_type(ri, "", struct)
2616
2617    if struct.request and struct.in_multi_val:
2618        print_alloc_wrapper(ri, "", struct)
2619        ri.cw.nl()
2620        free_rsp_nested_prototype(ri)
2621        ri.cw.nl()
2622
2623        # Name conflicts are too hard to deal with with the current code base,
2624        # they are very rare so don't bother printing setters in that case.
2625        if ri.ku_space == 'user' and not ri.type_name_conflict:
2626            for _, attr in struct.member_list():
2627                attr.setter(ri, ri.attr_set, "", var="obj")
2628        ri.cw.nl()
2629
2630
2631def print_type_helpers(ri, direction, deref=False):
2632    print_free_prototype(ri, direction)
2633    ri.cw.nl()
2634
2635    if ri.needs_nlflags(direction):
2636        print_nlflags_set(ri, direction)
2637
2638    if ri.ku_space == 'user' and direction == 'request':
2639        for _, attr in ri.struct[direction].member_list():
2640            attr.setter(ri, ri.attr_set, direction, deref=deref)
2641    ri.cw.nl()
2642
2643
2644def print_req_type_helpers(ri):
2645    if ri.type_empty("request"):
2646        return
2647    print_alloc_wrapper(ri, "request")
2648    print_type_helpers(ri, "request")
2649
2650
2651def print_rsp_type_helpers(ri):
2652    if 'reply' not in ri.op[ri.op_mode]:
2653        return
2654    print_type_helpers(ri, "reply")
2655
2656
2657def print_parse_prototype(ri, direction, terminate=True):
2658    suffix = "_rsp" if direction == "reply" else "_req"
2659    term = ';' if terminate else ''
2660
2661    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2662                          ['const struct nlattr **tb',
2663                           f"struct {ri.op.render_name}{suffix} *req"],
2664                          suffix=term)
2665
2666
2667def print_req_type(ri):
2668    if ri.type_empty("request"):
2669        return
2670    print_type(ri, "request")
2671
2672
2673def print_req_free(ri):
2674    if 'request' not in ri.op[ri.op_mode]:
2675        return
2676    _free_type(ri, 'request', ri.struct['request'])
2677
2678
2679def print_rsp_type(ri):
2680    if ri.op_mode in ('do', 'dump') and 'reply' in ri.op[ri.op_mode]:
2681        direction = 'reply'
2682    elif ri.op_mode == 'event':
2683        direction = 'reply'
2684    else:
2685        return
2686    print_type(ri, direction)
2687
2688
2689def print_wrapped_type(ri):
2690    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2691    if ri.op_mode == 'dump':
2692        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2693    elif ri.op_mode in ('notify', 'event'):
2694        ri.cw.p('__u16 family;')
2695        ri.cw.p('__u8 cmd;')
2696        ri.cw.p('struct ynl_ntf_base_type *next;')
2697        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2698    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2699    ri.cw.block_end(line=';')
2700    ri.cw.nl()
2701    print_free_prototype(ri, 'reply')
2702    ri.cw.nl()
2703
2704
2705def _free_type_members_iter(ri, struct):
2706    if struct.free_needs_iter():
2707        ri.cw.p('unsigned int i;')
2708        ri.cw.nl()
2709
2710
2711def _free_type_members(ri, var, struct, ref=''):
2712    for _, attr in struct.member_list():
2713        attr.free(ri, var, ref)
2714
2715
2716def _free_type(ri, direction, struct):
2717    var = free_arg_name(direction)
2718
2719    print_free_prototype(ri, direction, suffix='')
2720    ri.cw.block_start()
2721    _free_type_members_iter(ri, struct)
2722    _free_type_members(ri, var, struct)
2723    if direction:
2724        ri.cw.p(f'free({var});')
2725    ri.cw.block_end()
2726    ri.cw.nl()
2727
2728
2729def free_rsp_nested_prototype(ri):
2730    print_free_prototype(ri, "")
2731
2732
2733def free_rsp_nested(ri, struct):
2734    _free_type(ri, "", struct)
2735
2736
2737def print_rsp_free(ri):
2738    if 'reply' not in ri.op[ri.op_mode]:
2739        return
2740    _free_type(ri, 'reply', ri.struct['reply'])
2741
2742
2743def print_dump_type_free(ri):
2744    sub_type = type_name(ri, 'reply')
2745
2746    print_free_prototype(ri, 'reply', suffix='')
2747    ri.cw.block_start()
2748    ri.cw.p(f"{sub_type} *next = rsp;")
2749    ri.cw.nl()
2750    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2751    _free_type_members_iter(ri, ri.struct['reply'])
2752    ri.cw.p('rsp = next;')
2753    ri.cw.p('next = rsp->next;')
2754    ri.cw.nl()
2755
2756    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2757    ri.cw.p('free(rsp);')
2758    ri.cw.block_end()
2759    ri.cw.block_end()
2760    ri.cw.nl()
2761
2762
2763def print_ntf_type_free(ri):
2764    print_free_prototype(ri, 'reply', suffix='')
2765    ri.cw.block_start()
2766    _free_type_members_iter(ri, ri.struct['reply'])
2767    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2768    ri.cw.p('free(rsp);')
2769    ri.cw.block_end()
2770    ri.cw.nl()
2771
2772
2773def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2774    if terminate and ri and policy_should_be_static(struct.family):
2775        return
2776
2777    if terminate:
2778        prefix = 'extern '
2779    else:
2780        if ri and policy_should_be_static(struct.family):
2781            prefix = 'static '
2782        else:
2783            prefix = ''
2784
2785    suffix = ';' if terminate else ' = {'
2786
2787    max_attr = struct.attr_max_val
2788    if ri:
2789        name = ri.op.render_name
2790        if ri.op.dual_policy:
2791            name += '_' + ri.op_mode
2792    else:
2793        name = struct.render_name
2794    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2795
2796
2797def print_req_policy(cw, struct, ri=None):
2798    if ri and ri.op:
2799        cw.ifdef_block(ri.op.get('config-cond', None))
2800    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2801    for _, arg in struct.member_list():
2802        arg.attr_policy(cw)
2803    cw.p("};")
2804    cw.ifdef_block(None)
2805    cw.nl()
2806
2807
2808def kernel_can_gen_family_struct(family):
2809    return family.proto == 'genetlink'
2810
2811
2812def policy_should_be_static(family):
2813    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2814
2815
2816def print_kernel_policy_ranges(family, cw):
2817    first = True
2818    for _, attr_set in family.attr_sets.items():
2819        if attr_set.subset_of:
2820            continue
2821
2822        for _, attr in attr_set.items():
2823            if not attr.request:
2824                continue
2825            if 'full-range' not in attr.checks:
2826                continue
2827
2828            if first:
2829                cw.p('/* Integer value ranges */')
2830                first = False
2831
2832            sign = '' if attr.type[0] == 'u' else '_signed'
2833            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2834            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2835            members = []
2836            if 'min' in attr.checks:
2837                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2838            if 'max' in attr.checks:
2839                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2840            cw.write_struct_init(members)
2841            cw.block_end(line=';')
2842            cw.nl()
2843
2844
2845def print_kernel_policy_sparse_enum_validates(family, cw):
2846    first = True
2847    for _, attr_set in family.attr_sets.items():
2848        if attr_set.subset_of:
2849            continue
2850
2851        for _, attr in attr_set.items():
2852            if not attr.request:
2853                continue
2854            if not attr.enum_name:
2855                continue
2856            if 'sparse' not in attr.checks:
2857                continue
2858
2859            if first:
2860                cw.p('/* Sparse enums validation callbacks */')
2861                first = False
2862
2863            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2864                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2865            cw.block_start()
2866            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2867            enum = family.consts[attr['enum']]
2868            first_entry = True
2869            for entry in enum.entries.values():
2870                if first_entry:
2871                    first_entry = False
2872                else:
2873                    cw.p('fallthrough;')
2874                cw.p(f'case {entry.c_name}:')
2875            cw.p('return 0;')
2876            cw.block_end()
2877            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2878            cw.p('return -EINVAL;')
2879            cw.block_end()
2880            cw.nl()
2881
2882
2883def print_kernel_op_table_fwd(family, cw, terminate):
2884    exported = not kernel_can_gen_family_struct(family)
2885
2886    if not terminate or exported:
2887        cw.p(f"/* Ops table for {family.ident_name} */")
2888
2889        pol_to_struct = {'global': 'genl_small_ops',
2890                         'per-op': 'genl_ops',
2891                         'split': 'genl_split_ops'}
2892        struct_type = pol_to_struct[family.kernel_policy]
2893
2894        if not exported:
2895            cnt = ""
2896        elif family.kernel_policy == 'split':
2897            cnt = 0
2898            for op in family.ops.values():
2899                if 'do' in op:
2900                    cnt += 1
2901                if 'dump' in op:
2902                    cnt += 1
2903        else:
2904            cnt = len(family.ops)
2905
2906        qual = 'static const' if not exported else 'const'
2907        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2908        if terminate:
2909            cw.p(f"extern {line};")
2910        else:
2911            cw.block_start(line=line + ' =')
2912
2913    if not terminate:
2914        return
2915
2916    cw.nl()
2917    for name in family.hooks['pre']['do']['list']:
2918        cw.write_func_prot('int', c_lower(name),
2919                           ['const struct genl_split_ops *ops',
2920                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2921    for name in family.hooks['post']['do']['list']:
2922        cw.write_func_prot('void', c_lower(name),
2923                           ['const struct genl_split_ops *ops',
2924                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2925    for name in family.hooks['pre']['dump']['list']:
2926        cw.write_func_prot('int', c_lower(name),
2927                           ['struct netlink_callback *cb'], suffix=';')
2928    for name in family.hooks['post']['dump']['list']:
2929        cw.write_func_prot('int', c_lower(name),
2930                           ['struct netlink_callback *cb'], suffix=';')
2931
2932    cw.nl()
2933
2934    for op_name, op in family.ops.items():
2935        if op.is_async:
2936            continue
2937
2938        if 'do' in op:
2939            name = c_lower(f"{family.fn_prefix}-{op_name}-doit")
2940            cw.write_func_prot('int', name,
2941                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2942
2943        if 'dump' in op:
2944            name = c_lower(f"{family.fn_prefix}-{op_name}-dumpit")
2945            cw.write_func_prot('int', name,
2946                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2947    cw.nl()
2948
2949
2950def print_kernel_op_table_hdr(family, cw):
2951    print_kernel_op_table_fwd(family, cw, terminate=True)
2952
2953
2954def print_kernel_op_table(family, cw):
2955    print_kernel_op_table_fwd(family, cw, terminate=False)
2956    if family.kernel_policy in ('global', 'per-op'):
2957        for op_name, op in family.ops.items():
2958            if op.is_async:
2959                continue
2960
2961            cw.ifdef_block(op.get('config-cond', None))
2962            cw.block_start()
2963            members = [('cmd', op.enum_name)]
2964            if 'dont-validate' in op:
2965                members.append(('validate',
2966                                ' | '.join([c_upper('genl-dont-validate-' + x)
2967                                            for x in op['dont-validate']])), )
2968            for op_mode in ['do', 'dump']:
2969                if op_mode in op:
2970                    name = c_lower(f"{family.fn_prefix}-{op_name}-{op_mode}it")
2971                    members.append((op_mode + 'it', name))
2972            if family.kernel_policy == 'per-op':
2973                struct = Struct(family, op['attribute-set'],
2974                                type_list=op['do']['request']['attributes'])
2975
2976                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2977                members.append(('policy', name))
2978                members.append(('maxattr', struct.attr_max_val.enum_name))
2979            if 'flags' in op:
2980                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2981            cw.write_struct_init(members)
2982            cw.block_end(line=',')
2983    elif family.kernel_policy == 'split':
2984        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2985                    'dump': {'pre': 'start', 'post': 'done'}}
2986
2987        for op_name, op in family.ops.items():
2988            for op_mode in ['do', 'dump']:
2989                if op.is_async or op_mode not in op:
2990                    continue
2991
2992                cw.ifdef_block(op.get('config-cond', None))
2993                cw.block_start()
2994                members = [('cmd', op.enum_name)]
2995                if 'dont-validate' in op:
2996                    dont_validate = []
2997                    for x in op['dont-validate']:
2998                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2999                            continue
3000                        if op_mode == "dump" and x == 'strict':
3001                            continue
3002                        dont_validate.append(x)
3003
3004                    if dont_validate:
3005                        members.append(('validate',
3006                                        ' | '.join([c_upper('genl-dont-validate-' + x)
3007                                                    for x in dont_validate])), )
3008                name = c_lower(f"{family.fn_prefix}-{op_name}-{op_mode}it")
3009                if 'pre' in op[op_mode]:
3010                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
3011                members.append((op_mode + 'it', name))
3012                if 'post' in op[op_mode]:
3013                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
3014                if 'request' in op[op_mode]:
3015                    struct = Struct(family, op['attribute-set'],
3016                                    type_list=op[op_mode]['request']['attributes'])
3017
3018                    if op.dual_policy:
3019                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
3020                    else:
3021                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
3022                    members.append(('policy', name))
3023                    members.append(('maxattr', struct.attr_max_val.enum_name))
3024                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
3025                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
3026                cw.write_struct_init(members)
3027                cw.block_end(line=',')
3028    cw.ifdef_block(None)
3029
3030    cw.block_end(line=';')
3031    cw.nl()
3032
3033
3034def print_kernel_mcgrp_hdr(family, cw):
3035    if not family.mcgrps['list']:
3036        return
3037
3038    cw.block_start('enum')
3039    for grp in family.mcgrps['list']:
3040        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
3041        cw.p(grp_id)
3042    cw.block_end(';')
3043    cw.nl()
3044
3045
3046def print_kernel_mcgrp_src(family, cw):
3047    if not family.mcgrps['list']:
3048        return
3049
3050    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
3051    for grp in family.mcgrps['list']:
3052        name = grp['name']
3053        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
3054        cw.p('[' + grp_id + '] = { "' + name + '", },')
3055    cw.block_end(';')
3056    cw.nl()
3057
3058
3059def print_kernel_family_struct_hdr(family, cw):
3060    if not kernel_can_gen_family_struct(family):
3061        return
3062
3063    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
3064    cw.nl()
3065    if 'sock-priv' in family.kernel_family:
3066        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
3067        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
3068        cw.nl()
3069
3070
3071def print_kernel_family_struct_src(family, cw):
3072    if not kernel_can_gen_family_struct(family):
3073        return
3074
3075    if 'sock-priv' in family.kernel_family:
3076        # Generate "trampolines" to make CFI happy
3077        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
3078                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
3079                      ["void *priv"])
3080        cw.nl()
3081        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
3082                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
3083                      ["void *priv"])
3084        cw.nl()
3085
3086    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
3087    cw.p('.name\t\t= ' + family.fam_key + ',')
3088    cw.p('.version\t= ' + family.ver_key + ',')
3089    cw.p('.netnsok\t= true,')
3090    cw.p('.parallel_ops\t= true,')
3091    cw.p('.module\t\t= THIS_MODULE,')
3092    if family.kernel_policy == 'per-op':
3093        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
3094        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3095    elif family.kernel_policy == 'split':
3096        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
3097        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3098    if family.mcgrps['list']:
3099        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
3100        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
3101    if 'sock-priv' in family.kernel_family:
3102        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
3103        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
3104        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
3105    cw.block_end(';')
3106
3107
3108def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
3109    start_line = 'enum'
3110    if enum_name in obj:
3111        if obj[enum_name]:
3112            start_line = 'enum ' + c_lower(obj[enum_name])
3113    elif ckey and ckey in obj:
3114        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3115    cw.block_start(line=start_line)
3116
3117
3118def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3119    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3120    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3121    max_value = f"({cnt_name} - 1)"
3122
3123    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3124    val = 0
3125    for op in family.msgs.values():
3126        if separate_ntf and ('notify' in op or 'event' in op):
3127            continue
3128
3129        suffix = ','
3130        if op.value != val:
3131            suffix = f" = {op.value},"
3132            val = op.value
3133        cw.p(op.enum_name + suffix)
3134        val += 1
3135    cw.nl()
3136    cw.p(cnt_name + ('' if max_by_define else ','))
3137    if not max_by_define:
3138        cw.p(f"{max_name} = {max_value}")
3139    cw.block_end(line=';')
3140    if max_by_define:
3141        cw.p(f"#define {max_name} {max_value}")
3142    cw.nl()
3143
3144
3145def render_uapi_directional(family, cw, max_by_define):
3146    max_name = f"{family.op_prefix}USER_MAX"
3147    cnt_name = f"__{family.op_prefix}USER_CNT"
3148    max_value = f"({cnt_name} - 1)"
3149
3150    cw.block_start(line='enum')
3151    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3152    val = 0
3153    for op in family.msgs.values():
3154        if 'do' in op and 'event' not in op:
3155            suffix = ','
3156            if op.value and op.value != val:
3157                suffix = f" = {op.value},"
3158                val = op.value
3159            cw.p(op.enum_name + suffix)
3160            val += 1
3161    cw.nl()
3162    cw.p(cnt_name + ('' if max_by_define else ','))
3163    if not max_by_define:
3164        cw.p(f"{max_name} = {max_value}")
3165    cw.block_end(line=';')
3166    if max_by_define:
3167        cw.p(f"#define {max_name} {max_value}")
3168    cw.nl()
3169
3170    max_name = f"{family.op_prefix}KERNEL_MAX"
3171    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3172    max_value = f"({cnt_name} - 1)"
3173
3174    cw.block_start(line='enum')
3175    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3176    val = 0
3177    for op in family.msgs.values():
3178        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3179            enum_name = op.enum_name
3180            if 'event' not in op and 'notify' not in op:
3181                enum_name = f'{enum_name}_REPLY'
3182
3183            suffix = ','
3184            if op.value and op.value != val:
3185                suffix = f" = {op.value},"
3186                val = op.value
3187            cw.p(enum_name + suffix)
3188            val += 1
3189    cw.nl()
3190    cw.p(cnt_name + ('' if max_by_define else ','))
3191    if not max_by_define:
3192        cw.p(f"{max_name} = {max_value}")
3193    cw.block_end(line=';')
3194    if max_by_define:
3195        cw.p(f"#define {max_name} {max_value}")
3196    cw.nl()
3197
3198
3199def render_uapi(family, cw):
3200    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3201    hdr_prot = hdr_prot.replace('/', '_')
3202    cw.p('#ifndef ' + hdr_prot)
3203    cw.p('#define ' + hdr_prot)
3204    cw.nl()
3205
3206    defines = [(family.fam_key, family["name"]),
3207               (family.ver_key, family.get('version', 1))]
3208    cw.writes_defines(defines)
3209    cw.nl()
3210
3211    defines = []
3212    for const in family['definitions']:
3213        if const.get('header'):
3214            continue
3215
3216        if const['type'] != 'const':
3217            cw.writes_defines(defines)
3218            defines = []
3219            cw.nl()
3220
3221        # Write kdoc for enum and flags (one day maybe also structs)
3222        if const['type'] == 'enum' or const['type'] == 'flags':
3223            enum = family.consts[const['name']]
3224
3225            if enum.header:
3226                continue
3227
3228            if enum.has_doc():
3229                if enum.has_entry_doc():
3230                    cw.p('/**')
3231                    doc = ''
3232                    if 'doc' in enum:
3233                        doc = ' - ' + enum['doc']
3234                    cw.write_doc_line(enum.enum_name + doc)
3235                else:
3236                    cw.p('/*')
3237                    cw.write_doc_line(enum['doc'], indent=False)
3238                for entry in enum.entries.values():
3239                    if entry.has_doc():
3240                        doc = '@' + entry.c_name + ': ' + entry['doc']
3241                        cw.write_doc_line(doc)
3242                cw.p(' */')
3243
3244            uapi_enum_start(family, cw, const, 'name')
3245            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3246            for entry in enum.entries.values():
3247                suffix = ','
3248                if entry.value_change:
3249                    suffix = f" = {entry.user_value()}" + suffix
3250                cw.p(entry.c_name + suffix)
3251
3252            if const.get('render-max', False):
3253                cw.nl()
3254                cw.p('/* private: */')
3255                if const['type'] == 'flags':
3256                    max_name = c_upper(name_pfx + 'mask')
3257                    max_val = f' = {enum.get_mask()},'
3258                    cw.p(max_name + max_val)
3259                else:
3260                    cnt_name = enum.enum_cnt_name
3261                    max_name = c_upper(name_pfx + 'max')
3262                    if not cnt_name:
3263                        cnt_name = '__' + name_pfx + 'max'
3264                    cw.p(c_upper(cnt_name) + ',')
3265                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3266            cw.block_end(line=';')
3267            cw.nl()
3268        elif const['type'] == 'const':
3269            name_pfx = const.get('name-prefix', f"{family.ident_name}-")
3270            defines.append([c_upper(family.get('c-define-name',
3271                                               f"{name_pfx}{const['name']}")),
3272                            const['value']])
3273
3274    if defines:
3275        cw.writes_defines(defines)
3276        cw.nl()
3277
3278    max_by_define = family.get('max-by-define', False)
3279
3280    for _, attr_set in family.attr_sets.items():
3281        if attr_set.subset_of:
3282            continue
3283
3284        max_value = f"({attr_set.cnt_name} - 1)"
3285
3286        val = 0
3287        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3288        for _, attr in attr_set.items():
3289            suffix = ','
3290            if attr.value != val:
3291                suffix = f" = {attr.value},"
3292                val = attr.value
3293            val += 1
3294            cw.p(attr.enum_name + suffix)
3295        if attr_set.items():
3296            cw.nl()
3297        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3298        if not max_by_define:
3299            cw.p(f"{attr_set.max_name} = {max_value}")
3300        cw.block_end(line=';')
3301        if max_by_define:
3302            cw.p(f"#define {attr_set.max_name} {max_value}")
3303        cw.nl()
3304
3305    # Commands
3306    separate_ntf = 'async-prefix' in family['operations']
3307
3308    if family.msg_id_model == 'unified':
3309        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3310    elif family.msg_id_model == 'directional':
3311        render_uapi_directional(family, cw, max_by_define)
3312    else:
3313        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3314
3315    if separate_ntf:
3316        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3317        for op in family.msgs.values():
3318            if separate_ntf and not ('notify' in op or 'event' in op):
3319                continue
3320
3321            suffix = ','
3322            if 'value' in op:
3323                suffix = f" = {op['value']},"
3324            cw.p(op.enum_name + suffix)
3325        cw.block_end(line=';')
3326        cw.nl()
3327
3328    # Multicast
3329    defines = []
3330    for grp in family.mcgrps['list']:
3331        name = grp['name']
3332        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3333                        f'{name}'])
3334    cw.nl()
3335    if defines:
3336        cw.writes_defines(defines)
3337        cw.nl()
3338
3339    cw.p(f'#endif /* {hdr_prot} */')
3340
3341
3342def _render_user_ntf_entry(ri, op):
3343    if not ri.family.is_classic():
3344        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3345    else:
3346        crud_op = ri.family.req_by_value[op.rsp_value]
3347        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3348    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3349    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3350    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3351    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3352    ri.cw.block_end(line=',')
3353
3354
3355def render_user_family(family, cw, prototype):
3356    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3357    if prototype:
3358        cw.p(f'extern {symbol};')
3359        return
3360
3361    if family.ntfs:
3362        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3363        for ntf_op_name, ntf_op in family.ntfs.items():
3364            if 'notify' in ntf_op:
3365                op = family.ops[ntf_op['notify']]
3366                ri = RenderInfo(cw, family, "user", op, "notify")
3367            elif 'event' in ntf_op:
3368                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3369            else:
3370                raise Exception('Invalid notification ' + ntf_op_name)
3371            _render_user_ntf_entry(ri, ntf_op)
3372        for _op_name, op in family.ops.items():
3373            if 'event' not in op:
3374                continue
3375            ri = RenderInfo(cw, family, "user", op, "event")
3376            _render_user_ntf_entry(ri, op)
3377        cw.block_end(line=";")
3378        cw.nl()
3379
3380    cw.block_start(f'{symbol} = ')
3381    cw.p(f'.name\t\t= "{family.c_name}",')
3382    if family.is_classic():
3383        cw.p('.is_classic\t= true,')
3384        cw.p(f'.classic_id\t= {family.get("protonum")},')
3385    if family.is_classic():
3386        if family.fixed_header:
3387            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3388    elif family.fixed_header:
3389        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3390    else:
3391        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3392    if family.ntfs:
3393        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3394        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3395    cw.block_end(line=';')
3396
3397
3398def family_contains_bitfield32(family):
3399    for _, attr_set in family.attr_sets.items():
3400        if attr_set.subset_of:
3401            continue
3402        for _, attr in attr_set.items():
3403            if attr.type == "bitfield32":
3404                return True
3405    return False
3406
3407
3408def find_kernel_root(full_path):
3409    sub_path = ''
3410    while True:
3411        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3412        full_path = os.path.dirname(full_path)
3413        maintainers = os.path.join(full_path, "MAINTAINERS")
3414        if os.path.exists(maintainers):
3415            return full_path, sub_path[:-1]
3416
3417
3418def main():
3419    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3420    parser.add_argument('--mode', dest='mode', type=str, required=True,
3421                        choices=('user', 'kernel', 'uapi'))
3422    parser.add_argument('--spec', dest='spec', type=str, required=True)
3423    parser.add_argument('--header', dest='header', action='store_true', default=None)
3424    parser.add_argument('--source', dest='header', action='store_false')
3425    parser.add_argument('--user-header', nargs='+', default=[])
3426    parser.add_argument('--cmp-out', action='store_true', default=None,
3427                        help='Do not overwrite the output file if the new output is identical to the old')
3428    parser.add_argument('--exclude-op', action='append', default=[])
3429    parser.add_argument('-o', dest='out_file', type=str, default=None)
3430    parser.add_argument('--function-prefix', dest='fn_prefix', type=str)
3431    args = parser.parse_args()
3432
3433    if args.header is None:
3434        parser.error("--header or --source is required")
3435
3436    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3437
3438    try:
3439        parsed = Family(args.spec, exclude_ops, args.fn_prefix)
3440        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3441            print('Spec license:', parsed.license)
3442            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3443            os.sys.exit(1)
3444    except pyyaml.YAMLError as exc:
3445        print(exc)
3446        os.sys.exit(1)
3447
3448    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=not args.cmp_out)
3449
3450    _, spec_kernel = find_kernel_root(args.spec)
3451    if args.mode == 'uapi' or args.header:
3452        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3453    else:
3454        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3455    cw.p("/* Do not edit directly, auto-generated from: */")
3456    cw.p(f"/*\t{spec_kernel} */")
3457    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3458    if args.exclude_op or args.user_header or args.fn_prefix:
3459        line = ''
3460        if args.user_header:
3461            line += ' --user-header '.join([''] + args.user_header)
3462        if args.exclude_op:
3463            line += ' --exclude-op '.join([''] + args.exclude_op)
3464        if args.fn_prefix:
3465            line += f' --function-prefix {args.fn_prefix}'
3466        cw.p(f'/* YNL-ARG{line} */')
3467    cw.p('/* To regenerate run: tools/net/ynl/ynl-regen.sh */')
3468    cw.nl()
3469
3470    if args.mode == 'uapi':
3471        render_uapi(parsed, cw)
3472        return
3473
3474    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3475    if args.header:
3476        cw.p('#ifndef ' + hdr_prot)
3477        cw.p('#define ' + hdr_prot)
3478        cw.nl()
3479
3480    if args.out_file:
3481        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3482    else:
3483        hdr_file = "generated_header_file.h"
3484
3485    if args.mode == 'kernel':
3486        cw.p('#include <net/netlink.h>')
3487        cw.p('#include <net/genetlink.h>')
3488        cw.nl()
3489        if not args.header:
3490            if args.out_file:
3491                cw.p(f'#include "{hdr_file}"')
3492            cw.nl()
3493        headers = ['uapi/' + parsed.uapi_header]
3494        headers += parsed.kernel_family.get('headers', [])
3495    else:
3496        cw.p('#include <stdlib.h>')
3497        cw.p('#include <string.h>')
3498        if args.header:
3499            cw.p('#include <linux/types.h>')
3500            if family_contains_bitfield32(parsed):
3501                cw.p('#include <linux/netlink.h>')
3502        else:
3503            cw.p(f'#include "{hdr_file}"')
3504            cw.p('#include "ynl.h"')
3505        headers = []
3506    for definition in parsed['definitions'] + parsed['attribute-sets']:
3507        if 'header' in definition:
3508            headers.append(definition['header'])
3509    if args.mode == 'user':
3510        headers.append(parsed.uapi_header)
3511    seen_header = []
3512    for one in headers:
3513        if one not in seen_header:
3514            cw.p(f"#include <{one}>")
3515            seen_header.append(one)
3516    cw.nl()
3517
3518    if args.mode == "user":
3519        if not args.header:
3520            cw.p("#include <linux/genetlink.h>")
3521            cw.nl()
3522            for one in args.user_header:
3523                cw.p(f'#include "{one}"')
3524        else:
3525            cw.p('struct ynl_sock;')
3526            cw.nl()
3527            render_user_family(parsed, cw, True)
3528        cw.nl()
3529
3530    if args.mode == "kernel":
3531        if args.header:
3532            for _, struct in sorted(parsed.pure_nested_structs.items()):
3533                if struct.request:
3534                    cw.p('/* Common nested types */')
3535                    break
3536            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3537                if struct.request:
3538                    print_req_policy_fwd(cw, struct)
3539            cw.nl()
3540
3541            if parsed.kernel_policy == 'global':
3542                cw.p(f"/* Global operation policy for {parsed.name} */")
3543
3544                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3545                print_req_policy_fwd(cw, struct)
3546                cw.nl()
3547
3548            if parsed.kernel_policy in {'per-op', 'split'}:
3549                for _op_name, op in parsed.ops.items():
3550                    if 'do' in op and 'event' not in op:
3551                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3552                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3553                        cw.nl()
3554
3555            print_kernel_op_table_hdr(parsed, cw)
3556            print_kernel_mcgrp_hdr(parsed, cw)
3557            print_kernel_family_struct_hdr(parsed, cw)
3558        else:
3559            print_kernel_policy_ranges(parsed, cw)
3560            print_kernel_policy_sparse_enum_validates(parsed, cw)
3561
3562            for _, struct in sorted(parsed.pure_nested_structs.items()):
3563                if struct.request:
3564                    cw.p('/* Common nested types */')
3565                    break
3566            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3567                if struct.request:
3568                    print_req_policy(cw, struct)
3569            cw.nl()
3570
3571            if parsed.kernel_policy == 'global':
3572                cw.p(f"/* Global operation policy for {parsed.name} */")
3573
3574                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3575                print_req_policy(cw, struct)
3576                cw.nl()
3577
3578            for _op_name, op in parsed.ops.items():
3579                if parsed.kernel_policy in {'per-op', 'split'}:
3580                    for op_mode in ['do', 'dump']:
3581                        if op_mode in op and 'request' in op[op_mode]:
3582                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3583                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3584                            print_req_policy(cw, ri.struct['request'], ri=ri)
3585                            cw.nl()
3586
3587            print_kernel_op_table(parsed, cw)
3588            print_kernel_mcgrp_src(parsed, cw)
3589            print_kernel_family_struct_src(parsed, cw)
3590
3591    if args.mode == "user":
3592        if args.header:
3593            cw.p('/* Enums */')
3594            put_op_name_fwd(parsed, cw)
3595
3596            for name, const in parsed.consts.items():
3597                if isinstance(const, EnumSet):
3598                    put_enum_to_str_fwd(parsed, cw, const)
3599            cw.nl()
3600
3601            cw.p('/* Common nested types */')
3602            for attr_set, struct in parsed.pure_nested_structs.items():
3603                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3604                print_type_full(ri, struct)
3605
3606            for _op_name, op in parsed.ops.items():
3607                cw.p(f"/* ============== {op.enum_name} ============== */")
3608
3609                if 'do' in op and 'event' not in op:
3610                    cw.p(f"/* {op.enum_name} - do */")
3611                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3612                    print_req_type(ri)
3613                    print_req_type_helpers(ri)
3614                    cw.nl()
3615                    print_rsp_type(ri)
3616                    print_rsp_type_helpers(ri)
3617                    cw.nl()
3618                    print_req_prototype(ri)
3619                    cw.nl()
3620
3621                if 'dump' in op:
3622                    cw.p(f"/* {op.enum_name} - dump */")
3623                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3624                    print_req_type(ri)
3625                    print_req_type_helpers(ri)
3626                    if not ri.type_consistent or ri.type_oneside:
3627                        print_rsp_type(ri)
3628                    print_wrapped_type(ri)
3629                    print_dump_prototype(ri)
3630                    cw.nl()
3631
3632                if op.has_ntf:
3633                    cw.p(f"/* {op.enum_name} - notify */")
3634                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3635                    if not ri.type_consistent:
3636                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3637                    print_wrapped_type(ri)
3638
3639            for _op_name, op in parsed.ntfs.items():
3640                if 'event' in op:
3641                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3642                    cw.p(f"/* {op.enum_name} - event */")
3643                    print_rsp_type(ri)
3644                    cw.nl()
3645                    print_wrapped_type(ri)
3646            cw.nl()
3647        else:
3648            cw.p('/* Enums */')
3649            put_op_name(parsed, cw)
3650
3651            for name, const in parsed.consts.items():
3652                if isinstance(const, EnumSet):
3653                    put_enum_to_str(parsed, cw, const)
3654            cw.nl()
3655
3656            has_recursive_nests = False
3657            cw.p('/* Policies */')
3658            for struct in parsed.pure_nested_structs.values():
3659                if struct.recursive:
3660                    put_typol_fwd(cw, struct)
3661                    has_recursive_nests = True
3662            if has_recursive_nests:
3663                cw.nl()
3664            for struct in parsed.pure_nested_structs.values():
3665                put_typol(cw, struct)
3666            for name in parsed.root_sets:
3667                struct = Struct(parsed, name)
3668                put_typol(cw, struct)
3669
3670            cw.p('/* Common nested types */')
3671            if has_recursive_nests:
3672                for attr_set, struct in parsed.pure_nested_structs.items():
3673                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3674                    free_rsp_nested_prototype(ri)
3675                    if struct.request:
3676                        put_req_nested_prototype(ri, struct)
3677                    if struct.reply:
3678                        parse_rsp_nested_prototype(ri, struct)
3679                cw.nl()
3680            for attr_set, struct in parsed.pure_nested_structs.items():
3681                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3682
3683                free_rsp_nested(ri, struct)
3684                if struct.request:
3685                    put_req_nested(ri, struct)
3686                if struct.reply:
3687                    parse_rsp_nested(ri, struct)
3688
3689            for _op_name, op in parsed.ops.items():
3690                cw.p(f"/* ============== {op.enum_name} ============== */")
3691                if 'do' in op and 'event' not in op:
3692                    cw.p(f"/* {op.enum_name} - do */")
3693                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3694                    print_req_free(ri)
3695                    print_rsp_free(ri)
3696                    parse_rsp_msg(ri)
3697                    print_req(ri)
3698                    cw.nl()
3699
3700                if 'dump' in op:
3701                    cw.p(f"/* {op.enum_name} - dump */")
3702                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3703                    if not ri.type_consistent or ri.type_oneside:
3704                        parse_rsp_msg(ri, deref=True)
3705                    print_req_free(ri)
3706                    print_dump_type_free(ri)
3707                    print_dump(ri)
3708                    cw.nl()
3709
3710                if op.has_ntf:
3711                    cw.p(f"/* {op.enum_name} - notify */")
3712                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3713                    if not ri.type_consistent:
3714                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3715                    print_ntf_type_free(ri)
3716
3717            for _op_name, op in parsed.ntfs.items():
3718                if 'event' in op:
3719                    cw.p(f"/* {op.enum_name} - event */")
3720
3721                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3722                    parse_rsp_msg(ri)
3723
3724                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3725                    print_ntf_type_free(ri)
3726            cw.nl()
3727            render_user_family(parsed, cw, False)
3728
3729    if args.header:
3730        cw.p(f'#endif /* {hdr_prot} */')
3731
3732
3733if __name__ == "__main__":
3734    main()
3735