xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 02962ddb3936f91780c2e05662e4b20c7b35a2df)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17from lib import SpecSubMessage, SpecSubMessageFormat
18
19
20def c_upper(name):
21    return name.upper().replace('-', '_')
22
23
24def c_lower(name):
25    return name.lower().replace('-', '_')
26
27
28def limit_to_number(name):
29    """
30    Turn a string limit like u32-max or s64-min into its numerical value
31    """
32    if name[0] == 'u' and name.endswith('-min'):
33        return 0
34    width = int(name[1:-4])
35    if name[0] == 's':
36        width -= 1
37    value = (1 << width) - 1
38    if name[0] == 's' and name.endswith('-min'):
39        value = -value - 1
40    return value
41
42
43class BaseNlLib:
44    def get_family_id(self):
45        return 'ys->family_id'
46
47
48class Type(SpecAttr):
49    def __init__(self, family, attr_set, attr, value):
50        super().__init__(family, attr_set, attr, value)
51
52        self.attr = attr
53        self.attr_set = attr_set
54        self.type = attr['type']
55        self.checks = attr.get('checks', {})
56
57        self.request = False
58        self.reply = False
59
60        self.is_selector = False
61
62        if 'len' in attr:
63            self.len = attr['len']
64
65        if 'nested-attributes' in attr:
66            nested = attr['nested-attributes']
67        elif 'sub-message' in attr:
68            nested = attr['sub-message']
69        else:
70            nested = None
71
72        if nested:
73            self.nested_attrs = nested
74            if self.nested_attrs == family.name:
75                self.nested_render_name = c_lower(f"{family.ident_name}")
76            else:
77                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
78
79            if self.nested_attrs in self.family.consts:
80                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
81            else:
82                self.nested_struct_type = 'struct ' + self.nested_render_name
83
84        self.c_name = c_lower(self.name)
85        if self.c_name in _C_KW:
86            self.c_name += '_'
87        if self.c_name[0].isdigit():
88            self.c_name = '_' + self.c_name
89
90        # Added by resolve():
91        self.enum_name = None
92        delattr(self, "enum_name")
93
94    def _get_real_attr(self):
95        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
96        return self.family.attr_sets[self.attr_set.subset_of][self.name]
97
98    def set_request(self):
99        self.request = True
100        if self.attr_set.subset_of:
101            self._get_real_attr().set_request()
102
103    def set_reply(self):
104        self.reply = True
105        if self.attr_set.subset_of:
106            self._get_real_attr().set_reply()
107
108    def get_limit(self, limit, default=None):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return value
112        if isinstance(value, int):
113            return value
114        if value in self.family.consts:
115            return self.family.consts[value]["value"]
116        return limit_to_number(value)
117
118    def get_limit_str(self, limit, default=None, suffix=''):
119        value = self.checks.get(limit, default)
120        if value is None:
121            return ''
122        if isinstance(value, int):
123            return str(value) + suffix
124        if value in self.family.consts:
125            const = self.family.consts[value]
126            if const.get('header'):
127                return c_upper(value)
128            return c_upper(f"{self.family['name']}-{value}")
129        return c_upper(value)
130
131    def resolve(self):
132        if 'parent-sub-message' in self.attr:
133            enum_name = self.attr['parent-sub-message'].enum_name
134        elif 'name-prefix' in self.attr:
135            enum_name = f"{self.attr['name-prefix']}{self.name}"
136        else:
137            enum_name = f"{self.attr_set.name_prefix}{self.name}"
138        self.enum_name = c_upper(enum_name)
139
140        if self.attr_set.subset_of:
141            if self.checks != self._get_real_attr().checks:
142                raise Exception("Overriding checks not supported by codegen, yet")
143
144    def is_multi_val(self):
145        return None
146
147    def is_scalar(self):
148        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
149
150    def is_recursive(self):
151        return False
152
153    def is_recursive_for_op(self, ri):
154        return self.is_recursive() and not ri.op
155
156    def presence_type(self):
157        return 'present'
158
159    def presence_member(self, space, type_filter):
160        if self.presence_type() != type_filter:
161            return
162
163        if self.presence_type() == 'present':
164            pfx = '__' if space == 'user' else ''
165            return f"{pfx}u32 {self.c_name}:1;"
166
167        if self.presence_type() in {'len', 'count'}:
168            pfx = '__' if space == 'user' else ''
169            return f"{pfx}u32 {self.c_name};"
170
171    def _complex_member_type(self, ri):
172        return None
173
174    def free_needs_iter(self):
175        return False
176
177    def _free_lines(self, ri, var, ref):
178        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
179            return [f'free({var}->{ref}{self.c_name});']
180        return []
181
182    def free(self, ri, var, ref):
183        lines = self._free_lines(ri, var, ref)
184        for line in lines:
185            ri.cw.p(line)
186
187    def arg_member(self, ri):
188        member = self._complex_member_type(ri)
189        if member:
190            spc = ' ' if member[-1] != '*' else ''
191            arg = [member + spc + '*' + self.c_name]
192            if self.presence_type() == 'count':
193                arg += ['unsigned int n_' + self.c_name]
194            return arg
195        raise Exception(f"Struct member not implemented for class type {self.type}")
196
197    def struct_member(self, ri):
198        member = self._complex_member_type(ri)
199        if member:
200            ptr = '*' if self.is_multi_val() else ''
201            if self.is_recursive_for_op(ri):
202                ptr = '*'
203            spc = ' ' if member[-1] != '*' else ''
204            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
205            return
206        members = self.arg_member(ri)
207        for one in members:
208            ri.cw.p(one + ';')
209
210    def _attr_policy(self, policy):
211        return '{ .type = ' + policy + ', }'
212
213    def attr_policy(self, cw):
214        policy = f'NLA_{c_upper(self.type)}'
215        if self.attr.get('byte-order') == 'big-endian':
216            if self.type in {'u16', 'u32'}:
217                policy = f'NLA_BE{self.type[1:]}'
218
219        spec = self._attr_policy(policy)
220        cw.p(f"\t[{self.enum_name}] = {spec},")
221
222    def _attr_typol(self):
223        raise Exception(f"Type policy not implemented for class type {self.type}")
224
225    def attr_typol(self, cw):
226        typol = self._attr_typol()
227        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
228
229    def _attr_put_line(self, ri, var, line):
230        presence = self.presence_type()
231        if presence in {'present', 'len'}:
232            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
233        ri.cw.p(f"{line};")
234
235    def _attr_put_simple(self, ri, var, put_type):
236        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
237        self._attr_put_line(ri, var, line)
238
239    def attr_put(self, ri, var):
240        raise Exception(f"Put not implemented for class type {self.type}")
241
242    def _attr_get(self, ri, var):
243        raise Exception(f"Attr get not implemented for class type {self.type}")
244
245    def attr_get(self, ri, var, first):
246        lines, init_lines, local_vars = self._attr_get(ri, var)
247        if type(lines) is str:
248            lines = [lines]
249        if type(init_lines) is str:
250            init_lines = [init_lines]
251
252        kw = 'if' if first else 'else if'
253        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
254        if local_vars:
255            for local in local_vars:
256                ri.cw.p(local)
257            ri.cw.nl()
258
259        if not self.is_multi_val():
260            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
261            ri.cw.p("return YNL_PARSE_CB_ERROR;")
262            if self.presence_type() == 'present':
263                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
264
265        if init_lines:
266            ri.cw.nl()
267            for line in init_lines:
268                ri.cw.p(line)
269
270        for line in lines:
271            ri.cw.p(line)
272        ri.cw.block_end()
273        return True
274
275    def _setter_lines(self, ri, member, presence):
276        raise Exception(f"Setter not implemented for class type {self.type}")
277
278    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
279        ref = (ref if ref else []) + [self.c_name]
280        member = f"{var}->{'.'.join(ref)}"
281
282        local_vars = []
283        if self.free_needs_iter():
284            local_vars += ['unsigned int i;']
285
286        code = []
287        presence = ''
288        for i in range(0, len(ref)):
289            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
290            # Every layer below last is a nest, so we know it uses bit presence
291            # last layer is "self" and may be a complex type
292            if i == len(ref) - 1 and self.presence_type() != 'present':
293                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
294                continue
295            code.append(presence + ' = 1;')
296        ref_path = '.'.join(ref[:-1])
297        if ref_path:
298            ref_path += '.'
299        code += self._free_lines(ri, var, ref_path)
300        code += self._setter_lines(ri, member, presence)
301
302        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
303        free = bool([x for x in code if 'free(' in x])
304        alloc = bool([x for x in code if 'alloc(' in x])
305        if free and not alloc:
306            func_name = '__' + func_name
307        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
308                         body=code,
309                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
310
311
312class TypeUnused(Type):
313    def presence_type(self):
314        return ''
315
316    def arg_member(self, ri):
317        return []
318
319    def _attr_get(self, ri, var):
320        return ['return YNL_PARSE_CB_ERROR;'], None, None
321
322    def _attr_typol(self):
323        return '.type = YNL_PT_REJECT, '
324
325    def attr_policy(self, cw):
326        pass
327
328    def attr_put(self, ri, var):
329        pass
330
331    def attr_get(self, ri, var, first):
332        pass
333
334    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
335        pass
336
337
338class TypePad(Type):
339    def presence_type(self):
340        return ''
341
342    def arg_member(self, ri):
343        return []
344
345    def _attr_typol(self):
346        return '.type = YNL_PT_IGNORE, '
347
348    def attr_put(self, ri, var):
349        pass
350
351    def attr_get(self, ri, var, first):
352        pass
353
354    def attr_policy(self, cw):
355        pass
356
357    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
358        pass
359
360
361class TypeScalar(Type):
362    def __init__(self, family, attr_set, attr, value):
363        super().__init__(family, attr_set, attr, value)
364
365        self.byte_order_comment = ''
366        if 'byte-order' in attr:
367            self.byte_order_comment = f" /* {attr['byte-order']} */"
368
369        # Classic families have some funny enums, don't bother
370        # computing checks, since we only need them for kernel policies
371        if not family.is_classic():
372            self._init_checks()
373
374        # Added by resolve():
375        self.is_bitfield = None
376        delattr(self, "is_bitfield")
377        self.type_name = None
378        delattr(self, "type_name")
379
380    def resolve(self):
381        self.resolve_up(super())
382
383        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
384            self.is_bitfield = True
385        elif 'enum' in self.attr:
386            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
387        else:
388            self.is_bitfield = False
389
390        if not self.is_bitfield and 'enum' in self.attr:
391            self.type_name = self.family.consts[self.attr['enum']].user_type
392        elif self.is_auto_scalar:
393            self.type_name = '__' + self.type[0] + '64'
394        else:
395            self.type_name = '__' + self.type
396
397    def _init_checks(self):
398        if 'enum' in self.attr:
399            enum = self.family.consts[self.attr['enum']]
400            low, high = enum.value_range()
401            if low == None and high == None:
402                self.checks['sparse'] = True
403            else:
404                if 'min' not in self.checks:
405                    if low != 0 or self.type[0] == 's':
406                        self.checks['min'] = low
407                if 'max' not in self.checks:
408                    self.checks['max'] = high
409
410        if 'min' in self.checks and 'max' in self.checks:
411            if self.get_limit('min') > self.get_limit('max'):
412                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
413            self.checks['range'] = True
414
415        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
416        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
417        if low < 0 and self.type[0] == 'u':
418            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
419        if low < -32768 or high > 32767:
420            self.checks['full-range'] = True
421
422    def _attr_policy(self, policy):
423        if 'flags-mask' in self.checks or self.is_bitfield:
424            if self.is_bitfield:
425                enum = self.family.consts[self.attr['enum']]
426                mask = enum.get_mask(as_flags=True)
427            else:
428                flags = self.family.consts[self.checks['flags-mask']]
429                flag_cnt = len(flags['entries'])
430                mask = (1 << flag_cnt) - 1
431            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
432        elif 'full-range' in self.checks:
433            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
434        elif 'range' in self.checks:
435            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
436        elif 'min' in self.checks:
437            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
438        elif 'max' in self.checks:
439            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
440        elif 'sparse' in self.checks:
441            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
442        return super()._attr_policy(policy)
443
444    def _attr_typol(self):
445        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
446
447    def arg_member(self, ri):
448        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
449
450    def attr_put(self, ri, var):
451        self._attr_put_simple(ri, var, self.type)
452
453    def _attr_get(self, ri, var):
454        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
455
456    def _setter_lines(self, ri, member, presence):
457        return [f"{member} = {self.c_name};"]
458
459
460class TypeFlag(Type):
461    def arg_member(self, ri):
462        return []
463
464    def _attr_typol(self):
465        return '.type = YNL_PT_FLAG, '
466
467    def attr_put(self, ri, var):
468        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
469
470    def _attr_get(self, ri, var):
471        return [], None, None
472
473    def _setter_lines(self, ri, member, presence):
474        return []
475
476
477class TypeString(Type):
478    def arg_member(self, ri):
479        return [f"const char *{self.c_name}"]
480
481    def presence_type(self):
482        return 'len'
483
484    def struct_member(self, ri):
485        ri.cw.p(f"char *{self.c_name};")
486
487    def _attr_typol(self):
488        typol = f'.type = YNL_PT_NUL_STR, '
489        if self.is_selector:
490            typol += '.is_selector = 1, '
491        return typol
492
493    def _attr_policy(self, policy):
494        if 'exact-len' in self.checks:
495            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
496        else:
497            mem = '{ .type = ' + policy
498            if 'max-len' in self.checks:
499                mem += ', .len = ' + self.get_limit_str('max-len')
500            mem += ', }'
501        return mem
502
503    def attr_policy(self, cw):
504        if self.checks.get('unterminated-ok', False):
505            policy = 'NLA_STRING'
506        else:
507            policy = 'NLA_NUL_STRING'
508
509        spec = self._attr_policy(policy)
510        cw.p(f"\t[{self.enum_name}] = {spec},")
511
512    def attr_put(self, ri, var):
513        self._attr_put_simple(ri, var, 'str')
514
515    def _attr_get(self, ri, var):
516        len_mem = var + '->_len.' + self.c_name
517        return [f"{len_mem} = len;",
518                f"{var}->{self.c_name} = malloc(len + 1);",
519                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
520                f"{var}->{self.c_name}[len] = 0;"], \
521               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
522               ['unsigned int len;']
523
524    def _setter_lines(self, ri, member, presence):
525        return [f"{presence} = strlen({self.c_name});",
526                f"{member} = malloc({presence} + 1);",
527                f'memcpy({member}, {self.c_name}, {presence});',
528                f'{member}[{presence}] = 0;']
529
530
531class TypeBinary(Type):
532    def arg_member(self, ri):
533        return [f"const void *{self.c_name}", 'size_t len']
534
535    def presence_type(self):
536        return 'len'
537
538    def struct_member(self, ri):
539        ri.cw.p(f"void *{self.c_name};")
540
541    def _attr_typol(self):
542        return f'.type = YNL_PT_BINARY,'
543
544    def _attr_policy(self, policy):
545        if len(self.checks) == 0:
546            pass
547        elif len(self.checks) == 1:
548            check_name = list(self.checks)[0]
549            if check_name not in {'exact-len', 'min-len', 'max-len'}:
550                raise Exception('Unsupported check for binary type: ' + check_name)
551        else:
552            raise Exception('More than one check for binary type not implemented, yet')
553
554        if len(self.checks) == 0:
555            mem = '{ .type = NLA_BINARY, }'
556        elif 'exact-len' in self.checks:
557            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
558        elif 'min-len' in self.checks:
559            mem = 'NLA_POLICY_MIN_LEN(' + self.get_limit_str('min-len') + ')'
560        elif 'max-len' in self.checks:
561            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
562
563        return mem
564
565    def attr_put(self, ri, var):
566        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
567                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
568
569    def _attr_get(self, ri, var):
570        len_mem = var + '->_len.' + self.c_name
571        return [f"{len_mem} = len;",
572                f"{var}->{self.c_name} = malloc(len);",
573                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
574               ['len = ynl_attr_data_len(attr);'], \
575               ['unsigned int len;']
576
577    def _setter_lines(self, ri, member, presence):
578        return [f"{presence} = len;",
579                f"{member} = malloc({presence});",
580                f'memcpy({member}, {self.c_name}, {presence});']
581
582
583class TypeBinaryStruct(TypeBinary):
584    def struct_member(self, ri):
585        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
586
587    def _attr_get(self, ri, var):
588        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
589        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
590        return [f"{len_mem} = len;",
591                f"if (len < {struct_sz})",
592                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
593                "else",
594                f"{var}->{self.c_name} = malloc(len);",
595                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
596               ['len = ynl_attr_data_len(attr);'], \
597               ['unsigned int len;']
598
599
600class TypeBinaryScalarArray(TypeBinary):
601    def arg_member(self, ri):
602        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
603
604    def presence_type(self):
605        return 'count'
606
607    def struct_member(self, ri):
608        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
609
610    def attr_put(self, ri, var):
611        presence = self.presence_type()
612        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
613        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
614        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
615                f"{var}->{self.c_name}, i);")
616        ri.cw.block_end()
617
618    def _attr_get(self, ri, var):
619        len_mem = var + '->_count.' + self.c_name
620        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
621                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
622                f"{var}->{self.c_name} = malloc(len);",
623                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
624               ['len = ynl_attr_data_len(attr);'], \
625               ['unsigned int len;']
626
627    def _setter_lines(self, ri, member, presence):
628        return [f"{presence} = count;",
629                f"count *= sizeof(__{self.get('sub-type')});",
630                f"{member} = malloc(count);",
631                f'memcpy({member}, {self.c_name}, count);']
632
633
634class TypeBitfield32(Type):
635    def _complex_member_type(self, ri):
636        return "struct nla_bitfield32"
637
638    def _attr_typol(self):
639        return f'.type = YNL_PT_BITFIELD32, '
640
641    def _attr_policy(self, policy):
642        if not 'enum' in self.attr:
643            raise Exception('Enum required for bitfield32 attr')
644        enum = self.family.consts[self.attr['enum']]
645        mask = enum.get_mask(as_flags=True)
646        return f"NLA_POLICY_BITFIELD32({mask})"
647
648    def attr_put(self, ri, var):
649        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
650        self._attr_put_line(ri, var, line)
651
652    def _attr_get(self, ri, var):
653        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
654
655    def _setter_lines(self, ri, member, presence):
656        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
657
658
659class TypeNest(Type):
660    def is_recursive(self):
661        return self.family.pure_nested_structs[self.nested_attrs].recursive
662
663    def _complex_member_type(self, ri):
664        return self.nested_struct_type
665
666    def _free_lines(self, ri, var, ref):
667        lines = []
668        at = '&'
669        if self.is_recursive_for_op(ri):
670            at = ''
671            lines += [f'if ({var}->{ref}{self.c_name})']
672        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
673        return lines
674
675    def _attr_typol(self):
676        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
677
678    def _attr_policy(self, policy):
679        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
680
681    def attr_put(self, ri, var):
682        at = '' if self.is_recursive_for_op(ri) else '&'
683        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
684                            f"{self.enum_name}, {at}{var}->{self.c_name})")
685
686    def _attr_get(self, ri, var):
687        pns = self.family.pure_nested_structs[self.nested_attrs]
688        args = ["&parg", "attr"]
689        for sel in pns.external_selectors():
690            args.append(f'{var}->{sel.name}')
691        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
692                     "return YNL_PARSE_CB_ERROR;"]
693        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
694                      f"parg.data = &{var}->{self.c_name};"]
695        return get_lines, init_lines, None
696
697    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
698        ref = (ref if ref else []) + [self.c_name]
699
700        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
701            if attr.is_recursive():
702                continue
703            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref,
704                        var=var)
705
706
707class TypeMultiAttr(Type):
708    def __init__(self, family, attr_set, attr, value, base_type):
709        super().__init__(family, attr_set, attr, value)
710
711        self.base_type = base_type
712
713    def is_multi_val(self):
714        return True
715
716    def presence_type(self):
717        return 'count'
718
719    def _complex_member_type(self, ri):
720        if 'type' not in self.attr or self.attr['type'] == 'nest':
721            return self.nested_struct_type
722        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
723            return None  # use arg_member()
724        elif self.attr['type'] == 'string':
725            return 'struct ynl_string *'
726        elif self.attr['type'] in scalars:
727            scalar_pfx = '__' if ri.ku_space == 'user' else ''
728            return scalar_pfx + self.attr['type']
729        else:
730            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
731
732    def arg_member(self, ri):
733        if self.type == 'binary' and 'struct' in self.attr:
734            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
735                    f'unsigned int n_{self.c_name}']
736        return super().arg_member(ri)
737
738    def free_needs_iter(self):
739        return self.attr['type'] in {'nest', 'string'}
740
741    def _free_lines(self, ri, var, ref):
742        lines = []
743        if self.attr['type'] in scalars:
744            lines += [f"free({var}->{ref}{self.c_name});"]
745        elif self.attr['type'] == 'binary':
746            lines += [f"free({var}->{ref}{self.c_name});"]
747        elif self.attr['type'] == 'string':
748            lines += [
749                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
750                f"free({var}->{ref}{self.c_name}[i]);",
751                f"free({var}->{ref}{self.c_name});",
752            ]
753        elif 'type' not in self.attr or self.attr['type'] == 'nest':
754            lines += [
755                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
756                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
757                f"free({var}->{ref}{self.c_name});",
758            ]
759        else:
760            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
761        return lines
762
763    def _attr_policy(self, policy):
764        return self.base_type._attr_policy(policy)
765
766    def _attr_typol(self):
767        return self.base_type._attr_typol()
768
769    def _attr_get(self, ri, var):
770        return f'n_{self.c_name}++;', None, None
771
772    def attr_put(self, ri, var):
773        if self.attr['type'] in scalars:
774            put_type = self.type
775            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
776            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
777        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
778            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
779            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
780        elif self.attr['type'] == 'string':
781            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
782            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
783        elif 'type' not in self.attr or self.attr['type'] == 'nest':
784            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
785            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
786                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
787        else:
788            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
789
790    def _setter_lines(self, ri, member, presence):
791        return [f"{member} = {self.c_name};",
792                f"{presence} = n_{self.c_name};"]
793
794
795class TypeArrayNest(Type):
796    def is_multi_val(self):
797        return True
798
799    def presence_type(self):
800        return 'count'
801
802    def _complex_member_type(self, ri):
803        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
804            return self.nested_struct_type
805        elif self.attr['sub-type'] in scalars:
806            scalar_pfx = '__' if ri.ku_space == 'user' else ''
807            return scalar_pfx + self.attr['sub-type']
808        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
809            return None  # use arg_member()
810        else:
811            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
812
813    def arg_member(self, ri):
814        if self.sub_type == 'binary' and 'exact-len' in self.checks:
815            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
816                    f'unsigned int n_{self.c_name}']
817        return super().arg_member(ri)
818
819    def _attr_typol(self):
820        if self.attr['sub-type'] in scalars:
821            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
822        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
823            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
824        else:
825            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
826
827    def _attr_get(self, ri, var):
828        local_vars = ['const struct nlattr *attr2;']
829        get_lines = [f'attr_{self.c_name} = attr;',
830                     'ynl_attr_for_each_nested(attr2, attr) {',
831                     '\tif (ynl_attr_validate(yarg, attr2))',
832                     '\t\treturn YNL_PARSE_CB_ERROR;',
833                     f'\tn_{self.c_name}++;',
834                     '}']
835        return get_lines, None, local_vars
836
837    def attr_put(self, ri, var):
838        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
839        if self.sub_type in scalars:
840            put_type = self.sub_type
841            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
842            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
843            ri.cw.block_end()
844        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
845            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
846            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
847        elif self.sub_type == 'nest':
848            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
849            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
850        else:
851            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
852        ri.cw.p('ynl_attr_nest_end(nlh, array);')
853
854    def _setter_lines(self, ri, member, presence):
855        return [f"{member} = {self.c_name};",
856                f"{presence} = n_{self.c_name};"]
857
858
859class TypeNestTypeValue(Type):
860    def _complex_member_type(self, ri):
861        return self.nested_struct_type
862
863    def _attr_typol(self):
864        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
865
866    def _attr_get(self, ri, var):
867        prev = 'attr'
868        tv_args = ''
869        get_lines = []
870        local_vars = []
871        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
872                      f"parg.data = &{var}->{self.c_name};"]
873        if 'type-value' in self.attr:
874            tv_names = [c_lower(x) for x in self.attr["type-value"]]
875            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
876            local_vars += [f'__u32 {", ".join(tv_names)};']
877            for level in self.attr["type-value"]:
878                level = c_lower(level)
879                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
880                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
881                prev = 'attr_' + level
882
883            tv_args = f", {', '.join(tv_names)}"
884
885        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
886        return get_lines, init_lines, local_vars
887
888
889class TypeSubMessage(TypeNest):
890    def __init__(self, family, attr_set, attr, value):
891        super().__init__(family, attr_set, attr, value)
892
893        self.selector = Selector(attr, attr_set)
894
895    def _attr_typol(self):
896        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
897        typol += '.is_submsg = 1, '
898        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
899        # support external selectors. No family uses sub-messages with external
900        # selector for requests so this is fine for now.
901        if not self.selector.is_external():
902            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
903        return typol
904
905    def _attr_get(self, ri, var):
906        sel = c_lower(self['selector'])
907        if self.selector.is_external():
908            sel_var = f"_sel_{sel}"
909        else:
910            sel_var = f"{var}->{sel}"
911        get_lines = [f'if (!{sel_var})',
912                     f'return ynl_submsg_failed(yarg, "%s", "%s");' %
913                        (self.name, self['selector']),
914                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
915                     "return YNL_PARSE_CB_ERROR;"]
916        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
917                      f"parg.data = &{var}->{self.c_name};"]
918        return get_lines, init_lines, None
919
920
921class Selector:
922    def __init__(self, msg_attr, attr_set):
923        self.name = msg_attr["selector"]
924
925        if self.name in attr_set:
926            self.attr = attr_set[self.name]
927            self.attr.is_selector = True
928            self._external = False
929        else:
930            # The selector will need to get passed down thru the structs
931            self.attr = None
932            self._external = True
933
934    def set_attr(self, attr):
935        self.attr = attr
936
937    def is_external(self):
938        return self._external
939
940
941class Struct:
942    def __init__(self, family, space_name, type_list=None, fixed_header=None,
943                 inherited=None, submsg=None):
944        self.family = family
945        self.space_name = space_name
946        self.attr_set = family.attr_sets[space_name]
947        # Use list to catch comparisons with empty sets
948        self._inherited = inherited if inherited is not None else []
949        self.inherited = []
950        self.fixed_header = None
951        if fixed_header:
952            self.fixed_header = 'struct ' + c_lower(fixed_header)
953        self.submsg = submsg
954
955        self.nested = type_list is None
956        if family.name == c_lower(space_name):
957            self.render_name = c_lower(family.ident_name)
958        else:
959            self.render_name = c_lower(family.ident_name + '-' + space_name)
960        self.struct_name = 'struct ' + self.render_name
961        if self.nested and space_name in family.consts:
962            self.struct_name += '_'
963        self.ptr_name = self.struct_name + ' *'
964        # All attr sets this one contains, directly or multiple levels down
965        self.child_nests = set()
966
967        self.request = False
968        self.reply = False
969        self.recursive = False
970        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
971
972        self.attr_list = []
973        self.attrs = dict()
974        if type_list is not None:
975            for t in type_list:
976                self.attr_list.append((t, self.attr_set[t]),)
977        else:
978            for t in self.attr_set:
979                self.attr_list.append((t, self.attr_set[t]),)
980
981        max_val = 0
982        self.attr_max_val = None
983        for name, attr in self.attr_list:
984            if attr.value >= max_val:
985                max_val = attr.value
986                self.attr_max_val = attr
987            self.attrs[name] = attr
988
989    def __iter__(self):
990        yield from self.attrs
991
992    def __getitem__(self, key):
993        return self.attrs[key]
994
995    def member_list(self):
996        return self.attr_list
997
998    def set_inherited(self, new_inherited):
999        if self._inherited != new_inherited:
1000            raise Exception("Inheriting different members not supported")
1001        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1002
1003    def external_selectors(self):
1004        sels = []
1005        for name, attr in self.attr_list:
1006            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1007                sels.append(attr.selector)
1008        return sels
1009
1010    def free_needs_iter(self):
1011        for _, attr in self.attr_list:
1012            if attr.free_needs_iter():
1013                return True
1014        return False
1015
1016
1017class EnumEntry(SpecEnumEntry):
1018    def __init__(self, enum_set, yaml, prev, value_start):
1019        super().__init__(enum_set, yaml, prev, value_start)
1020
1021        if prev:
1022            self.value_change = (self.value != prev.value + 1)
1023        else:
1024            self.value_change = (self.value != 0)
1025        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1026
1027        # Added by resolve:
1028        self.c_name = None
1029        delattr(self, "c_name")
1030
1031    def resolve(self):
1032        self.resolve_up(super())
1033
1034        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1035
1036
1037class EnumSet(SpecEnumSet):
1038    def __init__(self, family, yaml):
1039        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1040
1041        if 'enum-name' in yaml:
1042            if yaml['enum-name']:
1043                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1044                self.user_type = self.enum_name
1045            else:
1046                self.enum_name = None
1047        else:
1048            self.enum_name = 'enum ' + self.render_name
1049
1050        if self.enum_name:
1051            self.user_type = self.enum_name
1052        else:
1053            self.user_type = 'int'
1054
1055        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1056        self.header = yaml.get('header', None)
1057        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1058
1059        super().__init__(family, yaml)
1060
1061    def new_entry(self, entry, prev_entry, value_start):
1062        return EnumEntry(self, entry, prev_entry, value_start)
1063
1064    def value_range(self):
1065        low = min([x.value for x in self.entries.values()])
1066        high = max([x.value for x in self.entries.values()])
1067
1068        if high - low + 1 != len(self.entries):
1069            return None, None
1070
1071        return low, high
1072
1073
1074class AttrSet(SpecAttrSet):
1075    def __init__(self, family, yaml):
1076        super().__init__(family, yaml)
1077
1078        if self.subset_of is None:
1079            if 'name-prefix' in yaml:
1080                pfx = yaml['name-prefix']
1081            elif self.name == family.name:
1082                pfx = family.ident_name + '-a-'
1083            else:
1084                pfx = f"{family.ident_name}-a-{self.name}-"
1085            self.name_prefix = c_upper(pfx)
1086            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1087            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1088        else:
1089            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1090            self.max_name = family.attr_sets[self.subset_of].max_name
1091            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1092
1093        # Added by resolve:
1094        self.c_name = None
1095        delattr(self, "c_name")
1096
1097    def resolve(self):
1098        self.c_name = c_lower(self.name)
1099        if self.c_name in _C_KW:
1100            self.c_name += '_'
1101        if self.c_name == self.family.c_name:
1102            self.c_name = ''
1103
1104    def new_attr(self, elem, value):
1105        if elem['type'] in scalars:
1106            t = TypeScalar(self.family, self, elem, value)
1107        elif elem['type'] == 'unused':
1108            t = TypeUnused(self.family, self, elem, value)
1109        elif elem['type'] == 'pad':
1110            t = TypePad(self.family, self, elem, value)
1111        elif elem['type'] == 'flag':
1112            t = TypeFlag(self.family, self, elem, value)
1113        elif elem['type'] == 'string':
1114            t = TypeString(self.family, self, elem, value)
1115        elif elem['type'] == 'binary':
1116            if 'struct' in elem:
1117                t = TypeBinaryStruct(self.family, self, elem, value)
1118            elif elem.get('sub-type') in scalars:
1119                t = TypeBinaryScalarArray(self.family, self, elem, value)
1120            else:
1121                t = TypeBinary(self.family, self, elem, value)
1122        elif elem['type'] == 'bitfield32':
1123            t = TypeBitfield32(self.family, self, elem, value)
1124        elif elem['type'] == 'nest':
1125            t = TypeNest(self.family, self, elem, value)
1126        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1127            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1128                t = TypeArrayNest(self.family, self, elem, value)
1129            else:
1130                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1131        elif elem['type'] == 'nest-type-value':
1132            t = TypeNestTypeValue(self.family, self, elem, value)
1133        elif elem['type'] == 'sub-message':
1134            t = TypeSubMessage(self.family, self, elem, value)
1135        else:
1136            raise Exception(f"No typed class for type {elem['type']}")
1137
1138        if 'multi-attr' in elem and elem['multi-attr']:
1139            t = TypeMultiAttr(self.family, self, elem, value, t)
1140
1141        return t
1142
1143
1144class Operation(SpecOperation):
1145    def __init__(self, family, yaml, req_value, rsp_value):
1146        # Fill in missing operation properties (for fixed hdr-only msgs)
1147        for mode in ['do', 'dump', 'event']:
1148            for direction in ['request', 'reply']:
1149                try:
1150                    yaml[mode][direction].setdefault('attributes', [])
1151                except KeyError:
1152                    pass
1153
1154        super().__init__(family, yaml, req_value, rsp_value)
1155
1156        self.render_name = c_lower(family.ident_name + '_' + self.name)
1157
1158        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1159                         ('dump' in yaml and 'request' in yaml['dump'])
1160
1161        self.has_ntf = False
1162
1163        # Added by resolve:
1164        self.enum_name = None
1165        delattr(self, "enum_name")
1166
1167    def resolve(self):
1168        self.resolve_up(super())
1169
1170        if not self.is_async:
1171            self.enum_name = self.family.op_prefix + c_upper(self.name)
1172        else:
1173            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1174
1175    def mark_has_ntf(self):
1176        self.has_ntf = True
1177
1178
1179class SubMessage(SpecSubMessage):
1180    def __init__(self, family, yaml):
1181        super().__init__(family, yaml)
1182
1183        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1184
1185    def resolve(self):
1186        self.resolve_up(super())
1187
1188
1189class Family(SpecFamily):
1190    def __init__(self, file_name, exclude_ops):
1191        # Added by resolve:
1192        self.c_name = None
1193        delattr(self, "c_name")
1194        self.op_prefix = None
1195        delattr(self, "op_prefix")
1196        self.async_op_prefix = None
1197        delattr(self, "async_op_prefix")
1198        self.mcgrps = None
1199        delattr(self, "mcgrps")
1200        self.consts = None
1201        delattr(self, "consts")
1202        self.hooks = None
1203        delattr(self, "hooks")
1204
1205        super().__init__(file_name, exclude_ops=exclude_ops)
1206
1207        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1208        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1209
1210        if 'definitions' not in self.yaml:
1211            self.yaml['definitions'] = []
1212
1213        if 'uapi-header' in self.yaml:
1214            self.uapi_header = self.yaml['uapi-header']
1215        else:
1216            self.uapi_header = f"linux/{self.ident_name}.h"
1217        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1218            self.uapi_header_name = self.uapi_header[6:-2]
1219        else:
1220            self.uapi_header_name = self.ident_name
1221
1222    def resolve(self):
1223        self.resolve_up(super())
1224
1225        self.c_name = c_lower(self.ident_name)
1226        if 'name-prefix' in self.yaml['operations']:
1227            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1228        else:
1229            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1230        if 'async-prefix' in self.yaml['operations']:
1231            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1232        else:
1233            self.async_op_prefix = self.op_prefix
1234
1235        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1236
1237        self.hooks = dict()
1238        for when in ['pre', 'post']:
1239            self.hooks[when] = dict()
1240            for op_mode in ['do', 'dump']:
1241                self.hooks[when][op_mode] = dict()
1242                self.hooks[when][op_mode]['set'] = set()
1243                self.hooks[when][op_mode]['list'] = []
1244
1245        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1246        self.root_sets = dict()
1247        # dict space-name -> Struct
1248        self.pure_nested_structs = dict()
1249
1250        self._mark_notify()
1251        self._mock_up_events()
1252
1253        self._load_root_sets()
1254        self._load_nested_sets()
1255        self._load_attr_use()
1256        self._load_selector_passing()
1257        self._load_hooks()
1258
1259        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1260        if self.kernel_policy == 'global':
1261            self._load_global_policy()
1262
1263    def new_enum(self, elem):
1264        return EnumSet(self, elem)
1265
1266    def new_attr_set(self, elem):
1267        return AttrSet(self, elem)
1268
1269    def new_operation(self, elem, req_value, rsp_value):
1270        return Operation(self, elem, req_value, rsp_value)
1271
1272    def new_sub_message(self, elem):
1273        return SubMessage(self, elem)
1274
1275    def is_classic(self):
1276        return self.proto == 'netlink-raw'
1277
1278    def _mark_notify(self):
1279        for op in self.msgs.values():
1280            if 'notify' in op:
1281                self.ops[op['notify']].mark_has_ntf()
1282
1283    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1284    def _mock_up_events(self):
1285        for op in self.yaml['operations']['list']:
1286            if 'event' in op:
1287                op['do'] = {
1288                    'reply': {
1289                        'attributes': op['event']['attributes']
1290                    }
1291                }
1292
1293    def _load_root_sets(self):
1294        for op_name, op in self.msgs.items():
1295            if 'attribute-set' not in op:
1296                continue
1297
1298            req_attrs = set()
1299            rsp_attrs = set()
1300            for op_mode in ['do', 'dump']:
1301                if op_mode in op and 'request' in op[op_mode]:
1302                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1303                if op_mode in op and 'reply' in op[op_mode]:
1304                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1305            if 'event' in op:
1306                rsp_attrs.update(set(op['event']['attributes']))
1307
1308            if op['attribute-set'] not in self.root_sets:
1309                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1310            else:
1311                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1312                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1313
1314    def _sort_pure_types(self):
1315        # Try to reorder according to dependencies
1316        pns_key_list = list(self.pure_nested_structs.keys())
1317        pns_key_seen = set()
1318        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1319        for _ in range(rounds):
1320            if len(pns_key_list) == 0:
1321                break
1322            name = pns_key_list.pop(0)
1323            finished = True
1324            for _, spec in self.attr_sets[name].items():
1325                if 'nested-attributes' in spec:
1326                    nested = spec['nested-attributes']
1327                elif 'sub-message' in spec:
1328                    nested = spec.sub_message
1329                else:
1330                    continue
1331
1332                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1333                if self.pure_nested_structs[nested].recursive:
1334                    continue
1335                if nested not in pns_key_seen:
1336                    # Dicts are sorted, this will make struct last
1337                    struct = self.pure_nested_structs.pop(name)
1338                    self.pure_nested_structs[name] = struct
1339                    finished = False
1340                    break
1341            if finished:
1342                pns_key_seen.add(name)
1343            else:
1344                pns_key_list.append(name)
1345
1346    def _load_nested_set_nest(self, spec):
1347        inherit = set()
1348        nested = spec['nested-attributes']
1349        if nested not in self.root_sets:
1350            if nested not in self.pure_nested_structs:
1351                self.pure_nested_structs[nested] = \
1352                    Struct(self, nested, inherited=inherit,
1353                           fixed_header=spec.get('fixed-header'))
1354        else:
1355            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1356
1357        if 'type-value' in spec:
1358            if nested in self.root_sets:
1359                raise Exception("Inheriting members to a space used as root not supported")
1360            inherit.update(set(spec['type-value']))
1361        elif spec['type'] == 'indexed-array':
1362            inherit.add('idx')
1363        self.pure_nested_structs[nested].set_inherited(inherit)
1364
1365        return nested
1366
1367    def _load_nested_set_submsg(self, spec):
1368        # Fake the struct type for the sub-message itself
1369        # its not a attr_set but codegen wants attr_sets.
1370        submsg = self.sub_msgs[spec["sub-message"]]
1371        nested = submsg.name
1372
1373        attrs = []
1374        for name, fmt in submsg.formats.items():
1375            attr = {
1376                "name": name,
1377                "parent-sub-message": spec,
1378            }
1379            if 'attribute-set' in fmt:
1380                attr |= {
1381                    "type": "nest",
1382                    "nested-attributes": fmt['attribute-set'],
1383                }
1384                if 'fixed-header' in fmt:
1385                    attr |= { "fixed-header": fmt["fixed-header"] }
1386            elif 'fixed-header' in fmt:
1387                attr |= {
1388                    "type": "binary",
1389                    "struct": fmt["fixed-header"],
1390                }
1391            else:
1392                attr["type"] = "flag"
1393            attrs.append(attr)
1394
1395        self.attr_sets[nested] = AttrSet(self, {
1396            "name": nested,
1397            "name-pfx": self.name + '-' + spec.name + '-',
1398            "attributes": attrs
1399        })
1400
1401        if nested not in self.pure_nested_structs:
1402            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1403
1404        return nested
1405
1406    def _load_nested_sets(self):
1407        attr_set_queue = list(self.root_sets.keys())
1408        attr_set_seen = set(self.root_sets.keys())
1409
1410        while len(attr_set_queue):
1411            a_set = attr_set_queue.pop(0)
1412            for attr, spec in self.attr_sets[a_set].items():
1413                if 'nested-attributes' in spec:
1414                    nested = self._load_nested_set_nest(spec)
1415                elif 'sub-message' in spec:
1416                    nested = self._load_nested_set_submsg(spec)
1417                else:
1418                    continue
1419
1420                if nested not in attr_set_seen:
1421                    attr_set_queue.append(nested)
1422                    attr_set_seen.add(nested)
1423
1424        for root_set, rs_members in self.root_sets.items():
1425            for attr, spec in self.attr_sets[root_set].items():
1426                if 'nested-attributes' in spec:
1427                    nested = spec['nested-attributes']
1428                elif 'sub-message' in spec:
1429                    nested = spec.sub_message
1430                else:
1431                    nested = None
1432
1433                if nested:
1434                    if attr in rs_members['request']:
1435                        self.pure_nested_structs[nested].request = True
1436                    if attr in rs_members['reply']:
1437                        self.pure_nested_structs[nested].reply = True
1438
1439                    if spec.is_multi_val():
1440                        child = self.pure_nested_structs.get(nested)
1441                        child.in_multi_val = True
1442
1443        self._sort_pure_types()
1444
1445        # Propagate the request / reply / recursive
1446        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1447            for _, spec in self.attr_sets[attr_set].items():
1448                if attr_set in struct.child_nests:
1449                    struct.recursive = True
1450
1451                if 'nested-attributes' in spec:
1452                    child_name = spec['nested-attributes']
1453                elif 'sub-message' in spec:
1454                    child_name = spec.sub_message
1455                else:
1456                    continue
1457
1458                struct.child_nests.add(child_name)
1459                child = self.pure_nested_structs.get(child_name)
1460                if child:
1461                    if not child.recursive:
1462                        struct.child_nests.update(child.child_nests)
1463                    child.request |= struct.request
1464                    child.reply |= struct.reply
1465                    if spec.is_multi_val():
1466                        child.in_multi_val = True
1467
1468        self._sort_pure_types()
1469
1470    def _load_attr_use(self):
1471        for _, struct in self.pure_nested_structs.items():
1472            if struct.request:
1473                for _, arg in struct.member_list():
1474                    arg.set_request()
1475            if struct.reply:
1476                for _, arg in struct.member_list():
1477                    arg.set_reply()
1478
1479        for root_set, rs_members in self.root_sets.items():
1480            for attr, spec in self.attr_sets[root_set].items():
1481                if attr in rs_members['request']:
1482                    spec.set_request()
1483                if attr in rs_members['reply']:
1484                    spec.set_reply()
1485
1486    def _load_selector_passing(self):
1487        def all_structs():
1488            for k, v in reversed(self.pure_nested_structs.items()):
1489                yield k, v
1490            for k, _ in self.root_sets.items():
1491                yield k, None  # we don't have a struct, but it must be terminal
1492
1493        for attr_set, struct in all_structs():
1494            for _, spec in self.attr_sets[attr_set].items():
1495                if 'nested-attributes' in spec:
1496                    child_name = spec['nested-attributes']
1497                elif 'sub-message' in spec:
1498                    child_name = spec.sub_message
1499                else:
1500                    continue
1501
1502                child = self.pure_nested_structs.get(child_name)
1503                for selector in child.external_selectors():
1504                    if selector.name in self.attr_sets[attr_set]:
1505                        sel_attr = self.attr_sets[attr_set][selector.name]
1506                        selector.set_attr(sel_attr)
1507                    else:
1508                        raise Exception("Passing selector thru more than one layer not supported")
1509
1510    def _load_global_policy(self):
1511        global_set = set()
1512        attr_set_name = None
1513        for op_name, op in self.ops.items():
1514            if not op:
1515                continue
1516            if 'attribute-set' not in op:
1517                continue
1518
1519            if attr_set_name is None:
1520                attr_set_name = op['attribute-set']
1521            if attr_set_name != op['attribute-set']:
1522                raise Exception('For a global policy all ops must use the same set')
1523
1524            for op_mode in ['do', 'dump']:
1525                if op_mode in op:
1526                    req = op[op_mode].get('request')
1527                    if req:
1528                        global_set.update(req.get('attributes', []))
1529
1530        self.global_policy = []
1531        self.global_policy_set = attr_set_name
1532        for attr in self.attr_sets[attr_set_name]:
1533            if attr in global_set:
1534                self.global_policy.append(attr)
1535
1536    def _load_hooks(self):
1537        for op in self.ops.values():
1538            for op_mode in ['do', 'dump']:
1539                if op_mode not in op:
1540                    continue
1541                for when in ['pre', 'post']:
1542                    if when not in op[op_mode]:
1543                        continue
1544                    name = op[op_mode][when]
1545                    if name in self.hooks[when][op_mode]['set']:
1546                        continue
1547                    self.hooks[when][op_mode]['set'].add(name)
1548                    self.hooks[when][op_mode]['list'].append(name)
1549
1550
1551class RenderInfo:
1552    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1553        self.family = family
1554        self.nl = cw.nlib
1555        self.ku_space = ku_space
1556        self.op_mode = op_mode
1557        self.op = op
1558
1559        fixed_hdr = op.fixed_header if op else None
1560        self.fixed_hdr_len = 'ys->family->hdr_len'
1561        if op and op.fixed_header:
1562            if op.fixed_header != family.fixed_header:
1563                if family.is_classic():
1564                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1565                else:
1566                    raise Exception(f"Per-op fixed header not supported, yet")
1567
1568
1569        # 'do' and 'dump' response parsing is identical
1570        self.type_consistent = True
1571        self.type_oneside = False
1572        if op_mode != 'do' and 'dump' in op:
1573            if 'do' in op:
1574                if ('reply' in op['do']) != ('reply' in op["dump"]):
1575                    self.type_consistent = False
1576                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1577                    self.type_consistent = False
1578            else:
1579                self.type_consistent = True
1580                self.type_oneside = True
1581
1582        self.attr_set = attr_set
1583        if not self.attr_set:
1584            self.attr_set = op['attribute-set']
1585
1586        self.type_name_conflict = False
1587        if op:
1588            self.type_name = c_lower(op.name)
1589        else:
1590            self.type_name = c_lower(attr_set)
1591            if attr_set in family.consts:
1592                self.type_name_conflict = True
1593
1594        self.cw = cw
1595
1596        self.struct = dict()
1597        if op_mode == 'notify':
1598            op_mode = 'do' if 'do' in op else 'dump'
1599        for op_dir in ['request', 'reply']:
1600            if op:
1601                type_list = []
1602                if op_dir in op[op_mode]:
1603                    type_list = op[op_mode][op_dir]['attributes']
1604                self.struct[op_dir] = Struct(family, self.attr_set,
1605                                             fixed_header=fixed_hdr,
1606                                             type_list=type_list)
1607        if op_mode == 'event':
1608            self.struct['reply'] = Struct(family, self.attr_set,
1609                                          fixed_header=fixed_hdr,
1610                                          type_list=op['event']['attributes'])
1611
1612    def type_empty(self, key):
1613        return len(self.struct[key].attr_list) == 0 and \
1614            self.struct['request'].fixed_header is None
1615
1616    def needs_nlflags(self, direction):
1617        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1618
1619
1620class CodeWriter:
1621    def __init__(self, nlib, out_file=None, overwrite=True):
1622        self.nlib = nlib
1623        self._overwrite = overwrite
1624
1625        self._nl = False
1626        self._block_end = False
1627        self._silent_block = False
1628        self._ind = 0
1629        self._ifdef_block = None
1630        if out_file is None:
1631            self._out = os.sys.stdout
1632        else:
1633            self._out = tempfile.NamedTemporaryFile('w+')
1634            self._out_file = out_file
1635
1636    def __del__(self):
1637        self.close_out_file()
1638
1639    def close_out_file(self):
1640        if self._out == os.sys.stdout:
1641            return
1642        # Avoid modifying the file if contents didn't change
1643        self._out.flush()
1644        if not self._overwrite and os.path.isfile(self._out_file):
1645            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1646                return
1647        with open(self._out_file, 'w+') as out_file:
1648            self._out.seek(0)
1649            shutil.copyfileobj(self._out, out_file)
1650            self._out.close()
1651        self._out = os.sys.stdout
1652
1653    @classmethod
1654    def _is_cond(cls, line):
1655        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1656
1657    def p(self, line, add_ind=0):
1658        if self._block_end:
1659            self._block_end = False
1660            if line.startswith('else'):
1661                line = '} ' + line
1662            else:
1663                self._out.write('\t' * self._ind + '}\n')
1664
1665        if self._nl:
1666            self._out.write('\n')
1667            self._nl = False
1668
1669        ind = self._ind
1670        if line[-1] == ':':
1671            ind -= 1
1672        if self._silent_block:
1673            ind += 1
1674        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1675        self._silent_block |= line.strip() == 'else'
1676        if line[0] == '#':
1677            ind = 0
1678        if add_ind:
1679            ind += add_ind
1680        self._out.write('\t' * ind + line + '\n')
1681
1682    def nl(self):
1683        self._nl = True
1684
1685    def block_start(self, line=''):
1686        if line:
1687            line = line + ' '
1688        self.p(line + '{')
1689        self._ind += 1
1690
1691    def block_end(self, line=''):
1692        if line and line[0] not in {';', ','}:
1693            line = ' ' + line
1694        self._ind -= 1
1695        self._nl = False
1696        if not line:
1697            # Delay printing closing bracket in case "else" comes next
1698            if self._block_end:
1699                self._out.write('\t' * (self._ind + 1) + '}\n')
1700            self._block_end = True
1701        else:
1702            self.p('}' + line)
1703
1704    def write_doc_line(self, doc, indent=True):
1705        words = doc.split()
1706        line = ' *'
1707        for word in words:
1708            if len(line) + len(word) >= 79:
1709                self.p(line)
1710                line = ' *'
1711                if indent:
1712                    line += '  '
1713            line += ' ' + word
1714        self.p(line)
1715
1716    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1717        if not args:
1718            args = ['void']
1719
1720        if doc:
1721            self.p('/*')
1722            self.p(' * ' + doc)
1723            self.p(' */')
1724
1725        oneline = qual_ret
1726        if qual_ret[-1] != '*':
1727            oneline += ' '
1728        oneline += f"{name}({', '.join(args)}){suffix}"
1729
1730        if len(oneline) < 80:
1731            self.p(oneline)
1732            return
1733
1734        v = qual_ret
1735        if len(v) > 3:
1736            self.p(v)
1737            v = ''
1738        elif qual_ret[-1] != '*':
1739            v += ' '
1740        v += name + '('
1741        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1742        delta_ind = len(v) - len(ind)
1743        v += args[0]
1744        i = 1
1745        while i < len(args):
1746            next_len = len(v) + len(args[i])
1747            if v[0] == '\t':
1748                next_len += delta_ind
1749            if next_len > 76:
1750                self.p(v + ',')
1751                v = ind
1752            else:
1753                v += ', '
1754            v += args[i]
1755            i += 1
1756        self.p(v + ')' + suffix)
1757
1758    def write_func_lvar(self, local_vars):
1759        if not local_vars:
1760            return
1761
1762        if type(local_vars) is str:
1763            local_vars = [local_vars]
1764
1765        local_vars.sort(key=len, reverse=True)
1766        for var in local_vars:
1767            self.p(var)
1768        self.nl()
1769
1770    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1771        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1772        self.block_start()
1773        self.write_func_lvar(local_vars=local_vars)
1774
1775        for line in body:
1776            self.p(line)
1777        self.block_end()
1778
1779    def writes_defines(self, defines):
1780        longest = 0
1781        for define in defines:
1782            if len(define[0]) > longest:
1783                longest = len(define[0])
1784        longest = ((longest + 8) // 8) * 8
1785        for define in defines:
1786            line = '#define ' + define[0]
1787            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1788            if type(define[1]) is int:
1789                line += str(define[1])
1790            elif type(define[1]) is str:
1791                line += '"' + define[1] + '"'
1792            self.p(line)
1793
1794    def write_struct_init(self, members):
1795        longest = max([len(x[0]) for x in members])
1796        longest += 1  # because we prepend a .
1797        longest = ((longest + 8) // 8) * 8
1798        for one in members:
1799            line = '.' + one[0]
1800            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1801            line += '= ' + str(one[1]) + ','
1802            self.p(line)
1803
1804    def ifdef_block(self, config):
1805        config_option = None
1806        if config:
1807            config_option = 'CONFIG_' + c_upper(config)
1808        if self._ifdef_block == config_option:
1809            return
1810
1811        if self._ifdef_block:
1812            self.p('#endif /* ' + self._ifdef_block + ' */')
1813        if config_option:
1814            self.p('#ifdef ' + config_option)
1815        self._ifdef_block = config_option
1816
1817
1818scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1819
1820direction_to_suffix = {
1821    'reply': '_rsp',
1822    'request': '_req',
1823    '': ''
1824}
1825
1826op_mode_to_wrapper = {
1827    'do': '',
1828    'dump': '_list',
1829    'notify': '_ntf',
1830    'event': '',
1831}
1832
1833_C_KW = {
1834    'auto',
1835    'bool',
1836    'break',
1837    'case',
1838    'char',
1839    'const',
1840    'continue',
1841    'default',
1842    'do',
1843    'double',
1844    'else',
1845    'enum',
1846    'extern',
1847    'float',
1848    'for',
1849    'goto',
1850    'if',
1851    'inline',
1852    'int',
1853    'long',
1854    'register',
1855    'return',
1856    'short',
1857    'signed',
1858    'sizeof',
1859    'static',
1860    'struct',
1861    'switch',
1862    'typedef',
1863    'union',
1864    'unsigned',
1865    'void',
1866    'volatile',
1867    'while'
1868}
1869
1870
1871def rdir(direction):
1872    if direction == 'reply':
1873        return 'request'
1874    if direction == 'request':
1875        return 'reply'
1876    return direction
1877
1878
1879def op_prefix(ri, direction, deref=False):
1880    suffix = f"_{ri.type_name}"
1881
1882    if not ri.op_mode:
1883        pass
1884    elif ri.op_mode == 'do':
1885        suffix += f"{direction_to_suffix[direction]}"
1886    else:
1887        if direction == 'request':
1888            suffix += '_req'
1889            if not ri.type_oneside:
1890                suffix += '_dump'
1891        else:
1892            if ri.type_consistent:
1893                if deref:
1894                    suffix += f"{direction_to_suffix[direction]}"
1895                else:
1896                    suffix += op_mode_to_wrapper[ri.op_mode]
1897            else:
1898                suffix += '_rsp'
1899                suffix += '_dump' if deref else '_list'
1900
1901    return f"{ri.family.c_name}{suffix}"
1902
1903
1904def type_name(ri, direction, deref=False):
1905    return f"struct {op_prefix(ri, direction, deref=deref)}"
1906
1907
1908def print_prototype(ri, direction, terminate=True, doc=None):
1909    suffix = ';' if terminate else ''
1910
1911    fname = ri.op.render_name
1912    if ri.op_mode == 'dump':
1913        fname += '_dump'
1914
1915    args = ['struct ynl_sock *ys']
1916    if 'request' in ri.op[ri.op_mode]:
1917        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1918
1919    ret = 'int'
1920    if 'reply' in ri.op[ri.op_mode]:
1921        ret = f"{type_name(ri, rdir(direction))} *"
1922
1923    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1924
1925
1926def print_req_prototype(ri):
1927    print_prototype(ri, "request", doc=ri.op['doc'])
1928
1929
1930def print_dump_prototype(ri):
1931    print_prototype(ri, "request")
1932
1933
1934def put_typol_submsg(cw, struct):
1935    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1936
1937    i = 0
1938    for name, arg in struct.member_list():
1939        nest = ""
1940        if arg.type == 'nest':
1941            nest = f" .nest = &{arg.nested_render_name}_nest,"
1942        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s",%s },' %
1943             (i, name, nest))
1944        i += 1
1945
1946    cw.block_end(line=';')
1947    cw.nl()
1948
1949    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1950    cw.p(f'.max_attr = {i - 1},')
1951    cw.p(f'.table = {struct.render_name}_policy,')
1952    cw.block_end(line=';')
1953    cw.nl()
1954
1955
1956def put_typol_fwd(cw, struct):
1957    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1958
1959
1960def put_typol(cw, struct):
1961    if struct.submsg:
1962        put_typol_submsg(cw, struct)
1963        return
1964
1965    type_max = struct.attr_set.max_name
1966    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1967
1968    for _, arg in struct.member_list():
1969        arg.attr_typol(cw)
1970
1971    cw.block_end(line=';')
1972    cw.nl()
1973
1974    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1975    cw.p(f'.max_attr = {type_max},')
1976    cw.p(f'.table = {struct.render_name}_policy,')
1977    cw.block_end(line=';')
1978    cw.nl()
1979
1980
1981def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1982    args = [f'int {arg_name}']
1983    if enum:
1984        args = [enum.user_type + ' ' + arg_name]
1985    cw.write_func_prot('const char *', f'{render_name}_str', args)
1986    cw.block_start()
1987    if enum and enum.type == 'flags':
1988        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1989    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1990    cw.p('return NULL;')
1991    cw.p(f'return {map_name}[{arg_name}];')
1992    cw.block_end()
1993    cw.nl()
1994
1995
1996def put_op_name_fwd(family, cw):
1997    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1998
1999
2000def put_op_name(family, cw):
2001    map_name = f'{family.c_name}_op_strmap'
2002    cw.block_start(line=f"static const char * const {map_name}[] =")
2003    for op_name, op in family.msgs.items():
2004        if op.rsp_value:
2005            # Make sure we don't add duplicated entries, if multiple commands
2006            # produce the same response in legacy families.
2007            if family.rsp_by_value[op.rsp_value] != op:
2008                cw.p(f'// skip "{op_name}", duplicate reply value')
2009                continue
2010
2011            if op.req_value == op.rsp_value:
2012                cw.p(f'[{op.enum_name}] = "{op_name}",')
2013            else:
2014                cw.p(f'[{op.rsp_value}] = "{op_name}",')
2015    cw.block_end(line=';')
2016    cw.nl()
2017
2018    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2019
2020
2021def put_enum_to_str_fwd(family, cw, enum):
2022    args = [enum.user_type + ' value']
2023    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2024
2025
2026def put_enum_to_str(family, cw, enum):
2027    map_name = f'{enum.render_name}_strmap'
2028    cw.block_start(line=f"static const char * const {map_name}[] =")
2029    for entry in enum.entries.values():
2030        cw.p(f'[{entry.value}] = "{entry.name}",')
2031    cw.block_end(line=';')
2032    cw.nl()
2033
2034    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2035
2036
2037def put_req_nested_prototype(ri, struct, suffix=';'):
2038    func_args = ['struct nlmsghdr *nlh',
2039                 'unsigned int attr_type',
2040                 f'{struct.ptr_name}obj']
2041
2042    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2043                          suffix=suffix)
2044
2045
2046def put_req_nested(ri, struct):
2047    local_vars = []
2048    init_lines = []
2049
2050    if struct.submsg is None:
2051        local_vars.append('struct nlattr *nest;')
2052        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2053    if struct.fixed_header:
2054        local_vars.append('void *hdr;')
2055        struct_sz = f'sizeof({struct.fixed_header})'
2056        init_lines.append(f"hdr = ynl_nlmsg_put_extra_header(nlh, {struct_sz});")
2057        init_lines.append(f"memcpy(hdr, &obj->_hdr, {struct_sz});")
2058
2059    has_anest = False
2060    has_count = False
2061    for _, arg in struct.member_list():
2062        has_anest |= arg.type == 'indexed-array'
2063        has_count |= arg.presence_type() == 'count'
2064    if has_anest:
2065        local_vars.append('struct nlattr *array;')
2066    if has_count:
2067        local_vars.append('unsigned int i;')
2068
2069    put_req_nested_prototype(ri, struct, suffix='')
2070    ri.cw.block_start()
2071    ri.cw.write_func_lvar(local_vars)
2072
2073    for line in init_lines:
2074        ri.cw.p(line)
2075
2076    for _, arg in struct.member_list():
2077        arg.attr_put(ri, "obj")
2078
2079    if struct.submsg is None:
2080        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2081
2082    ri.cw.nl()
2083    ri.cw.p('return 0;')
2084    ri.cw.block_end()
2085    ri.cw.nl()
2086
2087
2088def _multi_parse(ri, struct, init_lines, local_vars):
2089    if struct.fixed_header:
2090        local_vars += ['void *hdr;']
2091    if struct.nested:
2092        if struct.fixed_header:
2093            iter_line = f"ynl_attr_for_each_nested_off(attr, nested, sizeof({struct.fixed_header}))"
2094        else:
2095            iter_line = "ynl_attr_for_each_nested(attr, nested)"
2096    else:
2097        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2098        if ri.op.fixed_header != ri.family.fixed_header:
2099            if ri.family.is_classic():
2100                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2101            else:
2102                raise Exception(f"Per-op fixed header not supported, yet")
2103
2104    array_nests = set()
2105    multi_attrs = set()
2106    needs_parg = False
2107    for arg, aspec in struct.member_list():
2108        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2109            if aspec["sub-type"] in {'binary', 'nest'}:
2110                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2111                array_nests.add(arg)
2112            elif aspec['sub-type'] in scalars:
2113                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2114                array_nests.add(arg)
2115            else:
2116                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2117        if 'multi-attr' in aspec:
2118            multi_attrs.add(arg)
2119        needs_parg |= 'nested-attributes' in aspec
2120        needs_parg |= 'sub-message' in aspec
2121    if array_nests or multi_attrs:
2122        local_vars.append('int i;')
2123    if needs_parg:
2124        local_vars.append('struct ynl_parse_arg parg;')
2125        init_lines.append('parg.ys = yarg->ys;')
2126
2127    all_multi = array_nests | multi_attrs
2128
2129    for anest in sorted(all_multi):
2130        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
2131
2132    ri.cw.block_start()
2133    ri.cw.write_func_lvar(local_vars)
2134
2135    for line in init_lines:
2136        ri.cw.p(line)
2137    ri.cw.nl()
2138
2139    for arg in struct.inherited:
2140        ri.cw.p(f'dst->{arg} = {arg};')
2141
2142    if struct.fixed_header:
2143        if struct.nested:
2144            ri.cw.p('hdr = ynl_attr_data(nested);')
2145        elif ri.family.is_classic():
2146            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2147        else:
2148            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2149        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2150    for anest in sorted(all_multi):
2151        aspec = struct[anest]
2152        ri.cw.p(f"if (dst->{aspec.c_name})")
2153        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2154
2155    ri.cw.nl()
2156    ri.cw.block_start(line=iter_line)
2157    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2158    ri.cw.nl()
2159
2160    first = True
2161    for _, arg in struct.member_list():
2162        good = arg.attr_get(ri, 'dst', first=first)
2163        # First may be 'unused' or 'pad', ignore those
2164        first &= not good
2165
2166    ri.cw.block_end()
2167    ri.cw.nl()
2168
2169    for anest in sorted(array_nests):
2170        aspec = struct[anest]
2171
2172        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2173        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2174        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2175        ri.cw.p('i = 0;')
2176        if 'nested-attributes' in aspec:
2177            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2178        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2179        if 'nested-attributes' in aspec:
2180            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2181            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2182            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2183        elif aspec.sub_type in scalars:
2184            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2185        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2186            # Length is validated by typol
2187            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2188        else:
2189            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2190        ri.cw.p('i++;')
2191        ri.cw.block_end()
2192        ri.cw.block_end()
2193    ri.cw.nl()
2194
2195    for anest in sorted(multi_attrs):
2196        aspec = struct[anest]
2197        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2198        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2199        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2200        ri.cw.p('i = 0;')
2201        if 'nested-attributes' in aspec:
2202            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2203        ri.cw.block_start(line=iter_line)
2204        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2205        if 'nested-attributes' in aspec:
2206            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2207            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2208            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2209        elif aspec.type in scalars:
2210            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2211        elif aspec.type == 'binary' and 'struct' in aspec:
2212            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2213            ri.cw.nl()
2214            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2215            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2216            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2217        elif aspec.type == 'string':
2218            ri.cw.p('unsigned int len;')
2219            ri.cw.nl()
2220            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2221            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2222            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2223            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2224            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2225        else:
2226            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2227        ri.cw.p('i++;')
2228        ri.cw.block_end()
2229        ri.cw.block_end()
2230        ri.cw.block_end()
2231    ri.cw.nl()
2232
2233    if struct.nested:
2234        ri.cw.p('return 0;')
2235    else:
2236        ri.cw.p('return YNL_PARSE_CB_OK;')
2237    ri.cw.block_end()
2238    ri.cw.nl()
2239
2240
2241def parse_rsp_submsg(ri, struct):
2242    parse_rsp_nested_prototype(ri, struct, suffix='')
2243
2244    var = 'dst'
2245    local_vars = {'const struct nlattr *attr = nested;',
2246                  f'{struct.ptr_name}{var} = yarg->data;',
2247                  'struct ynl_parse_arg parg;'}
2248
2249    for _, arg in struct.member_list():
2250        _, _, l_vars = arg._attr_get(ri, var)
2251        local_vars |= set(l_vars) if l_vars else set()
2252
2253    ri.cw.block_start()
2254    ri.cw.write_func_lvar(list(local_vars))
2255    ri.cw.p('parg.ys = yarg->ys;')
2256    ri.cw.nl()
2257
2258    first = True
2259    for name, arg in struct.member_list():
2260        kw = 'if' if first else 'else if'
2261        first = False
2262
2263        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2264        get_lines, init_lines, _ = arg._attr_get(ri, var)
2265        for line in init_lines or []:
2266            ri.cw.p(line)
2267        for line in get_lines:
2268            ri.cw.p(line)
2269        if arg.presence_type() == 'present':
2270            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2271        ri.cw.block_end()
2272    ri.cw.p('return 0;')
2273    ri.cw.block_end()
2274    ri.cw.nl()
2275
2276
2277def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2278    func_args = ['struct ynl_parse_arg *yarg',
2279                 'const struct nlattr *nested']
2280    for sel in struct.external_selectors():
2281        func_args.append('const char *_sel_' + sel.name)
2282    if struct.submsg:
2283        func_args.insert(1, 'const char *sel')
2284    for arg in struct.inherited:
2285        func_args.append('__u32 ' + arg)
2286
2287    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2288                          suffix=suffix)
2289
2290
2291def parse_rsp_nested(ri, struct):
2292    if struct.submsg:
2293        return parse_rsp_submsg(ri, struct)
2294
2295    parse_rsp_nested_prototype(ri, struct, suffix='')
2296
2297    local_vars = ['const struct nlattr *attr;',
2298                  f'{struct.ptr_name}dst = yarg->data;']
2299    init_lines = []
2300
2301    if struct.member_list():
2302        _multi_parse(ri, struct, init_lines, local_vars)
2303    else:
2304        # Empty nest
2305        ri.cw.block_start()
2306        ri.cw.p('return 0;')
2307        ri.cw.block_end()
2308        ri.cw.nl()
2309
2310
2311def parse_rsp_msg(ri, deref=False):
2312    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2313        return
2314
2315    func_args = ['const struct nlmsghdr *nlh',
2316                 'struct ynl_parse_arg *yarg']
2317
2318    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2319                  'const struct nlattr *attr;']
2320    init_lines = ['dst = yarg->data;']
2321
2322    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2323
2324    if ri.struct["reply"].member_list():
2325        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2326    else:
2327        # Empty reply
2328        ri.cw.block_start()
2329        ri.cw.p('return YNL_PARSE_CB_OK;')
2330        ri.cw.block_end()
2331        ri.cw.nl()
2332
2333
2334def print_req(ri):
2335    ret_ok = '0'
2336    ret_err = '-1'
2337    direction = "request"
2338    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2339                  'struct nlmsghdr *nlh;',
2340                  'int err;']
2341
2342    if 'reply' in ri.op[ri.op_mode]:
2343        ret_ok = 'rsp'
2344        ret_err = 'NULL'
2345        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2346
2347    if ri.struct["request"].fixed_header:
2348        local_vars += ['size_t hdr_len;',
2349                       'void *hdr;']
2350
2351    for _, attr in ri.struct["request"].member_list():
2352        if attr.presence_type() == 'count':
2353            local_vars += ['unsigned int i;']
2354            break
2355
2356    print_prototype(ri, direction, terminate=False)
2357    ri.cw.block_start()
2358    ri.cw.write_func_lvar(local_vars)
2359
2360    if ri.family.is_classic():
2361        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2362    else:
2363        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2364
2365    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2366    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2367    if 'reply' in ri.op[ri.op_mode]:
2368        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2369    ri.cw.nl()
2370
2371    if ri.struct['request'].fixed_header:
2372        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2373        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2374        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2375        ri.cw.nl()
2376
2377    for _, attr in ri.struct["request"].member_list():
2378        attr.attr_put(ri, "req")
2379    ri.cw.nl()
2380
2381    if 'reply' in ri.op[ri.op_mode]:
2382        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2383        ri.cw.p('yrs.yarg.data = rsp;')
2384        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2385        if ri.op.value is not None:
2386            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2387        else:
2388            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2389        ri.cw.nl()
2390    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2391    ri.cw.p('if (err < 0)')
2392    if 'reply' in ri.op[ri.op_mode]:
2393        ri.cw.p('goto err_free;')
2394    else:
2395        ri.cw.p('return -1;')
2396    ri.cw.nl()
2397
2398    ri.cw.p(f"return {ret_ok};")
2399    ri.cw.nl()
2400
2401    if 'reply' in ri.op[ri.op_mode]:
2402        ri.cw.p('err_free:')
2403        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2404        ri.cw.p(f"return {ret_err};")
2405
2406    ri.cw.block_end()
2407
2408
2409def print_dump(ri):
2410    direction = "request"
2411    print_prototype(ri, direction, terminate=False)
2412    ri.cw.block_start()
2413    local_vars = ['struct ynl_dump_state yds = {};',
2414                  'struct nlmsghdr *nlh;',
2415                  'int err;']
2416
2417    if ri.struct['request'].fixed_header:
2418        local_vars += ['size_t hdr_len;',
2419                       'void *hdr;']
2420
2421    ri.cw.write_func_lvar(local_vars)
2422
2423    ri.cw.p('yds.yarg.ys = ys;')
2424    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2425    ri.cw.p("yds.yarg.data = NULL;")
2426    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2427    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2428    if ri.op.value is not None:
2429        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2430    else:
2431        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2432    ri.cw.nl()
2433    if ri.family.is_classic():
2434        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2435    else:
2436        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2437
2438    if ri.struct['request'].fixed_header:
2439        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2440        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2441        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2442        ri.cw.nl()
2443
2444    if "request" in ri.op[ri.op_mode]:
2445        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2446        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2447        ri.cw.nl()
2448        for _, attr in ri.struct["request"].member_list():
2449            attr.attr_put(ri, "req")
2450    ri.cw.nl()
2451
2452    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2453    ri.cw.p('if (err < 0)')
2454    ri.cw.p('goto free_list;')
2455    ri.cw.nl()
2456
2457    ri.cw.p('return yds.first;')
2458    ri.cw.nl()
2459    ri.cw.p('free_list:')
2460    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2461    ri.cw.p('return NULL;')
2462    ri.cw.block_end()
2463
2464
2465def call_free(ri, direction, var):
2466    return f"{op_prefix(ri, direction)}_free({var});"
2467
2468
2469def free_arg_name(direction):
2470    if direction:
2471        return direction_to_suffix[direction][1:]
2472    return 'obj'
2473
2474
2475def print_alloc_wrapper(ri, direction, struct=None):
2476    name = op_prefix(ri, direction)
2477    struct_name = name
2478    if ri.type_name_conflict:
2479        struct_name += '_'
2480
2481    args = ["void"]
2482    cnt = "1"
2483    if struct and struct.in_multi_val:
2484        args = ["unsigned int n"]
2485        cnt = "n"
2486
2487    ri.cw.write_func_prot(f'static inline struct {struct_name} *',
2488                          f"{name}_alloc", args)
2489    ri.cw.block_start()
2490    ri.cw.p(f'return calloc({cnt}, sizeof(struct {struct_name}));')
2491    ri.cw.block_end()
2492
2493
2494def print_free_prototype(ri, direction, suffix=';'):
2495    name = op_prefix(ri, direction)
2496    struct_name = name
2497    if ri.type_name_conflict:
2498        struct_name += '_'
2499    arg = free_arg_name(direction)
2500    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2501
2502
2503def print_nlflags_set(ri, direction):
2504    name = op_prefix(ri, direction)
2505    ri.cw.write_func_prot(f'static inline void', f"{name}_set_nlflags",
2506                          [f"struct {name} *req", "__u16 nl_flags"])
2507    ri.cw.block_start()
2508    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2509    ri.cw.block_end()
2510    ri.cw.nl()
2511
2512
2513def _print_type(ri, direction, struct):
2514    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2515    if not direction and ri.type_name_conflict:
2516        suffix += '_'
2517
2518    if ri.op_mode == 'dump' and not ri.type_oneside:
2519        suffix += '_dump'
2520
2521    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2522
2523    if ri.needs_nlflags(direction):
2524        ri.cw.p('__u16 _nlmsg_flags;')
2525        ri.cw.nl()
2526    if struct.fixed_header:
2527        ri.cw.p(struct.fixed_header + ' _hdr;')
2528        ri.cw.nl()
2529
2530    for type_filter in ['present', 'len', 'count']:
2531        meta_started = False
2532        for _, attr in struct.member_list():
2533            line = attr.presence_member(ri.ku_space, type_filter)
2534            if line:
2535                if not meta_started:
2536                    ri.cw.block_start(line=f"struct")
2537                    meta_started = True
2538                ri.cw.p(line)
2539        if meta_started:
2540            ri.cw.block_end(line=f'_{type_filter};')
2541    ri.cw.nl()
2542
2543    for arg in struct.inherited:
2544        ri.cw.p(f"__u32 {arg};")
2545
2546    for _, attr in struct.member_list():
2547        attr.struct_member(ri)
2548
2549    ri.cw.block_end(line=';')
2550    ri.cw.nl()
2551
2552
2553def print_type(ri, direction):
2554    _print_type(ri, direction, ri.struct[direction])
2555
2556
2557def print_type_full(ri, struct):
2558    _print_type(ri, "", struct)
2559
2560    if struct.request and struct.in_multi_val:
2561        print_alloc_wrapper(ri, "", struct)
2562        ri.cw.nl()
2563        free_rsp_nested_prototype(ri)
2564        ri.cw.nl()
2565
2566        # Name conflicts are too hard to deal with with the current code base,
2567        # they are very rare so don't bother printing setters in that case.
2568        if ri.ku_space == 'user' and not ri.type_name_conflict:
2569            for _, attr in struct.member_list():
2570                attr.setter(ri, ri.attr_set, "", var="obj")
2571        ri.cw.nl()
2572
2573
2574def print_type_helpers(ri, direction, deref=False):
2575    print_free_prototype(ri, direction)
2576    ri.cw.nl()
2577
2578    if ri.needs_nlflags(direction):
2579        print_nlflags_set(ri, direction)
2580
2581    if ri.ku_space == 'user' and direction == 'request':
2582        for _, attr in ri.struct[direction].member_list():
2583            attr.setter(ri, ri.attr_set, direction, deref=deref)
2584    ri.cw.nl()
2585
2586
2587def print_req_type_helpers(ri):
2588    if ri.type_empty("request"):
2589        return
2590    print_alloc_wrapper(ri, "request")
2591    print_type_helpers(ri, "request")
2592
2593
2594def print_rsp_type_helpers(ri):
2595    if 'reply' not in ri.op[ri.op_mode]:
2596        return
2597    print_type_helpers(ri, "reply")
2598
2599
2600def print_parse_prototype(ri, direction, terminate=True):
2601    suffix = "_rsp" if direction == "reply" else "_req"
2602    term = ';' if terminate else ''
2603
2604    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2605                          ['const struct nlattr **tb',
2606                           f"struct {ri.op.render_name}{suffix} *req"],
2607                          suffix=term)
2608
2609
2610def print_req_type(ri):
2611    if ri.type_empty("request"):
2612        return
2613    print_type(ri, "request")
2614
2615
2616def print_req_free(ri):
2617    if 'request' not in ri.op[ri.op_mode]:
2618        return
2619    _free_type(ri, 'request', ri.struct['request'])
2620
2621
2622def print_rsp_type(ri):
2623    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2624        direction = 'reply'
2625    elif ri.op_mode == 'event':
2626        direction = 'reply'
2627    else:
2628        return
2629    print_type(ri, direction)
2630
2631
2632def print_wrapped_type(ri):
2633    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2634    if ri.op_mode == 'dump':
2635        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2636    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2637        ri.cw.p('__u16 family;')
2638        ri.cw.p('__u8 cmd;')
2639        ri.cw.p('struct ynl_ntf_base_type *next;')
2640        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2641    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2642    ri.cw.block_end(line=';')
2643    ri.cw.nl()
2644    print_free_prototype(ri, 'reply')
2645    ri.cw.nl()
2646
2647
2648def _free_type_members_iter(ri, struct):
2649    if struct.free_needs_iter():
2650        ri.cw.p('unsigned int i;')
2651        ri.cw.nl()
2652
2653
2654def _free_type_members(ri, var, struct, ref=''):
2655    for _, attr in struct.member_list():
2656        attr.free(ri, var, ref)
2657
2658
2659def _free_type(ri, direction, struct):
2660    var = free_arg_name(direction)
2661
2662    print_free_prototype(ri, direction, suffix='')
2663    ri.cw.block_start()
2664    _free_type_members_iter(ri, struct)
2665    _free_type_members(ri, var, struct)
2666    if direction:
2667        ri.cw.p(f'free({var});')
2668    ri.cw.block_end()
2669    ri.cw.nl()
2670
2671
2672def free_rsp_nested_prototype(ri):
2673        print_free_prototype(ri, "")
2674
2675
2676def free_rsp_nested(ri, struct):
2677    _free_type(ri, "", struct)
2678
2679
2680def print_rsp_free(ri):
2681    if 'reply' not in ri.op[ri.op_mode]:
2682        return
2683    _free_type(ri, 'reply', ri.struct['reply'])
2684
2685
2686def print_dump_type_free(ri):
2687    sub_type = type_name(ri, 'reply')
2688
2689    print_free_prototype(ri, 'reply', suffix='')
2690    ri.cw.block_start()
2691    ri.cw.p(f"{sub_type} *next = rsp;")
2692    ri.cw.nl()
2693    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2694    _free_type_members_iter(ri, ri.struct['reply'])
2695    ri.cw.p('rsp = next;')
2696    ri.cw.p('next = rsp->next;')
2697    ri.cw.nl()
2698
2699    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2700    ri.cw.p(f'free(rsp);')
2701    ri.cw.block_end()
2702    ri.cw.block_end()
2703    ri.cw.nl()
2704
2705
2706def print_ntf_type_free(ri):
2707    print_free_prototype(ri, 'reply', suffix='')
2708    ri.cw.block_start()
2709    _free_type_members_iter(ri, ri.struct['reply'])
2710    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2711    ri.cw.p(f'free(rsp);')
2712    ri.cw.block_end()
2713    ri.cw.nl()
2714
2715
2716def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2717    if terminate and ri and policy_should_be_static(struct.family):
2718        return
2719
2720    if terminate:
2721        prefix = 'extern '
2722    else:
2723        if ri and policy_should_be_static(struct.family):
2724            prefix = 'static '
2725        else:
2726            prefix = ''
2727
2728    suffix = ';' if terminate else ' = {'
2729
2730    max_attr = struct.attr_max_val
2731    if ri:
2732        name = ri.op.render_name
2733        if ri.op.dual_policy:
2734            name += '_' + ri.op_mode
2735    else:
2736        name = struct.render_name
2737    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2738
2739
2740def print_req_policy(cw, struct, ri=None):
2741    if ri and ri.op:
2742        cw.ifdef_block(ri.op.get('config-cond', None))
2743    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2744    for _, arg in struct.member_list():
2745        arg.attr_policy(cw)
2746    cw.p("};")
2747    cw.ifdef_block(None)
2748    cw.nl()
2749
2750
2751def kernel_can_gen_family_struct(family):
2752    return family.proto == 'genetlink'
2753
2754
2755def policy_should_be_static(family):
2756    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2757
2758
2759def print_kernel_policy_ranges(family, cw):
2760    first = True
2761    for _, attr_set in family.attr_sets.items():
2762        if attr_set.subset_of:
2763            continue
2764
2765        for _, attr in attr_set.items():
2766            if not attr.request:
2767                continue
2768            if 'full-range' not in attr.checks:
2769                continue
2770
2771            if first:
2772                cw.p('/* Integer value ranges */')
2773                first = False
2774
2775            sign = '' if attr.type[0] == 'u' else '_signed'
2776            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2777            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2778            members = []
2779            if 'min' in attr.checks:
2780                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2781            if 'max' in attr.checks:
2782                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2783            cw.write_struct_init(members)
2784            cw.block_end(line=';')
2785            cw.nl()
2786
2787
2788def print_kernel_policy_sparse_enum_validates(family, cw):
2789    first = True
2790    for _, attr_set in family.attr_sets.items():
2791        if attr_set.subset_of:
2792            continue
2793
2794        for _, attr in attr_set.items():
2795            if not attr.request:
2796                continue
2797            if not attr.enum_name:
2798                continue
2799            if 'sparse' not in attr.checks:
2800                continue
2801
2802            if first:
2803                cw.p('/* Sparse enums validation callbacks */')
2804                first = False
2805
2806            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2807                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2808            cw.block_start()
2809            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2810            enum = family.consts[attr['enum']]
2811            first_entry = True
2812            for entry in enum.entries.values():
2813                if first_entry:
2814                    first_entry = False
2815                else:
2816                    cw.p('fallthrough;')
2817                cw.p(f'case {entry.c_name}:')
2818            cw.p('return 0;')
2819            cw.block_end()
2820            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2821            cw.p('return -EINVAL;')
2822            cw.block_end()
2823            cw.nl()
2824
2825
2826def print_kernel_op_table_fwd(family, cw, terminate):
2827    exported = not kernel_can_gen_family_struct(family)
2828
2829    if not terminate or exported:
2830        cw.p(f"/* Ops table for {family.ident_name} */")
2831
2832        pol_to_struct = {'global': 'genl_small_ops',
2833                         'per-op': 'genl_ops',
2834                         'split': 'genl_split_ops'}
2835        struct_type = pol_to_struct[family.kernel_policy]
2836
2837        if not exported:
2838            cnt = ""
2839        elif family.kernel_policy == 'split':
2840            cnt = 0
2841            for op in family.ops.values():
2842                if 'do' in op:
2843                    cnt += 1
2844                if 'dump' in op:
2845                    cnt += 1
2846        else:
2847            cnt = len(family.ops)
2848
2849        qual = 'static const' if not exported else 'const'
2850        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2851        if terminate:
2852            cw.p(f"extern {line};")
2853        else:
2854            cw.block_start(line=line + ' =')
2855
2856    if not terminate:
2857        return
2858
2859    cw.nl()
2860    for name in family.hooks['pre']['do']['list']:
2861        cw.write_func_prot('int', c_lower(name),
2862                           ['const struct genl_split_ops *ops',
2863                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2864    for name in family.hooks['post']['do']['list']:
2865        cw.write_func_prot('void', c_lower(name),
2866                           ['const struct genl_split_ops *ops',
2867                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2868    for name in family.hooks['pre']['dump']['list']:
2869        cw.write_func_prot('int', c_lower(name),
2870                           ['struct netlink_callback *cb'], suffix=';')
2871    for name in family.hooks['post']['dump']['list']:
2872        cw.write_func_prot('int', c_lower(name),
2873                           ['struct netlink_callback *cb'], suffix=';')
2874
2875    cw.nl()
2876
2877    for op_name, op in family.ops.items():
2878        if op.is_async:
2879            continue
2880
2881        if 'do' in op:
2882            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2883            cw.write_func_prot('int', name,
2884                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2885
2886        if 'dump' in op:
2887            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2888            cw.write_func_prot('int', name,
2889                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2890    cw.nl()
2891
2892
2893def print_kernel_op_table_hdr(family, cw):
2894    print_kernel_op_table_fwd(family, cw, terminate=True)
2895
2896
2897def print_kernel_op_table(family, cw):
2898    print_kernel_op_table_fwd(family, cw, terminate=False)
2899    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2900        for op_name, op in family.ops.items():
2901            if op.is_async:
2902                continue
2903
2904            cw.ifdef_block(op.get('config-cond', None))
2905            cw.block_start()
2906            members = [('cmd', op.enum_name)]
2907            if 'dont-validate' in op:
2908                members.append(('validate',
2909                                ' | '.join([c_upper('genl-dont-validate-' + x)
2910                                            for x in op['dont-validate']])), )
2911            for op_mode in ['do', 'dump']:
2912                if op_mode in op:
2913                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2914                    members.append((op_mode + 'it', name))
2915            if family.kernel_policy == 'per-op':
2916                struct = Struct(family, op['attribute-set'],
2917                                type_list=op['do']['request']['attributes'])
2918
2919                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2920                members.append(('policy', name))
2921                members.append(('maxattr', struct.attr_max_val.enum_name))
2922            if 'flags' in op:
2923                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2924            cw.write_struct_init(members)
2925            cw.block_end(line=',')
2926    elif family.kernel_policy == 'split':
2927        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2928                    'dump': {'pre': 'start', 'post': 'done'}}
2929
2930        for op_name, op in family.ops.items():
2931            for op_mode in ['do', 'dump']:
2932                if op.is_async or op_mode not in op:
2933                    continue
2934
2935                cw.ifdef_block(op.get('config-cond', None))
2936                cw.block_start()
2937                members = [('cmd', op.enum_name)]
2938                if 'dont-validate' in op:
2939                    dont_validate = []
2940                    for x in op['dont-validate']:
2941                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2942                            continue
2943                        if op_mode == "dump" and x == 'strict':
2944                            continue
2945                        dont_validate.append(x)
2946
2947                    if dont_validate:
2948                        members.append(('validate',
2949                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2950                                                    for x in dont_validate])), )
2951                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2952                if 'pre' in op[op_mode]:
2953                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2954                members.append((op_mode + 'it', name))
2955                if 'post' in op[op_mode]:
2956                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2957                if 'request' in op[op_mode]:
2958                    struct = Struct(family, op['attribute-set'],
2959                                    type_list=op[op_mode]['request']['attributes'])
2960
2961                    if op.dual_policy:
2962                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2963                    else:
2964                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2965                    members.append(('policy', name))
2966                    members.append(('maxattr', struct.attr_max_val.enum_name))
2967                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2968                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2969                cw.write_struct_init(members)
2970                cw.block_end(line=',')
2971    cw.ifdef_block(None)
2972
2973    cw.block_end(line=';')
2974    cw.nl()
2975
2976
2977def print_kernel_mcgrp_hdr(family, cw):
2978    if not family.mcgrps['list']:
2979        return
2980
2981    cw.block_start('enum')
2982    for grp in family.mcgrps['list']:
2983        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2984        cw.p(grp_id)
2985    cw.block_end(';')
2986    cw.nl()
2987
2988
2989def print_kernel_mcgrp_src(family, cw):
2990    if not family.mcgrps['list']:
2991        return
2992
2993    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2994    for grp in family.mcgrps['list']:
2995        name = grp['name']
2996        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2997        cw.p('[' + grp_id + '] = { "' + name + '", },')
2998    cw.block_end(';')
2999    cw.nl()
3000
3001
3002def print_kernel_family_struct_hdr(family, cw):
3003    if not kernel_can_gen_family_struct(family):
3004        return
3005
3006    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
3007    cw.nl()
3008    if 'sock-priv' in family.kernel_family:
3009        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
3010        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
3011        cw.nl()
3012
3013
3014def print_kernel_family_struct_src(family, cw):
3015    if not kernel_can_gen_family_struct(family):
3016        return
3017
3018    if 'sock-priv' in family.kernel_family:
3019        # Generate "trampolines" to make CFI happy
3020        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
3021                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
3022                      ["void *priv"])
3023        cw.nl()
3024        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
3025                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
3026                      ["void *priv"])
3027        cw.nl()
3028
3029    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
3030    cw.p('.name\t\t= ' + family.fam_key + ',')
3031    cw.p('.version\t= ' + family.ver_key + ',')
3032    cw.p('.netnsok\t= true,')
3033    cw.p('.parallel_ops\t= true,')
3034    cw.p('.module\t\t= THIS_MODULE,')
3035    if family.kernel_policy == 'per-op':
3036        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
3037        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3038    elif family.kernel_policy == 'split':
3039        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
3040        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3041    if family.mcgrps['list']:
3042        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
3043        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
3044    if 'sock-priv' in family.kernel_family:
3045        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
3046        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
3047        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
3048    cw.block_end(';')
3049
3050
3051def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
3052    start_line = 'enum'
3053    if enum_name in obj:
3054        if obj[enum_name]:
3055            start_line = 'enum ' + c_lower(obj[enum_name])
3056    elif ckey and ckey in obj:
3057        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3058    cw.block_start(line=start_line)
3059
3060
3061def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3062    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3063    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3064    max_value = f"({cnt_name} - 1)"
3065
3066    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3067    val = 0
3068    for op in family.msgs.values():
3069        if separate_ntf and ('notify' in op or 'event' in op):
3070            continue
3071
3072        suffix = ','
3073        if op.value != val:
3074            suffix = f" = {op.value},"
3075            val = op.value
3076        cw.p(op.enum_name + suffix)
3077        val += 1
3078    cw.nl()
3079    cw.p(cnt_name + ('' if max_by_define else ','))
3080    if not max_by_define:
3081        cw.p(f"{max_name} = {max_value}")
3082    cw.block_end(line=';')
3083    if max_by_define:
3084        cw.p(f"#define {max_name} {max_value}")
3085    cw.nl()
3086
3087
3088def render_uapi_directional(family, cw, max_by_define):
3089    max_name = f"{family.op_prefix}USER_MAX"
3090    cnt_name = f"__{family.op_prefix}USER_CNT"
3091    max_value = f"({cnt_name} - 1)"
3092
3093    cw.block_start(line='enum')
3094    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3095    val = 0
3096    for op in family.msgs.values():
3097        if 'do' in op and 'event' not in op:
3098            suffix = ','
3099            if op.value and op.value != val:
3100                suffix = f" = {op.value},"
3101                val = op.value
3102            cw.p(op.enum_name + suffix)
3103            val += 1
3104    cw.nl()
3105    cw.p(cnt_name + ('' if max_by_define else ','))
3106    if not max_by_define:
3107        cw.p(f"{max_name} = {max_value}")
3108    cw.block_end(line=';')
3109    if max_by_define:
3110        cw.p(f"#define {max_name} {max_value}")
3111    cw.nl()
3112
3113    max_name = f"{family.op_prefix}KERNEL_MAX"
3114    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3115    max_value = f"({cnt_name} - 1)"
3116
3117    cw.block_start(line='enum')
3118    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3119    val = 0
3120    for op in family.msgs.values():
3121        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3122            enum_name = op.enum_name
3123            if 'event' not in op and 'notify' not in op:
3124                enum_name = f'{enum_name}_REPLY'
3125
3126            suffix = ','
3127            if op.value and op.value != val:
3128                suffix = f" = {op.value},"
3129                val = op.value
3130            cw.p(enum_name + suffix)
3131            val += 1
3132    cw.nl()
3133    cw.p(cnt_name + ('' if max_by_define else ','))
3134    if not max_by_define:
3135        cw.p(f"{max_name} = {max_value}")
3136    cw.block_end(line=';')
3137    if max_by_define:
3138        cw.p(f"#define {max_name} {max_value}")
3139    cw.nl()
3140
3141
3142def render_uapi(family, cw):
3143    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3144    hdr_prot = hdr_prot.replace('/', '_')
3145    cw.p('#ifndef ' + hdr_prot)
3146    cw.p('#define ' + hdr_prot)
3147    cw.nl()
3148
3149    defines = [(family.fam_key, family["name"]),
3150               (family.ver_key, family.get('version', 1))]
3151    cw.writes_defines(defines)
3152    cw.nl()
3153
3154    defines = []
3155    for const in family['definitions']:
3156        if const.get('header'):
3157            continue
3158
3159        if const['type'] != 'const':
3160            cw.writes_defines(defines)
3161            defines = []
3162            cw.nl()
3163
3164        # Write kdoc for enum and flags (one day maybe also structs)
3165        if const['type'] == 'enum' or const['type'] == 'flags':
3166            enum = family.consts[const['name']]
3167
3168            if enum.header:
3169                continue
3170
3171            if enum.has_doc():
3172                if enum.has_entry_doc():
3173                    cw.p('/**')
3174                    doc = ''
3175                    if 'doc' in enum:
3176                        doc = ' - ' + enum['doc']
3177                    cw.write_doc_line(enum.enum_name + doc)
3178                else:
3179                    cw.p('/*')
3180                    cw.write_doc_line(enum['doc'], indent=False)
3181                for entry in enum.entries.values():
3182                    if entry.has_doc():
3183                        doc = '@' + entry.c_name + ': ' + entry['doc']
3184                        cw.write_doc_line(doc)
3185                cw.p(' */')
3186
3187            uapi_enum_start(family, cw, const, 'name')
3188            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3189            for entry in enum.entries.values():
3190                suffix = ','
3191                if entry.value_change:
3192                    suffix = f" = {entry.user_value()}" + suffix
3193                cw.p(entry.c_name + suffix)
3194
3195            if const.get('render-max', False):
3196                cw.nl()
3197                cw.p('/* private: */')
3198                if const['type'] == 'flags':
3199                    max_name = c_upper(name_pfx + 'mask')
3200                    max_val = f' = {enum.get_mask()},'
3201                    cw.p(max_name + max_val)
3202                else:
3203                    cnt_name = enum.enum_cnt_name
3204                    max_name = c_upper(name_pfx + 'max')
3205                    if not cnt_name:
3206                        cnt_name = '__' + name_pfx + 'max'
3207                    cw.p(c_upper(cnt_name) + ',')
3208                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3209            cw.block_end(line=';')
3210            cw.nl()
3211        elif const['type'] == 'const':
3212            defines.append([c_upper(family.get('c-define-name',
3213                                               f"{family.ident_name}-{const['name']}")),
3214                            const['value']])
3215
3216    if defines:
3217        cw.writes_defines(defines)
3218        cw.nl()
3219
3220    max_by_define = family.get('max-by-define', False)
3221
3222    for _, attr_set in family.attr_sets.items():
3223        if attr_set.subset_of:
3224            continue
3225
3226        max_value = f"({attr_set.cnt_name} - 1)"
3227
3228        val = 0
3229        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3230        for _, attr in attr_set.items():
3231            suffix = ','
3232            if attr.value != val:
3233                suffix = f" = {attr.value},"
3234                val = attr.value
3235            val += 1
3236            cw.p(attr.enum_name + suffix)
3237        if attr_set.items():
3238            cw.nl()
3239        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3240        if not max_by_define:
3241            cw.p(f"{attr_set.max_name} = {max_value}")
3242        cw.block_end(line=';')
3243        if max_by_define:
3244            cw.p(f"#define {attr_set.max_name} {max_value}")
3245        cw.nl()
3246
3247    # Commands
3248    separate_ntf = 'async-prefix' in family['operations']
3249
3250    if family.msg_id_model == 'unified':
3251        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3252    elif family.msg_id_model == 'directional':
3253        render_uapi_directional(family, cw, max_by_define)
3254    else:
3255        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3256
3257    if separate_ntf:
3258        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3259        for op in family.msgs.values():
3260            if separate_ntf and not ('notify' in op or 'event' in op):
3261                continue
3262
3263            suffix = ','
3264            if 'value' in op:
3265                suffix = f" = {op['value']},"
3266            cw.p(op.enum_name + suffix)
3267        cw.block_end(line=';')
3268        cw.nl()
3269
3270    # Multicast
3271    defines = []
3272    for grp in family.mcgrps['list']:
3273        name = grp['name']
3274        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3275                        f'{name}'])
3276    cw.nl()
3277    if defines:
3278        cw.writes_defines(defines)
3279        cw.nl()
3280
3281    cw.p(f'#endif /* {hdr_prot} */')
3282
3283
3284def _render_user_ntf_entry(ri, op):
3285    if not ri.family.is_classic():
3286        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3287    else:
3288        crud_op = ri.family.req_by_value[op.rsp_value]
3289        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3290    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3291    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3292    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3293    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3294    ri.cw.block_end(line=',')
3295
3296
3297def render_user_family(family, cw, prototype):
3298    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3299    if prototype:
3300        cw.p(f'extern {symbol};')
3301        return
3302
3303    if family.ntfs:
3304        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3305        for ntf_op_name, ntf_op in family.ntfs.items():
3306            if 'notify' in ntf_op:
3307                op = family.ops[ntf_op['notify']]
3308                ri = RenderInfo(cw, family, "user", op, "notify")
3309            elif 'event' in ntf_op:
3310                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3311            else:
3312                raise Exception('Invalid notification ' + ntf_op_name)
3313            _render_user_ntf_entry(ri, ntf_op)
3314        for op_name, op in family.ops.items():
3315            if 'event' not in op:
3316                continue
3317            ri = RenderInfo(cw, family, "user", op, "event")
3318            _render_user_ntf_entry(ri, op)
3319        cw.block_end(line=";")
3320        cw.nl()
3321
3322    cw.block_start(f'{symbol} = ')
3323    cw.p(f'.name\t\t= "{family.c_name}",')
3324    if family.is_classic():
3325        cw.p(f'.is_classic\t= true,')
3326        cw.p(f'.classic_id\t= {family.get("protonum")},')
3327    if family.is_classic():
3328        if family.fixed_header:
3329            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3330    elif family.fixed_header:
3331        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3332    else:
3333        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3334    if family.ntfs:
3335        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3336        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3337    cw.block_end(line=';')
3338
3339
3340def family_contains_bitfield32(family):
3341    for _, attr_set in family.attr_sets.items():
3342        if attr_set.subset_of:
3343            continue
3344        for _, attr in attr_set.items():
3345            if attr.type == "bitfield32":
3346                return True
3347    return False
3348
3349
3350def find_kernel_root(full_path):
3351    sub_path = ''
3352    while True:
3353        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3354        full_path = os.path.dirname(full_path)
3355        maintainers = os.path.join(full_path, "MAINTAINERS")
3356        if os.path.exists(maintainers):
3357            return full_path, sub_path[:-1]
3358
3359
3360def main():
3361    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3362    parser.add_argument('--mode', dest='mode', type=str, required=True,
3363                        choices=('user', 'kernel', 'uapi'))
3364    parser.add_argument('--spec', dest='spec', type=str, required=True)
3365    parser.add_argument('--header', dest='header', action='store_true', default=None)
3366    parser.add_argument('--source', dest='header', action='store_false')
3367    parser.add_argument('--user-header', nargs='+', default=[])
3368    parser.add_argument('--cmp-out', action='store_true', default=None,
3369                        help='Do not overwrite the output file if the new output is identical to the old')
3370    parser.add_argument('--exclude-op', action='append', default=[])
3371    parser.add_argument('-o', dest='out_file', type=str, default=None)
3372    args = parser.parse_args()
3373
3374    if args.header is None:
3375        parser.error("--header or --source is required")
3376
3377    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3378
3379    try:
3380        parsed = Family(args.spec, exclude_ops)
3381        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3382            print('Spec license:', parsed.license)
3383            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3384            os.sys.exit(1)
3385    except yaml.YAMLError as exc:
3386        print(exc)
3387        os.sys.exit(1)
3388        return
3389
3390    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3391
3392    _, spec_kernel = find_kernel_root(args.spec)
3393    if args.mode == 'uapi' or args.header:
3394        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3395    else:
3396        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3397    cw.p("/* Do not edit directly, auto-generated from: */")
3398    cw.p(f"/*\t{spec_kernel} */")
3399    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3400    if args.exclude_op or args.user_header:
3401        line = ''
3402        line += ' --user-header '.join([''] + args.user_header)
3403        line += ' --exclude-op '.join([''] + args.exclude_op)
3404        cw.p(f'/* YNL-ARG{line} */')
3405    cw.nl()
3406
3407    if args.mode == 'uapi':
3408        render_uapi(parsed, cw)
3409        return
3410
3411    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3412    if args.header:
3413        cw.p('#ifndef ' + hdr_prot)
3414        cw.p('#define ' + hdr_prot)
3415        cw.nl()
3416
3417    if args.out_file:
3418        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3419    else:
3420        hdr_file = "generated_header_file.h"
3421
3422    if args.mode == 'kernel':
3423        cw.p('#include <net/netlink.h>')
3424        cw.p('#include <net/genetlink.h>')
3425        cw.nl()
3426        if not args.header:
3427            if args.out_file:
3428                cw.p(f'#include "{hdr_file}"')
3429            cw.nl()
3430        headers = ['uapi/' + parsed.uapi_header]
3431        headers += parsed.kernel_family.get('headers', [])
3432    else:
3433        cw.p('#include <stdlib.h>')
3434        cw.p('#include <string.h>')
3435        if args.header:
3436            cw.p('#include <linux/types.h>')
3437            if family_contains_bitfield32(parsed):
3438                cw.p('#include <linux/netlink.h>')
3439        else:
3440            cw.p(f'#include "{hdr_file}"')
3441            cw.p('#include "ynl.h"')
3442        headers = []
3443    for definition in parsed['definitions'] + parsed['attribute-sets']:
3444        if 'header' in definition:
3445            headers.append(definition['header'])
3446    if args.mode == 'user':
3447        headers.append(parsed.uapi_header)
3448    seen_header = []
3449    for one in headers:
3450        if one not in seen_header:
3451            cw.p(f"#include <{one}>")
3452            seen_header.append(one)
3453    cw.nl()
3454
3455    if args.mode == "user":
3456        if not args.header:
3457            cw.p("#include <linux/genetlink.h>")
3458            cw.nl()
3459            for one in args.user_header:
3460                cw.p(f'#include "{one}"')
3461        else:
3462            cw.p('struct ynl_sock;')
3463            cw.nl()
3464            render_user_family(parsed, cw, True)
3465        cw.nl()
3466
3467    if args.mode == "kernel":
3468        if args.header:
3469            for _, struct in sorted(parsed.pure_nested_structs.items()):
3470                if struct.request:
3471                    cw.p('/* Common nested types */')
3472                    break
3473            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3474                if struct.request:
3475                    print_req_policy_fwd(cw, struct)
3476            cw.nl()
3477
3478            if parsed.kernel_policy == 'global':
3479                cw.p(f"/* Global operation policy for {parsed.name} */")
3480
3481                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3482                print_req_policy_fwd(cw, struct)
3483                cw.nl()
3484
3485            if parsed.kernel_policy in {'per-op', 'split'}:
3486                for op_name, op in parsed.ops.items():
3487                    if 'do' in op and 'event' not in op:
3488                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3489                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3490                        cw.nl()
3491
3492            print_kernel_op_table_hdr(parsed, cw)
3493            print_kernel_mcgrp_hdr(parsed, cw)
3494            print_kernel_family_struct_hdr(parsed, cw)
3495        else:
3496            print_kernel_policy_ranges(parsed, cw)
3497            print_kernel_policy_sparse_enum_validates(parsed, cw)
3498
3499            for _, struct in sorted(parsed.pure_nested_structs.items()):
3500                if struct.request:
3501                    cw.p('/* Common nested types */')
3502                    break
3503            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3504                if struct.request:
3505                    print_req_policy(cw, struct)
3506            cw.nl()
3507
3508            if parsed.kernel_policy == 'global':
3509                cw.p(f"/* Global operation policy for {parsed.name} */")
3510
3511                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3512                print_req_policy(cw, struct)
3513                cw.nl()
3514
3515            for op_name, op in parsed.ops.items():
3516                if parsed.kernel_policy in {'per-op', 'split'}:
3517                    for op_mode in ['do', 'dump']:
3518                        if op_mode in op and 'request' in op[op_mode]:
3519                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3520                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3521                            print_req_policy(cw, ri.struct['request'], ri=ri)
3522                            cw.nl()
3523
3524            print_kernel_op_table(parsed, cw)
3525            print_kernel_mcgrp_src(parsed, cw)
3526            print_kernel_family_struct_src(parsed, cw)
3527
3528    if args.mode == "user":
3529        if args.header:
3530            cw.p('/* Enums */')
3531            put_op_name_fwd(parsed, cw)
3532
3533            for name, const in parsed.consts.items():
3534                if isinstance(const, EnumSet):
3535                    put_enum_to_str_fwd(parsed, cw, const)
3536            cw.nl()
3537
3538            cw.p('/* Common nested types */')
3539            for attr_set, struct in parsed.pure_nested_structs.items():
3540                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3541                print_type_full(ri, struct)
3542
3543            for op_name, op in parsed.ops.items():
3544                cw.p(f"/* ============== {op.enum_name} ============== */")
3545
3546                if 'do' in op and 'event' not in op:
3547                    cw.p(f"/* {op.enum_name} - do */")
3548                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3549                    print_req_type(ri)
3550                    print_req_type_helpers(ri)
3551                    cw.nl()
3552                    print_rsp_type(ri)
3553                    print_rsp_type_helpers(ri)
3554                    cw.nl()
3555                    print_req_prototype(ri)
3556                    cw.nl()
3557
3558                if 'dump' in op:
3559                    cw.p(f"/* {op.enum_name} - dump */")
3560                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3561                    print_req_type(ri)
3562                    print_req_type_helpers(ri)
3563                    if not ri.type_consistent or ri.type_oneside:
3564                        print_rsp_type(ri)
3565                    print_wrapped_type(ri)
3566                    print_dump_prototype(ri)
3567                    cw.nl()
3568
3569                if op.has_ntf:
3570                    cw.p(f"/* {op.enum_name} - notify */")
3571                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3572                    if not ri.type_consistent:
3573                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3574                    print_wrapped_type(ri)
3575
3576            for op_name, op in parsed.ntfs.items():
3577                if 'event' in op:
3578                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3579                    cw.p(f"/* {op.enum_name} - event */")
3580                    print_rsp_type(ri)
3581                    cw.nl()
3582                    print_wrapped_type(ri)
3583            cw.nl()
3584        else:
3585            cw.p('/* Enums */')
3586            put_op_name(parsed, cw)
3587
3588            for name, const in parsed.consts.items():
3589                if isinstance(const, EnumSet):
3590                    put_enum_to_str(parsed, cw, const)
3591            cw.nl()
3592
3593            has_recursive_nests = False
3594            cw.p('/* Policies */')
3595            for struct in parsed.pure_nested_structs.values():
3596                if struct.recursive:
3597                    put_typol_fwd(cw, struct)
3598                    has_recursive_nests = True
3599            if has_recursive_nests:
3600                cw.nl()
3601            for struct in parsed.pure_nested_structs.values():
3602                put_typol(cw, struct)
3603            for name in parsed.root_sets:
3604                struct = Struct(parsed, name)
3605                put_typol(cw, struct)
3606
3607            cw.p('/* Common nested types */')
3608            if has_recursive_nests:
3609                for attr_set, struct in parsed.pure_nested_structs.items():
3610                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3611                    free_rsp_nested_prototype(ri)
3612                    if struct.request:
3613                        put_req_nested_prototype(ri, struct)
3614                    if struct.reply:
3615                        parse_rsp_nested_prototype(ri, struct)
3616                cw.nl()
3617            for attr_set, struct in parsed.pure_nested_structs.items():
3618                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3619
3620                free_rsp_nested(ri, struct)
3621                if struct.request:
3622                    put_req_nested(ri, struct)
3623                if struct.reply:
3624                    parse_rsp_nested(ri, struct)
3625
3626            for op_name, op in parsed.ops.items():
3627                cw.p(f"/* ============== {op.enum_name} ============== */")
3628                if 'do' in op and 'event' not in op:
3629                    cw.p(f"/* {op.enum_name} - do */")
3630                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3631                    print_req_free(ri)
3632                    print_rsp_free(ri)
3633                    parse_rsp_msg(ri)
3634                    print_req(ri)
3635                    cw.nl()
3636
3637                if 'dump' in op:
3638                    cw.p(f"/* {op.enum_name} - dump */")
3639                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3640                    if not ri.type_consistent or ri.type_oneside:
3641                        parse_rsp_msg(ri, deref=True)
3642                    print_req_free(ri)
3643                    print_dump_type_free(ri)
3644                    print_dump(ri)
3645                    cw.nl()
3646
3647                if op.has_ntf:
3648                    cw.p(f"/* {op.enum_name} - notify */")
3649                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3650                    if not ri.type_consistent:
3651                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3652                    print_ntf_type_free(ri)
3653
3654            for op_name, op in parsed.ntfs.items():
3655                if 'event' in op:
3656                    cw.p(f"/* {op.enum_name} - event */")
3657
3658                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3659                    parse_rsp_msg(ri)
3660
3661                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3662                    print_ntf_type_free(ri)
3663            cw.nl()
3664            render_user_family(parsed, cw, False)
3665
3666    if args.header:
3667        cw.p(f'#endif /* {hdr_prot} */')
3668
3669
3670if __name__ == "__main__":
3671    main()
3672