xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 0939a418b3b092aa8a97a59752084896e4a7c813)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17from lib import SpecSubMessage, SpecSubMessageFormat
18
19
20def c_upper(name):
21    return name.upper().replace('-', '_')
22
23
24def c_lower(name):
25    return name.lower().replace('-', '_')
26
27
28def limit_to_number(name):
29    """
30    Turn a string limit like u32-max or s64-min into its numerical value
31    """
32    if name[0] == 'u' and name.endswith('-min'):
33        return 0
34    width = int(name[1:-4])
35    if name[0] == 's':
36        width -= 1
37    value = (1 << width) - 1
38    if name[0] == 's' and name.endswith('-min'):
39        value = -value - 1
40    return value
41
42
43class BaseNlLib:
44    def get_family_id(self):
45        return 'ys->family_id'
46
47
48class Type(SpecAttr):
49    def __init__(self, family, attr_set, attr, value):
50        super().__init__(family, attr_set, attr, value)
51
52        self.attr = attr
53        self.attr_set = attr_set
54        self.type = attr['type']
55        self.checks = attr.get('checks', {})
56
57        self.request = False
58        self.reply = False
59
60        self.is_selector = False
61
62        if 'len' in attr:
63            self.len = attr['len']
64
65        if 'nested-attributes' in attr:
66            nested = attr['nested-attributes']
67        elif 'sub-message' in attr:
68            nested = attr['sub-message']
69        else:
70            nested = None
71
72        if nested:
73            self.nested_attrs = nested
74            if self.nested_attrs == family.name:
75                self.nested_render_name = c_lower(f"{family.ident_name}")
76            else:
77                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
78
79            if self.nested_attrs in self.family.consts:
80                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
81            else:
82                self.nested_struct_type = 'struct ' + self.nested_render_name
83
84        self.c_name = c_lower(self.name)
85        if self.c_name in _C_KW:
86            self.c_name += '_'
87        if self.c_name[0].isdigit():
88            self.c_name = '_' + self.c_name
89
90        # Added by resolve():
91        self.enum_name = None
92        delattr(self, "enum_name")
93
94    def _get_real_attr(self):
95        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
96        return self.family.attr_sets[self.attr_set.subset_of][self.name]
97
98    def set_request(self):
99        self.request = True
100        if self.attr_set.subset_of:
101            self._get_real_attr().set_request()
102
103    def set_reply(self):
104        self.reply = True
105        if self.attr_set.subset_of:
106            self._get_real_attr().set_reply()
107
108    def get_limit(self, limit, default=None):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return value
112        if isinstance(value, int):
113            return value
114        if value in self.family.consts:
115            return self.family.consts[value]["value"]
116        return limit_to_number(value)
117
118    def get_limit_str(self, limit, default=None, suffix=''):
119        value = self.checks.get(limit, default)
120        if value is None:
121            return ''
122        if isinstance(value, int):
123            return str(value) + suffix
124        if value in self.family.consts:
125            const = self.family.consts[value]
126            if const.get('header'):
127                return c_upper(value)
128            return c_upper(f"{self.family['name']}-{value}")
129        return c_upper(value)
130
131    def resolve(self):
132        if 'parent-sub-message' in self.attr:
133            enum_name = self.attr['parent-sub-message'].enum_name
134        elif 'name-prefix' in self.attr:
135            enum_name = f"{self.attr['name-prefix']}{self.name}"
136        else:
137            enum_name = f"{self.attr_set.name_prefix}{self.name}"
138        self.enum_name = c_upper(enum_name)
139
140        if self.attr_set.subset_of:
141            if self.checks != self._get_real_attr().checks:
142                raise Exception("Overriding checks not supported by codegen, yet")
143
144    def is_multi_val(self):
145        return None
146
147    def is_scalar(self):
148        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
149
150    def is_recursive(self):
151        return False
152
153    def is_recursive_for_op(self, ri):
154        return self.is_recursive() and not ri.op
155
156    def presence_type(self):
157        return 'present'
158
159    def presence_member(self, space, type_filter):
160        if self.presence_type() != type_filter:
161            return
162
163        if self.presence_type() == 'present':
164            pfx = '__' if space == 'user' else ''
165            return f"{pfx}u32 {self.c_name}:1;"
166
167        if self.presence_type() in {'len', 'count'}:
168            pfx = '__' if space == 'user' else ''
169            return f"{pfx}u32 {self.c_name};"
170
171    def _complex_member_type(self, ri):
172        return None
173
174    def free_needs_iter(self):
175        return False
176
177    def _free_lines(self, ri, var, ref):
178        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
179            return [f'free({var}->{ref}{self.c_name});']
180        return []
181
182    def free(self, ri, var, ref):
183        lines = self._free_lines(ri, var, ref)
184        for line in lines:
185            ri.cw.p(line)
186
187    def arg_member(self, ri):
188        member = self._complex_member_type(ri)
189        if member:
190            spc = ' ' if member[-1] != '*' else ''
191            arg = [member + spc + '*' + self.c_name]
192            if self.presence_type() == 'count':
193                arg += ['unsigned int n_' + self.c_name]
194            return arg
195        raise Exception(f"Struct member not implemented for class type {self.type}")
196
197    def struct_member(self, ri):
198        member = self._complex_member_type(ri)
199        if member:
200            ptr = '*' if self.is_multi_val() else ''
201            if self.is_recursive_for_op(ri):
202                ptr = '*'
203            spc = ' ' if member[-1] != '*' else ''
204            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
205            return
206        members = self.arg_member(ri)
207        for one in members:
208            ri.cw.p(one + ';')
209
210    def _attr_policy(self, policy):
211        return '{ .type = ' + policy + ', }'
212
213    def attr_policy(self, cw):
214        policy = f'NLA_{c_upper(self.type)}'
215        if self.attr.get('byte-order') == 'big-endian':
216            if self.type in {'u16', 'u32'}:
217                policy = f'NLA_BE{self.type[1:]}'
218
219        spec = self._attr_policy(policy)
220        cw.p(f"\t[{self.enum_name}] = {spec},")
221
222    def _attr_typol(self):
223        raise Exception(f"Type policy not implemented for class type {self.type}")
224
225    def attr_typol(self, cw):
226        typol = self._attr_typol()
227        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
228
229    def _attr_put_line(self, ri, var, line):
230        presence = self.presence_type()
231        if presence in {'present', 'len'}:
232            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
233        ri.cw.p(f"{line};")
234
235    def _attr_put_simple(self, ri, var, put_type):
236        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
237        self._attr_put_line(ri, var, line)
238
239    def attr_put(self, ri, var):
240        raise Exception(f"Put not implemented for class type {self.type}")
241
242    def _attr_get(self, ri, var):
243        raise Exception(f"Attr get not implemented for class type {self.type}")
244
245    def attr_get(self, ri, var, first):
246        lines, init_lines, local_vars = self._attr_get(ri, var)
247        if type(lines) is str:
248            lines = [lines]
249        if type(init_lines) is str:
250            init_lines = [init_lines]
251
252        kw = 'if' if first else 'else if'
253        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
254        if local_vars:
255            for local in local_vars:
256                ri.cw.p(local)
257            ri.cw.nl()
258
259        if not self.is_multi_val():
260            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
261            ri.cw.p("return YNL_PARSE_CB_ERROR;")
262            if self.presence_type() == 'present':
263                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
264
265        if init_lines:
266            ri.cw.nl()
267            for line in init_lines:
268                ri.cw.p(line)
269
270        for line in lines:
271            ri.cw.p(line)
272        ri.cw.block_end()
273        return True
274
275    def _setter_lines(self, ri, member, presence):
276        raise Exception(f"Setter not implemented for class type {self.type}")
277
278    def setter(self, ri, space, direction, deref=False, ref=None):
279        ref = (ref if ref else []) + [self.c_name]
280        var = "req"
281        member = f"{var}->{'.'.join(ref)}"
282
283        local_vars = []
284        if self.free_needs_iter():
285            local_vars += ['unsigned int i;']
286
287        code = []
288        presence = ''
289        for i in range(0, len(ref)):
290            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
291            # Every layer below last is a nest, so we know it uses bit presence
292            # last layer is "self" and may be a complex type
293            if i == len(ref) - 1 and self.presence_type() != 'present':
294                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
295                continue
296            code.append(presence + ' = 1;')
297        ref_path = '.'.join(ref[:-1])
298        if ref_path:
299            ref_path += '.'
300        code += self._free_lines(ri, var, ref_path)
301        code += self._setter_lines(ri, member, presence)
302
303        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
304        free = bool([x for x in code if 'free(' in x])
305        alloc = bool([x for x in code if 'alloc(' in x])
306        if free and not alloc:
307            func_name = '__' + func_name
308        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
309                         body=code,
310                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
311
312
313class TypeUnused(Type):
314    def presence_type(self):
315        return ''
316
317    def arg_member(self, ri):
318        return []
319
320    def _attr_get(self, ri, var):
321        return ['return YNL_PARSE_CB_ERROR;'], None, None
322
323    def _attr_typol(self):
324        return '.type = YNL_PT_REJECT, '
325
326    def attr_policy(self, cw):
327        pass
328
329    def attr_put(self, ri, var):
330        pass
331
332    def attr_get(self, ri, var, first):
333        pass
334
335    def setter(self, ri, space, direction, deref=False, ref=None):
336        pass
337
338
339class TypePad(Type):
340    def presence_type(self):
341        return ''
342
343    def arg_member(self, ri):
344        return []
345
346    def _attr_typol(self):
347        return '.type = YNL_PT_IGNORE, '
348
349    def attr_put(self, ri, var):
350        pass
351
352    def attr_get(self, ri, var, first):
353        pass
354
355    def attr_policy(self, cw):
356        pass
357
358    def setter(self, ri, space, direction, deref=False, ref=None):
359        pass
360
361
362class TypeScalar(Type):
363    def __init__(self, family, attr_set, attr, value):
364        super().__init__(family, attr_set, attr, value)
365
366        self.byte_order_comment = ''
367        if 'byte-order' in attr:
368            self.byte_order_comment = f" /* {attr['byte-order']} */"
369
370        # Classic families have some funny enums, don't bother
371        # computing checks, since we only need them for kernel policies
372        if not family.is_classic():
373            self._init_checks()
374
375        # Added by resolve():
376        self.is_bitfield = None
377        delattr(self, "is_bitfield")
378        self.type_name = None
379        delattr(self, "type_name")
380
381    def resolve(self):
382        self.resolve_up(super())
383
384        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
385            self.is_bitfield = True
386        elif 'enum' in self.attr:
387            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
388        else:
389            self.is_bitfield = False
390
391        if not self.is_bitfield and 'enum' in self.attr:
392            self.type_name = self.family.consts[self.attr['enum']].user_type
393        elif self.is_auto_scalar:
394            self.type_name = '__' + self.type[0] + '64'
395        else:
396            self.type_name = '__' + self.type
397
398    def _init_checks(self):
399        if 'enum' in self.attr:
400            enum = self.family.consts[self.attr['enum']]
401            low, high = enum.value_range()
402            if low == None and high == None:
403                self.checks['sparse'] = True
404            else:
405                if 'min' not in self.checks:
406                    if low != 0 or self.type[0] == 's':
407                        self.checks['min'] = low
408                if 'max' not in self.checks:
409                    self.checks['max'] = high
410
411        if 'min' in self.checks and 'max' in self.checks:
412            if self.get_limit('min') > self.get_limit('max'):
413                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
414            self.checks['range'] = True
415
416        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
417        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
418        if low < 0 and self.type[0] == 'u':
419            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
420        if low < -32768 or high > 32767:
421            self.checks['full-range'] = True
422
423    def _attr_policy(self, policy):
424        if 'flags-mask' in self.checks or self.is_bitfield:
425            if self.is_bitfield:
426                enum = self.family.consts[self.attr['enum']]
427                mask = enum.get_mask(as_flags=True)
428            else:
429                flags = self.family.consts[self.checks['flags-mask']]
430                flag_cnt = len(flags['entries'])
431                mask = (1 << flag_cnt) - 1
432            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
433        elif 'full-range' in self.checks:
434            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
435        elif 'range' in self.checks:
436            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
437        elif 'min' in self.checks:
438            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
439        elif 'max' in self.checks:
440            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
441        elif 'sparse' in self.checks:
442            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
443        return super()._attr_policy(policy)
444
445    def _attr_typol(self):
446        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
447
448    def arg_member(self, ri):
449        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
450
451    def attr_put(self, ri, var):
452        self._attr_put_simple(ri, var, self.type)
453
454    def _attr_get(self, ri, var):
455        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
456
457    def _setter_lines(self, ri, member, presence):
458        return [f"{member} = {self.c_name};"]
459
460
461class TypeFlag(Type):
462    def arg_member(self, ri):
463        return []
464
465    def _attr_typol(self):
466        return '.type = YNL_PT_FLAG, '
467
468    def attr_put(self, ri, var):
469        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
470
471    def _attr_get(self, ri, var):
472        return [], None, None
473
474    def _setter_lines(self, ri, member, presence):
475        return []
476
477
478class TypeString(Type):
479    def arg_member(self, ri):
480        return [f"const char *{self.c_name}"]
481
482    def presence_type(self):
483        return 'len'
484
485    def struct_member(self, ri):
486        ri.cw.p(f"char *{self.c_name};")
487
488    def _attr_typol(self):
489        typol = f'.type = YNL_PT_NUL_STR, '
490        if self.is_selector:
491            typol += '.is_selector = 1, '
492        return typol
493
494    def _attr_policy(self, policy):
495        if 'exact-len' in self.checks:
496            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
497        else:
498            mem = '{ .type = ' + policy
499            if 'max-len' in self.checks:
500                mem += ', .len = ' + self.get_limit_str('max-len')
501            mem += ', }'
502        return mem
503
504    def attr_policy(self, cw):
505        if self.checks.get('unterminated-ok', False):
506            policy = 'NLA_STRING'
507        else:
508            policy = 'NLA_NUL_STRING'
509
510        spec = self._attr_policy(policy)
511        cw.p(f"\t[{self.enum_name}] = {spec},")
512
513    def attr_put(self, ri, var):
514        self._attr_put_simple(ri, var, 'str')
515
516    def _attr_get(self, ri, var):
517        len_mem = var + '->_len.' + self.c_name
518        return [f"{len_mem} = len;",
519                f"{var}->{self.c_name} = malloc(len + 1);",
520                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
521                f"{var}->{self.c_name}[len] = 0;"], \
522               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
523               ['unsigned int len;']
524
525    def _setter_lines(self, ri, member, presence):
526        return [f"{presence} = strlen({self.c_name});",
527                f"{member} = malloc({presence} + 1);",
528                f'memcpy({member}, {self.c_name}, {presence});',
529                f'{member}[{presence}] = 0;']
530
531
532class TypeBinary(Type):
533    def arg_member(self, ri):
534        return [f"const void *{self.c_name}", 'size_t len']
535
536    def presence_type(self):
537        return 'len'
538
539    def struct_member(self, ri):
540        ri.cw.p(f"void *{self.c_name};")
541
542    def _attr_typol(self):
543        return f'.type = YNL_PT_BINARY,'
544
545    def _attr_policy(self, policy):
546        if len(self.checks) == 0:
547            pass
548        elif len(self.checks) == 1:
549            check_name = list(self.checks)[0]
550            if check_name not in {'exact-len', 'min-len', 'max-len'}:
551                raise Exception('Unsupported check for binary type: ' + check_name)
552        else:
553            raise Exception('More than one check for binary type not implemented, yet')
554
555        if len(self.checks) == 0:
556            mem = '{ .type = NLA_BINARY, }'
557        elif 'exact-len' in self.checks:
558            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
559        elif 'min-len' in self.checks:
560            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
561        elif 'max-len' in self.checks:
562            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
563
564        return mem
565
566    def attr_put(self, ri, var):
567        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
568                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
569
570    def _attr_get(self, ri, var):
571        len_mem = var + '->_len.' + self.c_name
572        return [f"{len_mem} = len;",
573                f"{var}->{self.c_name} = malloc(len);",
574                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
575               ['len = ynl_attr_data_len(attr);'], \
576               ['unsigned int len;']
577
578    def _setter_lines(self, ri, member, presence):
579        return [f"{presence} = len;",
580                f"{member} = malloc({presence});",
581                f'memcpy({member}, {self.c_name}, {presence});']
582
583
584class TypeBinaryStruct(TypeBinary):
585    def struct_member(self, ri):
586        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
587
588    def _attr_get(self, ri, var):
589        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
590        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
591        return [f"{len_mem} = len;",
592                f"if (len < {struct_sz})",
593                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
594                "else",
595                f"{var}->{self.c_name} = malloc(len);",
596                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
597               ['len = ynl_attr_data_len(attr);'], \
598               ['unsigned int len;']
599
600
601class TypeBinaryScalarArray(TypeBinary):
602    def arg_member(self, ri):
603        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
604
605    def presence_type(self):
606        return 'count'
607
608    def struct_member(self, ri):
609        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
610
611    def attr_put(self, ri, var):
612        presence = self.presence_type()
613        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
614        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
615        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
616                f"{var}->{self.c_name}, i);")
617        ri.cw.block_end()
618
619    def _attr_get(self, ri, var):
620        len_mem = var + '->_count.' + self.c_name
621        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
622                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
623                f"{var}->{self.c_name} = malloc(len);",
624                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
625               ['len = ynl_attr_data_len(attr);'], \
626               ['unsigned int len;']
627
628    def _setter_lines(self, ri, member, presence):
629        return [f"{presence} = count;",
630                f"count *= sizeof(__{self.get('sub-type')});",
631                f"{member} = malloc(count);",
632                f'memcpy({member}, {self.c_name}, count);']
633
634
635class TypeBitfield32(Type):
636    def _complex_member_type(self, ri):
637        return "struct nla_bitfield32"
638
639    def _attr_typol(self):
640        return f'.type = YNL_PT_BITFIELD32, '
641
642    def _attr_policy(self, policy):
643        if not 'enum' in self.attr:
644            raise Exception('Enum required for bitfield32 attr')
645        enum = self.family.consts[self.attr['enum']]
646        mask = enum.get_mask(as_flags=True)
647        return f"NLA_POLICY_BITFIELD32({mask})"
648
649    def attr_put(self, ri, var):
650        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
651        self._attr_put_line(ri, var, line)
652
653    def _attr_get(self, ri, var):
654        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
655
656    def _setter_lines(self, ri, member, presence):
657        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
658
659
660class TypeNest(Type):
661    def is_recursive(self):
662        return self.family.pure_nested_structs[self.nested_attrs].recursive
663
664    def _complex_member_type(self, ri):
665        return self.nested_struct_type
666
667    def _free_lines(self, ri, var, ref):
668        lines = []
669        at = '&'
670        if self.is_recursive_for_op(ri):
671            at = ''
672            lines += [f'if ({var}->{ref}{self.c_name})']
673        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
674        return lines
675
676    def _attr_typol(self):
677        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
678
679    def _attr_policy(self, policy):
680        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
681
682    def attr_put(self, ri, var):
683        at = '' if self.is_recursive_for_op(ri) else '&'
684        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
685                            f"{self.enum_name}, {at}{var}->{self.c_name})")
686
687    def _attr_get(self, ri, var):
688        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
689                     "return YNL_PARSE_CB_ERROR;"]
690        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
691                      f"parg.data = &{var}->{self.c_name};"]
692        return get_lines, init_lines, None
693
694    def setter(self, ri, space, direction, deref=False, ref=None):
695        ref = (ref if ref else []) + [self.c_name]
696
697        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
698            if attr.is_recursive():
699                continue
700            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
701
702
703class TypeMultiAttr(Type):
704    def __init__(self, family, attr_set, attr, value, base_type):
705        super().__init__(family, attr_set, attr, value)
706
707        self.base_type = base_type
708
709    def is_multi_val(self):
710        return True
711
712    def presence_type(self):
713        return 'count'
714
715    def _complex_member_type(self, ri):
716        if 'type' not in self.attr or self.attr['type'] == 'nest':
717            return self.nested_struct_type
718        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
719            return None  # use arg_member()
720        elif self.attr['type'] == 'string':
721            return 'struct ynl_string *'
722        elif self.attr['type'] in scalars:
723            scalar_pfx = '__' if ri.ku_space == 'user' else ''
724            return scalar_pfx + self.attr['type']
725        else:
726            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
727
728    def arg_member(self, ri):
729        if self.type == 'binary' and 'struct' in self.attr:
730            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
731                    f'unsigned int n_{self.c_name}']
732        return super().arg_member(ri)
733
734    def free_needs_iter(self):
735        return self.attr['type'] in {'nest', 'string'}
736
737    def _free_lines(self, ri, var, ref):
738        lines = []
739        if self.attr['type'] in scalars:
740            lines += [f"free({var}->{ref}{self.c_name});"]
741        elif self.attr['type'] == 'binary':
742            lines += [f"free({var}->{ref}{self.c_name});"]
743        elif self.attr['type'] == 'string':
744            lines += [
745                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
746                f"free({var}->{ref}{self.c_name}[i]);",
747                f"free({var}->{ref}{self.c_name});",
748            ]
749        elif 'type' not in self.attr or self.attr['type'] == 'nest':
750            lines += [
751                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
752                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
753                f"free({var}->{ref}{self.c_name});",
754            ]
755        else:
756            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
757        return lines
758
759    def _attr_policy(self, policy):
760        return self.base_type._attr_policy(policy)
761
762    def _attr_typol(self):
763        return self.base_type._attr_typol()
764
765    def _attr_get(self, ri, var):
766        return f'n_{self.c_name}++;', None, None
767
768    def attr_put(self, ri, var):
769        if self.attr['type'] in scalars:
770            put_type = self.type
771            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
772            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
773        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
774            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
775            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
776        elif self.attr['type'] == 'string':
777            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
778            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
779        elif 'type' not in self.attr or self.attr['type'] == 'nest':
780            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
781            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
782                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
783        else:
784            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
785
786    def _setter_lines(self, ri, member, presence):
787        return [f"{member} = {self.c_name};",
788                f"{presence} = n_{self.c_name};"]
789
790
791class TypeArrayNest(Type):
792    def is_multi_val(self):
793        return True
794
795    def presence_type(self):
796        return 'count'
797
798    def _complex_member_type(self, ri):
799        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
800            return self.nested_struct_type
801        elif self.attr['sub-type'] in scalars:
802            scalar_pfx = '__' if ri.ku_space == 'user' else ''
803            return scalar_pfx + self.attr['sub-type']
804        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
805            return None  # use arg_member()
806        else:
807            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
808
809    def arg_member(self, ri):
810        if self.sub_type == 'binary' and 'exact-len' in self.checks:
811            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
812                    f'unsigned int n_{self.c_name}']
813        return super().arg_member(ri)
814
815    def _attr_typol(self):
816        if self.attr['sub-type'] in scalars:
817            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
818        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
819            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
820        else:
821            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
822
823    def _attr_get(self, ri, var):
824        local_vars = ['const struct nlattr *attr2;']
825        get_lines = [f'attr_{self.c_name} = attr;',
826                     'ynl_attr_for_each_nested(attr2, attr) {',
827                     '\tif (ynl_attr_validate(yarg, attr2))',
828                     '\t\treturn YNL_PARSE_CB_ERROR;',
829                     f'\t{var}->_count.{self.c_name}++;',
830                     '}']
831        return get_lines, None, local_vars
832
833    def attr_put(self, ri, var):
834        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
835        if self.sub_type in scalars:
836            put_type = self.sub_type
837            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
838            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
839            ri.cw.block_end()
840        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
841            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
842            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
843        elif self.sub_type == 'nest':
844            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
845            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
846        else:
847            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
848        ri.cw.p('ynl_attr_nest_end(nlh, array);')
849
850    def _setter_lines(self, ri, member, presence):
851        return [f"{member} = {self.c_name};",
852                f"{presence} = n_{self.c_name};"]
853
854
855class TypeNestTypeValue(Type):
856    def _complex_member_type(self, ri):
857        return self.nested_struct_type
858
859    def _attr_typol(self):
860        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
861
862    def _attr_get(self, ri, var):
863        prev = 'attr'
864        tv_args = ''
865        get_lines = []
866        local_vars = []
867        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
868                      f"parg.data = &{var}->{self.c_name};"]
869        if 'type-value' in self.attr:
870            tv_names = [c_lower(x) for x in self.attr["type-value"]]
871            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
872            local_vars += [f'__u32 {", ".join(tv_names)};']
873            for level in self.attr["type-value"]:
874                level = c_lower(level)
875                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
876                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
877                prev = 'attr_' + level
878
879            tv_args = f", {', '.join(tv_names)}"
880
881        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
882        return get_lines, init_lines, local_vars
883
884
885class TypeSubMessage(TypeNest):
886    def __init__(self, family, attr_set, attr, value):
887        super().__init__(family, attr_set, attr, value)
888
889        self.selector = Selector(attr, attr_set)
890
891    def _attr_typol(self):
892        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
893        typol += f'.is_submsg = 1, .selector_type = {self.attr_set[self["selector"]].value} '
894        return typol
895
896    def _attr_get(self, ri, var):
897        sel = c_lower(self['selector'])
898        get_lines = [f'if (!{var}->{sel})',
899                     f'return ynl_submsg_failed(yarg, "%s", "%s");' %
900                        (self.name, self['selector']),
901                    f"if ({self.nested_render_name}_parse(&parg, {var}->{sel}, attr))",
902                     "return YNL_PARSE_CB_ERROR;"]
903        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
904                      f"parg.data = &{var}->{self.c_name};"]
905        return get_lines, init_lines, None
906
907
908class Selector:
909    def __init__(self, msg_attr, attr_set):
910        self.name = msg_attr["selector"]
911
912        if self.name in attr_set:
913            self.attr = attr_set[self.name]
914            self.attr.is_selector = True
915            self._external = False
916        else:
917            raise Exception("Passing selectors from external nests not supported")
918
919
920class Struct:
921    def __init__(self, family, space_name, type_list=None,
922                 inherited=None, submsg=None):
923        self.family = family
924        self.space_name = space_name
925        self.attr_set = family.attr_sets[space_name]
926        # Use list to catch comparisons with empty sets
927        self._inherited = inherited if inherited is not None else []
928        self.inherited = []
929        self.submsg = submsg
930
931        self.nested = type_list is None
932        if family.name == c_lower(space_name):
933            self.render_name = c_lower(family.ident_name)
934        else:
935            self.render_name = c_lower(family.ident_name + '-' + space_name)
936        self.struct_name = 'struct ' + self.render_name
937        if self.nested and space_name in family.consts:
938            self.struct_name += '_'
939        self.ptr_name = self.struct_name + ' *'
940        # All attr sets this one contains, directly or multiple levels down
941        self.child_nests = set()
942
943        self.request = False
944        self.reply = False
945        self.recursive = False
946        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
947
948        self.attr_list = []
949        self.attrs = dict()
950        if type_list is not None:
951            for t in type_list:
952                self.attr_list.append((t, self.attr_set[t]),)
953        else:
954            for t in self.attr_set:
955                self.attr_list.append((t, self.attr_set[t]),)
956
957        max_val = 0
958        self.attr_max_val = None
959        for name, attr in self.attr_list:
960            if attr.value >= max_val:
961                max_val = attr.value
962                self.attr_max_val = attr
963            self.attrs[name] = attr
964
965    def __iter__(self):
966        yield from self.attrs
967
968    def __getitem__(self, key):
969        return self.attrs[key]
970
971    def member_list(self):
972        return self.attr_list
973
974    def set_inherited(self, new_inherited):
975        if self._inherited != new_inherited:
976            raise Exception("Inheriting different members not supported")
977        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
978
979    def free_needs_iter(self):
980        for _, attr in self.attr_list:
981            if attr.free_needs_iter():
982                return True
983        return False
984
985
986class EnumEntry(SpecEnumEntry):
987    def __init__(self, enum_set, yaml, prev, value_start):
988        super().__init__(enum_set, yaml, prev, value_start)
989
990        if prev:
991            self.value_change = (self.value != prev.value + 1)
992        else:
993            self.value_change = (self.value != 0)
994        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
995
996        # Added by resolve:
997        self.c_name = None
998        delattr(self, "c_name")
999
1000    def resolve(self):
1001        self.resolve_up(super())
1002
1003        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1004
1005
1006class EnumSet(SpecEnumSet):
1007    def __init__(self, family, yaml):
1008        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1009
1010        if 'enum-name' in yaml:
1011            if yaml['enum-name']:
1012                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1013                self.user_type = self.enum_name
1014            else:
1015                self.enum_name = None
1016        else:
1017            self.enum_name = 'enum ' + self.render_name
1018
1019        if self.enum_name:
1020            self.user_type = self.enum_name
1021        else:
1022            self.user_type = 'int'
1023
1024        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1025        self.header = yaml.get('header', None)
1026        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1027
1028        super().__init__(family, yaml)
1029
1030    def new_entry(self, entry, prev_entry, value_start):
1031        return EnumEntry(self, entry, prev_entry, value_start)
1032
1033    def value_range(self):
1034        low = min([x.value for x in self.entries.values()])
1035        high = max([x.value for x in self.entries.values()])
1036
1037        if high - low + 1 != len(self.entries):
1038            return None, None
1039
1040        return low, high
1041
1042
1043class AttrSet(SpecAttrSet):
1044    def __init__(self, family, yaml):
1045        super().__init__(family, yaml)
1046
1047        if self.subset_of is None:
1048            if 'name-prefix' in yaml:
1049                pfx = yaml['name-prefix']
1050            elif self.name == family.name:
1051                pfx = family.ident_name + '-a-'
1052            else:
1053                pfx = f"{family.ident_name}-a-{self.name}-"
1054            self.name_prefix = c_upper(pfx)
1055            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1056            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1057        else:
1058            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1059            self.max_name = family.attr_sets[self.subset_of].max_name
1060            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1061
1062        # Added by resolve:
1063        self.c_name = None
1064        delattr(self, "c_name")
1065
1066    def resolve(self):
1067        self.c_name = c_lower(self.name)
1068        if self.c_name in _C_KW:
1069            self.c_name += '_'
1070        if self.c_name == self.family.c_name:
1071            self.c_name = ''
1072
1073    def new_attr(self, elem, value):
1074        if elem['type'] in scalars:
1075            t = TypeScalar(self.family, self, elem, value)
1076        elif elem['type'] == 'unused':
1077            t = TypeUnused(self.family, self, elem, value)
1078        elif elem['type'] == 'pad':
1079            t = TypePad(self.family, self, elem, value)
1080        elif elem['type'] == 'flag':
1081            t = TypeFlag(self.family, self, elem, value)
1082        elif elem['type'] == 'string':
1083            t = TypeString(self.family, self, elem, value)
1084        elif elem['type'] == 'binary':
1085            if 'struct' in elem:
1086                t = TypeBinaryStruct(self.family, self, elem, value)
1087            elif elem.get('sub-type') in scalars:
1088                t = TypeBinaryScalarArray(self.family, self, elem, value)
1089            else:
1090                t = TypeBinary(self.family, self, elem, value)
1091        elif elem['type'] == 'bitfield32':
1092            t = TypeBitfield32(self.family, self, elem, value)
1093        elif elem['type'] == 'nest':
1094            t = TypeNest(self.family, self, elem, value)
1095        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1096            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1097                t = TypeArrayNest(self.family, self, elem, value)
1098            else:
1099                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1100        elif elem['type'] == 'nest-type-value':
1101            t = TypeNestTypeValue(self.family, self, elem, value)
1102        elif elem['type'] == 'sub-message':
1103            t = TypeSubMessage(self.family, self, elem, value)
1104        else:
1105            raise Exception(f"No typed class for type {elem['type']}")
1106
1107        if 'multi-attr' in elem and elem['multi-attr']:
1108            t = TypeMultiAttr(self.family, self, elem, value, t)
1109
1110        return t
1111
1112
1113class Operation(SpecOperation):
1114    def __init__(self, family, yaml, req_value, rsp_value):
1115        # Fill in missing operation properties (for fixed hdr-only msgs)
1116        for mode in ['do', 'dump', 'event']:
1117            for direction in ['request', 'reply']:
1118                try:
1119                    yaml[mode][direction].setdefault('attributes', [])
1120                except KeyError:
1121                    pass
1122
1123        super().__init__(family, yaml, req_value, rsp_value)
1124
1125        self.render_name = c_lower(family.ident_name + '_' + self.name)
1126
1127        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1128                         ('dump' in yaml and 'request' in yaml['dump'])
1129
1130        self.has_ntf = False
1131
1132        # Added by resolve:
1133        self.enum_name = None
1134        delattr(self, "enum_name")
1135
1136    def resolve(self):
1137        self.resolve_up(super())
1138
1139        if not self.is_async:
1140            self.enum_name = self.family.op_prefix + c_upper(self.name)
1141        else:
1142            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1143
1144    def mark_has_ntf(self):
1145        self.has_ntf = True
1146
1147
1148class SubMessage(SpecSubMessage):
1149    def __init__(self, family, yaml):
1150        super().__init__(family, yaml)
1151
1152        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1153
1154    def resolve(self):
1155        self.resolve_up(super())
1156
1157
1158class Family(SpecFamily):
1159    def __init__(self, file_name, exclude_ops):
1160        # Added by resolve:
1161        self.c_name = None
1162        delattr(self, "c_name")
1163        self.op_prefix = None
1164        delattr(self, "op_prefix")
1165        self.async_op_prefix = None
1166        delattr(self, "async_op_prefix")
1167        self.mcgrps = None
1168        delattr(self, "mcgrps")
1169        self.consts = None
1170        delattr(self, "consts")
1171        self.hooks = None
1172        delattr(self, "hooks")
1173
1174        super().__init__(file_name, exclude_ops=exclude_ops)
1175
1176        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1177        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1178
1179        if 'definitions' not in self.yaml:
1180            self.yaml['definitions'] = []
1181
1182        if 'uapi-header' in self.yaml:
1183            self.uapi_header = self.yaml['uapi-header']
1184        else:
1185            self.uapi_header = f"linux/{self.ident_name}.h"
1186        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1187            self.uapi_header_name = self.uapi_header[6:-2]
1188        else:
1189            self.uapi_header_name = self.ident_name
1190
1191    def resolve(self):
1192        self.resolve_up(super())
1193
1194        self.c_name = c_lower(self.ident_name)
1195        if 'name-prefix' in self.yaml['operations']:
1196            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1197        else:
1198            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1199        if 'async-prefix' in self.yaml['operations']:
1200            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1201        else:
1202            self.async_op_prefix = self.op_prefix
1203
1204        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1205
1206        self.hooks = dict()
1207        for when in ['pre', 'post']:
1208            self.hooks[when] = dict()
1209            for op_mode in ['do', 'dump']:
1210                self.hooks[when][op_mode] = dict()
1211                self.hooks[when][op_mode]['set'] = set()
1212                self.hooks[when][op_mode]['list'] = []
1213
1214        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1215        self.root_sets = dict()
1216        # dict space-name -> Struct
1217        self.pure_nested_structs = dict()
1218
1219        self._mark_notify()
1220        self._mock_up_events()
1221
1222        self._load_root_sets()
1223        self._load_nested_sets()
1224        self._load_attr_use()
1225        self._load_hooks()
1226
1227        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1228        if self.kernel_policy == 'global':
1229            self._load_global_policy()
1230
1231    def new_enum(self, elem):
1232        return EnumSet(self, elem)
1233
1234    def new_attr_set(self, elem):
1235        return AttrSet(self, elem)
1236
1237    def new_operation(self, elem, req_value, rsp_value):
1238        return Operation(self, elem, req_value, rsp_value)
1239
1240    def new_sub_message(self, elem):
1241        return SubMessage(self, elem)
1242
1243    def is_classic(self):
1244        return self.proto == 'netlink-raw'
1245
1246    def _mark_notify(self):
1247        for op in self.msgs.values():
1248            if 'notify' in op:
1249                self.ops[op['notify']].mark_has_ntf()
1250
1251    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1252    def _mock_up_events(self):
1253        for op in self.yaml['operations']['list']:
1254            if 'event' in op:
1255                op['do'] = {
1256                    'reply': {
1257                        'attributes': op['event']['attributes']
1258                    }
1259                }
1260
1261    def _load_root_sets(self):
1262        for op_name, op in self.msgs.items():
1263            if 'attribute-set' not in op:
1264                continue
1265
1266            req_attrs = set()
1267            rsp_attrs = set()
1268            for op_mode in ['do', 'dump']:
1269                if op_mode in op and 'request' in op[op_mode]:
1270                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1271                if op_mode in op and 'reply' in op[op_mode]:
1272                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1273            if 'event' in op:
1274                rsp_attrs.update(set(op['event']['attributes']))
1275
1276            if op['attribute-set'] not in self.root_sets:
1277                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1278            else:
1279                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1280                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1281
1282    def _sort_pure_types(self):
1283        # Try to reorder according to dependencies
1284        pns_key_list = list(self.pure_nested_structs.keys())
1285        pns_key_seen = set()
1286        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1287        for _ in range(rounds):
1288            if len(pns_key_list) == 0:
1289                break
1290            name = pns_key_list.pop(0)
1291            finished = True
1292            for _, spec in self.attr_sets[name].items():
1293                if 'nested-attributes' in spec:
1294                    nested = spec['nested-attributes']
1295                elif 'sub-message' in spec:
1296                    nested = spec.sub_message
1297                else:
1298                    continue
1299
1300                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1301                if self.pure_nested_structs[nested].recursive:
1302                    continue
1303                if nested not in pns_key_seen:
1304                    # Dicts are sorted, this will make struct last
1305                    struct = self.pure_nested_structs.pop(name)
1306                    self.pure_nested_structs[name] = struct
1307                    finished = False
1308                    break
1309            if finished:
1310                pns_key_seen.add(name)
1311            else:
1312                pns_key_list.append(name)
1313
1314    def _load_nested_set_nest(self, spec):
1315        inherit = set()
1316        nested = spec['nested-attributes']
1317        if nested not in self.root_sets:
1318            if nested not in self.pure_nested_structs:
1319                self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1320        else:
1321            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1322
1323        if 'type-value' in spec:
1324            if nested in self.root_sets:
1325                raise Exception("Inheriting members to a space used as root not supported")
1326            inherit.update(set(spec['type-value']))
1327        elif spec['type'] == 'indexed-array':
1328            inherit.add('idx')
1329        self.pure_nested_structs[nested].set_inherited(inherit)
1330
1331        return nested
1332
1333    def _load_nested_set_submsg(self, spec):
1334        # Fake the struct type for the sub-message itself
1335        # its not a attr_set but codegen wants attr_sets.
1336        submsg = self.sub_msgs[spec["sub-message"]]
1337        nested = submsg.name
1338
1339        attrs = []
1340        for name, fmt in submsg.formats.items():
1341            attrs.append({
1342                "name": name,
1343                "type": "nest",
1344                "parent-sub-message": spec,
1345                "nested-attributes": fmt['attribute-set']
1346            })
1347
1348        self.attr_sets[nested] = AttrSet(self, {
1349            "name": nested,
1350            "name-pfx": self.name + '-' + spec.name + '-',
1351            "attributes": attrs
1352        })
1353
1354        if nested not in self.pure_nested_structs:
1355            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1356
1357        return nested
1358
1359    def _load_nested_sets(self):
1360        attr_set_queue = list(self.root_sets.keys())
1361        attr_set_seen = set(self.root_sets.keys())
1362
1363        while len(attr_set_queue):
1364            a_set = attr_set_queue.pop(0)
1365            for attr, spec in self.attr_sets[a_set].items():
1366                if 'nested-attributes' in spec:
1367                    nested = self._load_nested_set_nest(spec)
1368                elif 'sub-message' in spec:
1369                    nested = self._load_nested_set_submsg(spec)
1370                else:
1371                    continue
1372
1373                if nested not in attr_set_seen:
1374                    attr_set_queue.append(nested)
1375                    attr_set_seen.add(nested)
1376
1377        for root_set, rs_members in self.root_sets.items():
1378            for attr, spec in self.attr_sets[root_set].items():
1379                if 'nested-attributes' in spec:
1380                    nested = spec['nested-attributes']
1381                elif 'sub-message' in spec:
1382                    nested = spec.sub_message
1383                else:
1384                    nested = None
1385
1386                if nested:
1387                    if attr in rs_members['request']:
1388                        self.pure_nested_structs[nested].request = True
1389                    if attr in rs_members['reply']:
1390                        self.pure_nested_structs[nested].reply = True
1391
1392                    if spec.is_multi_val():
1393                        child = self.pure_nested_structs.get(nested)
1394                        child.in_multi_val = True
1395
1396        self._sort_pure_types()
1397
1398        # Propagate the request / reply / recursive
1399        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1400            for _, spec in self.attr_sets[attr_set].items():
1401                if attr_set in struct.child_nests:
1402                    struct.recursive = True
1403
1404                if 'nested-attributes' in spec:
1405                    child_name = spec['nested-attributes']
1406                elif 'sub-message' in spec:
1407                    child_name = spec.sub_message
1408                else:
1409                    continue
1410
1411                struct.child_nests.add(child_name)
1412                child = self.pure_nested_structs.get(child_name)
1413                if child:
1414                    if not child.recursive:
1415                        struct.child_nests.update(child.child_nests)
1416                    child.request |= struct.request
1417                    child.reply |= struct.reply
1418                    if spec.is_multi_val():
1419                        child.in_multi_val = True
1420
1421        self._sort_pure_types()
1422
1423    def _load_attr_use(self):
1424        for _, struct in self.pure_nested_structs.items():
1425            if struct.request:
1426                for _, arg in struct.member_list():
1427                    arg.set_request()
1428            if struct.reply:
1429                for _, arg in struct.member_list():
1430                    arg.set_reply()
1431
1432        for root_set, rs_members in self.root_sets.items():
1433            for attr, spec in self.attr_sets[root_set].items():
1434                if attr in rs_members['request']:
1435                    spec.set_request()
1436                if attr in rs_members['reply']:
1437                    spec.set_reply()
1438
1439    def _load_global_policy(self):
1440        global_set = set()
1441        attr_set_name = None
1442        for op_name, op in self.ops.items():
1443            if not op:
1444                continue
1445            if 'attribute-set' not in op:
1446                continue
1447
1448            if attr_set_name is None:
1449                attr_set_name = op['attribute-set']
1450            if attr_set_name != op['attribute-set']:
1451                raise Exception('For a global policy all ops must use the same set')
1452
1453            for op_mode in ['do', 'dump']:
1454                if op_mode in op:
1455                    req = op[op_mode].get('request')
1456                    if req:
1457                        global_set.update(req.get('attributes', []))
1458
1459        self.global_policy = []
1460        self.global_policy_set = attr_set_name
1461        for attr in self.attr_sets[attr_set_name]:
1462            if attr in global_set:
1463                self.global_policy.append(attr)
1464
1465    def _load_hooks(self):
1466        for op in self.ops.values():
1467            for op_mode in ['do', 'dump']:
1468                if op_mode not in op:
1469                    continue
1470                for when in ['pre', 'post']:
1471                    if when not in op[op_mode]:
1472                        continue
1473                    name = op[op_mode][when]
1474                    if name in self.hooks[when][op_mode]['set']:
1475                        continue
1476                    self.hooks[when][op_mode]['set'].add(name)
1477                    self.hooks[when][op_mode]['list'].append(name)
1478
1479
1480class RenderInfo:
1481    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1482        self.family = family
1483        self.nl = cw.nlib
1484        self.ku_space = ku_space
1485        self.op_mode = op_mode
1486        self.op = op
1487
1488        self.fixed_hdr = None
1489        self.fixed_hdr_len = 'ys->family->hdr_len'
1490        if op and op.fixed_header:
1491            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1492            if op.fixed_header != family.fixed_header:
1493                if family.is_classic():
1494                    self.fixed_hdr_len = f"sizeof({self.fixed_hdr})"
1495                else:
1496                    raise Exception(f"Per-op fixed header not supported, yet")
1497
1498
1499        # 'do' and 'dump' response parsing is identical
1500        self.type_consistent = True
1501        self.type_oneside = False
1502        if op_mode != 'do' and 'dump' in op:
1503            if 'do' in op:
1504                if ('reply' in op['do']) != ('reply' in op["dump"]):
1505                    self.type_consistent = False
1506                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1507                    self.type_consistent = False
1508            else:
1509                self.type_consistent = True
1510                self.type_oneside = True
1511
1512        self.attr_set = attr_set
1513        if not self.attr_set:
1514            self.attr_set = op['attribute-set']
1515
1516        self.type_name_conflict = False
1517        if op:
1518            self.type_name = c_lower(op.name)
1519        else:
1520            self.type_name = c_lower(attr_set)
1521            if attr_set in family.consts:
1522                self.type_name_conflict = True
1523
1524        self.cw = cw
1525
1526        self.struct = dict()
1527        if op_mode == 'notify':
1528            op_mode = 'do' if 'do' in op else 'dump'
1529        for op_dir in ['request', 'reply']:
1530            if op:
1531                type_list = []
1532                if op_dir in op[op_mode]:
1533                    type_list = op[op_mode][op_dir]['attributes']
1534                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1535        if op_mode == 'event':
1536            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1537
1538    def type_empty(self, key):
1539        return len(self.struct[key].attr_list) == 0 and self.fixed_hdr is None
1540
1541    def needs_nlflags(self, direction):
1542        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1543
1544
1545class CodeWriter:
1546    def __init__(self, nlib, out_file=None, overwrite=True):
1547        self.nlib = nlib
1548        self._overwrite = overwrite
1549
1550        self._nl = False
1551        self._block_end = False
1552        self._silent_block = False
1553        self._ind = 0
1554        self._ifdef_block = None
1555        if out_file is None:
1556            self._out = os.sys.stdout
1557        else:
1558            self._out = tempfile.NamedTemporaryFile('w+')
1559            self._out_file = out_file
1560
1561    def __del__(self):
1562        self.close_out_file()
1563
1564    def close_out_file(self):
1565        if self._out == os.sys.stdout:
1566            return
1567        # Avoid modifying the file if contents didn't change
1568        self._out.flush()
1569        if not self._overwrite and os.path.isfile(self._out_file):
1570            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1571                return
1572        with open(self._out_file, 'w+') as out_file:
1573            self._out.seek(0)
1574            shutil.copyfileobj(self._out, out_file)
1575            self._out.close()
1576        self._out = os.sys.stdout
1577
1578    @classmethod
1579    def _is_cond(cls, line):
1580        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1581
1582    def p(self, line, add_ind=0):
1583        if self._block_end:
1584            self._block_end = False
1585            if line.startswith('else'):
1586                line = '} ' + line
1587            else:
1588                self._out.write('\t' * self._ind + '}\n')
1589
1590        if self._nl:
1591            self._out.write('\n')
1592            self._nl = False
1593
1594        ind = self._ind
1595        if line[-1] == ':':
1596            ind -= 1
1597        if self._silent_block:
1598            ind += 1
1599        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1600        self._silent_block |= line.strip() == 'else'
1601        if line[0] == '#':
1602            ind = 0
1603        if add_ind:
1604            ind += add_ind
1605        self._out.write('\t' * ind + line + '\n')
1606
1607    def nl(self):
1608        self._nl = True
1609
1610    def block_start(self, line=''):
1611        if line:
1612            line = line + ' '
1613        self.p(line + '{')
1614        self._ind += 1
1615
1616    def block_end(self, line=''):
1617        if line and line[0] not in {';', ','}:
1618            line = ' ' + line
1619        self._ind -= 1
1620        self._nl = False
1621        if not line:
1622            # Delay printing closing bracket in case "else" comes next
1623            if self._block_end:
1624                self._out.write('\t' * (self._ind + 1) + '}\n')
1625            self._block_end = True
1626        else:
1627            self.p('}' + line)
1628
1629    def write_doc_line(self, doc, indent=True):
1630        words = doc.split()
1631        line = ' *'
1632        for word in words:
1633            if len(line) + len(word) >= 79:
1634                self.p(line)
1635                line = ' *'
1636                if indent:
1637                    line += '  '
1638            line += ' ' + word
1639        self.p(line)
1640
1641    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1642        if not args:
1643            args = ['void']
1644
1645        if doc:
1646            self.p('/*')
1647            self.p(' * ' + doc)
1648            self.p(' */')
1649
1650        oneline = qual_ret
1651        if qual_ret[-1] != '*':
1652            oneline += ' '
1653        oneline += f"{name}({', '.join(args)}){suffix}"
1654
1655        if len(oneline) < 80:
1656            self.p(oneline)
1657            return
1658
1659        v = qual_ret
1660        if len(v) > 3:
1661            self.p(v)
1662            v = ''
1663        elif qual_ret[-1] != '*':
1664            v += ' '
1665        v += name + '('
1666        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1667        delta_ind = len(v) - len(ind)
1668        v += args[0]
1669        i = 1
1670        while i < len(args):
1671            next_len = len(v) + len(args[i])
1672            if v[0] == '\t':
1673                next_len += delta_ind
1674            if next_len > 76:
1675                self.p(v + ',')
1676                v = ind
1677            else:
1678                v += ', '
1679            v += args[i]
1680            i += 1
1681        self.p(v + ')' + suffix)
1682
1683    def write_func_lvar(self, local_vars):
1684        if not local_vars:
1685            return
1686
1687        if type(local_vars) is str:
1688            local_vars = [local_vars]
1689
1690        local_vars.sort(key=len, reverse=True)
1691        for var in local_vars:
1692            self.p(var)
1693        self.nl()
1694
1695    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1696        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1697        self.block_start()
1698        self.write_func_lvar(local_vars=local_vars)
1699
1700        for line in body:
1701            self.p(line)
1702        self.block_end()
1703
1704    def writes_defines(self, defines):
1705        longest = 0
1706        for define in defines:
1707            if len(define[0]) > longest:
1708                longest = len(define[0])
1709        longest = ((longest + 8) // 8) * 8
1710        for define in defines:
1711            line = '#define ' + define[0]
1712            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1713            if type(define[1]) is int:
1714                line += str(define[1])
1715            elif type(define[1]) is str:
1716                line += '"' + define[1] + '"'
1717            self.p(line)
1718
1719    def write_struct_init(self, members):
1720        longest = max([len(x[0]) for x in members])
1721        longest += 1  # because we prepend a .
1722        longest = ((longest + 8) // 8) * 8
1723        for one in members:
1724            line = '.' + one[0]
1725            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1726            line += '= ' + str(one[1]) + ','
1727            self.p(line)
1728
1729    def ifdef_block(self, config):
1730        config_option = None
1731        if config:
1732            config_option = 'CONFIG_' + c_upper(config)
1733        if self._ifdef_block == config_option:
1734            return
1735
1736        if self._ifdef_block:
1737            self.p('#endif /* ' + self._ifdef_block + ' */')
1738        if config_option:
1739            self.p('#ifdef ' + config_option)
1740        self._ifdef_block = config_option
1741
1742
1743scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1744
1745direction_to_suffix = {
1746    'reply': '_rsp',
1747    'request': '_req',
1748    '': ''
1749}
1750
1751op_mode_to_wrapper = {
1752    'do': '',
1753    'dump': '_list',
1754    'notify': '_ntf',
1755    'event': '',
1756}
1757
1758_C_KW = {
1759    'auto',
1760    'bool',
1761    'break',
1762    'case',
1763    'char',
1764    'const',
1765    'continue',
1766    'default',
1767    'do',
1768    'double',
1769    'else',
1770    'enum',
1771    'extern',
1772    'float',
1773    'for',
1774    'goto',
1775    'if',
1776    'inline',
1777    'int',
1778    'long',
1779    'register',
1780    'return',
1781    'short',
1782    'signed',
1783    'sizeof',
1784    'static',
1785    'struct',
1786    'switch',
1787    'typedef',
1788    'union',
1789    'unsigned',
1790    'void',
1791    'volatile',
1792    'while'
1793}
1794
1795
1796def rdir(direction):
1797    if direction == 'reply':
1798        return 'request'
1799    if direction == 'request':
1800        return 'reply'
1801    return direction
1802
1803
1804def op_prefix(ri, direction, deref=False):
1805    suffix = f"_{ri.type_name}"
1806
1807    if not ri.op_mode or ri.op_mode == 'do':
1808        suffix += f"{direction_to_suffix[direction]}"
1809    else:
1810        if direction == 'request':
1811            suffix += '_req'
1812            if not ri.type_oneside:
1813                suffix += '_dump'
1814        else:
1815            if ri.type_consistent:
1816                if deref:
1817                    suffix += f"{direction_to_suffix[direction]}"
1818                else:
1819                    suffix += op_mode_to_wrapper[ri.op_mode]
1820            else:
1821                suffix += '_rsp'
1822                suffix += '_dump' if deref else '_list'
1823
1824    return f"{ri.family.c_name}{suffix}"
1825
1826
1827def type_name(ri, direction, deref=False):
1828    return f"struct {op_prefix(ri, direction, deref=deref)}"
1829
1830
1831def print_prototype(ri, direction, terminate=True, doc=None):
1832    suffix = ';' if terminate else ''
1833
1834    fname = ri.op.render_name
1835    if ri.op_mode == 'dump':
1836        fname += '_dump'
1837
1838    args = ['struct ynl_sock *ys']
1839    if 'request' in ri.op[ri.op_mode]:
1840        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1841
1842    ret = 'int'
1843    if 'reply' in ri.op[ri.op_mode]:
1844        ret = f"{type_name(ri, rdir(direction))} *"
1845
1846    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1847
1848
1849def print_req_prototype(ri):
1850    print_prototype(ri, "request", doc=ri.op['doc'])
1851
1852
1853def print_dump_prototype(ri):
1854    print_prototype(ri, "request")
1855
1856
1857def put_typol_submsg(cw, struct):
1858    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1859
1860    i = 0
1861    for name, arg in struct.member_list():
1862        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s", .nest = &%s_nest, },' %
1863             (i, name, arg.nested_render_name))
1864        i += 1
1865
1866    cw.block_end(line=';')
1867    cw.nl()
1868
1869    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1870    cw.p(f'.max_attr = {i - 1},')
1871    cw.p(f'.table = {struct.render_name}_policy,')
1872    cw.block_end(line=';')
1873    cw.nl()
1874
1875
1876def put_typol_fwd(cw, struct):
1877    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1878
1879
1880def put_typol(cw, struct):
1881    if struct.submsg:
1882        put_typol_submsg(cw, struct)
1883        return
1884
1885    type_max = struct.attr_set.max_name
1886    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1887
1888    for _, arg in struct.member_list():
1889        arg.attr_typol(cw)
1890
1891    cw.block_end(line=';')
1892    cw.nl()
1893
1894    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1895    cw.p(f'.max_attr = {type_max},')
1896    cw.p(f'.table = {struct.render_name}_policy,')
1897    cw.block_end(line=';')
1898    cw.nl()
1899
1900
1901def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1902    args = [f'int {arg_name}']
1903    if enum:
1904        args = [enum.user_type + ' ' + arg_name]
1905    cw.write_func_prot('const char *', f'{render_name}_str', args)
1906    cw.block_start()
1907    if enum and enum.type == 'flags':
1908        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1909    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1910    cw.p('return NULL;')
1911    cw.p(f'return {map_name}[{arg_name}];')
1912    cw.block_end()
1913    cw.nl()
1914
1915
1916def put_op_name_fwd(family, cw):
1917    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1918
1919
1920def put_op_name(family, cw):
1921    map_name = f'{family.c_name}_op_strmap'
1922    cw.block_start(line=f"static const char * const {map_name}[] =")
1923    for op_name, op in family.msgs.items():
1924        if op.rsp_value:
1925            # Make sure we don't add duplicated entries, if multiple commands
1926            # produce the same response in legacy families.
1927            if family.rsp_by_value[op.rsp_value] != op:
1928                cw.p(f'// skip "{op_name}", duplicate reply value')
1929                continue
1930
1931            if op.req_value == op.rsp_value:
1932                cw.p(f'[{op.enum_name}] = "{op_name}",')
1933            else:
1934                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1935    cw.block_end(line=';')
1936    cw.nl()
1937
1938    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1939
1940
1941def put_enum_to_str_fwd(family, cw, enum):
1942    args = [enum.user_type + ' value']
1943    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1944
1945
1946def put_enum_to_str(family, cw, enum):
1947    map_name = f'{enum.render_name}_strmap'
1948    cw.block_start(line=f"static const char * const {map_name}[] =")
1949    for entry in enum.entries.values():
1950        cw.p(f'[{entry.value}] = "{entry.name}",')
1951    cw.block_end(line=';')
1952    cw.nl()
1953
1954    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1955
1956
1957def put_req_nested_prototype(ri, struct, suffix=';'):
1958    func_args = ['struct nlmsghdr *nlh',
1959                 'unsigned int attr_type',
1960                 f'{struct.ptr_name}obj']
1961
1962    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1963                          suffix=suffix)
1964
1965
1966def put_req_nested(ri, struct):
1967    local_vars = []
1968    init_lines = []
1969
1970    if struct.submsg is None:
1971        local_vars.append('struct nlattr *nest;')
1972        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
1973
1974    has_anest = False
1975    has_count = False
1976    for _, arg in struct.member_list():
1977        has_anest |= arg.type == 'indexed-array'
1978        has_count |= arg.presence_type() == 'count'
1979    if has_anest:
1980        local_vars.append('struct nlattr *array;')
1981    if has_count:
1982        local_vars.append('unsigned int i;')
1983
1984    put_req_nested_prototype(ri, struct, suffix='')
1985    ri.cw.block_start()
1986    ri.cw.write_func_lvar(local_vars)
1987
1988    for line in init_lines:
1989        ri.cw.p(line)
1990
1991    for _, arg in struct.member_list():
1992        arg.attr_put(ri, "obj")
1993
1994    if struct.submsg is None:
1995        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1996
1997    ri.cw.nl()
1998    ri.cw.p('return 0;')
1999    ri.cw.block_end()
2000    ri.cw.nl()
2001
2002
2003def _multi_parse(ri, struct, init_lines, local_vars):
2004    if struct.nested:
2005        iter_line = "ynl_attr_for_each_nested(attr, nested)"
2006    else:
2007        if ri.fixed_hdr:
2008            local_vars += ['void *hdr;']
2009        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2010        if ri.op.fixed_header != ri.family.fixed_header:
2011            if ri.family.is_classic():
2012                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({ri.fixed_hdr}))"
2013            else:
2014                raise Exception(f"Per-op fixed header not supported, yet")
2015
2016    array_nests = set()
2017    multi_attrs = set()
2018    needs_parg = False
2019    for arg, aspec in struct.member_list():
2020        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2021            if aspec["sub-type"] in {'binary', 'nest'}:
2022                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2023                array_nests.add(arg)
2024            elif aspec['sub-type'] in scalars:
2025                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2026                array_nests.add(arg)
2027            else:
2028                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2029        if 'multi-attr' in aspec:
2030            multi_attrs.add(arg)
2031        needs_parg |= 'nested-attributes' in aspec
2032        needs_parg |= 'sub-message' in aspec
2033    if array_nests or multi_attrs:
2034        local_vars.append('int i;')
2035    if needs_parg:
2036        local_vars.append('struct ynl_parse_arg parg;')
2037        init_lines.append('parg.ys = yarg->ys;')
2038
2039    all_multi = array_nests | multi_attrs
2040
2041    for anest in sorted(all_multi):
2042        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
2043
2044    ri.cw.block_start()
2045    ri.cw.write_func_lvar(local_vars)
2046
2047    for line in init_lines:
2048        ri.cw.p(line)
2049    ri.cw.nl()
2050
2051    for arg in struct.inherited:
2052        ri.cw.p(f'dst->{arg} = {arg};')
2053
2054    if ri.fixed_hdr:
2055        if ri.family.is_classic():
2056            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2057        else:
2058            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2059        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
2060    for anest in sorted(all_multi):
2061        aspec = struct[anest]
2062        ri.cw.p(f"if (dst->{aspec.c_name})")
2063        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2064
2065    ri.cw.nl()
2066    ri.cw.block_start(line=iter_line)
2067    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2068    ri.cw.nl()
2069
2070    first = True
2071    for _, arg in struct.member_list():
2072        good = arg.attr_get(ri, 'dst', first=first)
2073        # First may be 'unused' or 'pad', ignore those
2074        first &= not good
2075
2076    ri.cw.block_end()
2077    ri.cw.nl()
2078
2079    for anest in sorted(array_nests):
2080        aspec = struct[anest]
2081
2082        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2083        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2084        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2085        ri.cw.p('i = 0;')
2086        if 'nested-attributes' in aspec:
2087            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2088        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2089        if 'nested-attributes' in aspec:
2090            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2091            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2092            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2093        elif aspec.sub_type in scalars:
2094            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2095        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2096            # Length is validated by typol
2097            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2098        else:
2099            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2100        ri.cw.p('i++;')
2101        ri.cw.block_end()
2102        ri.cw.block_end()
2103    ri.cw.nl()
2104
2105    for anest in sorted(multi_attrs):
2106        aspec = struct[anest]
2107        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2108        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2109        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2110        ri.cw.p('i = 0;')
2111        if 'nested-attributes' in aspec:
2112            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2113        ri.cw.block_start(line=iter_line)
2114        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2115        if 'nested-attributes' in aspec:
2116            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2117            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2118            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2119        elif aspec.type in scalars:
2120            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2121        elif aspec.type == 'binary' and 'struct' in aspec:
2122            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2123            ri.cw.nl()
2124            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2125            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2126            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2127        elif aspec.type == 'string':
2128            ri.cw.p('unsigned int len;')
2129            ri.cw.nl()
2130            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2131            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2132            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2133            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2134            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2135        else:
2136            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2137        ri.cw.p('i++;')
2138        ri.cw.block_end()
2139        ri.cw.block_end()
2140        ri.cw.block_end()
2141    ri.cw.nl()
2142
2143    if struct.nested:
2144        ri.cw.p('return 0;')
2145    else:
2146        ri.cw.p('return YNL_PARSE_CB_OK;')
2147    ri.cw.block_end()
2148    ri.cw.nl()
2149
2150
2151def parse_rsp_submsg(ri, struct):
2152    parse_rsp_nested_prototype(ri, struct, suffix='')
2153
2154    var = 'dst'
2155
2156    ri.cw.block_start()
2157    ri.cw.write_func_lvar(['const struct nlattr *attr = nested;',
2158                          f'{struct.ptr_name}{var} = yarg->data;',
2159                          'struct ynl_parse_arg parg;'])
2160
2161    ri.cw.p('parg.ys = yarg->ys;')
2162    ri.cw.nl()
2163
2164    first = True
2165    for name, arg in struct.member_list():
2166        kw = 'if' if first else 'else if'
2167        first = False
2168
2169        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2170        get_lines, init_lines, _ = arg._attr_get(ri, var)
2171        for line in init_lines:
2172            ri.cw.p(line)
2173        for line in get_lines:
2174            ri.cw.p(line)
2175        if arg.presence_type() == 'present':
2176            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2177        ri.cw.block_end()
2178    ri.cw.p('return 0;')
2179    ri.cw.block_end()
2180    ri.cw.nl()
2181
2182
2183def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2184    func_args = ['struct ynl_parse_arg *yarg',
2185                 'const struct nlattr *nested']
2186    if struct.submsg:
2187        func_args.insert(1, 'const char *sel')
2188    for arg in struct.inherited:
2189        func_args.append('__u32 ' + arg)
2190
2191    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2192                          suffix=suffix)
2193
2194
2195def parse_rsp_nested(ri, struct):
2196    if struct.submsg:
2197        return parse_rsp_submsg(ri, struct)
2198
2199    parse_rsp_nested_prototype(ri, struct, suffix='')
2200
2201    local_vars = ['const struct nlattr *attr;',
2202                  f'{struct.ptr_name}dst = yarg->data;']
2203    init_lines = []
2204
2205    if struct.member_list():
2206        _multi_parse(ri, struct, init_lines, local_vars)
2207    else:
2208        # Empty nest
2209        ri.cw.block_start()
2210        ri.cw.p('return 0;')
2211        ri.cw.block_end()
2212        ri.cw.nl()
2213
2214
2215def parse_rsp_msg(ri, deref=False):
2216    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2217        return
2218
2219    func_args = ['const struct nlmsghdr *nlh',
2220                 'struct ynl_parse_arg *yarg']
2221
2222    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2223                  'const struct nlattr *attr;']
2224    init_lines = ['dst = yarg->data;']
2225
2226    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2227
2228    if ri.struct["reply"].member_list():
2229        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2230    else:
2231        # Empty reply
2232        ri.cw.block_start()
2233        ri.cw.p('return YNL_PARSE_CB_OK;')
2234        ri.cw.block_end()
2235        ri.cw.nl()
2236
2237
2238def print_req(ri):
2239    ret_ok = '0'
2240    ret_err = '-1'
2241    direction = "request"
2242    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2243                  'struct nlmsghdr *nlh;',
2244                  'int err;']
2245
2246    if 'reply' in ri.op[ri.op_mode]:
2247        ret_ok = 'rsp'
2248        ret_err = 'NULL'
2249        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2250
2251    if ri.fixed_hdr:
2252        local_vars += ['size_t hdr_len;',
2253                       'void *hdr;']
2254
2255    for _, attr in ri.struct["request"].member_list():
2256        if attr.presence_type() == 'count':
2257            local_vars += ['unsigned int i;']
2258            break
2259
2260    print_prototype(ri, direction, terminate=False)
2261    ri.cw.block_start()
2262    ri.cw.write_func_lvar(local_vars)
2263
2264    if ri.family.is_classic():
2265        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2266    else:
2267        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2268
2269    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2270    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2271    if 'reply' in ri.op[ri.op_mode]:
2272        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2273    ri.cw.nl()
2274
2275    if ri.fixed_hdr:
2276        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2277        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2278        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2279        ri.cw.nl()
2280
2281    for _, attr in ri.struct["request"].member_list():
2282        attr.attr_put(ri, "req")
2283    ri.cw.nl()
2284
2285    if 'reply' in ri.op[ri.op_mode]:
2286        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2287        ri.cw.p('yrs.yarg.data = rsp;')
2288        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2289        if ri.op.value is not None:
2290            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2291        else:
2292            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2293        ri.cw.nl()
2294    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2295    ri.cw.p('if (err < 0)')
2296    if 'reply' in ri.op[ri.op_mode]:
2297        ri.cw.p('goto err_free;')
2298    else:
2299        ri.cw.p('return -1;')
2300    ri.cw.nl()
2301
2302    ri.cw.p(f"return {ret_ok};")
2303    ri.cw.nl()
2304
2305    if 'reply' in ri.op[ri.op_mode]:
2306        ri.cw.p('err_free:')
2307        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2308        ri.cw.p(f"return {ret_err};")
2309
2310    ri.cw.block_end()
2311
2312
2313def print_dump(ri):
2314    direction = "request"
2315    print_prototype(ri, direction, terminate=False)
2316    ri.cw.block_start()
2317    local_vars = ['struct ynl_dump_state yds = {};',
2318                  'struct nlmsghdr *nlh;',
2319                  'int err;']
2320
2321    if ri.fixed_hdr:
2322        local_vars += ['size_t hdr_len;',
2323                       'void *hdr;']
2324
2325    ri.cw.write_func_lvar(local_vars)
2326
2327    ri.cw.p('yds.yarg.ys = ys;')
2328    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2329    ri.cw.p("yds.yarg.data = NULL;")
2330    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2331    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2332    if ri.op.value is not None:
2333        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2334    else:
2335        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2336    ri.cw.nl()
2337    if ri.family.is_classic():
2338        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2339    else:
2340        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2341
2342    if ri.fixed_hdr:
2343        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2344        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2345        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2346        ri.cw.nl()
2347
2348    if "request" in ri.op[ri.op_mode]:
2349        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2350        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2351        ri.cw.nl()
2352        for _, attr in ri.struct["request"].member_list():
2353            attr.attr_put(ri, "req")
2354    ri.cw.nl()
2355
2356    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2357    ri.cw.p('if (err < 0)')
2358    ri.cw.p('goto free_list;')
2359    ri.cw.nl()
2360
2361    ri.cw.p('return yds.first;')
2362    ri.cw.nl()
2363    ri.cw.p('free_list:')
2364    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2365    ri.cw.p('return NULL;')
2366    ri.cw.block_end()
2367
2368
2369def call_free(ri, direction, var):
2370    return f"{op_prefix(ri, direction)}_free({var});"
2371
2372
2373def free_arg_name(direction):
2374    if direction:
2375        return direction_to_suffix[direction][1:]
2376    return 'obj'
2377
2378
2379def print_alloc_wrapper(ri, direction):
2380    name = op_prefix(ri, direction)
2381    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
2382    ri.cw.block_start()
2383    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
2384    ri.cw.block_end()
2385
2386
2387def print_free_prototype(ri, direction, suffix=';'):
2388    name = op_prefix(ri, direction)
2389    struct_name = name
2390    if ri.type_name_conflict:
2391        struct_name += '_'
2392    arg = free_arg_name(direction)
2393    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2394
2395
2396def print_nlflags_set(ri, direction):
2397    name = op_prefix(ri, direction)
2398    ri.cw.write_func_prot(f'static inline void', f"{name}_set_nlflags",
2399                          [f"struct {name} *req", "__u16 nl_flags"])
2400    ri.cw.block_start()
2401    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2402    ri.cw.block_end()
2403    ri.cw.nl()
2404
2405
2406def _print_type(ri, direction, struct):
2407    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2408    if not direction and ri.type_name_conflict:
2409        suffix += '_'
2410
2411    if ri.op_mode == 'dump' and not ri.type_oneside:
2412        suffix += '_dump'
2413
2414    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2415
2416    if ri.needs_nlflags(direction):
2417        ri.cw.p('__u16 _nlmsg_flags;')
2418        ri.cw.nl()
2419    if ri.fixed_hdr:
2420        ri.cw.p(ri.fixed_hdr + ' _hdr;')
2421        ri.cw.nl()
2422
2423    for type_filter in ['present', 'len', 'count']:
2424        meta_started = False
2425        for _, attr in struct.member_list():
2426            line = attr.presence_member(ri.ku_space, type_filter)
2427            if line:
2428                if not meta_started:
2429                    ri.cw.block_start(line=f"struct")
2430                    meta_started = True
2431                ri.cw.p(line)
2432        if meta_started:
2433            ri.cw.block_end(line=f'_{type_filter};')
2434    ri.cw.nl()
2435
2436    for arg in struct.inherited:
2437        ri.cw.p(f"__u32 {arg};")
2438
2439    for _, attr in struct.member_list():
2440        attr.struct_member(ri)
2441
2442    ri.cw.block_end(line=';')
2443    ri.cw.nl()
2444
2445
2446def print_type(ri, direction):
2447    _print_type(ri, direction, ri.struct[direction])
2448
2449
2450def print_type_full(ri, struct):
2451    _print_type(ri, "", struct)
2452
2453
2454def print_type_helpers(ri, direction, deref=False):
2455    print_free_prototype(ri, direction)
2456    ri.cw.nl()
2457
2458    if ri.needs_nlflags(direction):
2459        print_nlflags_set(ri, direction)
2460
2461    if ri.ku_space == 'user' and direction == 'request':
2462        for _, attr in ri.struct[direction].member_list():
2463            attr.setter(ri, ri.attr_set, direction, deref=deref)
2464    ri.cw.nl()
2465
2466
2467def print_req_type_helpers(ri):
2468    if ri.type_empty("request"):
2469        return
2470    print_alloc_wrapper(ri, "request")
2471    print_type_helpers(ri, "request")
2472
2473
2474def print_rsp_type_helpers(ri):
2475    if 'reply' not in ri.op[ri.op_mode]:
2476        return
2477    print_type_helpers(ri, "reply")
2478
2479
2480def print_parse_prototype(ri, direction, terminate=True):
2481    suffix = "_rsp" if direction == "reply" else "_req"
2482    term = ';' if terminate else ''
2483
2484    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2485                          ['const struct nlattr **tb',
2486                           f"struct {ri.op.render_name}{suffix} *req"],
2487                          suffix=term)
2488
2489
2490def print_req_type(ri):
2491    if ri.type_empty("request"):
2492        return
2493    print_type(ri, "request")
2494
2495
2496def print_req_free(ri):
2497    if 'request' not in ri.op[ri.op_mode]:
2498        return
2499    _free_type(ri, 'request', ri.struct['request'])
2500
2501
2502def print_rsp_type(ri):
2503    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2504        direction = 'reply'
2505    elif ri.op_mode == 'event':
2506        direction = 'reply'
2507    else:
2508        return
2509    print_type(ri, direction)
2510
2511
2512def print_wrapped_type(ri):
2513    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2514    if ri.op_mode == 'dump':
2515        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2516    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2517        ri.cw.p('__u16 family;')
2518        ri.cw.p('__u8 cmd;')
2519        ri.cw.p('struct ynl_ntf_base_type *next;')
2520        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2521    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2522    ri.cw.block_end(line=';')
2523    ri.cw.nl()
2524    print_free_prototype(ri, 'reply')
2525    ri.cw.nl()
2526
2527
2528def _free_type_members_iter(ri, struct):
2529    if struct.free_needs_iter():
2530        ri.cw.p('unsigned int i;')
2531        ri.cw.nl()
2532
2533
2534def _free_type_members(ri, var, struct, ref=''):
2535    for _, attr in struct.member_list():
2536        attr.free(ri, var, ref)
2537
2538
2539def _free_type(ri, direction, struct):
2540    var = free_arg_name(direction)
2541
2542    print_free_prototype(ri, direction, suffix='')
2543    ri.cw.block_start()
2544    _free_type_members_iter(ri, struct)
2545    _free_type_members(ri, var, struct)
2546    if direction:
2547        ri.cw.p(f'free({var});')
2548    ri.cw.block_end()
2549    ri.cw.nl()
2550
2551
2552def free_rsp_nested_prototype(ri):
2553        print_free_prototype(ri, "")
2554
2555
2556def free_rsp_nested(ri, struct):
2557    _free_type(ri, "", struct)
2558
2559
2560def print_rsp_free(ri):
2561    if 'reply' not in ri.op[ri.op_mode]:
2562        return
2563    _free_type(ri, 'reply', ri.struct['reply'])
2564
2565
2566def print_dump_type_free(ri):
2567    sub_type = type_name(ri, 'reply')
2568
2569    print_free_prototype(ri, 'reply', suffix='')
2570    ri.cw.block_start()
2571    ri.cw.p(f"{sub_type} *next = rsp;")
2572    ri.cw.nl()
2573    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2574    _free_type_members_iter(ri, ri.struct['reply'])
2575    ri.cw.p('rsp = next;')
2576    ri.cw.p('next = rsp->next;')
2577    ri.cw.nl()
2578
2579    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2580    ri.cw.p(f'free(rsp);')
2581    ri.cw.block_end()
2582    ri.cw.block_end()
2583    ri.cw.nl()
2584
2585
2586def print_ntf_type_free(ri):
2587    print_free_prototype(ri, 'reply', suffix='')
2588    ri.cw.block_start()
2589    _free_type_members_iter(ri, ri.struct['reply'])
2590    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2591    ri.cw.p(f'free(rsp);')
2592    ri.cw.block_end()
2593    ri.cw.nl()
2594
2595
2596def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2597    if terminate and ri and policy_should_be_static(struct.family):
2598        return
2599
2600    if terminate:
2601        prefix = 'extern '
2602    else:
2603        if ri and policy_should_be_static(struct.family):
2604            prefix = 'static '
2605        else:
2606            prefix = ''
2607
2608    suffix = ';' if terminate else ' = {'
2609
2610    max_attr = struct.attr_max_val
2611    if ri:
2612        name = ri.op.render_name
2613        if ri.op.dual_policy:
2614            name += '_' + ri.op_mode
2615    else:
2616        name = struct.render_name
2617    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2618
2619
2620def print_req_policy(cw, struct, ri=None):
2621    if ri and ri.op:
2622        cw.ifdef_block(ri.op.get('config-cond', None))
2623    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2624    for _, arg in struct.member_list():
2625        arg.attr_policy(cw)
2626    cw.p("};")
2627    cw.ifdef_block(None)
2628    cw.nl()
2629
2630
2631def kernel_can_gen_family_struct(family):
2632    return family.proto == 'genetlink'
2633
2634
2635def policy_should_be_static(family):
2636    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2637
2638
2639def print_kernel_policy_ranges(family, cw):
2640    first = True
2641    for _, attr_set in family.attr_sets.items():
2642        if attr_set.subset_of:
2643            continue
2644
2645        for _, attr in attr_set.items():
2646            if not attr.request:
2647                continue
2648            if 'full-range' not in attr.checks:
2649                continue
2650
2651            if first:
2652                cw.p('/* Integer value ranges */')
2653                first = False
2654
2655            sign = '' if attr.type[0] == 'u' else '_signed'
2656            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2657            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2658            members = []
2659            if 'min' in attr.checks:
2660                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2661            if 'max' in attr.checks:
2662                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2663            cw.write_struct_init(members)
2664            cw.block_end(line=';')
2665            cw.nl()
2666
2667
2668def print_kernel_policy_sparse_enum_validates(family, cw):
2669    first = True
2670    for _, attr_set in family.attr_sets.items():
2671        if attr_set.subset_of:
2672            continue
2673
2674        for _, attr in attr_set.items():
2675            if not attr.request:
2676                continue
2677            if not attr.enum_name:
2678                continue
2679            if 'sparse' not in attr.checks:
2680                continue
2681
2682            if first:
2683                cw.p('/* Sparse enums validation callbacks */')
2684                first = False
2685
2686            sign = '' if attr.type[0] == 'u' else '_signed'
2687            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2688            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2689                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2690            cw.block_start()
2691            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2692            enum = family.consts[attr['enum']]
2693            first_entry = True
2694            for entry in enum.entries.values():
2695                if first_entry:
2696                    first_entry = False
2697                else:
2698                    cw.p('fallthrough;')
2699                cw.p(f'case {entry.c_name}:')
2700            cw.p('return 0;')
2701            cw.block_end()
2702            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2703            cw.p('return -EINVAL;')
2704            cw.block_end()
2705            cw.nl()
2706
2707
2708def print_kernel_op_table_fwd(family, cw, terminate):
2709    exported = not kernel_can_gen_family_struct(family)
2710
2711    if not terminate or exported:
2712        cw.p(f"/* Ops table for {family.ident_name} */")
2713
2714        pol_to_struct = {'global': 'genl_small_ops',
2715                         'per-op': 'genl_ops',
2716                         'split': 'genl_split_ops'}
2717        struct_type = pol_to_struct[family.kernel_policy]
2718
2719        if not exported:
2720            cnt = ""
2721        elif family.kernel_policy == 'split':
2722            cnt = 0
2723            for op in family.ops.values():
2724                if 'do' in op:
2725                    cnt += 1
2726                if 'dump' in op:
2727                    cnt += 1
2728        else:
2729            cnt = len(family.ops)
2730
2731        qual = 'static const' if not exported else 'const'
2732        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2733        if terminate:
2734            cw.p(f"extern {line};")
2735        else:
2736            cw.block_start(line=line + ' =')
2737
2738    if not terminate:
2739        return
2740
2741    cw.nl()
2742    for name in family.hooks['pre']['do']['list']:
2743        cw.write_func_prot('int', c_lower(name),
2744                           ['const struct genl_split_ops *ops',
2745                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2746    for name in family.hooks['post']['do']['list']:
2747        cw.write_func_prot('void', c_lower(name),
2748                           ['const struct genl_split_ops *ops',
2749                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2750    for name in family.hooks['pre']['dump']['list']:
2751        cw.write_func_prot('int', c_lower(name),
2752                           ['struct netlink_callback *cb'], suffix=';')
2753    for name in family.hooks['post']['dump']['list']:
2754        cw.write_func_prot('int', c_lower(name),
2755                           ['struct netlink_callback *cb'], suffix=';')
2756
2757    cw.nl()
2758
2759    for op_name, op in family.ops.items():
2760        if op.is_async:
2761            continue
2762
2763        if 'do' in op:
2764            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2765            cw.write_func_prot('int', name,
2766                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2767
2768        if 'dump' in op:
2769            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2770            cw.write_func_prot('int', name,
2771                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2772    cw.nl()
2773
2774
2775def print_kernel_op_table_hdr(family, cw):
2776    print_kernel_op_table_fwd(family, cw, terminate=True)
2777
2778
2779def print_kernel_op_table(family, cw):
2780    print_kernel_op_table_fwd(family, cw, terminate=False)
2781    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2782        for op_name, op in family.ops.items():
2783            if op.is_async:
2784                continue
2785
2786            cw.ifdef_block(op.get('config-cond', None))
2787            cw.block_start()
2788            members = [('cmd', op.enum_name)]
2789            if 'dont-validate' in op:
2790                members.append(('validate',
2791                                ' | '.join([c_upper('genl-dont-validate-' + x)
2792                                            for x in op['dont-validate']])), )
2793            for op_mode in ['do', 'dump']:
2794                if op_mode in op:
2795                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2796                    members.append((op_mode + 'it', name))
2797            if family.kernel_policy == 'per-op':
2798                struct = Struct(family, op['attribute-set'],
2799                                type_list=op['do']['request']['attributes'])
2800
2801                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2802                members.append(('policy', name))
2803                members.append(('maxattr', struct.attr_max_val.enum_name))
2804            if 'flags' in op:
2805                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2806            cw.write_struct_init(members)
2807            cw.block_end(line=',')
2808    elif family.kernel_policy == 'split':
2809        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2810                    'dump': {'pre': 'start', 'post': 'done'}}
2811
2812        for op_name, op in family.ops.items():
2813            for op_mode in ['do', 'dump']:
2814                if op.is_async or op_mode not in op:
2815                    continue
2816
2817                cw.ifdef_block(op.get('config-cond', None))
2818                cw.block_start()
2819                members = [('cmd', op.enum_name)]
2820                if 'dont-validate' in op:
2821                    dont_validate = []
2822                    for x in op['dont-validate']:
2823                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2824                            continue
2825                        if op_mode == "dump" and x == 'strict':
2826                            continue
2827                        dont_validate.append(x)
2828
2829                    if dont_validate:
2830                        members.append(('validate',
2831                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2832                                                    for x in dont_validate])), )
2833                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2834                if 'pre' in op[op_mode]:
2835                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2836                members.append((op_mode + 'it', name))
2837                if 'post' in op[op_mode]:
2838                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2839                if 'request' in op[op_mode]:
2840                    struct = Struct(family, op['attribute-set'],
2841                                    type_list=op[op_mode]['request']['attributes'])
2842
2843                    if op.dual_policy:
2844                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2845                    else:
2846                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2847                    members.append(('policy', name))
2848                    members.append(('maxattr', struct.attr_max_val.enum_name))
2849                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2850                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2851                cw.write_struct_init(members)
2852                cw.block_end(line=',')
2853    cw.ifdef_block(None)
2854
2855    cw.block_end(line=';')
2856    cw.nl()
2857
2858
2859def print_kernel_mcgrp_hdr(family, cw):
2860    if not family.mcgrps['list']:
2861        return
2862
2863    cw.block_start('enum')
2864    for grp in family.mcgrps['list']:
2865        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2866        cw.p(grp_id)
2867    cw.block_end(';')
2868    cw.nl()
2869
2870
2871def print_kernel_mcgrp_src(family, cw):
2872    if not family.mcgrps['list']:
2873        return
2874
2875    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2876    for grp in family.mcgrps['list']:
2877        name = grp['name']
2878        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2879        cw.p('[' + grp_id + '] = { "' + name + '", },')
2880    cw.block_end(';')
2881    cw.nl()
2882
2883
2884def print_kernel_family_struct_hdr(family, cw):
2885    if not kernel_can_gen_family_struct(family):
2886        return
2887
2888    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2889    cw.nl()
2890    if 'sock-priv' in family.kernel_family:
2891        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2892        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2893        cw.nl()
2894
2895
2896def print_kernel_family_struct_src(family, cw):
2897    if not kernel_can_gen_family_struct(family):
2898        return
2899
2900    if 'sock-priv' in family.kernel_family:
2901        # Generate "trampolines" to make CFI happy
2902        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2903                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2904                      ["void *priv"])
2905        cw.nl()
2906        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2907                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2908                      ["void *priv"])
2909        cw.nl()
2910
2911    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2912    cw.p('.name\t\t= ' + family.fam_key + ',')
2913    cw.p('.version\t= ' + family.ver_key + ',')
2914    cw.p('.netnsok\t= true,')
2915    cw.p('.parallel_ops\t= true,')
2916    cw.p('.module\t\t= THIS_MODULE,')
2917    if family.kernel_policy == 'per-op':
2918        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2919        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2920    elif family.kernel_policy == 'split':
2921        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2922        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2923    if family.mcgrps['list']:
2924        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2925        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2926    if 'sock-priv' in family.kernel_family:
2927        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2928        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2929        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2930    cw.block_end(';')
2931
2932
2933def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2934    start_line = 'enum'
2935    if enum_name in obj:
2936        if obj[enum_name]:
2937            start_line = 'enum ' + c_lower(obj[enum_name])
2938    elif ckey and ckey in obj:
2939        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2940    cw.block_start(line=start_line)
2941
2942
2943def render_uapi_unified(family, cw, max_by_define, separate_ntf):
2944    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2945    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2946    max_value = f"({cnt_name} - 1)"
2947
2948    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2949    val = 0
2950    for op in family.msgs.values():
2951        if separate_ntf and ('notify' in op or 'event' in op):
2952            continue
2953
2954        suffix = ','
2955        if op.value != val:
2956            suffix = f" = {op.value},"
2957            val = op.value
2958        cw.p(op.enum_name + suffix)
2959        val += 1
2960    cw.nl()
2961    cw.p(cnt_name + ('' if max_by_define else ','))
2962    if not max_by_define:
2963        cw.p(f"{max_name} = {max_value}")
2964    cw.block_end(line=';')
2965    if max_by_define:
2966        cw.p(f"#define {max_name} {max_value}")
2967    cw.nl()
2968
2969
2970def render_uapi_directional(family, cw, max_by_define):
2971    max_name = f"{family.op_prefix}USER_MAX"
2972    cnt_name = f"__{family.op_prefix}USER_CNT"
2973    max_value = f"({cnt_name} - 1)"
2974
2975    cw.block_start(line='enum')
2976    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
2977    val = 0
2978    for op in family.msgs.values():
2979        if 'do' in op and 'event' not in op:
2980            suffix = ','
2981            if op.value and op.value != val:
2982                suffix = f" = {op.value},"
2983                val = op.value
2984            cw.p(op.enum_name + suffix)
2985            val += 1
2986    cw.nl()
2987    cw.p(cnt_name + ('' if max_by_define else ','))
2988    if not max_by_define:
2989        cw.p(f"{max_name} = {max_value}")
2990    cw.block_end(line=';')
2991    if max_by_define:
2992        cw.p(f"#define {max_name} {max_value}")
2993    cw.nl()
2994
2995    max_name = f"{family.op_prefix}KERNEL_MAX"
2996    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
2997    max_value = f"({cnt_name} - 1)"
2998
2999    cw.block_start(line='enum')
3000    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3001    val = 0
3002    for op in family.msgs.values():
3003        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3004            enum_name = op.enum_name
3005            if 'event' not in op and 'notify' not in op:
3006                enum_name = f'{enum_name}_REPLY'
3007
3008            suffix = ','
3009            if op.value and op.value != val:
3010                suffix = f" = {op.value},"
3011                val = op.value
3012            cw.p(enum_name + suffix)
3013            val += 1
3014    cw.nl()
3015    cw.p(cnt_name + ('' if max_by_define else ','))
3016    if not max_by_define:
3017        cw.p(f"{max_name} = {max_value}")
3018    cw.block_end(line=';')
3019    if max_by_define:
3020        cw.p(f"#define {max_name} {max_value}")
3021    cw.nl()
3022
3023
3024def render_uapi(family, cw):
3025    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3026    hdr_prot = hdr_prot.replace('/', '_')
3027    cw.p('#ifndef ' + hdr_prot)
3028    cw.p('#define ' + hdr_prot)
3029    cw.nl()
3030
3031    defines = [(family.fam_key, family["name"]),
3032               (family.ver_key, family.get('version', 1))]
3033    cw.writes_defines(defines)
3034    cw.nl()
3035
3036    defines = []
3037    for const in family['definitions']:
3038        if const.get('header'):
3039            continue
3040
3041        if const['type'] != 'const':
3042            cw.writes_defines(defines)
3043            defines = []
3044            cw.nl()
3045
3046        # Write kdoc for enum and flags (one day maybe also structs)
3047        if const['type'] == 'enum' or const['type'] == 'flags':
3048            enum = family.consts[const['name']]
3049
3050            if enum.header:
3051                continue
3052
3053            if enum.has_doc():
3054                if enum.has_entry_doc():
3055                    cw.p('/**')
3056                    doc = ''
3057                    if 'doc' in enum:
3058                        doc = ' - ' + enum['doc']
3059                    cw.write_doc_line(enum.enum_name + doc)
3060                else:
3061                    cw.p('/*')
3062                    cw.write_doc_line(enum['doc'], indent=False)
3063                for entry in enum.entries.values():
3064                    if entry.has_doc():
3065                        doc = '@' + entry.c_name + ': ' + entry['doc']
3066                        cw.write_doc_line(doc)
3067                cw.p(' */')
3068
3069            uapi_enum_start(family, cw, const, 'name')
3070            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3071            for entry in enum.entries.values():
3072                suffix = ','
3073                if entry.value_change:
3074                    suffix = f" = {entry.user_value()}" + suffix
3075                cw.p(entry.c_name + suffix)
3076
3077            if const.get('render-max', False):
3078                cw.nl()
3079                cw.p('/* private: */')
3080                if const['type'] == 'flags':
3081                    max_name = c_upper(name_pfx + 'mask')
3082                    max_val = f' = {enum.get_mask()},'
3083                    cw.p(max_name + max_val)
3084                else:
3085                    cnt_name = enum.enum_cnt_name
3086                    max_name = c_upper(name_pfx + 'max')
3087                    if not cnt_name:
3088                        cnt_name = '__' + name_pfx + 'max'
3089                    cw.p(c_upper(cnt_name) + ',')
3090                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3091            cw.block_end(line=';')
3092            cw.nl()
3093        elif const['type'] == 'const':
3094            defines.append([c_upper(family.get('c-define-name',
3095                                               f"{family.ident_name}-{const['name']}")),
3096                            const['value']])
3097
3098    if defines:
3099        cw.writes_defines(defines)
3100        cw.nl()
3101
3102    max_by_define = family.get('max-by-define', False)
3103
3104    for _, attr_set in family.attr_sets.items():
3105        if attr_set.subset_of:
3106            continue
3107
3108        max_value = f"({attr_set.cnt_name} - 1)"
3109
3110        val = 0
3111        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3112        for _, attr in attr_set.items():
3113            suffix = ','
3114            if attr.value != val:
3115                suffix = f" = {attr.value},"
3116                val = attr.value
3117            val += 1
3118            cw.p(attr.enum_name + suffix)
3119        if attr_set.items():
3120            cw.nl()
3121        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3122        if not max_by_define:
3123            cw.p(f"{attr_set.max_name} = {max_value}")
3124        cw.block_end(line=';')
3125        if max_by_define:
3126            cw.p(f"#define {attr_set.max_name} {max_value}")
3127        cw.nl()
3128
3129    # Commands
3130    separate_ntf = 'async-prefix' in family['operations']
3131
3132    if family.msg_id_model == 'unified':
3133        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3134    elif family.msg_id_model == 'directional':
3135        render_uapi_directional(family, cw, max_by_define)
3136    else:
3137        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3138
3139    if separate_ntf:
3140        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3141        for op in family.msgs.values():
3142            if separate_ntf and not ('notify' in op or 'event' in op):
3143                continue
3144
3145            suffix = ','
3146            if 'value' in op:
3147                suffix = f" = {op['value']},"
3148            cw.p(op.enum_name + suffix)
3149        cw.block_end(line=';')
3150        cw.nl()
3151
3152    # Multicast
3153    defines = []
3154    for grp in family.mcgrps['list']:
3155        name = grp['name']
3156        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3157                        f'{name}'])
3158    cw.nl()
3159    if defines:
3160        cw.writes_defines(defines)
3161        cw.nl()
3162
3163    cw.p(f'#endif /* {hdr_prot} */')
3164
3165
3166def _render_user_ntf_entry(ri, op):
3167    if not ri.family.is_classic():
3168        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3169    else:
3170        crud_op = ri.family.req_by_value[op.rsp_value]
3171        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3172    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3173    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3174    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3175    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3176    ri.cw.block_end(line=',')
3177
3178
3179def render_user_family(family, cw, prototype):
3180    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3181    if prototype:
3182        cw.p(f'extern {symbol};')
3183        return
3184
3185    if family.ntfs:
3186        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3187        for ntf_op_name, ntf_op in family.ntfs.items():
3188            if 'notify' in ntf_op:
3189                op = family.ops[ntf_op['notify']]
3190                ri = RenderInfo(cw, family, "user", op, "notify")
3191            elif 'event' in ntf_op:
3192                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3193            else:
3194                raise Exception('Invalid notification ' + ntf_op_name)
3195            _render_user_ntf_entry(ri, ntf_op)
3196        for op_name, op in family.ops.items():
3197            if 'event' not in op:
3198                continue
3199            ri = RenderInfo(cw, family, "user", op, "event")
3200            _render_user_ntf_entry(ri, op)
3201        cw.block_end(line=";")
3202        cw.nl()
3203
3204    cw.block_start(f'{symbol} = ')
3205    cw.p(f'.name\t\t= "{family.c_name}",')
3206    if family.is_classic():
3207        cw.p(f'.is_classic\t= true,')
3208        cw.p(f'.classic_id\t= {family.get("protonum")},')
3209    if family.is_classic():
3210        if family.fixed_header:
3211            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3212    elif family.fixed_header:
3213        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3214    else:
3215        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3216    if family.ntfs:
3217        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3218        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3219    cw.block_end(line=';')
3220
3221
3222def family_contains_bitfield32(family):
3223    for _, attr_set in family.attr_sets.items():
3224        if attr_set.subset_of:
3225            continue
3226        for _, attr in attr_set.items():
3227            if attr.type == "bitfield32":
3228                return True
3229    return False
3230
3231
3232def find_kernel_root(full_path):
3233    sub_path = ''
3234    while True:
3235        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3236        full_path = os.path.dirname(full_path)
3237        maintainers = os.path.join(full_path, "MAINTAINERS")
3238        if os.path.exists(maintainers):
3239            return full_path, sub_path[:-1]
3240
3241
3242def main():
3243    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3244    parser.add_argument('--mode', dest='mode', type=str, required=True,
3245                        choices=('user', 'kernel', 'uapi'))
3246    parser.add_argument('--spec', dest='spec', type=str, required=True)
3247    parser.add_argument('--header', dest='header', action='store_true', default=None)
3248    parser.add_argument('--source', dest='header', action='store_false')
3249    parser.add_argument('--user-header', nargs='+', default=[])
3250    parser.add_argument('--cmp-out', action='store_true', default=None,
3251                        help='Do not overwrite the output file if the new output is identical to the old')
3252    parser.add_argument('--exclude-op', action='append', default=[])
3253    parser.add_argument('-o', dest='out_file', type=str, default=None)
3254    args = parser.parse_args()
3255
3256    if args.header is None:
3257        parser.error("--header or --source is required")
3258
3259    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3260
3261    try:
3262        parsed = Family(args.spec, exclude_ops)
3263        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3264            print('Spec license:', parsed.license)
3265            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3266            os.sys.exit(1)
3267    except yaml.YAMLError as exc:
3268        print(exc)
3269        os.sys.exit(1)
3270        return
3271
3272    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3273
3274    _, spec_kernel = find_kernel_root(args.spec)
3275    if args.mode == 'uapi' or args.header:
3276        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3277    else:
3278        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3279    cw.p("/* Do not edit directly, auto-generated from: */")
3280    cw.p(f"/*\t{spec_kernel} */")
3281    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3282    if args.exclude_op or args.user_header:
3283        line = ''
3284        line += ' --user-header '.join([''] + args.user_header)
3285        line += ' --exclude-op '.join([''] + args.exclude_op)
3286        cw.p(f'/* YNL-ARG{line} */')
3287    cw.nl()
3288
3289    if args.mode == 'uapi':
3290        render_uapi(parsed, cw)
3291        return
3292
3293    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3294    if args.header:
3295        cw.p('#ifndef ' + hdr_prot)
3296        cw.p('#define ' + hdr_prot)
3297        cw.nl()
3298
3299    if args.out_file:
3300        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3301    else:
3302        hdr_file = "generated_header_file.h"
3303
3304    if args.mode == 'kernel':
3305        cw.p('#include <net/netlink.h>')
3306        cw.p('#include <net/genetlink.h>')
3307        cw.nl()
3308        if not args.header:
3309            if args.out_file:
3310                cw.p(f'#include "{hdr_file}"')
3311            cw.nl()
3312        headers = ['uapi/' + parsed.uapi_header]
3313        headers += parsed.kernel_family.get('headers', [])
3314    else:
3315        cw.p('#include <stdlib.h>')
3316        cw.p('#include <string.h>')
3317        if args.header:
3318            cw.p('#include <linux/types.h>')
3319            if family_contains_bitfield32(parsed):
3320                cw.p('#include <linux/netlink.h>')
3321        else:
3322            cw.p(f'#include "{hdr_file}"')
3323            cw.p('#include "ynl.h"')
3324        headers = []
3325    for definition in parsed['definitions'] + parsed['attribute-sets']:
3326        if 'header' in definition:
3327            headers.append(definition['header'])
3328    if args.mode == 'user':
3329        headers.append(parsed.uapi_header)
3330    seen_header = []
3331    for one in headers:
3332        if one not in seen_header:
3333            cw.p(f"#include <{one}>")
3334            seen_header.append(one)
3335    cw.nl()
3336
3337    if args.mode == "user":
3338        if not args.header:
3339            cw.p("#include <linux/genetlink.h>")
3340            cw.nl()
3341            for one in args.user_header:
3342                cw.p(f'#include "{one}"')
3343        else:
3344            cw.p('struct ynl_sock;')
3345            cw.nl()
3346            render_user_family(parsed, cw, True)
3347        cw.nl()
3348
3349    if args.mode == "kernel":
3350        if args.header:
3351            for _, struct in sorted(parsed.pure_nested_structs.items()):
3352                if struct.request:
3353                    cw.p('/* Common nested types */')
3354                    break
3355            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3356                if struct.request:
3357                    print_req_policy_fwd(cw, struct)
3358            cw.nl()
3359
3360            if parsed.kernel_policy == 'global':
3361                cw.p(f"/* Global operation policy for {parsed.name} */")
3362
3363                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3364                print_req_policy_fwd(cw, struct)
3365                cw.nl()
3366
3367            if parsed.kernel_policy in {'per-op', 'split'}:
3368                for op_name, op in parsed.ops.items():
3369                    if 'do' in op and 'event' not in op:
3370                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3371                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3372                        cw.nl()
3373
3374            print_kernel_op_table_hdr(parsed, cw)
3375            print_kernel_mcgrp_hdr(parsed, cw)
3376            print_kernel_family_struct_hdr(parsed, cw)
3377        else:
3378            print_kernel_policy_ranges(parsed, cw)
3379            print_kernel_policy_sparse_enum_validates(parsed, cw)
3380
3381            for _, struct in sorted(parsed.pure_nested_structs.items()):
3382                if struct.request:
3383                    cw.p('/* Common nested types */')
3384                    break
3385            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3386                if struct.request:
3387                    print_req_policy(cw, struct)
3388            cw.nl()
3389
3390            if parsed.kernel_policy == 'global':
3391                cw.p(f"/* Global operation policy for {parsed.name} */")
3392
3393                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3394                print_req_policy(cw, struct)
3395                cw.nl()
3396
3397            for op_name, op in parsed.ops.items():
3398                if parsed.kernel_policy in {'per-op', 'split'}:
3399                    for op_mode in ['do', 'dump']:
3400                        if op_mode in op and 'request' in op[op_mode]:
3401                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3402                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3403                            print_req_policy(cw, ri.struct['request'], ri=ri)
3404                            cw.nl()
3405
3406            print_kernel_op_table(parsed, cw)
3407            print_kernel_mcgrp_src(parsed, cw)
3408            print_kernel_family_struct_src(parsed, cw)
3409
3410    if args.mode == "user":
3411        if args.header:
3412            cw.p('/* Enums */')
3413            put_op_name_fwd(parsed, cw)
3414
3415            for name, const in parsed.consts.items():
3416                if isinstance(const, EnumSet):
3417                    put_enum_to_str_fwd(parsed, cw, const)
3418            cw.nl()
3419
3420            cw.p('/* Common nested types */')
3421            for attr_set, struct in parsed.pure_nested_structs.items():
3422                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3423                print_type_full(ri, struct)
3424                if struct.request and struct.in_multi_val:
3425                    free_rsp_nested_prototype(ri)
3426                    cw.nl()
3427
3428            for op_name, op in parsed.ops.items():
3429                cw.p(f"/* ============== {op.enum_name} ============== */")
3430
3431                if 'do' in op and 'event' not in op:
3432                    cw.p(f"/* {op.enum_name} - do */")
3433                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3434                    print_req_type(ri)
3435                    print_req_type_helpers(ri)
3436                    cw.nl()
3437                    print_rsp_type(ri)
3438                    print_rsp_type_helpers(ri)
3439                    cw.nl()
3440                    print_req_prototype(ri)
3441                    cw.nl()
3442
3443                if 'dump' in op:
3444                    cw.p(f"/* {op.enum_name} - dump */")
3445                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3446                    print_req_type(ri)
3447                    print_req_type_helpers(ri)
3448                    if not ri.type_consistent or ri.type_oneside:
3449                        print_rsp_type(ri)
3450                    print_wrapped_type(ri)
3451                    print_dump_prototype(ri)
3452                    cw.nl()
3453
3454                if op.has_ntf:
3455                    cw.p(f"/* {op.enum_name} - notify */")
3456                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3457                    if not ri.type_consistent:
3458                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3459                    print_wrapped_type(ri)
3460
3461            for op_name, op in parsed.ntfs.items():
3462                if 'event' in op:
3463                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3464                    cw.p(f"/* {op.enum_name} - event */")
3465                    print_rsp_type(ri)
3466                    cw.nl()
3467                    print_wrapped_type(ri)
3468            cw.nl()
3469        else:
3470            cw.p('/* Enums */')
3471            put_op_name(parsed, cw)
3472
3473            for name, const in parsed.consts.items():
3474                if isinstance(const, EnumSet):
3475                    put_enum_to_str(parsed, cw, const)
3476            cw.nl()
3477
3478            has_recursive_nests = False
3479            cw.p('/* Policies */')
3480            for struct in parsed.pure_nested_structs.values():
3481                if struct.recursive:
3482                    put_typol_fwd(cw, struct)
3483                    has_recursive_nests = True
3484            if has_recursive_nests:
3485                cw.nl()
3486            for struct in parsed.pure_nested_structs.values():
3487                put_typol(cw, struct)
3488            for name in parsed.root_sets:
3489                struct = Struct(parsed, name)
3490                put_typol(cw, struct)
3491
3492            cw.p('/* Common nested types */')
3493            if has_recursive_nests:
3494                for attr_set, struct in parsed.pure_nested_structs.items():
3495                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3496                    free_rsp_nested_prototype(ri)
3497                    if struct.request:
3498                        put_req_nested_prototype(ri, struct)
3499                    if struct.reply:
3500                        parse_rsp_nested_prototype(ri, struct)
3501                cw.nl()
3502            for attr_set, struct in parsed.pure_nested_structs.items():
3503                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3504
3505                free_rsp_nested(ri, struct)
3506                if struct.request:
3507                    put_req_nested(ri, struct)
3508                if struct.reply:
3509                    parse_rsp_nested(ri, struct)
3510
3511            for op_name, op in parsed.ops.items():
3512                cw.p(f"/* ============== {op.enum_name} ============== */")
3513                if 'do' in op and 'event' not in op:
3514                    cw.p(f"/* {op.enum_name} - do */")
3515                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3516                    print_req_free(ri)
3517                    print_rsp_free(ri)
3518                    parse_rsp_msg(ri)
3519                    print_req(ri)
3520                    cw.nl()
3521
3522                if 'dump' in op:
3523                    cw.p(f"/* {op.enum_name} - dump */")
3524                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3525                    if not ri.type_consistent or ri.type_oneside:
3526                        parse_rsp_msg(ri, deref=True)
3527                    print_req_free(ri)
3528                    print_dump_type_free(ri)
3529                    print_dump(ri)
3530                    cw.nl()
3531
3532                if op.has_ntf:
3533                    cw.p(f"/* {op.enum_name} - notify */")
3534                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3535                    if not ri.type_consistent:
3536                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3537                    print_ntf_type_free(ri)
3538
3539            for op_name, op in parsed.ntfs.items():
3540                if 'event' in op:
3541                    cw.p(f"/* {op.enum_name} - event */")
3542
3543                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3544                    parse_rsp_msg(ri)
3545
3546                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3547                    print_ntf_type_free(ri)
3548            cw.nl()
3549            render_user_family(parsed, cw, False)
3550
3551    if args.header:
3552        cw.p(f'#endif /* {hdr_prot} */')
3553
3554
3555if __name__ == "__main__":
3556    main()
3557