xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision d103f26a5c8599385acb2d2e01dfbaedb00fdc0a)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import filecmp
6import pathlib
7import os
8import re
9import shutil
10import sys
11import tempfile
12import yaml
13
14sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
15from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
16from lib import SpecSubMessage
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        self.is_selector = False
60
61        if 'len' in attr:
62            self.len = attr['len']
63
64        if 'nested-attributes' in attr:
65            nested = attr['nested-attributes']
66        elif 'sub-message' in attr:
67            nested = attr['sub-message']
68        else:
69            nested = None
70
71        if nested:
72            self.nested_attrs = nested
73            if self.nested_attrs == family.name:
74                self.nested_render_name = c_lower(f"{family.ident_name}")
75            else:
76                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
77
78            if self.nested_attrs in self.family.consts:
79                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
80            else:
81                self.nested_struct_type = 'struct ' + self.nested_render_name
82
83        self.c_name = c_lower(self.name)
84        if self.c_name in _C_KW:
85            self.c_name += '_'
86        if self.c_name[0].isdigit():
87            self.c_name = '_' + self.c_name
88
89        # Added by resolve():
90        self.enum_name = None
91        delattr(self, "enum_name")
92
93    def _get_real_attr(self):
94        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
95        return self.family.attr_sets[self.attr_set.subset_of][self.name]
96
97    def set_request(self):
98        self.request = True
99        if self.attr_set.subset_of:
100            self._get_real_attr().set_request()
101
102    def set_reply(self):
103        self.reply = True
104        if self.attr_set.subset_of:
105            self._get_real_attr().set_reply()
106
107    def get_limit(self, limit, default=None):
108        value = self.checks.get(limit, default)
109        if value is None:
110            return value
111        if isinstance(value, int):
112            return value
113        if value in self.family.consts:
114            return self.family.consts[value]["value"]
115        return limit_to_number(value)
116
117    def get_limit_str(self, limit, default=None, suffix=''):
118        value = self.checks.get(limit, default)
119        if value is None:
120            return ''
121        if isinstance(value, int):
122            return str(value) + suffix
123        if value in self.family.consts:
124            const = self.family.consts[value]
125            if const.get('header'):
126                return c_upper(value)
127            return c_upper(f"{self.family['name']}-{value}")
128        return c_upper(value)
129
130    def resolve(self):
131        if 'parent-sub-message' in self.attr:
132            enum_name = self.attr['parent-sub-message'].enum_name
133        elif 'name-prefix' in self.attr:
134            enum_name = f"{self.attr['name-prefix']}{self.name}"
135        else:
136            enum_name = f"{self.attr_set.name_prefix}{self.name}"
137        self.enum_name = c_upper(enum_name)
138
139        if self.attr_set.subset_of:
140            if self.checks != self._get_real_attr().checks:
141                raise Exception("Overriding checks not supported by codegen, yet")
142
143    def is_multi_val(self):
144        return None
145
146    def is_scalar(self):
147        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
148
149    def is_recursive(self):
150        return False
151
152    def is_recursive_for_op(self, ri):
153        return self.is_recursive() and not ri.op
154
155    def presence_type(self):
156        return 'present'
157
158    def presence_member(self, space, type_filter):
159        if self.presence_type() != type_filter:
160            return
161
162        if self.presence_type() == 'present':
163            pfx = '__' if space == 'user' else ''
164            return f"{pfx}u32 {self.c_name}:1;"
165
166        if self.presence_type() in {'len', 'count'}:
167            pfx = '__' if space == 'user' else ''
168            return f"{pfx}u32 {self.c_name};"
169
170    def _complex_member_type(self, ri):
171        return None
172
173    def free_needs_iter(self):
174        return False
175
176    def _free_lines(self, ri, var, ref):
177        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
178            return [f'free({var}->{ref}{self.c_name});']
179        return []
180
181    def free(self, ri, var, ref):
182        lines = self._free_lines(ri, var, ref)
183        for line in lines:
184            ri.cw.p(line)
185
186    def arg_member(self, ri):
187        member = self._complex_member_type(ri)
188        if member:
189            spc = ' ' if member[-1] != '*' else ''
190            arg = [member + spc + '*' + self.c_name]
191            if self.presence_type() == 'count':
192                arg += ['unsigned int n_' + self.c_name]
193            return arg
194        raise Exception(f"Struct member not implemented for class type {self.type}")
195
196    def struct_member(self, ri):
197        member = self._complex_member_type(ri)
198        if member:
199            ptr = '*' if self.is_multi_val() else ''
200            if self.is_recursive_for_op(ri):
201                ptr = '*'
202            spc = ' ' if member[-1] != '*' else ''
203            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
204            return
205        members = self.arg_member(ri)
206        for one in members:
207            ri.cw.p(one + ';')
208
209    def _attr_policy(self, policy):
210        return '{ .type = ' + policy + ', }'
211
212    def attr_policy(self, cw):
213        policy = f'NLA_{c_upper(self.type)}'
214        if self.attr.get('byte-order') == 'big-endian':
215            if self.type in {'u16', 'u32'}:
216                policy = f'NLA_BE{self.type[1:]}'
217
218        spec = self._attr_policy(policy)
219        cw.p(f"\t[{self.enum_name}] = {spec},")
220
221    def _attr_typol(self):
222        raise Exception(f"Type policy not implemented for class type {self.type}")
223
224    def attr_typol(self, cw):
225        typol = self._attr_typol()
226        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
227
228    def _attr_put_line(self, ri, var, line):
229        presence = self.presence_type()
230        if presence in {'present', 'len'}:
231            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
232        ri.cw.p(f"{line};")
233
234    def _attr_put_simple(self, ri, var, put_type):
235        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
236        self._attr_put_line(ri, var, line)
237
238    def attr_put(self, ri, var):
239        raise Exception(f"Put not implemented for class type {self.type}")
240
241    def _attr_get(self, ri, var):
242        raise Exception(f"Attr get not implemented for class type {self.type}")
243
244    def attr_get(self, ri, var, first):
245        lines, init_lines, local_vars = self._attr_get(ri, var)
246        if type(lines) is str:
247            lines = [lines]
248        if type(init_lines) is str:
249            init_lines = [init_lines]
250
251        kw = 'if' if first else 'else if'
252        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
253        if local_vars:
254            for local in local_vars:
255                ri.cw.p(local)
256            ri.cw.nl()
257
258        if not self.is_multi_val():
259            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
260            ri.cw.p("return YNL_PARSE_CB_ERROR;")
261            if self.presence_type() == 'present':
262                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
263
264        if init_lines:
265            ri.cw.nl()
266            for line in init_lines:
267                ri.cw.p(line)
268
269        for line in lines:
270            ri.cw.p(line)
271        ri.cw.block_end()
272        return True
273
274    def _setter_lines(self, ri, member, presence):
275        raise Exception(f"Setter not implemented for class type {self.type}")
276
277    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
278        ref = (ref if ref else []) + [self.c_name]
279        member = f"{var}->{'.'.join(ref)}"
280
281        local_vars = []
282        if self.free_needs_iter():
283            local_vars += ['unsigned int i;']
284
285        code = []
286        presence = ''
287        for i in range(0, len(ref)):
288            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
289            # Every layer below last is a nest, so we know it uses bit presence
290            # last layer is "self" and may be a complex type
291            if i == len(ref) - 1 and self.presence_type() != 'present':
292                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
293                continue
294            code.append(presence + ' = 1;')
295        ref_path = '.'.join(ref[:-1])
296        if ref_path:
297            ref_path += '.'
298        code += self._free_lines(ri, var, ref_path)
299        code += self._setter_lines(ri, member, presence)
300
301        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
302        free = bool([x for x in code if 'free(' in x])
303        alloc = bool([x for x in code if 'alloc(' in x])
304        if free and not alloc:
305            func_name = '__' + func_name
306        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
307                         body=code,
308                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
309
310
311class TypeUnused(Type):
312    def presence_type(self):
313        return ''
314
315    def arg_member(self, ri):
316        return []
317
318    def _attr_get(self, ri, var):
319        return ['return YNL_PARSE_CB_ERROR;'], None, None
320
321    def _attr_typol(self):
322        return '.type = YNL_PT_REJECT, '
323
324    def attr_policy(self, cw):
325        pass
326
327    def attr_put(self, ri, var):
328        pass
329
330    def attr_get(self, ri, var, first):
331        pass
332
333    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
334        pass
335
336
337class TypePad(Type):
338    def presence_type(self):
339        return ''
340
341    def arg_member(self, ri):
342        return []
343
344    def _attr_typol(self):
345        return '.type = YNL_PT_IGNORE, '
346
347    def attr_put(self, ri, var):
348        pass
349
350    def attr_get(self, ri, var, first):
351        pass
352
353    def attr_policy(self, cw):
354        pass
355
356    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
357        pass
358
359
360class TypeScalar(Type):
361    def __init__(self, family, attr_set, attr, value):
362        super().__init__(family, attr_set, attr, value)
363
364        self.byte_order_comment = ''
365        if 'byte-order' in attr:
366            self.byte_order_comment = f" /* {attr['byte-order']} */"
367
368        # Classic families have some funny enums, don't bother
369        # computing checks, since we only need them for kernel policies
370        if not family.is_classic():
371            self._init_checks()
372
373        # Added by resolve():
374        self.is_bitfield = None
375        delattr(self, "is_bitfield")
376        self.type_name = None
377        delattr(self, "type_name")
378
379    def resolve(self):
380        self.resolve_up(super())
381
382        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
383            self.is_bitfield = True
384        elif 'enum' in self.attr:
385            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
386        else:
387            self.is_bitfield = False
388
389        if not self.is_bitfield and 'enum' in self.attr:
390            self.type_name = self.family.consts[self.attr['enum']].user_type
391        elif self.is_auto_scalar:
392            self.type_name = '__' + self.type[0] + '64'
393        else:
394            self.type_name = '__' + self.type
395
396    def _init_checks(self):
397        if 'enum' in self.attr:
398            enum = self.family.consts[self.attr['enum']]
399            low, high = enum.value_range()
400            if low is None and high is None:
401                self.checks['sparse'] = True
402            else:
403                if 'min' not in self.checks:
404                    if low != 0 or self.type[0] == 's':
405                        self.checks['min'] = low
406                if 'max' not in self.checks:
407                    self.checks['max'] = high
408
409        if 'min' in self.checks and 'max' in self.checks:
410            if self.get_limit('min') > self.get_limit('max'):
411                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
412            self.checks['range'] = True
413
414        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
415        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
416        if low < 0 and self.type[0] == 'u':
417            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
418        if low < -32768 or high > 32767:
419            self.checks['full-range'] = True
420
421    def _attr_policy(self, policy):
422        if 'flags-mask' in self.checks or self.is_bitfield:
423            if self.is_bitfield:
424                enum = self.family.consts[self.attr['enum']]
425                mask = enum.get_mask(as_flags=True)
426            else:
427                flags = self.family.consts[self.checks['flags-mask']]
428                flag_cnt = len(flags['entries'])
429                mask = (1 << flag_cnt) - 1
430            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
431        elif 'full-range' in self.checks:
432            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
433        elif 'range' in self.checks:
434            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
435        elif 'min' in self.checks:
436            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
437        elif 'max' in self.checks:
438            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
439        elif 'sparse' in self.checks:
440            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
441        return super()._attr_policy(policy)
442
443    def _attr_typol(self):
444        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
445
446    def arg_member(self, ri):
447        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
448
449    def attr_put(self, ri, var):
450        self._attr_put_simple(ri, var, self.type)
451
452    def _attr_get(self, ri, var):
453        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
454
455    def _setter_lines(self, ri, member, presence):
456        return [f"{member} = {self.c_name};"]
457
458
459class TypeFlag(Type):
460    def arg_member(self, ri):
461        return []
462
463    def _attr_typol(self):
464        return '.type = YNL_PT_FLAG, '
465
466    def attr_put(self, ri, var):
467        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
468
469    def _attr_get(self, ri, var):
470        return [], None, None
471
472    def _setter_lines(self, ri, member, presence):
473        return []
474
475
476class TypeString(Type):
477    def arg_member(self, ri):
478        return [f"const char *{self.c_name}"]
479
480    def presence_type(self):
481        return 'len'
482
483    def struct_member(self, ri):
484        ri.cw.p(f"char *{self.c_name};")
485
486    def _attr_typol(self):
487        typol = '.type = YNL_PT_NUL_STR, '
488        if self.is_selector:
489            typol += '.is_selector = 1, '
490        return typol
491
492    def _attr_policy(self, policy):
493        if 'exact-len' in self.checks:
494            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
495        else:
496            mem = '{ .type = ' + policy
497            if 'max-len' in self.checks:
498                mem += ', .len = ' + self.get_limit_str('max-len')
499            mem += ', }'
500        return mem
501
502    def attr_policy(self, cw):
503        if self.checks.get('unterminated-ok', False):
504            policy = 'NLA_STRING'
505        else:
506            policy = 'NLA_NUL_STRING'
507
508        spec = self._attr_policy(policy)
509        cw.p(f"\t[{self.enum_name}] = {spec},")
510
511    def attr_put(self, ri, var):
512        self._attr_put_simple(ri, var, 'str')
513
514    def _attr_get(self, ri, var):
515        len_mem = var + '->_len.' + self.c_name
516        return [f"{len_mem} = len;",
517                f"{var}->{self.c_name} = malloc(len + 1);",
518                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
519                f"{var}->{self.c_name}[len] = 0;"], \
520               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
521               ['unsigned int len;']
522
523    def _setter_lines(self, ri, member, presence):
524        return [f"{presence} = strlen({self.c_name});",
525                f"{member} = malloc({presence} + 1);",
526                f'memcpy({member}, {self.c_name}, {presence});',
527                f'{member}[{presence}] = 0;']
528
529
530class TypeBinary(Type):
531    def arg_member(self, ri):
532        return [f"const void *{self.c_name}", 'size_t len']
533
534    def presence_type(self):
535        return 'len'
536
537    def struct_member(self, ri):
538        ri.cw.p(f"void *{self.c_name};")
539
540    def _attr_typol(self):
541        return '.type = YNL_PT_BINARY,'
542
543    def _attr_policy(self, policy):
544        if len(self.checks) == 0:
545            pass
546        elif len(self.checks) == 1:
547            check_name = list(self.checks)[0]
548            if check_name not in {'exact-len', 'min-len', 'max-len'}:
549                raise Exception('Unsupported check for binary type: ' + check_name)
550        else:
551            raise Exception('More than one check for binary type not implemented, yet')
552
553        if len(self.checks) == 0:
554            mem = '{ .type = NLA_BINARY, }'
555        elif 'exact-len' in self.checks:
556            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
557        elif 'min-len' in self.checks:
558            mem = 'NLA_POLICY_MIN_LEN(' + self.get_limit_str('min-len') + ')'
559        elif 'max-len' in self.checks:
560            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
561
562        return mem
563
564    def attr_put(self, ri, var):
565        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
566                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
567
568    def _attr_get(self, ri, var):
569        len_mem = var + '->_len.' + self.c_name
570        return [f"{len_mem} = len;",
571                f"{var}->{self.c_name} = malloc(len);",
572                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
573               ['len = ynl_attr_data_len(attr);'], \
574               ['unsigned int len;']
575
576    def _setter_lines(self, ri, member, presence):
577        return [f"{presence} = len;",
578                f"{member} = malloc({presence});",
579                f'memcpy({member}, {self.c_name}, {presence});']
580
581
582class TypeBinaryStruct(TypeBinary):
583    def struct_member(self, ri):
584        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
585
586    def _attr_get(self, ri, var):
587        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
588        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
589        return [f"{len_mem} = len;",
590                f"if (len < {struct_sz})",
591                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
592                "else",
593                f"{var}->{self.c_name} = malloc(len);",
594                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
595               ['len = ynl_attr_data_len(attr);'], \
596               ['unsigned int len;']
597
598
599class TypeBinaryScalarArray(TypeBinary):
600    def arg_member(self, ri):
601        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
602
603    def presence_type(self):
604        return 'count'
605
606    def struct_member(self, ri):
607        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
608
609    def attr_put(self, ri, var):
610        presence = self.presence_type()
611        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
612        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
613        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
614                f"{var}->{self.c_name}, i);")
615        ri.cw.block_end()
616
617    def _attr_get(self, ri, var):
618        len_mem = var + '->_count.' + self.c_name
619        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
620                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
621                f"{var}->{self.c_name} = malloc(len);",
622                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
623               ['len = ynl_attr_data_len(attr);'], \
624               ['unsigned int len;']
625
626    def _setter_lines(self, ri, member, presence):
627        return [f"{presence} = count;",
628                f"count *= sizeof(__{self.get('sub-type')});",
629                f"{member} = malloc(count);",
630                f'memcpy({member}, {self.c_name}, count);']
631
632
633class TypeBitfield32(Type):
634    def _complex_member_type(self, ri):
635        return "struct nla_bitfield32"
636
637    def _attr_typol(self):
638        return '.type = YNL_PT_BITFIELD32, '
639
640    def _attr_policy(self, policy):
641        if 'enum' not in self.attr:
642            raise Exception('Enum required for bitfield32 attr')
643        enum = self.family.consts[self.attr['enum']]
644        mask = enum.get_mask(as_flags=True)
645        return f"NLA_POLICY_BITFIELD32({mask})"
646
647    def attr_put(self, ri, var):
648        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
649        self._attr_put_line(ri, var, line)
650
651    def _attr_get(self, ri, var):
652        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
653
654    def _setter_lines(self, ri, member, presence):
655        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
656
657
658class TypeNest(Type):
659    def is_recursive(self):
660        return self.family.pure_nested_structs[self.nested_attrs].recursive
661
662    def _complex_member_type(self, ri):
663        return self.nested_struct_type
664
665    def _free_lines(self, ri, var, ref):
666        lines = []
667        at = '&'
668        if self.is_recursive_for_op(ri):
669            at = ''
670            lines += [f'if ({var}->{ref}{self.c_name})']
671        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
672        return lines
673
674    def _attr_typol(self):
675        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
676
677    def _attr_policy(self, policy):
678        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
679
680    def attr_put(self, ri, var):
681        at = '' if self.is_recursive_for_op(ri) else '&'
682        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
683                            f"{self.enum_name}, {at}{var}->{self.c_name})")
684
685    def _attr_get(self, ri, var):
686        pns = self.family.pure_nested_structs[self.nested_attrs]
687        args = ["&parg", "attr"]
688        for sel in pns.external_selectors():
689            args.append(f'{var}->{sel.name}')
690        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
691                     "return YNL_PARSE_CB_ERROR;"]
692        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
693                      f"parg.data = &{var}->{self.c_name};"]
694        return get_lines, init_lines, None
695
696    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
697        ref = (ref if ref else []) + [self.c_name]
698
699        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
700            if attr.is_recursive():
701                continue
702            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref,
703                        var=var)
704
705
706class TypeMultiAttr(Type):
707    def __init__(self, family, attr_set, attr, value, base_type):
708        super().__init__(family, attr_set, attr, value)
709
710        self.base_type = base_type
711
712    def is_multi_val(self):
713        return True
714
715    def presence_type(self):
716        return 'count'
717
718    def _complex_member_type(self, ri):
719        if 'type' not in self.attr or self.attr['type'] == 'nest':
720            return self.nested_struct_type
721        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
722            return None  # use arg_member()
723        elif self.attr['type'] == 'string':
724            return 'struct ynl_string *'
725        elif self.attr['type'] in scalars:
726            scalar_pfx = '__' if ri.ku_space == 'user' else ''
727            return scalar_pfx + self.attr['type']
728        else:
729            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
730
731    def arg_member(self, ri):
732        if self.type == 'binary' and 'struct' in self.attr:
733            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
734                    f'unsigned int n_{self.c_name}']
735        return super().arg_member(ri)
736
737    def free_needs_iter(self):
738        return self.attr['type'] in {'nest', 'string'}
739
740    def _free_lines(self, ri, var, ref):
741        lines = []
742        if self.attr['type'] in scalars:
743            lines += [f"free({var}->{ref}{self.c_name});"]
744        elif self.attr['type'] == 'binary':
745            lines += [f"free({var}->{ref}{self.c_name});"]
746        elif self.attr['type'] == 'string':
747            lines += [
748                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
749                f"free({var}->{ref}{self.c_name}[i]);",
750                f"free({var}->{ref}{self.c_name});",
751            ]
752        elif 'type' not in self.attr or self.attr['type'] == 'nest':
753            lines += [
754                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
755                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
756                f"free({var}->{ref}{self.c_name});",
757            ]
758        else:
759            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
760        return lines
761
762    def _attr_policy(self, policy):
763        return self.base_type._attr_policy(policy)
764
765    def _attr_typol(self):
766        return self.base_type._attr_typol()
767
768    def _attr_get(self, ri, var):
769        return f'n_{self.c_name}++;', None, None
770
771    def attr_put(self, ri, var):
772        if self.attr['type'] in scalars:
773            put_type = self.type
774            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
775            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
776        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
777            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
778            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
779        elif self.attr['type'] == 'string':
780            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
781            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
782        elif 'type' not in self.attr or self.attr['type'] == 'nest':
783            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
784            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
785                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
786        else:
787            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
788
789    def _setter_lines(self, ri, member, presence):
790        return [f"{member} = {self.c_name};",
791                f"{presence} = n_{self.c_name};"]
792
793
794class TypeArrayNest(Type):
795    def is_multi_val(self):
796        return True
797
798    def presence_type(self):
799        return 'count'
800
801    def _complex_member_type(self, ri):
802        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
803            return self.nested_struct_type
804        elif self.attr['sub-type'] in scalars:
805            scalar_pfx = '__' if ri.ku_space == 'user' else ''
806            return scalar_pfx + self.attr['sub-type']
807        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
808            return None  # use arg_member()
809        else:
810            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
811
812    def arg_member(self, ri):
813        if self.sub_type == 'binary' and 'exact-len' in self.checks:
814            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
815                    f'unsigned int n_{self.c_name}']
816        return super().arg_member(ri)
817
818    def _attr_typol(self):
819        if self.attr['sub-type'] in scalars:
820            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
821        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
822            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
823        else:
824            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
825
826    def _attr_get(self, ri, var):
827        local_vars = ['const struct nlattr *attr2;']
828        get_lines = [f'attr_{self.c_name} = attr;',
829                     'ynl_attr_for_each_nested(attr2, attr) {',
830                     '\tif (ynl_attr_validate(yarg, attr2))',
831                     '\t\treturn YNL_PARSE_CB_ERROR;',
832                     f'\tn_{self.c_name}++;',
833                     '}']
834        return get_lines, None, local_vars
835
836    def attr_put(self, ri, var):
837        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
838        if self.sub_type in scalars:
839            put_type = self.sub_type
840            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
841            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
842            ri.cw.block_end()
843        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
844            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
845            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
846        elif self.sub_type == 'nest':
847            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
848            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
849        else:
850            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
851        ri.cw.p('ynl_attr_nest_end(nlh, array);')
852
853    def _setter_lines(self, ri, member, presence):
854        return [f"{member} = {self.c_name};",
855                f"{presence} = n_{self.c_name};"]
856
857
858class TypeNestTypeValue(Type):
859    def _complex_member_type(self, ri):
860        return self.nested_struct_type
861
862    def _attr_typol(self):
863        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
864
865    def _attr_get(self, ri, var):
866        prev = 'attr'
867        tv_args = ''
868        get_lines = []
869        local_vars = []
870        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
871                      f"parg.data = &{var}->{self.c_name};"]
872        if 'type-value' in self.attr:
873            tv_names = [c_lower(x) for x in self.attr["type-value"]]
874            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
875            local_vars += [f'__u32 {", ".join(tv_names)};']
876            for level in self.attr["type-value"]:
877                level = c_lower(level)
878                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
879                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
880                prev = 'attr_' + level
881
882            tv_args = f", {', '.join(tv_names)}"
883
884        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
885        return get_lines, init_lines, local_vars
886
887
888class TypeSubMessage(TypeNest):
889    def __init__(self, family, attr_set, attr, value):
890        super().__init__(family, attr_set, attr, value)
891
892        self.selector = Selector(attr, attr_set)
893
894    def _attr_typol(self):
895        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
896        typol += '.is_submsg = 1, '
897        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
898        # support external selectors. No family uses sub-messages with external
899        # selector for requests so this is fine for now.
900        if not self.selector.is_external():
901            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
902        return typol
903
904    def _attr_get(self, ri, var):
905        sel = c_lower(self['selector'])
906        if self.selector.is_external():
907            sel_var = f"_sel_{sel}"
908        else:
909            sel_var = f"{var}->{sel}"
910        get_lines = [f'if (!{sel_var})',
911                     'return ynl_submsg_failed(yarg, "%s", "%s");' %
912                        (self.name, self['selector']),
913                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
914                     "return YNL_PARSE_CB_ERROR;"]
915        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
916                      f"parg.data = &{var}->{self.c_name};"]
917        return get_lines, init_lines, None
918
919
920class Selector:
921    def __init__(self, msg_attr, attr_set):
922        self.name = msg_attr["selector"]
923
924        if self.name in attr_set:
925            self.attr = attr_set[self.name]
926            self.attr.is_selector = True
927            self._external = False
928        else:
929            # The selector will need to get passed down thru the structs
930            self.attr = None
931            self._external = True
932
933    def set_attr(self, attr):
934        self.attr = attr
935
936    def is_external(self):
937        return self._external
938
939
940class Struct:
941    def __init__(self, family, space_name, type_list=None, fixed_header=None,
942                 inherited=None, submsg=None):
943        self.family = family
944        self.space_name = space_name
945        self.attr_set = family.attr_sets[space_name]
946        # Use list to catch comparisons with empty sets
947        self._inherited = inherited if inherited is not None else []
948        self.inherited = []
949        self.fixed_header = None
950        if fixed_header:
951            self.fixed_header = 'struct ' + c_lower(fixed_header)
952        self.submsg = submsg
953
954        self.nested = type_list is None
955        if family.name == c_lower(space_name):
956            self.render_name = c_lower(family.ident_name)
957        else:
958            self.render_name = c_lower(family.ident_name + '-' + space_name)
959        self.struct_name = 'struct ' + self.render_name
960        if self.nested and space_name in family.consts:
961            self.struct_name += '_'
962        self.ptr_name = self.struct_name + ' *'
963        # All attr sets this one contains, directly or multiple levels down
964        self.child_nests = set()
965
966        self.request = False
967        self.reply = False
968        self.recursive = False
969        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
970
971        self.attr_list = []
972        self.attrs = dict()
973        if type_list is not None:
974            for t in type_list:
975                self.attr_list.append((t, self.attr_set[t]),)
976        else:
977            for t in self.attr_set:
978                self.attr_list.append((t, self.attr_set[t]),)
979
980        max_val = 0
981        self.attr_max_val = None
982        for name, attr in self.attr_list:
983            if attr.value >= max_val:
984                max_val = attr.value
985                self.attr_max_val = attr
986            self.attrs[name] = attr
987
988    def __iter__(self):
989        yield from self.attrs
990
991    def __getitem__(self, key):
992        return self.attrs[key]
993
994    def member_list(self):
995        return self.attr_list
996
997    def set_inherited(self, new_inherited):
998        if self._inherited != new_inherited:
999            raise Exception("Inheriting different members not supported")
1000        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1001
1002    def external_selectors(self):
1003        sels = []
1004        for name, attr in self.attr_list:
1005            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1006                sels.append(attr.selector)
1007        return sels
1008
1009    def free_needs_iter(self):
1010        for _, attr in self.attr_list:
1011            if attr.free_needs_iter():
1012                return True
1013        return False
1014
1015
1016class EnumEntry(SpecEnumEntry):
1017    def __init__(self, enum_set, yaml, prev, value_start):
1018        super().__init__(enum_set, yaml, prev, value_start)
1019
1020        if prev:
1021            self.value_change = (self.value != prev.value + 1)
1022        else:
1023            self.value_change = (self.value != 0)
1024        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1025
1026        # Added by resolve:
1027        self.c_name = None
1028        delattr(self, "c_name")
1029
1030    def resolve(self):
1031        self.resolve_up(super())
1032
1033        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1034
1035
1036class EnumSet(SpecEnumSet):
1037    def __init__(self, family, yaml):
1038        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1039
1040        if 'enum-name' in yaml:
1041            if yaml['enum-name']:
1042                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1043                self.user_type = self.enum_name
1044            else:
1045                self.enum_name = None
1046        else:
1047            self.enum_name = 'enum ' + self.render_name
1048
1049        if self.enum_name:
1050            self.user_type = self.enum_name
1051        else:
1052            self.user_type = 'int'
1053
1054        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1055        self.header = yaml.get('header', None)
1056        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1057
1058        super().__init__(family, yaml)
1059
1060    def new_entry(self, entry, prev_entry, value_start):
1061        return EnumEntry(self, entry, prev_entry, value_start)
1062
1063    def value_range(self):
1064        low = min([x.value for x in self.entries.values()])
1065        high = max([x.value for x in self.entries.values()])
1066
1067        if high - low + 1 != len(self.entries):
1068            return None, None
1069
1070        return low, high
1071
1072
1073class AttrSet(SpecAttrSet):
1074    def __init__(self, family, yaml):
1075        super().__init__(family, yaml)
1076
1077        if self.subset_of is None:
1078            if 'name-prefix' in yaml:
1079                pfx = yaml['name-prefix']
1080            elif self.name == family.name:
1081                pfx = family.ident_name + '-a-'
1082            else:
1083                pfx = f"{family.ident_name}-a-{self.name}-"
1084            self.name_prefix = c_upper(pfx)
1085            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1086            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1087        else:
1088            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1089            self.max_name = family.attr_sets[self.subset_of].max_name
1090            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1091
1092        # Added by resolve:
1093        self.c_name = None
1094        delattr(self, "c_name")
1095
1096    def resolve(self):
1097        self.c_name = c_lower(self.name)
1098        if self.c_name in _C_KW:
1099            self.c_name += '_'
1100        if self.c_name == self.family.c_name:
1101            self.c_name = ''
1102
1103    def new_attr(self, elem, value):
1104        if elem['type'] in scalars:
1105            t = TypeScalar(self.family, self, elem, value)
1106        elif elem['type'] == 'unused':
1107            t = TypeUnused(self.family, self, elem, value)
1108        elif elem['type'] == 'pad':
1109            t = TypePad(self.family, self, elem, value)
1110        elif elem['type'] == 'flag':
1111            t = TypeFlag(self.family, self, elem, value)
1112        elif elem['type'] == 'string':
1113            t = TypeString(self.family, self, elem, value)
1114        elif elem['type'] == 'binary':
1115            if 'struct' in elem:
1116                t = TypeBinaryStruct(self.family, self, elem, value)
1117            elif elem.get('sub-type') in scalars:
1118                t = TypeBinaryScalarArray(self.family, self, elem, value)
1119            else:
1120                t = TypeBinary(self.family, self, elem, value)
1121        elif elem['type'] == 'bitfield32':
1122            t = TypeBitfield32(self.family, self, elem, value)
1123        elif elem['type'] == 'nest':
1124            t = TypeNest(self.family, self, elem, value)
1125        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1126            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1127                t = TypeArrayNest(self.family, self, elem, value)
1128            else:
1129                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1130        elif elem['type'] == 'nest-type-value':
1131            t = TypeNestTypeValue(self.family, self, elem, value)
1132        elif elem['type'] == 'sub-message':
1133            t = TypeSubMessage(self.family, self, elem, value)
1134        else:
1135            raise Exception(f"No typed class for type {elem['type']}")
1136
1137        if 'multi-attr' in elem and elem['multi-attr']:
1138            t = TypeMultiAttr(self.family, self, elem, value, t)
1139
1140        return t
1141
1142
1143class Operation(SpecOperation):
1144    def __init__(self, family, yaml, req_value, rsp_value):
1145        # Fill in missing operation properties (for fixed hdr-only msgs)
1146        for mode in ['do', 'dump', 'event']:
1147            for direction in ['request', 'reply']:
1148                try:
1149                    yaml[mode][direction].setdefault('attributes', [])
1150                except KeyError:
1151                    pass
1152
1153        super().__init__(family, yaml, req_value, rsp_value)
1154
1155        self.render_name = c_lower(family.ident_name + '_' + self.name)
1156
1157        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1158                         ('dump' in yaml and 'request' in yaml['dump'])
1159
1160        self.has_ntf = False
1161
1162        # Added by resolve:
1163        self.enum_name = None
1164        delattr(self, "enum_name")
1165
1166    def resolve(self):
1167        self.resolve_up(super())
1168
1169        if not self.is_async:
1170            self.enum_name = self.family.op_prefix + c_upper(self.name)
1171        else:
1172            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1173
1174    def mark_has_ntf(self):
1175        self.has_ntf = True
1176
1177
1178class SubMessage(SpecSubMessage):
1179    def __init__(self, family, yaml):
1180        super().__init__(family, yaml)
1181
1182        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1183
1184    def resolve(self):
1185        self.resolve_up(super())
1186
1187
1188class Family(SpecFamily):
1189    def __init__(self, file_name, exclude_ops):
1190        # Added by resolve:
1191        self.c_name = None
1192        delattr(self, "c_name")
1193        self.op_prefix = None
1194        delattr(self, "op_prefix")
1195        self.async_op_prefix = None
1196        delattr(self, "async_op_prefix")
1197        self.mcgrps = None
1198        delattr(self, "mcgrps")
1199        self.consts = None
1200        delattr(self, "consts")
1201        self.hooks = None
1202        delattr(self, "hooks")
1203
1204        super().__init__(file_name, exclude_ops=exclude_ops)
1205
1206        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1207        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1208
1209        if 'definitions' not in self.yaml:
1210            self.yaml['definitions'] = []
1211
1212        if 'uapi-header' in self.yaml:
1213            self.uapi_header = self.yaml['uapi-header']
1214        else:
1215            self.uapi_header = f"linux/{self.ident_name}.h"
1216        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1217            self.uapi_header_name = self.uapi_header[6:-2]
1218        else:
1219            self.uapi_header_name = self.ident_name
1220
1221    def resolve(self):
1222        self.resolve_up(super())
1223
1224        self.c_name = c_lower(self.ident_name)
1225        if 'name-prefix' in self.yaml['operations']:
1226            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1227        else:
1228            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1229        if 'async-prefix' in self.yaml['operations']:
1230            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1231        else:
1232            self.async_op_prefix = self.op_prefix
1233
1234        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1235
1236        self.hooks = dict()
1237        for when in ['pre', 'post']:
1238            self.hooks[when] = dict()
1239            for op_mode in ['do', 'dump']:
1240                self.hooks[when][op_mode] = dict()
1241                self.hooks[when][op_mode]['set'] = set()
1242                self.hooks[when][op_mode]['list'] = []
1243
1244        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1245        self.root_sets = dict()
1246        # dict space-name -> Struct
1247        self.pure_nested_structs = dict()
1248
1249        self._mark_notify()
1250        self._mock_up_events()
1251
1252        self._load_root_sets()
1253        self._load_nested_sets()
1254        self._load_attr_use()
1255        self._load_selector_passing()
1256        self._load_hooks()
1257
1258        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1259        if self.kernel_policy == 'global':
1260            self._load_global_policy()
1261
1262    def new_enum(self, elem):
1263        return EnumSet(self, elem)
1264
1265    def new_attr_set(self, elem):
1266        return AttrSet(self, elem)
1267
1268    def new_operation(self, elem, req_value, rsp_value):
1269        return Operation(self, elem, req_value, rsp_value)
1270
1271    def new_sub_message(self, elem):
1272        return SubMessage(self, elem)
1273
1274    def is_classic(self):
1275        return self.proto == 'netlink-raw'
1276
1277    def _mark_notify(self):
1278        for op in self.msgs.values():
1279            if 'notify' in op:
1280                self.ops[op['notify']].mark_has_ntf()
1281
1282    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1283    def _mock_up_events(self):
1284        for op in self.yaml['operations']['list']:
1285            if 'event' in op:
1286                op['do'] = {
1287                    'reply': {
1288                        'attributes': op['event']['attributes']
1289                    }
1290                }
1291
1292    def _load_root_sets(self):
1293        for op_name, op in self.msgs.items():
1294            if 'attribute-set' not in op:
1295                continue
1296
1297            req_attrs = set()
1298            rsp_attrs = set()
1299            for op_mode in ['do', 'dump']:
1300                if op_mode in op and 'request' in op[op_mode]:
1301                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1302                if op_mode in op and 'reply' in op[op_mode]:
1303                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1304            if 'event' in op:
1305                rsp_attrs.update(set(op['event']['attributes']))
1306
1307            if op['attribute-set'] not in self.root_sets:
1308                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1309            else:
1310                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1311                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1312
1313    def _sort_pure_types(self):
1314        # Try to reorder according to dependencies
1315        pns_key_list = list(self.pure_nested_structs.keys())
1316        pns_key_seen = set()
1317        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1318        for _ in range(rounds):
1319            if len(pns_key_list) == 0:
1320                break
1321            name = pns_key_list.pop(0)
1322            finished = True
1323            for _, spec in self.attr_sets[name].items():
1324                if 'nested-attributes' in spec:
1325                    nested = spec['nested-attributes']
1326                elif 'sub-message' in spec:
1327                    nested = spec.sub_message
1328                else:
1329                    continue
1330
1331                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1332                if self.pure_nested_structs[nested].recursive:
1333                    continue
1334                if nested not in pns_key_seen:
1335                    # Dicts are sorted, this will make struct last
1336                    struct = self.pure_nested_structs.pop(name)
1337                    self.pure_nested_structs[name] = struct
1338                    finished = False
1339                    break
1340            if finished:
1341                pns_key_seen.add(name)
1342            else:
1343                pns_key_list.append(name)
1344
1345    def _load_nested_set_nest(self, spec):
1346        inherit = set()
1347        nested = spec['nested-attributes']
1348        if nested not in self.root_sets:
1349            if nested not in self.pure_nested_structs:
1350                self.pure_nested_structs[nested] = \
1351                    Struct(self, nested, inherited=inherit,
1352                           fixed_header=spec.get('fixed-header'))
1353        else:
1354            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1355
1356        if 'type-value' in spec:
1357            if nested in self.root_sets:
1358                raise Exception("Inheriting members to a space used as root not supported")
1359            inherit.update(set(spec['type-value']))
1360        elif spec['type'] == 'indexed-array':
1361            inherit.add('idx')
1362        self.pure_nested_structs[nested].set_inherited(inherit)
1363
1364        return nested
1365
1366    def _load_nested_set_submsg(self, spec):
1367        # Fake the struct type for the sub-message itself
1368        # its not a attr_set but codegen wants attr_sets.
1369        submsg = self.sub_msgs[spec["sub-message"]]
1370        nested = submsg.name
1371
1372        attrs = []
1373        for name, fmt in submsg.formats.items():
1374            attr = {
1375                "name": name,
1376                "parent-sub-message": spec,
1377            }
1378            if 'attribute-set' in fmt:
1379                attr |= {
1380                    "type": "nest",
1381                    "nested-attributes": fmt['attribute-set'],
1382                }
1383                if 'fixed-header' in fmt:
1384                    attr |= { "fixed-header": fmt["fixed-header"] }
1385            elif 'fixed-header' in fmt:
1386                attr |= {
1387                    "type": "binary",
1388                    "struct": fmt["fixed-header"],
1389                }
1390            else:
1391                attr["type"] = "flag"
1392            attrs.append(attr)
1393
1394        self.attr_sets[nested] = AttrSet(self, {
1395            "name": nested,
1396            "name-pfx": self.name + '-' + spec.name + '-',
1397            "attributes": attrs
1398        })
1399
1400        if nested not in self.pure_nested_structs:
1401            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1402
1403        return nested
1404
1405    def _load_nested_sets(self):
1406        attr_set_queue = list(self.root_sets.keys())
1407        attr_set_seen = set(self.root_sets.keys())
1408
1409        while len(attr_set_queue):
1410            a_set = attr_set_queue.pop(0)
1411            for attr, spec in self.attr_sets[a_set].items():
1412                if 'nested-attributes' in spec:
1413                    nested = self._load_nested_set_nest(spec)
1414                elif 'sub-message' in spec:
1415                    nested = self._load_nested_set_submsg(spec)
1416                else:
1417                    continue
1418
1419                if nested not in attr_set_seen:
1420                    attr_set_queue.append(nested)
1421                    attr_set_seen.add(nested)
1422
1423        for root_set, rs_members in self.root_sets.items():
1424            for attr, spec in self.attr_sets[root_set].items():
1425                if 'nested-attributes' in spec:
1426                    nested = spec['nested-attributes']
1427                elif 'sub-message' in spec:
1428                    nested = spec.sub_message
1429                else:
1430                    nested = None
1431
1432                if nested:
1433                    if attr in rs_members['request']:
1434                        self.pure_nested_structs[nested].request = True
1435                    if attr in rs_members['reply']:
1436                        self.pure_nested_structs[nested].reply = True
1437
1438                    if spec.is_multi_val():
1439                        child = self.pure_nested_structs.get(nested)
1440                        child.in_multi_val = True
1441
1442        self._sort_pure_types()
1443
1444        # Propagate the request / reply / recursive
1445        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1446            for _, spec in self.attr_sets[attr_set].items():
1447                if attr_set in struct.child_nests:
1448                    struct.recursive = True
1449
1450                if 'nested-attributes' in spec:
1451                    child_name = spec['nested-attributes']
1452                elif 'sub-message' in spec:
1453                    child_name = spec.sub_message
1454                else:
1455                    continue
1456
1457                struct.child_nests.add(child_name)
1458                child = self.pure_nested_structs.get(child_name)
1459                if child:
1460                    if not child.recursive:
1461                        struct.child_nests.update(child.child_nests)
1462                    child.request |= struct.request
1463                    child.reply |= struct.reply
1464                    if spec.is_multi_val():
1465                        child.in_multi_val = True
1466
1467        self._sort_pure_types()
1468
1469    def _load_attr_use(self):
1470        for _, struct in self.pure_nested_structs.items():
1471            if struct.request:
1472                for _, arg in struct.member_list():
1473                    arg.set_request()
1474            if struct.reply:
1475                for _, arg in struct.member_list():
1476                    arg.set_reply()
1477
1478        for root_set, rs_members in self.root_sets.items():
1479            for attr, spec in self.attr_sets[root_set].items():
1480                if attr in rs_members['request']:
1481                    spec.set_request()
1482                if attr in rs_members['reply']:
1483                    spec.set_reply()
1484
1485    def _load_selector_passing(self):
1486        def all_structs():
1487            for k, v in reversed(self.pure_nested_structs.items()):
1488                yield k, v
1489            for k, _ in self.root_sets.items():
1490                yield k, None  # we don't have a struct, but it must be terminal
1491
1492        for attr_set, struct in all_structs():
1493            for _, spec in self.attr_sets[attr_set].items():
1494                if 'nested-attributes' in spec:
1495                    child_name = spec['nested-attributes']
1496                elif 'sub-message' in spec:
1497                    child_name = spec.sub_message
1498                else:
1499                    continue
1500
1501                child = self.pure_nested_structs.get(child_name)
1502                for selector in child.external_selectors():
1503                    if selector.name in self.attr_sets[attr_set]:
1504                        sel_attr = self.attr_sets[attr_set][selector.name]
1505                        selector.set_attr(sel_attr)
1506                    else:
1507                        raise Exception("Passing selector thru more than one layer not supported")
1508
1509    def _load_global_policy(self):
1510        global_set = set()
1511        attr_set_name = None
1512        for op_name, op in self.ops.items():
1513            if not op:
1514                continue
1515            if 'attribute-set' not in op:
1516                continue
1517
1518            if attr_set_name is None:
1519                attr_set_name = op['attribute-set']
1520            if attr_set_name != op['attribute-set']:
1521                raise Exception('For a global policy all ops must use the same set')
1522
1523            for op_mode in ['do', 'dump']:
1524                if op_mode in op:
1525                    req = op[op_mode].get('request')
1526                    if req:
1527                        global_set.update(req.get('attributes', []))
1528
1529        self.global_policy = []
1530        self.global_policy_set = attr_set_name
1531        for attr in self.attr_sets[attr_set_name]:
1532            if attr in global_set:
1533                self.global_policy.append(attr)
1534
1535    def _load_hooks(self):
1536        for op in self.ops.values():
1537            for op_mode in ['do', 'dump']:
1538                if op_mode not in op:
1539                    continue
1540                for when in ['pre', 'post']:
1541                    if when not in op[op_mode]:
1542                        continue
1543                    name = op[op_mode][when]
1544                    if name in self.hooks[when][op_mode]['set']:
1545                        continue
1546                    self.hooks[when][op_mode]['set'].add(name)
1547                    self.hooks[when][op_mode]['list'].append(name)
1548
1549
1550class RenderInfo:
1551    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1552        self.family = family
1553        self.nl = cw.nlib
1554        self.ku_space = ku_space
1555        self.op_mode = op_mode
1556        self.op = op
1557
1558        fixed_hdr = op.fixed_header if op else None
1559        self.fixed_hdr_len = 'ys->family->hdr_len'
1560        if op and op.fixed_header:
1561            if op.fixed_header != family.fixed_header:
1562                if family.is_classic():
1563                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1564                else:
1565                    raise Exception("Per-op fixed header not supported, yet")
1566
1567
1568        # 'do' and 'dump' response parsing is identical
1569        self.type_consistent = True
1570        self.type_oneside = False
1571        if op_mode != 'do' and 'dump' in op:
1572            if 'do' in op:
1573                if ('reply' in op['do']) != ('reply' in op["dump"]):
1574                    self.type_consistent = False
1575                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1576                    self.type_consistent = False
1577            else:
1578                self.type_consistent = True
1579                self.type_oneside = True
1580
1581        self.attr_set = attr_set
1582        if not self.attr_set:
1583            self.attr_set = op['attribute-set']
1584
1585        self.type_name_conflict = False
1586        if op:
1587            self.type_name = c_lower(op.name)
1588        else:
1589            self.type_name = c_lower(attr_set)
1590            if attr_set in family.consts:
1591                self.type_name_conflict = True
1592
1593        self.cw = cw
1594
1595        self.struct = dict()
1596        if op_mode == 'notify':
1597            op_mode = 'do' if 'do' in op else 'dump'
1598        for op_dir in ['request', 'reply']:
1599            if op:
1600                type_list = []
1601                if op_dir in op[op_mode]:
1602                    type_list = op[op_mode][op_dir]['attributes']
1603                self.struct[op_dir] = Struct(family, self.attr_set,
1604                                             fixed_header=fixed_hdr,
1605                                             type_list=type_list)
1606        if op_mode == 'event':
1607            self.struct['reply'] = Struct(family, self.attr_set,
1608                                          fixed_header=fixed_hdr,
1609                                          type_list=op['event']['attributes'])
1610
1611    def type_empty(self, key):
1612        return len(self.struct[key].attr_list) == 0 and \
1613            self.struct['request'].fixed_header is None
1614
1615    def needs_nlflags(self, direction):
1616        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1617
1618
1619class CodeWriter:
1620    def __init__(self, nlib, out_file=None, overwrite=True):
1621        self.nlib = nlib
1622        self._overwrite = overwrite
1623
1624        self._nl = False
1625        self._block_end = False
1626        self._silent_block = False
1627        self._ind = 0
1628        self._ifdef_block = None
1629        if out_file is None:
1630            self._out = os.sys.stdout
1631        else:
1632            self._out = tempfile.NamedTemporaryFile('w+')
1633            self._out_file = out_file
1634
1635    def __del__(self):
1636        self.close_out_file()
1637
1638    def close_out_file(self):
1639        if self._out == os.sys.stdout:
1640            return
1641        # Avoid modifying the file if contents didn't change
1642        self._out.flush()
1643        if not self._overwrite and os.path.isfile(self._out_file):
1644            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1645                return
1646        with open(self._out_file, 'w+') as out_file:
1647            self._out.seek(0)
1648            shutil.copyfileobj(self._out, out_file)
1649            self._out.close()
1650        self._out = os.sys.stdout
1651
1652    @classmethod
1653    def _is_cond(cls, line):
1654        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1655
1656    def p(self, line, add_ind=0):
1657        if self._block_end:
1658            self._block_end = False
1659            if line.startswith('else'):
1660                line = '} ' + line
1661            else:
1662                self._out.write('\t' * self._ind + '}\n')
1663
1664        if self._nl:
1665            self._out.write('\n')
1666            self._nl = False
1667
1668        ind = self._ind
1669        if line[-1] == ':':
1670            ind -= 1
1671        if self._silent_block:
1672            ind += 1
1673        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1674        self._silent_block |= line.strip() == 'else'
1675        if line[0] == '#':
1676            ind = 0
1677        if add_ind:
1678            ind += add_ind
1679        self._out.write('\t' * ind + line + '\n')
1680
1681    def nl(self):
1682        self._nl = True
1683
1684    def block_start(self, line=''):
1685        if line:
1686            line = line + ' '
1687        self.p(line + '{')
1688        self._ind += 1
1689
1690    def block_end(self, line=''):
1691        if line and line[0] not in {';', ','}:
1692            line = ' ' + line
1693        self._ind -= 1
1694        self._nl = False
1695        if not line:
1696            # Delay printing closing bracket in case "else" comes next
1697            if self._block_end:
1698                self._out.write('\t' * (self._ind + 1) + '}\n')
1699            self._block_end = True
1700        else:
1701            self.p('}' + line)
1702
1703    def write_doc_line(self, doc, indent=True):
1704        words = doc.split()
1705        line = ' *'
1706        for word in words:
1707            if len(line) + len(word) >= 79:
1708                self.p(line)
1709                line = ' *'
1710                if indent:
1711                    line += '  '
1712            line += ' ' + word
1713        self.p(line)
1714
1715    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1716        if not args:
1717            args = ['void']
1718
1719        if doc:
1720            self.p('/*')
1721            self.p(' * ' + doc)
1722            self.p(' */')
1723
1724        oneline = qual_ret
1725        if qual_ret[-1] != '*':
1726            oneline += ' '
1727        oneline += f"{name}({', '.join(args)}){suffix}"
1728
1729        if len(oneline) < 80:
1730            self.p(oneline)
1731            return
1732
1733        v = qual_ret
1734        if len(v) > 3:
1735            self.p(v)
1736            v = ''
1737        elif qual_ret[-1] != '*':
1738            v += ' '
1739        v += name + '('
1740        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1741        delta_ind = len(v) - len(ind)
1742        v += args[0]
1743        i = 1
1744        while i < len(args):
1745            next_len = len(v) + len(args[i])
1746            if v[0] == '\t':
1747                next_len += delta_ind
1748            if next_len > 76:
1749                self.p(v + ',')
1750                v = ind
1751            else:
1752                v += ', '
1753            v += args[i]
1754            i += 1
1755        self.p(v + ')' + suffix)
1756
1757    def write_func_lvar(self, local_vars):
1758        if not local_vars:
1759            return
1760
1761        if type(local_vars) is str:
1762            local_vars = [local_vars]
1763
1764        local_vars.sort(key=len, reverse=True)
1765        for var in local_vars:
1766            self.p(var)
1767        self.nl()
1768
1769    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1770        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1771        self.block_start()
1772        self.write_func_lvar(local_vars=local_vars)
1773
1774        for line in body:
1775            self.p(line)
1776        self.block_end()
1777
1778    def writes_defines(self, defines):
1779        longest = 0
1780        for define in defines:
1781            if len(define[0]) > longest:
1782                longest = len(define[0])
1783        longest = ((longest + 8) // 8) * 8
1784        for define in defines:
1785            line = '#define ' + define[0]
1786            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1787            if type(define[1]) is int:
1788                line += str(define[1])
1789            elif type(define[1]) is str:
1790                line += '"' + define[1] + '"'
1791            self.p(line)
1792
1793    def write_struct_init(self, members):
1794        longest = max([len(x[0]) for x in members])
1795        longest += 1  # because we prepend a .
1796        longest = ((longest + 8) // 8) * 8
1797        for one in members:
1798            line = '.' + one[0]
1799            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1800            line += '= ' + str(one[1]) + ','
1801            self.p(line)
1802
1803    def ifdef_block(self, config):
1804        config_option = None
1805        if config:
1806            config_option = 'CONFIG_' + c_upper(config)
1807        if self._ifdef_block == config_option:
1808            return
1809
1810        if self._ifdef_block:
1811            self.p('#endif /* ' + self._ifdef_block + ' */')
1812        if config_option:
1813            self.p('#ifdef ' + config_option)
1814        self._ifdef_block = config_option
1815
1816
1817scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1818
1819direction_to_suffix = {
1820    'reply': '_rsp',
1821    'request': '_req',
1822    '': ''
1823}
1824
1825op_mode_to_wrapper = {
1826    'do': '',
1827    'dump': '_list',
1828    'notify': '_ntf',
1829    'event': '',
1830}
1831
1832_C_KW = {
1833    'auto',
1834    'bool',
1835    'break',
1836    'case',
1837    'char',
1838    'const',
1839    'continue',
1840    'default',
1841    'do',
1842    'double',
1843    'else',
1844    'enum',
1845    'extern',
1846    'float',
1847    'for',
1848    'goto',
1849    'if',
1850    'inline',
1851    'int',
1852    'long',
1853    'register',
1854    'return',
1855    'short',
1856    'signed',
1857    'sizeof',
1858    'static',
1859    'struct',
1860    'switch',
1861    'typedef',
1862    'union',
1863    'unsigned',
1864    'void',
1865    'volatile',
1866    'while'
1867}
1868
1869
1870def rdir(direction):
1871    if direction == 'reply':
1872        return 'request'
1873    if direction == 'request':
1874        return 'reply'
1875    return direction
1876
1877
1878def op_prefix(ri, direction, deref=False):
1879    suffix = f"_{ri.type_name}"
1880
1881    if not ri.op_mode:
1882        pass
1883    elif ri.op_mode == 'do':
1884        suffix += f"{direction_to_suffix[direction]}"
1885    else:
1886        if direction == 'request':
1887            suffix += '_req'
1888            if not ri.type_oneside:
1889                suffix += '_dump'
1890        else:
1891            if ri.type_consistent:
1892                if deref:
1893                    suffix += f"{direction_to_suffix[direction]}"
1894                else:
1895                    suffix += op_mode_to_wrapper[ri.op_mode]
1896            else:
1897                suffix += '_rsp'
1898                suffix += '_dump' if deref else '_list'
1899
1900    return f"{ri.family.c_name}{suffix}"
1901
1902
1903def type_name(ri, direction, deref=False):
1904    return f"struct {op_prefix(ri, direction, deref=deref)}"
1905
1906
1907def print_prototype(ri, direction, terminate=True, doc=None):
1908    suffix = ';' if terminate else ''
1909
1910    fname = ri.op.render_name
1911    if ri.op_mode == 'dump':
1912        fname += '_dump'
1913
1914    args = ['struct ynl_sock *ys']
1915    if 'request' in ri.op[ri.op_mode]:
1916        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1917
1918    ret = 'int'
1919    if 'reply' in ri.op[ri.op_mode]:
1920        ret = f"{type_name(ri, rdir(direction))} *"
1921
1922    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1923
1924
1925def print_req_prototype(ri):
1926    print_prototype(ri, "request", doc=ri.op['doc'])
1927
1928
1929def print_dump_prototype(ri):
1930    print_prototype(ri, "request")
1931
1932
1933def put_typol_submsg(cw, struct):
1934    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1935
1936    i = 0
1937    for name, arg in struct.member_list():
1938        nest = ""
1939        if arg.type == 'nest':
1940            nest = f" .nest = &{arg.nested_render_name}_nest,"
1941        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s",%s },' %
1942             (i, name, nest))
1943        i += 1
1944
1945    cw.block_end(line=';')
1946    cw.nl()
1947
1948    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1949    cw.p(f'.max_attr = {i - 1},')
1950    cw.p(f'.table = {struct.render_name}_policy,')
1951    cw.block_end(line=';')
1952    cw.nl()
1953
1954
1955def put_typol_fwd(cw, struct):
1956    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1957
1958
1959def put_typol(cw, struct):
1960    if struct.submsg:
1961        put_typol_submsg(cw, struct)
1962        return
1963
1964    type_max = struct.attr_set.max_name
1965    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1966
1967    for _, arg in struct.member_list():
1968        arg.attr_typol(cw)
1969
1970    cw.block_end(line=';')
1971    cw.nl()
1972
1973    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1974    cw.p(f'.max_attr = {type_max},')
1975    cw.p(f'.table = {struct.render_name}_policy,')
1976    cw.block_end(line=';')
1977    cw.nl()
1978
1979
1980def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1981    args = [f'int {arg_name}']
1982    if enum:
1983        args = [enum.user_type + ' ' + arg_name]
1984    cw.write_func_prot('const char *', f'{render_name}_str', args)
1985    cw.block_start()
1986    if enum and enum.type == 'flags':
1987        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1988    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1989    cw.p('return NULL;')
1990    cw.p(f'return {map_name}[{arg_name}];')
1991    cw.block_end()
1992    cw.nl()
1993
1994
1995def put_op_name_fwd(family, cw):
1996    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1997
1998
1999def put_op_name(family, cw):
2000    map_name = f'{family.c_name}_op_strmap'
2001    cw.block_start(line=f"static const char * const {map_name}[] =")
2002    for op_name, op in family.msgs.items():
2003        if op.rsp_value:
2004            # Make sure we don't add duplicated entries, if multiple commands
2005            # produce the same response in legacy families.
2006            if family.rsp_by_value[op.rsp_value] != op:
2007                cw.p(f'// skip "{op_name}", duplicate reply value')
2008                continue
2009
2010            if op.req_value == op.rsp_value:
2011                cw.p(f'[{op.enum_name}] = "{op_name}",')
2012            else:
2013                cw.p(f'[{op.rsp_value}] = "{op_name}",')
2014    cw.block_end(line=';')
2015    cw.nl()
2016
2017    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2018
2019
2020def put_enum_to_str_fwd(family, cw, enum):
2021    args = [enum.user_type + ' value']
2022    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2023
2024
2025def put_enum_to_str(family, cw, enum):
2026    map_name = f'{enum.render_name}_strmap'
2027    cw.block_start(line=f"static const char * const {map_name}[] =")
2028    for entry in enum.entries.values():
2029        cw.p(f'[{entry.value}] = "{entry.name}",')
2030    cw.block_end(line=';')
2031    cw.nl()
2032
2033    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2034
2035
2036def put_req_nested_prototype(ri, struct, suffix=';'):
2037    func_args = ['struct nlmsghdr *nlh',
2038                 'unsigned int attr_type',
2039                 f'{struct.ptr_name}obj']
2040
2041    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2042                          suffix=suffix)
2043
2044
2045def put_req_nested(ri, struct):
2046    local_vars = []
2047    init_lines = []
2048
2049    if struct.submsg is None:
2050        local_vars.append('struct nlattr *nest;')
2051        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2052    if struct.fixed_header:
2053        local_vars.append('void *hdr;')
2054        struct_sz = f'sizeof({struct.fixed_header})'
2055        init_lines.append(f"hdr = ynl_nlmsg_put_extra_header(nlh, {struct_sz});")
2056        init_lines.append(f"memcpy(hdr, &obj->_hdr, {struct_sz});")
2057
2058    has_anest = False
2059    has_count = False
2060    for _, arg in struct.member_list():
2061        has_anest |= arg.type == 'indexed-array'
2062        has_count |= arg.presence_type() == 'count'
2063    if has_anest:
2064        local_vars.append('struct nlattr *array;')
2065    if has_count:
2066        local_vars.append('unsigned int i;')
2067
2068    put_req_nested_prototype(ri, struct, suffix='')
2069    ri.cw.block_start()
2070    ri.cw.write_func_lvar(local_vars)
2071
2072    for line in init_lines:
2073        ri.cw.p(line)
2074
2075    for _, arg in struct.member_list():
2076        arg.attr_put(ri, "obj")
2077
2078    if struct.submsg is None:
2079        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2080
2081    ri.cw.nl()
2082    ri.cw.p('return 0;')
2083    ri.cw.block_end()
2084    ri.cw.nl()
2085
2086
2087def _multi_parse(ri, struct, init_lines, local_vars):
2088    if struct.fixed_header:
2089        local_vars += ['void *hdr;']
2090    if struct.nested:
2091        if struct.fixed_header:
2092            iter_line = f"ynl_attr_for_each_nested_off(attr, nested, sizeof({struct.fixed_header}))"
2093        else:
2094            iter_line = "ynl_attr_for_each_nested(attr, nested)"
2095    else:
2096        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2097        if ri.op.fixed_header != ri.family.fixed_header:
2098            if ri.family.is_classic():
2099                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2100            else:
2101                raise Exception("Per-op fixed header not supported, yet")
2102
2103    array_nests = set()
2104    multi_attrs = set()
2105    needs_parg = False
2106    for arg, aspec in struct.member_list():
2107        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2108            if aspec["sub-type"] in {'binary', 'nest'}:
2109                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2110                array_nests.add(arg)
2111            elif aspec['sub-type'] in scalars:
2112                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
2113                array_nests.add(arg)
2114            else:
2115                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2116        if 'multi-attr' in aspec:
2117            multi_attrs.add(arg)
2118        needs_parg |= 'nested-attributes' in aspec
2119        needs_parg |= 'sub-message' in aspec
2120    if array_nests or multi_attrs:
2121        local_vars.append('int i;')
2122    if needs_parg:
2123        local_vars.append('struct ynl_parse_arg parg;')
2124        init_lines.append('parg.ys = yarg->ys;')
2125
2126    all_multi = array_nests | multi_attrs
2127
2128    for anest in sorted(all_multi):
2129        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
2130
2131    ri.cw.block_start()
2132    ri.cw.write_func_lvar(local_vars)
2133
2134    for line in init_lines:
2135        ri.cw.p(line)
2136    ri.cw.nl()
2137
2138    for arg in struct.inherited:
2139        ri.cw.p(f'dst->{arg} = {arg};')
2140
2141    if struct.fixed_header:
2142        if struct.nested:
2143            ri.cw.p('hdr = ynl_attr_data(nested);')
2144        elif ri.family.is_classic():
2145            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2146        else:
2147            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2148        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2149    for anest in sorted(all_multi):
2150        aspec = struct[anest]
2151        ri.cw.p(f"if (dst->{aspec.c_name})")
2152        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2153
2154    ri.cw.nl()
2155    ri.cw.block_start(line=iter_line)
2156    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2157    ri.cw.nl()
2158
2159    first = True
2160    for _, arg in struct.member_list():
2161        good = arg.attr_get(ri, 'dst', first=first)
2162        # First may be 'unused' or 'pad', ignore those
2163        first &= not good
2164
2165    ri.cw.block_end()
2166    ri.cw.nl()
2167
2168    for anest in sorted(array_nests):
2169        aspec = struct[anest]
2170
2171        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2172        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2173        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2174        ri.cw.p('i = 0;')
2175        if 'nested-attributes' in aspec:
2176            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2177        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2178        if 'nested-attributes' in aspec:
2179            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2180            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2181            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2182        elif aspec.sub_type in scalars:
2183            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2184        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2185            # Length is validated by typol
2186            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2187        else:
2188            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2189        ri.cw.p('i++;')
2190        ri.cw.block_end()
2191        ri.cw.block_end()
2192    ri.cw.nl()
2193
2194    for anest in sorted(multi_attrs):
2195        aspec = struct[anest]
2196        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2197        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2198        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2199        ri.cw.p('i = 0;')
2200        if 'nested-attributes' in aspec:
2201            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2202        ri.cw.block_start(line=iter_line)
2203        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2204        if 'nested-attributes' in aspec:
2205            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2206            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2207            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2208        elif aspec.type in scalars:
2209            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2210        elif aspec.type == 'binary' and 'struct' in aspec:
2211            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2212            ri.cw.nl()
2213            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2214            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2215            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2216        elif aspec.type == 'string':
2217            ri.cw.p('unsigned int len;')
2218            ri.cw.nl()
2219            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2220            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2221            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2222            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2223            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2224        else:
2225            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2226        ri.cw.p('i++;')
2227        ri.cw.block_end()
2228        ri.cw.block_end()
2229        ri.cw.block_end()
2230    ri.cw.nl()
2231
2232    if struct.nested:
2233        ri.cw.p('return 0;')
2234    else:
2235        ri.cw.p('return YNL_PARSE_CB_OK;')
2236    ri.cw.block_end()
2237    ri.cw.nl()
2238
2239
2240def parse_rsp_submsg(ri, struct):
2241    parse_rsp_nested_prototype(ri, struct, suffix='')
2242
2243    var = 'dst'
2244    local_vars = {'const struct nlattr *attr = nested;',
2245                  f'{struct.ptr_name}{var} = yarg->data;',
2246                  'struct ynl_parse_arg parg;'}
2247
2248    for _, arg in struct.member_list():
2249        _, _, l_vars = arg._attr_get(ri, var)
2250        local_vars |= set(l_vars) if l_vars else set()
2251
2252    ri.cw.block_start()
2253    ri.cw.write_func_lvar(list(local_vars))
2254    ri.cw.p('parg.ys = yarg->ys;')
2255    ri.cw.nl()
2256
2257    first = True
2258    for name, arg in struct.member_list():
2259        kw = 'if' if first else 'else if'
2260        first = False
2261
2262        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2263        get_lines, init_lines, _ = arg._attr_get(ri, var)
2264        for line in init_lines or []:
2265            ri.cw.p(line)
2266        for line in get_lines:
2267            ri.cw.p(line)
2268        if arg.presence_type() == 'present':
2269            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2270        ri.cw.block_end()
2271    ri.cw.p('return 0;')
2272    ri.cw.block_end()
2273    ri.cw.nl()
2274
2275
2276def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2277    func_args = ['struct ynl_parse_arg *yarg',
2278                 'const struct nlattr *nested']
2279    for sel in struct.external_selectors():
2280        func_args.append('const char *_sel_' + sel.name)
2281    if struct.submsg:
2282        func_args.insert(1, 'const char *sel')
2283    for arg in struct.inherited:
2284        func_args.append('__u32 ' + arg)
2285
2286    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2287                          suffix=suffix)
2288
2289
2290def parse_rsp_nested(ri, struct):
2291    if struct.submsg:
2292        return parse_rsp_submsg(ri, struct)
2293
2294    parse_rsp_nested_prototype(ri, struct, suffix='')
2295
2296    local_vars = ['const struct nlattr *attr;',
2297                  f'{struct.ptr_name}dst = yarg->data;']
2298    init_lines = []
2299
2300    if struct.member_list():
2301        _multi_parse(ri, struct, init_lines, local_vars)
2302    else:
2303        # Empty nest
2304        ri.cw.block_start()
2305        ri.cw.p('return 0;')
2306        ri.cw.block_end()
2307        ri.cw.nl()
2308
2309
2310def parse_rsp_msg(ri, deref=False):
2311    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2312        return
2313
2314    func_args = ['const struct nlmsghdr *nlh',
2315                 'struct ynl_parse_arg *yarg']
2316
2317    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2318                  'const struct nlattr *attr;']
2319    init_lines = ['dst = yarg->data;']
2320
2321    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2322
2323    if ri.struct["reply"].member_list():
2324        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2325    else:
2326        # Empty reply
2327        ri.cw.block_start()
2328        ri.cw.p('return YNL_PARSE_CB_OK;')
2329        ri.cw.block_end()
2330        ri.cw.nl()
2331
2332
2333def print_req(ri):
2334    ret_ok = '0'
2335    ret_err = '-1'
2336    direction = "request"
2337    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2338                  'struct nlmsghdr *nlh;',
2339                  'int err;']
2340
2341    if 'reply' in ri.op[ri.op_mode]:
2342        ret_ok = 'rsp'
2343        ret_err = 'NULL'
2344        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2345
2346    if ri.struct["request"].fixed_header:
2347        local_vars += ['size_t hdr_len;',
2348                       'void *hdr;']
2349
2350    for _, attr in ri.struct["request"].member_list():
2351        if attr.presence_type() == 'count':
2352            local_vars += ['unsigned int i;']
2353            break
2354
2355    print_prototype(ri, direction, terminate=False)
2356    ri.cw.block_start()
2357    ri.cw.write_func_lvar(local_vars)
2358
2359    if ri.family.is_classic():
2360        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2361    else:
2362        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2363
2364    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2365    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2366    if 'reply' in ri.op[ri.op_mode]:
2367        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2368    ri.cw.nl()
2369
2370    if ri.struct['request'].fixed_header:
2371        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2372        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2373        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2374        ri.cw.nl()
2375
2376    for _, attr in ri.struct["request"].member_list():
2377        attr.attr_put(ri, "req")
2378    ri.cw.nl()
2379
2380    if 'reply' in ri.op[ri.op_mode]:
2381        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2382        ri.cw.p('yrs.yarg.data = rsp;')
2383        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2384        if ri.op.value is not None:
2385            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2386        else:
2387            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2388        ri.cw.nl()
2389    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2390    ri.cw.p('if (err < 0)')
2391    if 'reply' in ri.op[ri.op_mode]:
2392        ri.cw.p('goto err_free;')
2393    else:
2394        ri.cw.p('return -1;')
2395    ri.cw.nl()
2396
2397    ri.cw.p(f"return {ret_ok};")
2398    ri.cw.nl()
2399
2400    if 'reply' in ri.op[ri.op_mode]:
2401        ri.cw.p('err_free:')
2402        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2403        ri.cw.p(f"return {ret_err};")
2404
2405    ri.cw.block_end()
2406
2407
2408def print_dump(ri):
2409    direction = "request"
2410    print_prototype(ri, direction, terminate=False)
2411    ri.cw.block_start()
2412    local_vars = ['struct ynl_dump_state yds = {};',
2413                  'struct nlmsghdr *nlh;',
2414                  'int err;']
2415
2416    if ri.struct['request'].fixed_header:
2417        local_vars += ['size_t hdr_len;',
2418                       'void *hdr;']
2419
2420    ri.cw.write_func_lvar(local_vars)
2421
2422    ri.cw.p('yds.yarg.ys = ys;')
2423    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2424    ri.cw.p("yds.yarg.data = NULL;")
2425    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2426    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2427    if ri.op.value is not None:
2428        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2429    else:
2430        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2431    ri.cw.nl()
2432    if ri.family.is_classic():
2433        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2434    else:
2435        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2436
2437    if ri.struct['request'].fixed_header:
2438        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2439        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2440        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2441        ri.cw.nl()
2442
2443    if "request" in ri.op[ri.op_mode]:
2444        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2445        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2446        ri.cw.nl()
2447        for _, attr in ri.struct["request"].member_list():
2448            attr.attr_put(ri, "req")
2449    ri.cw.nl()
2450
2451    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2452    ri.cw.p('if (err < 0)')
2453    ri.cw.p('goto free_list;')
2454    ri.cw.nl()
2455
2456    ri.cw.p('return yds.first;')
2457    ri.cw.nl()
2458    ri.cw.p('free_list:')
2459    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2460    ri.cw.p('return NULL;')
2461    ri.cw.block_end()
2462
2463
2464def call_free(ri, direction, var):
2465    return f"{op_prefix(ri, direction)}_free({var});"
2466
2467
2468def free_arg_name(direction):
2469    if direction:
2470        return direction_to_suffix[direction][1:]
2471    return 'obj'
2472
2473
2474def print_alloc_wrapper(ri, direction, struct=None):
2475    name = op_prefix(ri, direction)
2476    struct_name = name
2477    if ri.type_name_conflict:
2478        struct_name += '_'
2479
2480    args = ["void"]
2481    cnt = "1"
2482    if struct and struct.in_multi_val:
2483        args = ["unsigned int n"]
2484        cnt = "n"
2485
2486    ri.cw.write_func_prot(f'static inline struct {struct_name} *',
2487                          f"{name}_alloc", args)
2488    ri.cw.block_start()
2489    ri.cw.p(f'return calloc({cnt}, sizeof(struct {struct_name}));')
2490    ri.cw.block_end()
2491
2492
2493def print_free_prototype(ri, direction, suffix=';'):
2494    name = op_prefix(ri, direction)
2495    struct_name = name
2496    if ri.type_name_conflict:
2497        struct_name += '_'
2498    arg = free_arg_name(direction)
2499    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2500
2501
2502def print_nlflags_set(ri, direction):
2503    name = op_prefix(ri, direction)
2504    ri.cw.write_func_prot('static inline void', f"{name}_set_nlflags",
2505                          [f"struct {name} *req", "__u16 nl_flags"])
2506    ri.cw.block_start()
2507    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2508    ri.cw.block_end()
2509    ri.cw.nl()
2510
2511
2512def _print_type(ri, direction, struct):
2513    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2514    if not direction and ri.type_name_conflict:
2515        suffix += '_'
2516
2517    if ri.op_mode == 'dump' and not ri.type_oneside:
2518        suffix += '_dump'
2519
2520    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2521
2522    if ri.needs_nlflags(direction):
2523        ri.cw.p('__u16 _nlmsg_flags;')
2524        ri.cw.nl()
2525    if struct.fixed_header:
2526        ri.cw.p(struct.fixed_header + ' _hdr;')
2527        ri.cw.nl()
2528
2529    for type_filter in ['present', 'len', 'count']:
2530        meta_started = False
2531        for _, attr in struct.member_list():
2532            line = attr.presence_member(ri.ku_space, type_filter)
2533            if line:
2534                if not meta_started:
2535                    ri.cw.block_start(line="struct")
2536                    meta_started = True
2537                ri.cw.p(line)
2538        if meta_started:
2539            ri.cw.block_end(line=f'_{type_filter};')
2540    ri.cw.nl()
2541
2542    for arg in struct.inherited:
2543        ri.cw.p(f"__u32 {arg};")
2544
2545    for _, attr in struct.member_list():
2546        attr.struct_member(ri)
2547
2548    ri.cw.block_end(line=';')
2549    ri.cw.nl()
2550
2551
2552def print_type(ri, direction):
2553    _print_type(ri, direction, ri.struct[direction])
2554
2555
2556def print_type_full(ri, struct):
2557    _print_type(ri, "", struct)
2558
2559    if struct.request and struct.in_multi_val:
2560        print_alloc_wrapper(ri, "", struct)
2561        ri.cw.nl()
2562        free_rsp_nested_prototype(ri)
2563        ri.cw.nl()
2564
2565        # Name conflicts are too hard to deal with with the current code base,
2566        # they are very rare so don't bother printing setters in that case.
2567        if ri.ku_space == 'user' and not ri.type_name_conflict:
2568            for _, attr in struct.member_list():
2569                attr.setter(ri, ri.attr_set, "", var="obj")
2570        ri.cw.nl()
2571
2572
2573def print_type_helpers(ri, direction, deref=False):
2574    print_free_prototype(ri, direction)
2575    ri.cw.nl()
2576
2577    if ri.needs_nlflags(direction):
2578        print_nlflags_set(ri, direction)
2579
2580    if ri.ku_space == 'user' and direction == 'request':
2581        for _, attr in ri.struct[direction].member_list():
2582            attr.setter(ri, ri.attr_set, direction, deref=deref)
2583    ri.cw.nl()
2584
2585
2586def print_req_type_helpers(ri):
2587    if ri.type_empty("request"):
2588        return
2589    print_alloc_wrapper(ri, "request")
2590    print_type_helpers(ri, "request")
2591
2592
2593def print_rsp_type_helpers(ri):
2594    if 'reply' not in ri.op[ri.op_mode]:
2595        return
2596    print_type_helpers(ri, "reply")
2597
2598
2599def print_parse_prototype(ri, direction, terminate=True):
2600    suffix = "_rsp" if direction == "reply" else "_req"
2601    term = ';' if terminate else ''
2602
2603    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2604                          ['const struct nlattr **tb',
2605                           f"struct {ri.op.render_name}{suffix} *req"],
2606                          suffix=term)
2607
2608
2609def print_req_type(ri):
2610    if ri.type_empty("request"):
2611        return
2612    print_type(ri, "request")
2613
2614
2615def print_req_free(ri):
2616    if 'request' not in ri.op[ri.op_mode]:
2617        return
2618    _free_type(ri, 'request', ri.struct['request'])
2619
2620
2621def print_rsp_type(ri):
2622    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2623        direction = 'reply'
2624    elif ri.op_mode == 'event':
2625        direction = 'reply'
2626    else:
2627        return
2628    print_type(ri, direction)
2629
2630
2631def print_wrapped_type(ri):
2632    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2633    if ri.op_mode == 'dump':
2634        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2635    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2636        ri.cw.p('__u16 family;')
2637        ri.cw.p('__u8 cmd;')
2638        ri.cw.p('struct ynl_ntf_base_type *next;')
2639        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2640    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2641    ri.cw.block_end(line=';')
2642    ri.cw.nl()
2643    print_free_prototype(ri, 'reply')
2644    ri.cw.nl()
2645
2646
2647def _free_type_members_iter(ri, struct):
2648    if struct.free_needs_iter():
2649        ri.cw.p('unsigned int i;')
2650        ri.cw.nl()
2651
2652
2653def _free_type_members(ri, var, struct, ref=''):
2654    for _, attr in struct.member_list():
2655        attr.free(ri, var, ref)
2656
2657
2658def _free_type(ri, direction, struct):
2659    var = free_arg_name(direction)
2660
2661    print_free_prototype(ri, direction, suffix='')
2662    ri.cw.block_start()
2663    _free_type_members_iter(ri, struct)
2664    _free_type_members(ri, var, struct)
2665    if direction:
2666        ri.cw.p(f'free({var});')
2667    ri.cw.block_end()
2668    ri.cw.nl()
2669
2670
2671def free_rsp_nested_prototype(ri):
2672        print_free_prototype(ri, "")
2673
2674
2675def free_rsp_nested(ri, struct):
2676    _free_type(ri, "", struct)
2677
2678
2679def print_rsp_free(ri):
2680    if 'reply' not in ri.op[ri.op_mode]:
2681        return
2682    _free_type(ri, 'reply', ri.struct['reply'])
2683
2684
2685def print_dump_type_free(ri):
2686    sub_type = type_name(ri, 'reply')
2687
2688    print_free_prototype(ri, 'reply', suffix='')
2689    ri.cw.block_start()
2690    ri.cw.p(f"{sub_type} *next = rsp;")
2691    ri.cw.nl()
2692    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2693    _free_type_members_iter(ri, ri.struct['reply'])
2694    ri.cw.p('rsp = next;')
2695    ri.cw.p('next = rsp->next;')
2696    ri.cw.nl()
2697
2698    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2699    ri.cw.p('free(rsp);')
2700    ri.cw.block_end()
2701    ri.cw.block_end()
2702    ri.cw.nl()
2703
2704
2705def print_ntf_type_free(ri):
2706    print_free_prototype(ri, 'reply', suffix='')
2707    ri.cw.block_start()
2708    _free_type_members_iter(ri, ri.struct['reply'])
2709    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2710    ri.cw.p('free(rsp);')
2711    ri.cw.block_end()
2712    ri.cw.nl()
2713
2714
2715def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2716    if terminate and ri and policy_should_be_static(struct.family):
2717        return
2718
2719    if terminate:
2720        prefix = 'extern '
2721    else:
2722        if ri and policy_should_be_static(struct.family):
2723            prefix = 'static '
2724        else:
2725            prefix = ''
2726
2727    suffix = ';' if terminate else ' = {'
2728
2729    max_attr = struct.attr_max_val
2730    if ri:
2731        name = ri.op.render_name
2732        if ri.op.dual_policy:
2733            name += '_' + ri.op_mode
2734    else:
2735        name = struct.render_name
2736    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2737
2738
2739def print_req_policy(cw, struct, ri=None):
2740    if ri and ri.op:
2741        cw.ifdef_block(ri.op.get('config-cond', None))
2742    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2743    for _, arg in struct.member_list():
2744        arg.attr_policy(cw)
2745    cw.p("};")
2746    cw.ifdef_block(None)
2747    cw.nl()
2748
2749
2750def kernel_can_gen_family_struct(family):
2751    return family.proto == 'genetlink'
2752
2753
2754def policy_should_be_static(family):
2755    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2756
2757
2758def print_kernel_policy_ranges(family, cw):
2759    first = True
2760    for _, attr_set in family.attr_sets.items():
2761        if attr_set.subset_of:
2762            continue
2763
2764        for _, attr in attr_set.items():
2765            if not attr.request:
2766                continue
2767            if 'full-range' not in attr.checks:
2768                continue
2769
2770            if first:
2771                cw.p('/* Integer value ranges */')
2772                first = False
2773
2774            sign = '' if attr.type[0] == 'u' else '_signed'
2775            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2776            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2777            members = []
2778            if 'min' in attr.checks:
2779                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2780            if 'max' in attr.checks:
2781                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2782            cw.write_struct_init(members)
2783            cw.block_end(line=';')
2784            cw.nl()
2785
2786
2787def print_kernel_policy_sparse_enum_validates(family, cw):
2788    first = True
2789    for _, attr_set in family.attr_sets.items():
2790        if attr_set.subset_of:
2791            continue
2792
2793        for _, attr in attr_set.items():
2794            if not attr.request:
2795                continue
2796            if not attr.enum_name:
2797                continue
2798            if 'sparse' not in attr.checks:
2799                continue
2800
2801            if first:
2802                cw.p('/* Sparse enums validation callbacks */')
2803                first = False
2804
2805            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2806                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2807            cw.block_start()
2808            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2809            enum = family.consts[attr['enum']]
2810            first_entry = True
2811            for entry in enum.entries.values():
2812                if first_entry:
2813                    first_entry = False
2814                else:
2815                    cw.p('fallthrough;')
2816                cw.p(f'case {entry.c_name}:')
2817            cw.p('return 0;')
2818            cw.block_end()
2819            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2820            cw.p('return -EINVAL;')
2821            cw.block_end()
2822            cw.nl()
2823
2824
2825def print_kernel_op_table_fwd(family, cw, terminate):
2826    exported = not kernel_can_gen_family_struct(family)
2827
2828    if not terminate or exported:
2829        cw.p(f"/* Ops table for {family.ident_name} */")
2830
2831        pol_to_struct = {'global': 'genl_small_ops',
2832                         'per-op': 'genl_ops',
2833                         'split': 'genl_split_ops'}
2834        struct_type = pol_to_struct[family.kernel_policy]
2835
2836        if not exported:
2837            cnt = ""
2838        elif family.kernel_policy == 'split':
2839            cnt = 0
2840            for op in family.ops.values():
2841                if 'do' in op:
2842                    cnt += 1
2843                if 'dump' in op:
2844                    cnt += 1
2845        else:
2846            cnt = len(family.ops)
2847
2848        qual = 'static const' if not exported else 'const'
2849        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2850        if terminate:
2851            cw.p(f"extern {line};")
2852        else:
2853            cw.block_start(line=line + ' =')
2854
2855    if not terminate:
2856        return
2857
2858    cw.nl()
2859    for name in family.hooks['pre']['do']['list']:
2860        cw.write_func_prot('int', c_lower(name),
2861                           ['const struct genl_split_ops *ops',
2862                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2863    for name in family.hooks['post']['do']['list']:
2864        cw.write_func_prot('void', c_lower(name),
2865                           ['const struct genl_split_ops *ops',
2866                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2867    for name in family.hooks['pre']['dump']['list']:
2868        cw.write_func_prot('int', c_lower(name),
2869                           ['struct netlink_callback *cb'], suffix=';')
2870    for name in family.hooks['post']['dump']['list']:
2871        cw.write_func_prot('int', c_lower(name),
2872                           ['struct netlink_callback *cb'], suffix=';')
2873
2874    cw.nl()
2875
2876    for op_name, op in family.ops.items():
2877        if op.is_async:
2878            continue
2879
2880        if 'do' in op:
2881            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2882            cw.write_func_prot('int', name,
2883                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2884
2885        if 'dump' in op:
2886            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2887            cw.write_func_prot('int', name,
2888                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2889    cw.nl()
2890
2891
2892def print_kernel_op_table_hdr(family, cw):
2893    print_kernel_op_table_fwd(family, cw, terminate=True)
2894
2895
2896def print_kernel_op_table(family, cw):
2897    print_kernel_op_table_fwd(family, cw, terminate=False)
2898    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2899        for op_name, op in family.ops.items():
2900            if op.is_async:
2901                continue
2902
2903            cw.ifdef_block(op.get('config-cond', None))
2904            cw.block_start()
2905            members = [('cmd', op.enum_name)]
2906            if 'dont-validate' in op:
2907                members.append(('validate',
2908                                ' | '.join([c_upper('genl-dont-validate-' + x)
2909                                            for x in op['dont-validate']])), )
2910            for op_mode in ['do', 'dump']:
2911                if op_mode in op:
2912                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2913                    members.append((op_mode + 'it', name))
2914            if family.kernel_policy == 'per-op':
2915                struct = Struct(family, op['attribute-set'],
2916                                type_list=op['do']['request']['attributes'])
2917
2918                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2919                members.append(('policy', name))
2920                members.append(('maxattr', struct.attr_max_val.enum_name))
2921            if 'flags' in op:
2922                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2923            cw.write_struct_init(members)
2924            cw.block_end(line=',')
2925    elif family.kernel_policy == 'split':
2926        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2927                    'dump': {'pre': 'start', 'post': 'done'}}
2928
2929        for op_name, op in family.ops.items():
2930            for op_mode in ['do', 'dump']:
2931                if op.is_async or op_mode not in op:
2932                    continue
2933
2934                cw.ifdef_block(op.get('config-cond', None))
2935                cw.block_start()
2936                members = [('cmd', op.enum_name)]
2937                if 'dont-validate' in op:
2938                    dont_validate = []
2939                    for x in op['dont-validate']:
2940                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2941                            continue
2942                        if op_mode == "dump" and x == 'strict':
2943                            continue
2944                        dont_validate.append(x)
2945
2946                    if dont_validate:
2947                        members.append(('validate',
2948                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2949                                                    for x in dont_validate])), )
2950                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2951                if 'pre' in op[op_mode]:
2952                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2953                members.append((op_mode + 'it', name))
2954                if 'post' in op[op_mode]:
2955                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2956                if 'request' in op[op_mode]:
2957                    struct = Struct(family, op['attribute-set'],
2958                                    type_list=op[op_mode]['request']['attributes'])
2959
2960                    if op.dual_policy:
2961                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2962                    else:
2963                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2964                    members.append(('policy', name))
2965                    members.append(('maxattr', struct.attr_max_val.enum_name))
2966                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2967                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2968                cw.write_struct_init(members)
2969                cw.block_end(line=',')
2970    cw.ifdef_block(None)
2971
2972    cw.block_end(line=';')
2973    cw.nl()
2974
2975
2976def print_kernel_mcgrp_hdr(family, cw):
2977    if not family.mcgrps['list']:
2978        return
2979
2980    cw.block_start('enum')
2981    for grp in family.mcgrps['list']:
2982        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2983        cw.p(grp_id)
2984    cw.block_end(';')
2985    cw.nl()
2986
2987
2988def print_kernel_mcgrp_src(family, cw):
2989    if not family.mcgrps['list']:
2990        return
2991
2992    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2993    for grp in family.mcgrps['list']:
2994        name = grp['name']
2995        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2996        cw.p('[' + grp_id + '] = { "' + name + '", },')
2997    cw.block_end(';')
2998    cw.nl()
2999
3000
3001def print_kernel_family_struct_hdr(family, cw):
3002    if not kernel_can_gen_family_struct(family):
3003        return
3004
3005    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
3006    cw.nl()
3007    if 'sock-priv' in family.kernel_family:
3008        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
3009        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
3010        cw.nl()
3011
3012
3013def print_kernel_family_struct_src(family, cw):
3014    if not kernel_can_gen_family_struct(family):
3015        return
3016
3017    if 'sock-priv' in family.kernel_family:
3018        # Generate "trampolines" to make CFI happy
3019        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
3020                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
3021                      ["void *priv"])
3022        cw.nl()
3023        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
3024                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
3025                      ["void *priv"])
3026        cw.nl()
3027
3028    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
3029    cw.p('.name\t\t= ' + family.fam_key + ',')
3030    cw.p('.version\t= ' + family.ver_key + ',')
3031    cw.p('.netnsok\t= true,')
3032    cw.p('.parallel_ops\t= true,')
3033    cw.p('.module\t\t= THIS_MODULE,')
3034    if family.kernel_policy == 'per-op':
3035        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
3036        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3037    elif family.kernel_policy == 'split':
3038        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
3039        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3040    if family.mcgrps['list']:
3041        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
3042        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
3043    if 'sock-priv' in family.kernel_family:
3044        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
3045        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
3046        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
3047    cw.block_end(';')
3048
3049
3050def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
3051    start_line = 'enum'
3052    if enum_name in obj:
3053        if obj[enum_name]:
3054            start_line = 'enum ' + c_lower(obj[enum_name])
3055    elif ckey and ckey in obj:
3056        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3057    cw.block_start(line=start_line)
3058
3059
3060def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3061    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3062    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3063    max_value = f"({cnt_name} - 1)"
3064
3065    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3066    val = 0
3067    for op in family.msgs.values():
3068        if separate_ntf and ('notify' in op or 'event' in op):
3069            continue
3070
3071        suffix = ','
3072        if op.value != val:
3073            suffix = f" = {op.value},"
3074            val = op.value
3075        cw.p(op.enum_name + suffix)
3076        val += 1
3077    cw.nl()
3078    cw.p(cnt_name + ('' if max_by_define else ','))
3079    if not max_by_define:
3080        cw.p(f"{max_name} = {max_value}")
3081    cw.block_end(line=';')
3082    if max_by_define:
3083        cw.p(f"#define {max_name} {max_value}")
3084    cw.nl()
3085
3086
3087def render_uapi_directional(family, cw, max_by_define):
3088    max_name = f"{family.op_prefix}USER_MAX"
3089    cnt_name = f"__{family.op_prefix}USER_CNT"
3090    max_value = f"({cnt_name} - 1)"
3091
3092    cw.block_start(line='enum')
3093    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3094    val = 0
3095    for op in family.msgs.values():
3096        if 'do' in op and 'event' not in op:
3097            suffix = ','
3098            if op.value and op.value != val:
3099                suffix = f" = {op.value},"
3100                val = op.value
3101            cw.p(op.enum_name + suffix)
3102            val += 1
3103    cw.nl()
3104    cw.p(cnt_name + ('' if max_by_define else ','))
3105    if not max_by_define:
3106        cw.p(f"{max_name} = {max_value}")
3107    cw.block_end(line=';')
3108    if max_by_define:
3109        cw.p(f"#define {max_name} {max_value}")
3110    cw.nl()
3111
3112    max_name = f"{family.op_prefix}KERNEL_MAX"
3113    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3114    max_value = f"({cnt_name} - 1)"
3115
3116    cw.block_start(line='enum')
3117    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3118    val = 0
3119    for op in family.msgs.values():
3120        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3121            enum_name = op.enum_name
3122            if 'event' not in op and 'notify' not in op:
3123                enum_name = f'{enum_name}_REPLY'
3124
3125            suffix = ','
3126            if op.value and op.value != val:
3127                suffix = f" = {op.value},"
3128                val = op.value
3129            cw.p(enum_name + suffix)
3130            val += 1
3131    cw.nl()
3132    cw.p(cnt_name + ('' if max_by_define else ','))
3133    if not max_by_define:
3134        cw.p(f"{max_name} = {max_value}")
3135    cw.block_end(line=';')
3136    if max_by_define:
3137        cw.p(f"#define {max_name} {max_value}")
3138    cw.nl()
3139
3140
3141def render_uapi(family, cw):
3142    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3143    hdr_prot = hdr_prot.replace('/', '_')
3144    cw.p('#ifndef ' + hdr_prot)
3145    cw.p('#define ' + hdr_prot)
3146    cw.nl()
3147
3148    defines = [(family.fam_key, family["name"]),
3149               (family.ver_key, family.get('version', 1))]
3150    cw.writes_defines(defines)
3151    cw.nl()
3152
3153    defines = []
3154    for const in family['definitions']:
3155        if const.get('header'):
3156            continue
3157
3158        if const['type'] != 'const':
3159            cw.writes_defines(defines)
3160            defines = []
3161            cw.nl()
3162
3163        # Write kdoc for enum and flags (one day maybe also structs)
3164        if const['type'] == 'enum' or const['type'] == 'flags':
3165            enum = family.consts[const['name']]
3166
3167            if enum.header:
3168                continue
3169
3170            if enum.has_doc():
3171                if enum.has_entry_doc():
3172                    cw.p('/**')
3173                    doc = ''
3174                    if 'doc' in enum:
3175                        doc = ' - ' + enum['doc']
3176                    cw.write_doc_line(enum.enum_name + doc)
3177                else:
3178                    cw.p('/*')
3179                    cw.write_doc_line(enum['doc'], indent=False)
3180                for entry in enum.entries.values():
3181                    if entry.has_doc():
3182                        doc = '@' + entry.c_name + ': ' + entry['doc']
3183                        cw.write_doc_line(doc)
3184                cw.p(' */')
3185
3186            uapi_enum_start(family, cw, const, 'name')
3187            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3188            for entry in enum.entries.values():
3189                suffix = ','
3190                if entry.value_change:
3191                    suffix = f" = {entry.user_value()}" + suffix
3192                cw.p(entry.c_name + suffix)
3193
3194            if const.get('render-max', False):
3195                cw.nl()
3196                cw.p('/* private: */')
3197                if const['type'] == 'flags':
3198                    max_name = c_upper(name_pfx + 'mask')
3199                    max_val = f' = {enum.get_mask()},'
3200                    cw.p(max_name + max_val)
3201                else:
3202                    cnt_name = enum.enum_cnt_name
3203                    max_name = c_upper(name_pfx + 'max')
3204                    if not cnt_name:
3205                        cnt_name = '__' + name_pfx + 'max'
3206                    cw.p(c_upper(cnt_name) + ',')
3207                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3208            cw.block_end(line=';')
3209            cw.nl()
3210        elif const['type'] == 'const':
3211            defines.append([c_upper(family.get('c-define-name',
3212                                               f"{family.ident_name}-{const['name']}")),
3213                            const['value']])
3214
3215    if defines:
3216        cw.writes_defines(defines)
3217        cw.nl()
3218
3219    max_by_define = family.get('max-by-define', False)
3220
3221    for _, attr_set in family.attr_sets.items():
3222        if attr_set.subset_of:
3223            continue
3224
3225        max_value = f"({attr_set.cnt_name} - 1)"
3226
3227        val = 0
3228        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3229        for _, attr in attr_set.items():
3230            suffix = ','
3231            if attr.value != val:
3232                suffix = f" = {attr.value},"
3233                val = attr.value
3234            val += 1
3235            cw.p(attr.enum_name + suffix)
3236        if attr_set.items():
3237            cw.nl()
3238        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3239        if not max_by_define:
3240            cw.p(f"{attr_set.max_name} = {max_value}")
3241        cw.block_end(line=';')
3242        if max_by_define:
3243            cw.p(f"#define {attr_set.max_name} {max_value}")
3244        cw.nl()
3245
3246    # Commands
3247    separate_ntf = 'async-prefix' in family['operations']
3248
3249    if family.msg_id_model == 'unified':
3250        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3251    elif family.msg_id_model == 'directional':
3252        render_uapi_directional(family, cw, max_by_define)
3253    else:
3254        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3255
3256    if separate_ntf:
3257        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3258        for op in family.msgs.values():
3259            if separate_ntf and not ('notify' in op or 'event' in op):
3260                continue
3261
3262            suffix = ','
3263            if 'value' in op:
3264                suffix = f" = {op['value']},"
3265            cw.p(op.enum_name + suffix)
3266        cw.block_end(line=';')
3267        cw.nl()
3268
3269    # Multicast
3270    defines = []
3271    for grp in family.mcgrps['list']:
3272        name = grp['name']
3273        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3274                        f'{name}'])
3275    cw.nl()
3276    if defines:
3277        cw.writes_defines(defines)
3278        cw.nl()
3279
3280    cw.p(f'#endif /* {hdr_prot} */')
3281
3282
3283def _render_user_ntf_entry(ri, op):
3284    if not ri.family.is_classic():
3285        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3286    else:
3287        crud_op = ri.family.req_by_value[op.rsp_value]
3288        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3289    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3290    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3291    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3292    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3293    ri.cw.block_end(line=',')
3294
3295
3296def render_user_family(family, cw, prototype):
3297    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3298    if prototype:
3299        cw.p(f'extern {symbol};')
3300        return
3301
3302    if family.ntfs:
3303        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3304        for ntf_op_name, ntf_op in family.ntfs.items():
3305            if 'notify' in ntf_op:
3306                op = family.ops[ntf_op['notify']]
3307                ri = RenderInfo(cw, family, "user", op, "notify")
3308            elif 'event' in ntf_op:
3309                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3310            else:
3311                raise Exception('Invalid notification ' + ntf_op_name)
3312            _render_user_ntf_entry(ri, ntf_op)
3313        for op_name, op in family.ops.items():
3314            if 'event' not in op:
3315                continue
3316            ri = RenderInfo(cw, family, "user", op, "event")
3317            _render_user_ntf_entry(ri, op)
3318        cw.block_end(line=";")
3319        cw.nl()
3320
3321    cw.block_start(f'{symbol} = ')
3322    cw.p(f'.name\t\t= "{family.c_name}",')
3323    if family.is_classic():
3324        cw.p('.is_classic\t= true,')
3325        cw.p(f'.classic_id\t= {family.get("protonum")},')
3326    if family.is_classic():
3327        if family.fixed_header:
3328            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3329    elif family.fixed_header:
3330        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3331    else:
3332        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3333    if family.ntfs:
3334        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3335        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3336    cw.block_end(line=';')
3337
3338
3339def family_contains_bitfield32(family):
3340    for _, attr_set in family.attr_sets.items():
3341        if attr_set.subset_of:
3342            continue
3343        for _, attr in attr_set.items():
3344            if attr.type == "bitfield32":
3345                return True
3346    return False
3347
3348
3349def find_kernel_root(full_path):
3350    sub_path = ''
3351    while True:
3352        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3353        full_path = os.path.dirname(full_path)
3354        maintainers = os.path.join(full_path, "MAINTAINERS")
3355        if os.path.exists(maintainers):
3356            return full_path, sub_path[:-1]
3357
3358
3359def main():
3360    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3361    parser.add_argument('--mode', dest='mode', type=str, required=True,
3362                        choices=('user', 'kernel', 'uapi'))
3363    parser.add_argument('--spec', dest='spec', type=str, required=True)
3364    parser.add_argument('--header', dest='header', action='store_true', default=None)
3365    parser.add_argument('--source', dest='header', action='store_false')
3366    parser.add_argument('--user-header', nargs='+', default=[])
3367    parser.add_argument('--cmp-out', action='store_true', default=None,
3368                        help='Do not overwrite the output file if the new output is identical to the old')
3369    parser.add_argument('--exclude-op', action='append', default=[])
3370    parser.add_argument('-o', dest='out_file', type=str, default=None)
3371    args = parser.parse_args()
3372
3373    if args.header is None:
3374        parser.error("--header or --source is required")
3375
3376    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3377
3378    try:
3379        parsed = Family(args.spec, exclude_ops)
3380        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3381            print('Spec license:', parsed.license)
3382            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3383            os.sys.exit(1)
3384    except yaml.YAMLError as exc:
3385        print(exc)
3386        os.sys.exit(1)
3387        return
3388
3389    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3390
3391    _, spec_kernel = find_kernel_root(args.spec)
3392    if args.mode == 'uapi' or args.header:
3393        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3394    else:
3395        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3396    cw.p("/* Do not edit directly, auto-generated from: */")
3397    cw.p(f"/*\t{spec_kernel} */")
3398    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3399    if args.exclude_op or args.user_header:
3400        line = ''
3401        line += ' --user-header '.join([''] + args.user_header)
3402        line += ' --exclude-op '.join([''] + args.exclude_op)
3403        cw.p(f'/* YNL-ARG{line} */')
3404    cw.nl()
3405
3406    if args.mode == 'uapi':
3407        render_uapi(parsed, cw)
3408        return
3409
3410    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3411    if args.header:
3412        cw.p('#ifndef ' + hdr_prot)
3413        cw.p('#define ' + hdr_prot)
3414        cw.nl()
3415
3416    if args.out_file:
3417        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3418    else:
3419        hdr_file = "generated_header_file.h"
3420
3421    if args.mode == 'kernel':
3422        cw.p('#include <net/netlink.h>')
3423        cw.p('#include <net/genetlink.h>')
3424        cw.nl()
3425        if not args.header:
3426            if args.out_file:
3427                cw.p(f'#include "{hdr_file}"')
3428            cw.nl()
3429        headers = ['uapi/' + parsed.uapi_header]
3430        headers += parsed.kernel_family.get('headers', [])
3431    else:
3432        cw.p('#include <stdlib.h>')
3433        cw.p('#include <string.h>')
3434        if args.header:
3435            cw.p('#include <linux/types.h>')
3436            if family_contains_bitfield32(parsed):
3437                cw.p('#include <linux/netlink.h>')
3438        else:
3439            cw.p(f'#include "{hdr_file}"')
3440            cw.p('#include "ynl.h"')
3441        headers = []
3442    for definition in parsed['definitions'] + parsed['attribute-sets']:
3443        if 'header' in definition:
3444            headers.append(definition['header'])
3445    if args.mode == 'user':
3446        headers.append(parsed.uapi_header)
3447    seen_header = []
3448    for one in headers:
3449        if one not in seen_header:
3450            cw.p(f"#include <{one}>")
3451            seen_header.append(one)
3452    cw.nl()
3453
3454    if args.mode == "user":
3455        if not args.header:
3456            cw.p("#include <linux/genetlink.h>")
3457            cw.nl()
3458            for one in args.user_header:
3459                cw.p(f'#include "{one}"')
3460        else:
3461            cw.p('struct ynl_sock;')
3462            cw.nl()
3463            render_user_family(parsed, cw, True)
3464        cw.nl()
3465
3466    if args.mode == "kernel":
3467        if args.header:
3468            for _, struct in sorted(parsed.pure_nested_structs.items()):
3469                if struct.request:
3470                    cw.p('/* Common nested types */')
3471                    break
3472            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3473                if struct.request:
3474                    print_req_policy_fwd(cw, struct)
3475            cw.nl()
3476
3477            if parsed.kernel_policy == 'global':
3478                cw.p(f"/* Global operation policy for {parsed.name} */")
3479
3480                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3481                print_req_policy_fwd(cw, struct)
3482                cw.nl()
3483
3484            if parsed.kernel_policy in {'per-op', 'split'}:
3485                for op_name, op in parsed.ops.items():
3486                    if 'do' in op and 'event' not in op:
3487                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3488                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3489                        cw.nl()
3490
3491            print_kernel_op_table_hdr(parsed, cw)
3492            print_kernel_mcgrp_hdr(parsed, cw)
3493            print_kernel_family_struct_hdr(parsed, cw)
3494        else:
3495            print_kernel_policy_ranges(parsed, cw)
3496            print_kernel_policy_sparse_enum_validates(parsed, cw)
3497
3498            for _, struct in sorted(parsed.pure_nested_structs.items()):
3499                if struct.request:
3500                    cw.p('/* Common nested types */')
3501                    break
3502            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3503                if struct.request:
3504                    print_req_policy(cw, struct)
3505            cw.nl()
3506
3507            if parsed.kernel_policy == 'global':
3508                cw.p(f"/* Global operation policy for {parsed.name} */")
3509
3510                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3511                print_req_policy(cw, struct)
3512                cw.nl()
3513
3514            for op_name, op in parsed.ops.items():
3515                if parsed.kernel_policy in {'per-op', 'split'}:
3516                    for op_mode in ['do', 'dump']:
3517                        if op_mode in op and 'request' in op[op_mode]:
3518                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3519                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3520                            print_req_policy(cw, ri.struct['request'], ri=ri)
3521                            cw.nl()
3522
3523            print_kernel_op_table(parsed, cw)
3524            print_kernel_mcgrp_src(parsed, cw)
3525            print_kernel_family_struct_src(parsed, cw)
3526
3527    if args.mode == "user":
3528        if args.header:
3529            cw.p('/* Enums */')
3530            put_op_name_fwd(parsed, cw)
3531
3532            for name, const in parsed.consts.items():
3533                if isinstance(const, EnumSet):
3534                    put_enum_to_str_fwd(parsed, cw, const)
3535            cw.nl()
3536
3537            cw.p('/* Common nested types */')
3538            for attr_set, struct in parsed.pure_nested_structs.items():
3539                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3540                print_type_full(ri, struct)
3541
3542            for op_name, op in parsed.ops.items():
3543                cw.p(f"/* ============== {op.enum_name} ============== */")
3544
3545                if 'do' in op and 'event' not in op:
3546                    cw.p(f"/* {op.enum_name} - do */")
3547                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3548                    print_req_type(ri)
3549                    print_req_type_helpers(ri)
3550                    cw.nl()
3551                    print_rsp_type(ri)
3552                    print_rsp_type_helpers(ri)
3553                    cw.nl()
3554                    print_req_prototype(ri)
3555                    cw.nl()
3556
3557                if 'dump' in op:
3558                    cw.p(f"/* {op.enum_name} - dump */")
3559                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3560                    print_req_type(ri)
3561                    print_req_type_helpers(ri)
3562                    if not ri.type_consistent or ri.type_oneside:
3563                        print_rsp_type(ri)
3564                    print_wrapped_type(ri)
3565                    print_dump_prototype(ri)
3566                    cw.nl()
3567
3568                if op.has_ntf:
3569                    cw.p(f"/* {op.enum_name} - notify */")
3570                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3571                    if not ri.type_consistent:
3572                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3573                    print_wrapped_type(ri)
3574
3575            for op_name, op in parsed.ntfs.items():
3576                if 'event' in op:
3577                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3578                    cw.p(f"/* {op.enum_name} - event */")
3579                    print_rsp_type(ri)
3580                    cw.nl()
3581                    print_wrapped_type(ri)
3582            cw.nl()
3583        else:
3584            cw.p('/* Enums */')
3585            put_op_name(parsed, cw)
3586
3587            for name, const in parsed.consts.items():
3588                if isinstance(const, EnumSet):
3589                    put_enum_to_str(parsed, cw, const)
3590            cw.nl()
3591
3592            has_recursive_nests = False
3593            cw.p('/* Policies */')
3594            for struct in parsed.pure_nested_structs.values():
3595                if struct.recursive:
3596                    put_typol_fwd(cw, struct)
3597                    has_recursive_nests = True
3598            if has_recursive_nests:
3599                cw.nl()
3600            for struct in parsed.pure_nested_structs.values():
3601                put_typol(cw, struct)
3602            for name in parsed.root_sets:
3603                struct = Struct(parsed, name)
3604                put_typol(cw, struct)
3605
3606            cw.p('/* Common nested types */')
3607            if has_recursive_nests:
3608                for attr_set, struct in parsed.pure_nested_structs.items():
3609                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3610                    free_rsp_nested_prototype(ri)
3611                    if struct.request:
3612                        put_req_nested_prototype(ri, struct)
3613                    if struct.reply:
3614                        parse_rsp_nested_prototype(ri, struct)
3615                cw.nl()
3616            for attr_set, struct in parsed.pure_nested_structs.items():
3617                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3618
3619                free_rsp_nested(ri, struct)
3620                if struct.request:
3621                    put_req_nested(ri, struct)
3622                if struct.reply:
3623                    parse_rsp_nested(ri, struct)
3624
3625            for op_name, op in parsed.ops.items():
3626                cw.p(f"/* ============== {op.enum_name} ============== */")
3627                if 'do' in op and 'event' not in op:
3628                    cw.p(f"/* {op.enum_name} - do */")
3629                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3630                    print_req_free(ri)
3631                    print_rsp_free(ri)
3632                    parse_rsp_msg(ri)
3633                    print_req(ri)
3634                    cw.nl()
3635
3636                if 'dump' in op:
3637                    cw.p(f"/* {op.enum_name} - dump */")
3638                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3639                    if not ri.type_consistent or ri.type_oneside:
3640                        parse_rsp_msg(ri, deref=True)
3641                    print_req_free(ri)
3642                    print_dump_type_free(ri)
3643                    print_dump(ri)
3644                    cw.nl()
3645
3646                if op.has_ntf:
3647                    cw.p(f"/* {op.enum_name} - notify */")
3648                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3649                    if not ri.type_consistent:
3650                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3651                    print_ntf_type_free(ri)
3652
3653            for op_name, op in parsed.ntfs.items():
3654                if 'event' in op:
3655                    cw.p(f"/* {op.enum_name} - event */")
3656
3657                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3658                    parse_rsp_msg(ri)
3659
3660                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3661                    print_ntf_type_free(ri)
3662            cw.nl()
3663            render_user_family(parsed, cw, False)
3664
3665    if args.header:
3666        cw.p(f'#endif /* {hdr_prot} */')
3667
3668
3669if __name__ == "__main__":
3670    main()
3671