xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 85502b2214d50ba0ddf2a5fb454e4d28a160d175)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17from lib import SpecSubMessage, SpecSubMessageFormat
18
19
20def c_upper(name):
21    return name.upper().replace('-', '_')
22
23
24def c_lower(name):
25    return name.lower().replace('-', '_')
26
27
28def limit_to_number(name):
29    """
30    Turn a string limit like u32-max or s64-min into its numerical value
31    """
32    if name[0] == 'u' and name.endswith('-min'):
33        return 0
34    width = int(name[1:-4])
35    if name[0] == 's':
36        width -= 1
37    value = (1 << width) - 1
38    if name[0] == 's' and name.endswith('-min'):
39        value = -value - 1
40    return value
41
42
43class BaseNlLib:
44    def get_family_id(self):
45        return 'ys->family_id'
46
47
48class Type(SpecAttr):
49    def __init__(self, family, attr_set, attr, value):
50        super().__init__(family, attr_set, attr, value)
51
52        self.attr = attr
53        self.attr_set = attr_set
54        self.type = attr['type']
55        self.checks = attr.get('checks', {})
56
57        self.request = False
58        self.reply = False
59
60        self.is_selector = False
61
62        if 'len' in attr:
63            self.len = attr['len']
64
65        if 'nested-attributes' in attr:
66            nested = attr['nested-attributes']
67        elif 'sub-message' in attr:
68            nested = attr['sub-message']
69        else:
70            nested = None
71
72        if nested:
73            self.nested_attrs = nested
74            if self.nested_attrs == family.name:
75                self.nested_render_name = c_lower(f"{family.ident_name}")
76            else:
77                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
78
79            if self.nested_attrs in self.family.consts:
80                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
81            else:
82                self.nested_struct_type = 'struct ' + self.nested_render_name
83
84        self.c_name = c_lower(self.name)
85        if self.c_name in _C_KW:
86            self.c_name += '_'
87        if self.c_name[0].isdigit():
88            self.c_name = '_' + self.c_name
89
90        # Added by resolve():
91        self.enum_name = None
92        delattr(self, "enum_name")
93
94    def _get_real_attr(self):
95        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
96        return self.family.attr_sets[self.attr_set.subset_of][self.name]
97
98    def set_request(self):
99        self.request = True
100        if self.attr_set.subset_of:
101            self._get_real_attr().set_request()
102
103    def set_reply(self):
104        self.reply = True
105        if self.attr_set.subset_of:
106            self._get_real_attr().set_reply()
107
108    def get_limit(self, limit, default=None):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return value
112        if isinstance(value, int):
113            return value
114        if value in self.family.consts:
115            return self.family.consts[value]["value"]
116        return limit_to_number(value)
117
118    def get_limit_str(self, limit, default=None, suffix=''):
119        value = self.checks.get(limit, default)
120        if value is None:
121            return ''
122        if isinstance(value, int):
123            return str(value) + suffix
124        if value in self.family.consts:
125            const = self.family.consts[value]
126            if const.get('header'):
127                return c_upper(value)
128            return c_upper(f"{self.family['name']}-{value}")
129        return c_upper(value)
130
131    def resolve(self):
132        if 'parent-sub-message' in self.attr:
133            enum_name = self.attr['parent-sub-message'].enum_name
134        elif 'name-prefix' in self.attr:
135            enum_name = f"{self.attr['name-prefix']}{self.name}"
136        else:
137            enum_name = f"{self.attr_set.name_prefix}{self.name}"
138        self.enum_name = c_upper(enum_name)
139
140        if self.attr_set.subset_of:
141            if self.checks != self._get_real_attr().checks:
142                raise Exception("Overriding checks not supported by codegen, yet")
143
144    def is_multi_val(self):
145        return None
146
147    def is_scalar(self):
148        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
149
150    def is_recursive(self):
151        return False
152
153    def is_recursive_for_op(self, ri):
154        return self.is_recursive() and not ri.op
155
156    def presence_type(self):
157        return 'present'
158
159    def presence_member(self, space, type_filter):
160        if self.presence_type() != type_filter:
161            return
162
163        if self.presence_type() == 'present':
164            pfx = '__' if space == 'user' else ''
165            return f"{pfx}u32 {self.c_name}:1;"
166
167        if self.presence_type() in {'len', 'count'}:
168            pfx = '__' if space == 'user' else ''
169            return f"{pfx}u32 {self.c_name};"
170
171    def _complex_member_type(self, ri):
172        return None
173
174    def free_needs_iter(self):
175        return False
176
177    def _free_lines(self, ri, var, ref):
178        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
179            return [f'free({var}->{ref}{self.c_name});']
180        return []
181
182    def free(self, ri, var, ref):
183        lines = self._free_lines(ri, var, ref)
184        for line in lines:
185            ri.cw.p(line)
186
187    def arg_member(self, ri):
188        member = self._complex_member_type(ri)
189        if member:
190            spc = ' ' if member[-1] != '*' else ''
191            arg = [member + spc + '*' + self.c_name]
192            if self.presence_type() == 'count':
193                arg += ['unsigned int n_' + self.c_name]
194            return arg
195        raise Exception(f"Struct member not implemented for class type {self.type}")
196
197    def struct_member(self, ri):
198        member = self._complex_member_type(ri)
199        if member:
200            ptr = '*' if self.is_multi_val() else ''
201            if self.is_recursive_for_op(ri):
202                ptr = '*'
203            spc = ' ' if member[-1] != '*' else ''
204            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
205            return
206        members = self.arg_member(ri)
207        for one in members:
208            ri.cw.p(one + ';')
209
210    def _attr_policy(self, policy):
211        return '{ .type = ' + policy + ', }'
212
213    def attr_policy(self, cw):
214        policy = f'NLA_{c_upper(self.type)}'
215        if self.attr.get('byte-order') == 'big-endian':
216            if self.type in {'u16', 'u32'}:
217                policy = f'NLA_BE{self.type[1:]}'
218
219        spec = self._attr_policy(policy)
220        cw.p(f"\t[{self.enum_name}] = {spec},")
221
222    def _attr_typol(self):
223        raise Exception(f"Type policy not implemented for class type {self.type}")
224
225    def attr_typol(self, cw):
226        typol = self._attr_typol()
227        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
228
229    def _attr_put_line(self, ri, var, line):
230        presence = self.presence_type()
231        if presence in {'present', 'len'}:
232            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
233        ri.cw.p(f"{line};")
234
235    def _attr_put_simple(self, ri, var, put_type):
236        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
237        self._attr_put_line(ri, var, line)
238
239    def attr_put(self, ri, var):
240        raise Exception(f"Put not implemented for class type {self.type}")
241
242    def _attr_get(self, ri, var):
243        raise Exception(f"Attr get not implemented for class type {self.type}")
244
245    def attr_get(self, ri, var, first):
246        lines, init_lines, local_vars = self._attr_get(ri, var)
247        if type(lines) is str:
248            lines = [lines]
249        if type(init_lines) is str:
250            init_lines = [init_lines]
251
252        kw = 'if' if first else 'else if'
253        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
254        if local_vars:
255            for local in local_vars:
256                ri.cw.p(local)
257            ri.cw.nl()
258
259        if not self.is_multi_val():
260            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
261            ri.cw.p("return YNL_PARSE_CB_ERROR;")
262            if self.presence_type() == 'present':
263                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
264
265        if init_lines:
266            ri.cw.nl()
267            for line in init_lines:
268                ri.cw.p(line)
269
270        for line in lines:
271            ri.cw.p(line)
272        ri.cw.block_end()
273        return True
274
275    def _setter_lines(self, ri, member, presence):
276        raise Exception(f"Setter not implemented for class type {self.type}")
277
278    def setter(self, ri, space, direction, deref=False, ref=None):
279        ref = (ref if ref else []) + [self.c_name]
280        var = "req"
281        member = f"{var}->{'.'.join(ref)}"
282
283        local_vars = []
284        if self.free_needs_iter():
285            local_vars += ['unsigned int i;']
286
287        code = []
288        presence = ''
289        for i in range(0, len(ref)):
290            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
291            # Every layer below last is a nest, so we know it uses bit presence
292            # last layer is "self" and may be a complex type
293            if i == len(ref) - 1 and self.presence_type() != 'present':
294                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
295                continue
296            code.append(presence + ' = 1;')
297        ref_path = '.'.join(ref[:-1])
298        if ref_path:
299            ref_path += '.'
300        code += self._free_lines(ri, var, ref_path)
301        code += self._setter_lines(ri, member, presence)
302
303        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
304        free = bool([x for x in code if 'free(' in x])
305        alloc = bool([x for x in code if 'alloc(' in x])
306        if free and not alloc:
307            func_name = '__' + func_name
308        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
309                         body=code,
310                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
311
312
313class TypeUnused(Type):
314    def presence_type(self):
315        return ''
316
317    def arg_member(self, ri):
318        return []
319
320    def _attr_get(self, ri, var):
321        return ['return YNL_PARSE_CB_ERROR;'], None, None
322
323    def _attr_typol(self):
324        return '.type = YNL_PT_REJECT, '
325
326    def attr_policy(self, cw):
327        pass
328
329    def attr_put(self, ri, var):
330        pass
331
332    def attr_get(self, ri, var, first):
333        pass
334
335    def setter(self, ri, space, direction, deref=False, ref=None):
336        pass
337
338
339class TypePad(Type):
340    def presence_type(self):
341        return ''
342
343    def arg_member(self, ri):
344        return []
345
346    def _attr_typol(self):
347        return '.type = YNL_PT_IGNORE, '
348
349    def attr_put(self, ri, var):
350        pass
351
352    def attr_get(self, ri, var, first):
353        pass
354
355    def attr_policy(self, cw):
356        pass
357
358    def setter(self, ri, space, direction, deref=False, ref=None):
359        pass
360
361
362class TypeScalar(Type):
363    def __init__(self, family, attr_set, attr, value):
364        super().__init__(family, attr_set, attr, value)
365
366        self.byte_order_comment = ''
367        if 'byte-order' in attr:
368            self.byte_order_comment = f" /* {attr['byte-order']} */"
369
370        # Classic families have some funny enums, don't bother
371        # computing checks, since we only need them for kernel policies
372        if not family.is_classic():
373            self._init_checks()
374
375        # Added by resolve():
376        self.is_bitfield = None
377        delattr(self, "is_bitfield")
378        self.type_name = None
379        delattr(self, "type_name")
380
381    def resolve(self):
382        self.resolve_up(super())
383
384        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
385            self.is_bitfield = True
386        elif 'enum' in self.attr:
387            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
388        else:
389            self.is_bitfield = False
390
391        if not self.is_bitfield and 'enum' in self.attr:
392            self.type_name = self.family.consts[self.attr['enum']].user_type
393        elif self.is_auto_scalar:
394            self.type_name = '__' + self.type[0] + '64'
395        else:
396            self.type_name = '__' + self.type
397
398    def _init_checks(self):
399        if 'enum' in self.attr:
400            enum = self.family.consts[self.attr['enum']]
401            low, high = enum.value_range()
402            if low == None and high == None:
403                self.checks['sparse'] = True
404            else:
405                if 'min' not in self.checks:
406                    if low != 0 or self.type[0] == 's':
407                        self.checks['min'] = low
408                if 'max' not in self.checks:
409                    self.checks['max'] = high
410
411        if 'min' in self.checks and 'max' in self.checks:
412            if self.get_limit('min') > self.get_limit('max'):
413                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
414            self.checks['range'] = True
415
416        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
417        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
418        if low < 0 and self.type[0] == 'u':
419            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
420        if low < -32768 or high > 32767:
421            self.checks['full-range'] = True
422
423    def _attr_policy(self, policy):
424        if 'flags-mask' in self.checks or self.is_bitfield:
425            if self.is_bitfield:
426                enum = self.family.consts[self.attr['enum']]
427                mask = enum.get_mask(as_flags=True)
428            else:
429                flags = self.family.consts[self.checks['flags-mask']]
430                flag_cnt = len(flags['entries'])
431                mask = (1 << flag_cnt) - 1
432            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
433        elif 'full-range' in self.checks:
434            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
435        elif 'range' in self.checks:
436            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
437        elif 'min' in self.checks:
438            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
439        elif 'max' in self.checks:
440            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
441        elif 'sparse' in self.checks:
442            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
443        return super()._attr_policy(policy)
444
445    def _attr_typol(self):
446        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
447
448    def arg_member(self, ri):
449        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
450
451    def attr_put(self, ri, var):
452        self._attr_put_simple(ri, var, self.type)
453
454    def _attr_get(self, ri, var):
455        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
456
457    def _setter_lines(self, ri, member, presence):
458        return [f"{member} = {self.c_name};"]
459
460
461class TypeFlag(Type):
462    def arg_member(self, ri):
463        return []
464
465    def _attr_typol(self):
466        return '.type = YNL_PT_FLAG, '
467
468    def attr_put(self, ri, var):
469        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
470
471    def _attr_get(self, ri, var):
472        return [], None, None
473
474    def _setter_lines(self, ri, member, presence):
475        return []
476
477
478class TypeString(Type):
479    def arg_member(self, ri):
480        return [f"const char *{self.c_name}"]
481
482    def presence_type(self):
483        return 'len'
484
485    def struct_member(self, ri):
486        ri.cw.p(f"char *{self.c_name};")
487
488    def _attr_typol(self):
489        typol = f'.type = YNL_PT_NUL_STR, '
490        if self.is_selector:
491            typol += '.is_selector = 1, '
492        return typol
493
494    def _attr_policy(self, policy):
495        if 'exact-len' in self.checks:
496            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
497        else:
498            mem = '{ .type = ' + policy
499            if 'max-len' in self.checks:
500                mem += ', .len = ' + self.get_limit_str('max-len')
501            mem += ', }'
502        return mem
503
504    def attr_policy(self, cw):
505        if self.checks.get('unterminated-ok', False):
506            policy = 'NLA_STRING'
507        else:
508            policy = 'NLA_NUL_STRING'
509
510        spec = self._attr_policy(policy)
511        cw.p(f"\t[{self.enum_name}] = {spec},")
512
513    def attr_put(self, ri, var):
514        self._attr_put_simple(ri, var, 'str')
515
516    def _attr_get(self, ri, var):
517        len_mem = var + '->_len.' + self.c_name
518        return [f"{len_mem} = len;",
519                f"{var}->{self.c_name} = malloc(len + 1);",
520                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
521                f"{var}->{self.c_name}[len] = 0;"], \
522               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
523               ['unsigned int len;']
524
525    def _setter_lines(self, ri, member, presence):
526        return [f"{presence} = strlen({self.c_name});",
527                f"{member} = malloc({presence} + 1);",
528                f'memcpy({member}, {self.c_name}, {presence});',
529                f'{member}[{presence}] = 0;']
530
531
532class TypeBinary(Type):
533    def arg_member(self, ri):
534        return [f"const void *{self.c_name}", 'size_t len']
535
536    def presence_type(self):
537        return 'len'
538
539    def struct_member(self, ri):
540        ri.cw.p(f"void *{self.c_name};")
541
542    def _attr_typol(self):
543        return f'.type = YNL_PT_BINARY,'
544
545    def _attr_policy(self, policy):
546        if len(self.checks) == 0:
547            pass
548        elif len(self.checks) == 1:
549            check_name = list(self.checks)[0]
550            if check_name not in {'exact-len', 'min-len', 'max-len'}:
551                raise Exception('Unsupported check for binary type: ' + check_name)
552        else:
553            raise Exception('More than one check for binary type not implemented, yet')
554
555        if len(self.checks) == 0:
556            mem = '{ .type = NLA_BINARY, }'
557        elif 'exact-len' in self.checks:
558            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
559        elif 'min-len' in self.checks:
560            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
561        elif 'max-len' in self.checks:
562            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
563
564        return mem
565
566    def attr_put(self, ri, var):
567        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
568                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
569
570    def _attr_get(self, ri, var):
571        len_mem = var + '->_len.' + self.c_name
572        return [f"{len_mem} = len;",
573                f"{var}->{self.c_name} = malloc(len);",
574                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
575               ['len = ynl_attr_data_len(attr);'], \
576               ['unsigned int len;']
577
578    def _setter_lines(self, ri, member, presence):
579        return [f"{presence} = len;",
580                f"{member} = malloc({presence});",
581                f'memcpy({member}, {self.c_name}, {presence});']
582
583
584class TypeBinaryStruct(TypeBinary):
585    def struct_member(self, ri):
586        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
587
588    def _attr_get(self, ri, var):
589        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
590        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
591        return [f"{len_mem} = len;",
592                f"if (len < {struct_sz})",
593                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
594                "else",
595                f"{var}->{self.c_name} = malloc(len);",
596                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
597               ['len = ynl_attr_data_len(attr);'], \
598               ['unsigned int len;']
599
600
601class TypeBinaryScalarArray(TypeBinary):
602    def arg_member(self, ri):
603        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
604
605    def presence_type(self):
606        return 'count'
607
608    def struct_member(self, ri):
609        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
610
611    def attr_put(self, ri, var):
612        presence = self.presence_type()
613        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
614        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
615        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
616                f"{var}->{self.c_name}, i);")
617        ri.cw.block_end()
618
619    def _attr_get(self, ri, var):
620        len_mem = var + '->_count.' + self.c_name
621        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
622                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
623                f"{var}->{self.c_name} = malloc(len);",
624                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
625               ['len = ynl_attr_data_len(attr);'], \
626               ['unsigned int len;']
627
628    def _setter_lines(self, ri, member, presence):
629        return [f"{presence} = count;",
630                f"count *= sizeof(__{self.get('sub-type')});",
631                f"{member} = malloc(count);",
632                f'memcpy({member}, {self.c_name}, count);']
633
634
635class TypeBitfield32(Type):
636    def _complex_member_type(self, ri):
637        return "struct nla_bitfield32"
638
639    def _attr_typol(self):
640        return f'.type = YNL_PT_BITFIELD32, '
641
642    def _attr_policy(self, policy):
643        if not 'enum' in self.attr:
644            raise Exception('Enum required for bitfield32 attr')
645        enum = self.family.consts[self.attr['enum']]
646        mask = enum.get_mask(as_flags=True)
647        return f"NLA_POLICY_BITFIELD32({mask})"
648
649    def attr_put(self, ri, var):
650        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
651        self._attr_put_line(ri, var, line)
652
653    def _attr_get(self, ri, var):
654        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
655
656    def _setter_lines(self, ri, member, presence):
657        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
658
659
660class TypeNest(Type):
661    def is_recursive(self):
662        return self.family.pure_nested_structs[self.nested_attrs].recursive
663
664    def _complex_member_type(self, ri):
665        return self.nested_struct_type
666
667    def _free_lines(self, ri, var, ref):
668        lines = []
669        at = '&'
670        if self.is_recursive_for_op(ri):
671            at = ''
672            lines += [f'if ({var}->{ref}{self.c_name})']
673        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
674        return lines
675
676    def _attr_typol(self):
677        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
678
679    def _attr_policy(self, policy):
680        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
681
682    def attr_put(self, ri, var):
683        at = '' if self.is_recursive_for_op(ri) else '&'
684        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
685                            f"{self.enum_name}, {at}{var}->{self.c_name})")
686
687    def _attr_get(self, ri, var):
688        pns = self.family.pure_nested_structs[self.nested_attrs]
689        args = ["&parg", "attr"]
690        for sel in pns.external_selectors():
691            args.append(f'{var}->{sel.name}')
692        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
693                     "return YNL_PARSE_CB_ERROR;"]
694        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
695                      f"parg.data = &{var}->{self.c_name};"]
696        return get_lines, init_lines, None
697
698    def setter(self, ri, space, direction, deref=False, ref=None):
699        ref = (ref if ref else []) + [self.c_name]
700
701        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
702            if attr.is_recursive():
703                continue
704            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
705
706
707class TypeMultiAttr(Type):
708    def __init__(self, family, attr_set, attr, value, base_type):
709        super().__init__(family, attr_set, attr, value)
710
711        self.base_type = base_type
712
713    def is_multi_val(self):
714        return True
715
716    def presence_type(self):
717        return 'count'
718
719    def _complex_member_type(self, ri):
720        if 'type' not in self.attr or self.attr['type'] == 'nest':
721            return self.nested_struct_type
722        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
723            return None  # use arg_member()
724        elif self.attr['type'] == 'string':
725            return 'struct ynl_string *'
726        elif self.attr['type'] in scalars:
727            scalar_pfx = '__' if ri.ku_space == 'user' else ''
728            return scalar_pfx + self.attr['type']
729        else:
730            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
731
732    def arg_member(self, ri):
733        if self.type == 'binary' and 'struct' in self.attr:
734            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
735                    f'unsigned int n_{self.c_name}']
736        return super().arg_member(ri)
737
738    def free_needs_iter(self):
739        return self.attr['type'] in {'nest', 'string'}
740
741    def _free_lines(self, ri, var, ref):
742        lines = []
743        if self.attr['type'] in scalars:
744            lines += [f"free({var}->{ref}{self.c_name});"]
745        elif self.attr['type'] == 'binary':
746            lines += [f"free({var}->{ref}{self.c_name});"]
747        elif self.attr['type'] == 'string':
748            lines += [
749                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
750                f"free({var}->{ref}{self.c_name}[i]);",
751                f"free({var}->{ref}{self.c_name});",
752            ]
753        elif 'type' not in self.attr or self.attr['type'] == 'nest':
754            lines += [
755                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
756                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
757                f"free({var}->{ref}{self.c_name});",
758            ]
759        else:
760            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
761        return lines
762
763    def _attr_policy(self, policy):
764        return self.base_type._attr_policy(policy)
765
766    def _attr_typol(self):
767        return self.base_type._attr_typol()
768
769    def _attr_get(self, ri, var):
770        return f'n_{self.c_name}++;', None, None
771
772    def attr_put(self, ri, var):
773        if self.attr['type'] in scalars:
774            put_type = self.type
775            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
776            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
777        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
778            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
779            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
780        elif self.attr['type'] == 'string':
781            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
782            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
783        elif 'type' not in self.attr or self.attr['type'] == 'nest':
784            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
785            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
786                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
787        else:
788            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
789
790    def _setter_lines(self, ri, member, presence):
791        return [f"{member} = {self.c_name};",
792                f"{presence} = n_{self.c_name};"]
793
794
795class TypeArrayNest(Type):
796    def is_multi_val(self):
797        return True
798
799    def presence_type(self):
800        return 'count'
801
802    def _complex_member_type(self, ri):
803        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
804            return self.nested_struct_type
805        elif self.attr['sub-type'] in scalars:
806            scalar_pfx = '__' if ri.ku_space == 'user' else ''
807            return scalar_pfx + self.attr['sub-type']
808        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
809            return None  # use arg_member()
810        else:
811            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
812
813    def arg_member(self, ri):
814        if self.sub_type == 'binary' and 'exact-len' in self.checks:
815            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
816                    f'unsigned int n_{self.c_name}']
817        return super().arg_member(ri)
818
819    def _attr_typol(self):
820        if self.attr['sub-type'] in scalars:
821            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
822        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
823            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
824        else:
825            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
826
827    def _attr_get(self, ri, var):
828        local_vars = ['const struct nlattr *attr2;']
829        get_lines = [f'attr_{self.c_name} = attr;',
830                     'ynl_attr_for_each_nested(attr2, attr) {',
831                     '\tif (ynl_attr_validate(yarg, attr2))',
832                     '\t\treturn YNL_PARSE_CB_ERROR;',
833                     f'\t{var}->_count.{self.c_name}++;',
834                     '}']
835        return get_lines, None, local_vars
836
837    def attr_put(self, ri, var):
838        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
839        if self.sub_type in scalars:
840            put_type = self.sub_type
841            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
842            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
843            ri.cw.block_end()
844        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
845            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
846            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
847        elif self.sub_type == 'nest':
848            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
849            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
850        else:
851            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
852        ri.cw.p('ynl_attr_nest_end(nlh, array);')
853
854    def _setter_lines(self, ri, member, presence):
855        return [f"{member} = {self.c_name};",
856                f"{presence} = n_{self.c_name};"]
857
858
859class TypeNestTypeValue(Type):
860    def _complex_member_type(self, ri):
861        return self.nested_struct_type
862
863    def _attr_typol(self):
864        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
865
866    def _attr_get(self, ri, var):
867        prev = 'attr'
868        tv_args = ''
869        get_lines = []
870        local_vars = []
871        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
872                      f"parg.data = &{var}->{self.c_name};"]
873        if 'type-value' in self.attr:
874            tv_names = [c_lower(x) for x in self.attr["type-value"]]
875            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
876            local_vars += [f'__u32 {", ".join(tv_names)};']
877            for level in self.attr["type-value"]:
878                level = c_lower(level)
879                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
880                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
881                prev = 'attr_' + level
882
883            tv_args = f", {', '.join(tv_names)}"
884
885        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
886        return get_lines, init_lines, local_vars
887
888
889class TypeSubMessage(TypeNest):
890    def __init__(self, family, attr_set, attr, value):
891        super().__init__(family, attr_set, attr, value)
892
893        self.selector = Selector(attr, attr_set)
894
895    def _attr_typol(self):
896        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
897        typol += '.is_submsg = 1, '
898        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
899        # support external selectors. No family uses sub-messages with external
900        # selector for requests so this is fine for now.
901        if not self.selector.is_external():
902            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
903        return typol
904
905    def _attr_get(self, ri, var):
906        sel = c_lower(self['selector'])
907        if self.selector.is_external():
908            sel_var = f"_sel_{sel}"
909        else:
910            sel_var = f"{var}->{sel}"
911        get_lines = [f'if (!{sel_var})',
912                     f'return ynl_submsg_failed(yarg, "%s", "%s");' %
913                        (self.name, self['selector']),
914                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
915                     "return YNL_PARSE_CB_ERROR;"]
916        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
917                      f"parg.data = &{var}->{self.c_name};"]
918        return get_lines, init_lines, None
919
920
921class Selector:
922    def __init__(self, msg_attr, attr_set):
923        self.name = msg_attr["selector"]
924
925        if self.name in attr_set:
926            self.attr = attr_set[self.name]
927            self.attr.is_selector = True
928            self._external = False
929        else:
930            # The selector will need to get passed down thru the structs
931            self.attr = None
932            self._external = True
933
934    def set_attr(self, attr):
935        self.attr = attr
936
937    def is_external(self):
938        return self._external
939
940
941class Struct:
942    def __init__(self, family, space_name, type_list=None, fixed_header=None,
943                 inherited=None, submsg=None):
944        self.family = family
945        self.space_name = space_name
946        self.attr_set = family.attr_sets[space_name]
947        # Use list to catch comparisons with empty sets
948        self._inherited = inherited if inherited is not None else []
949        self.inherited = []
950        self.fixed_header = None
951        if fixed_header:
952            self.fixed_header = 'struct ' + c_lower(fixed_header)
953        self.submsg = submsg
954
955        self.nested = type_list is None
956        if family.name == c_lower(space_name):
957            self.render_name = c_lower(family.ident_name)
958        else:
959            self.render_name = c_lower(family.ident_name + '-' + space_name)
960        self.struct_name = 'struct ' + self.render_name
961        if self.nested and space_name in family.consts:
962            self.struct_name += '_'
963        self.ptr_name = self.struct_name + ' *'
964        # All attr sets this one contains, directly or multiple levels down
965        self.child_nests = set()
966
967        self.request = False
968        self.reply = False
969        self.recursive = False
970        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
971
972        self.attr_list = []
973        self.attrs = dict()
974        if type_list is not None:
975            for t in type_list:
976                self.attr_list.append((t, self.attr_set[t]),)
977        else:
978            for t in self.attr_set:
979                self.attr_list.append((t, self.attr_set[t]),)
980
981        max_val = 0
982        self.attr_max_val = None
983        for name, attr in self.attr_list:
984            if attr.value >= max_val:
985                max_val = attr.value
986                self.attr_max_val = attr
987            self.attrs[name] = attr
988
989    def __iter__(self):
990        yield from self.attrs
991
992    def __getitem__(self, key):
993        return self.attrs[key]
994
995    def member_list(self):
996        return self.attr_list
997
998    def set_inherited(self, new_inherited):
999        if self._inherited != new_inherited:
1000            raise Exception("Inheriting different members not supported")
1001        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1002
1003    def external_selectors(self):
1004        sels = []
1005        for name, attr in self.attr_list:
1006            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1007                sels.append(attr.selector)
1008        return sels
1009
1010    def free_needs_iter(self):
1011        for _, attr in self.attr_list:
1012            if attr.free_needs_iter():
1013                return True
1014        return False
1015
1016
1017class EnumEntry(SpecEnumEntry):
1018    def __init__(self, enum_set, yaml, prev, value_start):
1019        super().__init__(enum_set, yaml, prev, value_start)
1020
1021        if prev:
1022            self.value_change = (self.value != prev.value + 1)
1023        else:
1024            self.value_change = (self.value != 0)
1025        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1026
1027        # Added by resolve:
1028        self.c_name = None
1029        delattr(self, "c_name")
1030
1031    def resolve(self):
1032        self.resolve_up(super())
1033
1034        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1035
1036
1037class EnumSet(SpecEnumSet):
1038    def __init__(self, family, yaml):
1039        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1040
1041        if 'enum-name' in yaml:
1042            if yaml['enum-name']:
1043                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1044                self.user_type = self.enum_name
1045            else:
1046                self.enum_name = None
1047        else:
1048            self.enum_name = 'enum ' + self.render_name
1049
1050        if self.enum_name:
1051            self.user_type = self.enum_name
1052        else:
1053            self.user_type = 'int'
1054
1055        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1056        self.header = yaml.get('header', None)
1057        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1058
1059        super().__init__(family, yaml)
1060
1061    def new_entry(self, entry, prev_entry, value_start):
1062        return EnumEntry(self, entry, prev_entry, value_start)
1063
1064    def value_range(self):
1065        low = min([x.value for x in self.entries.values()])
1066        high = max([x.value for x in self.entries.values()])
1067
1068        if high - low + 1 != len(self.entries):
1069            return None, None
1070
1071        return low, high
1072
1073
1074class AttrSet(SpecAttrSet):
1075    def __init__(self, family, yaml):
1076        super().__init__(family, yaml)
1077
1078        if self.subset_of is None:
1079            if 'name-prefix' in yaml:
1080                pfx = yaml['name-prefix']
1081            elif self.name == family.name:
1082                pfx = family.ident_name + '-a-'
1083            else:
1084                pfx = f"{family.ident_name}-a-{self.name}-"
1085            self.name_prefix = c_upper(pfx)
1086            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1087            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1088        else:
1089            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1090            self.max_name = family.attr_sets[self.subset_of].max_name
1091            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1092
1093        # Added by resolve:
1094        self.c_name = None
1095        delattr(self, "c_name")
1096
1097    def resolve(self):
1098        self.c_name = c_lower(self.name)
1099        if self.c_name in _C_KW:
1100            self.c_name += '_'
1101        if self.c_name == self.family.c_name:
1102            self.c_name = ''
1103
1104    def new_attr(self, elem, value):
1105        if elem['type'] in scalars:
1106            t = TypeScalar(self.family, self, elem, value)
1107        elif elem['type'] == 'unused':
1108            t = TypeUnused(self.family, self, elem, value)
1109        elif elem['type'] == 'pad':
1110            t = TypePad(self.family, self, elem, value)
1111        elif elem['type'] == 'flag':
1112            t = TypeFlag(self.family, self, elem, value)
1113        elif elem['type'] == 'string':
1114            t = TypeString(self.family, self, elem, value)
1115        elif elem['type'] == 'binary':
1116            if 'struct' in elem:
1117                t = TypeBinaryStruct(self.family, self, elem, value)
1118            elif elem.get('sub-type') in scalars:
1119                t = TypeBinaryScalarArray(self.family, self, elem, value)
1120            else:
1121                t = TypeBinary(self.family, self, elem, value)
1122        elif elem['type'] == 'bitfield32':
1123            t = TypeBitfield32(self.family, self, elem, value)
1124        elif elem['type'] == 'nest':
1125            t = TypeNest(self.family, self, elem, value)
1126        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1127            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1128                t = TypeArrayNest(self.family, self, elem, value)
1129            else:
1130                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1131        elif elem['type'] == 'nest-type-value':
1132            t = TypeNestTypeValue(self.family, self, elem, value)
1133        elif elem['type'] == 'sub-message':
1134            t = TypeSubMessage(self.family, self, elem, value)
1135        else:
1136            raise Exception(f"No typed class for type {elem['type']}")
1137
1138        if 'multi-attr' in elem and elem['multi-attr']:
1139            t = TypeMultiAttr(self.family, self, elem, value, t)
1140
1141        return t
1142
1143
1144class Operation(SpecOperation):
1145    def __init__(self, family, yaml, req_value, rsp_value):
1146        # Fill in missing operation properties (for fixed hdr-only msgs)
1147        for mode in ['do', 'dump', 'event']:
1148            for direction in ['request', 'reply']:
1149                try:
1150                    yaml[mode][direction].setdefault('attributes', [])
1151                except KeyError:
1152                    pass
1153
1154        super().__init__(family, yaml, req_value, rsp_value)
1155
1156        self.render_name = c_lower(family.ident_name + '_' + self.name)
1157
1158        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1159                         ('dump' in yaml and 'request' in yaml['dump'])
1160
1161        self.has_ntf = False
1162
1163        # Added by resolve:
1164        self.enum_name = None
1165        delattr(self, "enum_name")
1166
1167    def resolve(self):
1168        self.resolve_up(super())
1169
1170        if not self.is_async:
1171            self.enum_name = self.family.op_prefix + c_upper(self.name)
1172        else:
1173            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1174
1175    def mark_has_ntf(self):
1176        self.has_ntf = True
1177
1178
1179class SubMessage(SpecSubMessage):
1180    def __init__(self, family, yaml):
1181        super().__init__(family, yaml)
1182
1183        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1184
1185    def resolve(self):
1186        self.resolve_up(super())
1187
1188
1189class Family(SpecFamily):
1190    def __init__(self, file_name, exclude_ops):
1191        # Added by resolve:
1192        self.c_name = None
1193        delattr(self, "c_name")
1194        self.op_prefix = None
1195        delattr(self, "op_prefix")
1196        self.async_op_prefix = None
1197        delattr(self, "async_op_prefix")
1198        self.mcgrps = None
1199        delattr(self, "mcgrps")
1200        self.consts = None
1201        delattr(self, "consts")
1202        self.hooks = None
1203        delattr(self, "hooks")
1204
1205        super().__init__(file_name, exclude_ops=exclude_ops)
1206
1207        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1208        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1209
1210        if 'definitions' not in self.yaml:
1211            self.yaml['definitions'] = []
1212
1213        if 'uapi-header' in self.yaml:
1214            self.uapi_header = self.yaml['uapi-header']
1215        else:
1216            self.uapi_header = f"linux/{self.ident_name}.h"
1217        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1218            self.uapi_header_name = self.uapi_header[6:-2]
1219        else:
1220            self.uapi_header_name = self.ident_name
1221
1222    def resolve(self):
1223        self.resolve_up(super())
1224
1225        self.c_name = c_lower(self.ident_name)
1226        if 'name-prefix' in self.yaml['operations']:
1227            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1228        else:
1229            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1230        if 'async-prefix' in self.yaml['operations']:
1231            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1232        else:
1233            self.async_op_prefix = self.op_prefix
1234
1235        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1236
1237        self.hooks = dict()
1238        for when in ['pre', 'post']:
1239            self.hooks[when] = dict()
1240            for op_mode in ['do', 'dump']:
1241                self.hooks[when][op_mode] = dict()
1242                self.hooks[when][op_mode]['set'] = set()
1243                self.hooks[when][op_mode]['list'] = []
1244
1245        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1246        self.root_sets = dict()
1247        # dict space-name -> Struct
1248        self.pure_nested_structs = dict()
1249
1250        self._mark_notify()
1251        self._mock_up_events()
1252
1253        self._load_root_sets()
1254        self._load_nested_sets()
1255        self._load_attr_use()
1256        self._load_selector_passing()
1257        self._load_hooks()
1258
1259        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1260        if self.kernel_policy == 'global':
1261            self._load_global_policy()
1262
1263    def new_enum(self, elem):
1264        return EnumSet(self, elem)
1265
1266    def new_attr_set(self, elem):
1267        return AttrSet(self, elem)
1268
1269    def new_operation(self, elem, req_value, rsp_value):
1270        return Operation(self, elem, req_value, rsp_value)
1271
1272    def new_sub_message(self, elem):
1273        return SubMessage(self, elem)
1274
1275    def is_classic(self):
1276        return self.proto == 'netlink-raw'
1277
1278    def _mark_notify(self):
1279        for op in self.msgs.values():
1280            if 'notify' in op:
1281                self.ops[op['notify']].mark_has_ntf()
1282
1283    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1284    def _mock_up_events(self):
1285        for op in self.yaml['operations']['list']:
1286            if 'event' in op:
1287                op['do'] = {
1288                    'reply': {
1289                        'attributes': op['event']['attributes']
1290                    }
1291                }
1292
1293    def _load_root_sets(self):
1294        for op_name, op in self.msgs.items():
1295            if 'attribute-set' not in op:
1296                continue
1297
1298            req_attrs = set()
1299            rsp_attrs = set()
1300            for op_mode in ['do', 'dump']:
1301                if op_mode in op and 'request' in op[op_mode]:
1302                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1303                if op_mode in op and 'reply' in op[op_mode]:
1304                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1305            if 'event' in op:
1306                rsp_attrs.update(set(op['event']['attributes']))
1307
1308            if op['attribute-set'] not in self.root_sets:
1309                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1310            else:
1311                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1312                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1313
1314    def _sort_pure_types(self):
1315        # Try to reorder according to dependencies
1316        pns_key_list = list(self.pure_nested_structs.keys())
1317        pns_key_seen = set()
1318        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1319        for _ in range(rounds):
1320            if len(pns_key_list) == 0:
1321                break
1322            name = pns_key_list.pop(0)
1323            finished = True
1324            for _, spec in self.attr_sets[name].items():
1325                if 'nested-attributes' in spec:
1326                    nested = spec['nested-attributes']
1327                elif 'sub-message' in spec:
1328                    nested = spec.sub_message
1329                else:
1330                    continue
1331
1332                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1333                if self.pure_nested_structs[nested].recursive:
1334                    continue
1335                if nested not in pns_key_seen:
1336                    # Dicts are sorted, this will make struct last
1337                    struct = self.pure_nested_structs.pop(name)
1338                    self.pure_nested_structs[name] = struct
1339                    finished = False
1340                    break
1341            if finished:
1342                pns_key_seen.add(name)
1343            else:
1344                pns_key_list.append(name)
1345
1346    def _load_nested_set_nest(self, spec):
1347        inherit = set()
1348        nested = spec['nested-attributes']
1349        if nested not in self.root_sets:
1350            if nested not in self.pure_nested_structs:
1351                self.pure_nested_structs[nested] = \
1352                    Struct(self, nested, inherited=inherit,
1353                           fixed_header=spec.get('fixed-header'))
1354        else:
1355            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1356
1357        if 'type-value' in spec:
1358            if nested in self.root_sets:
1359                raise Exception("Inheriting members to a space used as root not supported")
1360            inherit.update(set(spec['type-value']))
1361        elif spec['type'] == 'indexed-array':
1362            inherit.add('idx')
1363        self.pure_nested_structs[nested].set_inherited(inherit)
1364
1365        return nested
1366
1367    def _load_nested_set_submsg(self, spec):
1368        # Fake the struct type for the sub-message itself
1369        # its not a attr_set but codegen wants attr_sets.
1370        submsg = self.sub_msgs[spec["sub-message"]]
1371        nested = submsg.name
1372
1373        attrs = []
1374        for name, fmt in submsg.formats.items():
1375            attr = {
1376                "name": name,
1377                "parent-sub-message": spec,
1378            }
1379            if 'attribute-set' in fmt:
1380                attr |= {
1381                    "type": "nest",
1382                    "nested-attributes": fmt['attribute-set'],
1383                }
1384                if 'fixed-header' in fmt:
1385                    attr |= { "fixed-header": fmt["fixed-header"] }
1386            elif 'fixed-header' in fmt:
1387                attr |= {
1388                    "type": "binary",
1389                    "struct": fmt["fixed-header"],
1390                }
1391            else:
1392                attr["type"] = "flag"
1393            attrs.append(attr)
1394
1395        self.attr_sets[nested] = AttrSet(self, {
1396            "name": nested,
1397            "name-pfx": self.name + '-' + spec.name + '-',
1398            "attributes": attrs
1399        })
1400
1401        if nested not in self.pure_nested_structs:
1402            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1403
1404        return nested
1405
1406    def _load_nested_sets(self):
1407        attr_set_queue = list(self.root_sets.keys())
1408        attr_set_seen = set(self.root_sets.keys())
1409
1410        while len(attr_set_queue):
1411            a_set = attr_set_queue.pop(0)
1412            for attr, spec in self.attr_sets[a_set].items():
1413                if 'nested-attributes' in spec:
1414                    nested = self._load_nested_set_nest(spec)
1415                elif 'sub-message' in spec:
1416                    nested = self._load_nested_set_submsg(spec)
1417                else:
1418                    continue
1419
1420                if nested not in attr_set_seen:
1421                    attr_set_queue.append(nested)
1422                    attr_set_seen.add(nested)
1423
1424        for root_set, rs_members in self.root_sets.items():
1425            for attr, spec in self.attr_sets[root_set].items():
1426                if 'nested-attributes' in spec:
1427                    nested = spec['nested-attributes']
1428                elif 'sub-message' in spec:
1429                    nested = spec.sub_message
1430                else:
1431                    nested = None
1432
1433                if nested:
1434                    if attr in rs_members['request']:
1435                        self.pure_nested_structs[nested].request = True
1436                    if attr in rs_members['reply']:
1437                        self.pure_nested_structs[nested].reply = True
1438
1439                    if spec.is_multi_val():
1440                        child = self.pure_nested_structs.get(nested)
1441                        child.in_multi_val = True
1442
1443        self._sort_pure_types()
1444
1445        # Propagate the request / reply / recursive
1446        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1447            for _, spec in self.attr_sets[attr_set].items():
1448                if attr_set in struct.child_nests:
1449                    struct.recursive = True
1450
1451                if 'nested-attributes' in spec:
1452                    child_name = spec['nested-attributes']
1453                elif 'sub-message' in spec:
1454                    child_name = spec.sub_message
1455                else:
1456                    continue
1457
1458                struct.child_nests.add(child_name)
1459                child = self.pure_nested_structs.get(child_name)
1460                if child:
1461                    if not child.recursive:
1462                        struct.child_nests.update(child.child_nests)
1463                    child.request |= struct.request
1464                    child.reply |= struct.reply
1465                    if spec.is_multi_val():
1466                        child.in_multi_val = True
1467
1468        self._sort_pure_types()
1469
1470    def _load_attr_use(self):
1471        for _, struct in self.pure_nested_structs.items():
1472            if struct.request:
1473                for _, arg in struct.member_list():
1474                    arg.set_request()
1475            if struct.reply:
1476                for _, arg in struct.member_list():
1477                    arg.set_reply()
1478
1479        for root_set, rs_members in self.root_sets.items():
1480            for attr, spec in self.attr_sets[root_set].items():
1481                if attr in rs_members['request']:
1482                    spec.set_request()
1483                if attr in rs_members['reply']:
1484                    spec.set_reply()
1485
1486    def _load_selector_passing(self):
1487        def all_structs():
1488            for k, v in reversed(self.pure_nested_structs.items()):
1489                yield k, v
1490            for k, _ in self.root_sets.items():
1491                yield k, None  # we don't have a struct, but it must be terminal
1492
1493        for attr_set, struct in all_structs():
1494            for _, spec in self.attr_sets[attr_set].items():
1495                if 'nested-attributes' in spec:
1496                    child_name = spec['nested-attributes']
1497                elif 'sub-message' in spec:
1498                    child_name = spec.sub_message
1499                else:
1500                    continue
1501
1502                child = self.pure_nested_structs.get(child_name)
1503                for selector in child.external_selectors():
1504                    if selector.name in self.attr_sets[attr_set]:
1505                        sel_attr = self.attr_sets[attr_set][selector.name]
1506                        selector.set_attr(sel_attr)
1507                    else:
1508                        raise Exception("Passing selector thru more than one layer not supported")
1509
1510    def _load_global_policy(self):
1511        global_set = set()
1512        attr_set_name = None
1513        for op_name, op in self.ops.items():
1514            if not op:
1515                continue
1516            if 'attribute-set' not in op:
1517                continue
1518
1519            if attr_set_name is None:
1520                attr_set_name = op['attribute-set']
1521            if attr_set_name != op['attribute-set']:
1522                raise Exception('For a global policy all ops must use the same set')
1523
1524            for op_mode in ['do', 'dump']:
1525                if op_mode in op:
1526                    req = op[op_mode].get('request')
1527                    if req:
1528                        global_set.update(req.get('attributes', []))
1529
1530        self.global_policy = []
1531        self.global_policy_set = attr_set_name
1532        for attr in self.attr_sets[attr_set_name]:
1533            if attr in global_set:
1534                self.global_policy.append(attr)
1535
1536    def _load_hooks(self):
1537        for op in self.ops.values():
1538            for op_mode in ['do', 'dump']:
1539                if op_mode not in op:
1540                    continue
1541                for when in ['pre', 'post']:
1542                    if when not in op[op_mode]:
1543                        continue
1544                    name = op[op_mode][when]
1545                    if name in self.hooks[when][op_mode]['set']:
1546                        continue
1547                    self.hooks[when][op_mode]['set'].add(name)
1548                    self.hooks[when][op_mode]['list'].append(name)
1549
1550
1551class RenderInfo:
1552    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1553        self.family = family
1554        self.nl = cw.nlib
1555        self.ku_space = ku_space
1556        self.op_mode = op_mode
1557        self.op = op
1558
1559        fixed_hdr = op.fixed_header if op else None
1560        self.fixed_hdr_len = 'ys->family->hdr_len'
1561        if op and op.fixed_header:
1562            if op.fixed_header != family.fixed_header:
1563                if family.is_classic():
1564                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1565                else:
1566                    raise Exception(f"Per-op fixed header not supported, yet")
1567
1568
1569        # 'do' and 'dump' response parsing is identical
1570        self.type_consistent = True
1571        self.type_oneside = False
1572        if op_mode != 'do' and 'dump' in op:
1573            if 'do' in op:
1574                if ('reply' in op['do']) != ('reply' in op["dump"]):
1575                    self.type_consistent = False
1576                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1577                    self.type_consistent = False
1578            else:
1579                self.type_consistent = True
1580                self.type_oneside = True
1581
1582        self.attr_set = attr_set
1583        if not self.attr_set:
1584            self.attr_set = op['attribute-set']
1585
1586        self.type_name_conflict = False
1587        if op:
1588            self.type_name = c_lower(op.name)
1589        else:
1590            self.type_name = c_lower(attr_set)
1591            if attr_set in family.consts:
1592                self.type_name_conflict = True
1593
1594        self.cw = cw
1595
1596        self.struct = dict()
1597        if op_mode == 'notify':
1598            op_mode = 'do' if 'do' in op else 'dump'
1599        for op_dir in ['request', 'reply']:
1600            if op:
1601                type_list = []
1602                if op_dir in op[op_mode]:
1603                    type_list = op[op_mode][op_dir]['attributes']
1604                self.struct[op_dir] = Struct(family, self.attr_set,
1605                                             fixed_header=fixed_hdr,
1606                                             type_list=type_list)
1607        if op_mode == 'event':
1608            self.struct['reply'] = Struct(family, self.attr_set,
1609                                          fixed_header=fixed_hdr,
1610                                          type_list=op['event']['attributes'])
1611
1612    def type_empty(self, key):
1613        return len(self.struct[key].attr_list) == 0 and \
1614            self.struct['request'].fixed_header is None
1615
1616    def needs_nlflags(self, direction):
1617        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1618
1619
1620class CodeWriter:
1621    def __init__(self, nlib, out_file=None, overwrite=True):
1622        self.nlib = nlib
1623        self._overwrite = overwrite
1624
1625        self._nl = False
1626        self._block_end = False
1627        self._silent_block = False
1628        self._ind = 0
1629        self._ifdef_block = None
1630        if out_file is None:
1631            self._out = os.sys.stdout
1632        else:
1633            self._out = tempfile.NamedTemporaryFile('w+')
1634            self._out_file = out_file
1635
1636    def __del__(self):
1637        self.close_out_file()
1638
1639    def close_out_file(self):
1640        if self._out == os.sys.stdout:
1641            return
1642        # Avoid modifying the file if contents didn't change
1643        self._out.flush()
1644        if not self._overwrite and os.path.isfile(self._out_file):
1645            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1646                return
1647        with open(self._out_file, 'w+') as out_file:
1648            self._out.seek(0)
1649            shutil.copyfileobj(self._out, out_file)
1650            self._out.close()
1651        self._out = os.sys.stdout
1652
1653    @classmethod
1654    def _is_cond(cls, line):
1655        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1656
1657    def p(self, line, add_ind=0):
1658        if self._block_end:
1659            self._block_end = False
1660            if line.startswith('else'):
1661                line = '} ' + line
1662            else:
1663                self._out.write('\t' * self._ind + '}\n')
1664
1665        if self._nl:
1666            self._out.write('\n')
1667            self._nl = False
1668
1669        ind = self._ind
1670        if line[-1] == ':':
1671            ind -= 1
1672        if self._silent_block:
1673            ind += 1
1674        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1675        self._silent_block |= line.strip() == 'else'
1676        if line[0] == '#':
1677            ind = 0
1678        if add_ind:
1679            ind += add_ind
1680        self._out.write('\t' * ind + line + '\n')
1681
1682    def nl(self):
1683        self._nl = True
1684
1685    def block_start(self, line=''):
1686        if line:
1687            line = line + ' '
1688        self.p(line + '{')
1689        self._ind += 1
1690
1691    def block_end(self, line=''):
1692        if line and line[0] not in {';', ','}:
1693            line = ' ' + line
1694        self._ind -= 1
1695        self._nl = False
1696        if not line:
1697            # Delay printing closing bracket in case "else" comes next
1698            if self._block_end:
1699                self._out.write('\t' * (self._ind + 1) + '}\n')
1700            self._block_end = True
1701        else:
1702            self.p('}' + line)
1703
1704    def write_doc_line(self, doc, indent=True):
1705        words = doc.split()
1706        line = ' *'
1707        for word in words:
1708            if len(line) + len(word) >= 79:
1709                self.p(line)
1710                line = ' *'
1711                if indent:
1712                    line += '  '
1713            line += ' ' + word
1714        self.p(line)
1715
1716    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1717        if not args:
1718            args = ['void']
1719
1720        if doc:
1721            self.p('/*')
1722            self.p(' * ' + doc)
1723            self.p(' */')
1724
1725        oneline = qual_ret
1726        if qual_ret[-1] != '*':
1727            oneline += ' '
1728        oneline += f"{name}({', '.join(args)}){suffix}"
1729
1730        if len(oneline) < 80:
1731            self.p(oneline)
1732            return
1733
1734        v = qual_ret
1735        if len(v) > 3:
1736            self.p(v)
1737            v = ''
1738        elif qual_ret[-1] != '*':
1739            v += ' '
1740        v += name + '('
1741        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1742        delta_ind = len(v) - len(ind)
1743        v += args[0]
1744        i = 1
1745        while i < len(args):
1746            next_len = len(v) + len(args[i])
1747            if v[0] == '\t':
1748                next_len += delta_ind
1749            if next_len > 76:
1750                self.p(v + ',')
1751                v = ind
1752            else:
1753                v += ', '
1754            v += args[i]
1755            i += 1
1756        self.p(v + ')' + suffix)
1757
1758    def write_func_lvar(self, local_vars):
1759        if not local_vars:
1760            return
1761
1762        if type(local_vars) is str:
1763            local_vars = [local_vars]
1764
1765        local_vars.sort(key=len, reverse=True)
1766        for var in local_vars:
1767            self.p(var)
1768        self.nl()
1769
1770    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1771        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1772        self.block_start()
1773        self.write_func_lvar(local_vars=local_vars)
1774
1775        for line in body:
1776            self.p(line)
1777        self.block_end()
1778
1779    def writes_defines(self, defines):
1780        longest = 0
1781        for define in defines:
1782            if len(define[0]) > longest:
1783                longest = len(define[0])
1784        longest = ((longest + 8) // 8) * 8
1785        for define in defines:
1786            line = '#define ' + define[0]
1787            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1788            if type(define[1]) is int:
1789                line += str(define[1])
1790            elif type(define[1]) is str:
1791                line += '"' + define[1] + '"'
1792            self.p(line)
1793
1794    def write_struct_init(self, members):
1795        longest = max([len(x[0]) for x in members])
1796        longest += 1  # because we prepend a .
1797        longest = ((longest + 8) // 8) * 8
1798        for one in members:
1799            line = '.' + one[0]
1800            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1801            line += '= ' + str(one[1]) + ','
1802            self.p(line)
1803
1804    def ifdef_block(self, config):
1805        config_option = None
1806        if config:
1807            config_option = 'CONFIG_' + c_upper(config)
1808        if self._ifdef_block == config_option:
1809            return
1810
1811        if self._ifdef_block:
1812            self.p('#endif /* ' + self._ifdef_block + ' */')
1813        if config_option:
1814            self.p('#ifdef ' + config_option)
1815        self._ifdef_block = config_option
1816
1817
1818scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1819
1820direction_to_suffix = {
1821    'reply': '_rsp',
1822    'request': '_req',
1823    '': ''
1824}
1825
1826op_mode_to_wrapper = {
1827    'do': '',
1828    'dump': '_list',
1829    'notify': '_ntf',
1830    'event': '',
1831}
1832
1833_C_KW = {
1834    'auto',
1835    'bool',
1836    'break',
1837    'case',
1838    'char',
1839    'const',
1840    'continue',
1841    'default',
1842    'do',
1843    'double',
1844    'else',
1845    'enum',
1846    'extern',
1847    'float',
1848    'for',
1849    'goto',
1850    'if',
1851    'inline',
1852    'int',
1853    'long',
1854    'register',
1855    'return',
1856    'short',
1857    'signed',
1858    'sizeof',
1859    'static',
1860    'struct',
1861    'switch',
1862    'typedef',
1863    'union',
1864    'unsigned',
1865    'void',
1866    'volatile',
1867    'while'
1868}
1869
1870
1871def rdir(direction):
1872    if direction == 'reply':
1873        return 'request'
1874    if direction == 'request':
1875        return 'reply'
1876    return direction
1877
1878
1879def op_prefix(ri, direction, deref=False):
1880    suffix = f"_{ri.type_name}"
1881
1882    if not ri.op_mode or ri.op_mode == 'do':
1883        suffix += f"{direction_to_suffix[direction]}"
1884    else:
1885        if direction == 'request':
1886            suffix += '_req'
1887            if not ri.type_oneside:
1888                suffix += '_dump'
1889        else:
1890            if ri.type_consistent:
1891                if deref:
1892                    suffix += f"{direction_to_suffix[direction]}"
1893                else:
1894                    suffix += op_mode_to_wrapper[ri.op_mode]
1895            else:
1896                suffix += '_rsp'
1897                suffix += '_dump' if deref else '_list'
1898
1899    return f"{ri.family.c_name}{suffix}"
1900
1901
1902def type_name(ri, direction, deref=False):
1903    return f"struct {op_prefix(ri, direction, deref=deref)}"
1904
1905
1906def print_prototype(ri, direction, terminate=True, doc=None):
1907    suffix = ';' if terminate else ''
1908
1909    fname = ri.op.render_name
1910    if ri.op_mode == 'dump':
1911        fname += '_dump'
1912
1913    args = ['struct ynl_sock *ys']
1914    if 'request' in ri.op[ri.op_mode]:
1915        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1916
1917    ret = 'int'
1918    if 'reply' in ri.op[ri.op_mode]:
1919        ret = f"{type_name(ri, rdir(direction))} *"
1920
1921    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1922
1923
1924def print_req_prototype(ri):
1925    print_prototype(ri, "request", doc=ri.op['doc'])
1926
1927
1928def print_dump_prototype(ri):
1929    print_prototype(ri, "request")
1930
1931
1932def put_typol_submsg(cw, struct):
1933    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1934
1935    i = 0
1936    for name, arg in struct.member_list():
1937        nest = ""
1938        if arg.type == 'nest':
1939            nest = f" .nest = &{arg.nested_render_name}_nest,"
1940        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s",%s },' %
1941             (i, name, nest))
1942        i += 1
1943
1944    cw.block_end(line=';')
1945    cw.nl()
1946
1947    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1948    cw.p(f'.max_attr = {i - 1},')
1949    cw.p(f'.table = {struct.render_name}_policy,')
1950    cw.block_end(line=';')
1951    cw.nl()
1952
1953
1954def put_typol_fwd(cw, struct):
1955    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1956
1957
1958def put_typol(cw, struct):
1959    if struct.submsg:
1960        put_typol_submsg(cw, struct)
1961        return
1962
1963    type_max = struct.attr_set.max_name
1964    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1965
1966    for _, arg in struct.member_list():
1967        arg.attr_typol(cw)
1968
1969    cw.block_end(line=';')
1970    cw.nl()
1971
1972    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1973    cw.p(f'.max_attr = {type_max},')
1974    cw.p(f'.table = {struct.render_name}_policy,')
1975    cw.block_end(line=';')
1976    cw.nl()
1977
1978
1979def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1980    args = [f'int {arg_name}']
1981    if enum:
1982        args = [enum.user_type + ' ' + arg_name]
1983    cw.write_func_prot('const char *', f'{render_name}_str', args)
1984    cw.block_start()
1985    if enum and enum.type == 'flags':
1986        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1987    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1988    cw.p('return NULL;')
1989    cw.p(f'return {map_name}[{arg_name}];')
1990    cw.block_end()
1991    cw.nl()
1992
1993
1994def put_op_name_fwd(family, cw):
1995    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1996
1997
1998def put_op_name(family, cw):
1999    map_name = f'{family.c_name}_op_strmap'
2000    cw.block_start(line=f"static const char * const {map_name}[] =")
2001    for op_name, op in family.msgs.items():
2002        if op.rsp_value:
2003            # Make sure we don't add duplicated entries, if multiple commands
2004            # produce the same response in legacy families.
2005            if family.rsp_by_value[op.rsp_value] != op:
2006                cw.p(f'// skip "{op_name}", duplicate reply value')
2007                continue
2008
2009            if op.req_value == op.rsp_value:
2010                cw.p(f'[{op.enum_name}] = "{op_name}",')
2011            else:
2012                cw.p(f'[{op.rsp_value}] = "{op_name}",')
2013    cw.block_end(line=';')
2014    cw.nl()
2015
2016    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2017
2018
2019def put_enum_to_str_fwd(family, cw, enum):
2020    args = [enum.user_type + ' value']
2021    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2022
2023
2024def put_enum_to_str(family, cw, enum):
2025    map_name = f'{enum.render_name}_strmap'
2026    cw.block_start(line=f"static const char * const {map_name}[] =")
2027    for entry in enum.entries.values():
2028        cw.p(f'[{entry.value}] = "{entry.name}",')
2029    cw.block_end(line=';')
2030    cw.nl()
2031
2032    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2033
2034
2035def put_req_nested_prototype(ri, struct, suffix=';'):
2036    func_args = ['struct nlmsghdr *nlh',
2037                 'unsigned int attr_type',
2038                 f'{struct.ptr_name}obj']
2039
2040    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2041                          suffix=suffix)
2042
2043
2044def put_req_nested(ri, struct):
2045    local_vars = []
2046    init_lines = []
2047
2048    if struct.submsg is None:
2049        local_vars.append('struct nlattr *nest;')
2050        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2051    if struct.fixed_header:
2052        local_vars.append('void *hdr;')
2053        struct_sz = f'sizeof({struct.fixed_header})'
2054        init_lines.append(f"hdr = ynl_nlmsg_put_extra_header(nlh, {struct_sz});")
2055        init_lines.append(f"memcpy(hdr, &obj->_hdr, {struct_sz});")
2056
2057    has_anest = False
2058    has_count = False
2059    for _, arg in struct.member_list():
2060        has_anest |= arg.type == 'indexed-array'
2061        has_count |= arg.presence_type() == 'count'
2062    if has_anest:
2063        local_vars.append('struct nlattr *array;')
2064    if has_count:
2065        local_vars.append('unsigned int i;')
2066
2067    put_req_nested_prototype(ri, struct, suffix='')
2068    ri.cw.block_start()
2069    ri.cw.write_func_lvar(local_vars)
2070
2071    for line in init_lines:
2072        ri.cw.p(line)
2073
2074    for _, arg in struct.member_list():
2075        arg.attr_put(ri, "obj")
2076
2077    if struct.submsg is None:
2078        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2079
2080    ri.cw.nl()
2081    ri.cw.p('return 0;')
2082    ri.cw.block_end()
2083    ri.cw.nl()
2084
2085
2086def _multi_parse(ri, struct, init_lines, local_vars):
2087    if struct.fixed_header:
2088        local_vars += ['void *hdr;']
2089    if struct.nested:
2090        if struct.fixed_header:
2091            iter_line = f"ynl_attr_for_each_nested_off(attr, nested, sizeof({struct.fixed_header}))"
2092        else:
2093            iter_line = "ynl_attr_for_each_nested(attr, nested)"
2094    else:
2095        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2096        if ri.op.fixed_header != ri.family.fixed_header:
2097            if ri.family.is_classic():
2098                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2099            else:
2100                raise Exception(f"Per-op fixed header not supported, yet")
2101
2102    array_nests = set()
2103    multi_attrs = set()
2104    needs_parg = False
2105    for arg, aspec in struct.member_list():
2106        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2107            if aspec["sub-type"] in {'binary', 'nest'}:
2108                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2109                array_nests.add(arg)
2110            elif aspec['sub-type'] in scalars:
2111                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2112                array_nests.add(arg)
2113            else:
2114                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2115        if 'multi-attr' in aspec:
2116            multi_attrs.add(arg)
2117        needs_parg |= 'nested-attributes' in aspec
2118        needs_parg |= 'sub-message' in aspec
2119    if array_nests or multi_attrs:
2120        local_vars.append('int i;')
2121    if needs_parg:
2122        local_vars.append('struct ynl_parse_arg parg;')
2123        init_lines.append('parg.ys = yarg->ys;')
2124
2125    all_multi = array_nests | multi_attrs
2126
2127    for anest in sorted(all_multi):
2128        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
2129
2130    ri.cw.block_start()
2131    ri.cw.write_func_lvar(local_vars)
2132
2133    for line in init_lines:
2134        ri.cw.p(line)
2135    ri.cw.nl()
2136
2137    for arg in struct.inherited:
2138        ri.cw.p(f'dst->{arg} = {arg};')
2139
2140    if struct.fixed_header:
2141        if struct.nested:
2142            ri.cw.p('hdr = ynl_attr_data(nested);')
2143        elif ri.family.is_classic():
2144            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2145        else:
2146            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2147        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2148    for anest in sorted(all_multi):
2149        aspec = struct[anest]
2150        ri.cw.p(f"if (dst->{aspec.c_name})")
2151        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2152
2153    ri.cw.nl()
2154    ri.cw.block_start(line=iter_line)
2155    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2156    ri.cw.nl()
2157
2158    first = True
2159    for _, arg in struct.member_list():
2160        good = arg.attr_get(ri, 'dst', first=first)
2161        # First may be 'unused' or 'pad', ignore those
2162        first &= not good
2163
2164    ri.cw.block_end()
2165    ri.cw.nl()
2166
2167    for anest in sorted(array_nests):
2168        aspec = struct[anest]
2169
2170        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2171        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2172        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2173        ri.cw.p('i = 0;')
2174        if 'nested-attributes' in aspec:
2175            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2176        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2177        if 'nested-attributes' in aspec:
2178            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2179            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2180            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2181        elif aspec.sub_type in scalars:
2182            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2183        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2184            # Length is validated by typol
2185            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2186        else:
2187            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2188        ri.cw.p('i++;')
2189        ri.cw.block_end()
2190        ri.cw.block_end()
2191    ri.cw.nl()
2192
2193    for anest in sorted(multi_attrs):
2194        aspec = struct[anest]
2195        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2196        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2197        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2198        ri.cw.p('i = 0;')
2199        if 'nested-attributes' in aspec:
2200            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2201        ri.cw.block_start(line=iter_line)
2202        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2203        if 'nested-attributes' in aspec:
2204            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2205            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2206            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2207        elif aspec.type in scalars:
2208            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2209        elif aspec.type == 'binary' and 'struct' in aspec:
2210            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2211            ri.cw.nl()
2212            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2213            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2214            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2215        elif aspec.type == 'string':
2216            ri.cw.p('unsigned int len;')
2217            ri.cw.nl()
2218            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2219            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2220            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2221            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2222            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2223        else:
2224            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2225        ri.cw.p('i++;')
2226        ri.cw.block_end()
2227        ri.cw.block_end()
2228        ri.cw.block_end()
2229    ri.cw.nl()
2230
2231    if struct.nested:
2232        ri.cw.p('return 0;')
2233    else:
2234        ri.cw.p('return YNL_PARSE_CB_OK;')
2235    ri.cw.block_end()
2236    ri.cw.nl()
2237
2238
2239def parse_rsp_submsg(ri, struct):
2240    parse_rsp_nested_prototype(ri, struct, suffix='')
2241
2242    var = 'dst'
2243    local_vars = {'const struct nlattr *attr = nested;',
2244                  f'{struct.ptr_name}{var} = yarg->data;',
2245                  'struct ynl_parse_arg parg;'}
2246
2247    for _, arg in struct.member_list():
2248        _, _, l_vars = arg._attr_get(ri, var)
2249        local_vars |= set(l_vars) if l_vars else set()
2250
2251    ri.cw.block_start()
2252    ri.cw.write_func_lvar(list(local_vars))
2253    ri.cw.p('parg.ys = yarg->ys;')
2254    ri.cw.nl()
2255
2256    first = True
2257    for name, arg in struct.member_list():
2258        kw = 'if' if first else 'else if'
2259        first = False
2260
2261        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2262        get_lines, init_lines, _ = arg._attr_get(ri, var)
2263        for line in init_lines or []:
2264            ri.cw.p(line)
2265        for line in get_lines:
2266            ri.cw.p(line)
2267        if arg.presence_type() == 'present':
2268            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2269        ri.cw.block_end()
2270    ri.cw.p('return 0;')
2271    ri.cw.block_end()
2272    ri.cw.nl()
2273
2274
2275def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2276    func_args = ['struct ynl_parse_arg *yarg',
2277                 'const struct nlattr *nested']
2278    for sel in struct.external_selectors():
2279        func_args.append('const char *_sel_' + sel.name)
2280    if struct.submsg:
2281        func_args.insert(1, 'const char *sel')
2282    for arg in struct.inherited:
2283        func_args.append('__u32 ' + arg)
2284
2285    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2286                          suffix=suffix)
2287
2288
2289def parse_rsp_nested(ri, struct):
2290    if struct.submsg:
2291        return parse_rsp_submsg(ri, struct)
2292
2293    parse_rsp_nested_prototype(ri, struct, suffix='')
2294
2295    local_vars = ['const struct nlattr *attr;',
2296                  f'{struct.ptr_name}dst = yarg->data;']
2297    init_lines = []
2298
2299    if struct.member_list():
2300        _multi_parse(ri, struct, init_lines, local_vars)
2301    else:
2302        # Empty nest
2303        ri.cw.block_start()
2304        ri.cw.p('return 0;')
2305        ri.cw.block_end()
2306        ri.cw.nl()
2307
2308
2309def parse_rsp_msg(ri, deref=False):
2310    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2311        return
2312
2313    func_args = ['const struct nlmsghdr *nlh',
2314                 'struct ynl_parse_arg *yarg']
2315
2316    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2317                  'const struct nlattr *attr;']
2318    init_lines = ['dst = yarg->data;']
2319
2320    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2321
2322    if ri.struct["reply"].member_list():
2323        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2324    else:
2325        # Empty reply
2326        ri.cw.block_start()
2327        ri.cw.p('return YNL_PARSE_CB_OK;')
2328        ri.cw.block_end()
2329        ri.cw.nl()
2330
2331
2332def print_req(ri):
2333    ret_ok = '0'
2334    ret_err = '-1'
2335    direction = "request"
2336    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2337                  'struct nlmsghdr *nlh;',
2338                  'int err;']
2339
2340    if 'reply' in ri.op[ri.op_mode]:
2341        ret_ok = 'rsp'
2342        ret_err = 'NULL'
2343        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2344
2345    if ri.struct["request"].fixed_header:
2346        local_vars += ['size_t hdr_len;',
2347                       'void *hdr;']
2348
2349    for _, attr in ri.struct["request"].member_list():
2350        if attr.presence_type() == 'count':
2351            local_vars += ['unsigned int i;']
2352            break
2353
2354    print_prototype(ri, direction, terminate=False)
2355    ri.cw.block_start()
2356    ri.cw.write_func_lvar(local_vars)
2357
2358    if ri.family.is_classic():
2359        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2360    else:
2361        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2362
2363    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2364    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2365    if 'reply' in ri.op[ri.op_mode]:
2366        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2367    ri.cw.nl()
2368
2369    if ri.struct['request'].fixed_header:
2370        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2371        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2372        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2373        ri.cw.nl()
2374
2375    for _, attr in ri.struct["request"].member_list():
2376        attr.attr_put(ri, "req")
2377    ri.cw.nl()
2378
2379    if 'reply' in ri.op[ri.op_mode]:
2380        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2381        ri.cw.p('yrs.yarg.data = rsp;')
2382        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2383        if ri.op.value is not None:
2384            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2385        else:
2386            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2387        ri.cw.nl()
2388    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2389    ri.cw.p('if (err < 0)')
2390    if 'reply' in ri.op[ri.op_mode]:
2391        ri.cw.p('goto err_free;')
2392    else:
2393        ri.cw.p('return -1;')
2394    ri.cw.nl()
2395
2396    ri.cw.p(f"return {ret_ok};")
2397    ri.cw.nl()
2398
2399    if 'reply' in ri.op[ri.op_mode]:
2400        ri.cw.p('err_free:')
2401        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2402        ri.cw.p(f"return {ret_err};")
2403
2404    ri.cw.block_end()
2405
2406
2407def print_dump(ri):
2408    direction = "request"
2409    print_prototype(ri, direction, terminate=False)
2410    ri.cw.block_start()
2411    local_vars = ['struct ynl_dump_state yds = {};',
2412                  'struct nlmsghdr *nlh;',
2413                  'int err;']
2414
2415    if ri.struct['request'].fixed_header:
2416        local_vars += ['size_t hdr_len;',
2417                       'void *hdr;']
2418
2419    ri.cw.write_func_lvar(local_vars)
2420
2421    ri.cw.p('yds.yarg.ys = ys;')
2422    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2423    ri.cw.p("yds.yarg.data = NULL;")
2424    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2425    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2426    if ri.op.value is not None:
2427        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2428    else:
2429        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2430    ri.cw.nl()
2431    if ri.family.is_classic():
2432        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2433    else:
2434        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2435
2436    if ri.struct['request'].fixed_header:
2437        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2438        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2439        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2440        ri.cw.nl()
2441
2442    if "request" in ri.op[ri.op_mode]:
2443        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2444        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2445        ri.cw.nl()
2446        for _, attr in ri.struct["request"].member_list():
2447            attr.attr_put(ri, "req")
2448    ri.cw.nl()
2449
2450    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2451    ri.cw.p('if (err < 0)')
2452    ri.cw.p('goto free_list;')
2453    ri.cw.nl()
2454
2455    ri.cw.p('return yds.first;')
2456    ri.cw.nl()
2457    ri.cw.p('free_list:')
2458    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2459    ri.cw.p('return NULL;')
2460    ri.cw.block_end()
2461
2462
2463def call_free(ri, direction, var):
2464    return f"{op_prefix(ri, direction)}_free({var});"
2465
2466
2467def free_arg_name(direction):
2468    if direction:
2469        return direction_to_suffix[direction][1:]
2470    return 'obj'
2471
2472
2473def print_alloc_wrapper(ri, direction):
2474    name = op_prefix(ri, direction)
2475    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
2476    ri.cw.block_start()
2477    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
2478    ri.cw.block_end()
2479
2480
2481def print_free_prototype(ri, direction, suffix=';'):
2482    name = op_prefix(ri, direction)
2483    struct_name = name
2484    if ri.type_name_conflict:
2485        struct_name += '_'
2486    arg = free_arg_name(direction)
2487    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2488
2489
2490def print_nlflags_set(ri, direction):
2491    name = op_prefix(ri, direction)
2492    ri.cw.write_func_prot(f'static inline void', f"{name}_set_nlflags",
2493                          [f"struct {name} *req", "__u16 nl_flags"])
2494    ri.cw.block_start()
2495    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2496    ri.cw.block_end()
2497    ri.cw.nl()
2498
2499
2500def _print_type(ri, direction, struct):
2501    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2502    if not direction and ri.type_name_conflict:
2503        suffix += '_'
2504
2505    if ri.op_mode == 'dump' and not ri.type_oneside:
2506        suffix += '_dump'
2507
2508    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2509
2510    if ri.needs_nlflags(direction):
2511        ri.cw.p('__u16 _nlmsg_flags;')
2512        ri.cw.nl()
2513    if struct.fixed_header:
2514        ri.cw.p(struct.fixed_header + ' _hdr;')
2515        ri.cw.nl()
2516
2517    for type_filter in ['present', 'len', 'count']:
2518        meta_started = False
2519        for _, attr in struct.member_list():
2520            line = attr.presence_member(ri.ku_space, type_filter)
2521            if line:
2522                if not meta_started:
2523                    ri.cw.block_start(line=f"struct")
2524                    meta_started = True
2525                ri.cw.p(line)
2526        if meta_started:
2527            ri.cw.block_end(line=f'_{type_filter};')
2528    ri.cw.nl()
2529
2530    for arg in struct.inherited:
2531        ri.cw.p(f"__u32 {arg};")
2532
2533    for _, attr in struct.member_list():
2534        attr.struct_member(ri)
2535
2536    ri.cw.block_end(line=';')
2537    ri.cw.nl()
2538
2539
2540def print_type(ri, direction):
2541    _print_type(ri, direction, ri.struct[direction])
2542
2543
2544def print_type_full(ri, struct):
2545    _print_type(ri, "", struct)
2546
2547
2548def print_type_helpers(ri, direction, deref=False):
2549    print_free_prototype(ri, direction)
2550    ri.cw.nl()
2551
2552    if ri.needs_nlflags(direction):
2553        print_nlflags_set(ri, direction)
2554
2555    if ri.ku_space == 'user' and direction == 'request':
2556        for _, attr in ri.struct[direction].member_list():
2557            attr.setter(ri, ri.attr_set, direction, deref=deref)
2558    ri.cw.nl()
2559
2560
2561def print_req_type_helpers(ri):
2562    if ri.type_empty("request"):
2563        return
2564    print_alloc_wrapper(ri, "request")
2565    print_type_helpers(ri, "request")
2566
2567
2568def print_rsp_type_helpers(ri):
2569    if 'reply' not in ri.op[ri.op_mode]:
2570        return
2571    print_type_helpers(ri, "reply")
2572
2573
2574def print_parse_prototype(ri, direction, terminate=True):
2575    suffix = "_rsp" if direction == "reply" else "_req"
2576    term = ';' if terminate else ''
2577
2578    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2579                          ['const struct nlattr **tb',
2580                           f"struct {ri.op.render_name}{suffix} *req"],
2581                          suffix=term)
2582
2583
2584def print_req_type(ri):
2585    if ri.type_empty("request"):
2586        return
2587    print_type(ri, "request")
2588
2589
2590def print_req_free(ri):
2591    if 'request' not in ri.op[ri.op_mode]:
2592        return
2593    _free_type(ri, 'request', ri.struct['request'])
2594
2595
2596def print_rsp_type(ri):
2597    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2598        direction = 'reply'
2599    elif ri.op_mode == 'event':
2600        direction = 'reply'
2601    else:
2602        return
2603    print_type(ri, direction)
2604
2605
2606def print_wrapped_type(ri):
2607    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2608    if ri.op_mode == 'dump':
2609        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2610    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2611        ri.cw.p('__u16 family;')
2612        ri.cw.p('__u8 cmd;')
2613        ri.cw.p('struct ynl_ntf_base_type *next;')
2614        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2615    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2616    ri.cw.block_end(line=';')
2617    ri.cw.nl()
2618    print_free_prototype(ri, 'reply')
2619    ri.cw.nl()
2620
2621
2622def _free_type_members_iter(ri, struct):
2623    if struct.free_needs_iter():
2624        ri.cw.p('unsigned int i;')
2625        ri.cw.nl()
2626
2627
2628def _free_type_members(ri, var, struct, ref=''):
2629    for _, attr in struct.member_list():
2630        attr.free(ri, var, ref)
2631
2632
2633def _free_type(ri, direction, struct):
2634    var = free_arg_name(direction)
2635
2636    print_free_prototype(ri, direction, suffix='')
2637    ri.cw.block_start()
2638    _free_type_members_iter(ri, struct)
2639    _free_type_members(ri, var, struct)
2640    if direction:
2641        ri.cw.p(f'free({var});')
2642    ri.cw.block_end()
2643    ri.cw.nl()
2644
2645
2646def free_rsp_nested_prototype(ri):
2647        print_free_prototype(ri, "")
2648
2649
2650def free_rsp_nested(ri, struct):
2651    _free_type(ri, "", struct)
2652
2653
2654def print_rsp_free(ri):
2655    if 'reply' not in ri.op[ri.op_mode]:
2656        return
2657    _free_type(ri, 'reply', ri.struct['reply'])
2658
2659
2660def print_dump_type_free(ri):
2661    sub_type = type_name(ri, 'reply')
2662
2663    print_free_prototype(ri, 'reply', suffix='')
2664    ri.cw.block_start()
2665    ri.cw.p(f"{sub_type} *next = rsp;")
2666    ri.cw.nl()
2667    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2668    _free_type_members_iter(ri, ri.struct['reply'])
2669    ri.cw.p('rsp = next;')
2670    ri.cw.p('next = rsp->next;')
2671    ri.cw.nl()
2672
2673    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2674    ri.cw.p(f'free(rsp);')
2675    ri.cw.block_end()
2676    ri.cw.block_end()
2677    ri.cw.nl()
2678
2679
2680def print_ntf_type_free(ri):
2681    print_free_prototype(ri, 'reply', suffix='')
2682    ri.cw.block_start()
2683    _free_type_members_iter(ri, ri.struct['reply'])
2684    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2685    ri.cw.p(f'free(rsp);')
2686    ri.cw.block_end()
2687    ri.cw.nl()
2688
2689
2690def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2691    if terminate and ri and policy_should_be_static(struct.family):
2692        return
2693
2694    if terminate:
2695        prefix = 'extern '
2696    else:
2697        if ri and policy_should_be_static(struct.family):
2698            prefix = 'static '
2699        else:
2700            prefix = ''
2701
2702    suffix = ';' if terminate else ' = {'
2703
2704    max_attr = struct.attr_max_val
2705    if ri:
2706        name = ri.op.render_name
2707        if ri.op.dual_policy:
2708            name += '_' + ri.op_mode
2709    else:
2710        name = struct.render_name
2711    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2712
2713
2714def print_req_policy(cw, struct, ri=None):
2715    if ri and ri.op:
2716        cw.ifdef_block(ri.op.get('config-cond', None))
2717    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2718    for _, arg in struct.member_list():
2719        arg.attr_policy(cw)
2720    cw.p("};")
2721    cw.ifdef_block(None)
2722    cw.nl()
2723
2724
2725def kernel_can_gen_family_struct(family):
2726    return family.proto == 'genetlink'
2727
2728
2729def policy_should_be_static(family):
2730    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2731
2732
2733def print_kernel_policy_ranges(family, cw):
2734    first = True
2735    for _, attr_set in family.attr_sets.items():
2736        if attr_set.subset_of:
2737            continue
2738
2739        for _, attr in attr_set.items():
2740            if not attr.request:
2741                continue
2742            if 'full-range' not in attr.checks:
2743                continue
2744
2745            if first:
2746                cw.p('/* Integer value ranges */')
2747                first = False
2748
2749            sign = '' if attr.type[0] == 'u' else '_signed'
2750            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2751            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2752            members = []
2753            if 'min' in attr.checks:
2754                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2755            if 'max' in attr.checks:
2756                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2757            cw.write_struct_init(members)
2758            cw.block_end(line=';')
2759            cw.nl()
2760
2761
2762def print_kernel_policy_sparse_enum_validates(family, cw):
2763    first = True
2764    for _, attr_set in family.attr_sets.items():
2765        if attr_set.subset_of:
2766            continue
2767
2768        for _, attr in attr_set.items():
2769            if not attr.request:
2770                continue
2771            if not attr.enum_name:
2772                continue
2773            if 'sparse' not in attr.checks:
2774                continue
2775
2776            if first:
2777                cw.p('/* Sparse enums validation callbacks */')
2778                first = False
2779
2780            sign = '' if attr.type[0] == 'u' else '_signed'
2781            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2782            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2783                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2784            cw.block_start()
2785            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2786            enum = family.consts[attr['enum']]
2787            first_entry = True
2788            for entry in enum.entries.values():
2789                if first_entry:
2790                    first_entry = False
2791                else:
2792                    cw.p('fallthrough;')
2793                cw.p(f'case {entry.c_name}:')
2794            cw.p('return 0;')
2795            cw.block_end()
2796            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2797            cw.p('return -EINVAL;')
2798            cw.block_end()
2799            cw.nl()
2800
2801
2802def print_kernel_op_table_fwd(family, cw, terminate):
2803    exported = not kernel_can_gen_family_struct(family)
2804
2805    if not terminate or exported:
2806        cw.p(f"/* Ops table for {family.ident_name} */")
2807
2808        pol_to_struct = {'global': 'genl_small_ops',
2809                         'per-op': 'genl_ops',
2810                         'split': 'genl_split_ops'}
2811        struct_type = pol_to_struct[family.kernel_policy]
2812
2813        if not exported:
2814            cnt = ""
2815        elif family.kernel_policy == 'split':
2816            cnt = 0
2817            for op in family.ops.values():
2818                if 'do' in op:
2819                    cnt += 1
2820                if 'dump' in op:
2821                    cnt += 1
2822        else:
2823            cnt = len(family.ops)
2824
2825        qual = 'static const' if not exported else 'const'
2826        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2827        if terminate:
2828            cw.p(f"extern {line};")
2829        else:
2830            cw.block_start(line=line + ' =')
2831
2832    if not terminate:
2833        return
2834
2835    cw.nl()
2836    for name in family.hooks['pre']['do']['list']:
2837        cw.write_func_prot('int', c_lower(name),
2838                           ['const struct genl_split_ops *ops',
2839                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2840    for name in family.hooks['post']['do']['list']:
2841        cw.write_func_prot('void', c_lower(name),
2842                           ['const struct genl_split_ops *ops',
2843                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2844    for name in family.hooks['pre']['dump']['list']:
2845        cw.write_func_prot('int', c_lower(name),
2846                           ['struct netlink_callback *cb'], suffix=';')
2847    for name in family.hooks['post']['dump']['list']:
2848        cw.write_func_prot('int', c_lower(name),
2849                           ['struct netlink_callback *cb'], suffix=';')
2850
2851    cw.nl()
2852
2853    for op_name, op in family.ops.items():
2854        if op.is_async:
2855            continue
2856
2857        if 'do' in op:
2858            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2859            cw.write_func_prot('int', name,
2860                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2861
2862        if 'dump' in op:
2863            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2864            cw.write_func_prot('int', name,
2865                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2866    cw.nl()
2867
2868
2869def print_kernel_op_table_hdr(family, cw):
2870    print_kernel_op_table_fwd(family, cw, terminate=True)
2871
2872
2873def print_kernel_op_table(family, cw):
2874    print_kernel_op_table_fwd(family, cw, terminate=False)
2875    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2876        for op_name, op in family.ops.items():
2877            if op.is_async:
2878                continue
2879
2880            cw.ifdef_block(op.get('config-cond', None))
2881            cw.block_start()
2882            members = [('cmd', op.enum_name)]
2883            if 'dont-validate' in op:
2884                members.append(('validate',
2885                                ' | '.join([c_upper('genl-dont-validate-' + x)
2886                                            for x in op['dont-validate']])), )
2887            for op_mode in ['do', 'dump']:
2888                if op_mode in op:
2889                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2890                    members.append((op_mode + 'it', name))
2891            if family.kernel_policy == 'per-op':
2892                struct = Struct(family, op['attribute-set'],
2893                                type_list=op['do']['request']['attributes'])
2894
2895                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2896                members.append(('policy', name))
2897                members.append(('maxattr', struct.attr_max_val.enum_name))
2898            if 'flags' in op:
2899                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2900            cw.write_struct_init(members)
2901            cw.block_end(line=',')
2902    elif family.kernel_policy == 'split':
2903        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2904                    'dump': {'pre': 'start', 'post': 'done'}}
2905
2906        for op_name, op in family.ops.items():
2907            for op_mode in ['do', 'dump']:
2908                if op.is_async or op_mode not in op:
2909                    continue
2910
2911                cw.ifdef_block(op.get('config-cond', None))
2912                cw.block_start()
2913                members = [('cmd', op.enum_name)]
2914                if 'dont-validate' in op:
2915                    dont_validate = []
2916                    for x in op['dont-validate']:
2917                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2918                            continue
2919                        if op_mode == "dump" and x == 'strict':
2920                            continue
2921                        dont_validate.append(x)
2922
2923                    if dont_validate:
2924                        members.append(('validate',
2925                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2926                                                    for x in dont_validate])), )
2927                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2928                if 'pre' in op[op_mode]:
2929                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2930                members.append((op_mode + 'it', name))
2931                if 'post' in op[op_mode]:
2932                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2933                if 'request' in op[op_mode]:
2934                    struct = Struct(family, op['attribute-set'],
2935                                    type_list=op[op_mode]['request']['attributes'])
2936
2937                    if op.dual_policy:
2938                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2939                    else:
2940                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2941                    members.append(('policy', name))
2942                    members.append(('maxattr', struct.attr_max_val.enum_name))
2943                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2944                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2945                cw.write_struct_init(members)
2946                cw.block_end(line=',')
2947    cw.ifdef_block(None)
2948
2949    cw.block_end(line=';')
2950    cw.nl()
2951
2952
2953def print_kernel_mcgrp_hdr(family, cw):
2954    if not family.mcgrps['list']:
2955        return
2956
2957    cw.block_start('enum')
2958    for grp in family.mcgrps['list']:
2959        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2960        cw.p(grp_id)
2961    cw.block_end(';')
2962    cw.nl()
2963
2964
2965def print_kernel_mcgrp_src(family, cw):
2966    if not family.mcgrps['list']:
2967        return
2968
2969    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2970    for grp in family.mcgrps['list']:
2971        name = grp['name']
2972        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2973        cw.p('[' + grp_id + '] = { "' + name + '", },')
2974    cw.block_end(';')
2975    cw.nl()
2976
2977
2978def print_kernel_family_struct_hdr(family, cw):
2979    if not kernel_can_gen_family_struct(family):
2980        return
2981
2982    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2983    cw.nl()
2984    if 'sock-priv' in family.kernel_family:
2985        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2986        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2987        cw.nl()
2988
2989
2990def print_kernel_family_struct_src(family, cw):
2991    if not kernel_can_gen_family_struct(family):
2992        return
2993
2994    if 'sock-priv' in family.kernel_family:
2995        # Generate "trampolines" to make CFI happy
2996        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2997                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2998                      ["void *priv"])
2999        cw.nl()
3000        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
3001                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
3002                      ["void *priv"])
3003        cw.nl()
3004
3005    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
3006    cw.p('.name\t\t= ' + family.fam_key + ',')
3007    cw.p('.version\t= ' + family.ver_key + ',')
3008    cw.p('.netnsok\t= true,')
3009    cw.p('.parallel_ops\t= true,')
3010    cw.p('.module\t\t= THIS_MODULE,')
3011    if family.kernel_policy == 'per-op':
3012        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
3013        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3014    elif family.kernel_policy == 'split':
3015        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
3016        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3017    if family.mcgrps['list']:
3018        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
3019        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
3020    if 'sock-priv' in family.kernel_family:
3021        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
3022        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
3023        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
3024    cw.block_end(';')
3025
3026
3027def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
3028    start_line = 'enum'
3029    if enum_name in obj:
3030        if obj[enum_name]:
3031            start_line = 'enum ' + c_lower(obj[enum_name])
3032    elif ckey and ckey in obj:
3033        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3034    cw.block_start(line=start_line)
3035
3036
3037def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3038    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3039    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3040    max_value = f"({cnt_name} - 1)"
3041
3042    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3043    val = 0
3044    for op in family.msgs.values():
3045        if separate_ntf and ('notify' in op or 'event' in op):
3046            continue
3047
3048        suffix = ','
3049        if op.value != val:
3050            suffix = f" = {op.value},"
3051            val = op.value
3052        cw.p(op.enum_name + suffix)
3053        val += 1
3054    cw.nl()
3055    cw.p(cnt_name + ('' if max_by_define else ','))
3056    if not max_by_define:
3057        cw.p(f"{max_name} = {max_value}")
3058    cw.block_end(line=';')
3059    if max_by_define:
3060        cw.p(f"#define {max_name} {max_value}")
3061    cw.nl()
3062
3063
3064def render_uapi_directional(family, cw, max_by_define):
3065    max_name = f"{family.op_prefix}USER_MAX"
3066    cnt_name = f"__{family.op_prefix}USER_CNT"
3067    max_value = f"({cnt_name} - 1)"
3068
3069    cw.block_start(line='enum')
3070    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3071    val = 0
3072    for op in family.msgs.values():
3073        if 'do' in op and 'event' not in op:
3074            suffix = ','
3075            if op.value and op.value != val:
3076                suffix = f" = {op.value},"
3077                val = op.value
3078            cw.p(op.enum_name + suffix)
3079            val += 1
3080    cw.nl()
3081    cw.p(cnt_name + ('' if max_by_define else ','))
3082    if not max_by_define:
3083        cw.p(f"{max_name} = {max_value}")
3084    cw.block_end(line=';')
3085    if max_by_define:
3086        cw.p(f"#define {max_name} {max_value}")
3087    cw.nl()
3088
3089    max_name = f"{family.op_prefix}KERNEL_MAX"
3090    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3091    max_value = f"({cnt_name} - 1)"
3092
3093    cw.block_start(line='enum')
3094    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3095    val = 0
3096    for op in family.msgs.values():
3097        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3098            enum_name = op.enum_name
3099            if 'event' not in op and 'notify' not in op:
3100                enum_name = f'{enum_name}_REPLY'
3101
3102            suffix = ','
3103            if op.value and op.value != val:
3104                suffix = f" = {op.value},"
3105                val = op.value
3106            cw.p(enum_name + suffix)
3107            val += 1
3108    cw.nl()
3109    cw.p(cnt_name + ('' if max_by_define else ','))
3110    if not max_by_define:
3111        cw.p(f"{max_name} = {max_value}")
3112    cw.block_end(line=';')
3113    if max_by_define:
3114        cw.p(f"#define {max_name} {max_value}")
3115    cw.nl()
3116
3117
3118def render_uapi(family, cw):
3119    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3120    hdr_prot = hdr_prot.replace('/', '_')
3121    cw.p('#ifndef ' + hdr_prot)
3122    cw.p('#define ' + hdr_prot)
3123    cw.nl()
3124
3125    defines = [(family.fam_key, family["name"]),
3126               (family.ver_key, family.get('version', 1))]
3127    cw.writes_defines(defines)
3128    cw.nl()
3129
3130    defines = []
3131    for const in family['definitions']:
3132        if const.get('header'):
3133            continue
3134
3135        if const['type'] != 'const':
3136            cw.writes_defines(defines)
3137            defines = []
3138            cw.nl()
3139
3140        # Write kdoc for enum and flags (one day maybe also structs)
3141        if const['type'] == 'enum' or const['type'] == 'flags':
3142            enum = family.consts[const['name']]
3143
3144            if enum.header:
3145                continue
3146
3147            if enum.has_doc():
3148                if enum.has_entry_doc():
3149                    cw.p('/**')
3150                    doc = ''
3151                    if 'doc' in enum:
3152                        doc = ' - ' + enum['doc']
3153                    cw.write_doc_line(enum.enum_name + doc)
3154                else:
3155                    cw.p('/*')
3156                    cw.write_doc_line(enum['doc'], indent=False)
3157                for entry in enum.entries.values():
3158                    if entry.has_doc():
3159                        doc = '@' + entry.c_name + ': ' + entry['doc']
3160                        cw.write_doc_line(doc)
3161                cw.p(' */')
3162
3163            uapi_enum_start(family, cw, const, 'name')
3164            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3165            for entry in enum.entries.values():
3166                suffix = ','
3167                if entry.value_change:
3168                    suffix = f" = {entry.user_value()}" + suffix
3169                cw.p(entry.c_name + suffix)
3170
3171            if const.get('render-max', False):
3172                cw.nl()
3173                cw.p('/* private: */')
3174                if const['type'] == 'flags':
3175                    max_name = c_upper(name_pfx + 'mask')
3176                    max_val = f' = {enum.get_mask()},'
3177                    cw.p(max_name + max_val)
3178                else:
3179                    cnt_name = enum.enum_cnt_name
3180                    max_name = c_upper(name_pfx + 'max')
3181                    if not cnt_name:
3182                        cnt_name = '__' + name_pfx + 'max'
3183                    cw.p(c_upper(cnt_name) + ',')
3184                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3185            cw.block_end(line=';')
3186            cw.nl()
3187        elif const['type'] == 'const':
3188            defines.append([c_upper(family.get('c-define-name',
3189                                               f"{family.ident_name}-{const['name']}")),
3190                            const['value']])
3191
3192    if defines:
3193        cw.writes_defines(defines)
3194        cw.nl()
3195
3196    max_by_define = family.get('max-by-define', False)
3197
3198    for _, attr_set in family.attr_sets.items():
3199        if attr_set.subset_of:
3200            continue
3201
3202        max_value = f"({attr_set.cnt_name} - 1)"
3203
3204        val = 0
3205        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3206        for _, attr in attr_set.items():
3207            suffix = ','
3208            if attr.value != val:
3209                suffix = f" = {attr.value},"
3210                val = attr.value
3211            val += 1
3212            cw.p(attr.enum_name + suffix)
3213        if attr_set.items():
3214            cw.nl()
3215        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3216        if not max_by_define:
3217            cw.p(f"{attr_set.max_name} = {max_value}")
3218        cw.block_end(line=';')
3219        if max_by_define:
3220            cw.p(f"#define {attr_set.max_name} {max_value}")
3221        cw.nl()
3222
3223    # Commands
3224    separate_ntf = 'async-prefix' in family['operations']
3225
3226    if family.msg_id_model == 'unified':
3227        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3228    elif family.msg_id_model == 'directional':
3229        render_uapi_directional(family, cw, max_by_define)
3230    else:
3231        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3232
3233    if separate_ntf:
3234        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3235        for op in family.msgs.values():
3236            if separate_ntf and not ('notify' in op or 'event' in op):
3237                continue
3238
3239            suffix = ','
3240            if 'value' in op:
3241                suffix = f" = {op['value']},"
3242            cw.p(op.enum_name + suffix)
3243        cw.block_end(line=';')
3244        cw.nl()
3245
3246    # Multicast
3247    defines = []
3248    for grp in family.mcgrps['list']:
3249        name = grp['name']
3250        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3251                        f'{name}'])
3252    cw.nl()
3253    if defines:
3254        cw.writes_defines(defines)
3255        cw.nl()
3256
3257    cw.p(f'#endif /* {hdr_prot} */')
3258
3259
3260def _render_user_ntf_entry(ri, op):
3261    if not ri.family.is_classic():
3262        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3263    else:
3264        crud_op = ri.family.req_by_value[op.rsp_value]
3265        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3266    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3267    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3268    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3269    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3270    ri.cw.block_end(line=',')
3271
3272
3273def render_user_family(family, cw, prototype):
3274    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3275    if prototype:
3276        cw.p(f'extern {symbol};')
3277        return
3278
3279    if family.ntfs:
3280        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3281        for ntf_op_name, ntf_op in family.ntfs.items():
3282            if 'notify' in ntf_op:
3283                op = family.ops[ntf_op['notify']]
3284                ri = RenderInfo(cw, family, "user", op, "notify")
3285            elif 'event' in ntf_op:
3286                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3287            else:
3288                raise Exception('Invalid notification ' + ntf_op_name)
3289            _render_user_ntf_entry(ri, ntf_op)
3290        for op_name, op in family.ops.items():
3291            if 'event' not in op:
3292                continue
3293            ri = RenderInfo(cw, family, "user", op, "event")
3294            _render_user_ntf_entry(ri, op)
3295        cw.block_end(line=";")
3296        cw.nl()
3297
3298    cw.block_start(f'{symbol} = ')
3299    cw.p(f'.name\t\t= "{family.c_name}",')
3300    if family.is_classic():
3301        cw.p(f'.is_classic\t= true,')
3302        cw.p(f'.classic_id\t= {family.get("protonum")},')
3303    if family.is_classic():
3304        if family.fixed_header:
3305            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3306    elif family.fixed_header:
3307        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3308    else:
3309        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3310    if family.ntfs:
3311        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3312        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3313    cw.block_end(line=';')
3314
3315
3316def family_contains_bitfield32(family):
3317    for _, attr_set in family.attr_sets.items():
3318        if attr_set.subset_of:
3319            continue
3320        for _, attr in attr_set.items():
3321            if attr.type == "bitfield32":
3322                return True
3323    return False
3324
3325
3326def find_kernel_root(full_path):
3327    sub_path = ''
3328    while True:
3329        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3330        full_path = os.path.dirname(full_path)
3331        maintainers = os.path.join(full_path, "MAINTAINERS")
3332        if os.path.exists(maintainers):
3333            return full_path, sub_path[:-1]
3334
3335
3336def main():
3337    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3338    parser.add_argument('--mode', dest='mode', type=str, required=True,
3339                        choices=('user', 'kernel', 'uapi'))
3340    parser.add_argument('--spec', dest='spec', type=str, required=True)
3341    parser.add_argument('--header', dest='header', action='store_true', default=None)
3342    parser.add_argument('--source', dest='header', action='store_false')
3343    parser.add_argument('--user-header', nargs='+', default=[])
3344    parser.add_argument('--cmp-out', action='store_true', default=None,
3345                        help='Do not overwrite the output file if the new output is identical to the old')
3346    parser.add_argument('--exclude-op', action='append', default=[])
3347    parser.add_argument('-o', dest='out_file', type=str, default=None)
3348    args = parser.parse_args()
3349
3350    if args.header is None:
3351        parser.error("--header or --source is required")
3352
3353    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3354
3355    try:
3356        parsed = Family(args.spec, exclude_ops)
3357        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3358            print('Spec license:', parsed.license)
3359            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3360            os.sys.exit(1)
3361    except yaml.YAMLError as exc:
3362        print(exc)
3363        os.sys.exit(1)
3364        return
3365
3366    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3367
3368    _, spec_kernel = find_kernel_root(args.spec)
3369    if args.mode == 'uapi' or args.header:
3370        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3371    else:
3372        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3373    cw.p("/* Do not edit directly, auto-generated from: */")
3374    cw.p(f"/*\t{spec_kernel} */")
3375    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3376    if args.exclude_op or args.user_header:
3377        line = ''
3378        line += ' --user-header '.join([''] + args.user_header)
3379        line += ' --exclude-op '.join([''] + args.exclude_op)
3380        cw.p(f'/* YNL-ARG{line} */')
3381    cw.nl()
3382
3383    if args.mode == 'uapi':
3384        render_uapi(parsed, cw)
3385        return
3386
3387    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3388    if args.header:
3389        cw.p('#ifndef ' + hdr_prot)
3390        cw.p('#define ' + hdr_prot)
3391        cw.nl()
3392
3393    if args.out_file:
3394        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3395    else:
3396        hdr_file = "generated_header_file.h"
3397
3398    if args.mode == 'kernel':
3399        cw.p('#include <net/netlink.h>')
3400        cw.p('#include <net/genetlink.h>')
3401        cw.nl()
3402        if not args.header:
3403            if args.out_file:
3404                cw.p(f'#include "{hdr_file}"')
3405            cw.nl()
3406        headers = ['uapi/' + parsed.uapi_header]
3407        headers += parsed.kernel_family.get('headers', [])
3408    else:
3409        cw.p('#include <stdlib.h>')
3410        cw.p('#include <string.h>')
3411        if args.header:
3412            cw.p('#include <linux/types.h>')
3413            if family_contains_bitfield32(parsed):
3414                cw.p('#include <linux/netlink.h>')
3415        else:
3416            cw.p(f'#include "{hdr_file}"')
3417            cw.p('#include "ynl.h"')
3418        headers = []
3419    for definition in parsed['definitions'] + parsed['attribute-sets']:
3420        if 'header' in definition:
3421            headers.append(definition['header'])
3422    if args.mode == 'user':
3423        headers.append(parsed.uapi_header)
3424    seen_header = []
3425    for one in headers:
3426        if one not in seen_header:
3427            cw.p(f"#include <{one}>")
3428            seen_header.append(one)
3429    cw.nl()
3430
3431    if args.mode == "user":
3432        if not args.header:
3433            cw.p("#include <linux/genetlink.h>")
3434            cw.nl()
3435            for one in args.user_header:
3436                cw.p(f'#include "{one}"')
3437        else:
3438            cw.p('struct ynl_sock;')
3439            cw.nl()
3440            render_user_family(parsed, cw, True)
3441        cw.nl()
3442
3443    if args.mode == "kernel":
3444        if args.header:
3445            for _, struct in sorted(parsed.pure_nested_structs.items()):
3446                if struct.request:
3447                    cw.p('/* Common nested types */')
3448                    break
3449            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3450                if struct.request:
3451                    print_req_policy_fwd(cw, struct)
3452            cw.nl()
3453
3454            if parsed.kernel_policy == 'global':
3455                cw.p(f"/* Global operation policy for {parsed.name} */")
3456
3457                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3458                print_req_policy_fwd(cw, struct)
3459                cw.nl()
3460
3461            if parsed.kernel_policy in {'per-op', 'split'}:
3462                for op_name, op in parsed.ops.items():
3463                    if 'do' in op and 'event' not in op:
3464                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3465                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3466                        cw.nl()
3467
3468            print_kernel_op_table_hdr(parsed, cw)
3469            print_kernel_mcgrp_hdr(parsed, cw)
3470            print_kernel_family_struct_hdr(parsed, cw)
3471        else:
3472            print_kernel_policy_ranges(parsed, cw)
3473            print_kernel_policy_sparse_enum_validates(parsed, cw)
3474
3475            for _, struct in sorted(parsed.pure_nested_structs.items()):
3476                if struct.request:
3477                    cw.p('/* Common nested types */')
3478                    break
3479            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3480                if struct.request:
3481                    print_req_policy(cw, struct)
3482            cw.nl()
3483
3484            if parsed.kernel_policy == 'global':
3485                cw.p(f"/* Global operation policy for {parsed.name} */")
3486
3487                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3488                print_req_policy(cw, struct)
3489                cw.nl()
3490
3491            for op_name, op in parsed.ops.items():
3492                if parsed.kernel_policy in {'per-op', 'split'}:
3493                    for op_mode in ['do', 'dump']:
3494                        if op_mode in op and 'request' in op[op_mode]:
3495                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3496                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3497                            print_req_policy(cw, ri.struct['request'], ri=ri)
3498                            cw.nl()
3499
3500            print_kernel_op_table(parsed, cw)
3501            print_kernel_mcgrp_src(parsed, cw)
3502            print_kernel_family_struct_src(parsed, cw)
3503
3504    if args.mode == "user":
3505        if args.header:
3506            cw.p('/* Enums */')
3507            put_op_name_fwd(parsed, cw)
3508
3509            for name, const in parsed.consts.items():
3510                if isinstance(const, EnumSet):
3511                    put_enum_to_str_fwd(parsed, cw, const)
3512            cw.nl()
3513
3514            cw.p('/* Common nested types */')
3515            for attr_set, struct in parsed.pure_nested_structs.items():
3516                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3517                print_type_full(ri, struct)
3518                if struct.request and struct.in_multi_val:
3519                    free_rsp_nested_prototype(ri)
3520                    cw.nl()
3521
3522            for op_name, op in parsed.ops.items():
3523                cw.p(f"/* ============== {op.enum_name} ============== */")
3524
3525                if 'do' in op and 'event' not in op:
3526                    cw.p(f"/* {op.enum_name} - do */")
3527                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3528                    print_req_type(ri)
3529                    print_req_type_helpers(ri)
3530                    cw.nl()
3531                    print_rsp_type(ri)
3532                    print_rsp_type_helpers(ri)
3533                    cw.nl()
3534                    print_req_prototype(ri)
3535                    cw.nl()
3536
3537                if 'dump' in op:
3538                    cw.p(f"/* {op.enum_name} - dump */")
3539                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3540                    print_req_type(ri)
3541                    print_req_type_helpers(ri)
3542                    if not ri.type_consistent or ri.type_oneside:
3543                        print_rsp_type(ri)
3544                    print_wrapped_type(ri)
3545                    print_dump_prototype(ri)
3546                    cw.nl()
3547
3548                if op.has_ntf:
3549                    cw.p(f"/* {op.enum_name} - notify */")
3550                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3551                    if not ri.type_consistent:
3552                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3553                    print_wrapped_type(ri)
3554
3555            for op_name, op in parsed.ntfs.items():
3556                if 'event' in op:
3557                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3558                    cw.p(f"/* {op.enum_name} - event */")
3559                    print_rsp_type(ri)
3560                    cw.nl()
3561                    print_wrapped_type(ri)
3562            cw.nl()
3563        else:
3564            cw.p('/* Enums */')
3565            put_op_name(parsed, cw)
3566
3567            for name, const in parsed.consts.items():
3568                if isinstance(const, EnumSet):
3569                    put_enum_to_str(parsed, cw, const)
3570            cw.nl()
3571
3572            has_recursive_nests = False
3573            cw.p('/* Policies */')
3574            for struct in parsed.pure_nested_structs.values():
3575                if struct.recursive:
3576                    put_typol_fwd(cw, struct)
3577                    has_recursive_nests = True
3578            if has_recursive_nests:
3579                cw.nl()
3580            for struct in parsed.pure_nested_structs.values():
3581                put_typol(cw, struct)
3582            for name in parsed.root_sets:
3583                struct = Struct(parsed, name)
3584                put_typol(cw, struct)
3585
3586            cw.p('/* Common nested types */')
3587            if has_recursive_nests:
3588                for attr_set, struct in parsed.pure_nested_structs.items():
3589                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3590                    free_rsp_nested_prototype(ri)
3591                    if struct.request:
3592                        put_req_nested_prototype(ri, struct)
3593                    if struct.reply:
3594                        parse_rsp_nested_prototype(ri, struct)
3595                cw.nl()
3596            for attr_set, struct in parsed.pure_nested_structs.items():
3597                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3598
3599                free_rsp_nested(ri, struct)
3600                if struct.request:
3601                    put_req_nested(ri, struct)
3602                if struct.reply:
3603                    parse_rsp_nested(ri, struct)
3604
3605            for op_name, op in parsed.ops.items():
3606                cw.p(f"/* ============== {op.enum_name} ============== */")
3607                if 'do' in op and 'event' not in op:
3608                    cw.p(f"/* {op.enum_name} - do */")
3609                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3610                    print_req_free(ri)
3611                    print_rsp_free(ri)
3612                    parse_rsp_msg(ri)
3613                    print_req(ri)
3614                    cw.nl()
3615
3616                if 'dump' in op:
3617                    cw.p(f"/* {op.enum_name} - dump */")
3618                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3619                    if not ri.type_consistent or ri.type_oneside:
3620                        parse_rsp_msg(ri, deref=True)
3621                    print_req_free(ri)
3622                    print_dump_type_free(ri)
3623                    print_dump(ri)
3624                    cw.nl()
3625
3626                if op.has_ntf:
3627                    cw.p(f"/* {op.enum_name} - notify */")
3628                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3629                    if not ri.type_consistent:
3630                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3631                    print_ntf_type_free(ri)
3632
3633            for op_name, op in parsed.ntfs.items():
3634                if 'event' in op:
3635                    cw.p(f"/* {op.enum_name} - event */")
3636
3637                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3638                    parse_rsp_msg(ri)
3639
3640                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3641                    print_ntf_type_free(ri)
3642            cw.nl()
3643            render_user_family(parsed, cw, False)
3644
3645    if args.header:
3646        cw.p(f'#endif /* {hdr_prot} */')
3647
3648
3649if __name__ == "__main__":
3650    main()
3651