xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision a8a9fd042e0995ed63d33f507c26baf56031e581)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17from lib import SpecSubMessage, SpecSubMessageFormat
18
19
20def c_upper(name):
21    return name.upper().replace('-', '_')
22
23
24def c_lower(name):
25    return name.lower().replace('-', '_')
26
27
28def limit_to_number(name):
29    """
30    Turn a string limit like u32-max or s64-min into its numerical value
31    """
32    if name[0] == 'u' and name.endswith('-min'):
33        return 0
34    width = int(name[1:-4])
35    if name[0] == 's':
36        width -= 1
37    value = (1 << width) - 1
38    if name[0] == 's' and name.endswith('-min'):
39        value = -value - 1
40    return value
41
42
43class BaseNlLib:
44    def get_family_id(self):
45        return 'ys->family_id'
46
47
48class Type(SpecAttr):
49    def __init__(self, family, attr_set, attr, value):
50        super().__init__(family, attr_set, attr, value)
51
52        self.attr = attr
53        self.attr_set = attr_set
54        self.type = attr['type']
55        self.checks = attr.get('checks', {})
56
57        self.request = False
58        self.reply = False
59
60        self.is_selector = False
61
62        if 'len' in attr:
63            self.len = attr['len']
64
65        if 'nested-attributes' in attr:
66            nested = attr['nested-attributes']
67        elif 'sub-message' in attr:
68            nested = attr['sub-message']
69        else:
70            nested = None
71
72        if nested:
73            self.nested_attrs = nested
74            if self.nested_attrs == family.name:
75                self.nested_render_name = c_lower(f"{family.ident_name}")
76            else:
77                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
78
79            if self.nested_attrs in self.family.consts:
80                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
81            else:
82                self.nested_struct_type = 'struct ' + self.nested_render_name
83
84        self.c_name = c_lower(self.name)
85        if self.c_name in _C_KW:
86            self.c_name += '_'
87        if self.c_name[0].isdigit():
88            self.c_name = '_' + self.c_name
89
90        # Added by resolve():
91        self.enum_name = None
92        delattr(self, "enum_name")
93
94    def _get_real_attr(self):
95        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
96        return self.family.attr_sets[self.attr_set.subset_of][self.name]
97
98    def set_request(self):
99        self.request = True
100        if self.attr_set.subset_of:
101            self._get_real_attr().set_request()
102
103    def set_reply(self):
104        self.reply = True
105        if self.attr_set.subset_of:
106            self._get_real_attr().set_reply()
107
108    def get_limit(self, limit, default=None):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return value
112        if isinstance(value, int):
113            return value
114        if value in self.family.consts:
115            return self.family.consts[value]["value"]
116        return limit_to_number(value)
117
118    def get_limit_str(self, limit, default=None, suffix=''):
119        value = self.checks.get(limit, default)
120        if value is None:
121            return ''
122        if isinstance(value, int):
123            return str(value) + suffix
124        if value in self.family.consts:
125            const = self.family.consts[value]
126            if const.get('header'):
127                return c_upper(value)
128            return c_upper(f"{self.family['name']}-{value}")
129        return c_upper(value)
130
131    def resolve(self):
132        if 'parent-sub-message' in self.attr:
133            enum_name = self.attr['parent-sub-message'].enum_name
134        elif 'name-prefix' in self.attr:
135            enum_name = f"{self.attr['name-prefix']}{self.name}"
136        else:
137            enum_name = f"{self.attr_set.name_prefix}{self.name}"
138        self.enum_name = c_upper(enum_name)
139
140        if self.attr_set.subset_of:
141            if self.checks != self._get_real_attr().checks:
142                raise Exception("Overriding checks not supported by codegen, yet")
143
144    def is_multi_val(self):
145        return None
146
147    def is_scalar(self):
148        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
149
150    def is_recursive(self):
151        return False
152
153    def is_recursive_for_op(self, ri):
154        return self.is_recursive() and not ri.op
155
156    def presence_type(self):
157        return 'present'
158
159    def presence_member(self, space, type_filter):
160        if self.presence_type() != type_filter:
161            return
162
163        if self.presence_type() == 'present':
164            pfx = '__' if space == 'user' else ''
165            return f"{pfx}u32 {self.c_name}:1;"
166
167        if self.presence_type() in {'len', 'count'}:
168            pfx = '__' if space == 'user' else ''
169            return f"{pfx}u32 {self.c_name};"
170
171    def _complex_member_type(self, ri):
172        return None
173
174    def free_needs_iter(self):
175        return False
176
177    def _free_lines(self, ri, var, ref):
178        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
179            return [f'free({var}->{ref}{self.c_name});']
180        return []
181
182    def free(self, ri, var, ref):
183        lines = self._free_lines(ri, var, ref)
184        for line in lines:
185            ri.cw.p(line)
186
187    def arg_member(self, ri):
188        member = self._complex_member_type(ri)
189        if member:
190            spc = ' ' if member[-1] != '*' else ''
191            arg = [member + spc + '*' + self.c_name]
192            if self.presence_type() == 'count':
193                arg += ['unsigned int n_' + self.c_name]
194            return arg
195        raise Exception(f"Struct member not implemented for class type {self.type}")
196
197    def struct_member(self, ri):
198        member = self._complex_member_type(ri)
199        if member:
200            ptr = '*' if self.is_multi_val() else ''
201            if self.is_recursive_for_op(ri):
202                ptr = '*'
203            spc = ' ' if member[-1] != '*' else ''
204            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
205            return
206        members = self.arg_member(ri)
207        for one in members:
208            ri.cw.p(one + ';')
209
210    def _attr_policy(self, policy):
211        return '{ .type = ' + policy + ', }'
212
213    def attr_policy(self, cw):
214        policy = f'NLA_{c_upper(self.type)}'
215        if self.attr.get('byte-order') == 'big-endian':
216            if self.type in {'u16', 'u32'}:
217                policy = f'NLA_BE{self.type[1:]}'
218
219        spec = self._attr_policy(policy)
220        cw.p(f"\t[{self.enum_name}] = {spec},")
221
222    def _attr_typol(self):
223        raise Exception(f"Type policy not implemented for class type {self.type}")
224
225    def attr_typol(self, cw):
226        typol = self._attr_typol()
227        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
228
229    def _attr_put_line(self, ri, var, line):
230        presence = self.presence_type()
231        if presence in {'present', 'len'}:
232            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
233        ri.cw.p(f"{line};")
234
235    def _attr_put_simple(self, ri, var, put_type):
236        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
237        self._attr_put_line(ri, var, line)
238
239    def attr_put(self, ri, var):
240        raise Exception(f"Put not implemented for class type {self.type}")
241
242    def _attr_get(self, ri, var):
243        raise Exception(f"Attr get not implemented for class type {self.type}")
244
245    def attr_get(self, ri, var, first):
246        lines, init_lines, local_vars = self._attr_get(ri, var)
247        if type(lines) is str:
248            lines = [lines]
249        if type(init_lines) is str:
250            init_lines = [init_lines]
251
252        kw = 'if' if first else 'else if'
253        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
254        if local_vars:
255            for local in local_vars:
256                ri.cw.p(local)
257            ri.cw.nl()
258
259        if not self.is_multi_val():
260            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
261            ri.cw.p("return YNL_PARSE_CB_ERROR;")
262            if self.presence_type() == 'present':
263                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
264
265        if init_lines:
266            ri.cw.nl()
267            for line in init_lines:
268                ri.cw.p(line)
269
270        for line in lines:
271            ri.cw.p(line)
272        ri.cw.block_end()
273        return True
274
275    def _setter_lines(self, ri, member, presence):
276        raise Exception(f"Setter not implemented for class type {self.type}")
277
278    def setter(self, ri, space, direction, deref=False, ref=None):
279        ref = (ref if ref else []) + [self.c_name]
280        var = "req"
281        member = f"{var}->{'.'.join(ref)}"
282
283        local_vars = []
284        if self.free_needs_iter():
285            local_vars += ['unsigned int i;']
286
287        code = []
288        presence = ''
289        for i in range(0, len(ref)):
290            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
291            # Every layer below last is a nest, so we know it uses bit presence
292            # last layer is "self" and may be a complex type
293            if i == len(ref) - 1 and self.presence_type() != 'present':
294                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
295                continue
296            code.append(presence + ' = 1;')
297        ref_path = '.'.join(ref[:-1])
298        if ref_path:
299            ref_path += '.'
300        code += self._free_lines(ri, var, ref_path)
301        code += self._setter_lines(ri, member, presence)
302
303        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
304        free = bool([x for x in code if 'free(' in x])
305        alloc = bool([x for x in code if 'alloc(' in x])
306        if free and not alloc:
307            func_name = '__' + func_name
308        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
309                         body=code,
310                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
311
312
313class TypeUnused(Type):
314    def presence_type(self):
315        return ''
316
317    def arg_member(self, ri):
318        return []
319
320    def _attr_get(self, ri, var):
321        return ['return YNL_PARSE_CB_ERROR;'], None, None
322
323    def _attr_typol(self):
324        return '.type = YNL_PT_REJECT, '
325
326    def attr_policy(self, cw):
327        pass
328
329    def attr_put(self, ri, var):
330        pass
331
332    def attr_get(self, ri, var, first):
333        pass
334
335    def setter(self, ri, space, direction, deref=False, ref=None):
336        pass
337
338
339class TypePad(Type):
340    def presence_type(self):
341        return ''
342
343    def arg_member(self, ri):
344        return []
345
346    def _attr_typol(self):
347        return '.type = YNL_PT_IGNORE, '
348
349    def attr_put(self, ri, var):
350        pass
351
352    def attr_get(self, ri, var, first):
353        pass
354
355    def attr_policy(self, cw):
356        pass
357
358    def setter(self, ri, space, direction, deref=False, ref=None):
359        pass
360
361
362class TypeScalar(Type):
363    def __init__(self, family, attr_set, attr, value):
364        super().__init__(family, attr_set, attr, value)
365
366        self.byte_order_comment = ''
367        if 'byte-order' in attr:
368            self.byte_order_comment = f" /* {attr['byte-order']} */"
369
370        # Classic families have some funny enums, don't bother
371        # computing checks, since we only need them for kernel policies
372        if not family.is_classic():
373            self._init_checks()
374
375        # Added by resolve():
376        self.is_bitfield = None
377        delattr(self, "is_bitfield")
378        self.type_name = None
379        delattr(self, "type_name")
380
381    def resolve(self):
382        self.resolve_up(super())
383
384        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
385            self.is_bitfield = True
386        elif 'enum' in self.attr:
387            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
388        else:
389            self.is_bitfield = False
390
391        if not self.is_bitfield and 'enum' in self.attr:
392            self.type_name = self.family.consts[self.attr['enum']].user_type
393        elif self.is_auto_scalar:
394            self.type_name = '__' + self.type[0] + '64'
395        else:
396            self.type_name = '__' + self.type
397
398    def _init_checks(self):
399        if 'enum' in self.attr:
400            enum = self.family.consts[self.attr['enum']]
401            low, high = enum.value_range()
402            if low == None and high == None:
403                self.checks['sparse'] = True
404            else:
405                if 'min' not in self.checks:
406                    if low != 0 or self.type[0] == 's':
407                        self.checks['min'] = low
408                if 'max' not in self.checks:
409                    self.checks['max'] = high
410
411        if 'min' in self.checks and 'max' in self.checks:
412            if self.get_limit('min') > self.get_limit('max'):
413                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
414            self.checks['range'] = True
415
416        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
417        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
418        if low < 0 and self.type[0] == 'u':
419            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
420        if low < -32768 or high > 32767:
421            self.checks['full-range'] = True
422
423    def _attr_policy(self, policy):
424        if 'flags-mask' in self.checks or self.is_bitfield:
425            if self.is_bitfield:
426                enum = self.family.consts[self.attr['enum']]
427                mask = enum.get_mask(as_flags=True)
428            else:
429                flags = self.family.consts[self.checks['flags-mask']]
430                flag_cnt = len(flags['entries'])
431                mask = (1 << flag_cnt) - 1
432            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
433        elif 'full-range' in self.checks:
434            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
435        elif 'range' in self.checks:
436            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
437        elif 'min' in self.checks:
438            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
439        elif 'max' in self.checks:
440            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
441        elif 'sparse' in self.checks:
442            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
443        return super()._attr_policy(policy)
444
445    def _attr_typol(self):
446        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
447
448    def arg_member(self, ri):
449        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
450
451    def attr_put(self, ri, var):
452        self._attr_put_simple(ri, var, self.type)
453
454    def _attr_get(self, ri, var):
455        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
456
457    def _setter_lines(self, ri, member, presence):
458        return [f"{member} = {self.c_name};"]
459
460
461class TypeFlag(Type):
462    def arg_member(self, ri):
463        return []
464
465    def _attr_typol(self):
466        return '.type = YNL_PT_FLAG, '
467
468    def attr_put(self, ri, var):
469        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
470
471    def _attr_get(self, ri, var):
472        return [], None, None
473
474    def _setter_lines(self, ri, member, presence):
475        return []
476
477
478class TypeString(Type):
479    def arg_member(self, ri):
480        return [f"const char *{self.c_name}"]
481
482    def presence_type(self):
483        return 'len'
484
485    def struct_member(self, ri):
486        ri.cw.p(f"char *{self.c_name};")
487
488    def _attr_typol(self):
489        typol = f'.type = YNL_PT_NUL_STR, '
490        if self.is_selector:
491            typol += '.is_selector = 1, '
492        return typol
493
494    def _attr_policy(self, policy):
495        if 'exact-len' in self.checks:
496            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
497        else:
498            mem = '{ .type = ' + policy
499            if 'max-len' in self.checks:
500                mem += ', .len = ' + self.get_limit_str('max-len')
501            mem += ', }'
502        return mem
503
504    def attr_policy(self, cw):
505        if self.checks.get('unterminated-ok', False):
506            policy = 'NLA_STRING'
507        else:
508            policy = 'NLA_NUL_STRING'
509
510        spec = self._attr_policy(policy)
511        cw.p(f"\t[{self.enum_name}] = {spec},")
512
513    def attr_put(self, ri, var):
514        self._attr_put_simple(ri, var, 'str')
515
516    def _attr_get(self, ri, var):
517        len_mem = var + '->_len.' + self.c_name
518        return [f"{len_mem} = len;",
519                f"{var}->{self.c_name} = malloc(len + 1);",
520                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
521                f"{var}->{self.c_name}[len] = 0;"], \
522               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
523               ['unsigned int len;']
524
525    def _setter_lines(self, ri, member, presence):
526        return [f"{presence} = strlen({self.c_name});",
527                f"{member} = malloc({presence} + 1);",
528                f'memcpy({member}, {self.c_name}, {presence});',
529                f'{member}[{presence}] = 0;']
530
531
532class TypeBinary(Type):
533    def arg_member(self, ri):
534        return [f"const void *{self.c_name}", 'size_t len']
535
536    def presence_type(self):
537        return 'len'
538
539    def struct_member(self, ri):
540        ri.cw.p(f"void *{self.c_name};")
541
542    def _attr_typol(self):
543        return f'.type = YNL_PT_BINARY,'
544
545    def _attr_policy(self, policy):
546        if len(self.checks) == 0:
547            pass
548        elif len(self.checks) == 1:
549            check_name = list(self.checks)[0]
550            if check_name not in {'exact-len', 'min-len', 'max-len'}:
551                raise Exception('Unsupported check for binary type: ' + check_name)
552        else:
553            raise Exception('More than one check for binary type not implemented, yet')
554
555        if len(self.checks) == 0:
556            mem = '{ .type = NLA_BINARY, }'
557        elif 'exact-len' in self.checks:
558            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
559        elif 'min-len' in self.checks:
560            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
561        elif 'max-len' in self.checks:
562            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
563
564        return mem
565
566    def attr_put(self, ri, var):
567        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
568                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
569
570    def _attr_get(self, ri, var):
571        len_mem = var + '->_len.' + self.c_name
572        return [f"{len_mem} = len;",
573                f"{var}->{self.c_name} = malloc(len);",
574                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
575               ['len = ynl_attr_data_len(attr);'], \
576               ['unsigned int len;']
577
578    def _setter_lines(self, ri, member, presence):
579        return [f"{presence} = len;",
580                f"{member} = malloc({presence});",
581                f'memcpy({member}, {self.c_name}, {presence});']
582
583
584class TypeBinaryStruct(TypeBinary):
585    def struct_member(self, ri):
586        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
587
588    def _attr_get(self, ri, var):
589        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
590        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
591        return [f"{len_mem} = len;",
592                f"if (len < {struct_sz})",
593                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
594                "else",
595                f"{var}->{self.c_name} = malloc(len);",
596                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
597               ['len = ynl_attr_data_len(attr);'], \
598               ['unsigned int len;']
599
600
601class TypeBinaryScalarArray(TypeBinary):
602    def arg_member(self, ri):
603        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
604
605    def presence_type(self):
606        return 'count'
607
608    def struct_member(self, ri):
609        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
610
611    def attr_put(self, ri, var):
612        presence = self.presence_type()
613        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
614        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
615        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
616                f"{var}->{self.c_name}, i);")
617        ri.cw.block_end()
618
619    def _attr_get(self, ri, var):
620        len_mem = var + '->_count.' + self.c_name
621        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
622                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
623                f"{var}->{self.c_name} = malloc(len);",
624                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
625               ['len = ynl_attr_data_len(attr);'], \
626               ['unsigned int len;']
627
628    def _setter_lines(self, ri, member, presence):
629        return [f"{presence} = count;",
630                f"count *= sizeof(__{self.get('sub-type')});",
631                f"{member} = malloc(count);",
632                f'memcpy({member}, {self.c_name}, count);']
633
634
635class TypeBitfield32(Type):
636    def _complex_member_type(self, ri):
637        return "struct nla_bitfield32"
638
639    def _attr_typol(self):
640        return f'.type = YNL_PT_BITFIELD32, '
641
642    def _attr_policy(self, policy):
643        if not 'enum' in self.attr:
644            raise Exception('Enum required for bitfield32 attr')
645        enum = self.family.consts[self.attr['enum']]
646        mask = enum.get_mask(as_flags=True)
647        return f"NLA_POLICY_BITFIELD32({mask})"
648
649    def attr_put(self, ri, var):
650        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
651        self._attr_put_line(ri, var, line)
652
653    def _attr_get(self, ri, var):
654        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
655
656    def _setter_lines(self, ri, member, presence):
657        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
658
659
660class TypeNest(Type):
661    def is_recursive(self):
662        return self.family.pure_nested_structs[self.nested_attrs].recursive
663
664    def _complex_member_type(self, ri):
665        return self.nested_struct_type
666
667    def _free_lines(self, ri, var, ref):
668        lines = []
669        at = '&'
670        if self.is_recursive_for_op(ri):
671            at = ''
672            lines += [f'if ({var}->{ref}{self.c_name})']
673        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
674        return lines
675
676    def _attr_typol(self):
677        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
678
679    def _attr_policy(self, policy):
680        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
681
682    def attr_put(self, ri, var):
683        at = '' if self.is_recursive_for_op(ri) else '&'
684        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
685                            f"{self.enum_name}, {at}{var}->{self.c_name})")
686
687    def _attr_get(self, ri, var):
688        pns = self.family.pure_nested_structs[self.nested_attrs]
689        args = ["&parg", "attr"]
690        for sel in pns.external_selectors():
691            args.append(f'{var}->{sel.name}')
692        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
693                     "return YNL_PARSE_CB_ERROR;"]
694        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
695                      f"parg.data = &{var}->{self.c_name};"]
696        return get_lines, init_lines, None
697
698    def setter(self, ri, space, direction, deref=False, ref=None):
699        ref = (ref if ref else []) + [self.c_name]
700
701        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
702            if attr.is_recursive():
703                continue
704            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
705
706
707class TypeMultiAttr(Type):
708    def __init__(self, family, attr_set, attr, value, base_type):
709        super().__init__(family, attr_set, attr, value)
710
711        self.base_type = base_type
712
713    def is_multi_val(self):
714        return True
715
716    def presence_type(self):
717        return 'count'
718
719    def _complex_member_type(self, ri):
720        if 'type' not in self.attr or self.attr['type'] == 'nest':
721            return self.nested_struct_type
722        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
723            return None  # use arg_member()
724        elif self.attr['type'] == 'string':
725            return 'struct ynl_string *'
726        elif self.attr['type'] in scalars:
727            scalar_pfx = '__' if ri.ku_space == 'user' else ''
728            return scalar_pfx + self.attr['type']
729        else:
730            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
731
732    def arg_member(self, ri):
733        if self.type == 'binary' and 'struct' in self.attr:
734            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
735                    f'unsigned int n_{self.c_name}']
736        return super().arg_member(ri)
737
738    def free_needs_iter(self):
739        return self.attr['type'] in {'nest', 'string'}
740
741    def _free_lines(self, ri, var, ref):
742        lines = []
743        if self.attr['type'] in scalars:
744            lines += [f"free({var}->{ref}{self.c_name});"]
745        elif self.attr['type'] == 'binary':
746            lines += [f"free({var}->{ref}{self.c_name});"]
747        elif self.attr['type'] == 'string':
748            lines += [
749                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
750                f"free({var}->{ref}{self.c_name}[i]);",
751                f"free({var}->{ref}{self.c_name});",
752            ]
753        elif 'type' not in self.attr or self.attr['type'] == 'nest':
754            lines += [
755                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
756                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
757                f"free({var}->{ref}{self.c_name});",
758            ]
759        else:
760            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
761        return lines
762
763    def _attr_policy(self, policy):
764        return self.base_type._attr_policy(policy)
765
766    def _attr_typol(self):
767        return self.base_type._attr_typol()
768
769    def _attr_get(self, ri, var):
770        return f'n_{self.c_name}++;', None, None
771
772    def attr_put(self, ri, var):
773        if self.attr['type'] in scalars:
774            put_type = self.type
775            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
776            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
777        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
778            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
779            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
780        elif self.attr['type'] == 'string':
781            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
782            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
783        elif 'type' not in self.attr or self.attr['type'] == 'nest':
784            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
785            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
786                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
787        else:
788            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
789
790    def _setter_lines(self, ri, member, presence):
791        return [f"{member} = {self.c_name};",
792                f"{presence} = n_{self.c_name};"]
793
794
795class TypeArrayNest(Type):
796    def is_multi_val(self):
797        return True
798
799    def presence_type(self):
800        return 'count'
801
802    def _complex_member_type(self, ri):
803        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
804            return self.nested_struct_type
805        elif self.attr['sub-type'] in scalars:
806            scalar_pfx = '__' if ri.ku_space == 'user' else ''
807            return scalar_pfx + self.attr['sub-type']
808        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
809            return None  # use arg_member()
810        else:
811            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
812
813    def arg_member(self, ri):
814        if self.sub_type == 'binary' and 'exact-len' in self.checks:
815            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
816                    f'unsigned int n_{self.c_name}']
817        return super().arg_member(ri)
818
819    def _attr_typol(self):
820        if self.attr['sub-type'] in scalars:
821            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
822        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
823            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
824        else:
825            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
826
827    def _attr_get(self, ri, var):
828        local_vars = ['const struct nlattr *attr2;']
829        get_lines = [f'attr_{self.c_name} = attr;',
830                     'ynl_attr_for_each_nested(attr2, attr) {',
831                     '\tif (ynl_attr_validate(yarg, attr2))',
832                     '\t\treturn YNL_PARSE_CB_ERROR;',
833                     f'\t{var}->_count.{self.c_name}++;',
834                     '}']
835        return get_lines, None, local_vars
836
837    def attr_put(self, ri, var):
838        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
839        if self.sub_type in scalars:
840            put_type = self.sub_type
841            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
842            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
843            ri.cw.block_end()
844        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
845            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
846            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
847        elif self.sub_type == 'nest':
848            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
849            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
850        else:
851            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
852        ri.cw.p('ynl_attr_nest_end(nlh, array);')
853
854    def _setter_lines(self, ri, member, presence):
855        return [f"{member} = {self.c_name};",
856                f"{presence} = n_{self.c_name};"]
857
858
859class TypeNestTypeValue(Type):
860    def _complex_member_type(self, ri):
861        return self.nested_struct_type
862
863    def _attr_typol(self):
864        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
865
866    def _attr_get(self, ri, var):
867        prev = 'attr'
868        tv_args = ''
869        get_lines = []
870        local_vars = []
871        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
872                      f"parg.data = &{var}->{self.c_name};"]
873        if 'type-value' in self.attr:
874            tv_names = [c_lower(x) for x in self.attr["type-value"]]
875            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
876            local_vars += [f'__u32 {", ".join(tv_names)};']
877            for level in self.attr["type-value"]:
878                level = c_lower(level)
879                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
880                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
881                prev = 'attr_' + level
882
883            tv_args = f", {', '.join(tv_names)}"
884
885        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
886        return get_lines, init_lines, local_vars
887
888
889class TypeSubMessage(TypeNest):
890    def __init__(self, family, attr_set, attr, value):
891        super().__init__(family, attr_set, attr, value)
892
893        self.selector = Selector(attr, attr_set)
894
895    def _attr_typol(self):
896        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
897        typol += '.is_submsg = 1, '
898        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
899        # support external selectors. No family uses sub-messages with external
900        # selector for requests so this is fine for now.
901        if not self.selector.is_external():
902            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
903        return typol
904
905    def _attr_get(self, ri, var):
906        sel = c_lower(self['selector'])
907        if self.selector.is_external():
908            sel_var = f"_sel_{sel}"
909        else:
910            sel_var = f"{var}->{sel}"
911        get_lines = [f'if (!{sel_var})',
912                     f'return ynl_submsg_failed(yarg, "%s", "%s");' %
913                        (self.name, self['selector']),
914                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
915                     "return YNL_PARSE_CB_ERROR;"]
916        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
917                      f"parg.data = &{var}->{self.c_name};"]
918        return get_lines, init_lines, None
919
920
921class Selector:
922    def __init__(self, msg_attr, attr_set):
923        self.name = msg_attr["selector"]
924
925        if self.name in attr_set:
926            self.attr = attr_set[self.name]
927            self.attr.is_selector = True
928            self._external = False
929        else:
930            # The selector will need to get passed down thru the structs
931            self.attr = None
932            self._external = True
933
934    def set_attr(self, attr):
935        self.attr = attr
936
937    def is_external(self):
938        return self._external
939
940
941class Struct:
942    def __init__(self, family, space_name, type_list=None, fixed_header=None,
943                 inherited=None, submsg=None):
944        self.family = family
945        self.space_name = space_name
946        self.attr_set = family.attr_sets[space_name]
947        # Use list to catch comparisons with empty sets
948        self._inherited = inherited if inherited is not None else []
949        self.inherited = []
950        self.fixed_header = None
951        if fixed_header:
952            self.fixed_header = 'struct ' + c_lower(fixed_header)
953        self.submsg = submsg
954
955        self.nested = type_list is None
956        if family.name == c_lower(space_name):
957            self.render_name = c_lower(family.ident_name)
958        else:
959            self.render_name = c_lower(family.ident_name + '-' + space_name)
960        self.struct_name = 'struct ' + self.render_name
961        if self.nested and space_name in family.consts:
962            self.struct_name += '_'
963        self.ptr_name = self.struct_name + ' *'
964        # All attr sets this one contains, directly or multiple levels down
965        self.child_nests = set()
966
967        self.request = False
968        self.reply = False
969        self.recursive = False
970        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
971
972        self.attr_list = []
973        self.attrs = dict()
974        if type_list is not None:
975            for t in type_list:
976                self.attr_list.append((t, self.attr_set[t]),)
977        else:
978            for t in self.attr_set:
979                self.attr_list.append((t, self.attr_set[t]),)
980
981        max_val = 0
982        self.attr_max_val = None
983        for name, attr in self.attr_list:
984            if attr.value >= max_val:
985                max_val = attr.value
986                self.attr_max_val = attr
987            self.attrs[name] = attr
988
989    def __iter__(self):
990        yield from self.attrs
991
992    def __getitem__(self, key):
993        return self.attrs[key]
994
995    def member_list(self):
996        return self.attr_list
997
998    def set_inherited(self, new_inherited):
999        if self._inherited != new_inherited:
1000            raise Exception("Inheriting different members not supported")
1001        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1002
1003    def external_selectors(self):
1004        sels = []
1005        for name, attr in self.attr_list:
1006            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1007                sels.append(attr.selector)
1008        return sels
1009
1010    def free_needs_iter(self):
1011        for _, attr in self.attr_list:
1012            if attr.free_needs_iter():
1013                return True
1014        return False
1015
1016
1017class EnumEntry(SpecEnumEntry):
1018    def __init__(self, enum_set, yaml, prev, value_start):
1019        super().__init__(enum_set, yaml, prev, value_start)
1020
1021        if prev:
1022            self.value_change = (self.value != prev.value + 1)
1023        else:
1024            self.value_change = (self.value != 0)
1025        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1026
1027        # Added by resolve:
1028        self.c_name = None
1029        delattr(self, "c_name")
1030
1031    def resolve(self):
1032        self.resolve_up(super())
1033
1034        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1035
1036
1037class EnumSet(SpecEnumSet):
1038    def __init__(self, family, yaml):
1039        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1040
1041        if 'enum-name' in yaml:
1042            if yaml['enum-name']:
1043                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1044                self.user_type = self.enum_name
1045            else:
1046                self.enum_name = None
1047        else:
1048            self.enum_name = 'enum ' + self.render_name
1049
1050        if self.enum_name:
1051            self.user_type = self.enum_name
1052        else:
1053            self.user_type = 'int'
1054
1055        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1056        self.header = yaml.get('header', None)
1057        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1058
1059        super().__init__(family, yaml)
1060
1061    def new_entry(self, entry, prev_entry, value_start):
1062        return EnumEntry(self, entry, prev_entry, value_start)
1063
1064    def value_range(self):
1065        low = min([x.value for x in self.entries.values()])
1066        high = max([x.value for x in self.entries.values()])
1067
1068        if high - low + 1 != len(self.entries):
1069            return None, None
1070
1071        return low, high
1072
1073
1074class AttrSet(SpecAttrSet):
1075    def __init__(self, family, yaml):
1076        super().__init__(family, yaml)
1077
1078        if self.subset_of is None:
1079            if 'name-prefix' in yaml:
1080                pfx = yaml['name-prefix']
1081            elif self.name == family.name:
1082                pfx = family.ident_name + '-a-'
1083            else:
1084                pfx = f"{family.ident_name}-a-{self.name}-"
1085            self.name_prefix = c_upper(pfx)
1086            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1087            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1088        else:
1089            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1090            self.max_name = family.attr_sets[self.subset_of].max_name
1091            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1092
1093        # Added by resolve:
1094        self.c_name = None
1095        delattr(self, "c_name")
1096
1097    def resolve(self):
1098        self.c_name = c_lower(self.name)
1099        if self.c_name in _C_KW:
1100            self.c_name += '_'
1101        if self.c_name == self.family.c_name:
1102            self.c_name = ''
1103
1104    def new_attr(self, elem, value):
1105        if elem['type'] in scalars:
1106            t = TypeScalar(self.family, self, elem, value)
1107        elif elem['type'] == 'unused':
1108            t = TypeUnused(self.family, self, elem, value)
1109        elif elem['type'] == 'pad':
1110            t = TypePad(self.family, self, elem, value)
1111        elif elem['type'] == 'flag':
1112            t = TypeFlag(self.family, self, elem, value)
1113        elif elem['type'] == 'string':
1114            t = TypeString(self.family, self, elem, value)
1115        elif elem['type'] == 'binary':
1116            if 'struct' in elem:
1117                t = TypeBinaryStruct(self.family, self, elem, value)
1118            elif elem.get('sub-type') in scalars:
1119                t = TypeBinaryScalarArray(self.family, self, elem, value)
1120            else:
1121                t = TypeBinary(self.family, self, elem, value)
1122        elif elem['type'] == 'bitfield32':
1123            t = TypeBitfield32(self.family, self, elem, value)
1124        elif elem['type'] == 'nest':
1125            t = TypeNest(self.family, self, elem, value)
1126        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1127            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1128                t = TypeArrayNest(self.family, self, elem, value)
1129            else:
1130                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1131        elif elem['type'] == 'nest-type-value':
1132            t = TypeNestTypeValue(self.family, self, elem, value)
1133        elif elem['type'] == 'sub-message':
1134            t = TypeSubMessage(self.family, self, elem, value)
1135        else:
1136            raise Exception(f"No typed class for type {elem['type']}")
1137
1138        if 'multi-attr' in elem and elem['multi-attr']:
1139            t = TypeMultiAttr(self.family, self, elem, value, t)
1140
1141        return t
1142
1143
1144class Operation(SpecOperation):
1145    def __init__(self, family, yaml, req_value, rsp_value):
1146        # Fill in missing operation properties (for fixed hdr-only msgs)
1147        for mode in ['do', 'dump', 'event']:
1148            for direction in ['request', 'reply']:
1149                try:
1150                    yaml[mode][direction].setdefault('attributes', [])
1151                except KeyError:
1152                    pass
1153
1154        super().__init__(family, yaml, req_value, rsp_value)
1155
1156        self.render_name = c_lower(family.ident_name + '_' + self.name)
1157
1158        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1159                         ('dump' in yaml and 'request' in yaml['dump'])
1160
1161        self.has_ntf = False
1162
1163        # Added by resolve:
1164        self.enum_name = None
1165        delattr(self, "enum_name")
1166
1167    def resolve(self):
1168        self.resolve_up(super())
1169
1170        if not self.is_async:
1171            self.enum_name = self.family.op_prefix + c_upper(self.name)
1172        else:
1173            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1174
1175    def mark_has_ntf(self):
1176        self.has_ntf = True
1177
1178
1179class SubMessage(SpecSubMessage):
1180    def __init__(self, family, yaml):
1181        super().__init__(family, yaml)
1182
1183        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1184
1185    def resolve(self):
1186        self.resolve_up(super())
1187
1188
1189class Family(SpecFamily):
1190    def __init__(self, file_name, exclude_ops):
1191        # Added by resolve:
1192        self.c_name = None
1193        delattr(self, "c_name")
1194        self.op_prefix = None
1195        delattr(self, "op_prefix")
1196        self.async_op_prefix = None
1197        delattr(self, "async_op_prefix")
1198        self.mcgrps = None
1199        delattr(self, "mcgrps")
1200        self.consts = None
1201        delattr(self, "consts")
1202        self.hooks = None
1203        delattr(self, "hooks")
1204
1205        super().__init__(file_name, exclude_ops=exclude_ops)
1206
1207        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1208        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1209
1210        if 'definitions' not in self.yaml:
1211            self.yaml['definitions'] = []
1212
1213        if 'uapi-header' in self.yaml:
1214            self.uapi_header = self.yaml['uapi-header']
1215        else:
1216            self.uapi_header = f"linux/{self.ident_name}.h"
1217        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1218            self.uapi_header_name = self.uapi_header[6:-2]
1219        else:
1220            self.uapi_header_name = self.ident_name
1221
1222    def resolve(self):
1223        self.resolve_up(super())
1224
1225        self.c_name = c_lower(self.ident_name)
1226        if 'name-prefix' in self.yaml['operations']:
1227            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1228        else:
1229            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1230        if 'async-prefix' in self.yaml['operations']:
1231            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1232        else:
1233            self.async_op_prefix = self.op_prefix
1234
1235        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1236
1237        self.hooks = dict()
1238        for when in ['pre', 'post']:
1239            self.hooks[when] = dict()
1240            for op_mode in ['do', 'dump']:
1241                self.hooks[when][op_mode] = dict()
1242                self.hooks[when][op_mode]['set'] = set()
1243                self.hooks[when][op_mode]['list'] = []
1244
1245        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1246        self.root_sets = dict()
1247        # dict space-name -> Struct
1248        self.pure_nested_structs = dict()
1249
1250        self._mark_notify()
1251        self._mock_up_events()
1252
1253        self._load_root_sets()
1254        self._load_nested_sets()
1255        self._load_attr_use()
1256        self._load_selector_passing()
1257        self._load_hooks()
1258
1259        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1260        if self.kernel_policy == 'global':
1261            self._load_global_policy()
1262
1263    def new_enum(self, elem):
1264        return EnumSet(self, elem)
1265
1266    def new_attr_set(self, elem):
1267        return AttrSet(self, elem)
1268
1269    def new_operation(self, elem, req_value, rsp_value):
1270        return Operation(self, elem, req_value, rsp_value)
1271
1272    def new_sub_message(self, elem):
1273        return SubMessage(self, elem)
1274
1275    def is_classic(self):
1276        return self.proto == 'netlink-raw'
1277
1278    def _mark_notify(self):
1279        for op in self.msgs.values():
1280            if 'notify' in op:
1281                self.ops[op['notify']].mark_has_ntf()
1282
1283    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1284    def _mock_up_events(self):
1285        for op in self.yaml['operations']['list']:
1286            if 'event' in op:
1287                op['do'] = {
1288                    'reply': {
1289                        'attributes': op['event']['attributes']
1290                    }
1291                }
1292
1293    def _load_root_sets(self):
1294        for op_name, op in self.msgs.items():
1295            if 'attribute-set' not in op:
1296                continue
1297
1298            req_attrs = set()
1299            rsp_attrs = set()
1300            for op_mode in ['do', 'dump']:
1301                if op_mode in op and 'request' in op[op_mode]:
1302                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1303                if op_mode in op and 'reply' in op[op_mode]:
1304                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1305            if 'event' in op:
1306                rsp_attrs.update(set(op['event']['attributes']))
1307
1308            if op['attribute-set'] not in self.root_sets:
1309                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1310            else:
1311                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1312                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1313
1314    def _sort_pure_types(self):
1315        # Try to reorder according to dependencies
1316        pns_key_list = list(self.pure_nested_structs.keys())
1317        pns_key_seen = set()
1318        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1319        for _ in range(rounds):
1320            if len(pns_key_list) == 0:
1321                break
1322            name = pns_key_list.pop(0)
1323            finished = True
1324            for _, spec in self.attr_sets[name].items():
1325                if 'nested-attributes' in spec:
1326                    nested = spec['nested-attributes']
1327                elif 'sub-message' in spec:
1328                    nested = spec.sub_message
1329                else:
1330                    continue
1331
1332                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1333                if self.pure_nested_structs[nested].recursive:
1334                    continue
1335                if nested not in pns_key_seen:
1336                    # Dicts are sorted, this will make struct last
1337                    struct = self.pure_nested_structs.pop(name)
1338                    self.pure_nested_structs[name] = struct
1339                    finished = False
1340                    break
1341            if finished:
1342                pns_key_seen.add(name)
1343            else:
1344                pns_key_list.append(name)
1345
1346    def _load_nested_set_nest(self, spec):
1347        inherit = set()
1348        nested = spec['nested-attributes']
1349        if nested not in self.root_sets:
1350            if nested not in self.pure_nested_structs:
1351                self.pure_nested_structs[nested] = \
1352                    Struct(self, nested, inherited=inherit,
1353                           fixed_header=spec.get('fixed-header'))
1354        else:
1355            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1356
1357        if 'type-value' in spec:
1358            if nested in self.root_sets:
1359                raise Exception("Inheriting members to a space used as root not supported")
1360            inherit.update(set(spec['type-value']))
1361        elif spec['type'] == 'indexed-array':
1362            inherit.add('idx')
1363        self.pure_nested_structs[nested].set_inherited(inherit)
1364
1365        return nested
1366
1367    def _load_nested_set_submsg(self, spec):
1368        # Fake the struct type for the sub-message itself
1369        # its not a attr_set but codegen wants attr_sets.
1370        submsg = self.sub_msgs[spec["sub-message"]]
1371        nested = submsg.name
1372
1373        attrs = []
1374        for name, fmt in submsg.formats.items():
1375            attr = {
1376                "name": name,
1377                "parent-sub-message": spec,
1378            }
1379            if 'attribute-set' in fmt:
1380                attr |= {
1381                    "type": "nest",
1382                    "nested-attributes": fmt['attribute-set'],
1383                }
1384                if 'fixed-header' in fmt:
1385                    attr |= { "fixed-header": fmt["fixed-header"] }
1386            elif 'fixed-header' in fmt:
1387                attr |= {
1388                    "type": "binary",
1389                    "struct": fmt["fixed-header"],
1390                }
1391            else:
1392                attr["type"] = "flag"
1393            attrs.append(attr)
1394
1395        self.attr_sets[nested] = AttrSet(self, {
1396            "name": nested,
1397            "name-pfx": self.name + '-' + spec.name + '-',
1398            "attributes": attrs
1399        })
1400
1401        if nested not in self.pure_nested_structs:
1402            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1403
1404        return nested
1405
1406    def _load_nested_sets(self):
1407        attr_set_queue = list(self.root_sets.keys())
1408        attr_set_seen = set(self.root_sets.keys())
1409
1410        while len(attr_set_queue):
1411            a_set = attr_set_queue.pop(0)
1412            for attr, spec in self.attr_sets[a_set].items():
1413                if 'nested-attributes' in spec:
1414                    nested = self._load_nested_set_nest(spec)
1415                elif 'sub-message' in spec:
1416                    nested = self._load_nested_set_submsg(spec)
1417                else:
1418                    continue
1419
1420                if nested not in attr_set_seen:
1421                    attr_set_queue.append(nested)
1422                    attr_set_seen.add(nested)
1423
1424        for root_set, rs_members in self.root_sets.items():
1425            for attr, spec in self.attr_sets[root_set].items():
1426                if 'nested-attributes' in spec:
1427                    nested = spec['nested-attributes']
1428                elif 'sub-message' in spec:
1429                    nested = spec.sub_message
1430                else:
1431                    nested = None
1432
1433                if nested:
1434                    if attr in rs_members['request']:
1435                        self.pure_nested_structs[nested].request = True
1436                    if attr in rs_members['reply']:
1437                        self.pure_nested_structs[nested].reply = True
1438
1439                    if spec.is_multi_val():
1440                        child = self.pure_nested_structs.get(nested)
1441                        child.in_multi_val = True
1442
1443        self._sort_pure_types()
1444
1445        # Propagate the request / reply / recursive
1446        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1447            for _, spec in self.attr_sets[attr_set].items():
1448                if attr_set in struct.child_nests:
1449                    struct.recursive = True
1450
1451                if 'nested-attributes' in spec:
1452                    child_name = spec['nested-attributes']
1453                elif 'sub-message' in spec:
1454                    child_name = spec.sub_message
1455                else:
1456                    continue
1457
1458                struct.child_nests.add(child_name)
1459                child = self.pure_nested_structs.get(child_name)
1460                if child:
1461                    if not child.recursive:
1462                        struct.child_nests.update(child.child_nests)
1463                    child.request |= struct.request
1464                    child.reply |= struct.reply
1465                    if spec.is_multi_val():
1466                        child.in_multi_val = True
1467
1468        self._sort_pure_types()
1469
1470    def _load_attr_use(self):
1471        for _, struct in self.pure_nested_structs.items():
1472            if struct.request:
1473                for _, arg in struct.member_list():
1474                    arg.set_request()
1475            if struct.reply:
1476                for _, arg in struct.member_list():
1477                    arg.set_reply()
1478
1479        for root_set, rs_members in self.root_sets.items():
1480            for attr, spec in self.attr_sets[root_set].items():
1481                if attr in rs_members['request']:
1482                    spec.set_request()
1483                if attr in rs_members['reply']:
1484                    spec.set_reply()
1485
1486    def _load_selector_passing(self):
1487        def all_structs():
1488            for k, v in reversed(self.pure_nested_structs.items()):
1489                yield k, v
1490            for k, _ in self.root_sets.items():
1491                yield k, None  # we don't have a struct, but it must be terminal
1492
1493        for attr_set, struct in all_structs():
1494            for _, spec in self.attr_sets[attr_set].items():
1495                if 'nested-attributes' in spec:
1496                    child_name = spec['nested-attributes']
1497                elif 'sub-message' in spec:
1498                    child_name = spec.sub_message
1499                else:
1500                    continue
1501
1502                child = self.pure_nested_structs.get(child_name)
1503                for selector in child.external_selectors():
1504                    if selector.name in self.attr_sets[attr_set]:
1505                        sel_attr = self.attr_sets[attr_set][selector.name]
1506                        selector.set_attr(sel_attr)
1507                    else:
1508                        raise Exception("Passing selector thru more than one layer not supported")
1509
1510    def _load_global_policy(self):
1511        global_set = set()
1512        attr_set_name = None
1513        for op_name, op in self.ops.items():
1514            if not op:
1515                continue
1516            if 'attribute-set' not in op:
1517                continue
1518
1519            if attr_set_name is None:
1520                attr_set_name = op['attribute-set']
1521            if attr_set_name != op['attribute-set']:
1522                raise Exception('For a global policy all ops must use the same set')
1523
1524            for op_mode in ['do', 'dump']:
1525                if op_mode in op:
1526                    req = op[op_mode].get('request')
1527                    if req:
1528                        global_set.update(req.get('attributes', []))
1529
1530        self.global_policy = []
1531        self.global_policy_set = attr_set_name
1532        for attr in self.attr_sets[attr_set_name]:
1533            if attr in global_set:
1534                self.global_policy.append(attr)
1535
1536    def _load_hooks(self):
1537        for op in self.ops.values():
1538            for op_mode in ['do', 'dump']:
1539                if op_mode not in op:
1540                    continue
1541                for when in ['pre', 'post']:
1542                    if when not in op[op_mode]:
1543                        continue
1544                    name = op[op_mode][when]
1545                    if name in self.hooks[when][op_mode]['set']:
1546                        continue
1547                    self.hooks[when][op_mode]['set'].add(name)
1548                    self.hooks[when][op_mode]['list'].append(name)
1549
1550
1551class RenderInfo:
1552    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1553        self.family = family
1554        self.nl = cw.nlib
1555        self.ku_space = ku_space
1556        self.op_mode = op_mode
1557        self.op = op
1558
1559        fixed_hdr = op.fixed_header if op else None
1560        self.fixed_hdr_len = 'ys->family->hdr_len'
1561        if op and op.fixed_header:
1562            if op.fixed_header != family.fixed_header:
1563                if family.is_classic():
1564                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1565                else:
1566                    raise Exception(f"Per-op fixed header not supported, yet")
1567
1568
1569        # 'do' and 'dump' response parsing is identical
1570        self.type_consistent = True
1571        self.type_oneside = False
1572        if op_mode != 'do' and 'dump' in op:
1573            if 'do' in op:
1574                if ('reply' in op['do']) != ('reply' in op["dump"]):
1575                    self.type_consistent = False
1576                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1577                    self.type_consistent = False
1578            else:
1579                self.type_consistent = True
1580                self.type_oneside = True
1581
1582        self.attr_set = attr_set
1583        if not self.attr_set:
1584            self.attr_set = op['attribute-set']
1585
1586        self.type_name_conflict = False
1587        if op:
1588            self.type_name = c_lower(op.name)
1589        else:
1590            self.type_name = c_lower(attr_set)
1591            if attr_set in family.consts:
1592                self.type_name_conflict = True
1593
1594        self.cw = cw
1595
1596        self.struct = dict()
1597        if op_mode == 'notify':
1598            op_mode = 'do' if 'do' in op else 'dump'
1599        for op_dir in ['request', 'reply']:
1600            if op:
1601                type_list = []
1602                if op_dir in op[op_mode]:
1603                    type_list = op[op_mode][op_dir]['attributes']
1604                self.struct[op_dir] = Struct(family, self.attr_set,
1605                                             fixed_header=fixed_hdr,
1606                                             type_list=type_list)
1607        if op_mode == 'event':
1608            self.struct['reply'] = Struct(family, self.attr_set,
1609                                          fixed_header=fixed_hdr,
1610                                          type_list=op['event']['attributes'])
1611
1612    def type_empty(self, key):
1613        return len(self.struct[key].attr_list) == 0 and \
1614            self.struct['request'].fixed_header is None
1615
1616    def needs_nlflags(self, direction):
1617        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1618
1619
1620class CodeWriter:
1621    def __init__(self, nlib, out_file=None, overwrite=True):
1622        self.nlib = nlib
1623        self._overwrite = overwrite
1624
1625        self._nl = False
1626        self._block_end = False
1627        self._silent_block = False
1628        self._ind = 0
1629        self._ifdef_block = None
1630        if out_file is None:
1631            self._out = os.sys.stdout
1632        else:
1633            self._out = tempfile.NamedTemporaryFile('w+')
1634            self._out_file = out_file
1635
1636    def __del__(self):
1637        self.close_out_file()
1638
1639    def close_out_file(self):
1640        if self._out == os.sys.stdout:
1641            return
1642        # Avoid modifying the file if contents didn't change
1643        self._out.flush()
1644        if not self._overwrite and os.path.isfile(self._out_file):
1645            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1646                return
1647        with open(self._out_file, 'w+') as out_file:
1648            self._out.seek(0)
1649            shutil.copyfileobj(self._out, out_file)
1650            self._out.close()
1651        self._out = os.sys.stdout
1652
1653    @classmethod
1654    def _is_cond(cls, line):
1655        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1656
1657    def p(self, line, add_ind=0):
1658        if self._block_end:
1659            self._block_end = False
1660            if line.startswith('else'):
1661                line = '} ' + line
1662            else:
1663                self._out.write('\t' * self._ind + '}\n')
1664
1665        if self._nl:
1666            self._out.write('\n')
1667            self._nl = False
1668
1669        ind = self._ind
1670        if line[-1] == ':':
1671            ind -= 1
1672        if self._silent_block:
1673            ind += 1
1674        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1675        self._silent_block |= line.strip() == 'else'
1676        if line[0] == '#':
1677            ind = 0
1678        if add_ind:
1679            ind += add_ind
1680        self._out.write('\t' * ind + line + '\n')
1681
1682    def nl(self):
1683        self._nl = True
1684
1685    def block_start(self, line=''):
1686        if line:
1687            line = line + ' '
1688        self.p(line + '{')
1689        self._ind += 1
1690
1691    def block_end(self, line=''):
1692        if line and line[0] not in {';', ','}:
1693            line = ' ' + line
1694        self._ind -= 1
1695        self._nl = False
1696        if not line:
1697            # Delay printing closing bracket in case "else" comes next
1698            if self._block_end:
1699                self._out.write('\t' * (self._ind + 1) + '}\n')
1700            self._block_end = True
1701        else:
1702            self.p('}' + line)
1703
1704    def write_doc_line(self, doc, indent=True):
1705        words = doc.split()
1706        line = ' *'
1707        for word in words:
1708            if len(line) + len(word) >= 79:
1709                self.p(line)
1710                line = ' *'
1711                if indent:
1712                    line += '  '
1713            line += ' ' + word
1714        self.p(line)
1715
1716    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1717        if not args:
1718            args = ['void']
1719
1720        if doc:
1721            self.p('/*')
1722            self.p(' * ' + doc)
1723            self.p(' */')
1724
1725        oneline = qual_ret
1726        if qual_ret[-1] != '*':
1727            oneline += ' '
1728        oneline += f"{name}({', '.join(args)}){suffix}"
1729
1730        if len(oneline) < 80:
1731            self.p(oneline)
1732            return
1733
1734        v = qual_ret
1735        if len(v) > 3:
1736            self.p(v)
1737            v = ''
1738        elif qual_ret[-1] != '*':
1739            v += ' '
1740        v += name + '('
1741        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1742        delta_ind = len(v) - len(ind)
1743        v += args[0]
1744        i = 1
1745        while i < len(args):
1746            next_len = len(v) + len(args[i])
1747            if v[0] == '\t':
1748                next_len += delta_ind
1749            if next_len > 76:
1750                self.p(v + ',')
1751                v = ind
1752            else:
1753                v += ', '
1754            v += args[i]
1755            i += 1
1756        self.p(v + ')' + suffix)
1757
1758    def write_func_lvar(self, local_vars):
1759        if not local_vars:
1760            return
1761
1762        if type(local_vars) is str:
1763            local_vars = [local_vars]
1764
1765        local_vars.sort(key=len, reverse=True)
1766        for var in local_vars:
1767            self.p(var)
1768        self.nl()
1769
1770    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1771        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1772        self.block_start()
1773        self.write_func_lvar(local_vars=local_vars)
1774
1775        for line in body:
1776            self.p(line)
1777        self.block_end()
1778
1779    def writes_defines(self, defines):
1780        longest = 0
1781        for define in defines:
1782            if len(define[0]) > longest:
1783                longest = len(define[0])
1784        longest = ((longest + 8) // 8) * 8
1785        for define in defines:
1786            line = '#define ' + define[0]
1787            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1788            if type(define[1]) is int:
1789                line += str(define[1])
1790            elif type(define[1]) is str:
1791                line += '"' + define[1] + '"'
1792            self.p(line)
1793
1794    def write_struct_init(self, members):
1795        longest = max([len(x[0]) for x in members])
1796        longest += 1  # because we prepend a .
1797        longest = ((longest + 8) // 8) * 8
1798        for one in members:
1799            line = '.' + one[0]
1800            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1801            line += '= ' + str(one[1]) + ','
1802            self.p(line)
1803
1804    def ifdef_block(self, config):
1805        config_option = None
1806        if config:
1807            config_option = 'CONFIG_' + c_upper(config)
1808        if self._ifdef_block == config_option:
1809            return
1810
1811        if self._ifdef_block:
1812            self.p('#endif /* ' + self._ifdef_block + ' */')
1813        if config_option:
1814            self.p('#ifdef ' + config_option)
1815        self._ifdef_block = config_option
1816
1817
1818scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1819
1820direction_to_suffix = {
1821    'reply': '_rsp',
1822    'request': '_req',
1823    '': ''
1824}
1825
1826op_mode_to_wrapper = {
1827    'do': '',
1828    'dump': '_list',
1829    'notify': '_ntf',
1830    'event': '',
1831}
1832
1833_C_KW = {
1834    'auto',
1835    'bool',
1836    'break',
1837    'case',
1838    'char',
1839    'const',
1840    'continue',
1841    'default',
1842    'do',
1843    'double',
1844    'else',
1845    'enum',
1846    'extern',
1847    'float',
1848    'for',
1849    'goto',
1850    'if',
1851    'inline',
1852    'int',
1853    'long',
1854    'register',
1855    'return',
1856    'short',
1857    'signed',
1858    'sizeof',
1859    'static',
1860    'struct',
1861    'switch',
1862    'typedef',
1863    'union',
1864    'unsigned',
1865    'void',
1866    'volatile',
1867    'while'
1868}
1869
1870
1871def rdir(direction):
1872    if direction == 'reply':
1873        return 'request'
1874    if direction == 'request':
1875        return 'reply'
1876    return direction
1877
1878
1879def op_prefix(ri, direction, deref=False):
1880    suffix = f"_{ri.type_name}"
1881
1882    if not ri.op_mode:
1883        pass
1884    elif ri.op_mode == 'do':
1885        suffix += f"{direction_to_suffix[direction]}"
1886    else:
1887        if direction == 'request':
1888            suffix += '_req'
1889            if not ri.type_oneside:
1890                suffix += '_dump'
1891        else:
1892            if ri.type_consistent:
1893                if deref:
1894                    suffix += f"{direction_to_suffix[direction]}"
1895                else:
1896                    suffix += op_mode_to_wrapper[ri.op_mode]
1897            else:
1898                suffix += '_rsp'
1899                suffix += '_dump' if deref else '_list'
1900
1901    return f"{ri.family.c_name}{suffix}"
1902
1903
1904def type_name(ri, direction, deref=False):
1905    return f"struct {op_prefix(ri, direction, deref=deref)}"
1906
1907
1908def print_prototype(ri, direction, terminate=True, doc=None):
1909    suffix = ';' if terminate else ''
1910
1911    fname = ri.op.render_name
1912    if ri.op_mode == 'dump':
1913        fname += '_dump'
1914
1915    args = ['struct ynl_sock *ys']
1916    if 'request' in ri.op[ri.op_mode]:
1917        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1918
1919    ret = 'int'
1920    if 'reply' in ri.op[ri.op_mode]:
1921        ret = f"{type_name(ri, rdir(direction))} *"
1922
1923    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1924
1925
1926def print_req_prototype(ri):
1927    print_prototype(ri, "request", doc=ri.op['doc'])
1928
1929
1930def print_dump_prototype(ri):
1931    print_prototype(ri, "request")
1932
1933
1934def put_typol_submsg(cw, struct):
1935    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1936
1937    i = 0
1938    for name, arg in struct.member_list():
1939        nest = ""
1940        if arg.type == 'nest':
1941            nest = f" .nest = &{arg.nested_render_name}_nest,"
1942        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s",%s },' %
1943             (i, name, nest))
1944        i += 1
1945
1946    cw.block_end(line=';')
1947    cw.nl()
1948
1949    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1950    cw.p(f'.max_attr = {i - 1},')
1951    cw.p(f'.table = {struct.render_name}_policy,')
1952    cw.block_end(line=';')
1953    cw.nl()
1954
1955
1956def put_typol_fwd(cw, struct):
1957    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1958
1959
1960def put_typol(cw, struct):
1961    if struct.submsg:
1962        put_typol_submsg(cw, struct)
1963        return
1964
1965    type_max = struct.attr_set.max_name
1966    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1967
1968    for _, arg in struct.member_list():
1969        arg.attr_typol(cw)
1970
1971    cw.block_end(line=';')
1972    cw.nl()
1973
1974    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1975    cw.p(f'.max_attr = {type_max},')
1976    cw.p(f'.table = {struct.render_name}_policy,')
1977    cw.block_end(line=';')
1978    cw.nl()
1979
1980
1981def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1982    args = [f'int {arg_name}']
1983    if enum:
1984        args = [enum.user_type + ' ' + arg_name]
1985    cw.write_func_prot('const char *', f'{render_name}_str', args)
1986    cw.block_start()
1987    if enum and enum.type == 'flags':
1988        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1989    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1990    cw.p('return NULL;')
1991    cw.p(f'return {map_name}[{arg_name}];')
1992    cw.block_end()
1993    cw.nl()
1994
1995
1996def put_op_name_fwd(family, cw):
1997    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1998
1999
2000def put_op_name(family, cw):
2001    map_name = f'{family.c_name}_op_strmap'
2002    cw.block_start(line=f"static const char * const {map_name}[] =")
2003    for op_name, op in family.msgs.items():
2004        if op.rsp_value:
2005            # Make sure we don't add duplicated entries, if multiple commands
2006            # produce the same response in legacy families.
2007            if family.rsp_by_value[op.rsp_value] != op:
2008                cw.p(f'// skip "{op_name}", duplicate reply value')
2009                continue
2010
2011            if op.req_value == op.rsp_value:
2012                cw.p(f'[{op.enum_name}] = "{op_name}",')
2013            else:
2014                cw.p(f'[{op.rsp_value}] = "{op_name}",')
2015    cw.block_end(line=';')
2016    cw.nl()
2017
2018    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2019
2020
2021def put_enum_to_str_fwd(family, cw, enum):
2022    args = [enum.user_type + ' value']
2023    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2024
2025
2026def put_enum_to_str(family, cw, enum):
2027    map_name = f'{enum.render_name}_strmap'
2028    cw.block_start(line=f"static const char * const {map_name}[] =")
2029    for entry in enum.entries.values():
2030        cw.p(f'[{entry.value}] = "{entry.name}",')
2031    cw.block_end(line=';')
2032    cw.nl()
2033
2034    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2035
2036
2037def put_req_nested_prototype(ri, struct, suffix=';'):
2038    func_args = ['struct nlmsghdr *nlh',
2039                 'unsigned int attr_type',
2040                 f'{struct.ptr_name}obj']
2041
2042    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2043                          suffix=suffix)
2044
2045
2046def put_req_nested(ri, struct):
2047    local_vars = []
2048    init_lines = []
2049
2050    if struct.submsg is None:
2051        local_vars.append('struct nlattr *nest;')
2052        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2053    if struct.fixed_header:
2054        local_vars.append('void *hdr;')
2055        struct_sz = f'sizeof({struct.fixed_header})'
2056        init_lines.append(f"hdr = ynl_nlmsg_put_extra_header(nlh, {struct_sz});")
2057        init_lines.append(f"memcpy(hdr, &obj->_hdr, {struct_sz});")
2058
2059    has_anest = False
2060    has_count = False
2061    for _, arg in struct.member_list():
2062        has_anest |= arg.type == 'indexed-array'
2063        has_count |= arg.presence_type() == 'count'
2064    if has_anest:
2065        local_vars.append('struct nlattr *array;')
2066    if has_count:
2067        local_vars.append('unsigned int i;')
2068
2069    put_req_nested_prototype(ri, struct, suffix='')
2070    ri.cw.block_start()
2071    ri.cw.write_func_lvar(local_vars)
2072
2073    for line in init_lines:
2074        ri.cw.p(line)
2075
2076    for _, arg in struct.member_list():
2077        arg.attr_put(ri, "obj")
2078
2079    if struct.submsg is None:
2080        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2081
2082    ri.cw.nl()
2083    ri.cw.p('return 0;')
2084    ri.cw.block_end()
2085    ri.cw.nl()
2086
2087
2088def _multi_parse(ri, struct, init_lines, local_vars):
2089    if struct.fixed_header:
2090        local_vars += ['void *hdr;']
2091    if struct.nested:
2092        if struct.fixed_header:
2093            iter_line = f"ynl_attr_for_each_nested_off(attr, nested, sizeof({struct.fixed_header}))"
2094        else:
2095            iter_line = "ynl_attr_for_each_nested(attr, nested)"
2096    else:
2097        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2098        if ri.op.fixed_header != ri.family.fixed_header:
2099            if ri.family.is_classic():
2100                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2101            else:
2102                raise Exception(f"Per-op fixed header not supported, yet")
2103
2104    array_nests = set()
2105    multi_attrs = set()
2106    needs_parg = False
2107    for arg, aspec in struct.member_list():
2108        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2109            if aspec["sub-type"] in {'binary', 'nest'}:
2110                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2111                array_nests.add(arg)
2112            elif aspec['sub-type'] in scalars:
2113                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2114                array_nests.add(arg)
2115            else:
2116                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2117        if 'multi-attr' in aspec:
2118            multi_attrs.add(arg)
2119        needs_parg |= 'nested-attributes' in aspec
2120        needs_parg |= 'sub-message' in aspec
2121    if array_nests or multi_attrs:
2122        local_vars.append('int i;')
2123    if needs_parg:
2124        local_vars.append('struct ynl_parse_arg parg;')
2125        init_lines.append('parg.ys = yarg->ys;')
2126
2127    all_multi = array_nests | multi_attrs
2128
2129    for anest in sorted(all_multi):
2130        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
2131
2132    ri.cw.block_start()
2133    ri.cw.write_func_lvar(local_vars)
2134
2135    for line in init_lines:
2136        ri.cw.p(line)
2137    ri.cw.nl()
2138
2139    for arg in struct.inherited:
2140        ri.cw.p(f'dst->{arg} = {arg};')
2141
2142    if struct.fixed_header:
2143        if struct.nested:
2144            ri.cw.p('hdr = ynl_attr_data(nested);')
2145        elif ri.family.is_classic():
2146            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2147        else:
2148            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2149        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2150    for anest in sorted(all_multi):
2151        aspec = struct[anest]
2152        ri.cw.p(f"if (dst->{aspec.c_name})")
2153        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2154
2155    ri.cw.nl()
2156    ri.cw.block_start(line=iter_line)
2157    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2158    ri.cw.nl()
2159
2160    first = True
2161    for _, arg in struct.member_list():
2162        good = arg.attr_get(ri, 'dst', first=first)
2163        # First may be 'unused' or 'pad', ignore those
2164        first &= not good
2165
2166    ri.cw.block_end()
2167    ri.cw.nl()
2168
2169    for anest in sorted(array_nests):
2170        aspec = struct[anest]
2171
2172        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2173        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2174        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2175        ri.cw.p('i = 0;')
2176        if 'nested-attributes' in aspec:
2177            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2178        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2179        if 'nested-attributes' in aspec:
2180            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2181            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2182            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2183        elif aspec.sub_type in scalars:
2184            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2185        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2186            # Length is validated by typol
2187            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2188        else:
2189            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2190        ri.cw.p('i++;')
2191        ri.cw.block_end()
2192        ri.cw.block_end()
2193    ri.cw.nl()
2194
2195    for anest in sorted(multi_attrs):
2196        aspec = struct[anest]
2197        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2198        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2199        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2200        ri.cw.p('i = 0;')
2201        if 'nested-attributes' in aspec:
2202            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2203        ri.cw.block_start(line=iter_line)
2204        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2205        if 'nested-attributes' in aspec:
2206            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2207            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2208            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2209        elif aspec.type in scalars:
2210            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2211        elif aspec.type == 'binary' and 'struct' in aspec:
2212            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2213            ri.cw.nl()
2214            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2215            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2216            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2217        elif aspec.type == 'string':
2218            ri.cw.p('unsigned int len;')
2219            ri.cw.nl()
2220            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2221            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2222            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2223            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2224            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2225        else:
2226            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2227        ri.cw.p('i++;')
2228        ri.cw.block_end()
2229        ri.cw.block_end()
2230        ri.cw.block_end()
2231    ri.cw.nl()
2232
2233    if struct.nested:
2234        ri.cw.p('return 0;')
2235    else:
2236        ri.cw.p('return YNL_PARSE_CB_OK;')
2237    ri.cw.block_end()
2238    ri.cw.nl()
2239
2240
2241def parse_rsp_submsg(ri, struct):
2242    parse_rsp_nested_prototype(ri, struct, suffix='')
2243
2244    var = 'dst'
2245    local_vars = {'const struct nlattr *attr = nested;',
2246                  f'{struct.ptr_name}{var} = yarg->data;',
2247                  'struct ynl_parse_arg parg;'}
2248
2249    for _, arg in struct.member_list():
2250        _, _, l_vars = arg._attr_get(ri, var)
2251        local_vars |= set(l_vars) if l_vars else set()
2252
2253    ri.cw.block_start()
2254    ri.cw.write_func_lvar(list(local_vars))
2255    ri.cw.p('parg.ys = yarg->ys;')
2256    ri.cw.nl()
2257
2258    first = True
2259    for name, arg in struct.member_list():
2260        kw = 'if' if first else 'else if'
2261        first = False
2262
2263        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2264        get_lines, init_lines, _ = arg._attr_get(ri, var)
2265        for line in init_lines or []:
2266            ri.cw.p(line)
2267        for line in get_lines:
2268            ri.cw.p(line)
2269        if arg.presence_type() == 'present':
2270            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2271        ri.cw.block_end()
2272    ri.cw.p('return 0;')
2273    ri.cw.block_end()
2274    ri.cw.nl()
2275
2276
2277def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2278    func_args = ['struct ynl_parse_arg *yarg',
2279                 'const struct nlattr *nested']
2280    for sel in struct.external_selectors():
2281        func_args.append('const char *_sel_' + sel.name)
2282    if struct.submsg:
2283        func_args.insert(1, 'const char *sel')
2284    for arg in struct.inherited:
2285        func_args.append('__u32 ' + arg)
2286
2287    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2288                          suffix=suffix)
2289
2290
2291def parse_rsp_nested(ri, struct):
2292    if struct.submsg:
2293        return parse_rsp_submsg(ri, struct)
2294
2295    parse_rsp_nested_prototype(ri, struct, suffix='')
2296
2297    local_vars = ['const struct nlattr *attr;',
2298                  f'{struct.ptr_name}dst = yarg->data;']
2299    init_lines = []
2300
2301    if struct.member_list():
2302        _multi_parse(ri, struct, init_lines, local_vars)
2303    else:
2304        # Empty nest
2305        ri.cw.block_start()
2306        ri.cw.p('return 0;')
2307        ri.cw.block_end()
2308        ri.cw.nl()
2309
2310
2311def parse_rsp_msg(ri, deref=False):
2312    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2313        return
2314
2315    func_args = ['const struct nlmsghdr *nlh',
2316                 'struct ynl_parse_arg *yarg']
2317
2318    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2319                  'const struct nlattr *attr;']
2320    init_lines = ['dst = yarg->data;']
2321
2322    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2323
2324    if ri.struct["reply"].member_list():
2325        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2326    else:
2327        # Empty reply
2328        ri.cw.block_start()
2329        ri.cw.p('return YNL_PARSE_CB_OK;')
2330        ri.cw.block_end()
2331        ri.cw.nl()
2332
2333
2334def print_req(ri):
2335    ret_ok = '0'
2336    ret_err = '-1'
2337    direction = "request"
2338    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2339                  'struct nlmsghdr *nlh;',
2340                  'int err;']
2341
2342    if 'reply' in ri.op[ri.op_mode]:
2343        ret_ok = 'rsp'
2344        ret_err = 'NULL'
2345        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2346
2347    if ri.struct["request"].fixed_header:
2348        local_vars += ['size_t hdr_len;',
2349                       'void *hdr;']
2350
2351    for _, attr in ri.struct["request"].member_list():
2352        if attr.presence_type() == 'count':
2353            local_vars += ['unsigned int i;']
2354            break
2355
2356    print_prototype(ri, direction, terminate=False)
2357    ri.cw.block_start()
2358    ri.cw.write_func_lvar(local_vars)
2359
2360    if ri.family.is_classic():
2361        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2362    else:
2363        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2364
2365    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2366    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2367    if 'reply' in ri.op[ri.op_mode]:
2368        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2369    ri.cw.nl()
2370
2371    if ri.struct['request'].fixed_header:
2372        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2373        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2374        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2375        ri.cw.nl()
2376
2377    for _, attr in ri.struct["request"].member_list():
2378        attr.attr_put(ri, "req")
2379    ri.cw.nl()
2380
2381    if 'reply' in ri.op[ri.op_mode]:
2382        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2383        ri.cw.p('yrs.yarg.data = rsp;')
2384        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2385        if ri.op.value is not None:
2386            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2387        else:
2388            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2389        ri.cw.nl()
2390    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2391    ri.cw.p('if (err < 0)')
2392    if 'reply' in ri.op[ri.op_mode]:
2393        ri.cw.p('goto err_free;')
2394    else:
2395        ri.cw.p('return -1;')
2396    ri.cw.nl()
2397
2398    ri.cw.p(f"return {ret_ok};")
2399    ri.cw.nl()
2400
2401    if 'reply' in ri.op[ri.op_mode]:
2402        ri.cw.p('err_free:')
2403        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2404        ri.cw.p(f"return {ret_err};")
2405
2406    ri.cw.block_end()
2407
2408
2409def print_dump(ri):
2410    direction = "request"
2411    print_prototype(ri, direction, terminate=False)
2412    ri.cw.block_start()
2413    local_vars = ['struct ynl_dump_state yds = {};',
2414                  'struct nlmsghdr *nlh;',
2415                  'int err;']
2416
2417    if ri.struct['request'].fixed_header:
2418        local_vars += ['size_t hdr_len;',
2419                       'void *hdr;']
2420
2421    ri.cw.write_func_lvar(local_vars)
2422
2423    ri.cw.p('yds.yarg.ys = ys;')
2424    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2425    ri.cw.p("yds.yarg.data = NULL;")
2426    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2427    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2428    if ri.op.value is not None:
2429        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2430    else:
2431        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2432    ri.cw.nl()
2433    if ri.family.is_classic():
2434        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2435    else:
2436        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2437
2438    if ri.struct['request'].fixed_header:
2439        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2440        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2441        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2442        ri.cw.nl()
2443
2444    if "request" in ri.op[ri.op_mode]:
2445        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2446        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2447        ri.cw.nl()
2448        for _, attr in ri.struct["request"].member_list():
2449            attr.attr_put(ri, "req")
2450    ri.cw.nl()
2451
2452    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2453    ri.cw.p('if (err < 0)')
2454    ri.cw.p('goto free_list;')
2455    ri.cw.nl()
2456
2457    ri.cw.p('return yds.first;')
2458    ri.cw.nl()
2459    ri.cw.p('free_list:')
2460    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2461    ri.cw.p('return NULL;')
2462    ri.cw.block_end()
2463
2464
2465def call_free(ri, direction, var):
2466    return f"{op_prefix(ri, direction)}_free({var});"
2467
2468
2469def free_arg_name(direction):
2470    if direction:
2471        return direction_to_suffix[direction][1:]
2472    return 'obj'
2473
2474
2475def print_alloc_wrapper(ri, direction):
2476    name = op_prefix(ri, direction)
2477    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
2478    ri.cw.block_start()
2479    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
2480    ri.cw.block_end()
2481
2482
2483def print_free_prototype(ri, direction, suffix=';'):
2484    name = op_prefix(ri, direction)
2485    struct_name = name
2486    if ri.type_name_conflict:
2487        struct_name += '_'
2488    arg = free_arg_name(direction)
2489    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2490
2491
2492def print_nlflags_set(ri, direction):
2493    name = op_prefix(ri, direction)
2494    ri.cw.write_func_prot(f'static inline void', f"{name}_set_nlflags",
2495                          [f"struct {name} *req", "__u16 nl_flags"])
2496    ri.cw.block_start()
2497    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2498    ri.cw.block_end()
2499    ri.cw.nl()
2500
2501
2502def _print_type(ri, direction, struct):
2503    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2504    if not direction and ri.type_name_conflict:
2505        suffix += '_'
2506
2507    if ri.op_mode == 'dump' and not ri.type_oneside:
2508        suffix += '_dump'
2509
2510    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2511
2512    if ri.needs_nlflags(direction):
2513        ri.cw.p('__u16 _nlmsg_flags;')
2514        ri.cw.nl()
2515    if struct.fixed_header:
2516        ri.cw.p(struct.fixed_header + ' _hdr;')
2517        ri.cw.nl()
2518
2519    for type_filter in ['present', 'len', 'count']:
2520        meta_started = False
2521        for _, attr in struct.member_list():
2522            line = attr.presence_member(ri.ku_space, type_filter)
2523            if line:
2524                if not meta_started:
2525                    ri.cw.block_start(line=f"struct")
2526                    meta_started = True
2527                ri.cw.p(line)
2528        if meta_started:
2529            ri.cw.block_end(line=f'_{type_filter};')
2530    ri.cw.nl()
2531
2532    for arg in struct.inherited:
2533        ri.cw.p(f"__u32 {arg};")
2534
2535    for _, attr in struct.member_list():
2536        attr.struct_member(ri)
2537
2538    ri.cw.block_end(line=';')
2539    ri.cw.nl()
2540
2541
2542def print_type(ri, direction):
2543    _print_type(ri, direction, ri.struct[direction])
2544
2545
2546def print_type_full(ri, struct):
2547    _print_type(ri, "", struct)
2548
2549
2550def print_type_helpers(ri, direction, deref=False):
2551    print_free_prototype(ri, direction)
2552    ri.cw.nl()
2553
2554    if ri.needs_nlflags(direction):
2555        print_nlflags_set(ri, direction)
2556
2557    if ri.ku_space == 'user' and direction == 'request':
2558        for _, attr in ri.struct[direction].member_list():
2559            attr.setter(ri, ri.attr_set, direction, deref=deref)
2560    ri.cw.nl()
2561
2562
2563def print_req_type_helpers(ri):
2564    if ri.type_empty("request"):
2565        return
2566    print_alloc_wrapper(ri, "request")
2567    print_type_helpers(ri, "request")
2568
2569
2570def print_rsp_type_helpers(ri):
2571    if 'reply' not in ri.op[ri.op_mode]:
2572        return
2573    print_type_helpers(ri, "reply")
2574
2575
2576def print_parse_prototype(ri, direction, terminate=True):
2577    suffix = "_rsp" if direction == "reply" else "_req"
2578    term = ';' if terminate else ''
2579
2580    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2581                          ['const struct nlattr **tb',
2582                           f"struct {ri.op.render_name}{suffix} *req"],
2583                          suffix=term)
2584
2585
2586def print_req_type(ri):
2587    if ri.type_empty("request"):
2588        return
2589    print_type(ri, "request")
2590
2591
2592def print_req_free(ri):
2593    if 'request' not in ri.op[ri.op_mode]:
2594        return
2595    _free_type(ri, 'request', ri.struct['request'])
2596
2597
2598def print_rsp_type(ri):
2599    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2600        direction = 'reply'
2601    elif ri.op_mode == 'event':
2602        direction = 'reply'
2603    else:
2604        return
2605    print_type(ri, direction)
2606
2607
2608def print_wrapped_type(ri):
2609    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2610    if ri.op_mode == 'dump':
2611        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2612    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2613        ri.cw.p('__u16 family;')
2614        ri.cw.p('__u8 cmd;')
2615        ri.cw.p('struct ynl_ntf_base_type *next;')
2616        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2617    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2618    ri.cw.block_end(line=';')
2619    ri.cw.nl()
2620    print_free_prototype(ri, 'reply')
2621    ri.cw.nl()
2622
2623
2624def _free_type_members_iter(ri, struct):
2625    if struct.free_needs_iter():
2626        ri.cw.p('unsigned int i;')
2627        ri.cw.nl()
2628
2629
2630def _free_type_members(ri, var, struct, ref=''):
2631    for _, attr in struct.member_list():
2632        attr.free(ri, var, ref)
2633
2634
2635def _free_type(ri, direction, struct):
2636    var = free_arg_name(direction)
2637
2638    print_free_prototype(ri, direction, suffix='')
2639    ri.cw.block_start()
2640    _free_type_members_iter(ri, struct)
2641    _free_type_members(ri, var, struct)
2642    if direction:
2643        ri.cw.p(f'free({var});')
2644    ri.cw.block_end()
2645    ri.cw.nl()
2646
2647
2648def free_rsp_nested_prototype(ri):
2649        print_free_prototype(ri, "")
2650
2651
2652def free_rsp_nested(ri, struct):
2653    _free_type(ri, "", struct)
2654
2655
2656def print_rsp_free(ri):
2657    if 'reply' not in ri.op[ri.op_mode]:
2658        return
2659    _free_type(ri, 'reply', ri.struct['reply'])
2660
2661
2662def print_dump_type_free(ri):
2663    sub_type = type_name(ri, 'reply')
2664
2665    print_free_prototype(ri, 'reply', suffix='')
2666    ri.cw.block_start()
2667    ri.cw.p(f"{sub_type} *next = rsp;")
2668    ri.cw.nl()
2669    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2670    _free_type_members_iter(ri, ri.struct['reply'])
2671    ri.cw.p('rsp = next;')
2672    ri.cw.p('next = rsp->next;')
2673    ri.cw.nl()
2674
2675    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2676    ri.cw.p(f'free(rsp);')
2677    ri.cw.block_end()
2678    ri.cw.block_end()
2679    ri.cw.nl()
2680
2681
2682def print_ntf_type_free(ri):
2683    print_free_prototype(ri, 'reply', suffix='')
2684    ri.cw.block_start()
2685    _free_type_members_iter(ri, ri.struct['reply'])
2686    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2687    ri.cw.p(f'free(rsp);')
2688    ri.cw.block_end()
2689    ri.cw.nl()
2690
2691
2692def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2693    if terminate and ri and policy_should_be_static(struct.family):
2694        return
2695
2696    if terminate:
2697        prefix = 'extern '
2698    else:
2699        if ri and policy_should_be_static(struct.family):
2700            prefix = 'static '
2701        else:
2702            prefix = ''
2703
2704    suffix = ';' if terminate else ' = {'
2705
2706    max_attr = struct.attr_max_val
2707    if ri:
2708        name = ri.op.render_name
2709        if ri.op.dual_policy:
2710            name += '_' + ri.op_mode
2711    else:
2712        name = struct.render_name
2713    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2714
2715
2716def print_req_policy(cw, struct, ri=None):
2717    if ri and ri.op:
2718        cw.ifdef_block(ri.op.get('config-cond', None))
2719    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2720    for _, arg in struct.member_list():
2721        arg.attr_policy(cw)
2722    cw.p("};")
2723    cw.ifdef_block(None)
2724    cw.nl()
2725
2726
2727def kernel_can_gen_family_struct(family):
2728    return family.proto == 'genetlink'
2729
2730
2731def policy_should_be_static(family):
2732    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2733
2734
2735def print_kernel_policy_ranges(family, cw):
2736    first = True
2737    for _, attr_set in family.attr_sets.items():
2738        if attr_set.subset_of:
2739            continue
2740
2741        for _, attr in attr_set.items():
2742            if not attr.request:
2743                continue
2744            if 'full-range' not in attr.checks:
2745                continue
2746
2747            if first:
2748                cw.p('/* Integer value ranges */')
2749                first = False
2750
2751            sign = '' if attr.type[0] == 'u' else '_signed'
2752            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2753            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2754            members = []
2755            if 'min' in attr.checks:
2756                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2757            if 'max' in attr.checks:
2758                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2759            cw.write_struct_init(members)
2760            cw.block_end(line=';')
2761            cw.nl()
2762
2763
2764def print_kernel_policy_sparse_enum_validates(family, cw):
2765    first = True
2766    for _, attr_set in family.attr_sets.items():
2767        if attr_set.subset_of:
2768            continue
2769
2770        for _, attr in attr_set.items():
2771            if not attr.request:
2772                continue
2773            if not attr.enum_name:
2774                continue
2775            if 'sparse' not in attr.checks:
2776                continue
2777
2778            if first:
2779                cw.p('/* Sparse enums validation callbacks */')
2780                first = False
2781
2782            sign = '' if attr.type[0] == 'u' else '_signed'
2783            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2784            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2785                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2786            cw.block_start()
2787            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2788            enum = family.consts[attr['enum']]
2789            first_entry = True
2790            for entry in enum.entries.values():
2791                if first_entry:
2792                    first_entry = False
2793                else:
2794                    cw.p('fallthrough;')
2795                cw.p(f'case {entry.c_name}:')
2796            cw.p('return 0;')
2797            cw.block_end()
2798            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2799            cw.p('return -EINVAL;')
2800            cw.block_end()
2801            cw.nl()
2802
2803
2804def print_kernel_op_table_fwd(family, cw, terminate):
2805    exported = not kernel_can_gen_family_struct(family)
2806
2807    if not terminate or exported:
2808        cw.p(f"/* Ops table for {family.ident_name} */")
2809
2810        pol_to_struct = {'global': 'genl_small_ops',
2811                         'per-op': 'genl_ops',
2812                         'split': 'genl_split_ops'}
2813        struct_type = pol_to_struct[family.kernel_policy]
2814
2815        if not exported:
2816            cnt = ""
2817        elif family.kernel_policy == 'split':
2818            cnt = 0
2819            for op in family.ops.values():
2820                if 'do' in op:
2821                    cnt += 1
2822                if 'dump' in op:
2823                    cnt += 1
2824        else:
2825            cnt = len(family.ops)
2826
2827        qual = 'static const' if not exported else 'const'
2828        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2829        if terminate:
2830            cw.p(f"extern {line};")
2831        else:
2832            cw.block_start(line=line + ' =')
2833
2834    if not terminate:
2835        return
2836
2837    cw.nl()
2838    for name in family.hooks['pre']['do']['list']:
2839        cw.write_func_prot('int', c_lower(name),
2840                           ['const struct genl_split_ops *ops',
2841                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2842    for name in family.hooks['post']['do']['list']:
2843        cw.write_func_prot('void', c_lower(name),
2844                           ['const struct genl_split_ops *ops',
2845                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2846    for name in family.hooks['pre']['dump']['list']:
2847        cw.write_func_prot('int', c_lower(name),
2848                           ['struct netlink_callback *cb'], suffix=';')
2849    for name in family.hooks['post']['dump']['list']:
2850        cw.write_func_prot('int', c_lower(name),
2851                           ['struct netlink_callback *cb'], suffix=';')
2852
2853    cw.nl()
2854
2855    for op_name, op in family.ops.items():
2856        if op.is_async:
2857            continue
2858
2859        if 'do' in op:
2860            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2861            cw.write_func_prot('int', name,
2862                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2863
2864        if 'dump' in op:
2865            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2866            cw.write_func_prot('int', name,
2867                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2868    cw.nl()
2869
2870
2871def print_kernel_op_table_hdr(family, cw):
2872    print_kernel_op_table_fwd(family, cw, terminate=True)
2873
2874
2875def print_kernel_op_table(family, cw):
2876    print_kernel_op_table_fwd(family, cw, terminate=False)
2877    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2878        for op_name, op in family.ops.items():
2879            if op.is_async:
2880                continue
2881
2882            cw.ifdef_block(op.get('config-cond', None))
2883            cw.block_start()
2884            members = [('cmd', op.enum_name)]
2885            if 'dont-validate' in op:
2886                members.append(('validate',
2887                                ' | '.join([c_upper('genl-dont-validate-' + x)
2888                                            for x in op['dont-validate']])), )
2889            for op_mode in ['do', 'dump']:
2890                if op_mode in op:
2891                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2892                    members.append((op_mode + 'it', name))
2893            if family.kernel_policy == 'per-op':
2894                struct = Struct(family, op['attribute-set'],
2895                                type_list=op['do']['request']['attributes'])
2896
2897                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2898                members.append(('policy', name))
2899                members.append(('maxattr', struct.attr_max_val.enum_name))
2900            if 'flags' in op:
2901                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2902            cw.write_struct_init(members)
2903            cw.block_end(line=',')
2904    elif family.kernel_policy == 'split':
2905        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2906                    'dump': {'pre': 'start', 'post': 'done'}}
2907
2908        for op_name, op in family.ops.items():
2909            for op_mode in ['do', 'dump']:
2910                if op.is_async or op_mode not in op:
2911                    continue
2912
2913                cw.ifdef_block(op.get('config-cond', None))
2914                cw.block_start()
2915                members = [('cmd', op.enum_name)]
2916                if 'dont-validate' in op:
2917                    dont_validate = []
2918                    for x in op['dont-validate']:
2919                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2920                            continue
2921                        if op_mode == "dump" and x == 'strict':
2922                            continue
2923                        dont_validate.append(x)
2924
2925                    if dont_validate:
2926                        members.append(('validate',
2927                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2928                                                    for x in dont_validate])), )
2929                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2930                if 'pre' in op[op_mode]:
2931                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2932                members.append((op_mode + 'it', name))
2933                if 'post' in op[op_mode]:
2934                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2935                if 'request' in op[op_mode]:
2936                    struct = Struct(family, op['attribute-set'],
2937                                    type_list=op[op_mode]['request']['attributes'])
2938
2939                    if op.dual_policy:
2940                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2941                    else:
2942                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2943                    members.append(('policy', name))
2944                    members.append(('maxattr', struct.attr_max_val.enum_name))
2945                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2946                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2947                cw.write_struct_init(members)
2948                cw.block_end(line=',')
2949    cw.ifdef_block(None)
2950
2951    cw.block_end(line=';')
2952    cw.nl()
2953
2954
2955def print_kernel_mcgrp_hdr(family, cw):
2956    if not family.mcgrps['list']:
2957        return
2958
2959    cw.block_start('enum')
2960    for grp in family.mcgrps['list']:
2961        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2962        cw.p(grp_id)
2963    cw.block_end(';')
2964    cw.nl()
2965
2966
2967def print_kernel_mcgrp_src(family, cw):
2968    if not family.mcgrps['list']:
2969        return
2970
2971    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2972    for grp in family.mcgrps['list']:
2973        name = grp['name']
2974        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2975        cw.p('[' + grp_id + '] = { "' + name + '", },')
2976    cw.block_end(';')
2977    cw.nl()
2978
2979
2980def print_kernel_family_struct_hdr(family, cw):
2981    if not kernel_can_gen_family_struct(family):
2982        return
2983
2984    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2985    cw.nl()
2986    if 'sock-priv' in family.kernel_family:
2987        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2988        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2989        cw.nl()
2990
2991
2992def print_kernel_family_struct_src(family, cw):
2993    if not kernel_can_gen_family_struct(family):
2994        return
2995
2996    if 'sock-priv' in family.kernel_family:
2997        # Generate "trampolines" to make CFI happy
2998        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2999                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
3000                      ["void *priv"])
3001        cw.nl()
3002        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
3003                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
3004                      ["void *priv"])
3005        cw.nl()
3006
3007    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
3008    cw.p('.name\t\t= ' + family.fam_key + ',')
3009    cw.p('.version\t= ' + family.ver_key + ',')
3010    cw.p('.netnsok\t= true,')
3011    cw.p('.parallel_ops\t= true,')
3012    cw.p('.module\t\t= THIS_MODULE,')
3013    if family.kernel_policy == 'per-op':
3014        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
3015        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3016    elif family.kernel_policy == 'split':
3017        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
3018        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3019    if family.mcgrps['list']:
3020        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
3021        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
3022    if 'sock-priv' in family.kernel_family:
3023        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
3024        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
3025        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
3026    cw.block_end(';')
3027
3028
3029def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
3030    start_line = 'enum'
3031    if enum_name in obj:
3032        if obj[enum_name]:
3033            start_line = 'enum ' + c_lower(obj[enum_name])
3034    elif ckey and ckey in obj:
3035        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3036    cw.block_start(line=start_line)
3037
3038
3039def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3040    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3041    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3042    max_value = f"({cnt_name} - 1)"
3043
3044    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3045    val = 0
3046    for op in family.msgs.values():
3047        if separate_ntf and ('notify' in op or 'event' in op):
3048            continue
3049
3050        suffix = ','
3051        if op.value != val:
3052            suffix = f" = {op.value},"
3053            val = op.value
3054        cw.p(op.enum_name + suffix)
3055        val += 1
3056    cw.nl()
3057    cw.p(cnt_name + ('' if max_by_define else ','))
3058    if not max_by_define:
3059        cw.p(f"{max_name} = {max_value}")
3060    cw.block_end(line=';')
3061    if max_by_define:
3062        cw.p(f"#define {max_name} {max_value}")
3063    cw.nl()
3064
3065
3066def render_uapi_directional(family, cw, max_by_define):
3067    max_name = f"{family.op_prefix}USER_MAX"
3068    cnt_name = f"__{family.op_prefix}USER_CNT"
3069    max_value = f"({cnt_name} - 1)"
3070
3071    cw.block_start(line='enum')
3072    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3073    val = 0
3074    for op in family.msgs.values():
3075        if 'do' in op and 'event' not in op:
3076            suffix = ','
3077            if op.value and op.value != val:
3078                suffix = f" = {op.value},"
3079                val = op.value
3080            cw.p(op.enum_name + suffix)
3081            val += 1
3082    cw.nl()
3083    cw.p(cnt_name + ('' if max_by_define else ','))
3084    if not max_by_define:
3085        cw.p(f"{max_name} = {max_value}")
3086    cw.block_end(line=';')
3087    if max_by_define:
3088        cw.p(f"#define {max_name} {max_value}")
3089    cw.nl()
3090
3091    max_name = f"{family.op_prefix}KERNEL_MAX"
3092    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3093    max_value = f"({cnt_name} - 1)"
3094
3095    cw.block_start(line='enum')
3096    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3097    val = 0
3098    for op in family.msgs.values():
3099        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3100            enum_name = op.enum_name
3101            if 'event' not in op and 'notify' not in op:
3102                enum_name = f'{enum_name}_REPLY'
3103
3104            suffix = ','
3105            if op.value and op.value != val:
3106                suffix = f" = {op.value},"
3107                val = op.value
3108            cw.p(enum_name + suffix)
3109            val += 1
3110    cw.nl()
3111    cw.p(cnt_name + ('' if max_by_define else ','))
3112    if not max_by_define:
3113        cw.p(f"{max_name} = {max_value}")
3114    cw.block_end(line=';')
3115    if max_by_define:
3116        cw.p(f"#define {max_name} {max_value}")
3117    cw.nl()
3118
3119
3120def render_uapi(family, cw):
3121    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3122    hdr_prot = hdr_prot.replace('/', '_')
3123    cw.p('#ifndef ' + hdr_prot)
3124    cw.p('#define ' + hdr_prot)
3125    cw.nl()
3126
3127    defines = [(family.fam_key, family["name"]),
3128               (family.ver_key, family.get('version', 1))]
3129    cw.writes_defines(defines)
3130    cw.nl()
3131
3132    defines = []
3133    for const in family['definitions']:
3134        if const.get('header'):
3135            continue
3136
3137        if const['type'] != 'const':
3138            cw.writes_defines(defines)
3139            defines = []
3140            cw.nl()
3141
3142        # Write kdoc for enum and flags (one day maybe also structs)
3143        if const['type'] == 'enum' or const['type'] == 'flags':
3144            enum = family.consts[const['name']]
3145
3146            if enum.header:
3147                continue
3148
3149            if enum.has_doc():
3150                if enum.has_entry_doc():
3151                    cw.p('/**')
3152                    doc = ''
3153                    if 'doc' in enum:
3154                        doc = ' - ' + enum['doc']
3155                    cw.write_doc_line(enum.enum_name + doc)
3156                else:
3157                    cw.p('/*')
3158                    cw.write_doc_line(enum['doc'], indent=False)
3159                for entry in enum.entries.values():
3160                    if entry.has_doc():
3161                        doc = '@' + entry.c_name + ': ' + entry['doc']
3162                        cw.write_doc_line(doc)
3163                cw.p(' */')
3164
3165            uapi_enum_start(family, cw, const, 'name')
3166            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3167            for entry in enum.entries.values():
3168                suffix = ','
3169                if entry.value_change:
3170                    suffix = f" = {entry.user_value()}" + suffix
3171                cw.p(entry.c_name + suffix)
3172
3173            if const.get('render-max', False):
3174                cw.nl()
3175                cw.p('/* private: */')
3176                if const['type'] == 'flags':
3177                    max_name = c_upper(name_pfx + 'mask')
3178                    max_val = f' = {enum.get_mask()},'
3179                    cw.p(max_name + max_val)
3180                else:
3181                    cnt_name = enum.enum_cnt_name
3182                    max_name = c_upper(name_pfx + 'max')
3183                    if not cnt_name:
3184                        cnt_name = '__' + name_pfx + 'max'
3185                    cw.p(c_upper(cnt_name) + ',')
3186                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3187            cw.block_end(line=';')
3188            cw.nl()
3189        elif const['type'] == 'const':
3190            defines.append([c_upper(family.get('c-define-name',
3191                                               f"{family.ident_name}-{const['name']}")),
3192                            const['value']])
3193
3194    if defines:
3195        cw.writes_defines(defines)
3196        cw.nl()
3197
3198    max_by_define = family.get('max-by-define', False)
3199
3200    for _, attr_set in family.attr_sets.items():
3201        if attr_set.subset_of:
3202            continue
3203
3204        max_value = f"({attr_set.cnt_name} - 1)"
3205
3206        val = 0
3207        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3208        for _, attr in attr_set.items():
3209            suffix = ','
3210            if attr.value != val:
3211                suffix = f" = {attr.value},"
3212                val = attr.value
3213            val += 1
3214            cw.p(attr.enum_name + suffix)
3215        if attr_set.items():
3216            cw.nl()
3217        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3218        if not max_by_define:
3219            cw.p(f"{attr_set.max_name} = {max_value}")
3220        cw.block_end(line=';')
3221        if max_by_define:
3222            cw.p(f"#define {attr_set.max_name} {max_value}")
3223        cw.nl()
3224
3225    # Commands
3226    separate_ntf = 'async-prefix' in family['operations']
3227
3228    if family.msg_id_model == 'unified':
3229        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3230    elif family.msg_id_model == 'directional':
3231        render_uapi_directional(family, cw, max_by_define)
3232    else:
3233        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3234
3235    if separate_ntf:
3236        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3237        for op in family.msgs.values():
3238            if separate_ntf and not ('notify' in op or 'event' in op):
3239                continue
3240
3241            suffix = ','
3242            if 'value' in op:
3243                suffix = f" = {op['value']},"
3244            cw.p(op.enum_name + suffix)
3245        cw.block_end(line=';')
3246        cw.nl()
3247
3248    # Multicast
3249    defines = []
3250    for grp in family.mcgrps['list']:
3251        name = grp['name']
3252        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3253                        f'{name}'])
3254    cw.nl()
3255    if defines:
3256        cw.writes_defines(defines)
3257        cw.nl()
3258
3259    cw.p(f'#endif /* {hdr_prot} */')
3260
3261
3262def _render_user_ntf_entry(ri, op):
3263    if not ri.family.is_classic():
3264        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3265    else:
3266        crud_op = ri.family.req_by_value[op.rsp_value]
3267        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3268    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3269    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3270    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3271    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3272    ri.cw.block_end(line=',')
3273
3274
3275def render_user_family(family, cw, prototype):
3276    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3277    if prototype:
3278        cw.p(f'extern {symbol};')
3279        return
3280
3281    if family.ntfs:
3282        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3283        for ntf_op_name, ntf_op in family.ntfs.items():
3284            if 'notify' in ntf_op:
3285                op = family.ops[ntf_op['notify']]
3286                ri = RenderInfo(cw, family, "user", op, "notify")
3287            elif 'event' in ntf_op:
3288                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3289            else:
3290                raise Exception('Invalid notification ' + ntf_op_name)
3291            _render_user_ntf_entry(ri, ntf_op)
3292        for op_name, op in family.ops.items():
3293            if 'event' not in op:
3294                continue
3295            ri = RenderInfo(cw, family, "user", op, "event")
3296            _render_user_ntf_entry(ri, op)
3297        cw.block_end(line=";")
3298        cw.nl()
3299
3300    cw.block_start(f'{symbol} = ')
3301    cw.p(f'.name\t\t= "{family.c_name}",')
3302    if family.is_classic():
3303        cw.p(f'.is_classic\t= true,')
3304        cw.p(f'.classic_id\t= {family.get("protonum")},')
3305    if family.is_classic():
3306        if family.fixed_header:
3307            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3308    elif family.fixed_header:
3309        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3310    else:
3311        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3312    if family.ntfs:
3313        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3314        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3315    cw.block_end(line=';')
3316
3317
3318def family_contains_bitfield32(family):
3319    for _, attr_set in family.attr_sets.items():
3320        if attr_set.subset_of:
3321            continue
3322        for _, attr in attr_set.items():
3323            if attr.type == "bitfield32":
3324                return True
3325    return False
3326
3327
3328def find_kernel_root(full_path):
3329    sub_path = ''
3330    while True:
3331        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3332        full_path = os.path.dirname(full_path)
3333        maintainers = os.path.join(full_path, "MAINTAINERS")
3334        if os.path.exists(maintainers):
3335            return full_path, sub_path[:-1]
3336
3337
3338def main():
3339    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3340    parser.add_argument('--mode', dest='mode', type=str, required=True,
3341                        choices=('user', 'kernel', 'uapi'))
3342    parser.add_argument('--spec', dest='spec', type=str, required=True)
3343    parser.add_argument('--header', dest='header', action='store_true', default=None)
3344    parser.add_argument('--source', dest='header', action='store_false')
3345    parser.add_argument('--user-header', nargs='+', default=[])
3346    parser.add_argument('--cmp-out', action='store_true', default=None,
3347                        help='Do not overwrite the output file if the new output is identical to the old')
3348    parser.add_argument('--exclude-op', action='append', default=[])
3349    parser.add_argument('-o', dest='out_file', type=str, default=None)
3350    args = parser.parse_args()
3351
3352    if args.header is None:
3353        parser.error("--header or --source is required")
3354
3355    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3356
3357    try:
3358        parsed = Family(args.spec, exclude_ops)
3359        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3360            print('Spec license:', parsed.license)
3361            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3362            os.sys.exit(1)
3363    except yaml.YAMLError as exc:
3364        print(exc)
3365        os.sys.exit(1)
3366        return
3367
3368    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3369
3370    _, spec_kernel = find_kernel_root(args.spec)
3371    if args.mode == 'uapi' or args.header:
3372        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3373    else:
3374        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3375    cw.p("/* Do not edit directly, auto-generated from: */")
3376    cw.p(f"/*\t{spec_kernel} */")
3377    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3378    if args.exclude_op or args.user_header:
3379        line = ''
3380        line += ' --user-header '.join([''] + args.user_header)
3381        line += ' --exclude-op '.join([''] + args.exclude_op)
3382        cw.p(f'/* YNL-ARG{line} */')
3383    cw.nl()
3384
3385    if args.mode == 'uapi':
3386        render_uapi(parsed, cw)
3387        return
3388
3389    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3390    if args.header:
3391        cw.p('#ifndef ' + hdr_prot)
3392        cw.p('#define ' + hdr_prot)
3393        cw.nl()
3394
3395    if args.out_file:
3396        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3397    else:
3398        hdr_file = "generated_header_file.h"
3399
3400    if args.mode == 'kernel':
3401        cw.p('#include <net/netlink.h>')
3402        cw.p('#include <net/genetlink.h>')
3403        cw.nl()
3404        if not args.header:
3405            if args.out_file:
3406                cw.p(f'#include "{hdr_file}"')
3407            cw.nl()
3408        headers = ['uapi/' + parsed.uapi_header]
3409        headers += parsed.kernel_family.get('headers', [])
3410    else:
3411        cw.p('#include <stdlib.h>')
3412        cw.p('#include <string.h>')
3413        if args.header:
3414            cw.p('#include <linux/types.h>')
3415            if family_contains_bitfield32(parsed):
3416                cw.p('#include <linux/netlink.h>')
3417        else:
3418            cw.p(f'#include "{hdr_file}"')
3419            cw.p('#include "ynl.h"')
3420        headers = []
3421    for definition in parsed['definitions'] + parsed['attribute-sets']:
3422        if 'header' in definition:
3423            headers.append(definition['header'])
3424    if args.mode == 'user':
3425        headers.append(parsed.uapi_header)
3426    seen_header = []
3427    for one in headers:
3428        if one not in seen_header:
3429            cw.p(f"#include <{one}>")
3430            seen_header.append(one)
3431    cw.nl()
3432
3433    if args.mode == "user":
3434        if not args.header:
3435            cw.p("#include <linux/genetlink.h>")
3436            cw.nl()
3437            for one in args.user_header:
3438                cw.p(f'#include "{one}"')
3439        else:
3440            cw.p('struct ynl_sock;')
3441            cw.nl()
3442            render_user_family(parsed, cw, True)
3443        cw.nl()
3444
3445    if args.mode == "kernel":
3446        if args.header:
3447            for _, struct in sorted(parsed.pure_nested_structs.items()):
3448                if struct.request:
3449                    cw.p('/* Common nested types */')
3450                    break
3451            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3452                if struct.request:
3453                    print_req_policy_fwd(cw, struct)
3454            cw.nl()
3455
3456            if parsed.kernel_policy == 'global':
3457                cw.p(f"/* Global operation policy for {parsed.name} */")
3458
3459                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3460                print_req_policy_fwd(cw, struct)
3461                cw.nl()
3462
3463            if parsed.kernel_policy in {'per-op', 'split'}:
3464                for op_name, op in parsed.ops.items():
3465                    if 'do' in op and 'event' not in op:
3466                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3467                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3468                        cw.nl()
3469
3470            print_kernel_op_table_hdr(parsed, cw)
3471            print_kernel_mcgrp_hdr(parsed, cw)
3472            print_kernel_family_struct_hdr(parsed, cw)
3473        else:
3474            print_kernel_policy_ranges(parsed, cw)
3475            print_kernel_policy_sparse_enum_validates(parsed, cw)
3476
3477            for _, struct in sorted(parsed.pure_nested_structs.items()):
3478                if struct.request:
3479                    cw.p('/* Common nested types */')
3480                    break
3481            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3482                if struct.request:
3483                    print_req_policy(cw, struct)
3484            cw.nl()
3485
3486            if parsed.kernel_policy == 'global':
3487                cw.p(f"/* Global operation policy for {parsed.name} */")
3488
3489                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3490                print_req_policy(cw, struct)
3491                cw.nl()
3492
3493            for op_name, op in parsed.ops.items():
3494                if parsed.kernel_policy in {'per-op', 'split'}:
3495                    for op_mode in ['do', 'dump']:
3496                        if op_mode in op and 'request' in op[op_mode]:
3497                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3498                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3499                            print_req_policy(cw, ri.struct['request'], ri=ri)
3500                            cw.nl()
3501
3502            print_kernel_op_table(parsed, cw)
3503            print_kernel_mcgrp_src(parsed, cw)
3504            print_kernel_family_struct_src(parsed, cw)
3505
3506    if args.mode == "user":
3507        if args.header:
3508            cw.p('/* Enums */')
3509            put_op_name_fwd(parsed, cw)
3510
3511            for name, const in parsed.consts.items():
3512                if isinstance(const, EnumSet):
3513                    put_enum_to_str_fwd(parsed, cw, const)
3514            cw.nl()
3515
3516            cw.p('/* Common nested types */')
3517            for attr_set, struct in parsed.pure_nested_structs.items():
3518                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3519                print_type_full(ri, struct)
3520                if struct.request and struct.in_multi_val:
3521                    free_rsp_nested_prototype(ri)
3522                    cw.nl()
3523
3524            for op_name, op in parsed.ops.items():
3525                cw.p(f"/* ============== {op.enum_name} ============== */")
3526
3527                if 'do' in op and 'event' not in op:
3528                    cw.p(f"/* {op.enum_name} - do */")
3529                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3530                    print_req_type(ri)
3531                    print_req_type_helpers(ri)
3532                    cw.nl()
3533                    print_rsp_type(ri)
3534                    print_rsp_type_helpers(ri)
3535                    cw.nl()
3536                    print_req_prototype(ri)
3537                    cw.nl()
3538
3539                if 'dump' in op:
3540                    cw.p(f"/* {op.enum_name} - dump */")
3541                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3542                    print_req_type(ri)
3543                    print_req_type_helpers(ri)
3544                    if not ri.type_consistent or ri.type_oneside:
3545                        print_rsp_type(ri)
3546                    print_wrapped_type(ri)
3547                    print_dump_prototype(ri)
3548                    cw.nl()
3549
3550                if op.has_ntf:
3551                    cw.p(f"/* {op.enum_name} - notify */")
3552                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3553                    if not ri.type_consistent:
3554                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3555                    print_wrapped_type(ri)
3556
3557            for op_name, op in parsed.ntfs.items():
3558                if 'event' in op:
3559                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3560                    cw.p(f"/* {op.enum_name} - event */")
3561                    print_rsp_type(ri)
3562                    cw.nl()
3563                    print_wrapped_type(ri)
3564            cw.nl()
3565        else:
3566            cw.p('/* Enums */')
3567            put_op_name(parsed, cw)
3568
3569            for name, const in parsed.consts.items():
3570                if isinstance(const, EnumSet):
3571                    put_enum_to_str(parsed, cw, const)
3572            cw.nl()
3573
3574            has_recursive_nests = False
3575            cw.p('/* Policies */')
3576            for struct in parsed.pure_nested_structs.values():
3577                if struct.recursive:
3578                    put_typol_fwd(cw, struct)
3579                    has_recursive_nests = True
3580            if has_recursive_nests:
3581                cw.nl()
3582            for struct in parsed.pure_nested_structs.values():
3583                put_typol(cw, struct)
3584            for name in parsed.root_sets:
3585                struct = Struct(parsed, name)
3586                put_typol(cw, struct)
3587
3588            cw.p('/* Common nested types */')
3589            if has_recursive_nests:
3590                for attr_set, struct in parsed.pure_nested_structs.items():
3591                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3592                    free_rsp_nested_prototype(ri)
3593                    if struct.request:
3594                        put_req_nested_prototype(ri, struct)
3595                    if struct.reply:
3596                        parse_rsp_nested_prototype(ri, struct)
3597                cw.nl()
3598            for attr_set, struct in parsed.pure_nested_structs.items():
3599                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3600
3601                free_rsp_nested(ri, struct)
3602                if struct.request:
3603                    put_req_nested(ri, struct)
3604                if struct.reply:
3605                    parse_rsp_nested(ri, struct)
3606
3607            for op_name, op in parsed.ops.items():
3608                cw.p(f"/* ============== {op.enum_name} ============== */")
3609                if 'do' in op and 'event' not in op:
3610                    cw.p(f"/* {op.enum_name} - do */")
3611                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3612                    print_req_free(ri)
3613                    print_rsp_free(ri)
3614                    parse_rsp_msg(ri)
3615                    print_req(ri)
3616                    cw.nl()
3617
3618                if 'dump' in op:
3619                    cw.p(f"/* {op.enum_name} - dump */")
3620                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3621                    if not ri.type_consistent or ri.type_oneside:
3622                        parse_rsp_msg(ri, deref=True)
3623                    print_req_free(ri)
3624                    print_dump_type_free(ri)
3625                    print_dump(ri)
3626                    cw.nl()
3627
3628                if op.has_ntf:
3629                    cw.p(f"/* {op.enum_name} - notify */")
3630                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3631                    if not ri.type_consistent:
3632                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3633                    print_ntf_type_free(ri)
3634
3635            for op_name, op in parsed.ntfs.items():
3636                if 'event' in op:
3637                    cw.p(f"/* {op.enum_name} - event */")
3638
3639                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3640                    parse_rsp_msg(ri)
3641
3642                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3643                    print_ntf_type_free(ri)
3644            cw.nl()
3645            render_user_family(parsed, cw, False)
3646
3647    if args.header:
3648        cw.p(f'#endif /* {hdr_prot} */')
3649
3650
3651if __name__ == "__main__":
3652    main()
3653