xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision cb39645d9a6a8b84f2e820db7c2b49ebd4b18b2c)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17from lib import SpecSubMessage, SpecSubMessageFormat
18
19
20def c_upper(name):
21    return name.upper().replace('-', '_')
22
23
24def c_lower(name):
25    return name.lower().replace('-', '_')
26
27
28def limit_to_number(name):
29    """
30    Turn a string limit like u32-max or s64-min into its numerical value
31    """
32    if name[0] == 'u' and name.endswith('-min'):
33        return 0
34    width = int(name[1:-4])
35    if name[0] == 's':
36        width -= 1
37    value = (1 << width) - 1
38    if name[0] == 's' and name.endswith('-min'):
39        value = -value - 1
40    return value
41
42
43class BaseNlLib:
44    def get_family_id(self):
45        return 'ys->family_id'
46
47
48class Type(SpecAttr):
49    def __init__(self, family, attr_set, attr, value):
50        super().__init__(family, attr_set, attr, value)
51
52        self.attr = attr
53        self.attr_set = attr_set
54        self.type = attr['type']
55        self.checks = attr.get('checks', {})
56
57        self.request = False
58        self.reply = False
59
60        self.is_selector = False
61
62        if 'len' in attr:
63            self.len = attr['len']
64
65        if 'nested-attributes' in attr:
66            nested = attr['nested-attributes']
67        elif 'sub-message' in attr:
68            nested = attr['sub-message']
69        else:
70            nested = None
71
72        if nested:
73            self.nested_attrs = nested
74            if self.nested_attrs == family.name:
75                self.nested_render_name = c_lower(f"{family.ident_name}")
76            else:
77                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
78
79            if self.nested_attrs in self.family.consts:
80                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
81            else:
82                self.nested_struct_type = 'struct ' + self.nested_render_name
83
84        self.c_name = c_lower(self.name)
85        if self.c_name in _C_KW:
86            self.c_name += '_'
87        if self.c_name[0].isdigit():
88            self.c_name = '_' + self.c_name
89
90        # Added by resolve():
91        self.enum_name = None
92        delattr(self, "enum_name")
93
94    def _get_real_attr(self):
95        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
96        return self.family.attr_sets[self.attr_set.subset_of][self.name]
97
98    def set_request(self):
99        self.request = True
100        if self.attr_set.subset_of:
101            self._get_real_attr().set_request()
102
103    def set_reply(self):
104        self.reply = True
105        if self.attr_set.subset_of:
106            self._get_real_attr().set_reply()
107
108    def get_limit(self, limit, default=None):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return value
112        if isinstance(value, int):
113            return value
114        if value in self.family.consts:
115            return self.family.consts[value]["value"]
116        return limit_to_number(value)
117
118    def get_limit_str(self, limit, default=None, suffix=''):
119        value = self.checks.get(limit, default)
120        if value is None:
121            return ''
122        if isinstance(value, int):
123            return str(value) + suffix
124        if value in self.family.consts:
125            const = self.family.consts[value]
126            if const.get('header'):
127                return c_upper(value)
128            return c_upper(f"{self.family['name']}-{value}")
129        return c_upper(value)
130
131    def resolve(self):
132        if 'parent-sub-message' in self.attr:
133            enum_name = self.attr['parent-sub-message'].enum_name
134        elif 'name-prefix' in self.attr:
135            enum_name = f"{self.attr['name-prefix']}{self.name}"
136        else:
137            enum_name = f"{self.attr_set.name_prefix}{self.name}"
138        self.enum_name = c_upper(enum_name)
139
140        if self.attr_set.subset_of:
141            if self.checks != self._get_real_attr().checks:
142                raise Exception("Overriding checks not supported by codegen, yet")
143
144    def is_multi_val(self):
145        return None
146
147    def is_scalar(self):
148        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
149
150    def is_recursive(self):
151        return False
152
153    def is_recursive_for_op(self, ri):
154        return self.is_recursive() and not ri.op
155
156    def presence_type(self):
157        return 'present'
158
159    def presence_member(self, space, type_filter):
160        if self.presence_type() != type_filter:
161            return
162
163        if self.presence_type() == 'present':
164            pfx = '__' if space == 'user' else ''
165            return f"{pfx}u32 {self.c_name}:1;"
166
167        if self.presence_type() in {'len', 'count'}:
168            pfx = '__' if space == 'user' else ''
169            return f"{pfx}u32 {self.c_name};"
170
171    def _complex_member_type(self, ri):
172        return None
173
174    def free_needs_iter(self):
175        return False
176
177    def _free_lines(self, ri, var, ref):
178        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
179            return [f'free({var}->{ref}{self.c_name});']
180        return []
181
182    def free(self, ri, var, ref):
183        lines = self._free_lines(ri, var, ref)
184        for line in lines:
185            ri.cw.p(line)
186
187    def arg_member(self, ri):
188        member = self._complex_member_type(ri)
189        if member:
190            spc = ' ' if member[-1] != '*' else ''
191            arg = [member + spc + '*' + self.c_name]
192            if self.presence_type() == 'count':
193                arg += ['unsigned int n_' + self.c_name]
194            return arg
195        raise Exception(f"Struct member not implemented for class type {self.type}")
196
197    def struct_member(self, ri):
198        member = self._complex_member_type(ri)
199        if member:
200            ptr = '*' if self.is_multi_val() else ''
201            if self.is_recursive_for_op(ri):
202                ptr = '*'
203            spc = ' ' if member[-1] != '*' else ''
204            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
205            return
206        members = self.arg_member(ri)
207        for one in members:
208            ri.cw.p(one + ';')
209
210    def _attr_policy(self, policy):
211        return '{ .type = ' + policy + ', }'
212
213    def attr_policy(self, cw):
214        policy = f'NLA_{c_upper(self.type)}'
215        if self.attr.get('byte-order') == 'big-endian':
216            if self.type in {'u16', 'u32'}:
217                policy = f'NLA_BE{self.type[1:]}'
218
219        spec = self._attr_policy(policy)
220        cw.p(f"\t[{self.enum_name}] = {spec},")
221
222    def _attr_typol(self):
223        raise Exception(f"Type policy not implemented for class type {self.type}")
224
225    def attr_typol(self, cw):
226        typol = self._attr_typol()
227        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
228
229    def _attr_put_line(self, ri, var, line):
230        presence = self.presence_type()
231        if presence in {'present', 'len'}:
232            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
233        ri.cw.p(f"{line};")
234
235    def _attr_put_simple(self, ri, var, put_type):
236        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
237        self._attr_put_line(ri, var, line)
238
239    def attr_put(self, ri, var):
240        raise Exception(f"Put not implemented for class type {self.type}")
241
242    def _attr_get(self, ri, var):
243        raise Exception(f"Attr get not implemented for class type {self.type}")
244
245    def attr_get(self, ri, var, first):
246        lines, init_lines, local_vars = self._attr_get(ri, var)
247        if type(lines) is str:
248            lines = [lines]
249        if type(init_lines) is str:
250            init_lines = [init_lines]
251
252        kw = 'if' if first else 'else if'
253        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
254        if local_vars:
255            for local in local_vars:
256                ri.cw.p(local)
257            ri.cw.nl()
258
259        if not self.is_multi_val():
260            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
261            ri.cw.p("return YNL_PARSE_CB_ERROR;")
262            if self.presence_type() == 'present':
263                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
264
265        if init_lines:
266            ri.cw.nl()
267            for line in init_lines:
268                ri.cw.p(line)
269
270        for line in lines:
271            ri.cw.p(line)
272        ri.cw.block_end()
273        return True
274
275    def _setter_lines(self, ri, member, presence):
276        raise Exception(f"Setter not implemented for class type {self.type}")
277
278    def setter(self, ri, space, direction, deref=False, ref=None):
279        ref = (ref if ref else []) + [self.c_name]
280        var = "req"
281        member = f"{var}->{'.'.join(ref)}"
282
283        local_vars = []
284        if self.free_needs_iter():
285            local_vars += ['unsigned int i;']
286
287        code = []
288        presence = ''
289        for i in range(0, len(ref)):
290            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
291            # Every layer below last is a nest, so we know it uses bit presence
292            # last layer is "self" and may be a complex type
293            if i == len(ref) - 1 and self.presence_type() != 'present':
294                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
295                continue
296            code.append(presence + ' = 1;')
297        ref_path = '.'.join(ref[:-1])
298        if ref_path:
299            ref_path += '.'
300        code += self._free_lines(ri, var, ref_path)
301        code += self._setter_lines(ri, member, presence)
302
303        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
304        free = bool([x for x in code if 'free(' in x])
305        alloc = bool([x for x in code if 'alloc(' in x])
306        if free and not alloc:
307            func_name = '__' + func_name
308        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
309                         body=code,
310                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
311
312
313class TypeUnused(Type):
314    def presence_type(self):
315        return ''
316
317    def arg_member(self, ri):
318        return []
319
320    def _attr_get(self, ri, var):
321        return ['return YNL_PARSE_CB_ERROR;'], None, None
322
323    def _attr_typol(self):
324        return '.type = YNL_PT_REJECT, '
325
326    def attr_policy(self, cw):
327        pass
328
329    def attr_put(self, ri, var):
330        pass
331
332    def attr_get(self, ri, var, first):
333        pass
334
335    def setter(self, ri, space, direction, deref=False, ref=None):
336        pass
337
338
339class TypePad(Type):
340    def presence_type(self):
341        return ''
342
343    def arg_member(self, ri):
344        return []
345
346    def _attr_typol(self):
347        return '.type = YNL_PT_IGNORE, '
348
349    def attr_put(self, ri, var):
350        pass
351
352    def attr_get(self, ri, var, first):
353        pass
354
355    def attr_policy(self, cw):
356        pass
357
358    def setter(self, ri, space, direction, deref=False, ref=None):
359        pass
360
361
362class TypeScalar(Type):
363    def __init__(self, family, attr_set, attr, value):
364        super().__init__(family, attr_set, attr, value)
365
366        self.byte_order_comment = ''
367        if 'byte-order' in attr:
368            self.byte_order_comment = f" /* {attr['byte-order']} */"
369
370        # Classic families have some funny enums, don't bother
371        # computing checks, since we only need them for kernel policies
372        if not family.is_classic():
373            self._init_checks()
374
375        # Added by resolve():
376        self.is_bitfield = None
377        delattr(self, "is_bitfield")
378        self.type_name = None
379        delattr(self, "type_name")
380
381    def resolve(self):
382        self.resolve_up(super())
383
384        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
385            self.is_bitfield = True
386        elif 'enum' in self.attr:
387            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
388        else:
389            self.is_bitfield = False
390
391        if not self.is_bitfield and 'enum' in self.attr:
392            self.type_name = self.family.consts[self.attr['enum']].user_type
393        elif self.is_auto_scalar:
394            self.type_name = '__' + self.type[0] + '64'
395        else:
396            self.type_name = '__' + self.type
397
398    def _init_checks(self):
399        if 'enum' in self.attr:
400            enum = self.family.consts[self.attr['enum']]
401            low, high = enum.value_range()
402            if low == None and high == None:
403                self.checks['sparse'] = True
404            else:
405                if 'min' not in self.checks:
406                    if low != 0 or self.type[0] == 's':
407                        self.checks['min'] = low
408                if 'max' not in self.checks:
409                    self.checks['max'] = high
410
411        if 'min' in self.checks and 'max' in self.checks:
412            if self.get_limit('min') > self.get_limit('max'):
413                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
414            self.checks['range'] = True
415
416        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
417        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
418        if low < 0 and self.type[0] == 'u':
419            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
420        if low < -32768 or high > 32767:
421            self.checks['full-range'] = True
422
423    def _attr_policy(self, policy):
424        if 'flags-mask' in self.checks or self.is_bitfield:
425            if self.is_bitfield:
426                enum = self.family.consts[self.attr['enum']]
427                mask = enum.get_mask(as_flags=True)
428            else:
429                flags = self.family.consts[self.checks['flags-mask']]
430                flag_cnt = len(flags['entries'])
431                mask = (1 << flag_cnt) - 1
432            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
433        elif 'full-range' in self.checks:
434            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
435        elif 'range' in self.checks:
436            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
437        elif 'min' in self.checks:
438            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
439        elif 'max' in self.checks:
440            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
441        elif 'sparse' in self.checks:
442            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
443        return super()._attr_policy(policy)
444
445    def _attr_typol(self):
446        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
447
448    def arg_member(self, ri):
449        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
450
451    def attr_put(self, ri, var):
452        self._attr_put_simple(ri, var, self.type)
453
454    def _attr_get(self, ri, var):
455        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
456
457    def _setter_lines(self, ri, member, presence):
458        return [f"{member} = {self.c_name};"]
459
460
461class TypeFlag(Type):
462    def arg_member(self, ri):
463        return []
464
465    def _attr_typol(self):
466        return '.type = YNL_PT_FLAG, '
467
468    def attr_put(self, ri, var):
469        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
470
471    def _attr_get(self, ri, var):
472        return [], None, None
473
474    def _setter_lines(self, ri, member, presence):
475        return []
476
477
478class TypeString(Type):
479    def arg_member(self, ri):
480        return [f"const char *{self.c_name}"]
481
482    def presence_type(self):
483        return 'len'
484
485    def struct_member(self, ri):
486        ri.cw.p(f"char *{self.c_name};")
487
488    def _attr_typol(self):
489        typol = f'.type = YNL_PT_NUL_STR, '
490        if self.is_selector:
491            typol += '.is_selector = 1, '
492        return typol
493
494    def _attr_policy(self, policy):
495        if 'exact-len' in self.checks:
496            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
497        else:
498            mem = '{ .type = ' + policy
499            if 'max-len' in self.checks:
500                mem += ', .len = ' + self.get_limit_str('max-len')
501            mem += ', }'
502        return mem
503
504    def attr_policy(self, cw):
505        if self.checks.get('unterminated-ok', False):
506            policy = 'NLA_STRING'
507        else:
508            policy = 'NLA_NUL_STRING'
509
510        spec = self._attr_policy(policy)
511        cw.p(f"\t[{self.enum_name}] = {spec},")
512
513    def attr_put(self, ri, var):
514        self._attr_put_simple(ri, var, 'str')
515
516    def _attr_get(self, ri, var):
517        len_mem = var + '->_len.' + self.c_name
518        return [f"{len_mem} = len;",
519                f"{var}->{self.c_name} = malloc(len + 1);",
520                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
521                f"{var}->{self.c_name}[len] = 0;"], \
522               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
523               ['unsigned int len;']
524
525    def _setter_lines(self, ri, member, presence):
526        return [f"{presence} = strlen({self.c_name});",
527                f"{member} = malloc({presence} + 1);",
528                f'memcpy({member}, {self.c_name}, {presence});',
529                f'{member}[{presence}] = 0;']
530
531
532class TypeBinary(Type):
533    def arg_member(self, ri):
534        return [f"const void *{self.c_name}", 'size_t len']
535
536    def presence_type(self):
537        return 'len'
538
539    def struct_member(self, ri):
540        ri.cw.p(f"void *{self.c_name};")
541
542    def _attr_typol(self):
543        return f'.type = YNL_PT_BINARY,'
544
545    def _attr_policy(self, policy):
546        if len(self.checks) == 0:
547            pass
548        elif len(self.checks) == 1:
549            check_name = list(self.checks)[0]
550            if check_name not in {'exact-len', 'min-len', 'max-len'}:
551                raise Exception('Unsupported check for binary type: ' + check_name)
552        else:
553            raise Exception('More than one check for binary type not implemented, yet')
554
555        if len(self.checks) == 0:
556            mem = '{ .type = NLA_BINARY, }'
557        elif 'exact-len' in self.checks:
558            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
559        elif 'min-len' in self.checks:
560            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
561        elif 'max-len' in self.checks:
562            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
563
564        return mem
565
566    def attr_put(self, ri, var):
567        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
568                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
569
570    def _attr_get(self, ri, var):
571        len_mem = var + '->_len.' + self.c_name
572        return [f"{len_mem} = len;",
573                f"{var}->{self.c_name} = malloc(len);",
574                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
575               ['len = ynl_attr_data_len(attr);'], \
576               ['unsigned int len;']
577
578    def _setter_lines(self, ri, member, presence):
579        return [f"{presence} = len;",
580                f"{member} = malloc({presence});",
581                f'memcpy({member}, {self.c_name}, {presence});']
582
583
584class TypeBinaryStruct(TypeBinary):
585    def struct_member(self, ri):
586        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
587
588    def _attr_get(self, ri, var):
589        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
590        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
591        return [f"{len_mem} = len;",
592                f"if (len < {struct_sz})",
593                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
594                "else",
595                f"{var}->{self.c_name} = malloc(len);",
596                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
597               ['len = ynl_attr_data_len(attr);'], \
598               ['unsigned int len;']
599
600
601class TypeBinaryScalarArray(TypeBinary):
602    def arg_member(self, ri):
603        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
604
605    def presence_type(self):
606        return 'count'
607
608    def struct_member(self, ri):
609        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
610
611    def attr_put(self, ri, var):
612        presence = self.presence_type()
613        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
614        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
615        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
616                f"{var}->{self.c_name}, i);")
617        ri.cw.block_end()
618
619    def _attr_get(self, ri, var):
620        len_mem = var + '->_count.' + self.c_name
621        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
622                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
623                f"{var}->{self.c_name} = malloc(len);",
624                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
625               ['len = ynl_attr_data_len(attr);'], \
626               ['unsigned int len;']
627
628    def _setter_lines(self, ri, member, presence):
629        return [f"{presence} = count;",
630                f"count *= sizeof(__{self.get('sub-type')});",
631                f"{member} = malloc(count);",
632                f'memcpy({member}, {self.c_name}, count);']
633
634
635class TypeBitfield32(Type):
636    def _complex_member_type(self, ri):
637        return "struct nla_bitfield32"
638
639    def _attr_typol(self):
640        return f'.type = YNL_PT_BITFIELD32, '
641
642    def _attr_policy(self, policy):
643        if not 'enum' in self.attr:
644            raise Exception('Enum required for bitfield32 attr')
645        enum = self.family.consts[self.attr['enum']]
646        mask = enum.get_mask(as_flags=True)
647        return f"NLA_POLICY_BITFIELD32({mask})"
648
649    def attr_put(self, ri, var):
650        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
651        self._attr_put_line(ri, var, line)
652
653    def _attr_get(self, ri, var):
654        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
655
656    def _setter_lines(self, ri, member, presence):
657        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
658
659
660class TypeNest(Type):
661    def is_recursive(self):
662        return self.family.pure_nested_structs[self.nested_attrs].recursive
663
664    def _complex_member_type(self, ri):
665        return self.nested_struct_type
666
667    def _free_lines(self, ri, var, ref):
668        lines = []
669        at = '&'
670        if self.is_recursive_for_op(ri):
671            at = ''
672            lines += [f'if ({var}->{ref}{self.c_name})']
673        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
674        return lines
675
676    def _attr_typol(self):
677        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
678
679    def _attr_policy(self, policy):
680        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
681
682    def attr_put(self, ri, var):
683        at = '' if self.is_recursive_for_op(ri) else '&'
684        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
685                            f"{self.enum_name}, {at}{var}->{self.c_name})")
686
687    def _attr_get(self, ri, var):
688        pns = self.family.pure_nested_structs[self.nested_attrs]
689        args = ["&parg", "attr"]
690        for sel in pns.external_selectors():
691            args.append(f'{var}->{sel.name}')
692        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
693                     "return YNL_PARSE_CB_ERROR;"]
694        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
695                      f"parg.data = &{var}->{self.c_name};"]
696        return get_lines, init_lines, None
697
698    def setter(self, ri, space, direction, deref=False, ref=None):
699        ref = (ref if ref else []) + [self.c_name]
700
701        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
702            if attr.is_recursive():
703                continue
704            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
705
706
707class TypeMultiAttr(Type):
708    def __init__(self, family, attr_set, attr, value, base_type):
709        super().__init__(family, attr_set, attr, value)
710
711        self.base_type = base_type
712
713    def is_multi_val(self):
714        return True
715
716    def presence_type(self):
717        return 'count'
718
719    def _complex_member_type(self, ri):
720        if 'type' not in self.attr or self.attr['type'] == 'nest':
721            return self.nested_struct_type
722        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
723            return None  # use arg_member()
724        elif self.attr['type'] == 'string':
725            return 'struct ynl_string *'
726        elif self.attr['type'] in scalars:
727            scalar_pfx = '__' if ri.ku_space == 'user' else ''
728            return scalar_pfx + self.attr['type']
729        else:
730            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
731
732    def arg_member(self, ri):
733        if self.type == 'binary' and 'struct' in self.attr:
734            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
735                    f'unsigned int n_{self.c_name}']
736        return super().arg_member(ri)
737
738    def free_needs_iter(self):
739        return self.attr['type'] in {'nest', 'string'}
740
741    def _free_lines(self, ri, var, ref):
742        lines = []
743        if self.attr['type'] in scalars:
744            lines += [f"free({var}->{ref}{self.c_name});"]
745        elif self.attr['type'] == 'binary':
746            lines += [f"free({var}->{ref}{self.c_name});"]
747        elif self.attr['type'] == 'string':
748            lines += [
749                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
750                f"free({var}->{ref}{self.c_name}[i]);",
751                f"free({var}->{ref}{self.c_name});",
752            ]
753        elif 'type' not in self.attr or self.attr['type'] == 'nest':
754            lines += [
755                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
756                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
757                f"free({var}->{ref}{self.c_name});",
758            ]
759        else:
760            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
761        return lines
762
763    def _attr_policy(self, policy):
764        return self.base_type._attr_policy(policy)
765
766    def _attr_typol(self):
767        return self.base_type._attr_typol()
768
769    def _attr_get(self, ri, var):
770        return f'n_{self.c_name}++;', None, None
771
772    def attr_put(self, ri, var):
773        if self.attr['type'] in scalars:
774            put_type = self.type
775            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
776            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
777        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
778            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
779            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
780        elif self.attr['type'] == 'string':
781            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
782            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
783        elif 'type' not in self.attr or self.attr['type'] == 'nest':
784            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
785            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
786                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
787        else:
788            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
789
790    def _setter_lines(self, ri, member, presence):
791        return [f"{member} = {self.c_name};",
792                f"{presence} = n_{self.c_name};"]
793
794
795class TypeArrayNest(Type):
796    def is_multi_val(self):
797        return True
798
799    def presence_type(self):
800        return 'count'
801
802    def _complex_member_type(self, ri):
803        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
804            return self.nested_struct_type
805        elif self.attr['sub-type'] in scalars:
806            scalar_pfx = '__' if ri.ku_space == 'user' else ''
807            return scalar_pfx + self.attr['sub-type']
808        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
809            return None  # use arg_member()
810        else:
811            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
812
813    def arg_member(self, ri):
814        if self.sub_type == 'binary' and 'exact-len' in self.checks:
815            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
816                    f'unsigned int n_{self.c_name}']
817        return super().arg_member(ri)
818
819    def _attr_typol(self):
820        if self.attr['sub-type'] in scalars:
821            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
822        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
823            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
824        else:
825            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
826
827    def _attr_get(self, ri, var):
828        local_vars = ['const struct nlattr *attr2;']
829        get_lines = [f'attr_{self.c_name} = attr;',
830                     'ynl_attr_for_each_nested(attr2, attr) {',
831                     '\tif (ynl_attr_validate(yarg, attr2))',
832                     '\t\treturn YNL_PARSE_CB_ERROR;',
833                     f'\t{var}->_count.{self.c_name}++;',
834                     '}']
835        return get_lines, None, local_vars
836
837    def attr_put(self, ri, var):
838        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
839        if self.sub_type in scalars:
840            put_type = self.sub_type
841            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
842            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
843            ri.cw.block_end()
844        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
845            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
846            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
847        elif self.sub_type == 'nest':
848            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
849            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
850        else:
851            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
852        ri.cw.p('ynl_attr_nest_end(nlh, array);')
853
854    def _setter_lines(self, ri, member, presence):
855        return [f"{member} = {self.c_name};",
856                f"{presence} = n_{self.c_name};"]
857
858
859class TypeNestTypeValue(Type):
860    def _complex_member_type(self, ri):
861        return self.nested_struct_type
862
863    def _attr_typol(self):
864        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
865
866    def _attr_get(self, ri, var):
867        prev = 'attr'
868        tv_args = ''
869        get_lines = []
870        local_vars = []
871        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
872                      f"parg.data = &{var}->{self.c_name};"]
873        if 'type-value' in self.attr:
874            tv_names = [c_lower(x) for x in self.attr["type-value"]]
875            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
876            local_vars += [f'__u32 {", ".join(tv_names)};']
877            for level in self.attr["type-value"]:
878                level = c_lower(level)
879                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
880                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
881                prev = 'attr_' + level
882
883            tv_args = f", {', '.join(tv_names)}"
884
885        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
886        return get_lines, init_lines, local_vars
887
888
889class TypeSubMessage(TypeNest):
890    def __init__(self, family, attr_set, attr, value):
891        super().__init__(family, attr_set, attr, value)
892
893        self.selector = Selector(attr, attr_set)
894
895    def _attr_typol(self):
896        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
897        typol += '.is_submsg = 1, '
898        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
899        # support external selectors. No family uses sub-messages with external
900        # selector for requests so this is fine for now.
901        if not self.selector.is_external():
902            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
903        return typol
904
905    def _attr_get(self, ri, var):
906        sel = c_lower(self['selector'])
907        if self.selector.is_external():
908            sel_var = f"_sel_{sel}"
909        else:
910            sel_var = f"{var}->{sel}"
911        get_lines = [f'if (!{sel_var})',
912                     f'return ynl_submsg_failed(yarg, "%s", "%s");' %
913                        (self.name, self['selector']),
914                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
915                     "return YNL_PARSE_CB_ERROR;"]
916        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
917                      f"parg.data = &{var}->{self.c_name};"]
918        return get_lines, init_lines, None
919
920
921class Selector:
922    def __init__(self, msg_attr, attr_set):
923        self.name = msg_attr["selector"]
924
925        if self.name in attr_set:
926            self.attr = attr_set[self.name]
927            self.attr.is_selector = True
928            self._external = False
929        else:
930            # The selector will need to get passed down thru the structs
931            self.attr = None
932            self._external = True
933
934    def set_attr(self, attr):
935        self.attr = attr
936
937    def is_external(self):
938        return self._external
939
940
941class Struct:
942    def __init__(self, family, space_name, type_list=None,
943                 inherited=None, submsg=None):
944        self.family = family
945        self.space_name = space_name
946        self.attr_set = family.attr_sets[space_name]
947        # Use list to catch comparisons with empty sets
948        self._inherited = inherited if inherited is not None else []
949        self.inherited = []
950        self.submsg = submsg
951
952        self.nested = type_list is None
953        if family.name == c_lower(space_name):
954            self.render_name = c_lower(family.ident_name)
955        else:
956            self.render_name = c_lower(family.ident_name + '-' + space_name)
957        self.struct_name = 'struct ' + self.render_name
958        if self.nested and space_name in family.consts:
959            self.struct_name += '_'
960        self.ptr_name = self.struct_name + ' *'
961        # All attr sets this one contains, directly or multiple levels down
962        self.child_nests = set()
963
964        self.request = False
965        self.reply = False
966        self.recursive = False
967        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
968
969        self.attr_list = []
970        self.attrs = dict()
971        if type_list is not None:
972            for t in type_list:
973                self.attr_list.append((t, self.attr_set[t]),)
974        else:
975            for t in self.attr_set:
976                self.attr_list.append((t, self.attr_set[t]),)
977
978        max_val = 0
979        self.attr_max_val = None
980        for name, attr in self.attr_list:
981            if attr.value >= max_val:
982                max_val = attr.value
983                self.attr_max_val = attr
984            self.attrs[name] = attr
985
986    def __iter__(self):
987        yield from self.attrs
988
989    def __getitem__(self, key):
990        return self.attrs[key]
991
992    def member_list(self):
993        return self.attr_list
994
995    def set_inherited(self, new_inherited):
996        if self._inherited != new_inherited:
997            raise Exception("Inheriting different members not supported")
998        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
999
1000    def external_selectors(self):
1001        sels = []
1002        for name, attr in self.attr_list:
1003            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1004                sels.append(attr.selector)
1005        return sels
1006
1007    def free_needs_iter(self):
1008        for _, attr in self.attr_list:
1009            if attr.free_needs_iter():
1010                return True
1011        return False
1012
1013
1014class EnumEntry(SpecEnumEntry):
1015    def __init__(self, enum_set, yaml, prev, value_start):
1016        super().__init__(enum_set, yaml, prev, value_start)
1017
1018        if prev:
1019            self.value_change = (self.value != prev.value + 1)
1020        else:
1021            self.value_change = (self.value != 0)
1022        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1023
1024        # Added by resolve:
1025        self.c_name = None
1026        delattr(self, "c_name")
1027
1028    def resolve(self):
1029        self.resolve_up(super())
1030
1031        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1032
1033
1034class EnumSet(SpecEnumSet):
1035    def __init__(self, family, yaml):
1036        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1037
1038        if 'enum-name' in yaml:
1039            if yaml['enum-name']:
1040                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1041                self.user_type = self.enum_name
1042            else:
1043                self.enum_name = None
1044        else:
1045            self.enum_name = 'enum ' + self.render_name
1046
1047        if self.enum_name:
1048            self.user_type = self.enum_name
1049        else:
1050            self.user_type = 'int'
1051
1052        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1053        self.header = yaml.get('header', None)
1054        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1055
1056        super().__init__(family, yaml)
1057
1058    def new_entry(self, entry, prev_entry, value_start):
1059        return EnumEntry(self, entry, prev_entry, value_start)
1060
1061    def value_range(self):
1062        low = min([x.value for x in self.entries.values()])
1063        high = max([x.value for x in self.entries.values()])
1064
1065        if high - low + 1 != len(self.entries):
1066            return None, None
1067
1068        return low, high
1069
1070
1071class AttrSet(SpecAttrSet):
1072    def __init__(self, family, yaml):
1073        super().__init__(family, yaml)
1074
1075        if self.subset_of is None:
1076            if 'name-prefix' in yaml:
1077                pfx = yaml['name-prefix']
1078            elif self.name == family.name:
1079                pfx = family.ident_name + '-a-'
1080            else:
1081                pfx = f"{family.ident_name}-a-{self.name}-"
1082            self.name_prefix = c_upper(pfx)
1083            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1084            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1085        else:
1086            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1087            self.max_name = family.attr_sets[self.subset_of].max_name
1088            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1089
1090        # Added by resolve:
1091        self.c_name = None
1092        delattr(self, "c_name")
1093
1094    def resolve(self):
1095        self.c_name = c_lower(self.name)
1096        if self.c_name in _C_KW:
1097            self.c_name += '_'
1098        if self.c_name == self.family.c_name:
1099            self.c_name = ''
1100
1101    def new_attr(self, elem, value):
1102        if elem['type'] in scalars:
1103            t = TypeScalar(self.family, self, elem, value)
1104        elif elem['type'] == 'unused':
1105            t = TypeUnused(self.family, self, elem, value)
1106        elif elem['type'] == 'pad':
1107            t = TypePad(self.family, self, elem, value)
1108        elif elem['type'] == 'flag':
1109            t = TypeFlag(self.family, self, elem, value)
1110        elif elem['type'] == 'string':
1111            t = TypeString(self.family, self, elem, value)
1112        elif elem['type'] == 'binary':
1113            if 'struct' in elem:
1114                t = TypeBinaryStruct(self.family, self, elem, value)
1115            elif elem.get('sub-type') in scalars:
1116                t = TypeBinaryScalarArray(self.family, self, elem, value)
1117            else:
1118                t = TypeBinary(self.family, self, elem, value)
1119        elif elem['type'] == 'bitfield32':
1120            t = TypeBitfield32(self.family, self, elem, value)
1121        elif elem['type'] == 'nest':
1122            t = TypeNest(self.family, self, elem, value)
1123        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1124            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1125                t = TypeArrayNest(self.family, self, elem, value)
1126            else:
1127                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1128        elif elem['type'] == 'nest-type-value':
1129            t = TypeNestTypeValue(self.family, self, elem, value)
1130        elif elem['type'] == 'sub-message':
1131            t = TypeSubMessage(self.family, self, elem, value)
1132        else:
1133            raise Exception(f"No typed class for type {elem['type']}")
1134
1135        if 'multi-attr' in elem and elem['multi-attr']:
1136            t = TypeMultiAttr(self.family, self, elem, value, t)
1137
1138        return t
1139
1140
1141class Operation(SpecOperation):
1142    def __init__(self, family, yaml, req_value, rsp_value):
1143        # Fill in missing operation properties (for fixed hdr-only msgs)
1144        for mode in ['do', 'dump', 'event']:
1145            for direction in ['request', 'reply']:
1146                try:
1147                    yaml[mode][direction].setdefault('attributes', [])
1148                except KeyError:
1149                    pass
1150
1151        super().__init__(family, yaml, req_value, rsp_value)
1152
1153        self.render_name = c_lower(family.ident_name + '_' + self.name)
1154
1155        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1156                         ('dump' in yaml and 'request' in yaml['dump'])
1157
1158        self.has_ntf = False
1159
1160        # Added by resolve:
1161        self.enum_name = None
1162        delattr(self, "enum_name")
1163
1164    def resolve(self):
1165        self.resolve_up(super())
1166
1167        if not self.is_async:
1168            self.enum_name = self.family.op_prefix + c_upper(self.name)
1169        else:
1170            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1171
1172    def mark_has_ntf(self):
1173        self.has_ntf = True
1174
1175
1176class SubMessage(SpecSubMessage):
1177    def __init__(self, family, yaml):
1178        super().__init__(family, yaml)
1179
1180        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1181
1182    def resolve(self):
1183        self.resolve_up(super())
1184
1185
1186class Family(SpecFamily):
1187    def __init__(self, file_name, exclude_ops):
1188        # Added by resolve:
1189        self.c_name = None
1190        delattr(self, "c_name")
1191        self.op_prefix = None
1192        delattr(self, "op_prefix")
1193        self.async_op_prefix = None
1194        delattr(self, "async_op_prefix")
1195        self.mcgrps = None
1196        delattr(self, "mcgrps")
1197        self.consts = None
1198        delattr(self, "consts")
1199        self.hooks = None
1200        delattr(self, "hooks")
1201
1202        super().__init__(file_name, exclude_ops=exclude_ops)
1203
1204        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1205        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1206
1207        if 'definitions' not in self.yaml:
1208            self.yaml['definitions'] = []
1209
1210        if 'uapi-header' in self.yaml:
1211            self.uapi_header = self.yaml['uapi-header']
1212        else:
1213            self.uapi_header = f"linux/{self.ident_name}.h"
1214        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1215            self.uapi_header_name = self.uapi_header[6:-2]
1216        else:
1217            self.uapi_header_name = self.ident_name
1218
1219    def resolve(self):
1220        self.resolve_up(super())
1221
1222        self.c_name = c_lower(self.ident_name)
1223        if 'name-prefix' in self.yaml['operations']:
1224            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1225        else:
1226            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1227        if 'async-prefix' in self.yaml['operations']:
1228            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1229        else:
1230            self.async_op_prefix = self.op_prefix
1231
1232        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1233
1234        self.hooks = dict()
1235        for when in ['pre', 'post']:
1236            self.hooks[when] = dict()
1237            for op_mode in ['do', 'dump']:
1238                self.hooks[when][op_mode] = dict()
1239                self.hooks[when][op_mode]['set'] = set()
1240                self.hooks[when][op_mode]['list'] = []
1241
1242        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1243        self.root_sets = dict()
1244        # dict space-name -> Struct
1245        self.pure_nested_structs = dict()
1246
1247        self._mark_notify()
1248        self._mock_up_events()
1249
1250        self._load_root_sets()
1251        self._load_nested_sets()
1252        self._load_attr_use()
1253        self._load_selector_passing()
1254        self._load_hooks()
1255
1256        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1257        if self.kernel_policy == 'global':
1258            self._load_global_policy()
1259
1260    def new_enum(self, elem):
1261        return EnumSet(self, elem)
1262
1263    def new_attr_set(self, elem):
1264        return AttrSet(self, elem)
1265
1266    def new_operation(self, elem, req_value, rsp_value):
1267        return Operation(self, elem, req_value, rsp_value)
1268
1269    def new_sub_message(self, elem):
1270        return SubMessage(self, elem)
1271
1272    def is_classic(self):
1273        return self.proto == 'netlink-raw'
1274
1275    def _mark_notify(self):
1276        for op in self.msgs.values():
1277            if 'notify' in op:
1278                self.ops[op['notify']].mark_has_ntf()
1279
1280    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1281    def _mock_up_events(self):
1282        for op in self.yaml['operations']['list']:
1283            if 'event' in op:
1284                op['do'] = {
1285                    'reply': {
1286                        'attributes': op['event']['attributes']
1287                    }
1288                }
1289
1290    def _load_root_sets(self):
1291        for op_name, op in self.msgs.items():
1292            if 'attribute-set' not in op:
1293                continue
1294
1295            req_attrs = set()
1296            rsp_attrs = set()
1297            for op_mode in ['do', 'dump']:
1298                if op_mode in op and 'request' in op[op_mode]:
1299                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1300                if op_mode in op and 'reply' in op[op_mode]:
1301                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1302            if 'event' in op:
1303                rsp_attrs.update(set(op['event']['attributes']))
1304
1305            if op['attribute-set'] not in self.root_sets:
1306                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1307            else:
1308                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1309                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1310
1311    def _sort_pure_types(self):
1312        # Try to reorder according to dependencies
1313        pns_key_list = list(self.pure_nested_structs.keys())
1314        pns_key_seen = set()
1315        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1316        for _ in range(rounds):
1317            if len(pns_key_list) == 0:
1318                break
1319            name = pns_key_list.pop(0)
1320            finished = True
1321            for _, spec in self.attr_sets[name].items():
1322                if 'nested-attributes' in spec:
1323                    nested = spec['nested-attributes']
1324                elif 'sub-message' in spec:
1325                    nested = spec.sub_message
1326                else:
1327                    continue
1328
1329                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1330                if self.pure_nested_structs[nested].recursive:
1331                    continue
1332                if nested not in pns_key_seen:
1333                    # Dicts are sorted, this will make struct last
1334                    struct = self.pure_nested_structs.pop(name)
1335                    self.pure_nested_structs[name] = struct
1336                    finished = False
1337                    break
1338            if finished:
1339                pns_key_seen.add(name)
1340            else:
1341                pns_key_list.append(name)
1342
1343    def _load_nested_set_nest(self, spec):
1344        inherit = set()
1345        nested = spec['nested-attributes']
1346        if nested not in self.root_sets:
1347            if nested not in self.pure_nested_structs:
1348                self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1349        else:
1350            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1351
1352        if 'type-value' in spec:
1353            if nested in self.root_sets:
1354                raise Exception("Inheriting members to a space used as root not supported")
1355            inherit.update(set(spec['type-value']))
1356        elif spec['type'] == 'indexed-array':
1357            inherit.add('idx')
1358        self.pure_nested_structs[nested].set_inherited(inherit)
1359
1360        return nested
1361
1362    def _load_nested_set_submsg(self, spec):
1363        # Fake the struct type for the sub-message itself
1364        # its not a attr_set but codegen wants attr_sets.
1365        submsg = self.sub_msgs[spec["sub-message"]]
1366        nested = submsg.name
1367
1368        attrs = []
1369        for name, fmt in submsg.formats.items():
1370            attrs.append({
1371                "name": name,
1372                "type": "nest",
1373                "parent-sub-message": spec,
1374                "nested-attributes": fmt['attribute-set']
1375            })
1376
1377        self.attr_sets[nested] = AttrSet(self, {
1378            "name": nested,
1379            "name-pfx": self.name + '-' + spec.name + '-',
1380            "attributes": attrs
1381        })
1382
1383        if nested not in self.pure_nested_structs:
1384            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1385
1386        return nested
1387
1388    def _load_nested_sets(self):
1389        attr_set_queue = list(self.root_sets.keys())
1390        attr_set_seen = set(self.root_sets.keys())
1391
1392        while len(attr_set_queue):
1393            a_set = attr_set_queue.pop(0)
1394            for attr, spec in self.attr_sets[a_set].items():
1395                if 'nested-attributes' in spec:
1396                    nested = self._load_nested_set_nest(spec)
1397                elif 'sub-message' in spec:
1398                    nested = self._load_nested_set_submsg(spec)
1399                else:
1400                    continue
1401
1402                if nested not in attr_set_seen:
1403                    attr_set_queue.append(nested)
1404                    attr_set_seen.add(nested)
1405
1406        for root_set, rs_members in self.root_sets.items():
1407            for attr, spec in self.attr_sets[root_set].items():
1408                if 'nested-attributes' in spec:
1409                    nested = spec['nested-attributes']
1410                elif 'sub-message' in spec:
1411                    nested = spec.sub_message
1412                else:
1413                    nested = None
1414
1415                if nested:
1416                    if attr in rs_members['request']:
1417                        self.pure_nested_structs[nested].request = True
1418                    if attr in rs_members['reply']:
1419                        self.pure_nested_structs[nested].reply = True
1420
1421                    if spec.is_multi_val():
1422                        child = self.pure_nested_structs.get(nested)
1423                        child.in_multi_val = True
1424
1425        self._sort_pure_types()
1426
1427        # Propagate the request / reply / recursive
1428        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1429            for _, spec in self.attr_sets[attr_set].items():
1430                if attr_set in struct.child_nests:
1431                    struct.recursive = True
1432
1433                if 'nested-attributes' in spec:
1434                    child_name = spec['nested-attributes']
1435                elif 'sub-message' in spec:
1436                    child_name = spec.sub_message
1437                else:
1438                    continue
1439
1440                struct.child_nests.add(child_name)
1441                child = self.pure_nested_structs.get(child_name)
1442                if child:
1443                    if not child.recursive:
1444                        struct.child_nests.update(child.child_nests)
1445                    child.request |= struct.request
1446                    child.reply |= struct.reply
1447                    if spec.is_multi_val():
1448                        child.in_multi_val = True
1449
1450        self._sort_pure_types()
1451
1452    def _load_attr_use(self):
1453        for _, struct in self.pure_nested_structs.items():
1454            if struct.request:
1455                for _, arg in struct.member_list():
1456                    arg.set_request()
1457            if struct.reply:
1458                for _, arg in struct.member_list():
1459                    arg.set_reply()
1460
1461        for root_set, rs_members in self.root_sets.items():
1462            for attr, spec in self.attr_sets[root_set].items():
1463                if attr in rs_members['request']:
1464                    spec.set_request()
1465                if attr in rs_members['reply']:
1466                    spec.set_reply()
1467
1468    def _load_selector_passing(self):
1469        def all_structs():
1470            for k, v in reversed(self.pure_nested_structs.items()):
1471                yield k, v
1472            for k, _ in self.root_sets.items():
1473                yield k, None  # we don't have a struct, but it must be terminal
1474
1475        for attr_set, struct in all_structs():
1476            for _, spec in self.attr_sets[attr_set].items():
1477                if 'nested-attributes' in spec:
1478                    child_name = spec['nested-attributes']
1479                elif 'sub-message' in spec:
1480                    child_name = spec.sub_message
1481                else:
1482                    continue
1483
1484                child = self.pure_nested_structs.get(child_name)
1485                for selector in child.external_selectors():
1486                    if selector.name in self.attr_sets[attr_set]:
1487                        sel_attr = self.attr_sets[attr_set][selector.name]
1488                        selector.set_attr(sel_attr)
1489                    else:
1490                        raise Exception("Passing selector thru more than one layer not supported")
1491
1492    def _load_global_policy(self):
1493        global_set = set()
1494        attr_set_name = None
1495        for op_name, op in self.ops.items():
1496            if not op:
1497                continue
1498            if 'attribute-set' not in op:
1499                continue
1500
1501            if attr_set_name is None:
1502                attr_set_name = op['attribute-set']
1503            if attr_set_name != op['attribute-set']:
1504                raise Exception('For a global policy all ops must use the same set')
1505
1506            for op_mode in ['do', 'dump']:
1507                if op_mode in op:
1508                    req = op[op_mode].get('request')
1509                    if req:
1510                        global_set.update(req.get('attributes', []))
1511
1512        self.global_policy = []
1513        self.global_policy_set = attr_set_name
1514        for attr in self.attr_sets[attr_set_name]:
1515            if attr in global_set:
1516                self.global_policy.append(attr)
1517
1518    def _load_hooks(self):
1519        for op in self.ops.values():
1520            for op_mode in ['do', 'dump']:
1521                if op_mode not in op:
1522                    continue
1523                for when in ['pre', 'post']:
1524                    if when not in op[op_mode]:
1525                        continue
1526                    name = op[op_mode][when]
1527                    if name in self.hooks[when][op_mode]['set']:
1528                        continue
1529                    self.hooks[when][op_mode]['set'].add(name)
1530                    self.hooks[when][op_mode]['list'].append(name)
1531
1532
1533class RenderInfo:
1534    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1535        self.family = family
1536        self.nl = cw.nlib
1537        self.ku_space = ku_space
1538        self.op_mode = op_mode
1539        self.op = op
1540
1541        self.fixed_hdr = None
1542        self.fixed_hdr_len = 'ys->family->hdr_len'
1543        if op and op.fixed_header:
1544            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1545            if op.fixed_header != family.fixed_header:
1546                if family.is_classic():
1547                    self.fixed_hdr_len = f"sizeof({self.fixed_hdr})"
1548                else:
1549                    raise Exception(f"Per-op fixed header not supported, yet")
1550
1551
1552        # 'do' and 'dump' response parsing is identical
1553        self.type_consistent = True
1554        self.type_oneside = False
1555        if op_mode != 'do' and 'dump' in op:
1556            if 'do' in op:
1557                if ('reply' in op['do']) != ('reply' in op["dump"]):
1558                    self.type_consistent = False
1559                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1560                    self.type_consistent = False
1561            else:
1562                self.type_consistent = True
1563                self.type_oneside = True
1564
1565        self.attr_set = attr_set
1566        if not self.attr_set:
1567            self.attr_set = op['attribute-set']
1568
1569        self.type_name_conflict = False
1570        if op:
1571            self.type_name = c_lower(op.name)
1572        else:
1573            self.type_name = c_lower(attr_set)
1574            if attr_set in family.consts:
1575                self.type_name_conflict = True
1576
1577        self.cw = cw
1578
1579        self.struct = dict()
1580        if op_mode == 'notify':
1581            op_mode = 'do' if 'do' in op else 'dump'
1582        for op_dir in ['request', 'reply']:
1583            if op:
1584                type_list = []
1585                if op_dir in op[op_mode]:
1586                    type_list = op[op_mode][op_dir]['attributes']
1587                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1588        if op_mode == 'event':
1589            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1590
1591    def type_empty(self, key):
1592        return len(self.struct[key].attr_list) == 0 and self.fixed_hdr is None
1593
1594    def needs_nlflags(self, direction):
1595        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1596
1597
1598class CodeWriter:
1599    def __init__(self, nlib, out_file=None, overwrite=True):
1600        self.nlib = nlib
1601        self._overwrite = overwrite
1602
1603        self._nl = False
1604        self._block_end = False
1605        self._silent_block = False
1606        self._ind = 0
1607        self._ifdef_block = None
1608        if out_file is None:
1609            self._out = os.sys.stdout
1610        else:
1611            self._out = tempfile.NamedTemporaryFile('w+')
1612            self._out_file = out_file
1613
1614    def __del__(self):
1615        self.close_out_file()
1616
1617    def close_out_file(self):
1618        if self._out == os.sys.stdout:
1619            return
1620        # Avoid modifying the file if contents didn't change
1621        self._out.flush()
1622        if not self._overwrite and os.path.isfile(self._out_file):
1623            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1624                return
1625        with open(self._out_file, 'w+') as out_file:
1626            self._out.seek(0)
1627            shutil.copyfileobj(self._out, out_file)
1628            self._out.close()
1629        self._out = os.sys.stdout
1630
1631    @classmethod
1632    def _is_cond(cls, line):
1633        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1634
1635    def p(self, line, add_ind=0):
1636        if self._block_end:
1637            self._block_end = False
1638            if line.startswith('else'):
1639                line = '} ' + line
1640            else:
1641                self._out.write('\t' * self._ind + '}\n')
1642
1643        if self._nl:
1644            self._out.write('\n')
1645            self._nl = False
1646
1647        ind = self._ind
1648        if line[-1] == ':':
1649            ind -= 1
1650        if self._silent_block:
1651            ind += 1
1652        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1653        self._silent_block |= line.strip() == 'else'
1654        if line[0] == '#':
1655            ind = 0
1656        if add_ind:
1657            ind += add_ind
1658        self._out.write('\t' * ind + line + '\n')
1659
1660    def nl(self):
1661        self._nl = True
1662
1663    def block_start(self, line=''):
1664        if line:
1665            line = line + ' '
1666        self.p(line + '{')
1667        self._ind += 1
1668
1669    def block_end(self, line=''):
1670        if line and line[0] not in {';', ','}:
1671            line = ' ' + line
1672        self._ind -= 1
1673        self._nl = False
1674        if not line:
1675            # Delay printing closing bracket in case "else" comes next
1676            if self._block_end:
1677                self._out.write('\t' * (self._ind + 1) + '}\n')
1678            self._block_end = True
1679        else:
1680            self.p('}' + line)
1681
1682    def write_doc_line(self, doc, indent=True):
1683        words = doc.split()
1684        line = ' *'
1685        for word in words:
1686            if len(line) + len(word) >= 79:
1687                self.p(line)
1688                line = ' *'
1689                if indent:
1690                    line += '  '
1691            line += ' ' + word
1692        self.p(line)
1693
1694    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1695        if not args:
1696            args = ['void']
1697
1698        if doc:
1699            self.p('/*')
1700            self.p(' * ' + doc)
1701            self.p(' */')
1702
1703        oneline = qual_ret
1704        if qual_ret[-1] != '*':
1705            oneline += ' '
1706        oneline += f"{name}({', '.join(args)}){suffix}"
1707
1708        if len(oneline) < 80:
1709            self.p(oneline)
1710            return
1711
1712        v = qual_ret
1713        if len(v) > 3:
1714            self.p(v)
1715            v = ''
1716        elif qual_ret[-1] != '*':
1717            v += ' '
1718        v += name + '('
1719        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1720        delta_ind = len(v) - len(ind)
1721        v += args[0]
1722        i = 1
1723        while i < len(args):
1724            next_len = len(v) + len(args[i])
1725            if v[0] == '\t':
1726                next_len += delta_ind
1727            if next_len > 76:
1728                self.p(v + ',')
1729                v = ind
1730            else:
1731                v += ', '
1732            v += args[i]
1733            i += 1
1734        self.p(v + ')' + suffix)
1735
1736    def write_func_lvar(self, local_vars):
1737        if not local_vars:
1738            return
1739
1740        if type(local_vars) is str:
1741            local_vars = [local_vars]
1742
1743        local_vars.sort(key=len, reverse=True)
1744        for var in local_vars:
1745            self.p(var)
1746        self.nl()
1747
1748    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1749        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1750        self.block_start()
1751        self.write_func_lvar(local_vars=local_vars)
1752
1753        for line in body:
1754            self.p(line)
1755        self.block_end()
1756
1757    def writes_defines(self, defines):
1758        longest = 0
1759        for define in defines:
1760            if len(define[0]) > longest:
1761                longest = len(define[0])
1762        longest = ((longest + 8) // 8) * 8
1763        for define in defines:
1764            line = '#define ' + define[0]
1765            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1766            if type(define[1]) is int:
1767                line += str(define[1])
1768            elif type(define[1]) is str:
1769                line += '"' + define[1] + '"'
1770            self.p(line)
1771
1772    def write_struct_init(self, members):
1773        longest = max([len(x[0]) for x in members])
1774        longest += 1  # because we prepend a .
1775        longest = ((longest + 8) // 8) * 8
1776        for one in members:
1777            line = '.' + one[0]
1778            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1779            line += '= ' + str(one[1]) + ','
1780            self.p(line)
1781
1782    def ifdef_block(self, config):
1783        config_option = None
1784        if config:
1785            config_option = 'CONFIG_' + c_upper(config)
1786        if self._ifdef_block == config_option:
1787            return
1788
1789        if self._ifdef_block:
1790            self.p('#endif /* ' + self._ifdef_block + ' */')
1791        if config_option:
1792            self.p('#ifdef ' + config_option)
1793        self._ifdef_block = config_option
1794
1795
1796scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1797
1798direction_to_suffix = {
1799    'reply': '_rsp',
1800    'request': '_req',
1801    '': ''
1802}
1803
1804op_mode_to_wrapper = {
1805    'do': '',
1806    'dump': '_list',
1807    'notify': '_ntf',
1808    'event': '',
1809}
1810
1811_C_KW = {
1812    'auto',
1813    'bool',
1814    'break',
1815    'case',
1816    'char',
1817    'const',
1818    'continue',
1819    'default',
1820    'do',
1821    'double',
1822    'else',
1823    'enum',
1824    'extern',
1825    'float',
1826    'for',
1827    'goto',
1828    'if',
1829    'inline',
1830    'int',
1831    'long',
1832    'register',
1833    'return',
1834    'short',
1835    'signed',
1836    'sizeof',
1837    'static',
1838    'struct',
1839    'switch',
1840    'typedef',
1841    'union',
1842    'unsigned',
1843    'void',
1844    'volatile',
1845    'while'
1846}
1847
1848
1849def rdir(direction):
1850    if direction == 'reply':
1851        return 'request'
1852    if direction == 'request':
1853        return 'reply'
1854    return direction
1855
1856
1857def op_prefix(ri, direction, deref=False):
1858    suffix = f"_{ri.type_name}"
1859
1860    if not ri.op_mode or ri.op_mode == 'do':
1861        suffix += f"{direction_to_suffix[direction]}"
1862    else:
1863        if direction == 'request':
1864            suffix += '_req'
1865            if not ri.type_oneside:
1866                suffix += '_dump'
1867        else:
1868            if ri.type_consistent:
1869                if deref:
1870                    suffix += f"{direction_to_suffix[direction]}"
1871                else:
1872                    suffix += op_mode_to_wrapper[ri.op_mode]
1873            else:
1874                suffix += '_rsp'
1875                suffix += '_dump' if deref else '_list'
1876
1877    return f"{ri.family.c_name}{suffix}"
1878
1879
1880def type_name(ri, direction, deref=False):
1881    return f"struct {op_prefix(ri, direction, deref=deref)}"
1882
1883
1884def print_prototype(ri, direction, terminate=True, doc=None):
1885    suffix = ';' if terminate else ''
1886
1887    fname = ri.op.render_name
1888    if ri.op_mode == 'dump':
1889        fname += '_dump'
1890
1891    args = ['struct ynl_sock *ys']
1892    if 'request' in ri.op[ri.op_mode]:
1893        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1894
1895    ret = 'int'
1896    if 'reply' in ri.op[ri.op_mode]:
1897        ret = f"{type_name(ri, rdir(direction))} *"
1898
1899    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1900
1901
1902def print_req_prototype(ri):
1903    print_prototype(ri, "request", doc=ri.op['doc'])
1904
1905
1906def print_dump_prototype(ri):
1907    print_prototype(ri, "request")
1908
1909
1910def put_typol_submsg(cw, struct):
1911    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1912
1913    i = 0
1914    for name, arg in struct.member_list():
1915        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s", .nest = &%s_nest, },' %
1916             (i, name, arg.nested_render_name))
1917        i += 1
1918
1919    cw.block_end(line=';')
1920    cw.nl()
1921
1922    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1923    cw.p(f'.max_attr = {i - 1},')
1924    cw.p(f'.table = {struct.render_name}_policy,')
1925    cw.block_end(line=';')
1926    cw.nl()
1927
1928
1929def put_typol_fwd(cw, struct):
1930    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1931
1932
1933def put_typol(cw, struct):
1934    if struct.submsg:
1935        put_typol_submsg(cw, struct)
1936        return
1937
1938    type_max = struct.attr_set.max_name
1939    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1940
1941    for _, arg in struct.member_list():
1942        arg.attr_typol(cw)
1943
1944    cw.block_end(line=';')
1945    cw.nl()
1946
1947    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1948    cw.p(f'.max_attr = {type_max},')
1949    cw.p(f'.table = {struct.render_name}_policy,')
1950    cw.block_end(line=';')
1951    cw.nl()
1952
1953
1954def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1955    args = [f'int {arg_name}']
1956    if enum:
1957        args = [enum.user_type + ' ' + arg_name]
1958    cw.write_func_prot('const char *', f'{render_name}_str', args)
1959    cw.block_start()
1960    if enum and enum.type == 'flags':
1961        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1962    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1963    cw.p('return NULL;')
1964    cw.p(f'return {map_name}[{arg_name}];')
1965    cw.block_end()
1966    cw.nl()
1967
1968
1969def put_op_name_fwd(family, cw):
1970    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1971
1972
1973def put_op_name(family, cw):
1974    map_name = f'{family.c_name}_op_strmap'
1975    cw.block_start(line=f"static const char * const {map_name}[] =")
1976    for op_name, op in family.msgs.items():
1977        if op.rsp_value:
1978            # Make sure we don't add duplicated entries, if multiple commands
1979            # produce the same response in legacy families.
1980            if family.rsp_by_value[op.rsp_value] != op:
1981                cw.p(f'// skip "{op_name}", duplicate reply value')
1982                continue
1983
1984            if op.req_value == op.rsp_value:
1985                cw.p(f'[{op.enum_name}] = "{op_name}",')
1986            else:
1987                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1988    cw.block_end(line=';')
1989    cw.nl()
1990
1991    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1992
1993
1994def put_enum_to_str_fwd(family, cw, enum):
1995    args = [enum.user_type + ' value']
1996    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1997
1998
1999def put_enum_to_str(family, cw, enum):
2000    map_name = f'{enum.render_name}_strmap'
2001    cw.block_start(line=f"static const char * const {map_name}[] =")
2002    for entry in enum.entries.values():
2003        cw.p(f'[{entry.value}] = "{entry.name}",')
2004    cw.block_end(line=';')
2005    cw.nl()
2006
2007    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2008
2009
2010def put_req_nested_prototype(ri, struct, suffix=';'):
2011    func_args = ['struct nlmsghdr *nlh',
2012                 'unsigned int attr_type',
2013                 f'{struct.ptr_name}obj']
2014
2015    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2016                          suffix=suffix)
2017
2018
2019def put_req_nested(ri, struct):
2020    local_vars = []
2021    init_lines = []
2022
2023    if struct.submsg is None:
2024        local_vars.append('struct nlattr *nest;')
2025        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2026
2027    has_anest = False
2028    has_count = False
2029    for _, arg in struct.member_list():
2030        has_anest |= arg.type == 'indexed-array'
2031        has_count |= arg.presence_type() == 'count'
2032    if has_anest:
2033        local_vars.append('struct nlattr *array;')
2034    if has_count:
2035        local_vars.append('unsigned int i;')
2036
2037    put_req_nested_prototype(ri, struct, suffix='')
2038    ri.cw.block_start()
2039    ri.cw.write_func_lvar(local_vars)
2040
2041    for line in init_lines:
2042        ri.cw.p(line)
2043
2044    for _, arg in struct.member_list():
2045        arg.attr_put(ri, "obj")
2046
2047    if struct.submsg is None:
2048        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2049
2050    ri.cw.nl()
2051    ri.cw.p('return 0;')
2052    ri.cw.block_end()
2053    ri.cw.nl()
2054
2055
2056def _multi_parse(ri, struct, init_lines, local_vars):
2057    if struct.nested:
2058        iter_line = "ynl_attr_for_each_nested(attr, nested)"
2059    else:
2060        if ri.fixed_hdr:
2061            local_vars += ['void *hdr;']
2062        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2063        if ri.op.fixed_header != ri.family.fixed_header:
2064            if ri.family.is_classic():
2065                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({ri.fixed_hdr}))"
2066            else:
2067                raise Exception(f"Per-op fixed header not supported, yet")
2068
2069    array_nests = set()
2070    multi_attrs = set()
2071    needs_parg = False
2072    for arg, aspec in struct.member_list():
2073        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2074            if aspec["sub-type"] in {'binary', 'nest'}:
2075                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2076                array_nests.add(arg)
2077            elif aspec['sub-type'] in scalars:
2078                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2079                array_nests.add(arg)
2080            else:
2081                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2082        if 'multi-attr' in aspec:
2083            multi_attrs.add(arg)
2084        needs_parg |= 'nested-attributes' in aspec
2085        needs_parg |= 'sub-message' in aspec
2086    if array_nests or multi_attrs:
2087        local_vars.append('int i;')
2088    if needs_parg:
2089        local_vars.append('struct ynl_parse_arg parg;')
2090        init_lines.append('parg.ys = yarg->ys;')
2091
2092    all_multi = array_nests | multi_attrs
2093
2094    for anest in sorted(all_multi):
2095        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
2096
2097    ri.cw.block_start()
2098    ri.cw.write_func_lvar(local_vars)
2099
2100    for line in init_lines:
2101        ri.cw.p(line)
2102    ri.cw.nl()
2103
2104    for arg in struct.inherited:
2105        ri.cw.p(f'dst->{arg} = {arg};')
2106
2107    if ri.fixed_hdr:
2108        if ri.family.is_classic():
2109            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2110        else:
2111            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2112        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
2113    for anest in sorted(all_multi):
2114        aspec = struct[anest]
2115        ri.cw.p(f"if (dst->{aspec.c_name})")
2116        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2117
2118    ri.cw.nl()
2119    ri.cw.block_start(line=iter_line)
2120    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2121    ri.cw.nl()
2122
2123    first = True
2124    for _, arg in struct.member_list():
2125        good = arg.attr_get(ri, 'dst', first=first)
2126        # First may be 'unused' or 'pad', ignore those
2127        first &= not good
2128
2129    ri.cw.block_end()
2130    ri.cw.nl()
2131
2132    for anest in sorted(array_nests):
2133        aspec = struct[anest]
2134
2135        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2136        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2137        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2138        ri.cw.p('i = 0;')
2139        if 'nested-attributes' in aspec:
2140            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2141        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2142        if 'nested-attributes' in aspec:
2143            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2144            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2145            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2146        elif aspec.sub_type in scalars:
2147            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2148        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2149            # Length is validated by typol
2150            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2151        else:
2152            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2153        ri.cw.p('i++;')
2154        ri.cw.block_end()
2155        ri.cw.block_end()
2156    ri.cw.nl()
2157
2158    for anest in sorted(multi_attrs):
2159        aspec = struct[anest]
2160        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2161        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2162        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2163        ri.cw.p('i = 0;')
2164        if 'nested-attributes' in aspec:
2165            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2166        ri.cw.block_start(line=iter_line)
2167        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2168        if 'nested-attributes' in aspec:
2169            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2170            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2171            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2172        elif aspec.type in scalars:
2173            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2174        elif aspec.type == 'binary' and 'struct' in aspec:
2175            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2176            ri.cw.nl()
2177            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2178            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2179            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2180        elif aspec.type == 'string':
2181            ri.cw.p('unsigned int len;')
2182            ri.cw.nl()
2183            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2184            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2185            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2186            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2187            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2188        else:
2189            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2190        ri.cw.p('i++;')
2191        ri.cw.block_end()
2192        ri.cw.block_end()
2193        ri.cw.block_end()
2194    ri.cw.nl()
2195
2196    if struct.nested:
2197        ri.cw.p('return 0;')
2198    else:
2199        ri.cw.p('return YNL_PARSE_CB_OK;')
2200    ri.cw.block_end()
2201    ri.cw.nl()
2202
2203
2204def parse_rsp_submsg(ri, struct):
2205    parse_rsp_nested_prototype(ri, struct, suffix='')
2206
2207    var = 'dst'
2208
2209    ri.cw.block_start()
2210    ri.cw.write_func_lvar(['const struct nlattr *attr = nested;',
2211                          f'{struct.ptr_name}{var} = yarg->data;',
2212                          'struct ynl_parse_arg parg;'])
2213
2214    ri.cw.p('parg.ys = yarg->ys;')
2215    ri.cw.nl()
2216
2217    first = True
2218    for name, arg in struct.member_list():
2219        kw = 'if' if first else 'else if'
2220        first = False
2221
2222        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2223        get_lines, init_lines, _ = arg._attr_get(ri, var)
2224        for line in init_lines:
2225            ri.cw.p(line)
2226        for line in get_lines:
2227            ri.cw.p(line)
2228        if arg.presence_type() == 'present':
2229            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2230        ri.cw.block_end()
2231    ri.cw.p('return 0;')
2232    ri.cw.block_end()
2233    ri.cw.nl()
2234
2235
2236def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2237    func_args = ['struct ynl_parse_arg *yarg',
2238                 'const struct nlattr *nested']
2239    for sel in struct.external_selectors():
2240        func_args.append('const char *_sel_' + sel.name)
2241    if struct.submsg:
2242        func_args.insert(1, 'const char *sel')
2243    for arg in struct.inherited:
2244        func_args.append('__u32 ' + arg)
2245
2246    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2247                          suffix=suffix)
2248
2249
2250def parse_rsp_nested(ri, struct):
2251    if struct.submsg:
2252        return parse_rsp_submsg(ri, struct)
2253
2254    parse_rsp_nested_prototype(ri, struct, suffix='')
2255
2256    local_vars = ['const struct nlattr *attr;',
2257                  f'{struct.ptr_name}dst = yarg->data;']
2258    init_lines = []
2259
2260    if struct.member_list():
2261        _multi_parse(ri, struct, init_lines, local_vars)
2262    else:
2263        # Empty nest
2264        ri.cw.block_start()
2265        ri.cw.p('return 0;')
2266        ri.cw.block_end()
2267        ri.cw.nl()
2268
2269
2270def parse_rsp_msg(ri, deref=False):
2271    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2272        return
2273
2274    func_args = ['const struct nlmsghdr *nlh',
2275                 'struct ynl_parse_arg *yarg']
2276
2277    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2278                  'const struct nlattr *attr;']
2279    init_lines = ['dst = yarg->data;']
2280
2281    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2282
2283    if ri.struct["reply"].member_list():
2284        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2285    else:
2286        # Empty reply
2287        ri.cw.block_start()
2288        ri.cw.p('return YNL_PARSE_CB_OK;')
2289        ri.cw.block_end()
2290        ri.cw.nl()
2291
2292
2293def print_req(ri):
2294    ret_ok = '0'
2295    ret_err = '-1'
2296    direction = "request"
2297    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2298                  'struct nlmsghdr *nlh;',
2299                  'int err;']
2300
2301    if 'reply' in ri.op[ri.op_mode]:
2302        ret_ok = 'rsp'
2303        ret_err = 'NULL'
2304        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2305
2306    if ri.fixed_hdr:
2307        local_vars += ['size_t hdr_len;',
2308                       'void *hdr;']
2309
2310    for _, attr in ri.struct["request"].member_list():
2311        if attr.presence_type() == 'count':
2312            local_vars += ['unsigned int i;']
2313            break
2314
2315    print_prototype(ri, direction, terminate=False)
2316    ri.cw.block_start()
2317    ri.cw.write_func_lvar(local_vars)
2318
2319    if ri.family.is_classic():
2320        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2321    else:
2322        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2323
2324    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2325    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2326    if 'reply' in ri.op[ri.op_mode]:
2327        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2328    ri.cw.nl()
2329
2330    if ri.fixed_hdr:
2331        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2332        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2333        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2334        ri.cw.nl()
2335
2336    for _, attr in ri.struct["request"].member_list():
2337        attr.attr_put(ri, "req")
2338    ri.cw.nl()
2339
2340    if 'reply' in ri.op[ri.op_mode]:
2341        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2342        ri.cw.p('yrs.yarg.data = rsp;')
2343        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2344        if ri.op.value is not None:
2345            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2346        else:
2347            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2348        ri.cw.nl()
2349    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2350    ri.cw.p('if (err < 0)')
2351    if 'reply' in ri.op[ri.op_mode]:
2352        ri.cw.p('goto err_free;')
2353    else:
2354        ri.cw.p('return -1;')
2355    ri.cw.nl()
2356
2357    ri.cw.p(f"return {ret_ok};")
2358    ri.cw.nl()
2359
2360    if 'reply' in ri.op[ri.op_mode]:
2361        ri.cw.p('err_free:')
2362        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2363        ri.cw.p(f"return {ret_err};")
2364
2365    ri.cw.block_end()
2366
2367
2368def print_dump(ri):
2369    direction = "request"
2370    print_prototype(ri, direction, terminate=False)
2371    ri.cw.block_start()
2372    local_vars = ['struct ynl_dump_state yds = {};',
2373                  'struct nlmsghdr *nlh;',
2374                  'int err;']
2375
2376    if ri.fixed_hdr:
2377        local_vars += ['size_t hdr_len;',
2378                       'void *hdr;']
2379
2380    ri.cw.write_func_lvar(local_vars)
2381
2382    ri.cw.p('yds.yarg.ys = ys;')
2383    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2384    ri.cw.p("yds.yarg.data = NULL;")
2385    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2386    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2387    if ri.op.value is not None:
2388        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2389    else:
2390        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2391    ri.cw.nl()
2392    if ri.family.is_classic():
2393        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2394    else:
2395        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2396
2397    if ri.fixed_hdr:
2398        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2399        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2400        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2401        ri.cw.nl()
2402
2403    if "request" in ri.op[ri.op_mode]:
2404        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2405        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2406        ri.cw.nl()
2407        for _, attr in ri.struct["request"].member_list():
2408            attr.attr_put(ri, "req")
2409    ri.cw.nl()
2410
2411    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2412    ri.cw.p('if (err < 0)')
2413    ri.cw.p('goto free_list;')
2414    ri.cw.nl()
2415
2416    ri.cw.p('return yds.first;')
2417    ri.cw.nl()
2418    ri.cw.p('free_list:')
2419    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2420    ri.cw.p('return NULL;')
2421    ri.cw.block_end()
2422
2423
2424def call_free(ri, direction, var):
2425    return f"{op_prefix(ri, direction)}_free({var});"
2426
2427
2428def free_arg_name(direction):
2429    if direction:
2430        return direction_to_suffix[direction][1:]
2431    return 'obj'
2432
2433
2434def print_alloc_wrapper(ri, direction):
2435    name = op_prefix(ri, direction)
2436    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
2437    ri.cw.block_start()
2438    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
2439    ri.cw.block_end()
2440
2441
2442def print_free_prototype(ri, direction, suffix=';'):
2443    name = op_prefix(ri, direction)
2444    struct_name = name
2445    if ri.type_name_conflict:
2446        struct_name += '_'
2447    arg = free_arg_name(direction)
2448    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2449
2450
2451def print_nlflags_set(ri, direction):
2452    name = op_prefix(ri, direction)
2453    ri.cw.write_func_prot(f'static inline void', f"{name}_set_nlflags",
2454                          [f"struct {name} *req", "__u16 nl_flags"])
2455    ri.cw.block_start()
2456    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2457    ri.cw.block_end()
2458    ri.cw.nl()
2459
2460
2461def _print_type(ri, direction, struct):
2462    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2463    if not direction and ri.type_name_conflict:
2464        suffix += '_'
2465
2466    if ri.op_mode == 'dump' and not ri.type_oneside:
2467        suffix += '_dump'
2468
2469    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2470
2471    if ri.needs_nlflags(direction):
2472        ri.cw.p('__u16 _nlmsg_flags;')
2473        ri.cw.nl()
2474    if ri.fixed_hdr:
2475        ri.cw.p(ri.fixed_hdr + ' _hdr;')
2476        ri.cw.nl()
2477
2478    for type_filter in ['present', 'len', 'count']:
2479        meta_started = False
2480        for _, attr in struct.member_list():
2481            line = attr.presence_member(ri.ku_space, type_filter)
2482            if line:
2483                if not meta_started:
2484                    ri.cw.block_start(line=f"struct")
2485                    meta_started = True
2486                ri.cw.p(line)
2487        if meta_started:
2488            ri.cw.block_end(line=f'_{type_filter};')
2489    ri.cw.nl()
2490
2491    for arg in struct.inherited:
2492        ri.cw.p(f"__u32 {arg};")
2493
2494    for _, attr in struct.member_list():
2495        attr.struct_member(ri)
2496
2497    ri.cw.block_end(line=';')
2498    ri.cw.nl()
2499
2500
2501def print_type(ri, direction):
2502    _print_type(ri, direction, ri.struct[direction])
2503
2504
2505def print_type_full(ri, struct):
2506    _print_type(ri, "", struct)
2507
2508
2509def print_type_helpers(ri, direction, deref=False):
2510    print_free_prototype(ri, direction)
2511    ri.cw.nl()
2512
2513    if ri.needs_nlflags(direction):
2514        print_nlflags_set(ri, direction)
2515
2516    if ri.ku_space == 'user' and direction == 'request':
2517        for _, attr in ri.struct[direction].member_list():
2518            attr.setter(ri, ri.attr_set, direction, deref=deref)
2519    ri.cw.nl()
2520
2521
2522def print_req_type_helpers(ri):
2523    if ri.type_empty("request"):
2524        return
2525    print_alloc_wrapper(ri, "request")
2526    print_type_helpers(ri, "request")
2527
2528
2529def print_rsp_type_helpers(ri):
2530    if 'reply' not in ri.op[ri.op_mode]:
2531        return
2532    print_type_helpers(ri, "reply")
2533
2534
2535def print_parse_prototype(ri, direction, terminate=True):
2536    suffix = "_rsp" if direction == "reply" else "_req"
2537    term = ';' if terminate else ''
2538
2539    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2540                          ['const struct nlattr **tb',
2541                           f"struct {ri.op.render_name}{suffix} *req"],
2542                          suffix=term)
2543
2544
2545def print_req_type(ri):
2546    if ri.type_empty("request"):
2547        return
2548    print_type(ri, "request")
2549
2550
2551def print_req_free(ri):
2552    if 'request' not in ri.op[ri.op_mode]:
2553        return
2554    _free_type(ri, 'request', ri.struct['request'])
2555
2556
2557def print_rsp_type(ri):
2558    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2559        direction = 'reply'
2560    elif ri.op_mode == 'event':
2561        direction = 'reply'
2562    else:
2563        return
2564    print_type(ri, direction)
2565
2566
2567def print_wrapped_type(ri):
2568    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2569    if ri.op_mode == 'dump':
2570        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2571    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2572        ri.cw.p('__u16 family;')
2573        ri.cw.p('__u8 cmd;')
2574        ri.cw.p('struct ynl_ntf_base_type *next;')
2575        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2576    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2577    ri.cw.block_end(line=';')
2578    ri.cw.nl()
2579    print_free_prototype(ri, 'reply')
2580    ri.cw.nl()
2581
2582
2583def _free_type_members_iter(ri, struct):
2584    if struct.free_needs_iter():
2585        ri.cw.p('unsigned int i;')
2586        ri.cw.nl()
2587
2588
2589def _free_type_members(ri, var, struct, ref=''):
2590    for _, attr in struct.member_list():
2591        attr.free(ri, var, ref)
2592
2593
2594def _free_type(ri, direction, struct):
2595    var = free_arg_name(direction)
2596
2597    print_free_prototype(ri, direction, suffix='')
2598    ri.cw.block_start()
2599    _free_type_members_iter(ri, struct)
2600    _free_type_members(ri, var, struct)
2601    if direction:
2602        ri.cw.p(f'free({var});')
2603    ri.cw.block_end()
2604    ri.cw.nl()
2605
2606
2607def free_rsp_nested_prototype(ri):
2608        print_free_prototype(ri, "")
2609
2610
2611def free_rsp_nested(ri, struct):
2612    _free_type(ri, "", struct)
2613
2614
2615def print_rsp_free(ri):
2616    if 'reply' not in ri.op[ri.op_mode]:
2617        return
2618    _free_type(ri, 'reply', ri.struct['reply'])
2619
2620
2621def print_dump_type_free(ri):
2622    sub_type = type_name(ri, 'reply')
2623
2624    print_free_prototype(ri, 'reply', suffix='')
2625    ri.cw.block_start()
2626    ri.cw.p(f"{sub_type} *next = rsp;")
2627    ri.cw.nl()
2628    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2629    _free_type_members_iter(ri, ri.struct['reply'])
2630    ri.cw.p('rsp = next;')
2631    ri.cw.p('next = rsp->next;')
2632    ri.cw.nl()
2633
2634    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2635    ri.cw.p(f'free(rsp);')
2636    ri.cw.block_end()
2637    ri.cw.block_end()
2638    ri.cw.nl()
2639
2640
2641def print_ntf_type_free(ri):
2642    print_free_prototype(ri, 'reply', suffix='')
2643    ri.cw.block_start()
2644    _free_type_members_iter(ri, ri.struct['reply'])
2645    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2646    ri.cw.p(f'free(rsp);')
2647    ri.cw.block_end()
2648    ri.cw.nl()
2649
2650
2651def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2652    if terminate and ri and policy_should_be_static(struct.family):
2653        return
2654
2655    if terminate:
2656        prefix = 'extern '
2657    else:
2658        if ri and policy_should_be_static(struct.family):
2659            prefix = 'static '
2660        else:
2661            prefix = ''
2662
2663    suffix = ';' if terminate else ' = {'
2664
2665    max_attr = struct.attr_max_val
2666    if ri:
2667        name = ri.op.render_name
2668        if ri.op.dual_policy:
2669            name += '_' + ri.op_mode
2670    else:
2671        name = struct.render_name
2672    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2673
2674
2675def print_req_policy(cw, struct, ri=None):
2676    if ri and ri.op:
2677        cw.ifdef_block(ri.op.get('config-cond', None))
2678    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2679    for _, arg in struct.member_list():
2680        arg.attr_policy(cw)
2681    cw.p("};")
2682    cw.ifdef_block(None)
2683    cw.nl()
2684
2685
2686def kernel_can_gen_family_struct(family):
2687    return family.proto == 'genetlink'
2688
2689
2690def policy_should_be_static(family):
2691    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2692
2693
2694def print_kernel_policy_ranges(family, cw):
2695    first = True
2696    for _, attr_set in family.attr_sets.items():
2697        if attr_set.subset_of:
2698            continue
2699
2700        for _, attr in attr_set.items():
2701            if not attr.request:
2702                continue
2703            if 'full-range' not in attr.checks:
2704                continue
2705
2706            if first:
2707                cw.p('/* Integer value ranges */')
2708                first = False
2709
2710            sign = '' if attr.type[0] == 'u' else '_signed'
2711            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2712            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2713            members = []
2714            if 'min' in attr.checks:
2715                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2716            if 'max' in attr.checks:
2717                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2718            cw.write_struct_init(members)
2719            cw.block_end(line=';')
2720            cw.nl()
2721
2722
2723def print_kernel_policy_sparse_enum_validates(family, cw):
2724    first = True
2725    for _, attr_set in family.attr_sets.items():
2726        if attr_set.subset_of:
2727            continue
2728
2729        for _, attr in attr_set.items():
2730            if not attr.request:
2731                continue
2732            if not attr.enum_name:
2733                continue
2734            if 'sparse' not in attr.checks:
2735                continue
2736
2737            if first:
2738                cw.p('/* Sparse enums validation callbacks */')
2739                first = False
2740
2741            sign = '' if attr.type[0] == 'u' else '_signed'
2742            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2743            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2744                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2745            cw.block_start()
2746            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2747            enum = family.consts[attr['enum']]
2748            first_entry = True
2749            for entry in enum.entries.values():
2750                if first_entry:
2751                    first_entry = False
2752                else:
2753                    cw.p('fallthrough;')
2754                cw.p(f'case {entry.c_name}:')
2755            cw.p('return 0;')
2756            cw.block_end()
2757            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2758            cw.p('return -EINVAL;')
2759            cw.block_end()
2760            cw.nl()
2761
2762
2763def print_kernel_op_table_fwd(family, cw, terminate):
2764    exported = not kernel_can_gen_family_struct(family)
2765
2766    if not terminate or exported:
2767        cw.p(f"/* Ops table for {family.ident_name} */")
2768
2769        pol_to_struct = {'global': 'genl_small_ops',
2770                         'per-op': 'genl_ops',
2771                         'split': 'genl_split_ops'}
2772        struct_type = pol_to_struct[family.kernel_policy]
2773
2774        if not exported:
2775            cnt = ""
2776        elif family.kernel_policy == 'split':
2777            cnt = 0
2778            for op in family.ops.values():
2779                if 'do' in op:
2780                    cnt += 1
2781                if 'dump' in op:
2782                    cnt += 1
2783        else:
2784            cnt = len(family.ops)
2785
2786        qual = 'static const' if not exported else 'const'
2787        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2788        if terminate:
2789            cw.p(f"extern {line};")
2790        else:
2791            cw.block_start(line=line + ' =')
2792
2793    if not terminate:
2794        return
2795
2796    cw.nl()
2797    for name in family.hooks['pre']['do']['list']:
2798        cw.write_func_prot('int', c_lower(name),
2799                           ['const struct genl_split_ops *ops',
2800                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2801    for name in family.hooks['post']['do']['list']:
2802        cw.write_func_prot('void', c_lower(name),
2803                           ['const struct genl_split_ops *ops',
2804                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2805    for name in family.hooks['pre']['dump']['list']:
2806        cw.write_func_prot('int', c_lower(name),
2807                           ['struct netlink_callback *cb'], suffix=';')
2808    for name in family.hooks['post']['dump']['list']:
2809        cw.write_func_prot('int', c_lower(name),
2810                           ['struct netlink_callback *cb'], suffix=';')
2811
2812    cw.nl()
2813
2814    for op_name, op in family.ops.items():
2815        if op.is_async:
2816            continue
2817
2818        if 'do' in op:
2819            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2820            cw.write_func_prot('int', name,
2821                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2822
2823        if 'dump' in op:
2824            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2825            cw.write_func_prot('int', name,
2826                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2827    cw.nl()
2828
2829
2830def print_kernel_op_table_hdr(family, cw):
2831    print_kernel_op_table_fwd(family, cw, terminate=True)
2832
2833
2834def print_kernel_op_table(family, cw):
2835    print_kernel_op_table_fwd(family, cw, terminate=False)
2836    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2837        for op_name, op in family.ops.items():
2838            if op.is_async:
2839                continue
2840
2841            cw.ifdef_block(op.get('config-cond', None))
2842            cw.block_start()
2843            members = [('cmd', op.enum_name)]
2844            if 'dont-validate' in op:
2845                members.append(('validate',
2846                                ' | '.join([c_upper('genl-dont-validate-' + x)
2847                                            for x in op['dont-validate']])), )
2848            for op_mode in ['do', 'dump']:
2849                if op_mode in op:
2850                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2851                    members.append((op_mode + 'it', name))
2852            if family.kernel_policy == 'per-op':
2853                struct = Struct(family, op['attribute-set'],
2854                                type_list=op['do']['request']['attributes'])
2855
2856                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2857                members.append(('policy', name))
2858                members.append(('maxattr', struct.attr_max_val.enum_name))
2859            if 'flags' in op:
2860                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2861            cw.write_struct_init(members)
2862            cw.block_end(line=',')
2863    elif family.kernel_policy == 'split':
2864        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2865                    'dump': {'pre': 'start', 'post': 'done'}}
2866
2867        for op_name, op in family.ops.items():
2868            for op_mode in ['do', 'dump']:
2869                if op.is_async or op_mode not in op:
2870                    continue
2871
2872                cw.ifdef_block(op.get('config-cond', None))
2873                cw.block_start()
2874                members = [('cmd', op.enum_name)]
2875                if 'dont-validate' in op:
2876                    dont_validate = []
2877                    for x in op['dont-validate']:
2878                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2879                            continue
2880                        if op_mode == "dump" and x == 'strict':
2881                            continue
2882                        dont_validate.append(x)
2883
2884                    if dont_validate:
2885                        members.append(('validate',
2886                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2887                                                    for x in dont_validate])), )
2888                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2889                if 'pre' in op[op_mode]:
2890                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2891                members.append((op_mode + 'it', name))
2892                if 'post' in op[op_mode]:
2893                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2894                if 'request' in op[op_mode]:
2895                    struct = Struct(family, op['attribute-set'],
2896                                    type_list=op[op_mode]['request']['attributes'])
2897
2898                    if op.dual_policy:
2899                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2900                    else:
2901                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2902                    members.append(('policy', name))
2903                    members.append(('maxattr', struct.attr_max_val.enum_name))
2904                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2905                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2906                cw.write_struct_init(members)
2907                cw.block_end(line=',')
2908    cw.ifdef_block(None)
2909
2910    cw.block_end(line=';')
2911    cw.nl()
2912
2913
2914def print_kernel_mcgrp_hdr(family, cw):
2915    if not family.mcgrps['list']:
2916        return
2917
2918    cw.block_start('enum')
2919    for grp in family.mcgrps['list']:
2920        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2921        cw.p(grp_id)
2922    cw.block_end(';')
2923    cw.nl()
2924
2925
2926def print_kernel_mcgrp_src(family, cw):
2927    if not family.mcgrps['list']:
2928        return
2929
2930    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2931    for grp in family.mcgrps['list']:
2932        name = grp['name']
2933        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2934        cw.p('[' + grp_id + '] = { "' + name + '", },')
2935    cw.block_end(';')
2936    cw.nl()
2937
2938
2939def print_kernel_family_struct_hdr(family, cw):
2940    if not kernel_can_gen_family_struct(family):
2941        return
2942
2943    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2944    cw.nl()
2945    if 'sock-priv' in family.kernel_family:
2946        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2947        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2948        cw.nl()
2949
2950
2951def print_kernel_family_struct_src(family, cw):
2952    if not kernel_can_gen_family_struct(family):
2953        return
2954
2955    if 'sock-priv' in family.kernel_family:
2956        # Generate "trampolines" to make CFI happy
2957        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2958                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2959                      ["void *priv"])
2960        cw.nl()
2961        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2962                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2963                      ["void *priv"])
2964        cw.nl()
2965
2966    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2967    cw.p('.name\t\t= ' + family.fam_key + ',')
2968    cw.p('.version\t= ' + family.ver_key + ',')
2969    cw.p('.netnsok\t= true,')
2970    cw.p('.parallel_ops\t= true,')
2971    cw.p('.module\t\t= THIS_MODULE,')
2972    if family.kernel_policy == 'per-op':
2973        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2974        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2975    elif family.kernel_policy == 'split':
2976        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2977        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2978    if family.mcgrps['list']:
2979        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2980        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2981    if 'sock-priv' in family.kernel_family:
2982        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2983        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2984        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2985    cw.block_end(';')
2986
2987
2988def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2989    start_line = 'enum'
2990    if enum_name in obj:
2991        if obj[enum_name]:
2992            start_line = 'enum ' + c_lower(obj[enum_name])
2993    elif ckey and ckey in obj:
2994        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2995    cw.block_start(line=start_line)
2996
2997
2998def render_uapi_unified(family, cw, max_by_define, separate_ntf):
2999    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3000    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3001    max_value = f"({cnt_name} - 1)"
3002
3003    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3004    val = 0
3005    for op in family.msgs.values():
3006        if separate_ntf and ('notify' in op or 'event' in op):
3007            continue
3008
3009        suffix = ','
3010        if op.value != val:
3011            suffix = f" = {op.value},"
3012            val = op.value
3013        cw.p(op.enum_name + suffix)
3014        val += 1
3015    cw.nl()
3016    cw.p(cnt_name + ('' if max_by_define else ','))
3017    if not max_by_define:
3018        cw.p(f"{max_name} = {max_value}")
3019    cw.block_end(line=';')
3020    if max_by_define:
3021        cw.p(f"#define {max_name} {max_value}")
3022    cw.nl()
3023
3024
3025def render_uapi_directional(family, cw, max_by_define):
3026    max_name = f"{family.op_prefix}USER_MAX"
3027    cnt_name = f"__{family.op_prefix}USER_CNT"
3028    max_value = f"({cnt_name} - 1)"
3029
3030    cw.block_start(line='enum')
3031    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3032    val = 0
3033    for op in family.msgs.values():
3034        if 'do' in op and 'event' not in op:
3035            suffix = ','
3036            if op.value and op.value != val:
3037                suffix = f" = {op.value},"
3038                val = op.value
3039            cw.p(op.enum_name + suffix)
3040            val += 1
3041    cw.nl()
3042    cw.p(cnt_name + ('' if max_by_define else ','))
3043    if not max_by_define:
3044        cw.p(f"{max_name} = {max_value}")
3045    cw.block_end(line=';')
3046    if max_by_define:
3047        cw.p(f"#define {max_name} {max_value}")
3048    cw.nl()
3049
3050    max_name = f"{family.op_prefix}KERNEL_MAX"
3051    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3052    max_value = f"({cnt_name} - 1)"
3053
3054    cw.block_start(line='enum')
3055    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3056    val = 0
3057    for op in family.msgs.values():
3058        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3059            enum_name = op.enum_name
3060            if 'event' not in op and 'notify' not in op:
3061                enum_name = f'{enum_name}_REPLY'
3062
3063            suffix = ','
3064            if op.value and op.value != val:
3065                suffix = f" = {op.value},"
3066                val = op.value
3067            cw.p(enum_name + suffix)
3068            val += 1
3069    cw.nl()
3070    cw.p(cnt_name + ('' if max_by_define else ','))
3071    if not max_by_define:
3072        cw.p(f"{max_name} = {max_value}")
3073    cw.block_end(line=';')
3074    if max_by_define:
3075        cw.p(f"#define {max_name} {max_value}")
3076    cw.nl()
3077
3078
3079def render_uapi(family, cw):
3080    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3081    hdr_prot = hdr_prot.replace('/', '_')
3082    cw.p('#ifndef ' + hdr_prot)
3083    cw.p('#define ' + hdr_prot)
3084    cw.nl()
3085
3086    defines = [(family.fam_key, family["name"]),
3087               (family.ver_key, family.get('version', 1))]
3088    cw.writes_defines(defines)
3089    cw.nl()
3090
3091    defines = []
3092    for const in family['definitions']:
3093        if const.get('header'):
3094            continue
3095
3096        if const['type'] != 'const':
3097            cw.writes_defines(defines)
3098            defines = []
3099            cw.nl()
3100
3101        # Write kdoc for enum and flags (one day maybe also structs)
3102        if const['type'] == 'enum' or const['type'] == 'flags':
3103            enum = family.consts[const['name']]
3104
3105            if enum.header:
3106                continue
3107
3108            if enum.has_doc():
3109                if enum.has_entry_doc():
3110                    cw.p('/**')
3111                    doc = ''
3112                    if 'doc' in enum:
3113                        doc = ' - ' + enum['doc']
3114                    cw.write_doc_line(enum.enum_name + doc)
3115                else:
3116                    cw.p('/*')
3117                    cw.write_doc_line(enum['doc'], indent=False)
3118                for entry in enum.entries.values():
3119                    if entry.has_doc():
3120                        doc = '@' + entry.c_name + ': ' + entry['doc']
3121                        cw.write_doc_line(doc)
3122                cw.p(' */')
3123
3124            uapi_enum_start(family, cw, const, 'name')
3125            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3126            for entry in enum.entries.values():
3127                suffix = ','
3128                if entry.value_change:
3129                    suffix = f" = {entry.user_value()}" + suffix
3130                cw.p(entry.c_name + suffix)
3131
3132            if const.get('render-max', False):
3133                cw.nl()
3134                cw.p('/* private: */')
3135                if const['type'] == 'flags':
3136                    max_name = c_upper(name_pfx + 'mask')
3137                    max_val = f' = {enum.get_mask()},'
3138                    cw.p(max_name + max_val)
3139                else:
3140                    cnt_name = enum.enum_cnt_name
3141                    max_name = c_upper(name_pfx + 'max')
3142                    if not cnt_name:
3143                        cnt_name = '__' + name_pfx + 'max'
3144                    cw.p(c_upper(cnt_name) + ',')
3145                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3146            cw.block_end(line=';')
3147            cw.nl()
3148        elif const['type'] == 'const':
3149            defines.append([c_upper(family.get('c-define-name',
3150                                               f"{family.ident_name}-{const['name']}")),
3151                            const['value']])
3152
3153    if defines:
3154        cw.writes_defines(defines)
3155        cw.nl()
3156
3157    max_by_define = family.get('max-by-define', False)
3158
3159    for _, attr_set in family.attr_sets.items():
3160        if attr_set.subset_of:
3161            continue
3162
3163        max_value = f"({attr_set.cnt_name} - 1)"
3164
3165        val = 0
3166        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3167        for _, attr in attr_set.items():
3168            suffix = ','
3169            if attr.value != val:
3170                suffix = f" = {attr.value},"
3171                val = attr.value
3172            val += 1
3173            cw.p(attr.enum_name + suffix)
3174        if attr_set.items():
3175            cw.nl()
3176        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3177        if not max_by_define:
3178            cw.p(f"{attr_set.max_name} = {max_value}")
3179        cw.block_end(line=';')
3180        if max_by_define:
3181            cw.p(f"#define {attr_set.max_name} {max_value}")
3182        cw.nl()
3183
3184    # Commands
3185    separate_ntf = 'async-prefix' in family['operations']
3186
3187    if family.msg_id_model == 'unified':
3188        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3189    elif family.msg_id_model == 'directional':
3190        render_uapi_directional(family, cw, max_by_define)
3191    else:
3192        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3193
3194    if separate_ntf:
3195        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3196        for op in family.msgs.values():
3197            if separate_ntf and not ('notify' in op or 'event' in op):
3198                continue
3199
3200            suffix = ','
3201            if 'value' in op:
3202                suffix = f" = {op['value']},"
3203            cw.p(op.enum_name + suffix)
3204        cw.block_end(line=';')
3205        cw.nl()
3206
3207    # Multicast
3208    defines = []
3209    for grp in family.mcgrps['list']:
3210        name = grp['name']
3211        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3212                        f'{name}'])
3213    cw.nl()
3214    if defines:
3215        cw.writes_defines(defines)
3216        cw.nl()
3217
3218    cw.p(f'#endif /* {hdr_prot} */')
3219
3220
3221def _render_user_ntf_entry(ri, op):
3222    if not ri.family.is_classic():
3223        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3224    else:
3225        crud_op = ri.family.req_by_value[op.rsp_value]
3226        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3227    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3228    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3229    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3230    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3231    ri.cw.block_end(line=',')
3232
3233
3234def render_user_family(family, cw, prototype):
3235    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3236    if prototype:
3237        cw.p(f'extern {symbol};')
3238        return
3239
3240    if family.ntfs:
3241        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3242        for ntf_op_name, ntf_op in family.ntfs.items():
3243            if 'notify' in ntf_op:
3244                op = family.ops[ntf_op['notify']]
3245                ri = RenderInfo(cw, family, "user", op, "notify")
3246            elif 'event' in ntf_op:
3247                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3248            else:
3249                raise Exception('Invalid notification ' + ntf_op_name)
3250            _render_user_ntf_entry(ri, ntf_op)
3251        for op_name, op in family.ops.items():
3252            if 'event' not in op:
3253                continue
3254            ri = RenderInfo(cw, family, "user", op, "event")
3255            _render_user_ntf_entry(ri, op)
3256        cw.block_end(line=";")
3257        cw.nl()
3258
3259    cw.block_start(f'{symbol} = ')
3260    cw.p(f'.name\t\t= "{family.c_name}",')
3261    if family.is_classic():
3262        cw.p(f'.is_classic\t= true,')
3263        cw.p(f'.classic_id\t= {family.get("protonum")},')
3264    if family.is_classic():
3265        if family.fixed_header:
3266            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3267    elif family.fixed_header:
3268        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3269    else:
3270        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3271    if family.ntfs:
3272        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3273        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3274    cw.block_end(line=';')
3275
3276
3277def family_contains_bitfield32(family):
3278    for _, attr_set in family.attr_sets.items():
3279        if attr_set.subset_of:
3280            continue
3281        for _, attr in attr_set.items():
3282            if attr.type == "bitfield32":
3283                return True
3284    return False
3285
3286
3287def find_kernel_root(full_path):
3288    sub_path = ''
3289    while True:
3290        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3291        full_path = os.path.dirname(full_path)
3292        maintainers = os.path.join(full_path, "MAINTAINERS")
3293        if os.path.exists(maintainers):
3294            return full_path, sub_path[:-1]
3295
3296
3297def main():
3298    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3299    parser.add_argument('--mode', dest='mode', type=str, required=True,
3300                        choices=('user', 'kernel', 'uapi'))
3301    parser.add_argument('--spec', dest='spec', type=str, required=True)
3302    parser.add_argument('--header', dest='header', action='store_true', default=None)
3303    parser.add_argument('--source', dest='header', action='store_false')
3304    parser.add_argument('--user-header', nargs='+', default=[])
3305    parser.add_argument('--cmp-out', action='store_true', default=None,
3306                        help='Do not overwrite the output file if the new output is identical to the old')
3307    parser.add_argument('--exclude-op', action='append', default=[])
3308    parser.add_argument('-o', dest='out_file', type=str, default=None)
3309    args = parser.parse_args()
3310
3311    if args.header is None:
3312        parser.error("--header or --source is required")
3313
3314    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3315
3316    try:
3317        parsed = Family(args.spec, exclude_ops)
3318        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3319            print('Spec license:', parsed.license)
3320            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3321            os.sys.exit(1)
3322    except yaml.YAMLError as exc:
3323        print(exc)
3324        os.sys.exit(1)
3325        return
3326
3327    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3328
3329    _, spec_kernel = find_kernel_root(args.spec)
3330    if args.mode == 'uapi' or args.header:
3331        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3332    else:
3333        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3334    cw.p("/* Do not edit directly, auto-generated from: */")
3335    cw.p(f"/*\t{spec_kernel} */")
3336    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3337    if args.exclude_op or args.user_header:
3338        line = ''
3339        line += ' --user-header '.join([''] + args.user_header)
3340        line += ' --exclude-op '.join([''] + args.exclude_op)
3341        cw.p(f'/* YNL-ARG{line} */')
3342    cw.nl()
3343
3344    if args.mode == 'uapi':
3345        render_uapi(parsed, cw)
3346        return
3347
3348    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3349    if args.header:
3350        cw.p('#ifndef ' + hdr_prot)
3351        cw.p('#define ' + hdr_prot)
3352        cw.nl()
3353
3354    if args.out_file:
3355        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3356    else:
3357        hdr_file = "generated_header_file.h"
3358
3359    if args.mode == 'kernel':
3360        cw.p('#include <net/netlink.h>')
3361        cw.p('#include <net/genetlink.h>')
3362        cw.nl()
3363        if not args.header:
3364            if args.out_file:
3365                cw.p(f'#include "{hdr_file}"')
3366            cw.nl()
3367        headers = ['uapi/' + parsed.uapi_header]
3368        headers += parsed.kernel_family.get('headers', [])
3369    else:
3370        cw.p('#include <stdlib.h>')
3371        cw.p('#include <string.h>')
3372        if args.header:
3373            cw.p('#include <linux/types.h>')
3374            if family_contains_bitfield32(parsed):
3375                cw.p('#include <linux/netlink.h>')
3376        else:
3377            cw.p(f'#include "{hdr_file}"')
3378            cw.p('#include "ynl.h"')
3379        headers = []
3380    for definition in parsed['definitions'] + parsed['attribute-sets']:
3381        if 'header' in definition:
3382            headers.append(definition['header'])
3383    if args.mode == 'user':
3384        headers.append(parsed.uapi_header)
3385    seen_header = []
3386    for one in headers:
3387        if one not in seen_header:
3388            cw.p(f"#include <{one}>")
3389            seen_header.append(one)
3390    cw.nl()
3391
3392    if args.mode == "user":
3393        if not args.header:
3394            cw.p("#include <linux/genetlink.h>")
3395            cw.nl()
3396            for one in args.user_header:
3397                cw.p(f'#include "{one}"')
3398        else:
3399            cw.p('struct ynl_sock;')
3400            cw.nl()
3401            render_user_family(parsed, cw, True)
3402        cw.nl()
3403
3404    if args.mode == "kernel":
3405        if args.header:
3406            for _, struct in sorted(parsed.pure_nested_structs.items()):
3407                if struct.request:
3408                    cw.p('/* Common nested types */')
3409                    break
3410            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3411                if struct.request:
3412                    print_req_policy_fwd(cw, struct)
3413            cw.nl()
3414
3415            if parsed.kernel_policy == 'global':
3416                cw.p(f"/* Global operation policy for {parsed.name} */")
3417
3418                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3419                print_req_policy_fwd(cw, struct)
3420                cw.nl()
3421
3422            if parsed.kernel_policy in {'per-op', 'split'}:
3423                for op_name, op in parsed.ops.items():
3424                    if 'do' in op and 'event' not in op:
3425                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3426                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3427                        cw.nl()
3428
3429            print_kernel_op_table_hdr(parsed, cw)
3430            print_kernel_mcgrp_hdr(parsed, cw)
3431            print_kernel_family_struct_hdr(parsed, cw)
3432        else:
3433            print_kernel_policy_ranges(parsed, cw)
3434            print_kernel_policy_sparse_enum_validates(parsed, cw)
3435
3436            for _, struct in sorted(parsed.pure_nested_structs.items()):
3437                if struct.request:
3438                    cw.p('/* Common nested types */')
3439                    break
3440            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3441                if struct.request:
3442                    print_req_policy(cw, struct)
3443            cw.nl()
3444
3445            if parsed.kernel_policy == 'global':
3446                cw.p(f"/* Global operation policy for {parsed.name} */")
3447
3448                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3449                print_req_policy(cw, struct)
3450                cw.nl()
3451
3452            for op_name, op in parsed.ops.items():
3453                if parsed.kernel_policy in {'per-op', 'split'}:
3454                    for op_mode in ['do', 'dump']:
3455                        if op_mode in op and 'request' in op[op_mode]:
3456                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3457                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3458                            print_req_policy(cw, ri.struct['request'], ri=ri)
3459                            cw.nl()
3460
3461            print_kernel_op_table(parsed, cw)
3462            print_kernel_mcgrp_src(parsed, cw)
3463            print_kernel_family_struct_src(parsed, cw)
3464
3465    if args.mode == "user":
3466        if args.header:
3467            cw.p('/* Enums */')
3468            put_op_name_fwd(parsed, cw)
3469
3470            for name, const in parsed.consts.items():
3471                if isinstance(const, EnumSet):
3472                    put_enum_to_str_fwd(parsed, cw, const)
3473            cw.nl()
3474
3475            cw.p('/* Common nested types */')
3476            for attr_set, struct in parsed.pure_nested_structs.items():
3477                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3478                print_type_full(ri, struct)
3479                if struct.request and struct.in_multi_val:
3480                    free_rsp_nested_prototype(ri)
3481                    cw.nl()
3482
3483            for op_name, op in parsed.ops.items():
3484                cw.p(f"/* ============== {op.enum_name} ============== */")
3485
3486                if 'do' in op and 'event' not in op:
3487                    cw.p(f"/* {op.enum_name} - do */")
3488                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3489                    print_req_type(ri)
3490                    print_req_type_helpers(ri)
3491                    cw.nl()
3492                    print_rsp_type(ri)
3493                    print_rsp_type_helpers(ri)
3494                    cw.nl()
3495                    print_req_prototype(ri)
3496                    cw.nl()
3497
3498                if 'dump' in op:
3499                    cw.p(f"/* {op.enum_name} - dump */")
3500                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3501                    print_req_type(ri)
3502                    print_req_type_helpers(ri)
3503                    if not ri.type_consistent or ri.type_oneside:
3504                        print_rsp_type(ri)
3505                    print_wrapped_type(ri)
3506                    print_dump_prototype(ri)
3507                    cw.nl()
3508
3509                if op.has_ntf:
3510                    cw.p(f"/* {op.enum_name} - notify */")
3511                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3512                    if not ri.type_consistent:
3513                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3514                    print_wrapped_type(ri)
3515
3516            for op_name, op in parsed.ntfs.items():
3517                if 'event' in op:
3518                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3519                    cw.p(f"/* {op.enum_name} - event */")
3520                    print_rsp_type(ri)
3521                    cw.nl()
3522                    print_wrapped_type(ri)
3523            cw.nl()
3524        else:
3525            cw.p('/* Enums */')
3526            put_op_name(parsed, cw)
3527
3528            for name, const in parsed.consts.items():
3529                if isinstance(const, EnumSet):
3530                    put_enum_to_str(parsed, cw, const)
3531            cw.nl()
3532
3533            has_recursive_nests = False
3534            cw.p('/* Policies */')
3535            for struct in parsed.pure_nested_structs.values():
3536                if struct.recursive:
3537                    put_typol_fwd(cw, struct)
3538                    has_recursive_nests = True
3539            if has_recursive_nests:
3540                cw.nl()
3541            for struct in parsed.pure_nested_structs.values():
3542                put_typol(cw, struct)
3543            for name in parsed.root_sets:
3544                struct = Struct(parsed, name)
3545                put_typol(cw, struct)
3546
3547            cw.p('/* Common nested types */')
3548            if has_recursive_nests:
3549                for attr_set, struct in parsed.pure_nested_structs.items():
3550                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3551                    free_rsp_nested_prototype(ri)
3552                    if struct.request:
3553                        put_req_nested_prototype(ri, struct)
3554                    if struct.reply:
3555                        parse_rsp_nested_prototype(ri, struct)
3556                cw.nl()
3557            for attr_set, struct in parsed.pure_nested_structs.items():
3558                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3559
3560                free_rsp_nested(ri, struct)
3561                if struct.request:
3562                    put_req_nested(ri, struct)
3563                if struct.reply:
3564                    parse_rsp_nested(ri, struct)
3565
3566            for op_name, op in parsed.ops.items():
3567                cw.p(f"/* ============== {op.enum_name} ============== */")
3568                if 'do' in op and 'event' not in op:
3569                    cw.p(f"/* {op.enum_name} - do */")
3570                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3571                    print_req_free(ri)
3572                    print_rsp_free(ri)
3573                    parse_rsp_msg(ri)
3574                    print_req(ri)
3575                    cw.nl()
3576
3577                if 'dump' in op:
3578                    cw.p(f"/* {op.enum_name} - dump */")
3579                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3580                    if not ri.type_consistent or ri.type_oneside:
3581                        parse_rsp_msg(ri, deref=True)
3582                    print_req_free(ri)
3583                    print_dump_type_free(ri)
3584                    print_dump(ri)
3585                    cw.nl()
3586
3587                if op.has_ntf:
3588                    cw.p(f"/* {op.enum_name} - notify */")
3589                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3590                    if not ri.type_consistent:
3591                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3592                    print_ntf_type_free(ri)
3593
3594            for op_name, op in parsed.ntfs.items():
3595                if 'event' in op:
3596                    cw.p(f"/* {op.enum_name} - event */")
3597
3598                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3599                    parse_rsp_msg(ri)
3600
3601                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3602                    print_ntf_type_free(ri)
3603            cw.nl()
3604            render_user_family(parsed, cw, False)
3605
3606    if args.header:
3607        cw.p(f'#endif /* {hdr_prot} */')
3608
3609
3610if __name__ == "__main__":
3611    main()
3612