xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 8be4d31cb8aaeea27bde4b7ddb26e28a89062ebf)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17from lib import SpecSubMessage, SpecSubMessageFormat
18
19
20def c_upper(name):
21    return name.upper().replace('-', '_')
22
23
24def c_lower(name):
25    return name.lower().replace('-', '_')
26
27
28def limit_to_number(name):
29    """
30    Turn a string limit like u32-max or s64-min into its numerical value
31    """
32    if name[0] == 'u' and name.endswith('-min'):
33        return 0
34    width = int(name[1:-4])
35    if name[0] == 's':
36        width -= 1
37    value = (1 << width) - 1
38    if name[0] == 's' and name.endswith('-min'):
39        value = -value - 1
40    return value
41
42
43class BaseNlLib:
44    def get_family_id(self):
45        return 'ys->family_id'
46
47
48class Type(SpecAttr):
49    def __init__(self, family, attr_set, attr, value):
50        super().__init__(family, attr_set, attr, value)
51
52        self.attr = attr
53        self.attr_set = attr_set
54        self.type = attr['type']
55        self.checks = attr.get('checks', {})
56
57        self.request = False
58        self.reply = False
59
60        self.is_selector = False
61
62        if 'len' in attr:
63            self.len = attr['len']
64
65        if 'nested-attributes' in attr:
66            nested = attr['nested-attributes']
67        elif 'sub-message' in attr:
68            nested = attr['sub-message']
69        else:
70            nested = None
71
72        if nested:
73            self.nested_attrs = nested
74            if self.nested_attrs == family.name:
75                self.nested_render_name = c_lower(f"{family.ident_name}")
76            else:
77                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
78
79            if self.nested_attrs in self.family.consts:
80                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
81            else:
82                self.nested_struct_type = 'struct ' + self.nested_render_name
83
84        self.c_name = c_lower(self.name)
85        if self.c_name in _C_KW:
86            self.c_name += '_'
87        if self.c_name[0].isdigit():
88            self.c_name = '_' + self.c_name
89
90        # Added by resolve():
91        self.enum_name = None
92        delattr(self, "enum_name")
93
94    def _get_real_attr(self):
95        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
96        return self.family.attr_sets[self.attr_set.subset_of][self.name]
97
98    def set_request(self):
99        self.request = True
100        if self.attr_set.subset_of:
101            self._get_real_attr().set_request()
102
103    def set_reply(self):
104        self.reply = True
105        if self.attr_set.subset_of:
106            self._get_real_attr().set_reply()
107
108    def get_limit(self, limit, default=None):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return value
112        if isinstance(value, int):
113            return value
114        if value in self.family.consts:
115            return self.family.consts[value]["value"]
116        return limit_to_number(value)
117
118    def get_limit_str(self, limit, default=None, suffix=''):
119        value = self.checks.get(limit, default)
120        if value is None:
121            return ''
122        if isinstance(value, int):
123            return str(value) + suffix
124        if value in self.family.consts:
125            const = self.family.consts[value]
126            if const.get('header'):
127                return c_upper(value)
128            return c_upper(f"{self.family['name']}-{value}")
129        return c_upper(value)
130
131    def resolve(self):
132        if 'parent-sub-message' in self.attr:
133            enum_name = self.attr['parent-sub-message'].enum_name
134        elif 'name-prefix' in self.attr:
135            enum_name = f"{self.attr['name-prefix']}{self.name}"
136        else:
137            enum_name = f"{self.attr_set.name_prefix}{self.name}"
138        self.enum_name = c_upper(enum_name)
139
140        if self.attr_set.subset_of:
141            if self.checks != self._get_real_attr().checks:
142                raise Exception("Overriding checks not supported by codegen, yet")
143
144    def is_multi_val(self):
145        return None
146
147    def is_scalar(self):
148        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
149
150    def is_recursive(self):
151        return False
152
153    def is_recursive_for_op(self, ri):
154        return self.is_recursive() and not ri.op
155
156    def presence_type(self):
157        return 'present'
158
159    def presence_member(self, space, type_filter):
160        if self.presence_type() != type_filter:
161            return
162
163        if self.presence_type() == 'present':
164            pfx = '__' if space == 'user' else ''
165            return f"{pfx}u32 {self.c_name}:1;"
166
167        if self.presence_type() in {'len', 'count'}:
168            pfx = '__' if space == 'user' else ''
169            return f"{pfx}u32 {self.c_name};"
170
171    def _complex_member_type(self, ri):
172        return None
173
174    def free_needs_iter(self):
175        return False
176
177    def _free_lines(self, ri, var, ref):
178        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
179            return [f'free({var}->{ref}{self.c_name});']
180        return []
181
182    def free(self, ri, var, ref):
183        lines = self._free_lines(ri, var, ref)
184        for line in lines:
185            ri.cw.p(line)
186
187    def arg_member(self, ri):
188        member = self._complex_member_type(ri)
189        if member:
190            spc = ' ' if member[-1] != '*' else ''
191            arg = [member + spc + '*' + self.c_name]
192            if self.presence_type() == 'count':
193                arg += ['unsigned int n_' + self.c_name]
194            return arg
195        raise Exception(f"Struct member not implemented for class type {self.type}")
196
197    def struct_member(self, ri):
198        member = self._complex_member_type(ri)
199        if member:
200            ptr = '*' if self.is_multi_val() else ''
201            if self.is_recursive_for_op(ri):
202                ptr = '*'
203            spc = ' ' if member[-1] != '*' else ''
204            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
205            return
206        members = self.arg_member(ri)
207        for one in members:
208            ri.cw.p(one + ';')
209
210    def _attr_policy(self, policy):
211        return '{ .type = ' + policy + ', }'
212
213    def attr_policy(self, cw):
214        policy = f'NLA_{c_upper(self.type)}'
215        if self.attr.get('byte-order') == 'big-endian':
216            if self.type in {'u16', 'u32'}:
217                policy = f'NLA_BE{self.type[1:]}'
218
219        spec = self._attr_policy(policy)
220        cw.p(f"\t[{self.enum_name}] = {spec},")
221
222    def _attr_typol(self):
223        raise Exception(f"Type policy not implemented for class type {self.type}")
224
225    def attr_typol(self, cw):
226        typol = self._attr_typol()
227        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
228
229    def _attr_put_line(self, ri, var, line):
230        presence = self.presence_type()
231        if presence in {'present', 'len'}:
232            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
233        ri.cw.p(f"{line};")
234
235    def _attr_put_simple(self, ri, var, put_type):
236        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
237        self._attr_put_line(ri, var, line)
238
239    def attr_put(self, ri, var):
240        raise Exception(f"Put not implemented for class type {self.type}")
241
242    def _attr_get(self, ri, var):
243        raise Exception(f"Attr get not implemented for class type {self.type}")
244
245    def attr_get(self, ri, var, first):
246        lines, init_lines, local_vars = self._attr_get(ri, var)
247        if type(lines) is str:
248            lines = [lines]
249        if type(init_lines) is str:
250            init_lines = [init_lines]
251
252        kw = 'if' if first else 'else if'
253        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
254        if local_vars:
255            for local in local_vars:
256                ri.cw.p(local)
257            ri.cw.nl()
258
259        if not self.is_multi_val():
260            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
261            ri.cw.p("return YNL_PARSE_CB_ERROR;")
262            if self.presence_type() == 'present':
263                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
264
265        if init_lines:
266            ri.cw.nl()
267            for line in init_lines:
268                ri.cw.p(line)
269
270        for line in lines:
271            ri.cw.p(line)
272        ri.cw.block_end()
273        return True
274
275    def _setter_lines(self, ri, member, presence):
276        raise Exception(f"Setter not implemented for class type {self.type}")
277
278    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
279        ref = (ref if ref else []) + [self.c_name]
280        member = f"{var}->{'.'.join(ref)}"
281
282        local_vars = []
283        if self.free_needs_iter():
284            local_vars += ['unsigned int i;']
285
286        code = []
287        presence = ''
288        for i in range(0, len(ref)):
289            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
290            # Every layer below last is a nest, so we know it uses bit presence
291            # last layer is "self" and may be a complex type
292            if i == len(ref) - 1 and self.presence_type() != 'present':
293                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
294                continue
295            code.append(presence + ' = 1;')
296        ref_path = '.'.join(ref[:-1])
297        if ref_path:
298            ref_path += '.'
299        code += self._free_lines(ri, var, ref_path)
300        code += self._setter_lines(ri, member, presence)
301
302        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
303        free = bool([x for x in code if 'free(' in x])
304        alloc = bool([x for x in code if 'alloc(' in x])
305        if free and not alloc:
306            func_name = '__' + func_name
307        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
308                         body=code,
309                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
310
311
312class TypeUnused(Type):
313    def presence_type(self):
314        return ''
315
316    def arg_member(self, ri):
317        return []
318
319    def _attr_get(self, ri, var):
320        return ['return YNL_PARSE_CB_ERROR;'], None, None
321
322    def _attr_typol(self):
323        return '.type = YNL_PT_REJECT, '
324
325    def attr_policy(self, cw):
326        pass
327
328    def attr_put(self, ri, var):
329        pass
330
331    def attr_get(self, ri, var, first):
332        pass
333
334    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
335        pass
336
337
338class TypePad(Type):
339    def presence_type(self):
340        return ''
341
342    def arg_member(self, ri):
343        return []
344
345    def _attr_typol(self):
346        return '.type = YNL_PT_IGNORE, '
347
348    def attr_put(self, ri, var):
349        pass
350
351    def attr_get(self, ri, var, first):
352        pass
353
354    def attr_policy(self, cw):
355        pass
356
357    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
358        pass
359
360
361class TypeScalar(Type):
362    def __init__(self, family, attr_set, attr, value):
363        super().__init__(family, attr_set, attr, value)
364
365        self.byte_order_comment = ''
366        if 'byte-order' in attr:
367            self.byte_order_comment = f" /* {attr['byte-order']} */"
368
369        # Classic families have some funny enums, don't bother
370        # computing checks, since we only need them for kernel policies
371        if not family.is_classic():
372            self._init_checks()
373
374        # Added by resolve():
375        self.is_bitfield = None
376        delattr(self, "is_bitfield")
377        self.type_name = None
378        delattr(self, "type_name")
379
380    def resolve(self):
381        self.resolve_up(super())
382
383        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
384            self.is_bitfield = True
385        elif 'enum' in self.attr:
386            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
387        else:
388            self.is_bitfield = False
389
390        if not self.is_bitfield and 'enum' in self.attr:
391            self.type_name = self.family.consts[self.attr['enum']].user_type
392        elif self.is_auto_scalar:
393            self.type_name = '__' + self.type[0] + '64'
394        else:
395            self.type_name = '__' + self.type
396
397    def _init_checks(self):
398        if 'enum' in self.attr:
399            enum = self.family.consts[self.attr['enum']]
400            low, high = enum.value_range()
401            if low == None and high == None:
402                self.checks['sparse'] = True
403            else:
404                if 'min' not in self.checks:
405                    if low != 0 or self.type[0] == 's':
406                        self.checks['min'] = low
407                if 'max' not in self.checks:
408                    self.checks['max'] = high
409
410        if 'min' in self.checks and 'max' in self.checks:
411            if self.get_limit('min') > self.get_limit('max'):
412                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
413            self.checks['range'] = True
414
415        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
416        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
417        if low < 0 and self.type[0] == 'u':
418            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
419        if low < -32768 or high > 32767:
420            self.checks['full-range'] = True
421
422    def _attr_policy(self, policy):
423        if 'flags-mask' in self.checks or self.is_bitfield:
424            if self.is_bitfield:
425                enum = self.family.consts[self.attr['enum']]
426                mask = enum.get_mask(as_flags=True)
427            else:
428                flags = self.family.consts[self.checks['flags-mask']]
429                flag_cnt = len(flags['entries'])
430                mask = (1 << flag_cnt) - 1
431            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
432        elif 'full-range' in self.checks:
433            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
434        elif 'range' in self.checks:
435            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
436        elif 'min' in self.checks:
437            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
438        elif 'max' in self.checks:
439            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
440        elif 'sparse' in self.checks:
441            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
442        return super()._attr_policy(policy)
443
444    def _attr_typol(self):
445        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
446
447    def arg_member(self, ri):
448        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
449
450    def attr_put(self, ri, var):
451        self._attr_put_simple(ri, var, self.type)
452
453    def _attr_get(self, ri, var):
454        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
455
456    def _setter_lines(self, ri, member, presence):
457        return [f"{member} = {self.c_name};"]
458
459
460class TypeFlag(Type):
461    def arg_member(self, ri):
462        return []
463
464    def _attr_typol(self):
465        return '.type = YNL_PT_FLAG, '
466
467    def attr_put(self, ri, var):
468        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
469
470    def _attr_get(self, ri, var):
471        return [], None, None
472
473    def _setter_lines(self, ri, member, presence):
474        return []
475
476
477class TypeString(Type):
478    def arg_member(self, ri):
479        return [f"const char *{self.c_name}"]
480
481    def presence_type(self):
482        return 'len'
483
484    def struct_member(self, ri):
485        ri.cw.p(f"char *{self.c_name};")
486
487    def _attr_typol(self):
488        typol = f'.type = YNL_PT_NUL_STR, '
489        if self.is_selector:
490            typol += '.is_selector = 1, '
491        return typol
492
493    def _attr_policy(self, policy):
494        if 'exact-len' in self.checks:
495            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
496        else:
497            mem = '{ .type = ' + policy
498            if 'max-len' in self.checks:
499                mem += ', .len = ' + self.get_limit_str('max-len')
500            mem += ', }'
501        return mem
502
503    def attr_policy(self, cw):
504        if self.checks.get('unterminated-ok', False):
505            policy = 'NLA_STRING'
506        else:
507            policy = 'NLA_NUL_STRING'
508
509        spec = self._attr_policy(policy)
510        cw.p(f"\t[{self.enum_name}] = {spec},")
511
512    def attr_put(self, ri, var):
513        self._attr_put_simple(ri, var, 'str')
514
515    def _attr_get(self, ri, var):
516        len_mem = var + '->_len.' + self.c_name
517        return [f"{len_mem} = len;",
518                f"{var}->{self.c_name} = malloc(len + 1);",
519                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
520                f"{var}->{self.c_name}[len] = 0;"], \
521               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
522               ['unsigned int len;']
523
524    def _setter_lines(self, ri, member, presence):
525        return [f"{presence} = strlen({self.c_name});",
526                f"{member} = malloc({presence} + 1);",
527                f'memcpy({member}, {self.c_name}, {presence});',
528                f'{member}[{presence}] = 0;']
529
530
531class TypeBinary(Type):
532    def arg_member(self, ri):
533        return [f"const void *{self.c_name}", 'size_t len']
534
535    def presence_type(self):
536        return 'len'
537
538    def struct_member(self, ri):
539        ri.cw.p(f"void *{self.c_name};")
540
541    def _attr_typol(self):
542        return f'.type = YNL_PT_BINARY,'
543
544    def _attr_policy(self, policy):
545        if len(self.checks) == 0:
546            pass
547        elif len(self.checks) == 1:
548            check_name = list(self.checks)[0]
549            if check_name not in {'exact-len', 'min-len', 'max-len'}:
550                raise Exception('Unsupported check for binary type: ' + check_name)
551        else:
552            raise Exception('More than one check for binary type not implemented, yet')
553
554        if len(self.checks) == 0:
555            mem = '{ .type = NLA_BINARY, }'
556        elif 'exact-len' in self.checks:
557            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
558        elif 'min-len' in self.checks:
559            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
560        elif 'max-len' in self.checks:
561            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
562
563        return mem
564
565    def attr_put(self, ri, var):
566        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
567                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
568
569    def _attr_get(self, ri, var):
570        len_mem = var + '->_len.' + self.c_name
571        return [f"{len_mem} = len;",
572                f"{var}->{self.c_name} = malloc(len);",
573                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
574               ['len = ynl_attr_data_len(attr);'], \
575               ['unsigned int len;']
576
577    def _setter_lines(self, ri, member, presence):
578        return [f"{presence} = len;",
579                f"{member} = malloc({presence});",
580                f'memcpy({member}, {self.c_name}, {presence});']
581
582
583class TypeBinaryStruct(TypeBinary):
584    def struct_member(self, ri):
585        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
586
587    def _attr_get(self, ri, var):
588        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
589        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
590        return [f"{len_mem} = len;",
591                f"if (len < {struct_sz})",
592                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
593                "else",
594                f"{var}->{self.c_name} = malloc(len);",
595                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
596               ['len = ynl_attr_data_len(attr);'], \
597               ['unsigned int len;']
598
599
600class TypeBinaryScalarArray(TypeBinary):
601    def arg_member(self, ri):
602        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
603
604    def presence_type(self):
605        return 'count'
606
607    def struct_member(self, ri):
608        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
609
610    def attr_put(self, ri, var):
611        presence = self.presence_type()
612        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
613        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
614        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
615                f"{var}->{self.c_name}, i);")
616        ri.cw.block_end()
617
618    def _attr_get(self, ri, var):
619        len_mem = var + '->_count.' + self.c_name
620        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
621                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
622                f"{var}->{self.c_name} = malloc(len);",
623                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
624               ['len = ynl_attr_data_len(attr);'], \
625               ['unsigned int len;']
626
627    def _setter_lines(self, ri, member, presence):
628        return [f"{presence} = count;",
629                f"count *= sizeof(__{self.get('sub-type')});",
630                f"{member} = malloc(count);",
631                f'memcpy({member}, {self.c_name}, count);']
632
633
634class TypeBitfield32(Type):
635    def _complex_member_type(self, ri):
636        return "struct nla_bitfield32"
637
638    def _attr_typol(self):
639        return f'.type = YNL_PT_BITFIELD32, '
640
641    def _attr_policy(self, policy):
642        if not 'enum' in self.attr:
643            raise Exception('Enum required for bitfield32 attr')
644        enum = self.family.consts[self.attr['enum']]
645        mask = enum.get_mask(as_flags=True)
646        return f"NLA_POLICY_BITFIELD32({mask})"
647
648    def attr_put(self, ri, var):
649        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
650        self._attr_put_line(ri, var, line)
651
652    def _attr_get(self, ri, var):
653        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
654
655    def _setter_lines(self, ri, member, presence):
656        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
657
658
659class TypeNest(Type):
660    def is_recursive(self):
661        return self.family.pure_nested_structs[self.nested_attrs].recursive
662
663    def _complex_member_type(self, ri):
664        return self.nested_struct_type
665
666    def _free_lines(self, ri, var, ref):
667        lines = []
668        at = '&'
669        if self.is_recursive_for_op(ri):
670            at = ''
671            lines += [f'if ({var}->{ref}{self.c_name})']
672        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
673        return lines
674
675    def _attr_typol(self):
676        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
677
678    def _attr_policy(self, policy):
679        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
680
681    def attr_put(self, ri, var):
682        at = '' if self.is_recursive_for_op(ri) else '&'
683        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
684                            f"{self.enum_name}, {at}{var}->{self.c_name})")
685
686    def _attr_get(self, ri, var):
687        pns = self.family.pure_nested_structs[self.nested_attrs]
688        args = ["&parg", "attr"]
689        for sel in pns.external_selectors():
690            args.append(f'{var}->{sel.name}')
691        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
692                     "return YNL_PARSE_CB_ERROR;"]
693        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
694                      f"parg.data = &{var}->{self.c_name};"]
695        return get_lines, init_lines, None
696
697    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
698        ref = (ref if ref else []) + [self.c_name]
699
700        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
701            if attr.is_recursive():
702                continue
703            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref,
704                        var=var)
705
706
707class TypeMultiAttr(Type):
708    def __init__(self, family, attr_set, attr, value, base_type):
709        super().__init__(family, attr_set, attr, value)
710
711        self.base_type = base_type
712
713    def is_multi_val(self):
714        return True
715
716    def presence_type(self):
717        return 'count'
718
719    def _complex_member_type(self, ri):
720        if 'type' not in self.attr or self.attr['type'] == 'nest':
721            return self.nested_struct_type
722        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
723            return None  # use arg_member()
724        elif self.attr['type'] == 'string':
725            return 'struct ynl_string *'
726        elif self.attr['type'] in scalars:
727            scalar_pfx = '__' if ri.ku_space == 'user' else ''
728            return scalar_pfx + self.attr['type']
729        else:
730            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
731
732    def arg_member(self, ri):
733        if self.type == 'binary' and 'struct' in self.attr:
734            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
735                    f'unsigned int n_{self.c_name}']
736        return super().arg_member(ri)
737
738    def free_needs_iter(self):
739        return self.attr['type'] in {'nest', 'string'}
740
741    def _free_lines(self, ri, var, ref):
742        lines = []
743        if self.attr['type'] in scalars:
744            lines += [f"free({var}->{ref}{self.c_name});"]
745        elif self.attr['type'] == 'binary':
746            lines += [f"free({var}->{ref}{self.c_name});"]
747        elif self.attr['type'] == 'string':
748            lines += [
749                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
750                f"free({var}->{ref}{self.c_name}[i]);",
751                f"free({var}->{ref}{self.c_name});",
752            ]
753        elif 'type' not in self.attr or self.attr['type'] == 'nest':
754            lines += [
755                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
756                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
757                f"free({var}->{ref}{self.c_name});",
758            ]
759        else:
760            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
761        return lines
762
763    def _attr_policy(self, policy):
764        return self.base_type._attr_policy(policy)
765
766    def _attr_typol(self):
767        return self.base_type._attr_typol()
768
769    def _attr_get(self, ri, var):
770        return f'n_{self.c_name}++;', None, None
771
772    def attr_put(self, ri, var):
773        if self.attr['type'] in scalars:
774            put_type = self.type
775            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
776            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
777        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
778            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
779            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
780        elif self.attr['type'] == 'string':
781            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
782            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
783        elif 'type' not in self.attr or self.attr['type'] == 'nest':
784            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
785            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
786                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
787        else:
788            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
789
790    def _setter_lines(self, ri, member, presence):
791        return [f"{member} = {self.c_name};",
792                f"{presence} = n_{self.c_name};"]
793
794
795class TypeArrayNest(Type):
796    def is_multi_val(self):
797        return True
798
799    def presence_type(self):
800        return 'count'
801
802    def _complex_member_type(self, ri):
803        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
804            return self.nested_struct_type
805        elif self.attr['sub-type'] in scalars:
806            scalar_pfx = '__' if ri.ku_space == 'user' else ''
807            return scalar_pfx + self.attr['sub-type']
808        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
809            return None  # use arg_member()
810        else:
811            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
812
813    def arg_member(self, ri):
814        if self.sub_type == 'binary' and 'exact-len' in self.checks:
815            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
816                    f'unsigned int n_{self.c_name}']
817        return super().arg_member(ri)
818
819    def _attr_typol(self):
820        if self.attr['sub-type'] in scalars:
821            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
822        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
823            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
824        else:
825            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
826
827    def _attr_get(self, ri, var):
828        local_vars = ['const struct nlattr *attr2;']
829        get_lines = [f'attr_{self.c_name} = attr;',
830                     'ynl_attr_for_each_nested(attr2, attr) {',
831                     '\tif (ynl_attr_validate(yarg, attr2))',
832                     '\t\treturn YNL_PARSE_CB_ERROR;',
833                     f'\t{var}->_count.{self.c_name}++;',
834                     '}']
835        return get_lines, None, local_vars
836
837    def attr_put(self, ri, var):
838        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
839        if self.sub_type in scalars:
840            put_type = self.sub_type
841            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
842            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
843            ri.cw.block_end()
844        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
845            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
846            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
847        elif self.sub_type == 'nest':
848            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
849            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
850        else:
851            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
852        ri.cw.p('ynl_attr_nest_end(nlh, array);')
853
854    def _setter_lines(self, ri, member, presence):
855        return [f"{member} = {self.c_name};",
856                f"{presence} = n_{self.c_name};"]
857
858
859class TypeNestTypeValue(Type):
860    def _complex_member_type(self, ri):
861        return self.nested_struct_type
862
863    def _attr_typol(self):
864        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
865
866    def _attr_get(self, ri, var):
867        prev = 'attr'
868        tv_args = ''
869        get_lines = []
870        local_vars = []
871        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
872                      f"parg.data = &{var}->{self.c_name};"]
873        if 'type-value' in self.attr:
874            tv_names = [c_lower(x) for x in self.attr["type-value"]]
875            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
876            local_vars += [f'__u32 {", ".join(tv_names)};']
877            for level in self.attr["type-value"]:
878                level = c_lower(level)
879                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
880                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
881                prev = 'attr_' + level
882
883            tv_args = f", {', '.join(tv_names)}"
884
885        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
886        return get_lines, init_lines, local_vars
887
888
889class TypeSubMessage(TypeNest):
890    def __init__(self, family, attr_set, attr, value):
891        super().__init__(family, attr_set, attr, value)
892
893        self.selector = Selector(attr, attr_set)
894
895    def _attr_typol(self):
896        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
897        typol += '.is_submsg = 1, '
898        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
899        # support external selectors. No family uses sub-messages with external
900        # selector for requests so this is fine for now.
901        if not self.selector.is_external():
902            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
903        return typol
904
905    def _attr_get(self, ri, var):
906        sel = c_lower(self['selector'])
907        if self.selector.is_external():
908            sel_var = f"_sel_{sel}"
909        else:
910            sel_var = f"{var}->{sel}"
911        get_lines = [f'if (!{sel_var})',
912                     f'return ynl_submsg_failed(yarg, "%s", "%s");' %
913                        (self.name, self['selector']),
914                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
915                     "return YNL_PARSE_CB_ERROR;"]
916        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
917                      f"parg.data = &{var}->{self.c_name};"]
918        return get_lines, init_lines, None
919
920
921class Selector:
922    def __init__(self, msg_attr, attr_set):
923        self.name = msg_attr["selector"]
924
925        if self.name in attr_set:
926            self.attr = attr_set[self.name]
927            self.attr.is_selector = True
928            self._external = False
929        else:
930            # The selector will need to get passed down thru the structs
931            self.attr = None
932            self._external = True
933
934    def set_attr(self, attr):
935        self.attr = attr
936
937    def is_external(self):
938        return self._external
939
940
941class Struct:
942    def __init__(self, family, space_name, type_list=None, fixed_header=None,
943                 inherited=None, submsg=None):
944        self.family = family
945        self.space_name = space_name
946        self.attr_set = family.attr_sets[space_name]
947        # Use list to catch comparisons with empty sets
948        self._inherited = inherited if inherited is not None else []
949        self.inherited = []
950        self.fixed_header = None
951        if fixed_header:
952            self.fixed_header = 'struct ' + c_lower(fixed_header)
953        self.submsg = submsg
954
955        self.nested = type_list is None
956        if family.name == c_lower(space_name):
957            self.render_name = c_lower(family.ident_name)
958        else:
959            self.render_name = c_lower(family.ident_name + '-' + space_name)
960        self.struct_name = 'struct ' + self.render_name
961        if self.nested and space_name in family.consts:
962            self.struct_name += '_'
963        self.ptr_name = self.struct_name + ' *'
964        # All attr sets this one contains, directly or multiple levels down
965        self.child_nests = set()
966
967        self.request = False
968        self.reply = False
969        self.recursive = False
970        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
971
972        self.attr_list = []
973        self.attrs = dict()
974        if type_list is not None:
975            for t in type_list:
976                self.attr_list.append((t, self.attr_set[t]),)
977        else:
978            for t in self.attr_set:
979                self.attr_list.append((t, self.attr_set[t]),)
980
981        max_val = 0
982        self.attr_max_val = None
983        for name, attr in self.attr_list:
984            if attr.value >= max_val:
985                max_val = attr.value
986                self.attr_max_val = attr
987            self.attrs[name] = attr
988
989    def __iter__(self):
990        yield from self.attrs
991
992    def __getitem__(self, key):
993        return self.attrs[key]
994
995    def member_list(self):
996        return self.attr_list
997
998    def set_inherited(self, new_inherited):
999        if self._inherited != new_inherited:
1000            raise Exception("Inheriting different members not supported")
1001        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1002
1003    def external_selectors(self):
1004        sels = []
1005        for name, attr in self.attr_list:
1006            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1007                sels.append(attr.selector)
1008        return sels
1009
1010    def free_needs_iter(self):
1011        for _, attr in self.attr_list:
1012            if attr.free_needs_iter():
1013                return True
1014        return False
1015
1016
1017class EnumEntry(SpecEnumEntry):
1018    def __init__(self, enum_set, yaml, prev, value_start):
1019        super().__init__(enum_set, yaml, prev, value_start)
1020
1021        if prev:
1022            self.value_change = (self.value != prev.value + 1)
1023        else:
1024            self.value_change = (self.value != 0)
1025        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1026
1027        # Added by resolve:
1028        self.c_name = None
1029        delattr(self, "c_name")
1030
1031    def resolve(self):
1032        self.resolve_up(super())
1033
1034        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1035
1036
1037class EnumSet(SpecEnumSet):
1038    def __init__(self, family, yaml):
1039        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1040
1041        if 'enum-name' in yaml:
1042            if yaml['enum-name']:
1043                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1044                self.user_type = self.enum_name
1045            else:
1046                self.enum_name = None
1047        else:
1048            self.enum_name = 'enum ' + self.render_name
1049
1050        if self.enum_name:
1051            self.user_type = self.enum_name
1052        else:
1053            self.user_type = 'int'
1054
1055        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1056        self.header = yaml.get('header', None)
1057        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1058
1059        super().__init__(family, yaml)
1060
1061    def new_entry(self, entry, prev_entry, value_start):
1062        return EnumEntry(self, entry, prev_entry, value_start)
1063
1064    def value_range(self):
1065        low = min([x.value for x in self.entries.values()])
1066        high = max([x.value for x in self.entries.values()])
1067
1068        if high - low + 1 != len(self.entries):
1069            return None, None
1070
1071        return low, high
1072
1073
1074class AttrSet(SpecAttrSet):
1075    def __init__(self, family, yaml):
1076        super().__init__(family, yaml)
1077
1078        if self.subset_of is None:
1079            if 'name-prefix' in yaml:
1080                pfx = yaml['name-prefix']
1081            elif self.name == family.name:
1082                pfx = family.ident_name + '-a-'
1083            else:
1084                pfx = f"{family.ident_name}-a-{self.name}-"
1085            self.name_prefix = c_upper(pfx)
1086            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1087            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1088        else:
1089            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1090            self.max_name = family.attr_sets[self.subset_of].max_name
1091            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1092
1093        # Added by resolve:
1094        self.c_name = None
1095        delattr(self, "c_name")
1096
1097    def resolve(self):
1098        self.c_name = c_lower(self.name)
1099        if self.c_name in _C_KW:
1100            self.c_name += '_'
1101        if self.c_name == self.family.c_name:
1102            self.c_name = ''
1103
1104    def new_attr(self, elem, value):
1105        if elem['type'] in scalars:
1106            t = TypeScalar(self.family, self, elem, value)
1107        elif elem['type'] == 'unused':
1108            t = TypeUnused(self.family, self, elem, value)
1109        elif elem['type'] == 'pad':
1110            t = TypePad(self.family, self, elem, value)
1111        elif elem['type'] == 'flag':
1112            t = TypeFlag(self.family, self, elem, value)
1113        elif elem['type'] == 'string':
1114            t = TypeString(self.family, self, elem, value)
1115        elif elem['type'] == 'binary':
1116            if 'struct' in elem:
1117                t = TypeBinaryStruct(self.family, self, elem, value)
1118            elif elem.get('sub-type') in scalars:
1119                t = TypeBinaryScalarArray(self.family, self, elem, value)
1120            else:
1121                t = TypeBinary(self.family, self, elem, value)
1122        elif elem['type'] == 'bitfield32':
1123            t = TypeBitfield32(self.family, self, elem, value)
1124        elif elem['type'] == 'nest':
1125            t = TypeNest(self.family, self, elem, value)
1126        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1127            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1128                t = TypeArrayNest(self.family, self, elem, value)
1129            else:
1130                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1131        elif elem['type'] == 'nest-type-value':
1132            t = TypeNestTypeValue(self.family, self, elem, value)
1133        elif elem['type'] == 'sub-message':
1134            t = TypeSubMessage(self.family, self, elem, value)
1135        else:
1136            raise Exception(f"No typed class for type {elem['type']}")
1137
1138        if 'multi-attr' in elem and elem['multi-attr']:
1139            t = TypeMultiAttr(self.family, self, elem, value, t)
1140
1141        return t
1142
1143
1144class Operation(SpecOperation):
1145    def __init__(self, family, yaml, req_value, rsp_value):
1146        # Fill in missing operation properties (for fixed hdr-only msgs)
1147        for mode in ['do', 'dump', 'event']:
1148            for direction in ['request', 'reply']:
1149                try:
1150                    yaml[mode][direction].setdefault('attributes', [])
1151                except KeyError:
1152                    pass
1153
1154        super().__init__(family, yaml, req_value, rsp_value)
1155
1156        self.render_name = c_lower(family.ident_name + '_' + self.name)
1157
1158        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1159                         ('dump' in yaml and 'request' in yaml['dump'])
1160
1161        self.has_ntf = False
1162
1163        # Added by resolve:
1164        self.enum_name = None
1165        delattr(self, "enum_name")
1166
1167    def resolve(self):
1168        self.resolve_up(super())
1169
1170        if not self.is_async:
1171            self.enum_name = self.family.op_prefix + c_upper(self.name)
1172        else:
1173            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1174
1175    def mark_has_ntf(self):
1176        self.has_ntf = True
1177
1178
1179class SubMessage(SpecSubMessage):
1180    def __init__(self, family, yaml):
1181        super().__init__(family, yaml)
1182
1183        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1184
1185    def resolve(self):
1186        self.resolve_up(super())
1187
1188
1189class Family(SpecFamily):
1190    def __init__(self, file_name, exclude_ops):
1191        # Added by resolve:
1192        self.c_name = None
1193        delattr(self, "c_name")
1194        self.op_prefix = None
1195        delattr(self, "op_prefix")
1196        self.async_op_prefix = None
1197        delattr(self, "async_op_prefix")
1198        self.mcgrps = None
1199        delattr(self, "mcgrps")
1200        self.consts = None
1201        delattr(self, "consts")
1202        self.hooks = None
1203        delattr(self, "hooks")
1204
1205        super().__init__(file_name, exclude_ops=exclude_ops)
1206
1207        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1208        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1209
1210        if 'definitions' not in self.yaml:
1211            self.yaml['definitions'] = []
1212
1213        if 'uapi-header' in self.yaml:
1214            self.uapi_header = self.yaml['uapi-header']
1215        else:
1216            self.uapi_header = f"linux/{self.ident_name}.h"
1217        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1218            self.uapi_header_name = self.uapi_header[6:-2]
1219        else:
1220            self.uapi_header_name = self.ident_name
1221
1222    def resolve(self):
1223        self.resolve_up(super())
1224
1225        self.c_name = c_lower(self.ident_name)
1226        if 'name-prefix' in self.yaml['operations']:
1227            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1228        else:
1229            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1230        if 'async-prefix' in self.yaml['operations']:
1231            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1232        else:
1233            self.async_op_prefix = self.op_prefix
1234
1235        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1236
1237        self.hooks = dict()
1238        for when in ['pre', 'post']:
1239            self.hooks[when] = dict()
1240            for op_mode in ['do', 'dump']:
1241                self.hooks[when][op_mode] = dict()
1242                self.hooks[when][op_mode]['set'] = set()
1243                self.hooks[when][op_mode]['list'] = []
1244
1245        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1246        self.root_sets = dict()
1247        # dict space-name -> Struct
1248        self.pure_nested_structs = dict()
1249
1250        self._mark_notify()
1251        self._mock_up_events()
1252
1253        self._load_root_sets()
1254        self._load_nested_sets()
1255        self._load_attr_use()
1256        self._load_selector_passing()
1257        self._load_hooks()
1258
1259        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1260        if self.kernel_policy == 'global':
1261            self._load_global_policy()
1262
1263    def new_enum(self, elem):
1264        return EnumSet(self, elem)
1265
1266    def new_attr_set(self, elem):
1267        return AttrSet(self, elem)
1268
1269    def new_operation(self, elem, req_value, rsp_value):
1270        return Operation(self, elem, req_value, rsp_value)
1271
1272    def new_sub_message(self, elem):
1273        return SubMessage(self, elem)
1274
1275    def is_classic(self):
1276        return self.proto == 'netlink-raw'
1277
1278    def _mark_notify(self):
1279        for op in self.msgs.values():
1280            if 'notify' in op:
1281                self.ops[op['notify']].mark_has_ntf()
1282
1283    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1284    def _mock_up_events(self):
1285        for op in self.yaml['operations']['list']:
1286            if 'event' in op:
1287                op['do'] = {
1288                    'reply': {
1289                        'attributes': op['event']['attributes']
1290                    }
1291                }
1292
1293    def _load_root_sets(self):
1294        for op_name, op in self.msgs.items():
1295            if 'attribute-set' not in op:
1296                continue
1297
1298            req_attrs = set()
1299            rsp_attrs = set()
1300            for op_mode in ['do', 'dump']:
1301                if op_mode in op and 'request' in op[op_mode]:
1302                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1303                if op_mode in op and 'reply' in op[op_mode]:
1304                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1305            if 'event' in op:
1306                rsp_attrs.update(set(op['event']['attributes']))
1307
1308            if op['attribute-set'] not in self.root_sets:
1309                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1310            else:
1311                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1312                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1313
1314    def _sort_pure_types(self):
1315        # Try to reorder according to dependencies
1316        pns_key_list = list(self.pure_nested_structs.keys())
1317        pns_key_seen = set()
1318        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1319        for _ in range(rounds):
1320            if len(pns_key_list) == 0:
1321                break
1322            name = pns_key_list.pop(0)
1323            finished = True
1324            for _, spec in self.attr_sets[name].items():
1325                if 'nested-attributes' in spec:
1326                    nested = spec['nested-attributes']
1327                elif 'sub-message' in spec:
1328                    nested = spec.sub_message
1329                else:
1330                    continue
1331
1332                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1333                if self.pure_nested_structs[nested].recursive:
1334                    continue
1335                if nested not in pns_key_seen:
1336                    # Dicts are sorted, this will make struct last
1337                    struct = self.pure_nested_structs.pop(name)
1338                    self.pure_nested_structs[name] = struct
1339                    finished = False
1340                    break
1341            if finished:
1342                pns_key_seen.add(name)
1343            else:
1344                pns_key_list.append(name)
1345
1346    def _load_nested_set_nest(self, spec):
1347        inherit = set()
1348        nested = spec['nested-attributes']
1349        if nested not in self.root_sets:
1350            if nested not in self.pure_nested_structs:
1351                self.pure_nested_structs[nested] = \
1352                    Struct(self, nested, inherited=inherit,
1353                           fixed_header=spec.get('fixed-header'))
1354        else:
1355            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1356
1357        if 'type-value' in spec:
1358            if nested in self.root_sets:
1359                raise Exception("Inheriting members to a space used as root not supported")
1360            inherit.update(set(spec['type-value']))
1361        elif spec['type'] == 'indexed-array':
1362            inherit.add('idx')
1363        self.pure_nested_structs[nested].set_inherited(inherit)
1364
1365        return nested
1366
1367    def _load_nested_set_submsg(self, spec):
1368        # Fake the struct type for the sub-message itself
1369        # its not a attr_set but codegen wants attr_sets.
1370        submsg = self.sub_msgs[spec["sub-message"]]
1371        nested = submsg.name
1372
1373        attrs = []
1374        for name, fmt in submsg.formats.items():
1375            attr = {
1376                "name": name,
1377                "parent-sub-message": spec,
1378            }
1379            if 'attribute-set' in fmt:
1380                attr |= {
1381                    "type": "nest",
1382                    "nested-attributes": fmt['attribute-set'],
1383                }
1384                if 'fixed-header' in fmt:
1385                    attr |= { "fixed-header": fmt["fixed-header"] }
1386            elif 'fixed-header' in fmt:
1387                attr |= {
1388                    "type": "binary",
1389                    "struct": fmt["fixed-header"],
1390                }
1391            else:
1392                attr["type"] = "flag"
1393            attrs.append(attr)
1394
1395        self.attr_sets[nested] = AttrSet(self, {
1396            "name": nested,
1397            "name-pfx": self.name + '-' + spec.name + '-',
1398            "attributes": attrs
1399        })
1400
1401        if nested not in self.pure_nested_structs:
1402            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1403
1404        return nested
1405
1406    def _load_nested_sets(self):
1407        attr_set_queue = list(self.root_sets.keys())
1408        attr_set_seen = set(self.root_sets.keys())
1409
1410        while len(attr_set_queue):
1411            a_set = attr_set_queue.pop(0)
1412            for attr, spec in self.attr_sets[a_set].items():
1413                if 'nested-attributes' in spec:
1414                    nested = self._load_nested_set_nest(spec)
1415                elif 'sub-message' in spec:
1416                    nested = self._load_nested_set_submsg(spec)
1417                else:
1418                    continue
1419
1420                if nested not in attr_set_seen:
1421                    attr_set_queue.append(nested)
1422                    attr_set_seen.add(nested)
1423
1424        for root_set, rs_members in self.root_sets.items():
1425            for attr, spec in self.attr_sets[root_set].items():
1426                if 'nested-attributes' in spec:
1427                    nested = spec['nested-attributes']
1428                elif 'sub-message' in spec:
1429                    nested = spec.sub_message
1430                else:
1431                    nested = None
1432
1433                if nested:
1434                    if attr in rs_members['request']:
1435                        self.pure_nested_structs[nested].request = True
1436                    if attr in rs_members['reply']:
1437                        self.pure_nested_structs[nested].reply = True
1438
1439                    if spec.is_multi_val():
1440                        child = self.pure_nested_structs.get(nested)
1441                        child.in_multi_val = True
1442
1443        self._sort_pure_types()
1444
1445        # Propagate the request / reply / recursive
1446        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1447            for _, spec in self.attr_sets[attr_set].items():
1448                if attr_set in struct.child_nests:
1449                    struct.recursive = True
1450
1451                if 'nested-attributes' in spec:
1452                    child_name = spec['nested-attributes']
1453                elif 'sub-message' in spec:
1454                    child_name = spec.sub_message
1455                else:
1456                    continue
1457
1458                struct.child_nests.add(child_name)
1459                child = self.pure_nested_structs.get(child_name)
1460                if child:
1461                    if not child.recursive:
1462                        struct.child_nests.update(child.child_nests)
1463                    child.request |= struct.request
1464                    child.reply |= struct.reply
1465                    if spec.is_multi_val():
1466                        child.in_multi_val = True
1467
1468        self._sort_pure_types()
1469
1470    def _load_attr_use(self):
1471        for _, struct in self.pure_nested_structs.items():
1472            if struct.request:
1473                for _, arg in struct.member_list():
1474                    arg.set_request()
1475            if struct.reply:
1476                for _, arg in struct.member_list():
1477                    arg.set_reply()
1478
1479        for root_set, rs_members in self.root_sets.items():
1480            for attr, spec in self.attr_sets[root_set].items():
1481                if attr in rs_members['request']:
1482                    spec.set_request()
1483                if attr in rs_members['reply']:
1484                    spec.set_reply()
1485
1486    def _load_selector_passing(self):
1487        def all_structs():
1488            for k, v in reversed(self.pure_nested_structs.items()):
1489                yield k, v
1490            for k, _ in self.root_sets.items():
1491                yield k, None  # we don't have a struct, but it must be terminal
1492
1493        for attr_set, struct in all_structs():
1494            for _, spec in self.attr_sets[attr_set].items():
1495                if 'nested-attributes' in spec:
1496                    child_name = spec['nested-attributes']
1497                elif 'sub-message' in spec:
1498                    child_name = spec.sub_message
1499                else:
1500                    continue
1501
1502                child = self.pure_nested_structs.get(child_name)
1503                for selector in child.external_selectors():
1504                    if selector.name in self.attr_sets[attr_set]:
1505                        sel_attr = self.attr_sets[attr_set][selector.name]
1506                        selector.set_attr(sel_attr)
1507                    else:
1508                        raise Exception("Passing selector thru more than one layer not supported")
1509
1510    def _load_global_policy(self):
1511        global_set = set()
1512        attr_set_name = None
1513        for op_name, op in self.ops.items():
1514            if not op:
1515                continue
1516            if 'attribute-set' not in op:
1517                continue
1518
1519            if attr_set_name is None:
1520                attr_set_name = op['attribute-set']
1521            if attr_set_name != op['attribute-set']:
1522                raise Exception('For a global policy all ops must use the same set')
1523
1524            for op_mode in ['do', 'dump']:
1525                if op_mode in op:
1526                    req = op[op_mode].get('request')
1527                    if req:
1528                        global_set.update(req.get('attributes', []))
1529
1530        self.global_policy = []
1531        self.global_policy_set = attr_set_name
1532        for attr in self.attr_sets[attr_set_name]:
1533            if attr in global_set:
1534                self.global_policy.append(attr)
1535
1536    def _load_hooks(self):
1537        for op in self.ops.values():
1538            for op_mode in ['do', 'dump']:
1539                if op_mode not in op:
1540                    continue
1541                for when in ['pre', 'post']:
1542                    if when not in op[op_mode]:
1543                        continue
1544                    name = op[op_mode][when]
1545                    if name in self.hooks[when][op_mode]['set']:
1546                        continue
1547                    self.hooks[when][op_mode]['set'].add(name)
1548                    self.hooks[when][op_mode]['list'].append(name)
1549
1550
1551class RenderInfo:
1552    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1553        self.family = family
1554        self.nl = cw.nlib
1555        self.ku_space = ku_space
1556        self.op_mode = op_mode
1557        self.op = op
1558
1559        fixed_hdr = op.fixed_header if op else None
1560        self.fixed_hdr_len = 'ys->family->hdr_len'
1561        if op and op.fixed_header:
1562            if op.fixed_header != family.fixed_header:
1563                if family.is_classic():
1564                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1565                else:
1566                    raise Exception(f"Per-op fixed header not supported, yet")
1567
1568
1569        # 'do' and 'dump' response parsing is identical
1570        self.type_consistent = True
1571        self.type_oneside = False
1572        if op_mode != 'do' and 'dump' in op:
1573            if 'do' in op:
1574                if ('reply' in op['do']) != ('reply' in op["dump"]):
1575                    self.type_consistent = False
1576                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1577                    self.type_consistent = False
1578            else:
1579                self.type_consistent = True
1580                self.type_oneside = True
1581
1582        self.attr_set = attr_set
1583        if not self.attr_set:
1584            self.attr_set = op['attribute-set']
1585
1586        self.type_name_conflict = False
1587        if op:
1588            self.type_name = c_lower(op.name)
1589        else:
1590            self.type_name = c_lower(attr_set)
1591            if attr_set in family.consts:
1592                self.type_name_conflict = True
1593
1594        self.cw = cw
1595
1596        self.struct = dict()
1597        if op_mode == 'notify':
1598            op_mode = 'do' if 'do' in op else 'dump'
1599        for op_dir in ['request', 'reply']:
1600            if op:
1601                type_list = []
1602                if op_dir in op[op_mode]:
1603                    type_list = op[op_mode][op_dir]['attributes']
1604                self.struct[op_dir] = Struct(family, self.attr_set,
1605                                             fixed_header=fixed_hdr,
1606                                             type_list=type_list)
1607        if op_mode == 'event':
1608            self.struct['reply'] = Struct(family, self.attr_set,
1609                                          fixed_header=fixed_hdr,
1610                                          type_list=op['event']['attributes'])
1611
1612    def type_empty(self, key):
1613        return len(self.struct[key].attr_list) == 0 and \
1614            self.struct['request'].fixed_header is None
1615
1616    def needs_nlflags(self, direction):
1617        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1618
1619
1620class CodeWriter:
1621    def __init__(self, nlib, out_file=None, overwrite=True):
1622        self.nlib = nlib
1623        self._overwrite = overwrite
1624
1625        self._nl = False
1626        self._block_end = False
1627        self._silent_block = False
1628        self._ind = 0
1629        self._ifdef_block = None
1630        if out_file is None:
1631            self._out = os.sys.stdout
1632        else:
1633            self._out = tempfile.NamedTemporaryFile('w+')
1634            self._out_file = out_file
1635
1636    def __del__(self):
1637        self.close_out_file()
1638
1639    def close_out_file(self):
1640        if self._out == os.sys.stdout:
1641            return
1642        # Avoid modifying the file if contents didn't change
1643        self._out.flush()
1644        if not self._overwrite and os.path.isfile(self._out_file):
1645            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1646                return
1647        with open(self._out_file, 'w+') as out_file:
1648            self._out.seek(0)
1649            shutil.copyfileobj(self._out, out_file)
1650            self._out.close()
1651        self._out = os.sys.stdout
1652
1653    @classmethod
1654    def _is_cond(cls, line):
1655        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1656
1657    def p(self, line, add_ind=0):
1658        if self._block_end:
1659            self._block_end = False
1660            if line.startswith('else'):
1661                line = '} ' + line
1662            else:
1663                self._out.write('\t' * self._ind + '}\n')
1664
1665        if self._nl:
1666            self._out.write('\n')
1667            self._nl = False
1668
1669        ind = self._ind
1670        if line[-1] == ':':
1671            ind -= 1
1672        if self._silent_block:
1673            ind += 1
1674        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1675        self._silent_block |= line.strip() == 'else'
1676        if line[0] == '#':
1677            ind = 0
1678        if add_ind:
1679            ind += add_ind
1680        self._out.write('\t' * ind + line + '\n')
1681
1682    def nl(self):
1683        self._nl = True
1684
1685    def block_start(self, line=''):
1686        if line:
1687            line = line + ' '
1688        self.p(line + '{')
1689        self._ind += 1
1690
1691    def block_end(self, line=''):
1692        if line and line[0] not in {';', ','}:
1693            line = ' ' + line
1694        self._ind -= 1
1695        self._nl = False
1696        if not line:
1697            # Delay printing closing bracket in case "else" comes next
1698            if self._block_end:
1699                self._out.write('\t' * (self._ind + 1) + '}\n')
1700            self._block_end = True
1701        else:
1702            self.p('}' + line)
1703
1704    def write_doc_line(self, doc, indent=True):
1705        words = doc.split()
1706        line = ' *'
1707        for word in words:
1708            if len(line) + len(word) >= 79:
1709                self.p(line)
1710                line = ' *'
1711                if indent:
1712                    line += '  '
1713            line += ' ' + word
1714        self.p(line)
1715
1716    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1717        if not args:
1718            args = ['void']
1719
1720        if doc:
1721            self.p('/*')
1722            self.p(' * ' + doc)
1723            self.p(' */')
1724
1725        oneline = qual_ret
1726        if qual_ret[-1] != '*':
1727            oneline += ' '
1728        oneline += f"{name}({', '.join(args)}){suffix}"
1729
1730        if len(oneline) < 80:
1731            self.p(oneline)
1732            return
1733
1734        v = qual_ret
1735        if len(v) > 3:
1736            self.p(v)
1737            v = ''
1738        elif qual_ret[-1] != '*':
1739            v += ' '
1740        v += name + '('
1741        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1742        delta_ind = len(v) - len(ind)
1743        v += args[0]
1744        i = 1
1745        while i < len(args):
1746            next_len = len(v) + len(args[i])
1747            if v[0] == '\t':
1748                next_len += delta_ind
1749            if next_len > 76:
1750                self.p(v + ',')
1751                v = ind
1752            else:
1753                v += ', '
1754            v += args[i]
1755            i += 1
1756        self.p(v + ')' + suffix)
1757
1758    def write_func_lvar(self, local_vars):
1759        if not local_vars:
1760            return
1761
1762        if type(local_vars) is str:
1763            local_vars = [local_vars]
1764
1765        local_vars.sort(key=len, reverse=True)
1766        for var in local_vars:
1767            self.p(var)
1768        self.nl()
1769
1770    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1771        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1772        self.block_start()
1773        self.write_func_lvar(local_vars=local_vars)
1774
1775        for line in body:
1776            self.p(line)
1777        self.block_end()
1778
1779    def writes_defines(self, defines):
1780        longest = 0
1781        for define in defines:
1782            if len(define[0]) > longest:
1783                longest = len(define[0])
1784        longest = ((longest + 8) // 8) * 8
1785        for define in defines:
1786            line = '#define ' + define[0]
1787            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1788            if type(define[1]) is int:
1789                line += str(define[1])
1790            elif type(define[1]) is str:
1791                line += '"' + define[1] + '"'
1792            self.p(line)
1793
1794    def write_struct_init(self, members):
1795        longest = max([len(x[0]) for x in members])
1796        longest += 1  # because we prepend a .
1797        longest = ((longest + 8) // 8) * 8
1798        for one in members:
1799            line = '.' + one[0]
1800            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1801            line += '= ' + str(one[1]) + ','
1802            self.p(line)
1803
1804    def ifdef_block(self, config):
1805        config_option = None
1806        if config:
1807            config_option = 'CONFIG_' + c_upper(config)
1808        if self._ifdef_block == config_option:
1809            return
1810
1811        if self._ifdef_block:
1812            self.p('#endif /* ' + self._ifdef_block + ' */')
1813        if config_option:
1814            self.p('#ifdef ' + config_option)
1815        self._ifdef_block = config_option
1816
1817
1818scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1819
1820direction_to_suffix = {
1821    'reply': '_rsp',
1822    'request': '_req',
1823    '': ''
1824}
1825
1826op_mode_to_wrapper = {
1827    'do': '',
1828    'dump': '_list',
1829    'notify': '_ntf',
1830    'event': '',
1831}
1832
1833_C_KW = {
1834    'auto',
1835    'bool',
1836    'break',
1837    'case',
1838    'char',
1839    'const',
1840    'continue',
1841    'default',
1842    'do',
1843    'double',
1844    'else',
1845    'enum',
1846    'extern',
1847    'float',
1848    'for',
1849    'goto',
1850    'if',
1851    'inline',
1852    'int',
1853    'long',
1854    'register',
1855    'return',
1856    'short',
1857    'signed',
1858    'sizeof',
1859    'static',
1860    'struct',
1861    'switch',
1862    'typedef',
1863    'union',
1864    'unsigned',
1865    'void',
1866    'volatile',
1867    'while'
1868}
1869
1870
1871def rdir(direction):
1872    if direction == 'reply':
1873        return 'request'
1874    if direction == 'request':
1875        return 'reply'
1876    return direction
1877
1878
1879def op_prefix(ri, direction, deref=False):
1880    suffix = f"_{ri.type_name}"
1881
1882    if not ri.op_mode:
1883        pass
1884    elif ri.op_mode == 'do':
1885        suffix += f"{direction_to_suffix[direction]}"
1886    else:
1887        if direction == 'request':
1888            suffix += '_req'
1889            if not ri.type_oneside:
1890                suffix += '_dump'
1891        else:
1892            if ri.type_consistent:
1893                if deref:
1894                    suffix += f"{direction_to_suffix[direction]}"
1895                else:
1896                    suffix += op_mode_to_wrapper[ri.op_mode]
1897            else:
1898                suffix += '_rsp'
1899                suffix += '_dump' if deref else '_list'
1900
1901    return f"{ri.family.c_name}{suffix}"
1902
1903
1904def type_name(ri, direction, deref=False):
1905    return f"struct {op_prefix(ri, direction, deref=deref)}"
1906
1907
1908def print_prototype(ri, direction, terminate=True, doc=None):
1909    suffix = ';' if terminate else ''
1910
1911    fname = ri.op.render_name
1912    if ri.op_mode == 'dump':
1913        fname += '_dump'
1914
1915    args = ['struct ynl_sock *ys']
1916    if 'request' in ri.op[ri.op_mode]:
1917        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1918
1919    ret = 'int'
1920    if 'reply' in ri.op[ri.op_mode]:
1921        ret = f"{type_name(ri, rdir(direction))} *"
1922
1923    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1924
1925
1926def print_req_prototype(ri):
1927    print_prototype(ri, "request", doc=ri.op['doc'])
1928
1929
1930def print_dump_prototype(ri):
1931    print_prototype(ri, "request")
1932
1933
1934def put_typol_submsg(cw, struct):
1935    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1936
1937    i = 0
1938    for name, arg in struct.member_list():
1939        nest = ""
1940        if arg.type == 'nest':
1941            nest = f" .nest = &{arg.nested_render_name}_nest,"
1942        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s",%s },' %
1943             (i, name, nest))
1944        i += 1
1945
1946    cw.block_end(line=';')
1947    cw.nl()
1948
1949    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1950    cw.p(f'.max_attr = {i - 1},')
1951    cw.p(f'.table = {struct.render_name}_policy,')
1952    cw.block_end(line=';')
1953    cw.nl()
1954
1955
1956def put_typol_fwd(cw, struct):
1957    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1958
1959
1960def put_typol(cw, struct):
1961    if struct.submsg:
1962        put_typol_submsg(cw, struct)
1963        return
1964
1965    type_max = struct.attr_set.max_name
1966    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1967
1968    for _, arg in struct.member_list():
1969        arg.attr_typol(cw)
1970
1971    cw.block_end(line=';')
1972    cw.nl()
1973
1974    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1975    cw.p(f'.max_attr = {type_max},')
1976    cw.p(f'.table = {struct.render_name}_policy,')
1977    cw.block_end(line=';')
1978    cw.nl()
1979
1980
1981def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1982    args = [f'int {arg_name}']
1983    if enum:
1984        args = [enum.user_type + ' ' + arg_name]
1985    cw.write_func_prot('const char *', f'{render_name}_str', args)
1986    cw.block_start()
1987    if enum and enum.type == 'flags':
1988        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1989    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1990    cw.p('return NULL;')
1991    cw.p(f'return {map_name}[{arg_name}];')
1992    cw.block_end()
1993    cw.nl()
1994
1995
1996def put_op_name_fwd(family, cw):
1997    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1998
1999
2000def put_op_name(family, cw):
2001    map_name = f'{family.c_name}_op_strmap'
2002    cw.block_start(line=f"static const char * const {map_name}[] =")
2003    for op_name, op in family.msgs.items():
2004        if op.rsp_value:
2005            # Make sure we don't add duplicated entries, if multiple commands
2006            # produce the same response in legacy families.
2007            if family.rsp_by_value[op.rsp_value] != op:
2008                cw.p(f'// skip "{op_name}", duplicate reply value')
2009                continue
2010
2011            if op.req_value == op.rsp_value:
2012                cw.p(f'[{op.enum_name}] = "{op_name}",')
2013            else:
2014                cw.p(f'[{op.rsp_value}] = "{op_name}",')
2015    cw.block_end(line=';')
2016    cw.nl()
2017
2018    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2019
2020
2021def put_enum_to_str_fwd(family, cw, enum):
2022    args = [enum.user_type + ' value']
2023    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2024
2025
2026def put_enum_to_str(family, cw, enum):
2027    map_name = f'{enum.render_name}_strmap'
2028    cw.block_start(line=f"static const char * const {map_name}[] =")
2029    for entry in enum.entries.values():
2030        cw.p(f'[{entry.value}] = "{entry.name}",')
2031    cw.block_end(line=';')
2032    cw.nl()
2033
2034    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2035
2036
2037def put_req_nested_prototype(ri, struct, suffix=';'):
2038    func_args = ['struct nlmsghdr *nlh',
2039                 'unsigned int attr_type',
2040                 f'{struct.ptr_name}obj']
2041
2042    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2043                          suffix=suffix)
2044
2045
2046def put_req_nested(ri, struct):
2047    local_vars = []
2048    init_lines = []
2049
2050    if struct.submsg is None:
2051        local_vars.append('struct nlattr *nest;')
2052        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2053    if struct.fixed_header:
2054        local_vars.append('void *hdr;')
2055        struct_sz = f'sizeof({struct.fixed_header})'
2056        init_lines.append(f"hdr = ynl_nlmsg_put_extra_header(nlh, {struct_sz});")
2057        init_lines.append(f"memcpy(hdr, &obj->_hdr, {struct_sz});")
2058
2059    has_anest = False
2060    has_count = False
2061    for _, arg in struct.member_list():
2062        has_anest |= arg.type == 'indexed-array'
2063        has_count |= arg.presence_type() == 'count'
2064    if has_anest:
2065        local_vars.append('struct nlattr *array;')
2066    if has_count:
2067        local_vars.append('unsigned int i;')
2068
2069    put_req_nested_prototype(ri, struct, suffix='')
2070    ri.cw.block_start()
2071    ri.cw.write_func_lvar(local_vars)
2072
2073    for line in init_lines:
2074        ri.cw.p(line)
2075
2076    for _, arg in struct.member_list():
2077        arg.attr_put(ri, "obj")
2078
2079    if struct.submsg is None:
2080        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2081
2082    ri.cw.nl()
2083    ri.cw.p('return 0;')
2084    ri.cw.block_end()
2085    ri.cw.nl()
2086
2087
2088def _multi_parse(ri, struct, init_lines, local_vars):
2089    if struct.fixed_header:
2090        local_vars += ['void *hdr;']
2091    if struct.nested:
2092        if struct.fixed_header:
2093            iter_line = f"ynl_attr_for_each_nested_off(attr, nested, sizeof({struct.fixed_header}))"
2094        else:
2095            iter_line = "ynl_attr_for_each_nested(attr, nested)"
2096    else:
2097        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2098        if ri.op.fixed_header != ri.family.fixed_header:
2099            if ri.family.is_classic():
2100                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2101            else:
2102                raise Exception(f"Per-op fixed header not supported, yet")
2103
2104    array_nests = set()
2105    multi_attrs = set()
2106    needs_parg = False
2107    for arg, aspec in struct.member_list():
2108        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2109            if aspec["sub-type"] in {'binary', 'nest'}:
2110                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2111                array_nests.add(arg)
2112            elif aspec['sub-type'] in scalars:
2113                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2114                array_nests.add(arg)
2115            else:
2116                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2117        if 'multi-attr' in aspec:
2118            multi_attrs.add(arg)
2119        needs_parg |= 'nested-attributes' in aspec
2120        needs_parg |= 'sub-message' in aspec
2121    if array_nests or multi_attrs:
2122        local_vars.append('int i;')
2123    if needs_parg:
2124        local_vars.append('struct ynl_parse_arg parg;')
2125        init_lines.append('parg.ys = yarg->ys;')
2126
2127    all_multi = array_nests | multi_attrs
2128
2129    for anest in sorted(all_multi):
2130        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
2131
2132    ri.cw.block_start()
2133    ri.cw.write_func_lvar(local_vars)
2134
2135    for line in init_lines:
2136        ri.cw.p(line)
2137    ri.cw.nl()
2138
2139    for arg in struct.inherited:
2140        ri.cw.p(f'dst->{arg} = {arg};')
2141
2142    if struct.fixed_header:
2143        if struct.nested:
2144            ri.cw.p('hdr = ynl_attr_data(nested);')
2145        elif ri.family.is_classic():
2146            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2147        else:
2148            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2149        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2150    for anest in sorted(all_multi):
2151        aspec = struct[anest]
2152        ri.cw.p(f"if (dst->{aspec.c_name})")
2153        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2154
2155    ri.cw.nl()
2156    ri.cw.block_start(line=iter_line)
2157    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2158    ri.cw.nl()
2159
2160    first = True
2161    for _, arg in struct.member_list():
2162        good = arg.attr_get(ri, 'dst', first=first)
2163        # First may be 'unused' or 'pad', ignore those
2164        first &= not good
2165
2166    ri.cw.block_end()
2167    ri.cw.nl()
2168
2169    for anest in sorted(array_nests):
2170        aspec = struct[anest]
2171
2172        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2173        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2174        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2175        ri.cw.p('i = 0;')
2176        if 'nested-attributes' in aspec:
2177            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2178        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2179        if 'nested-attributes' in aspec:
2180            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2181            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2182            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2183        elif aspec.sub_type in scalars:
2184            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2185        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2186            # Length is validated by typol
2187            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2188        else:
2189            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2190        ri.cw.p('i++;')
2191        ri.cw.block_end()
2192        ri.cw.block_end()
2193    ri.cw.nl()
2194
2195    for anest in sorted(multi_attrs):
2196        aspec = struct[anest]
2197        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2198        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2199        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2200        ri.cw.p('i = 0;')
2201        if 'nested-attributes' in aspec:
2202            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2203        ri.cw.block_start(line=iter_line)
2204        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2205        if 'nested-attributes' in aspec:
2206            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2207            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2208            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2209        elif aspec.type in scalars:
2210            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2211        elif aspec.type == 'binary' and 'struct' in aspec:
2212            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2213            ri.cw.nl()
2214            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2215            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2216            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2217        elif aspec.type == 'string':
2218            ri.cw.p('unsigned int len;')
2219            ri.cw.nl()
2220            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2221            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2222            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2223            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2224            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2225        else:
2226            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2227        ri.cw.p('i++;')
2228        ri.cw.block_end()
2229        ri.cw.block_end()
2230        ri.cw.block_end()
2231    ri.cw.nl()
2232
2233    if struct.nested:
2234        ri.cw.p('return 0;')
2235    else:
2236        ri.cw.p('return YNL_PARSE_CB_OK;')
2237    ri.cw.block_end()
2238    ri.cw.nl()
2239
2240
2241def parse_rsp_submsg(ri, struct):
2242    parse_rsp_nested_prototype(ri, struct, suffix='')
2243
2244    var = 'dst'
2245    local_vars = {'const struct nlattr *attr = nested;',
2246                  f'{struct.ptr_name}{var} = yarg->data;',
2247                  'struct ynl_parse_arg parg;'}
2248
2249    for _, arg in struct.member_list():
2250        _, _, l_vars = arg._attr_get(ri, var)
2251        local_vars |= set(l_vars) if l_vars else set()
2252
2253    ri.cw.block_start()
2254    ri.cw.write_func_lvar(list(local_vars))
2255    ri.cw.p('parg.ys = yarg->ys;')
2256    ri.cw.nl()
2257
2258    first = True
2259    for name, arg in struct.member_list():
2260        kw = 'if' if first else 'else if'
2261        first = False
2262
2263        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2264        get_lines, init_lines, _ = arg._attr_get(ri, var)
2265        for line in init_lines or []:
2266            ri.cw.p(line)
2267        for line in get_lines:
2268            ri.cw.p(line)
2269        if arg.presence_type() == 'present':
2270            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2271        ri.cw.block_end()
2272    ri.cw.p('return 0;')
2273    ri.cw.block_end()
2274    ri.cw.nl()
2275
2276
2277def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2278    func_args = ['struct ynl_parse_arg *yarg',
2279                 'const struct nlattr *nested']
2280    for sel in struct.external_selectors():
2281        func_args.append('const char *_sel_' + sel.name)
2282    if struct.submsg:
2283        func_args.insert(1, 'const char *sel')
2284    for arg in struct.inherited:
2285        func_args.append('__u32 ' + arg)
2286
2287    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2288                          suffix=suffix)
2289
2290
2291def parse_rsp_nested(ri, struct):
2292    if struct.submsg:
2293        return parse_rsp_submsg(ri, struct)
2294
2295    parse_rsp_nested_prototype(ri, struct, suffix='')
2296
2297    local_vars = ['const struct nlattr *attr;',
2298                  f'{struct.ptr_name}dst = yarg->data;']
2299    init_lines = []
2300
2301    if struct.member_list():
2302        _multi_parse(ri, struct, init_lines, local_vars)
2303    else:
2304        # Empty nest
2305        ri.cw.block_start()
2306        ri.cw.p('return 0;')
2307        ri.cw.block_end()
2308        ri.cw.nl()
2309
2310
2311def parse_rsp_msg(ri, deref=False):
2312    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2313        return
2314
2315    func_args = ['const struct nlmsghdr *nlh',
2316                 'struct ynl_parse_arg *yarg']
2317
2318    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2319                  'const struct nlattr *attr;']
2320    init_lines = ['dst = yarg->data;']
2321
2322    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2323
2324    if ri.struct["reply"].member_list():
2325        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2326    else:
2327        # Empty reply
2328        ri.cw.block_start()
2329        ri.cw.p('return YNL_PARSE_CB_OK;')
2330        ri.cw.block_end()
2331        ri.cw.nl()
2332
2333
2334def print_req(ri):
2335    ret_ok = '0'
2336    ret_err = '-1'
2337    direction = "request"
2338    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2339                  'struct nlmsghdr *nlh;',
2340                  'int err;']
2341
2342    if 'reply' in ri.op[ri.op_mode]:
2343        ret_ok = 'rsp'
2344        ret_err = 'NULL'
2345        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2346
2347    if ri.struct["request"].fixed_header:
2348        local_vars += ['size_t hdr_len;',
2349                       'void *hdr;']
2350
2351    for _, attr in ri.struct["request"].member_list():
2352        if attr.presence_type() == 'count':
2353            local_vars += ['unsigned int i;']
2354            break
2355
2356    print_prototype(ri, direction, terminate=False)
2357    ri.cw.block_start()
2358    ri.cw.write_func_lvar(local_vars)
2359
2360    if ri.family.is_classic():
2361        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2362    else:
2363        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2364
2365    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2366    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2367    if 'reply' in ri.op[ri.op_mode]:
2368        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2369    ri.cw.nl()
2370
2371    if ri.struct['request'].fixed_header:
2372        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2373        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2374        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2375        ri.cw.nl()
2376
2377    for _, attr in ri.struct["request"].member_list():
2378        attr.attr_put(ri, "req")
2379    ri.cw.nl()
2380
2381    if 'reply' in ri.op[ri.op_mode]:
2382        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2383        ri.cw.p('yrs.yarg.data = rsp;')
2384        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2385        if ri.op.value is not None:
2386            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2387        else:
2388            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2389        ri.cw.nl()
2390    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2391    ri.cw.p('if (err < 0)')
2392    if 'reply' in ri.op[ri.op_mode]:
2393        ri.cw.p('goto err_free;')
2394    else:
2395        ri.cw.p('return -1;')
2396    ri.cw.nl()
2397
2398    ri.cw.p(f"return {ret_ok};")
2399    ri.cw.nl()
2400
2401    if 'reply' in ri.op[ri.op_mode]:
2402        ri.cw.p('err_free:')
2403        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2404        ri.cw.p(f"return {ret_err};")
2405
2406    ri.cw.block_end()
2407
2408
2409def print_dump(ri):
2410    direction = "request"
2411    print_prototype(ri, direction, terminate=False)
2412    ri.cw.block_start()
2413    local_vars = ['struct ynl_dump_state yds = {};',
2414                  'struct nlmsghdr *nlh;',
2415                  'int err;']
2416
2417    if ri.struct['request'].fixed_header:
2418        local_vars += ['size_t hdr_len;',
2419                       'void *hdr;']
2420
2421    ri.cw.write_func_lvar(local_vars)
2422
2423    ri.cw.p('yds.yarg.ys = ys;')
2424    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2425    ri.cw.p("yds.yarg.data = NULL;")
2426    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2427    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2428    if ri.op.value is not None:
2429        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2430    else:
2431        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2432    ri.cw.nl()
2433    if ri.family.is_classic():
2434        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2435    else:
2436        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2437
2438    if ri.struct['request'].fixed_header:
2439        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2440        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2441        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2442        ri.cw.nl()
2443
2444    if "request" in ri.op[ri.op_mode]:
2445        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2446        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2447        ri.cw.nl()
2448        for _, attr in ri.struct["request"].member_list():
2449            attr.attr_put(ri, "req")
2450    ri.cw.nl()
2451
2452    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2453    ri.cw.p('if (err < 0)')
2454    ri.cw.p('goto free_list;')
2455    ri.cw.nl()
2456
2457    ri.cw.p('return yds.first;')
2458    ri.cw.nl()
2459    ri.cw.p('free_list:')
2460    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2461    ri.cw.p('return NULL;')
2462    ri.cw.block_end()
2463
2464
2465def call_free(ri, direction, var):
2466    return f"{op_prefix(ri, direction)}_free({var});"
2467
2468
2469def free_arg_name(direction):
2470    if direction:
2471        return direction_to_suffix[direction][1:]
2472    return 'obj'
2473
2474
2475def print_alloc_wrapper(ri, direction, struct=None):
2476    name = op_prefix(ri, direction)
2477    struct_name = name
2478    if ri.type_name_conflict:
2479        struct_name += '_'
2480
2481    args = ["void"]
2482    cnt = "1"
2483    if struct and struct.in_multi_val:
2484        args = ["unsigned int n"]
2485        cnt = "n"
2486
2487    ri.cw.write_func_prot(f'static inline struct {struct_name} *',
2488                          f"{name}_alloc", args)
2489    ri.cw.block_start()
2490    ri.cw.p(f'return calloc({cnt}, sizeof(struct {struct_name}));')
2491    ri.cw.block_end()
2492
2493
2494def print_free_prototype(ri, direction, suffix=';'):
2495    name = op_prefix(ri, direction)
2496    struct_name = name
2497    if ri.type_name_conflict:
2498        struct_name += '_'
2499    arg = free_arg_name(direction)
2500    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2501
2502
2503def print_nlflags_set(ri, direction):
2504    name = op_prefix(ri, direction)
2505    ri.cw.write_func_prot(f'static inline void', f"{name}_set_nlflags",
2506                          [f"struct {name} *req", "__u16 nl_flags"])
2507    ri.cw.block_start()
2508    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2509    ri.cw.block_end()
2510    ri.cw.nl()
2511
2512
2513def _print_type(ri, direction, struct):
2514    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2515    if not direction and ri.type_name_conflict:
2516        suffix += '_'
2517
2518    if ri.op_mode == 'dump' and not ri.type_oneside:
2519        suffix += '_dump'
2520
2521    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2522
2523    if ri.needs_nlflags(direction):
2524        ri.cw.p('__u16 _nlmsg_flags;')
2525        ri.cw.nl()
2526    if struct.fixed_header:
2527        ri.cw.p(struct.fixed_header + ' _hdr;')
2528        ri.cw.nl()
2529
2530    for type_filter in ['present', 'len', 'count']:
2531        meta_started = False
2532        for _, attr in struct.member_list():
2533            line = attr.presence_member(ri.ku_space, type_filter)
2534            if line:
2535                if not meta_started:
2536                    ri.cw.block_start(line=f"struct")
2537                    meta_started = True
2538                ri.cw.p(line)
2539        if meta_started:
2540            ri.cw.block_end(line=f'_{type_filter};')
2541    ri.cw.nl()
2542
2543    for arg in struct.inherited:
2544        ri.cw.p(f"__u32 {arg};")
2545
2546    for _, attr in struct.member_list():
2547        attr.struct_member(ri)
2548
2549    ri.cw.block_end(line=';')
2550    ri.cw.nl()
2551
2552
2553def print_type(ri, direction):
2554    _print_type(ri, direction, ri.struct[direction])
2555
2556
2557def print_type_full(ri, struct):
2558    _print_type(ri, "", struct)
2559
2560    if struct.request and struct.in_multi_val:
2561        print_alloc_wrapper(ri, "", struct)
2562        ri.cw.nl()
2563        free_rsp_nested_prototype(ri)
2564        ri.cw.nl()
2565
2566        # Name conflicts are too hard to deal with with the current code base,
2567        # they are very rare so don't bother printing setters in that case.
2568        if ri.ku_space == 'user' and not ri.type_name_conflict:
2569            for _, attr in struct.member_list():
2570                attr.setter(ri, ri.attr_set, "", var="obj")
2571        ri.cw.nl()
2572
2573
2574def print_type_helpers(ri, direction, deref=False):
2575    print_free_prototype(ri, direction)
2576    ri.cw.nl()
2577
2578    if ri.needs_nlflags(direction):
2579        print_nlflags_set(ri, direction)
2580
2581    if ri.ku_space == 'user' and direction == 'request':
2582        for _, attr in ri.struct[direction].member_list():
2583            attr.setter(ri, ri.attr_set, direction, deref=deref)
2584    ri.cw.nl()
2585
2586
2587def print_req_type_helpers(ri):
2588    if ri.type_empty("request"):
2589        return
2590    print_alloc_wrapper(ri, "request")
2591    print_type_helpers(ri, "request")
2592
2593
2594def print_rsp_type_helpers(ri):
2595    if 'reply' not in ri.op[ri.op_mode]:
2596        return
2597    print_type_helpers(ri, "reply")
2598
2599
2600def print_parse_prototype(ri, direction, terminate=True):
2601    suffix = "_rsp" if direction == "reply" else "_req"
2602    term = ';' if terminate else ''
2603
2604    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2605                          ['const struct nlattr **tb',
2606                           f"struct {ri.op.render_name}{suffix} *req"],
2607                          suffix=term)
2608
2609
2610def print_req_type(ri):
2611    if ri.type_empty("request"):
2612        return
2613    print_type(ri, "request")
2614
2615
2616def print_req_free(ri):
2617    if 'request' not in ri.op[ri.op_mode]:
2618        return
2619    _free_type(ri, 'request', ri.struct['request'])
2620
2621
2622def print_rsp_type(ri):
2623    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2624        direction = 'reply'
2625    elif ri.op_mode == 'event':
2626        direction = 'reply'
2627    else:
2628        return
2629    print_type(ri, direction)
2630
2631
2632def print_wrapped_type(ri):
2633    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2634    if ri.op_mode == 'dump':
2635        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2636    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2637        ri.cw.p('__u16 family;')
2638        ri.cw.p('__u8 cmd;')
2639        ri.cw.p('struct ynl_ntf_base_type *next;')
2640        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2641    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2642    ri.cw.block_end(line=';')
2643    ri.cw.nl()
2644    print_free_prototype(ri, 'reply')
2645    ri.cw.nl()
2646
2647
2648def _free_type_members_iter(ri, struct):
2649    if struct.free_needs_iter():
2650        ri.cw.p('unsigned int i;')
2651        ri.cw.nl()
2652
2653
2654def _free_type_members(ri, var, struct, ref=''):
2655    for _, attr in struct.member_list():
2656        attr.free(ri, var, ref)
2657
2658
2659def _free_type(ri, direction, struct):
2660    var = free_arg_name(direction)
2661
2662    print_free_prototype(ri, direction, suffix='')
2663    ri.cw.block_start()
2664    _free_type_members_iter(ri, struct)
2665    _free_type_members(ri, var, struct)
2666    if direction:
2667        ri.cw.p(f'free({var});')
2668    ri.cw.block_end()
2669    ri.cw.nl()
2670
2671
2672def free_rsp_nested_prototype(ri):
2673        print_free_prototype(ri, "")
2674
2675
2676def free_rsp_nested(ri, struct):
2677    _free_type(ri, "", struct)
2678
2679
2680def print_rsp_free(ri):
2681    if 'reply' not in ri.op[ri.op_mode]:
2682        return
2683    _free_type(ri, 'reply', ri.struct['reply'])
2684
2685
2686def print_dump_type_free(ri):
2687    sub_type = type_name(ri, 'reply')
2688
2689    print_free_prototype(ri, 'reply', suffix='')
2690    ri.cw.block_start()
2691    ri.cw.p(f"{sub_type} *next = rsp;")
2692    ri.cw.nl()
2693    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2694    _free_type_members_iter(ri, ri.struct['reply'])
2695    ri.cw.p('rsp = next;')
2696    ri.cw.p('next = rsp->next;')
2697    ri.cw.nl()
2698
2699    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2700    ri.cw.p(f'free(rsp);')
2701    ri.cw.block_end()
2702    ri.cw.block_end()
2703    ri.cw.nl()
2704
2705
2706def print_ntf_type_free(ri):
2707    print_free_prototype(ri, 'reply', suffix='')
2708    ri.cw.block_start()
2709    _free_type_members_iter(ri, ri.struct['reply'])
2710    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2711    ri.cw.p(f'free(rsp);')
2712    ri.cw.block_end()
2713    ri.cw.nl()
2714
2715
2716def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2717    if terminate and ri and policy_should_be_static(struct.family):
2718        return
2719
2720    if terminate:
2721        prefix = 'extern '
2722    else:
2723        if ri and policy_should_be_static(struct.family):
2724            prefix = 'static '
2725        else:
2726            prefix = ''
2727
2728    suffix = ';' if terminate else ' = {'
2729
2730    max_attr = struct.attr_max_val
2731    if ri:
2732        name = ri.op.render_name
2733        if ri.op.dual_policy:
2734            name += '_' + ri.op_mode
2735    else:
2736        name = struct.render_name
2737    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2738
2739
2740def print_req_policy(cw, struct, ri=None):
2741    if ri and ri.op:
2742        cw.ifdef_block(ri.op.get('config-cond', None))
2743    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2744    for _, arg in struct.member_list():
2745        arg.attr_policy(cw)
2746    cw.p("};")
2747    cw.ifdef_block(None)
2748    cw.nl()
2749
2750
2751def kernel_can_gen_family_struct(family):
2752    return family.proto == 'genetlink'
2753
2754
2755def policy_should_be_static(family):
2756    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2757
2758
2759def print_kernel_policy_ranges(family, cw):
2760    first = True
2761    for _, attr_set in family.attr_sets.items():
2762        if attr_set.subset_of:
2763            continue
2764
2765        for _, attr in attr_set.items():
2766            if not attr.request:
2767                continue
2768            if 'full-range' not in attr.checks:
2769                continue
2770
2771            if first:
2772                cw.p('/* Integer value ranges */')
2773                first = False
2774
2775            sign = '' if attr.type[0] == 'u' else '_signed'
2776            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2777            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2778            members = []
2779            if 'min' in attr.checks:
2780                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2781            if 'max' in attr.checks:
2782                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2783            cw.write_struct_init(members)
2784            cw.block_end(line=';')
2785            cw.nl()
2786
2787
2788def print_kernel_policy_sparse_enum_validates(family, cw):
2789    first = True
2790    for _, attr_set in family.attr_sets.items():
2791        if attr_set.subset_of:
2792            continue
2793
2794        for _, attr in attr_set.items():
2795            if not attr.request:
2796                continue
2797            if not attr.enum_name:
2798                continue
2799            if 'sparse' not in attr.checks:
2800                continue
2801
2802            if first:
2803                cw.p('/* Sparse enums validation callbacks */')
2804                first = False
2805
2806            sign = '' if attr.type[0] == 'u' else '_signed'
2807            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2808            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2809                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2810            cw.block_start()
2811            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2812            enum = family.consts[attr['enum']]
2813            first_entry = True
2814            for entry in enum.entries.values():
2815                if first_entry:
2816                    first_entry = False
2817                else:
2818                    cw.p('fallthrough;')
2819                cw.p(f'case {entry.c_name}:')
2820            cw.p('return 0;')
2821            cw.block_end()
2822            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2823            cw.p('return -EINVAL;')
2824            cw.block_end()
2825            cw.nl()
2826
2827
2828def print_kernel_op_table_fwd(family, cw, terminate):
2829    exported = not kernel_can_gen_family_struct(family)
2830
2831    if not terminate or exported:
2832        cw.p(f"/* Ops table for {family.ident_name} */")
2833
2834        pol_to_struct = {'global': 'genl_small_ops',
2835                         'per-op': 'genl_ops',
2836                         'split': 'genl_split_ops'}
2837        struct_type = pol_to_struct[family.kernel_policy]
2838
2839        if not exported:
2840            cnt = ""
2841        elif family.kernel_policy == 'split':
2842            cnt = 0
2843            for op in family.ops.values():
2844                if 'do' in op:
2845                    cnt += 1
2846                if 'dump' in op:
2847                    cnt += 1
2848        else:
2849            cnt = len(family.ops)
2850
2851        qual = 'static const' if not exported else 'const'
2852        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2853        if terminate:
2854            cw.p(f"extern {line};")
2855        else:
2856            cw.block_start(line=line + ' =')
2857
2858    if not terminate:
2859        return
2860
2861    cw.nl()
2862    for name in family.hooks['pre']['do']['list']:
2863        cw.write_func_prot('int', c_lower(name),
2864                           ['const struct genl_split_ops *ops',
2865                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2866    for name in family.hooks['post']['do']['list']:
2867        cw.write_func_prot('void', c_lower(name),
2868                           ['const struct genl_split_ops *ops',
2869                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2870    for name in family.hooks['pre']['dump']['list']:
2871        cw.write_func_prot('int', c_lower(name),
2872                           ['struct netlink_callback *cb'], suffix=';')
2873    for name in family.hooks['post']['dump']['list']:
2874        cw.write_func_prot('int', c_lower(name),
2875                           ['struct netlink_callback *cb'], suffix=';')
2876
2877    cw.nl()
2878
2879    for op_name, op in family.ops.items():
2880        if op.is_async:
2881            continue
2882
2883        if 'do' in op:
2884            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2885            cw.write_func_prot('int', name,
2886                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2887
2888        if 'dump' in op:
2889            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2890            cw.write_func_prot('int', name,
2891                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2892    cw.nl()
2893
2894
2895def print_kernel_op_table_hdr(family, cw):
2896    print_kernel_op_table_fwd(family, cw, terminate=True)
2897
2898
2899def print_kernel_op_table(family, cw):
2900    print_kernel_op_table_fwd(family, cw, terminate=False)
2901    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2902        for op_name, op in family.ops.items():
2903            if op.is_async:
2904                continue
2905
2906            cw.ifdef_block(op.get('config-cond', None))
2907            cw.block_start()
2908            members = [('cmd', op.enum_name)]
2909            if 'dont-validate' in op:
2910                members.append(('validate',
2911                                ' | '.join([c_upper('genl-dont-validate-' + x)
2912                                            for x in op['dont-validate']])), )
2913            for op_mode in ['do', 'dump']:
2914                if op_mode in op:
2915                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2916                    members.append((op_mode + 'it', name))
2917            if family.kernel_policy == 'per-op':
2918                struct = Struct(family, op['attribute-set'],
2919                                type_list=op['do']['request']['attributes'])
2920
2921                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2922                members.append(('policy', name))
2923                members.append(('maxattr', struct.attr_max_val.enum_name))
2924            if 'flags' in op:
2925                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2926            cw.write_struct_init(members)
2927            cw.block_end(line=',')
2928    elif family.kernel_policy == 'split':
2929        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2930                    'dump': {'pre': 'start', 'post': 'done'}}
2931
2932        for op_name, op in family.ops.items():
2933            for op_mode in ['do', 'dump']:
2934                if op.is_async or op_mode not in op:
2935                    continue
2936
2937                cw.ifdef_block(op.get('config-cond', None))
2938                cw.block_start()
2939                members = [('cmd', op.enum_name)]
2940                if 'dont-validate' in op:
2941                    dont_validate = []
2942                    for x in op['dont-validate']:
2943                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2944                            continue
2945                        if op_mode == "dump" and x == 'strict':
2946                            continue
2947                        dont_validate.append(x)
2948
2949                    if dont_validate:
2950                        members.append(('validate',
2951                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2952                                                    for x in dont_validate])), )
2953                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2954                if 'pre' in op[op_mode]:
2955                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2956                members.append((op_mode + 'it', name))
2957                if 'post' in op[op_mode]:
2958                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2959                if 'request' in op[op_mode]:
2960                    struct = Struct(family, op['attribute-set'],
2961                                    type_list=op[op_mode]['request']['attributes'])
2962
2963                    if op.dual_policy:
2964                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2965                    else:
2966                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2967                    members.append(('policy', name))
2968                    members.append(('maxattr', struct.attr_max_val.enum_name))
2969                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2970                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2971                cw.write_struct_init(members)
2972                cw.block_end(line=',')
2973    cw.ifdef_block(None)
2974
2975    cw.block_end(line=';')
2976    cw.nl()
2977
2978
2979def print_kernel_mcgrp_hdr(family, cw):
2980    if not family.mcgrps['list']:
2981        return
2982
2983    cw.block_start('enum')
2984    for grp in family.mcgrps['list']:
2985        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2986        cw.p(grp_id)
2987    cw.block_end(';')
2988    cw.nl()
2989
2990
2991def print_kernel_mcgrp_src(family, cw):
2992    if not family.mcgrps['list']:
2993        return
2994
2995    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2996    for grp in family.mcgrps['list']:
2997        name = grp['name']
2998        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2999        cw.p('[' + grp_id + '] = { "' + name + '", },')
3000    cw.block_end(';')
3001    cw.nl()
3002
3003
3004def print_kernel_family_struct_hdr(family, cw):
3005    if not kernel_can_gen_family_struct(family):
3006        return
3007
3008    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
3009    cw.nl()
3010    if 'sock-priv' in family.kernel_family:
3011        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
3012        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
3013        cw.nl()
3014
3015
3016def print_kernel_family_struct_src(family, cw):
3017    if not kernel_can_gen_family_struct(family):
3018        return
3019
3020    if 'sock-priv' in family.kernel_family:
3021        # Generate "trampolines" to make CFI happy
3022        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
3023                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
3024                      ["void *priv"])
3025        cw.nl()
3026        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
3027                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
3028                      ["void *priv"])
3029        cw.nl()
3030
3031    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
3032    cw.p('.name\t\t= ' + family.fam_key + ',')
3033    cw.p('.version\t= ' + family.ver_key + ',')
3034    cw.p('.netnsok\t= true,')
3035    cw.p('.parallel_ops\t= true,')
3036    cw.p('.module\t\t= THIS_MODULE,')
3037    if family.kernel_policy == 'per-op':
3038        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
3039        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3040    elif family.kernel_policy == 'split':
3041        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
3042        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3043    if family.mcgrps['list']:
3044        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
3045        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
3046    if 'sock-priv' in family.kernel_family:
3047        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
3048        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
3049        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
3050    cw.block_end(';')
3051
3052
3053def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
3054    start_line = 'enum'
3055    if enum_name in obj:
3056        if obj[enum_name]:
3057            start_line = 'enum ' + c_lower(obj[enum_name])
3058    elif ckey and ckey in obj:
3059        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3060    cw.block_start(line=start_line)
3061
3062
3063def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3064    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3065    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3066    max_value = f"({cnt_name} - 1)"
3067
3068    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3069    val = 0
3070    for op in family.msgs.values():
3071        if separate_ntf and ('notify' in op or 'event' in op):
3072            continue
3073
3074        suffix = ','
3075        if op.value != val:
3076            suffix = f" = {op.value},"
3077            val = op.value
3078        cw.p(op.enum_name + suffix)
3079        val += 1
3080    cw.nl()
3081    cw.p(cnt_name + ('' if max_by_define else ','))
3082    if not max_by_define:
3083        cw.p(f"{max_name} = {max_value}")
3084    cw.block_end(line=';')
3085    if max_by_define:
3086        cw.p(f"#define {max_name} {max_value}")
3087    cw.nl()
3088
3089
3090def render_uapi_directional(family, cw, max_by_define):
3091    max_name = f"{family.op_prefix}USER_MAX"
3092    cnt_name = f"__{family.op_prefix}USER_CNT"
3093    max_value = f"({cnt_name} - 1)"
3094
3095    cw.block_start(line='enum')
3096    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3097    val = 0
3098    for op in family.msgs.values():
3099        if 'do' in op and 'event' not in op:
3100            suffix = ','
3101            if op.value and op.value != val:
3102                suffix = f" = {op.value},"
3103                val = op.value
3104            cw.p(op.enum_name + suffix)
3105            val += 1
3106    cw.nl()
3107    cw.p(cnt_name + ('' if max_by_define else ','))
3108    if not max_by_define:
3109        cw.p(f"{max_name} = {max_value}")
3110    cw.block_end(line=';')
3111    if max_by_define:
3112        cw.p(f"#define {max_name} {max_value}")
3113    cw.nl()
3114
3115    max_name = f"{family.op_prefix}KERNEL_MAX"
3116    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3117    max_value = f"({cnt_name} - 1)"
3118
3119    cw.block_start(line='enum')
3120    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3121    val = 0
3122    for op in family.msgs.values():
3123        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3124            enum_name = op.enum_name
3125            if 'event' not in op and 'notify' not in op:
3126                enum_name = f'{enum_name}_REPLY'
3127
3128            suffix = ','
3129            if op.value and op.value != val:
3130                suffix = f" = {op.value},"
3131                val = op.value
3132            cw.p(enum_name + suffix)
3133            val += 1
3134    cw.nl()
3135    cw.p(cnt_name + ('' if max_by_define else ','))
3136    if not max_by_define:
3137        cw.p(f"{max_name} = {max_value}")
3138    cw.block_end(line=';')
3139    if max_by_define:
3140        cw.p(f"#define {max_name} {max_value}")
3141    cw.nl()
3142
3143
3144def render_uapi(family, cw):
3145    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3146    hdr_prot = hdr_prot.replace('/', '_')
3147    cw.p('#ifndef ' + hdr_prot)
3148    cw.p('#define ' + hdr_prot)
3149    cw.nl()
3150
3151    defines = [(family.fam_key, family["name"]),
3152               (family.ver_key, family.get('version', 1))]
3153    cw.writes_defines(defines)
3154    cw.nl()
3155
3156    defines = []
3157    for const in family['definitions']:
3158        if const.get('header'):
3159            continue
3160
3161        if const['type'] != 'const':
3162            cw.writes_defines(defines)
3163            defines = []
3164            cw.nl()
3165
3166        # Write kdoc for enum and flags (one day maybe also structs)
3167        if const['type'] == 'enum' or const['type'] == 'flags':
3168            enum = family.consts[const['name']]
3169
3170            if enum.header:
3171                continue
3172
3173            if enum.has_doc():
3174                if enum.has_entry_doc():
3175                    cw.p('/**')
3176                    doc = ''
3177                    if 'doc' in enum:
3178                        doc = ' - ' + enum['doc']
3179                    cw.write_doc_line(enum.enum_name + doc)
3180                else:
3181                    cw.p('/*')
3182                    cw.write_doc_line(enum['doc'], indent=False)
3183                for entry in enum.entries.values():
3184                    if entry.has_doc():
3185                        doc = '@' + entry.c_name + ': ' + entry['doc']
3186                        cw.write_doc_line(doc)
3187                cw.p(' */')
3188
3189            uapi_enum_start(family, cw, const, 'name')
3190            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3191            for entry in enum.entries.values():
3192                suffix = ','
3193                if entry.value_change:
3194                    suffix = f" = {entry.user_value()}" + suffix
3195                cw.p(entry.c_name + suffix)
3196
3197            if const.get('render-max', False):
3198                cw.nl()
3199                cw.p('/* private: */')
3200                if const['type'] == 'flags':
3201                    max_name = c_upper(name_pfx + 'mask')
3202                    max_val = f' = {enum.get_mask()},'
3203                    cw.p(max_name + max_val)
3204                else:
3205                    cnt_name = enum.enum_cnt_name
3206                    max_name = c_upper(name_pfx + 'max')
3207                    if not cnt_name:
3208                        cnt_name = '__' + name_pfx + 'max'
3209                    cw.p(c_upper(cnt_name) + ',')
3210                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3211            cw.block_end(line=';')
3212            cw.nl()
3213        elif const['type'] == 'const':
3214            defines.append([c_upper(family.get('c-define-name',
3215                                               f"{family.ident_name}-{const['name']}")),
3216                            const['value']])
3217
3218    if defines:
3219        cw.writes_defines(defines)
3220        cw.nl()
3221
3222    max_by_define = family.get('max-by-define', False)
3223
3224    for _, attr_set in family.attr_sets.items():
3225        if attr_set.subset_of:
3226            continue
3227
3228        max_value = f"({attr_set.cnt_name} - 1)"
3229
3230        val = 0
3231        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3232        for _, attr in attr_set.items():
3233            suffix = ','
3234            if attr.value != val:
3235                suffix = f" = {attr.value},"
3236                val = attr.value
3237            val += 1
3238            cw.p(attr.enum_name + suffix)
3239        if attr_set.items():
3240            cw.nl()
3241        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3242        if not max_by_define:
3243            cw.p(f"{attr_set.max_name} = {max_value}")
3244        cw.block_end(line=';')
3245        if max_by_define:
3246            cw.p(f"#define {attr_set.max_name} {max_value}")
3247        cw.nl()
3248
3249    # Commands
3250    separate_ntf = 'async-prefix' in family['operations']
3251
3252    if family.msg_id_model == 'unified':
3253        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3254    elif family.msg_id_model == 'directional':
3255        render_uapi_directional(family, cw, max_by_define)
3256    else:
3257        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3258
3259    if separate_ntf:
3260        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3261        for op in family.msgs.values():
3262            if separate_ntf and not ('notify' in op or 'event' in op):
3263                continue
3264
3265            suffix = ','
3266            if 'value' in op:
3267                suffix = f" = {op['value']},"
3268            cw.p(op.enum_name + suffix)
3269        cw.block_end(line=';')
3270        cw.nl()
3271
3272    # Multicast
3273    defines = []
3274    for grp in family.mcgrps['list']:
3275        name = grp['name']
3276        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3277                        f'{name}'])
3278    cw.nl()
3279    if defines:
3280        cw.writes_defines(defines)
3281        cw.nl()
3282
3283    cw.p(f'#endif /* {hdr_prot} */')
3284
3285
3286def _render_user_ntf_entry(ri, op):
3287    if not ri.family.is_classic():
3288        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3289    else:
3290        crud_op = ri.family.req_by_value[op.rsp_value]
3291        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3292    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3293    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3294    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3295    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3296    ri.cw.block_end(line=',')
3297
3298
3299def render_user_family(family, cw, prototype):
3300    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3301    if prototype:
3302        cw.p(f'extern {symbol};')
3303        return
3304
3305    if family.ntfs:
3306        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3307        for ntf_op_name, ntf_op in family.ntfs.items():
3308            if 'notify' in ntf_op:
3309                op = family.ops[ntf_op['notify']]
3310                ri = RenderInfo(cw, family, "user", op, "notify")
3311            elif 'event' in ntf_op:
3312                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3313            else:
3314                raise Exception('Invalid notification ' + ntf_op_name)
3315            _render_user_ntf_entry(ri, ntf_op)
3316        for op_name, op in family.ops.items():
3317            if 'event' not in op:
3318                continue
3319            ri = RenderInfo(cw, family, "user", op, "event")
3320            _render_user_ntf_entry(ri, op)
3321        cw.block_end(line=";")
3322        cw.nl()
3323
3324    cw.block_start(f'{symbol} = ')
3325    cw.p(f'.name\t\t= "{family.c_name}",')
3326    if family.is_classic():
3327        cw.p(f'.is_classic\t= true,')
3328        cw.p(f'.classic_id\t= {family.get("protonum")},')
3329    if family.is_classic():
3330        if family.fixed_header:
3331            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3332    elif family.fixed_header:
3333        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3334    else:
3335        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3336    if family.ntfs:
3337        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3338        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3339    cw.block_end(line=';')
3340
3341
3342def family_contains_bitfield32(family):
3343    for _, attr_set in family.attr_sets.items():
3344        if attr_set.subset_of:
3345            continue
3346        for _, attr in attr_set.items():
3347            if attr.type == "bitfield32":
3348                return True
3349    return False
3350
3351
3352def find_kernel_root(full_path):
3353    sub_path = ''
3354    while True:
3355        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3356        full_path = os.path.dirname(full_path)
3357        maintainers = os.path.join(full_path, "MAINTAINERS")
3358        if os.path.exists(maintainers):
3359            return full_path, sub_path[:-1]
3360
3361
3362def main():
3363    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3364    parser.add_argument('--mode', dest='mode', type=str, required=True,
3365                        choices=('user', 'kernel', 'uapi'))
3366    parser.add_argument('--spec', dest='spec', type=str, required=True)
3367    parser.add_argument('--header', dest='header', action='store_true', default=None)
3368    parser.add_argument('--source', dest='header', action='store_false')
3369    parser.add_argument('--user-header', nargs='+', default=[])
3370    parser.add_argument('--cmp-out', action='store_true', default=None,
3371                        help='Do not overwrite the output file if the new output is identical to the old')
3372    parser.add_argument('--exclude-op', action='append', default=[])
3373    parser.add_argument('-o', dest='out_file', type=str, default=None)
3374    args = parser.parse_args()
3375
3376    if args.header is None:
3377        parser.error("--header or --source is required")
3378
3379    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3380
3381    try:
3382        parsed = Family(args.spec, exclude_ops)
3383        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3384            print('Spec license:', parsed.license)
3385            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3386            os.sys.exit(1)
3387    except yaml.YAMLError as exc:
3388        print(exc)
3389        os.sys.exit(1)
3390        return
3391
3392    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3393
3394    _, spec_kernel = find_kernel_root(args.spec)
3395    if args.mode == 'uapi' or args.header:
3396        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3397    else:
3398        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3399    cw.p("/* Do not edit directly, auto-generated from: */")
3400    cw.p(f"/*\t{spec_kernel} */")
3401    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3402    if args.exclude_op or args.user_header:
3403        line = ''
3404        line += ' --user-header '.join([''] + args.user_header)
3405        line += ' --exclude-op '.join([''] + args.exclude_op)
3406        cw.p(f'/* YNL-ARG{line} */')
3407    cw.nl()
3408
3409    if args.mode == 'uapi':
3410        render_uapi(parsed, cw)
3411        return
3412
3413    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3414    if args.header:
3415        cw.p('#ifndef ' + hdr_prot)
3416        cw.p('#define ' + hdr_prot)
3417        cw.nl()
3418
3419    if args.out_file:
3420        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3421    else:
3422        hdr_file = "generated_header_file.h"
3423
3424    if args.mode == 'kernel':
3425        cw.p('#include <net/netlink.h>')
3426        cw.p('#include <net/genetlink.h>')
3427        cw.nl()
3428        if not args.header:
3429            if args.out_file:
3430                cw.p(f'#include "{hdr_file}"')
3431            cw.nl()
3432        headers = ['uapi/' + parsed.uapi_header]
3433        headers += parsed.kernel_family.get('headers', [])
3434    else:
3435        cw.p('#include <stdlib.h>')
3436        cw.p('#include <string.h>')
3437        if args.header:
3438            cw.p('#include <linux/types.h>')
3439            if family_contains_bitfield32(parsed):
3440                cw.p('#include <linux/netlink.h>')
3441        else:
3442            cw.p(f'#include "{hdr_file}"')
3443            cw.p('#include "ynl.h"')
3444        headers = []
3445    for definition in parsed['definitions'] + parsed['attribute-sets']:
3446        if 'header' in definition:
3447            headers.append(definition['header'])
3448    if args.mode == 'user':
3449        headers.append(parsed.uapi_header)
3450    seen_header = []
3451    for one in headers:
3452        if one not in seen_header:
3453            cw.p(f"#include <{one}>")
3454            seen_header.append(one)
3455    cw.nl()
3456
3457    if args.mode == "user":
3458        if not args.header:
3459            cw.p("#include <linux/genetlink.h>")
3460            cw.nl()
3461            for one in args.user_header:
3462                cw.p(f'#include "{one}"')
3463        else:
3464            cw.p('struct ynl_sock;')
3465            cw.nl()
3466            render_user_family(parsed, cw, True)
3467        cw.nl()
3468
3469    if args.mode == "kernel":
3470        if args.header:
3471            for _, struct in sorted(parsed.pure_nested_structs.items()):
3472                if struct.request:
3473                    cw.p('/* Common nested types */')
3474                    break
3475            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3476                if struct.request:
3477                    print_req_policy_fwd(cw, struct)
3478            cw.nl()
3479
3480            if parsed.kernel_policy == 'global':
3481                cw.p(f"/* Global operation policy for {parsed.name} */")
3482
3483                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3484                print_req_policy_fwd(cw, struct)
3485                cw.nl()
3486
3487            if parsed.kernel_policy in {'per-op', 'split'}:
3488                for op_name, op in parsed.ops.items():
3489                    if 'do' in op and 'event' not in op:
3490                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3491                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3492                        cw.nl()
3493
3494            print_kernel_op_table_hdr(parsed, cw)
3495            print_kernel_mcgrp_hdr(parsed, cw)
3496            print_kernel_family_struct_hdr(parsed, cw)
3497        else:
3498            print_kernel_policy_ranges(parsed, cw)
3499            print_kernel_policy_sparse_enum_validates(parsed, cw)
3500
3501            for _, struct in sorted(parsed.pure_nested_structs.items()):
3502                if struct.request:
3503                    cw.p('/* Common nested types */')
3504                    break
3505            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3506                if struct.request:
3507                    print_req_policy(cw, struct)
3508            cw.nl()
3509
3510            if parsed.kernel_policy == 'global':
3511                cw.p(f"/* Global operation policy for {parsed.name} */")
3512
3513                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3514                print_req_policy(cw, struct)
3515                cw.nl()
3516
3517            for op_name, op in parsed.ops.items():
3518                if parsed.kernel_policy in {'per-op', 'split'}:
3519                    for op_mode in ['do', 'dump']:
3520                        if op_mode in op and 'request' in op[op_mode]:
3521                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3522                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3523                            print_req_policy(cw, ri.struct['request'], ri=ri)
3524                            cw.nl()
3525
3526            print_kernel_op_table(parsed, cw)
3527            print_kernel_mcgrp_src(parsed, cw)
3528            print_kernel_family_struct_src(parsed, cw)
3529
3530    if args.mode == "user":
3531        if args.header:
3532            cw.p('/* Enums */')
3533            put_op_name_fwd(parsed, cw)
3534
3535            for name, const in parsed.consts.items():
3536                if isinstance(const, EnumSet):
3537                    put_enum_to_str_fwd(parsed, cw, const)
3538            cw.nl()
3539
3540            cw.p('/* Common nested types */')
3541            for attr_set, struct in parsed.pure_nested_structs.items():
3542                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3543                print_type_full(ri, struct)
3544
3545            for op_name, op in parsed.ops.items():
3546                cw.p(f"/* ============== {op.enum_name} ============== */")
3547
3548                if 'do' in op and 'event' not in op:
3549                    cw.p(f"/* {op.enum_name} - do */")
3550                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3551                    print_req_type(ri)
3552                    print_req_type_helpers(ri)
3553                    cw.nl()
3554                    print_rsp_type(ri)
3555                    print_rsp_type_helpers(ri)
3556                    cw.nl()
3557                    print_req_prototype(ri)
3558                    cw.nl()
3559
3560                if 'dump' in op:
3561                    cw.p(f"/* {op.enum_name} - dump */")
3562                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3563                    print_req_type(ri)
3564                    print_req_type_helpers(ri)
3565                    if not ri.type_consistent or ri.type_oneside:
3566                        print_rsp_type(ri)
3567                    print_wrapped_type(ri)
3568                    print_dump_prototype(ri)
3569                    cw.nl()
3570
3571                if op.has_ntf:
3572                    cw.p(f"/* {op.enum_name} - notify */")
3573                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3574                    if not ri.type_consistent:
3575                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3576                    print_wrapped_type(ri)
3577
3578            for op_name, op in parsed.ntfs.items():
3579                if 'event' in op:
3580                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3581                    cw.p(f"/* {op.enum_name} - event */")
3582                    print_rsp_type(ri)
3583                    cw.nl()
3584                    print_wrapped_type(ri)
3585            cw.nl()
3586        else:
3587            cw.p('/* Enums */')
3588            put_op_name(parsed, cw)
3589
3590            for name, const in parsed.consts.items():
3591                if isinstance(const, EnumSet):
3592                    put_enum_to_str(parsed, cw, const)
3593            cw.nl()
3594
3595            has_recursive_nests = False
3596            cw.p('/* Policies */')
3597            for struct in parsed.pure_nested_structs.values():
3598                if struct.recursive:
3599                    put_typol_fwd(cw, struct)
3600                    has_recursive_nests = True
3601            if has_recursive_nests:
3602                cw.nl()
3603            for struct in parsed.pure_nested_structs.values():
3604                put_typol(cw, struct)
3605            for name in parsed.root_sets:
3606                struct = Struct(parsed, name)
3607                put_typol(cw, struct)
3608
3609            cw.p('/* Common nested types */')
3610            if has_recursive_nests:
3611                for attr_set, struct in parsed.pure_nested_structs.items():
3612                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3613                    free_rsp_nested_prototype(ri)
3614                    if struct.request:
3615                        put_req_nested_prototype(ri, struct)
3616                    if struct.reply:
3617                        parse_rsp_nested_prototype(ri, struct)
3618                cw.nl()
3619            for attr_set, struct in parsed.pure_nested_structs.items():
3620                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3621
3622                free_rsp_nested(ri, struct)
3623                if struct.request:
3624                    put_req_nested(ri, struct)
3625                if struct.reply:
3626                    parse_rsp_nested(ri, struct)
3627
3628            for op_name, op in parsed.ops.items():
3629                cw.p(f"/* ============== {op.enum_name} ============== */")
3630                if 'do' in op and 'event' not in op:
3631                    cw.p(f"/* {op.enum_name} - do */")
3632                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3633                    print_req_free(ri)
3634                    print_rsp_free(ri)
3635                    parse_rsp_msg(ri)
3636                    print_req(ri)
3637                    cw.nl()
3638
3639                if 'dump' in op:
3640                    cw.p(f"/* {op.enum_name} - dump */")
3641                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3642                    if not ri.type_consistent or ri.type_oneside:
3643                        parse_rsp_msg(ri, deref=True)
3644                    print_req_free(ri)
3645                    print_dump_type_free(ri)
3646                    print_dump(ri)
3647                    cw.nl()
3648
3649                if op.has_ntf:
3650                    cw.p(f"/* {op.enum_name} - notify */")
3651                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3652                    if not ri.type_consistent:
3653                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3654                    print_ntf_type_free(ri)
3655
3656            for op_name, op in parsed.ntfs.items():
3657                if 'event' in op:
3658                    cw.p(f"/* {op.enum_name} - event */")
3659
3660                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3661                    parse_rsp_msg(ri)
3662
3663                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3664                    print_ntf_type_free(ri)
3665            cw.nl()
3666            render_user_family(parsed, cw, False)
3667
3668    if args.header:
3669        cw.p(f'#endif /* {hdr_prot} */')
3670
3671
3672if __name__ == "__main__":
3673    main()
3674