xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 8df78d97e49816eacb7a8c4da63bf1e8be4e20ac)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import filecmp
6import pathlib
7import os
8import re
9import shutil
10import sys
11import tempfile
12import yaml
13
14sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
15from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
16from lib import SpecSubMessage
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        self.is_selector = False
60
61        if 'len' in attr:
62            self.len = attr['len']
63
64        if 'nested-attributes' in attr:
65            nested = attr['nested-attributes']
66        elif 'sub-message' in attr:
67            nested = attr['sub-message']
68        else:
69            nested = None
70
71        if nested:
72            self.nested_attrs = nested
73            if self.nested_attrs == family.name:
74                self.nested_render_name = c_lower(f"{family.ident_name}")
75            else:
76                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
77
78            if self.nested_attrs in self.family.consts:
79                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
80            else:
81                self.nested_struct_type = 'struct ' + self.nested_render_name
82
83        self.c_name = c_lower(self.name)
84        if self.c_name in _C_KW:
85            self.c_name += '_'
86        if self.c_name[0].isdigit():
87            self.c_name = '_' + self.c_name
88
89        # Added by resolve():
90        self.enum_name = None
91        delattr(self, "enum_name")
92
93    def _get_real_attr(self):
94        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
95        return self.family.attr_sets[self.attr_set.subset_of][self.name]
96
97    def set_request(self):
98        self.request = True
99        if self.attr_set.subset_of:
100            self._get_real_attr().set_request()
101
102    def set_reply(self):
103        self.reply = True
104        if self.attr_set.subset_of:
105            self._get_real_attr().set_reply()
106
107    def get_limit(self, limit, default=None):
108        value = self.checks.get(limit, default)
109        if value is None:
110            return value
111        if isinstance(value, int):
112            return value
113        if value in self.family.consts:
114            return self.family.consts[value]["value"]
115        return limit_to_number(value)
116
117    def get_limit_str(self, limit, default=None, suffix=''):
118        value = self.checks.get(limit, default)
119        if value is None:
120            return ''
121        if isinstance(value, int):
122            return str(value) + suffix
123        if value in self.family.consts:
124            const = self.family.consts[value]
125            if const.get('header'):
126                return c_upper(value)
127            return c_upper(f"{self.family['name']}-{value}")
128        return c_upper(value)
129
130    def resolve(self):
131        if 'parent-sub-message' in self.attr:
132            enum_name = self.attr['parent-sub-message'].enum_name
133        elif 'name-prefix' in self.attr:
134            enum_name = f"{self.attr['name-prefix']}{self.name}"
135        else:
136            enum_name = f"{self.attr_set.name_prefix}{self.name}"
137        self.enum_name = c_upper(enum_name)
138
139        if self.attr_set.subset_of:
140            if self.checks != self._get_real_attr().checks:
141                raise Exception("Overriding checks not supported by codegen, yet")
142
143    def is_multi_val(self):
144        return None
145
146    def is_scalar(self):
147        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
148
149    def is_recursive(self):
150        return False
151
152    def is_recursive_for_op(self, ri):
153        return self.is_recursive() and not ri.op
154
155    def presence_type(self):
156        return 'present'
157
158    def presence_member(self, space, type_filter):
159        if self.presence_type() != type_filter:
160            return
161
162        if self.presence_type() == 'present':
163            pfx = '__' if space == 'user' else ''
164            return f"{pfx}u32 {self.c_name}:1;"
165
166        if self.presence_type() in {'len', 'count'}:
167            pfx = '__' if space == 'user' else ''
168            return f"{pfx}u32 {self.c_name};"
169
170    def _complex_member_type(self, ri):
171        return None
172
173    def free_needs_iter(self):
174        return False
175
176    def _free_lines(self, ri, var, ref):
177        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
178            return [f'free({var}->{ref}{self.c_name});']
179        return []
180
181    def free(self, ri, var, ref):
182        lines = self._free_lines(ri, var, ref)
183        for line in lines:
184            ri.cw.p(line)
185
186    def arg_member(self, ri):
187        member = self._complex_member_type(ri)
188        if member:
189            spc = ' ' if member[-1] != '*' else ''
190            arg = [member + spc + '*' + self.c_name]
191            if self.presence_type() == 'count':
192                arg += ['unsigned int n_' + self.c_name]
193            return arg
194        raise Exception(f"Struct member not implemented for class type {self.type}")
195
196    def struct_member(self, ri):
197        member = self._complex_member_type(ri)
198        if member:
199            ptr = '*' if self.is_multi_val() else ''
200            if self.is_recursive_for_op(ri):
201                ptr = '*'
202            spc = ' ' if member[-1] != '*' else ''
203            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
204            return
205        members = self.arg_member(ri)
206        for one in members:
207            ri.cw.p(one + ';')
208
209    def _attr_policy(self, policy):
210        return '{ .type = ' + policy + ', }'
211
212    def attr_policy(self, cw):
213        policy = f'NLA_{c_upper(self.type)}'
214        if self.attr.get('byte-order') == 'big-endian':
215            if self.type in {'u16', 'u32'}:
216                policy = f'NLA_BE{self.type[1:]}'
217
218        spec = self._attr_policy(policy)
219        cw.p(f"\t[{self.enum_name}] = {spec},")
220
221    def _attr_typol(self):
222        raise Exception(f"Type policy not implemented for class type {self.type}")
223
224    def attr_typol(self, cw):
225        typol = self._attr_typol()
226        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
227
228    def _attr_put_line(self, ri, var, line):
229        presence = self.presence_type()
230        if presence in {'present', 'len'}:
231            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
232        ri.cw.p(f"{line};")
233
234    def _attr_put_simple(self, ri, var, put_type):
235        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
236        self._attr_put_line(ri, var, line)
237
238    def attr_put(self, ri, var):
239        raise Exception(f"Put not implemented for class type {self.type}")
240
241    def _attr_get(self, ri, var):
242        raise Exception(f"Attr get not implemented for class type {self.type}")
243
244    def attr_get(self, ri, var, first):
245        lines, init_lines, local_vars = self._attr_get(ri, var)
246        if type(lines) is str:
247            lines = [lines]
248        if type(init_lines) is str:
249            init_lines = [init_lines]
250
251        kw = 'if' if first else 'else if'
252        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
253        if local_vars:
254            for local in local_vars:
255                ri.cw.p(local)
256            ri.cw.nl()
257
258        if not self.is_multi_val():
259            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
260            ri.cw.p("return YNL_PARSE_CB_ERROR;")
261            if self.presence_type() == 'present':
262                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
263
264        if init_lines:
265            ri.cw.nl()
266            for line in init_lines:
267                ri.cw.p(line)
268
269        for line in lines:
270            ri.cw.p(line)
271        ri.cw.block_end()
272        return True
273
274    def _setter_lines(self, ri, member, presence):
275        raise Exception(f"Setter not implemented for class type {self.type}")
276
277    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
278        ref = (ref if ref else []) + [self.c_name]
279        member = f"{var}->{'.'.join(ref)}"
280
281        local_vars = []
282        if self.free_needs_iter():
283            local_vars += ['unsigned int i;']
284
285        code = []
286        presence = ''
287        for i in range(0, len(ref)):
288            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
289            # Every layer below last is a nest, so we know it uses bit presence
290            # last layer is "self" and may be a complex type
291            if i == len(ref) - 1 and self.presence_type() != 'present':
292                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
293                continue
294            code.append(presence + ' = 1;')
295        ref_path = '.'.join(ref[:-1])
296        if ref_path:
297            ref_path += '.'
298        code += self._free_lines(ri, var, ref_path)
299        code += self._setter_lines(ri, member, presence)
300
301        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
302        free = bool([x for x in code if 'free(' in x])
303        alloc = bool([x for x in code if 'alloc(' in x])
304        if free and not alloc:
305            func_name = '__' + func_name
306        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
307                         body=code,
308                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
309
310
311class TypeUnused(Type):
312    def presence_type(self):
313        return ''
314
315    def arg_member(self, ri):
316        return []
317
318    def _attr_get(self, ri, var):
319        return ['return YNL_PARSE_CB_ERROR;'], None, None
320
321    def _attr_typol(self):
322        return '.type = YNL_PT_REJECT, '
323
324    def attr_policy(self, cw):
325        pass
326
327    def attr_put(self, ri, var):
328        pass
329
330    def attr_get(self, ri, var, first):
331        pass
332
333    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
334        pass
335
336
337class TypePad(Type):
338    def presence_type(self):
339        return ''
340
341    def arg_member(self, ri):
342        return []
343
344    def _attr_typol(self):
345        return '.type = YNL_PT_IGNORE, '
346
347    def attr_put(self, ri, var):
348        pass
349
350    def attr_get(self, ri, var, first):
351        pass
352
353    def attr_policy(self, cw):
354        pass
355
356    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
357        pass
358
359
360class TypeScalar(Type):
361    def __init__(self, family, attr_set, attr, value):
362        super().__init__(family, attr_set, attr, value)
363
364        self.byte_order_comment = ''
365        if 'byte-order' in attr:
366            self.byte_order_comment = f" /* {attr['byte-order']} */"
367
368        # Classic families have some funny enums, don't bother
369        # computing checks, since we only need them for kernel policies
370        if not family.is_classic():
371            self._init_checks()
372
373        # Added by resolve():
374        self.is_bitfield = None
375        delattr(self, "is_bitfield")
376        self.type_name = None
377        delattr(self, "type_name")
378
379    def resolve(self):
380        self.resolve_up(super())
381
382        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
383            self.is_bitfield = True
384        elif 'enum' in self.attr:
385            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
386        else:
387            self.is_bitfield = False
388
389        if not self.is_bitfield and 'enum' in self.attr:
390            self.type_name = self.family.consts[self.attr['enum']].user_type
391        elif self.is_auto_scalar:
392            self.type_name = '__' + self.type[0] + '64'
393        else:
394            self.type_name = '__' + self.type
395
396    def _init_checks(self):
397        if 'enum' in self.attr:
398            enum = self.family.consts[self.attr['enum']]
399            low, high = enum.value_range()
400            if low is None and high is None:
401                self.checks['sparse'] = True
402            else:
403                if 'min' not in self.checks:
404                    if low != 0 or self.type[0] == 's':
405                        self.checks['min'] = low
406                if 'max' not in self.checks:
407                    self.checks['max'] = high
408
409        if 'min' in self.checks and 'max' in self.checks:
410            if self.get_limit('min') > self.get_limit('max'):
411                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
412            self.checks['range'] = True
413
414        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
415        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
416        if low < 0 and self.type[0] == 'u':
417            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
418        if low < -32768 or high > 32767:
419            self.checks['full-range'] = True
420
421    def _attr_policy(self, policy):
422        if 'flags-mask' in self.checks or self.is_bitfield:
423            if self.is_bitfield:
424                enum = self.family.consts[self.attr['enum']]
425                mask = enum.get_mask(as_flags=True)
426            else:
427                flags = self.family.consts[self.checks['flags-mask']]
428                flag_cnt = len(flags['entries'])
429                mask = (1 << flag_cnt) - 1
430            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
431        elif 'full-range' in self.checks:
432            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
433        elif 'range' in self.checks:
434            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
435        elif 'min' in self.checks:
436            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
437        elif 'max' in self.checks:
438            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
439        elif 'sparse' in self.checks:
440            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
441        return super()._attr_policy(policy)
442
443    def _attr_typol(self):
444        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
445
446    def arg_member(self, ri):
447        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
448
449    def attr_put(self, ri, var):
450        self._attr_put_simple(ri, var, self.type)
451
452    def _attr_get(self, ri, var):
453        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
454
455    def _setter_lines(self, ri, member, presence):
456        return [f"{member} = {self.c_name};"]
457
458
459class TypeFlag(Type):
460    def arg_member(self, ri):
461        return []
462
463    def _attr_typol(self):
464        return '.type = YNL_PT_FLAG, '
465
466    def attr_put(self, ri, var):
467        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
468
469    def _attr_get(self, ri, var):
470        return [], None, None
471
472    def _setter_lines(self, ri, member, presence):
473        return []
474
475
476class TypeString(Type):
477    def arg_member(self, ri):
478        return [f"const char *{self.c_name}"]
479
480    def presence_type(self):
481        return 'len'
482
483    def struct_member(self, ri):
484        ri.cw.p(f"char *{self.c_name};")
485
486    def _attr_typol(self):
487        typol = '.type = YNL_PT_NUL_STR, '
488        if self.is_selector:
489            typol += '.is_selector = 1, '
490        return typol
491
492    def _attr_policy(self, policy):
493        if 'exact-len' in self.checks:
494            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
495        else:
496            mem = '{ .type = ' + policy
497            if 'max-len' in self.checks:
498                mem += ', .len = ' + self.get_limit_str('max-len')
499            mem += ', }'
500        return mem
501
502    def attr_policy(self, cw):
503        if self.checks.get('unterminated-ok', False):
504            policy = 'NLA_STRING'
505        else:
506            policy = 'NLA_NUL_STRING'
507
508        spec = self._attr_policy(policy)
509        cw.p(f"\t[{self.enum_name}] = {spec},")
510
511    def attr_put(self, ri, var):
512        self._attr_put_simple(ri, var, 'str')
513
514    def _attr_get(self, ri, var):
515        len_mem = var + '->_len.' + self.c_name
516        return [f"{len_mem} = len;",
517                f"{var}->{self.c_name} = malloc(len + 1);",
518                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
519                f"{var}->{self.c_name}[len] = 0;"], \
520               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
521               ['unsigned int len;']
522
523    def _setter_lines(self, ri, member, presence):
524        return [f"{presence} = strlen({self.c_name});",
525                f"{member} = malloc({presence} + 1);",
526                f'memcpy({member}, {self.c_name}, {presence});',
527                f'{member}[{presence}] = 0;']
528
529
530class TypeBinary(Type):
531    def arg_member(self, ri):
532        return [f"const void *{self.c_name}", 'size_t len']
533
534    def presence_type(self):
535        return 'len'
536
537    def struct_member(self, ri):
538        ri.cw.p(f"void *{self.c_name};")
539
540    def _attr_typol(self):
541        return '.type = YNL_PT_BINARY,'
542
543    def _attr_policy(self, policy):
544        if len(self.checks) == 0:
545            pass
546        elif len(self.checks) == 1:
547            check_name = list(self.checks)[0]
548            if check_name not in {'exact-len', 'min-len', 'max-len'}:
549                raise Exception('Unsupported check for binary type: ' + check_name)
550        else:
551            raise Exception('More than one check for binary type not implemented, yet')
552
553        if len(self.checks) == 0:
554            mem = '{ .type = NLA_BINARY, }'
555        elif 'exact-len' in self.checks:
556            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
557        elif 'min-len' in self.checks:
558            mem = 'NLA_POLICY_MIN_LEN(' + self.get_limit_str('min-len') + ')'
559        elif 'max-len' in self.checks:
560            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
561
562        return mem
563
564    def attr_put(self, ri, var):
565        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
566                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
567
568    def _attr_get(self, ri, var):
569        len_mem = var + '->_len.' + self.c_name
570        return [f"{len_mem} = len;",
571                f"{var}->{self.c_name} = malloc(len);",
572                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
573               ['len = ynl_attr_data_len(attr);'], \
574               ['unsigned int len;']
575
576    def _setter_lines(self, ri, member, presence):
577        return [f"{presence} = len;",
578                f"{member} = malloc({presence});",
579                f'memcpy({member}, {self.c_name}, {presence});']
580
581
582class TypeBinaryStruct(TypeBinary):
583    def struct_member(self, ri):
584        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
585
586    def _attr_get(self, ri, var):
587        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
588        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
589        return [f"{len_mem} = len;",
590                f"if (len < {struct_sz})",
591                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
592                "else",
593                f"{var}->{self.c_name} = malloc(len);",
594                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
595               ['len = ynl_attr_data_len(attr);'], \
596               ['unsigned int len;']
597
598
599class TypeBinaryScalarArray(TypeBinary):
600    def arg_member(self, ri):
601        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
602
603    def presence_type(self):
604        return 'count'
605
606    def struct_member(self, ri):
607        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
608
609    def attr_put(self, ri, var):
610        presence = self.presence_type()
611        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
612        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
613        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
614                f"{var}->{self.c_name}, i);")
615        ri.cw.block_end()
616
617    def _attr_get(self, ri, var):
618        len_mem = var + '->_count.' + self.c_name
619        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
620                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
621                f"{var}->{self.c_name} = malloc(len);",
622                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
623               ['len = ynl_attr_data_len(attr);'], \
624               ['unsigned int len;']
625
626    def _setter_lines(self, ri, member, presence):
627        return [f"{presence} = count;",
628                f"count *= sizeof(__{self.get('sub-type')});",
629                f"{member} = malloc(count);",
630                f'memcpy({member}, {self.c_name}, count);']
631
632
633class TypeBitfield32(Type):
634    def _complex_member_type(self, ri):
635        return "struct nla_bitfield32"
636
637    def _attr_typol(self):
638        return '.type = YNL_PT_BITFIELD32, '
639
640    def _attr_policy(self, policy):
641        if 'enum' not in self.attr:
642            raise Exception('Enum required for bitfield32 attr')
643        enum = self.family.consts[self.attr['enum']]
644        mask = enum.get_mask(as_flags=True)
645        return f"NLA_POLICY_BITFIELD32({mask})"
646
647    def attr_put(self, ri, var):
648        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
649        self._attr_put_line(ri, var, line)
650
651    def _attr_get(self, ri, var):
652        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
653
654    def _setter_lines(self, ri, member, presence):
655        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
656
657
658class TypeNest(Type):
659    def is_recursive(self):
660        return self.family.pure_nested_structs[self.nested_attrs].recursive
661
662    def _complex_member_type(self, ri):
663        return self.nested_struct_type
664
665    def _free_lines(self, ri, var, ref):
666        lines = []
667        at = '&'
668        if self.is_recursive_for_op(ri):
669            at = ''
670            lines += [f'if ({var}->{ref}{self.c_name})']
671        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
672        return lines
673
674    def _attr_typol(self):
675        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
676
677    def _attr_policy(self, policy):
678        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
679
680    def attr_put(self, ri, var):
681        at = '' if self.is_recursive_for_op(ri) else '&'
682        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
683                            f"{self.enum_name}, {at}{var}->{self.c_name})")
684
685    def _attr_get(self, ri, var):
686        pns = self.family.pure_nested_structs[self.nested_attrs]
687        args = ["&parg", "attr"]
688        for sel in pns.external_selectors():
689            args.append(f'{var}->{sel.name}')
690        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
691                     "return YNL_PARSE_CB_ERROR;"]
692        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
693                      f"parg.data = &{var}->{self.c_name};"]
694        return get_lines, init_lines, None
695
696    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
697        ref = (ref if ref else []) + [self.c_name]
698
699        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
700            if attr.is_recursive():
701                continue
702            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref,
703                        var=var)
704
705
706class TypeMultiAttr(Type):
707    def __init__(self, family, attr_set, attr, value, base_type):
708        super().__init__(family, attr_set, attr, value)
709
710        self.base_type = base_type
711
712    def is_multi_val(self):
713        return True
714
715    def presence_type(self):
716        return 'count'
717
718    def _complex_member_type(self, ri):
719        if 'type' not in self.attr or self.attr['type'] == 'nest':
720            return self.nested_struct_type
721        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
722            return None  # use arg_member()
723        elif self.attr['type'] == 'string':
724            return 'struct ynl_string *'
725        elif self.attr['type'] in scalars:
726            scalar_pfx = '__' if ri.ku_space == 'user' else ''
727            return scalar_pfx + self.attr['type']
728        else:
729            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
730
731    def arg_member(self, ri):
732        if self.type == 'binary' and 'struct' in self.attr:
733            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
734                    f'unsigned int n_{self.c_name}']
735        return super().arg_member(ri)
736
737    def free_needs_iter(self):
738        return self.attr['type'] in {'nest', 'string'}
739
740    def _free_lines(self, ri, var, ref):
741        lines = []
742        if self.attr['type'] in scalars:
743            lines += [f"free({var}->{ref}{self.c_name});"]
744        elif self.attr['type'] == 'binary':
745            lines += [f"free({var}->{ref}{self.c_name});"]
746        elif self.attr['type'] == 'string':
747            lines += [
748                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
749                f"free({var}->{ref}{self.c_name}[i]);",
750                f"free({var}->{ref}{self.c_name});",
751            ]
752        elif 'type' not in self.attr or self.attr['type'] == 'nest':
753            lines += [
754                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
755                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
756                f"free({var}->{ref}{self.c_name});",
757            ]
758        else:
759            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
760        return lines
761
762    def _attr_policy(self, policy):
763        return self.base_type._attr_policy(policy)
764
765    def _attr_typol(self):
766        return self.base_type._attr_typol()
767
768    def _attr_get(self, ri, var):
769        return f'n_{self.c_name}++;', None, None
770
771    def attr_put(self, ri, var):
772        if self.attr['type'] in scalars:
773            put_type = self.type
774            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
775            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
776        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
777            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
778            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
779        elif self.attr['type'] == 'string':
780            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
781            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
782        elif 'type' not in self.attr or self.attr['type'] == 'nest':
783            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
784            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
785                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
786        else:
787            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
788
789    def _setter_lines(self, ri, member, presence):
790        return [f"{member} = {self.c_name};",
791                f"{presence} = n_{self.c_name};"]
792
793
794class TypeArrayNest(Type):
795    def is_multi_val(self):
796        return True
797
798    def presence_type(self):
799        return 'count'
800
801    def _complex_member_type(self, ri):
802        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
803            return self.nested_struct_type
804        elif self.attr['sub-type'] in scalars:
805            scalar_pfx = '__' if ri.ku_space == 'user' else ''
806            return scalar_pfx + self.attr['sub-type']
807        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
808            return None  # use arg_member()
809        else:
810            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
811
812    def arg_member(self, ri):
813        if self.sub_type == 'binary' and 'exact-len' in self.checks:
814            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
815                    f'unsigned int n_{self.c_name}']
816        return super().arg_member(ri)
817
818    def _attr_policy(self, policy):
819        if self.attr['sub-type'] == 'nest':
820            return f'NLA_POLICY_NESTED_ARRAY({self.nested_render_name}_nl_policy)'
821        return super()._attr_policy(policy)
822
823    def _attr_typol(self):
824        if self.attr['sub-type'] in scalars:
825            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
826        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
827            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
828        elif self.attr['sub-type'] == 'nest':
829            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
830        else:
831            raise Exception(f"Typol for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
832
833    def _attr_get(self, ri, var):
834        local_vars = ['const struct nlattr *attr2;']
835        get_lines = [f'attr_{self.c_name} = attr;',
836                     'ynl_attr_for_each_nested(attr2, attr) {',
837                     '\tif (ynl_attr_validate(yarg, attr2))',
838                     '\t\treturn YNL_PARSE_CB_ERROR;',
839                     f'\tn_{self.c_name}++;',
840                     '}']
841        return get_lines, None, local_vars
842
843    def attr_put(self, ri, var):
844        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
845        if self.sub_type in scalars:
846            put_type = self.sub_type
847            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
848            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
849            ri.cw.block_end()
850        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
851            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
852            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
853        elif self.sub_type == 'nest':
854            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
855            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
856        else:
857            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
858        ri.cw.p('ynl_attr_nest_end(nlh, array);')
859
860    def _setter_lines(self, ri, member, presence):
861        return [f"{member} = {self.c_name};",
862                f"{presence} = n_{self.c_name};"]
863
864
865class TypeNestTypeValue(Type):
866    def _complex_member_type(self, ri):
867        return self.nested_struct_type
868
869    def _attr_typol(self):
870        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
871
872    def _attr_get(self, ri, var):
873        prev = 'attr'
874        tv_args = ''
875        get_lines = []
876        local_vars = []
877        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
878                      f"parg.data = &{var}->{self.c_name};"]
879        if 'type-value' in self.attr:
880            tv_names = [c_lower(x) for x in self.attr["type-value"]]
881            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
882            local_vars += [f'__u32 {", ".join(tv_names)};']
883            for level in self.attr["type-value"]:
884                level = c_lower(level)
885                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
886                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
887                prev = 'attr_' + level
888
889            tv_args = f", {', '.join(tv_names)}"
890
891        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
892        return get_lines, init_lines, local_vars
893
894
895class TypeSubMessage(TypeNest):
896    def __init__(self, family, attr_set, attr, value):
897        super().__init__(family, attr_set, attr, value)
898
899        self.selector = Selector(attr, attr_set)
900
901    def _attr_typol(self):
902        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
903        typol += '.is_submsg = 1, '
904        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
905        # support external selectors. No family uses sub-messages with external
906        # selector for requests so this is fine for now.
907        if not self.selector.is_external():
908            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
909        return typol
910
911    def _attr_get(self, ri, var):
912        sel = c_lower(self['selector'])
913        if self.selector.is_external():
914            sel_var = f"_sel_{sel}"
915        else:
916            sel_var = f"{var}->{sel}"
917        get_lines = [f'if (!{sel_var})',
918                     'return ynl_submsg_failed(yarg, "%s", "%s");' %
919                        (self.name, self['selector']),
920                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
921                     "return YNL_PARSE_CB_ERROR;"]
922        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
923                      f"parg.data = &{var}->{self.c_name};"]
924        return get_lines, init_lines, None
925
926
927class Selector:
928    def __init__(self, msg_attr, attr_set):
929        self.name = msg_attr["selector"]
930
931        if self.name in attr_set:
932            self.attr = attr_set[self.name]
933            self.attr.is_selector = True
934            self._external = False
935        else:
936            # The selector will need to get passed down thru the structs
937            self.attr = None
938            self._external = True
939
940    def set_attr(self, attr):
941        self.attr = attr
942
943    def is_external(self):
944        return self._external
945
946
947class Struct:
948    def __init__(self, family, space_name, type_list=None, fixed_header=None,
949                 inherited=None, submsg=None):
950        self.family = family
951        self.space_name = space_name
952        self.attr_set = family.attr_sets[space_name]
953        # Use list to catch comparisons with empty sets
954        self._inherited = inherited if inherited is not None else []
955        self.inherited = []
956        self.fixed_header = None
957        if fixed_header:
958            self.fixed_header = 'struct ' + c_lower(fixed_header)
959        self.submsg = submsg
960
961        self.nested = type_list is None
962        if family.name == c_lower(space_name):
963            self.render_name = c_lower(family.ident_name)
964        else:
965            self.render_name = c_lower(family.ident_name + '-' + space_name)
966        self.struct_name = 'struct ' + self.render_name
967        if self.nested and space_name in family.consts:
968            self.struct_name += '_'
969        self.ptr_name = self.struct_name + ' *'
970        # All attr sets this one contains, directly or multiple levels down
971        self.child_nests = set()
972
973        self.request = False
974        self.reply = False
975        self.recursive = False
976        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
977
978        self.attr_list = []
979        self.attrs = dict()
980        if type_list is not None:
981            for t in type_list:
982                self.attr_list.append((t, self.attr_set[t]),)
983        else:
984            for t in self.attr_set:
985                self.attr_list.append((t, self.attr_set[t]),)
986
987        max_val = 0
988        self.attr_max_val = None
989        for name, attr in self.attr_list:
990            if attr.value >= max_val:
991                max_val = attr.value
992                self.attr_max_val = attr
993            self.attrs[name] = attr
994
995    def __iter__(self):
996        yield from self.attrs
997
998    def __getitem__(self, key):
999        return self.attrs[key]
1000
1001    def member_list(self):
1002        return self.attr_list
1003
1004    def set_inherited(self, new_inherited):
1005        if self._inherited != new_inherited:
1006            raise Exception("Inheriting different members not supported")
1007        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1008
1009    def external_selectors(self):
1010        sels = []
1011        for name, attr in self.attr_list:
1012            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1013                sels.append(attr.selector)
1014        return sels
1015
1016    def free_needs_iter(self):
1017        for _, attr in self.attr_list:
1018            if attr.free_needs_iter():
1019                return True
1020        return False
1021
1022
1023class EnumEntry(SpecEnumEntry):
1024    def __init__(self, enum_set, yaml, prev, value_start):
1025        super().__init__(enum_set, yaml, prev, value_start)
1026
1027        if prev:
1028            self.value_change = (self.value != prev.value + 1)
1029        else:
1030            self.value_change = (self.value != 0)
1031        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1032
1033        # Added by resolve:
1034        self.c_name = None
1035        delattr(self, "c_name")
1036
1037    def resolve(self):
1038        self.resolve_up(super())
1039
1040        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1041
1042
1043class EnumSet(SpecEnumSet):
1044    def __init__(self, family, yaml):
1045        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1046
1047        if 'enum-name' in yaml:
1048            if yaml['enum-name']:
1049                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1050                self.user_type = self.enum_name
1051            else:
1052                self.enum_name = None
1053        else:
1054            self.enum_name = 'enum ' + self.render_name
1055
1056        if self.enum_name:
1057            self.user_type = self.enum_name
1058        else:
1059            self.user_type = 'int'
1060
1061        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1062        self.header = yaml.get('header', None)
1063        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1064
1065        super().__init__(family, yaml)
1066
1067    def new_entry(self, entry, prev_entry, value_start):
1068        return EnumEntry(self, entry, prev_entry, value_start)
1069
1070    def value_range(self):
1071        low = min([x.value for x in self.entries.values()])
1072        high = max([x.value for x in self.entries.values()])
1073
1074        if high - low + 1 != len(self.entries):
1075            return None, None
1076
1077        return low, high
1078
1079
1080class AttrSet(SpecAttrSet):
1081    def __init__(self, family, yaml):
1082        super().__init__(family, yaml)
1083
1084        if self.subset_of is None:
1085            if 'name-prefix' in yaml:
1086                pfx = yaml['name-prefix']
1087            elif self.name == family.name:
1088                pfx = family.ident_name + '-a-'
1089            else:
1090                pfx = f"{family.ident_name}-a-{self.name}-"
1091            self.name_prefix = c_upper(pfx)
1092            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1093            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1094        else:
1095            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1096            self.max_name = family.attr_sets[self.subset_of].max_name
1097            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1098
1099        # Added by resolve:
1100        self.c_name = None
1101        delattr(self, "c_name")
1102
1103    def resolve(self):
1104        self.c_name = c_lower(self.name)
1105        if self.c_name in _C_KW:
1106            self.c_name += '_'
1107        if self.c_name == self.family.c_name:
1108            self.c_name = ''
1109
1110    def new_attr(self, elem, value):
1111        if elem['type'] in scalars:
1112            t = TypeScalar(self.family, self, elem, value)
1113        elif elem['type'] == 'unused':
1114            t = TypeUnused(self.family, self, elem, value)
1115        elif elem['type'] == 'pad':
1116            t = TypePad(self.family, self, elem, value)
1117        elif elem['type'] == 'flag':
1118            t = TypeFlag(self.family, self, elem, value)
1119        elif elem['type'] == 'string':
1120            t = TypeString(self.family, self, elem, value)
1121        elif elem['type'] == 'binary':
1122            if 'struct' in elem:
1123                t = TypeBinaryStruct(self.family, self, elem, value)
1124            elif elem.get('sub-type') in scalars:
1125                t = TypeBinaryScalarArray(self.family, self, elem, value)
1126            else:
1127                t = TypeBinary(self.family, self, elem, value)
1128        elif elem['type'] == 'bitfield32':
1129            t = TypeBitfield32(self.family, self, elem, value)
1130        elif elem['type'] == 'nest':
1131            t = TypeNest(self.family, self, elem, value)
1132        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1133            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1134                t = TypeArrayNest(self.family, self, elem, value)
1135            else:
1136                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1137        elif elem['type'] == 'nest-type-value':
1138            t = TypeNestTypeValue(self.family, self, elem, value)
1139        elif elem['type'] == 'sub-message':
1140            t = TypeSubMessage(self.family, self, elem, value)
1141        else:
1142            raise Exception(f"No typed class for type {elem['type']}")
1143
1144        if 'multi-attr' in elem and elem['multi-attr']:
1145            t = TypeMultiAttr(self.family, self, elem, value, t)
1146
1147        return t
1148
1149
1150class Operation(SpecOperation):
1151    def __init__(self, family, yaml, req_value, rsp_value):
1152        # Fill in missing operation properties (for fixed hdr-only msgs)
1153        for mode in ['do', 'dump', 'event']:
1154            for direction in ['request', 'reply']:
1155                try:
1156                    yaml[mode][direction].setdefault('attributes', [])
1157                except KeyError:
1158                    pass
1159
1160        super().__init__(family, yaml, req_value, rsp_value)
1161
1162        self.render_name = c_lower(family.ident_name + '_' + self.name)
1163
1164        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1165                         ('dump' in yaml and 'request' in yaml['dump'])
1166
1167        self.has_ntf = False
1168
1169        # Added by resolve:
1170        self.enum_name = None
1171        delattr(self, "enum_name")
1172
1173    def resolve(self):
1174        self.resolve_up(super())
1175
1176        if not self.is_async:
1177            self.enum_name = self.family.op_prefix + c_upper(self.name)
1178        else:
1179            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1180
1181    def mark_has_ntf(self):
1182        self.has_ntf = True
1183
1184
1185class SubMessage(SpecSubMessage):
1186    def __init__(self, family, yaml):
1187        super().__init__(family, yaml)
1188
1189        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1190
1191    def resolve(self):
1192        self.resolve_up(super())
1193
1194
1195class Family(SpecFamily):
1196    def __init__(self, file_name, exclude_ops):
1197        # Added by resolve:
1198        self.c_name = None
1199        delattr(self, "c_name")
1200        self.op_prefix = None
1201        delattr(self, "op_prefix")
1202        self.async_op_prefix = None
1203        delattr(self, "async_op_prefix")
1204        self.mcgrps = None
1205        delattr(self, "mcgrps")
1206        self.consts = None
1207        delattr(self, "consts")
1208        self.hooks = None
1209        delattr(self, "hooks")
1210
1211        super().__init__(file_name, exclude_ops=exclude_ops)
1212
1213        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1214        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1215
1216        if 'definitions' not in self.yaml:
1217            self.yaml['definitions'] = []
1218
1219        if 'uapi-header' in self.yaml:
1220            self.uapi_header = self.yaml['uapi-header']
1221        else:
1222            self.uapi_header = f"linux/{self.ident_name}.h"
1223        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1224            self.uapi_header_name = self.uapi_header[6:-2]
1225        else:
1226            self.uapi_header_name = self.ident_name
1227
1228    def resolve(self):
1229        self.resolve_up(super())
1230
1231        self.c_name = c_lower(self.ident_name)
1232        if 'name-prefix' in self.yaml['operations']:
1233            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1234        else:
1235            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1236        if 'async-prefix' in self.yaml['operations']:
1237            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1238        else:
1239            self.async_op_prefix = self.op_prefix
1240
1241        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1242
1243        self.hooks = dict()
1244        for when in ['pre', 'post']:
1245            self.hooks[when] = dict()
1246            for op_mode in ['do', 'dump']:
1247                self.hooks[when][op_mode] = dict()
1248                self.hooks[when][op_mode]['set'] = set()
1249                self.hooks[when][op_mode]['list'] = []
1250
1251        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1252        self.root_sets = dict()
1253        # dict space-name -> Struct
1254        self.pure_nested_structs = dict()
1255
1256        self._mark_notify()
1257        self._mock_up_events()
1258
1259        self._load_root_sets()
1260        self._load_nested_sets()
1261        self._load_attr_use()
1262        self._load_selector_passing()
1263        self._load_hooks()
1264
1265        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1266        if self.kernel_policy == 'global':
1267            self._load_global_policy()
1268
1269    def new_enum(self, elem):
1270        return EnumSet(self, elem)
1271
1272    def new_attr_set(self, elem):
1273        return AttrSet(self, elem)
1274
1275    def new_operation(self, elem, req_value, rsp_value):
1276        return Operation(self, elem, req_value, rsp_value)
1277
1278    def new_sub_message(self, elem):
1279        return SubMessage(self, elem)
1280
1281    def is_classic(self):
1282        return self.proto == 'netlink-raw'
1283
1284    def _mark_notify(self):
1285        for op in self.msgs.values():
1286            if 'notify' in op:
1287                self.ops[op['notify']].mark_has_ntf()
1288
1289    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1290    def _mock_up_events(self):
1291        for op in self.yaml['operations']['list']:
1292            if 'event' in op:
1293                op['do'] = {
1294                    'reply': {
1295                        'attributes': op['event']['attributes']
1296                    }
1297                }
1298
1299    def _load_root_sets(self):
1300        for op_name, op in self.msgs.items():
1301            if 'attribute-set' not in op:
1302                continue
1303
1304            req_attrs = set()
1305            rsp_attrs = set()
1306            for op_mode in ['do', 'dump']:
1307                if op_mode in op and 'request' in op[op_mode]:
1308                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1309                if op_mode in op and 'reply' in op[op_mode]:
1310                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1311            if 'event' in op:
1312                rsp_attrs.update(set(op['event']['attributes']))
1313
1314            if op['attribute-set'] not in self.root_sets:
1315                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1316            else:
1317                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1318                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1319
1320    def _sort_pure_types(self):
1321        # Try to reorder according to dependencies
1322        pns_key_list = list(self.pure_nested_structs.keys())
1323        pns_key_seen = set()
1324        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1325        for _ in range(rounds):
1326            if len(pns_key_list) == 0:
1327                break
1328            name = pns_key_list.pop(0)
1329            finished = True
1330            for _, spec in self.attr_sets[name].items():
1331                if 'nested-attributes' in spec:
1332                    nested = spec['nested-attributes']
1333                elif 'sub-message' in spec:
1334                    nested = spec.sub_message
1335                else:
1336                    continue
1337
1338                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1339                if self.pure_nested_structs[nested].recursive:
1340                    continue
1341                if nested not in pns_key_seen:
1342                    # Dicts are sorted, this will make struct last
1343                    struct = self.pure_nested_structs.pop(name)
1344                    self.pure_nested_structs[name] = struct
1345                    finished = False
1346                    break
1347            if finished:
1348                pns_key_seen.add(name)
1349            else:
1350                pns_key_list.append(name)
1351
1352    def _load_nested_set_nest(self, spec):
1353        inherit = set()
1354        nested = spec['nested-attributes']
1355        if nested not in self.root_sets:
1356            if nested not in self.pure_nested_structs:
1357                self.pure_nested_structs[nested] = \
1358                    Struct(self, nested, inherited=inherit,
1359                           fixed_header=spec.get('fixed-header'))
1360        else:
1361            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1362
1363        if 'type-value' in spec:
1364            if nested in self.root_sets:
1365                raise Exception("Inheriting members to a space used as root not supported")
1366            inherit.update(set(spec['type-value']))
1367        elif spec['type'] == 'indexed-array':
1368            inherit.add('idx')
1369        self.pure_nested_structs[nested].set_inherited(inherit)
1370
1371        return nested
1372
1373    def _load_nested_set_submsg(self, spec):
1374        # Fake the struct type for the sub-message itself
1375        # its not a attr_set but codegen wants attr_sets.
1376        submsg = self.sub_msgs[spec["sub-message"]]
1377        nested = submsg.name
1378
1379        attrs = []
1380        for name, fmt in submsg.formats.items():
1381            attr = {
1382                "name": name,
1383                "parent-sub-message": spec,
1384            }
1385            if 'attribute-set' in fmt:
1386                attr |= {
1387                    "type": "nest",
1388                    "nested-attributes": fmt['attribute-set'],
1389                }
1390                if 'fixed-header' in fmt:
1391                    attr |= { "fixed-header": fmt["fixed-header"] }
1392            elif 'fixed-header' in fmt:
1393                attr |= {
1394                    "type": "binary",
1395                    "struct": fmt["fixed-header"],
1396                }
1397            else:
1398                attr["type"] = "flag"
1399            attrs.append(attr)
1400
1401        self.attr_sets[nested] = AttrSet(self, {
1402            "name": nested,
1403            "name-pfx": self.name + '-' + spec.name + '-',
1404            "attributes": attrs
1405        })
1406
1407        if nested not in self.pure_nested_structs:
1408            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1409
1410        return nested
1411
1412    def _load_nested_sets(self):
1413        attr_set_queue = list(self.root_sets.keys())
1414        attr_set_seen = set(self.root_sets.keys())
1415
1416        while len(attr_set_queue):
1417            a_set = attr_set_queue.pop(0)
1418            for attr, spec in self.attr_sets[a_set].items():
1419                if 'nested-attributes' in spec:
1420                    nested = self._load_nested_set_nest(spec)
1421                elif 'sub-message' in spec:
1422                    nested = self._load_nested_set_submsg(spec)
1423                else:
1424                    continue
1425
1426                if nested not in attr_set_seen:
1427                    attr_set_queue.append(nested)
1428                    attr_set_seen.add(nested)
1429
1430        for root_set, rs_members in self.root_sets.items():
1431            for attr, spec in self.attr_sets[root_set].items():
1432                if 'nested-attributes' in spec:
1433                    nested = spec['nested-attributes']
1434                elif 'sub-message' in spec:
1435                    nested = spec.sub_message
1436                else:
1437                    nested = None
1438
1439                if nested:
1440                    if attr in rs_members['request']:
1441                        self.pure_nested_structs[nested].request = True
1442                    if attr in rs_members['reply']:
1443                        self.pure_nested_structs[nested].reply = True
1444
1445                    if spec.is_multi_val():
1446                        child = self.pure_nested_structs.get(nested)
1447                        child.in_multi_val = True
1448
1449        self._sort_pure_types()
1450
1451        # Propagate the request / reply / recursive
1452        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1453            for _, spec in self.attr_sets[attr_set].items():
1454                if attr_set in struct.child_nests:
1455                    struct.recursive = True
1456
1457                if 'nested-attributes' in spec:
1458                    child_name = spec['nested-attributes']
1459                elif 'sub-message' in spec:
1460                    child_name = spec.sub_message
1461                else:
1462                    continue
1463
1464                struct.child_nests.add(child_name)
1465                child = self.pure_nested_structs.get(child_name)
1466                if child:
1467                    if not child.recursive:
1468                        struct.child_nests.update(child.child_nests)
1469                    child.request |= struct.request
1470                    child.reply |= struct.reply
1471                    if spec.is_multi_val():
1472                        child.in_multi_val = True
1473
1474        self._sort_pure_types()
1475
1476    def _load_attr_use(self):
1477        for _, struct in self.pure_nested_structs.items():
1478            if struct.request:
1479                for _, arg in struct.member_list():
1480                    arg.set_request()
1481            if struct.reply:
1482                for _, arg in struct.member_list():
1483                    arg.set_reply()
1484
1485        for root_set, rs_members in self.root_sets.items():
1486            for attr, spec in self.attr_sets[root_set].items():
1487                if attr in rs_members['request']:
1488                    spec.set_request()
1489                if attr in rs_members['reply']:
1490                    spec.set_reply()
1491
1492    def _load_selector_passing(self):
1493        def all_structs():
1494            for k, v in reversed(self.pure_nested_structs.items()):
1495                yield k, v
1496            for k, _ in self.root_sets.items():
1497                yield k, None  # we don't have a struct, but it must be terminal
1498
1499        for attr_set, struct in all_structs():
1500            for _, spec in self.attr_sets[attr_set].items():
1501                if 'nested-attributes' in spec:
1502                    child_name = spec['nested-attributes']
1503                elif 'sub-message' in spec:
1504                    child_name = spec.sub_message
1505                else:
1506                    continue
1507
1508                child = self.pure_nested_structs.get(child_name)
1509                for selector in child.external_selectors():
1510                    if selector.name in self.attr_sets[attr_set]:
1511                        sel_attr = self.attr_sets[attr_set][selector.name]
1512                        selector.set_attr(sel_attr)
1513                    else:
1514                        raise Exception("Passing selector thru more than one layer not supported")
1515
1516    def _load_global_policy(self):
1517        global_set = set()
1518        attr_set_name = None
1519        for op_name, op in self.ops.items():
1520            if not op:
1521                continue
1522            if 'attribute-set' not in op:
1523                continue
1524
1525            if attr_set_name is None:
1526                attr_set_name = op['attribute-set']
1527            if attr_set_name != op['attribute-set']:
1528                raise Exception('For a global policy all ops must use the same set')
1529
1530            for op_mode in ['do', 'dump']:
1531                if op_mode in op:
1532                    req = op[op_mode].get('request')
1533                    if req:
1534                        global_set.update(req.get('attributes', []))
1535
1536        self.global_policy = []
1537        self.global_policy_set = attr_set_name
1538        for attr in self.attr_sets[attr_set_name]:
1539            if attr in global_set:
1540                self.global_policy.append(attr)
1541
1542    def _load_hooks(self):
1543        for op in self.ops.values():
1544            for op_mode in ['do', 'dump']:
1545                if op_mode not in op:
1546                    continue
1547                for when in ['pre', 'post']:
1548                    if when not in op[op_mode]:
1549                        continue
1550                    name = op[op_mode][when]
1551                    if name in self.hooks[when][op_mode]['set']:
1552                        continue
1553                    self.hooks[when][op_mode]['set'].add(name)
1554                    self.hooks[when][op_mode]['list'].append(name)
1555
1556
1557class RenderInfo:
1558    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1559        self.family = family
1560        self.nl = cw.nlib
1561        self.ku_space = ku_space
1562        self.op_mode = op_mode
1563        self.op = op
1564
1565        fixed_hdr = op.fixed_header if op else None
1566        self.fixed_hdr_len = 'ys->family->hdr_len'
1567        if op and op.fixed_header:
1568            if op.fixed_header != family.fixed_header:
1569                if family.is_classic():
1570                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1571                else:
1572                    raise Exception("Per-op fixed header not supported, yet")
1573
1574
1575        # 'do' and 'dump' response parsing is identical
1576        self.type_consistent = True
1577        self.type_oneside = False
1578        if op_mode != 'do' and 'dump' in op:
1579            if 'do' in op:
1580                if ('reply' in op['do']) != ('reply' in op["dump"]):
1581                    self.type_consistent = False
1582                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1583                    self.type_consistent = False
1584            else:
1585                self.type_consistent = True
1586                self.type_oneside = True
1587
1588        self.attr_set = attr_set
1589        if not self.attr_set:
1590            self.attr_set = op['attribute-set']
1591
1592        self.type_name_conflict = False
1593        if op:
1594            self.type_name = c_lower(op.name)
1595        else:
1596            self.type_name = c_lower(attr_set)
1597            if attr_set in family.consts:
1598                self.type_name_conflict = True
1599
1600        self.cw = cw
1601
1602        self.struct = dict()
1603        if op_mode == 'notify':
1604            op_mode = 'do' if 'do' in op else 'dump'
1605        for op_dir in ['request', 'reply']:
1606            if op:
1607                type_list = []
1608                if op_dir in op[op_mode]:
1609                    type_list = op[op_mode][op_dir]['attributes']
1610                self.struct[op_dir] = Struct(family, self.attr_set,
1611                                             fixed_header=fixed_hdr,
1612                                             type_list=type_list)
1613        if op_mode == 'event':
1614            self.struct['reply'] = Struct(family, self.attr_set,
1615                                          fixed_header=fixed_hdr,
1616                                          type_list=op['event']['attributes'])
1617
1618    def type_empty(self, key):
1619        return len(self.struct[key].attr_list) == 0 and \
1620            self.struct['request'].fixed_header is None
1621
1622    def needs_nlflags(self, direction):
1623        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1624
1625
1626class CodeWriter:
1627    def __init__(self, nlib, out_file=None, overwrite=True):
1628        self.nlib = nlib
1629        self._overwrite = overwrite
1630
1631        self._nl = False
1632        self._block_end = False
1633        self._silent_block = False
1634        self._ind = 0
1635        self._ifdef_block = None
1636        if out_file is None:
1637            self._out = os.sys.stdout
1638        else:
1639            self._out = tempfile.NamedTemporaryFile('w+')
1640            self._out_file = out_file
1641
1642    def __del__(self):
1643        self.close_out_file()
1644
1645    def close_out_file(self):
1646        if self._out == os.sys.stdout:
1647            return
1648        # Avoid modifying the file if contents didn't change
1649        self._out.flush()
1650        if not self._overwrite and os.path.isfile(self._out_file):
1651            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1652                return
1653        with open(self._out_file, 'w+') as out_file:
1654            self._out.seek(0)
1655            shutil.copyfileobj(self._out, out_file)
1656            self._out.close()
1657        self._out = os.sys.stdout
1658
1659    @classmethod
1660    def _is_cond(cls, line):
1661        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1662
1663    def p(self, line, add_ind=0):
1664        if self._block_end:
1665            self._block_end = False
1666            if line.startswith('else'):
1667                line = '} ' + line
1668            else:
1669                self._out.write('\t' * self._ind + '}\n')
1670
1671        if self._nl:
1672            self._out.write('\n')
1673            self._nl = False
1674
1675        ind = self._ind
1676        if line[-1] == ':':
1677            ind -= 1
1678        if self._silent_block:
1679            ind += 1
1680        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1681        self._silent_block |= line.strip() == 'else'
1682        if line[0] == '#':
1683            ind = 0
1684        if add_ind:
1685            ind += add_ind
1686        self._out.write('\t' * ind + line + '\n')
1687
1688    def nl(self):
1689        self._nl = True
1690
1691    def block_start(self, line=''):
1692        if line:
1693            line = line + ' '
1694        self.p(line + '{')
1695        self._ind += 1
1696
1697    def block_end(self, line=''):
1698        if line and line[0] not in {';', ','}:
1699            line = ' ' + line
1700        self._ind -= 1
1701        self._nl = False
1702        if not line:
1703            # Delay printing closing bracket in case "else" comes next
1704            if self._block_end:
1705                self._out.write('\t' * (self._ind + 1) + '}\n')
1706            self._block_end = True
1707        else:
1708            self.p('}' + line)
1709
1710    def write_doc_line(self, doc, indent=True):
1711        words = doc.split()
1712        line = ' *'
1713        for word in words:
1714            if len(line) + len(word) >= 79:
1715                self.p(line)
1716                line = ' *'
1717                if indent:
1718                    line += '  '
1719            line += ' ' + word
1720        self.p(line)
1721
1722    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1723        if not args:
1724            args = ['void']
1725
1726        if doc:
1727            self.p('/*')
1728            self.p(' * ' + doc)
1729            self.p(' */')
1730
1731        oneline = qual_ret
1732        if qual_ret[-1] != '*':
1733            oneline += ' '
1734        oneline += f"{name}({', '.join(args)}){suffix}"
1735
1736        if len(oneline) < 80:
1737            self.p(oneline)
1738            return
1739
1740        v = qual_ret
1741        if len(v) > 3:
1742            self.p(v)
1743            v = ''
1744        elif qual_ret[-1] != '*':
1745            v += ' '
1746        v += name + '('
1747        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1748        delta_ind = len(v) - len(ind)
1749        v += args[0]
1750        i = 1
1751        while i < len(args):
1752            next_len = len(v) + len(args[i])
1753            if v[0] == '\t':
1754                next_len += delta_ind
1755            if next_len > 76:
1756                self.p(v + ',')
1757                v = ind
1758            else:
1759                v += ', '
1760            v += args[i]
1761            i += 1
1762        self.p(v + ')' + suffix)
1763
1764    def write_func_lvar(self, local_vars):
1765        if not local_vars:
1766            return
1767
1768        if type(local_vars) is str:
1769            local_vars = [local_vars]
1770
1771        local_vars.sort(key=len, reverse=True)
1772        for var in local_vars:
1773            self.p(var)
1774        self.nl()
1775
1776    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1777        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1778        self.block_start()
1779        self.write_func_lvar(local_vars=local_vars)
1780
1781        for line in body:
1782            self.p(line)
1783        self.block_end()
1784
1785    def writes_defines(self, defines):
1786        longest = 0
1787        for define in defines:
1788            if len(define[0]) > longest:
1789                longest = len(define[0])
1790        longest = ((longest + 8) // 8) * 8
1791        for define in defines:
1792            line = '#define ' + define[0]
1793            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1794            if type(define[1]) is int:
1795                line += str(define[1])
1796            elif type(define[1]) is str:
1797                line += '"' + define[1] + '"'
1798            self.p(line)
1799
1800    def write_struct_init(self, members):
1801        longest = max([len(x[0]) for x in members])
1802        longest += 1  # because we prepend a .
1803        longest = ((longest + 8) // 8) * 8
1804        for one in members:
1805            line = '.' + one[0]
1806            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1807            line += '= ' + str(one[1]) + ','
1808            self.p(line)
1809
1810    def ifdef_block(self, config):
1811        config_option = None
1812        if config:
1813            config_option = 'CONFIG_' + c_upper(config)
1814        if self._ifdef_block == config_option:
1815            return
1816
1817        if self._ifdef_block:
1818            self.p('#endif /* ' + self._ifdef_block + ' */')
1819        if config_option:
1820            self.p('#ifdef ' + config_option)
1821        self._ifdef_block = config_option
1822
1823
1824scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1825
1826direction_to_suffix = {
1827    'reply': '_rsp',
1828    'request': '_req',
1829    '': ''
1830}
1831
1832op_mode_to_wrapper = {
1833    'do': '',
1834    'dump': '_list',
1835    'notify': '_ntf',
1836    'event': '',
1837}
1838
1839_C_KW = {
1840    'auto',
1841    'bool',
1842    'break',
1843    'case',
1844    'char',
1845    'const',
1846    'continue',
1847    'default',
1848    'do',
1849    'double',
1850    'else',
1851    'enum',
1852    'extern',
1853    'float',
1854    'for',
1855    'goto',
1856    'if',
1857    'inline',
1858    'int',
1859    'long',
1860    'register',
1861    'return',
1862    'short',
1863    'signed',
1864    'sizeof',
1865    'static',
1866    'struct',
1867    'switch',
1868    'typedef',
1869    'union',
1870    'unsigned',
1871    'void',
1872    'volatile',
1873    'while'
1874}
1875
1876
1877def rdir(direction):
1878    if direction == 'reply':
1879        return 'request'
1880    if direction == 'request':
1881        return 'reply'
1882    return direction
1883
1884
1885def op_prefix(ri, direction, deref=False):
1886    suffix = f"_{ri.type_name}"
1887
1888    if not ri.op_mode:
1889        pass
1890    elif ri.op_mode == 'do':
1891        suffix += f"{direction_to_suffix[direction]}"
1892    else:
1893        if direction == 'request':
1894            suffix += '_req'
1895            if not ri.type_oneside:
1896                suffix += '_dump'
1897        else:
1898            if ri.type_consistent:
1899                if deref:
1900                    suffix += f"{direction_to_suffix[direction]}"
1901                else:
1902                    suffix += op_mode_to_wrapper[ri.op_mode]
1903            else:
1904                suffix += '_rsp'
1905                suffix += '_dump' if deref else '_list'
1906
1907    return f"{ri.family.c_name}{suffix}"
1908
1909
1910def type_name(ri, direction, deref=False):
1911    return f"struct {op_prefix(ri, direction, deref=deref)}"
1912
1913
1914def print_prototype(ri, direction, terminate=True, doc=None):
1915    suffix = ';' if terminate else ''
1916
1917    fname = ri.op.render_name
1918    if ri.op_mode == 'dump':
1919        fname += '_dump'
1920
1921    args = ['struct ynl_sock *ys']
1922    if 'request' in ri.op[ri.op_mode]:
1923        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1924
1925    ret = 'int'
1926    if 'reply' in ri.op[ri.op_mode]:
1927        ret = f"{type_name(ri, rdir(direction))} *"
1928
1929    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1930
1931
1932def print_req_prototype(ri):
1933    print_prototype(ri, "request", doc=ri.op['doc'])
1934
1935
1936def print_dump_prototype(ri):
1937    print_prototype(ri, "request")
1938
1939
1940def put_typol_submsg(cw, struct):
1941    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1942
1943    i = 0
1944    for name, arg in struct.member_list():
1945        nest = ""
1946        if arg.type == 'nest':
1947            nest = f" .nest = &{arg.nested_render_name}_nest,"
1948        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s",%s },' %
1949             (i, name, nest))
1950        i += 1
1951
1952    cw.block_end(line=';')
1953    cw.nl()
1954
1955    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1956    cw.p(f'.max_attr = {i - 1},')
1957    cw.p(f'.table = {struct.render_name}_policy,')
1958    cw.block_end(line=';')
1959    cw.nl()
1960
1961
1962def put_typol_fwd(cw, struct):
1963    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1964
1965
1966def put_typol(cw, struct):
1967    if struct.submsg:
1968        put_typol_submsg(cw, struct)
1969        return
1970
1971    type_max = struct.attr_set.max_name
1972    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1973
1974    for _, arg in struct.member_list():
1975        arg.attr_typol(cw)
1976
1977    cw.block_end(line=';')
1978    cw.nl()
1979
1980    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1981    cw.p(f'.max_attr = {type_max},')
1982    cw.p(f'.table = {struct.render_name}_policy,')
1983    cw.block_end(line=';')
1984    cw.nl()
1985
1986
1987def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1988    args = [f'int {arg_name}']
1989    if enum:
1990        args = [enum.user_type + ' ' + arg_name]
1991    cw.write_func_prot('const char *', f'{render_name}_str', args)
1992    cw.block_start()
1993    if enum and enum.type == 'flags':
1994        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1995    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1996    cw.p('return NULL;')
1997    cw.p(f'return {map_name}[{arg_name}];')
1998    cw.block_end()
1999    cw.nl()
2000
2001
2002def put_op_name_fwd(family, cw):
2003    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
2004
2005
2006def put_op_name(family, cw):
2007    map_name = f'{family.c_name}_op_strmap'
2008    cw.block_start(line=f"static const char * const {map_name}[] =")
2009    for op_name, op in family.msgs.items():
2010        if op.rsp_value:
2011            # Make sure we don't add duplicated entries, if multiple commands
2012            # produce the same response in legacy families.
2013            if family.rsp_by_value[op.rsp_value] != op:
2014                cw.p(f'// skip "{op_name}", duplicate reply value')
2015                continue
2016
2017            if op.req_value == op.rsp_value:
2018                cw.p(f'[{op.enum_name}] = "{op_name}",')
2019            else:
2020                cw.p(f'[{op.rsp_value}] = "{op_name}",')
2021    cw.block_end(line=';')
2022    cw.nl()
2023
2024    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2025
2026
2027def put_enum_to_str_fwd(family, cw, enum):
2028    args = [enum.user_type + ' value']
2029    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2030
2031
2032def put_enum_to_str(family, cw, enum):
2033    map_name = f'{enum.render_name}_strmap'
2034    cw.block_start(line=f"static const char * const {map_name}[] =")
2035    for entry in enum.entries.values():
2036        cw.p(f'[{entry.value}] = "{entry.name}",')
2037    cw.block_end(line=';')
2038    cw.nl()
2039
2040    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2041
2042
2043def put_req_nested_prototype(ri, struct, suffix=';'):
2044    func_args = ['struct nlmsghdr *nlh',
2045                 'unsigned int attr_type',
2046                 f'{struct.ptr_name}obj']
2047
2048    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2049                          suffix=suffix)
2050
2051
2052def put_req_nested(ri, struct):
2053    local_vars = []
2054    init_lines = []
2055
2056    if struct.submsg is None:
2057        local_vars.append('struct nlattr *nest;')
2058        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2059    if struct.fixed_header:
2060        local_vars.append('void *hdr;')
2061        struct_sz = f'sizeof({struct.fixed_header})'
2062        init_lines.append(f"hdr = ynl_nlmsg_put_extra_header(nlh, {struct_sz});")
2063        init_lines.append(f"memcpy(hdr, &obj->_hdr, {struct_sz});")
2064
2065    has_anest = False
2066    has_count = False
2067    for _, arg in struct.member_list():
2068        has_anest |= arg.type == 'indexed-array'
2069        has_count |= arg.presence_type() == 'count'
2070    if has_anest:
2071        local_vars.append('struct nlattr *array;')
2072    if has_count:
2073        local_vars.append('unsigned int i;')
2074
2075    put_req_nested_prototype(ri, struct, suffix='')
2076    ri.cw.block_start()
2077    ri.cw.write_func_lvar(local_vars)
2078
2079    for line in init_lines:
2080        ri.cw.p(line)
2081
2082    for _, arg in struct.member_list():
2083        arg.attr_put(ri, "obj")
2084
2085    if struct.submsg is None:
2086        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2087
2088    ri.cw.nl()
2089    ri.cw.p('return 0;')
2090    ri.cw.block_end()
2091    ri.cw.nl()
2092
2093
2094def _multi_parse(ri, struct, init_lines, local_vars):
2095    if struct.fixed_header:
2096        local_vars += ['void *hdr;']
2097    if struct.nested:
2098        if struct.fixed_header:
2099            iter_line = f"ynl_attr_for_each_nested_off(attr, nested, sizeof({struct.fixed_header}))"
2100        else:
2101            iter_line = "ynl_attr_for_each_nested(attr, nested)"
2102    else:
2103        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2104        if ri.op.fixed_header != ri.family.fixed_header:
2105            if ri.family.is_classic():
2106                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2107            else:
2108                raise Exception("Per-op fixed header not supported, yet")
2109
2110    array_nests = set()
2111    multi_attrs = set()
2112    needs_parg = False
2113    for arg, aspec in struct.member_list():
2114        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2115            if aspec["sub-type"] in {'binary', 'nest'}:
2116                local_vars.append(f'const struct nlattr *attr_{aspec.c_name} = NULL;')
2117                array_nests.add(arg)
2118            elif aspec['sub-type'] in scalars:
2119                local_vars.append(f'const struct nlattr *attr_{aspec.c_name} = NULL;')
2120                array_nests.add(arg)
2121            else:
2122                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2123        if 'multi-attr' in aspec:
2124            multi_attrs.add(arg)
2125        needs_parg |= 'nested-attributes' in aspec
2126        needs_parg |= 'sub-message' in aspec
2127    if array_nests or multi_attrs:
2128        local_vars.append('int i;')
2129    if needs_parg:
2130        local_vars.append('struct ynl_parse_arg parg;')
2131        init_lines.append('parg.ys = yarg->ys;')
2132
2133    all_multi = array_nests | multi_attrs
2134
2135    for anest in sorted(all_multi):
2136        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
2137
2138    ri.cw.block_start()
2139    ri.cw.write_func_lvar(local_vars)
2140
2141    for line in init_lines:
2142        ri.cw.p(line)
2143    ri.cw.nl()
2144
2145    for arg in struct.inherited:
2146        ri.cw.p(f'dst->{arg} = {arg};')
2147
2148    if struct.fixed_header:
2149        if struct.nested:
2150            ri.cw.p('hdr = ynl_attr_data(nested);')
2151        elif ri.family.is_classic():
2152            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2153        else:
2154            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2155        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2156    for anest in sorted(all_multi):
2157        aspec = struct[anest]
2158        ri.cw.p(f"if (dst->{aspec.c_name})")
2159        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2160
2161    ri.cw.nl()
2162    ri.cw.block_start(line=iter_line)
2163    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2164    ri.cw.nl()
2165
2166    first = True
2167    for _, arg in struct.member_list():
2168        good = arg.attr_get(ri, 'dst', first=first)
2169        # First may be 'unused' or 'pad', ignore those
2170        first &= not good
2171
2172    ri.cw.block_end()
2173    ri.cw.nl()
2174
2175    for anest in sorted(array_nests):
2176        aspec = struct[anest]
2177
2178        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2179        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2180        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2181        ri.cw.p('i = 0;')
2182        if 'nested-attributes' in aspec:
2183            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2184        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2185        if 'nested-attributes' in aspec:
2186            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2187            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2188            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2189        elif aspec.sub_type in scalars:
2190            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2191        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2192            # Length is validated by typol
2193            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2194        else:
2195            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2196        ri.cw.p('i++;')
2197        ri.cw.block_end()
2198        ri.cw.block_end()
2199    ri.cw.nl()
2200
2201    for anest in sorted(multi_attrs):
2202        aspec = struct[anest]
2203        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2204        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2205        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2206        ri.cw.p('i = 0;')
2207        if 'nested-attributes' in aspec:
2208            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2209        ri.cw.block_start(line=iter_line)
2210        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2211        if 'nested-attributes' in aspec:
2212            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2213            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2214            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2215        elif aspec.type in scalars:
2216            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2217        elif aspec.type == 'binary' and 'struct' in aspec:
2218            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2219            ri.cw.nl()
2220            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2221            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2222            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2223        elif aspec.type == 'string':
2224            ri.cw.p('unsigned int len;')
2225            ri.cw.nl()
2226            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2227            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2228            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2229            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2230            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2231        else:
2232            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2233        ri.cw.p('i++;')
2234        ri.cw.block_end()
2235        ri.cw.block_end()
2236        ri.cw.block_end()
2237    ri.cw.nl()
2238
2239    if struct.nested:
2240        ri.cw.p('return 0;')
2241    else:
2242        ri.cw.p('return YNL_PARSE_CB_OK;')
2243    ri.cw.block_end()
2244    ri.cw.nl()
2245
2246
2247def parse_rsp_submsg(ri, struct):
2248    parse_rsp_nested_prototype(ri, struct, suffix='')
2249
2250    var = 'dst'
2251    local_vars = {'const struct nlattr *attr = nested;',
2252                  f'{struct.ptr_name}{var} = yarg->data;',
2253                  'struct ynl_parse_arg parg;'}
2254
2255    for _, arg in struct.member_list():
2256        _, _, l_vars = arg._attr_get(ri, var)
2257        local_vars |= set(l_vars) if l_vars else set()
2258
2259    ri.cw.block_start()
2260    ri.cw.write_func_lvar(list(local_vars))
2261    ri.cw.p('parg.ys = yarg->ys;')
2262    ri.cw.nl()
2263
2264    first = True
2265    for name, arg in struct.member_list():
2266        kw = 'if' if first else 'else if'
2267        first = False
2268
2269        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2270        get_lines, init_lines, _ = arg._attr_get(ri, var)
2271        for line in init_lines or []:
2272            ri.cw.p(line)
2273        for line in get_lines:
2274            ri.cw.p(line)
2275        if arg.presence_type() == 'present':
2276            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2277        ri.cw.block_end()
2278    ri.cw.p('return 0;')
2279    ri.cw.block_end()
2280    ri.cw.nl()
2281
2282
2283def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2284    func_args = ['struct ynl_parse_arg *yarg',
2285                 'const struct nlattr *nested']
2286    for sel in struct.external_selectors():
2287        func_args.append('const char *_sel_' + sel.name)
2288    if struct.submsg:
2289        func_args.insert(1, 'const char *sel')
2290    for arg in struct.inherited:
2291        func_args.append('__u32 ' + arg)
2292
2293    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2294                          suffix=suffix)
2295
2296
2297def parse_rsp_nested(ri, struct):
2298    if struct.submsg:
2299        return parse_rsp_submsg(ri, struct)
2300
2301    parse_rsp_nested_prototype(ri, struct, suffix='')
2302
2303    local_vars = ['const struct nlattr *attr;',
2304                  f'{struct.ptr_name}dst = yarg->data;']
2305    init_lines = []
2306
2307    if struct.member_list():
2308        _multi_parse(ri, struct, init_lines, local_vars)
2309    else:
2310        # Empty nest
2311        ri.cw.block_start()
2312        ri.cw.p('return 0;')
2313        ri.cw.block_end()
2314        ri.cw.nl()
2315
2316
2317def parse_rsp_msg(ri, deref=False):
2318    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2319        return
2320
2321    func_args = ['const struct nlmsghdr *nlh',
2322                 'struct ynl_parse_arg *yarg']
2323
2324    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2325                  'const struct nlattr *attr;']
2326    init_lines = ['dst = yarg->data;']
2327
2328    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2329
2330    if ri.struct["reply"].member_list():
2331        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2332    else:
2333        # Empty reply
2334        ri.cw.block_start()
2335        ri.cw.p('return YNL_PARSE_CB_OK;')
2336        ri.cw.block_end()
2337        ri.cw.nl()
2338
2339
2340def print_req(ri):
2341    ret_ok = '0'
2342    ret_err = '-1'
2343    direction = "request"
2344    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2345                  'struct nlmsghdr *nlh;',
2346                  'int err;']
2347
2348    if 'reply' in ri.op[ri.op_mode]:
2349        ret_ok = 'rsp'
2350        ret_err = 'NULL'
2351        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2352
2353    if ri.struct["request"].fixed_header:
2354        local_vars += ['size_t hdr_len;',
2355                       'void *hdr;']
2356
2357    for _, attr in ri.struct["request"].member_list():
2358        if attr.presence_type() == 'count':
2359            local_vars += ['unsigned int i;']
2360            break
2361
2362    print_prototype(ri, direction, terminate=False)
2363    ri.cw.block_start()
2364    ri.cw.write_func_lvar(local_vars)
2365
2366    if ri.family.is_classic():
2367        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2368    else:
2369        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2370
2371    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2372    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2373    if 'reply' in ri.op[ri.op_mode]:
2374        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2375    ri.cw.nl()
2376
2377    if ri.struct['request'].fixed_header:
2378        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2379        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2380        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2381        ri.cw.nl()
2382
2383    for _, attr in ri.struct["request"].member_list():
2384        attr.attr_put(ri, "req")
2385    ri.cw.nl()
2386
2387    if 'reply' in ri.op[ri.op_mode]:
2388        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2389        ri.cw.p('yrs.yarg.data = rsp;')
2390        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2391        if ri.op.value is not None:
2392            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2393        else:
2394            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2395        ri.cw.nl()
2396    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2397    ri.cw.p('if (err < 0)')
2398    if 'reply' in ri.op[ri.op_mode]:
2399        ri.cw.p('goto err_free;')
2400    else:
2401        ri.cw.p('return -1;')
2402    ri.cw.nl()
2403
2404    ri.cw.p(f"return {ret_ok};")
2405    ri.cw.nl()
2406
2407    if 'reply' in ri.op[ri.op_mode]:
2408        ri.cw.p('err_free:')
2409        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2410        ri.cw.p(f"return {ret_err};")
2411
2412    ri.cw.block_end()
2413
2414
2415def print_dump(ri):
2416    direction = "request"
2417    print_prototype(ri, direction, terminate=False)
2418    ri.cw.block_start()
2419    local_vars = ['struct ynl_dump_state yds = {};',
2420                  'struct nlmsghdr *nlh;',
2421                  'int err;']
2422
2423    if ri.struct['request'].fixed_header:
2424        local_vars += ['size_t hdr_len;',
2425                       'void *hdr;']
2426
2427    ri.cw.write_func_lvar(local_vars)
2428
2429    ri.cw.p('yds.yarg.ys = ys;')
2430    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2431    ri.cw.p("yds.yarg.data = NULL;")
2432    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2433    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2434    if ri.op.value is not None:
2435        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2436    else:
2437        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2438    ri.cw.nl()
2439    if ri.family.is_classic():
2440        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2441    else:
2442        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2443
2444    if ri.struct['request'].fixed_header:
2445        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2446        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2447        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2448        ri.cw.nl()
2449
2450    if "request" in ri.op[ri.op_mode]:
2451        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2452        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2453        ri.cw.nl()
2454        for _, attr in ri.struct["request"].member_list():
2455            attr.attr_put(ri, "req")
2456    ri.cw.nl()
2457
2458    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2459    ri.cw.p('if (err < 0)')
2460    ri.cw.p('goto free_list;')
2461    ri.cw.nl()
2462
2463    ri.cw.p('return yds.first;')
2464    ri.cw.nl()
2465    ri.cw.p('free_list:')
2466    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2467    ri.cw.p('return NULL;')
2468    ri.cw.block_end()
2469
2470
2471def call_free(ri, direction, var):
2472    return f"{op_prefix(ri, direction)}_free({var});"
2473
2474
2475def free_arg_name(direction):
2476    if direction:
2477        return direction_to_suffix[direction][1:]
2478    return 'obj'
2479
2480
2481def print_alloc_wrapper(ri, direction, struct=None):
2482    name = op_prefix(ri, direction)
2483    struct_name = name
2484    if ri.type_name_conflict:
2485        struct_name += '_'
2486
2487    args = ["void"]
2488    cnt = "1"
2489    if struct and struct.in_multi_val:
2490        args = ["unsigned int n"]
2491        cnt = "n"
2492
2493    ri.cw.write_func_prot(f'static inline struct {struct_name} *',
2494                          f"{name}_alloc", args)
2495    ri.cw.block_start()
2496    ri.cw.p(f'return calloc({cnt}, sizeof(struct {struct_name}));')
2497    ri.cw.block_end()
2498
2499
2500def print_free_prototype(ri, direction, suffix=';'):
2501    name = op_prefix(ri, direction)
2502    struct_name = name
2503    if ri.type_name_conflict:
2504        struct_name += '_'
2505    arg = free_arg_name(direction)
2506    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2507
2508
2509def print_nlflags_set(ri, direction):
2510    name = op_prefix(ri, direction)
2511    ri.cw.write_func_prot('static inline void', f"{name}_set_nlflags",
2512                          [f"struct {name} *req", "__u16 nl_flags"])
2513    ri.cw.block_start()
2514    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2515    ri.cw.block_end()
2516    ri.cw.nl()
2517
2518
2519def _print_type(ri, direction, struct):
2520    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2521    if not direction and ri.type_name_conflict:
2522        suffix += '_'
2523
2524    if ri.op_mode == 'dump' and not ri.type_oneside:
2525        suffix += '_dump'
2526
2527    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2528
2529    if ri.needs_nlflags(direction):
2530        ri.cw.p('__u16 _nlmsg_flags;')
2531        ri.cw.nl()
2532    if struct.fixed_header:
2533        ri.cw.p(struct.fixed_header + ' _hdr;')
2534        ri.cw.nl()
2535
2536    for type_filter in ['present', 'len', 'count']:
2537        meta_started = False
2538        for _, attr in struct.member_list():
2539            line = attr.presence_member(ri.ku_space, type_filter)
2540            if line:
2541                if not meta_started:
2542                    ri.cw.block_start(line="struct")
2543                    meta_started = True
2544                ri.cw.p(line)
2545        if meta_started:
2546            ri.cw.block_end(line=f'_{type_filter};')
2547    ri.cw.nl()
2548
2549    for arg in struct.inherited:
2550        ri.cw.p(f"__u32 {arg};")
2551
2552    for _, attr in struct.member_list():
2553        attr.struct_member(ri)
2554
2555    ri.cw.block_end(line=';')
2556    ri.cw.nl()
2557
2558
2559def print_type(ri, direction):
2560    _print_type(ri, direction, ri.struct[direction])
2561
2562
2563def print_type_full(ri, struct):
2564    _print_type(ri, "", struct)
2565
2566    if struct.request and struct.in_multi_val:
2567        print_alloc_wrapper(ri, "", struct)
2568        ri.cw.nl()
2569        free_rsp_nested_prototype(ri)
2570        ri.cw.nl()
2571
2572        # Name conflicts are too hard to deal with with the current code base,
2573        # they are very rare so don't bother printing setters in that case.
2574        if ri.ku_space == 'user' and not ri.type_name_conflict:
2575            for _, attr in struct.member_list():
2576                attr.setter(ri, ri.attr_set, "", var="obj")
2577        ri.cw.nl()
2578
2579
2580def print_type_helpers(ri, direction, deref=False):
2581    print_free_prototype(ri, direction)
2582    ri.cw.nl()
2583
2584    if ri.needs_nlflags(direction):
2585        print_nlflags_set(ri, direction)
2586
2587    if ri.ku_space == 'user' and direction == 'request':
2588        for _, attr in ri.struct[direction].member_list():
2589            attr.setter(ri, ri.attr_set, direction, deref=deref)
2590    ri.cw.nl()
2591
2592
2593def print_req_type_helpers(ri):
2594    if ri.type_empty("request"):
2595        return
2596    print_alloc_wrapper(ri, "request")
2597    print_type_helpers(ri, "request")
2598
2599
2600def print_rsp_type_helpers(ri):
2601    if 'reply' not in ri.op[ri.op_mode]:
2602        return
2603    print_type_helpers(ri, "reply")
2604
2605
2606def print_parse_prototype(ri, direction, terminate=True):
2607    suffix = "_rsp" if direction == "reply" else "_req"
2608    term = ';' if terminate else ''
2609
2610    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2611                          ['const struct nlattr **tb',
2612                           f"struct {ri.op.render_name}{suffix} *req"],
2613                          suffix=term)
2614
2615
2616def print_req_type(ri):
2617    if ri.type_empty("request"):
2618        return
2619    print_type(ri, "request")
2620
2621
2622def print_req_free(ri):
2623    if 'request' not in ri.op[ri.op_mode]:
2624        return
2625    _free_type(ri, 'request', ri.struct['request'])
2626
2627
2628def print_rsp_type(ri):
2629    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2630        direction = 'reply'
2631    elif ri.op_mode == 'event':
2632        direction = 'reply'
2633    else:
2634        return
2635    print_type(ri, direction)
2636
2637
2638def print_wrapped_type(ri):
2639    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2640    if ri.op_mode == 'dump':
2641        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2642    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2643        ri.cw.p('__u16 family;')
2644        ri.cw.p('__u8 cmd;')
2645        ri.cw.p('struct ynl_ntf_base_type *next;')
2646        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2647    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2648    ri.cw.block_end(line=';')
2649    ri.cw.nl()
2650    print_free_prototype(ri, 'reply')
2651    ri.cw.nl()
2652
2653
2654def _free_type_members_iter(ri, struct):
2655    if struct.free_needs_iter():
2656        ri.cw.p('unsigned int i;')
2657        ri.cw.nl()
2658
2659
2660def _free_type_members(ri, var, struct, ref=''):
2661    for _, attr in struct.member_list():
2662        attr.free(ri, var, ref)
2663
2664
2665def _free_type(ri, direction, struct):
2666    var = free_arg_name(direction)
2667
2668    print_free_prototype(ri, direction, suffix='')
2669    ri.cw.block_start()
2670    _free_type_members_iter(ri, struct)
2671    _free_type_members(ri, var, struct)
2672    if direction:
2673        ri.cw.p(f'free({var});')
2674    ri.cw.block_end()
2675    ri.cw.nl()
2676
2677
2678def free_rsp_nested_prototype(ri):
2679        print_free_prototype(ri, "")
2680
2681
2682def free_rsp_nested(ri, struct):
2683    _free_type(ri, "", struct)
2684
2685
2686def print_rsp_free(ri):
2687    if 'reply' not in ri.op[ri.op_mode]:
2688        return
2689    _free_type(ri, 'reply', ri.struct['reply'])
2690
2691
2692def print_dump_type_free(ri):
2693    sub_type = type_name(ri, 'reply')
2694
2695    print_free_prototype(ri, 'reply', suffix='')
2696    ri.cw.block_start()
2697    ri.cw.p(f"{sub_type} *next = rsp;")
2698    ri.cw.nl()
2699    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2700    _free_type_members_iter(ri, ri.struct['reply'])
2701    ri.cw.p('rsp = next;')
2702    ri.cw.p('next = rsp->next;')
2703    ri.cw.nl()
2704
2705    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2706    ri.cw.p('free(rsp);')
2707    ri.cw.block_end()
2708    ri.cw.block_end()
2709    ri.cw.nl()
2710
2711
2712def print_ntf_type_free(ri):
2713    print_free_prototype(ri, 'reply', suffix='')
2714    ri.cw.block_start()
2715    _free_type_members_iter(ri, ri.struct['reply'])
2716    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2717    ri.cw.p('free(rsp);')
2718    ri.cw.block_end()
2719    ri.cw.nl()
2720
2721
2722def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2723    if terminate and ri and policy_should_be_static(struct.family):
2724        return
2725
2726    if terminate:
2727        prefix = 'extern '
2728    else:
2729        if ri and policy_should_be_static(struct.family):
2730            prefix = 'static '
2731        else:
2732            prefix = ''
2733
2734    suffix = ';' if terminate else ' = {'
2735
2736    max_attr = struct.attr_max_val
2737    if ri:
2738        name = ri.op.render_name
2739        if ri.op.dual_policy:
2740            name += '_' + ri.op_mode
2741    else:
2742        name = struct.render_name
2743    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2744
2745
2746def print_req_policy(cw, struct, ri=None):
2747    if ri and ri.op:
2748        cw.ifdef_block(ri.op.get('config-cond', None))
2749    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2750    for _, arg in struct.member_list():
2751        arg.attr_policy(cw)
2752    cw.p("};")
2753    cw.ifdef_block(None)
2754    cw.nl()
2755
2756
2757def kernel_can_gen_family_struct(family):
2758    return family.proto == 'genetlink'
2759
2760
2761def policy_should_be_static(family):
2762    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2763
2764
2765def print_kernel_policy_ranges(family, cw):
2766    first = True
2767    for _, attr_set in family.attr_sets.items():
2768        if attr_set.subset_of:
2769            continue
2770
2771        for _, attr in attr_set.items():
2772            if not attr.request:
2773                continue
2774            if 'full-range' not in attr.checks:
2775                continue
2776
2777            if first:
2778                cw.p('/* Integer value ranges */')
2779                first = False
2780
2781            sign = '' if attr.type[0] == 'u' else '_signed'
2782            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2783            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2784            members = []
2785            if 'min' in attr.checks:
2786                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2787            if 'max' in attr.checks:
2788                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2789            cw.write_struct_init(members)
2790            cw.block_end(line=';')
2791            cw.nl()
2792
2793
2794def print_kernel_policy_sparse_enum_validates(family, cw):
2795    first = True
2796    for _, attr_set in family.attr_sets.items():
2797        if attr_set.subset_of:
2798            continue
2799
2800        for _, attr in attr_set.items():
2801            if not attr.request:
2802                continue
2803            if not attr.enum_name:
2804                continue
2805            if 'sparse' not in attr.checks:
2806                continue
2807
2808            if first:
2809                cw.p('/* Sparse enums validation callbacks */')
2810                first = False
2811
2812            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2813                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2814            cw.block_start()
2815            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2816            enum = family.consts[attr['enum']]
2817            first_entry = True
2818            for entry in enum.entries.values():
2819                if first_entry:
2820                    first_entry = False
2821                else:
2822                    cw.p('fallthrough;')
2823                cw.p(f'case {entry.c_name}:')
2824            cw.p('return 0;')
2825            cw.block_end()
2826            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2827            cw.p('return -EINVAL;')
2828            cw.block_end()
2829            cw.nl()
2830
2831
2832def print_kernel_op_table_fwd(family, cw, terminate):
2833    exported = not kernel_can_gen_family_struct(family)
2834
2835    if not terminate or exported:
2836        cw.p(f"/* Ops table for {family.ident_name} */")
2837
2838        pol_to_struct = {'global': 'genl_small_ops',
2839                         'per-op': 'genl_ops',
2840                         'split': 'genl_split_ops'}
2841        struct_type = pol_to_struct[family.kernel_policy]
2842
2843        if not exported:
2844            cnt = ""
2845        elif family.kernel_policy == 'split':
2846            cnt = 0
2847            for op in family.ops.values():
2848                if 'do' in op:
2849                    cnt += 1
2850                if 'dump' in op:
2851                    cnt += 1
2852        else:
2853            cnt = len(family.ops)
2854
2855        qual = 'static const' if not exported else 'const'
2856        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2857        if terminate:
2858            cw.p(f"extern {line};")
2859        else:
2860            cw.block_start(line=line + ' =')
2861
2862    if not terminate:
2863        return
2864
2865    cw.nl()
2866    for name in family.hooks['pre']['do']['list']:
2867        cw.write_func_prot('int', c_lower(name),
2868                           ['const struct genl_split_ops *ops',
2869                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2870    for name in family.hooks['post']['do']['list']:
2871        cw.write_func_prot('void', c_lower(name),
2872                           ['const struct genl_split_ops *ops',
2873                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2874    for name in family.hooks['pre']['dump']['list']:
2875        cw.write_func_prot('int', c_lower(name),
2876                           ['struct netlink_callback *cb'], suffix=';')
2877    for name in family.hooks['post']['dump']['list']:
2878        cw.write_func_prot('int', c_lower(name),
2879                           ['struct netlink_callback *cb'], suffix=';')
2880
2881    cw.nl()
2882
2883    for op_name, op in family.ops.items():
2884        if op.is_async:
2885            continue
2886
2887        if 'do' in op:
2888            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2889            cw.write_func_prot('int', name,
2890                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2891
2892        if 'dump' in op:
2893            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2894            cw.write_func_prot('int', name,
2895                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2896    cw.nl()
2897
2898
2899def print_kernel_op_table_hdr(family, cw):
2900    print_kernel_op_table_fwd(family, cw, terminate=True)
2901
2902
2903def print_kernel_op_table(family, cw):
2904    print_kernel_op_table_fwd(family, cw, terminate=False)
2905    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2906        for op_name, op in family.ops.items():
2907            if op.is_async:
2908                continue
2909
2910            cw.ifdef_block(op.get('config-cond', None))
2911            cw.block_start()
2912            members = [('cmd', op.enum_name)]
2913            if 'dont-validate' in op:
2914                members.append(('validate',
2915                                ' | '.join([c_upper('genl-dont-validate-' + x)
2916                                            for x in op['dont-validate']])), )
2917            for op_mode in ['do', 'dump']:
2918                if op_mode in op:
2919                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2920                    members.append((op_mode + 'it', name))
2921            if family.kernel_policy == 'per-op':
2922                struct = Struct(family, op['attribute-set'],
2923                                type_list=op['do']['request']['attributes'])
2924
2925                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2926                members.append(('policy', name))
2927                members.append(('maxattr', struct.attr_max_val.enum_name))
2928            if 'flags' in op:
2929                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2930            cw.write_struct_init(members)
2931            cw.block_end(line=',')
2932    elif family.kernel_policy == 'split':
2933        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2934                    'dump': {'pre': 'start', 'post': 'done'}}
2935
2936        for op_name, op in family.ops.items():
2937            for op_mode in ['do', 'dump']:
2938                if op.is_async or op_mode not in op:
2939                    continue
2940
2941                cw.ifdef_block(op.get('config-cond', None))
2942                cw.block_start()
2943                members = [('cmd', op.enum_name)]
2944                if 'dont-validate' in op:
2945                    dont_validate = []
2946                    for x in op['dont-validate']:
2947                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2948                            continue
2949                        if op_mode == "dump" and x == 'strict':
2950                            continue
2951                        dont_validate.append(x)
2952
2953                    if dont_validate:
2954                        members.append(('validate',
2955                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2956                                                    for x in dont_validate])), )
2957                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2958                if 'pre' in op[op_mode]:
2959                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2960                members.append((op_mode + 'it', name))
2961                if 'post' in op[op_mode]:
2962                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2963                if 'request' in op[op_mode]:
2964                    struct = Struct(family, op['attribute-set'],
2965                                    type_list=op[op_mode]['request']['attributes'])
2966
2967                    if op.dual_policy:
2968                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2969                    else:
2970                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2971                    members.append(('policy', name))
2972                    members.append(('maxattr', struct.attr_max_val.enum_name))
2973                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2974                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2975                cw.write_struct_init(members)
2976                cw.block_end(line=',')
2977    cw.ifdef_block(None)
2978
2979    cw.block_end(line=';')
2980    cw.nl()
2981
2982
2983def print_kernel_mcgrp_hdr(family, cw):
2984    if not family.mcgrps['list']:
2985        return
2986
2987    cw.block_start('enum')
2988    for grp in family.mcgrps['list']:
2989        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2990        cw.p(grp_id)
2991    cw.block_end(';')
2992    cw.nl()
2993
2994
2995def print_kernel_mcgrp_src(family, cw):
2996    if not family.mcgrps['list']:
2997        return
2998
2999    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
3000    for grp in family.mcgrps['list']:
3001        name = grp['name']
3002        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
3003        cw.p('[' + grp_id + '] = { "' + name + '", },')
3004    cw.block_end(';')
3005    cw.nl()
3006
3007
3008def print_kernel_family_struct_hdr(family, cw):
3009    if not kernel_can_gen_family_struct(family):
3010        return
3011
3012    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
3013    cw.nl()
3014    if 'sock-priv' in family.kernel_family:
3015        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
3016        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
3017        cw.nl()
3018
3019
3020def print_kernel_family_struct_src(family, cw):
3021    if not kernel_can_gen_family_struct(family):
3022        return
3023
3024    if 'sock-priv' in family.kernel_family:
3025        # Generate "trampolines" to make CFI happy
3026        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
3027                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
3028                      ["void *priv"])
3029        cw.nl()
3030        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
3031                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
3032                      ["void *priv"])
3033        cw.nl()
3034
3035    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
3036    cw.p('.name\t\t= ' + family.fam_key + ',')
3037    cw.p('.version\t= ' + family.ver_key + ',')
3038    cw.p('.netnsok\t= true,')
3039    cw.p('.parallel_ops\t= true,')
3040    cw.p('.module\t\t= THIS_MODULE,')
3041    if family.kernel_policy == 'per-op':
3042        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
3043        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3044    elif family.kernel_policy == 'split':
3045        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
3046        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3047    if family.mcgrps['list']:
3048        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
3049        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
3050    if 'sock-priv' in family.kernel_family:
3051        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
3052        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
3053        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
3054    cw.block_end(';')
3055
3056
3057def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
3058    start_line = 'enum'
3059    if enum_name in obj:
3060        if obj[enum_name]:
3061            start_line = 'enum ' + c_lower(obj[enum_name])
3062    elif ckey and ckey in obj:
3063        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3064    cw.block_start(line=start_line)
3065
3066
3067def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3068    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3069    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3070    max_value = f"({cnt_name} - 1)"
3071
3072    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3073    val = 0
3074    for op in family.msgs.values():
3075        if separate_ntf and ('notify' in op or 'event' in op):
3076            continue
3077
3078        suffix = ','
3079        if op.value != val:
3080            suffix = f" = {op.value},"
3081            val = op.value
3082        cw.p(op.enum_name + suffix)
3083        val += 1
3084    cw.nl()
3085    cw.p(cnt_name + ('' if max_by_define else ','))
3086    if not max_by_define:
3087        cw.p(f"{max_name} = {max_value}")
3088    cw.block_end(line=';')
3089    if max_by_define:
3090        cw.p(f"#define {max_name} {max_value}")
3091    cw.nl()
3092
3093
3094def render_uapi_directional(family, cw, max_by_define):
3095    max_name = f"{family.op_prefix}USER_MAX"
3096    cnt_name = f"__{family.op_prefix}USER_CNT"
3097    max_value = f"({cnt_name} - 1)"
3098
3099    cw.block_start(line='enum')
3100    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3101    val = 0
3102    for op in family.msgs.values():
3103        if 'do' in op and 'event' not in op:
3104            suffix = ','
3105            if op.value and op.value != val:
3106                suffix = f" = {op.value},"
3107                val = op.value
3108            cw.p(op.enum_name + suffix)
3109            val += 1
3110    cw.nl()
3111    cw.p(cnt_name + ('' if max_by_define else ','))
3112    if not max_by_define:
3113        cw.p(f"{max_name} = {max_value}")
3114    cw.block_end(line=';')
3115    if max_by_define:
3116        cw.p(f"#define {max_name} {max_value}")
3117    cw.nl()
3118
3119    max_name = f"{family.op_prefix}KERNEL_MAX"
3120    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3121    max_value = f"({cnt_name} - 1)"
3122
3123    cw.block_start(line='enum')
3124    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3125    val = 0
3126    for op in family.msgs.values():
3127        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3128            enum_name = op.enum_name
3129            if 'event' not in op and 'notify' not in op:
3130                enum_name = f'{enum_name}_REPLY'
3131
3132            suffix = ','
3133            if op.value and op.value != val:
3134                suffix = f" = {op.value},"
3135                val = op.value
3136            cw.p(enum_name + suffix)
3137            val += 1
3138    cw.nl()
3139    cw.p(cnt_name + ('' if max_by_define else ','))
3140    if not max_by_define:
3141        cw.p(f"{max_name} = {max_value}")
3142    cw.block_end(line=';')
3143    if max_by_define:
3144        cw.p(f"#define {max_name} {max_value}")
3145    cw.nl()
3146
3147
3148def render_uapi(family, cw):
3149    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3150    hdr_prot = hdr_prot.replace('/', '_')
3151    cw.p('#ifndef ' + hdr_prot)
3152    cw.p('#define ' + hdr_prot)
3153    cw.nl()
3154
3155    defines = [(family.fam_key, family["name"]),
3156               (family.ver_key, family.get('version', 1))]
3157    cw.writes_defines(defines)
3158    cw.nl()
3159
3160    defines = []
3161    for const in family['definitions']:
3162        if const.get('header'):
3163            continue
3164
3165        if const['type'] != 'const':
3166            cw.writes_defines(defines)
3167            defines = []
3168            cw.nl()
3169
3170        # Write kdoc for enum and flags (one day maybe also structs)
3171        if const['type'] == 'enum' or const['type'] == 'flags':
3172            enum = family.consts[const['name']]
3173
3174            if enum.header:
3175                continue
3176
3177            if enum.has_doc():
3178                if enum.has_entry_doc():
3179                    cw.p('/**')
3180                    doc = ''
3181                    if 'doc' in enum:
3182                        doc = ' - ' + enum['doc']
3183                    cw.write_doc_line(enum.enum_name + doc)
3184                else:
3185                    cw.p('/*')
3186                    cw.write_doc_line(enum['doc'], indent=False)
3187                for entry in enum.entries.values():
3188                    if entry.has_doc():
3189                        doc = '@' + entry.c_name + ': ' + entry['doc']
3190                        cw.write_doc_line(doc)
3191                cw.p(' */')
3192
3193            uapi_enum_start(family, cw, const, 'name')
3194            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3195            for entry in enum.entries.values():
3196                suffix = ','
3197                if entry.value_change:
3198                    suffix = f" = {entry.user_value()}" + suffix
3199                cw.p(entry.c_name + suffix)
3200
3201            if const.get('render-max', False):
3202                cw.nl()
3203                cw.p('/* private: */')
3204                if const['type'] == 'flags':
3205                    max_name = c_upper(name_pfx + 'mask')
3206                    max_val = f' = {enum.get_mask()},'
3207                    cw.p(max_name + max_val)
3208                else:
3209                    cnt_name = enum.enum_cnt_name
3210                    max_name = c_upper(name_pfx + 'max')
3211                    if not cnt_name:
3212                        cnt_name = '__' + name_pfx + 'max'
3213                    cw.p(c_upper(cnt_name) + ',')
3214                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3215            cw.block_end(line=';')
3216            cw.nl()
3217        elif const['type'] == 'const':
3218            name_pfx = const.get('name-prefix', f"{family.ident_name}-")
3219            defines.append([c_upper(family.get('c-define-name',
3220                                               f"{name_pfx}{const['name']}")),
3221                            const['value']])
3222
3223    if defines:
3224        cw.writes_defines(defines)
3225        cw.nl()
3226
3227    max_by_define = family.get('max-by-define', False)
3228
3229    for _, attr_set in family.attr_sets.items():
3230        if attr_set.subset_of:
3231            continue
3232
3233        max_value = f"({attr_set.cnt_name} - 1)"
3234
3235        val = 0
3236        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3237        for _, attr in attr_set.items():
3238            suffix = ','
3239            if attr.value != val:
3240                suffix = f" = {attr.value},"
3241                val = attr.value
3242            val += 1
3243            cw.p(attr.enum_name + suffix)
3244        if attr_set.items():
3245            cw.nl()
3246        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3247        if not max_by_define:
3248            cw.p(f"{attr_set.max_name} = {max_value}")
3249        cw.block_end(line=';')
3250        if max_by_define:
3251            cw.p(f"#define {attr_set.max_name} {max_value}")
3252        cw.nl()
3253
3254    # Commands
3255    separate_ntf = 'async-prefix' in family['operations']
3256
3257    if family.msg_id_model == 'unified':
3258        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3259    elif family.msg_id_model == 'directional':
3260        render_uapi_directional(family, cw, max_by_define)
3261    else:
3262        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3263
3264    if separate_ntf:
3265        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3266        for op in family.msgs.values():
3267            if separate_ntf and not ('notify' in op or 'event' in op):
3268                continue
3269
3270            suffix = ','
3271            if 'value' in op:
3272                suffix = f" = {op['value']},"
3273            cw.p(op.enum_name + suffix)
3274        cw.block_end(line=';')
3275        cw.nl()
3276
3277    # Multicast
3278    defines = []
3279    for grp in family.mcgrps['list']:
3280        name = grp['name']
3281        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3282                        f'{name}'])
3283    cw.nl()
3284    if defines:
3285        cw.writes_defines(defines)
3286        cw.nl()
3287
3288    cw.p(f'#endif /* {hdr_prot} */')
3289
3290
3291def _render_user_ntf_entry(ri, op):
3292    if not ri.family.is_classic():
3293        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3294    else:
3295        crud_op = ri.family.req_by_value[op.rsp_value]
3296        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3297    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3298    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3299    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3300    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3301    ri.cw.block_end(line=',')
3302
3303
3304def render_user_family(family, cw, prototype):
3305    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3306    if prototype:
3307        cw.p(f'extern {symbol};')
3308        return
3309
3310    if family.ntfs:
3311        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3312        for ntf_op_name, ntf_op in family.ntfs.items():
3313            if 'notify' in ntf_op:
3314                op = family.ops[ntf_op['notify']]
3315                ri = RenderInfo(cw, family, "user", op, "notify")
3316            elif 'event' in ntf_op:
3317                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3318            else:
3319                raise Exception('Invalid notification ' + ntf_op_name)
3320            _render_user_ntf_entry(ri, ntf_op)
3321        for op_name, op in family.ops.items():
3322            if 'event' not in op:
3323                continue
3324            ri = RenderInfo(cw, family, "user", op, "event")
3325            _render_user_ntf_entry(ri, op)
3326        cw.block_end(line=";")
3327        cw.nl()
3328
3329    cw.block_start(f'{symbol} = ')
3330    cw.p(f'.name\t\t= "{family.c_name}",')
3331    if family.is_classic():
3332        cw.p('.is_classic\t= true,')
3333        cw.p(f'.classic_id\t= {family.get("protonum")},')
3334    if family.is_classic():
3335        if family.fixed_header:
3336            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3337    elif family.fixed_header:
3338        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3339    else:
3340        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3341    if family.ntfs:
3342        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3343        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3344    cw.block_end(line=';')
3345
3346
3347def family_contains_bitfield32(family):
3348    for _, attr_set in family.attr_sets.items():
3349        if attr_set.subset_of:
3350            continue
3351        for _, attr in attr_set.items():
3352            if attr.type == "bitfield32":
3353                return True
3354    return False
3355
3356
3357def find_kernel_root(full_path):
3358    sub_path = ''
3359    while True:
3360        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3361        full_path = os.path.dirname(full_path)
3362        maintainers = os.path.join(full_path, "MAINTAINERS")
3363        if os.path.exists(maintainers):
3364            return full_path, sub_path[:-1]
3365
3366
3367def main():
3368    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3369    parser.add_argument('--mode', dest='mode', type=str, required=True,
3370                        choices=('user', 'kernel', 'uapi'))
3371    parser.add_argument('--spec', dest='spec', type=str, required=True)
3372    parser.add_argument('--header', dest='header', action='store_true', default=None)
3373    parser.add_argument('--source', dest='header', action='store_false')
3374    parser.add_argument('--user-header', nargs='+', default=[])
3375    parser.add_argument('--cmp-out', action='store_true', default=None,
3376                        help='Do not overwrite the output file if the new output is identical to the old')
3377    parser.add_argument('--exclude-op', action='append', default=[])
3378    parser.add_argument('-o', dest='out_file', type=str, default=None)
3379    args = parser.parse_args()
3380
3381    if args.header is None:
3382        parser.error("--header or --source is required")
3383
3384    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3385
3386    try:
3387        parsed = Family(args.spec, exclude_ops)
3388        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3389            print('Spec license:', parsed.license)
3390            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3391            os.sys.exit(1)
3392    except yaml.YAMLError as exc:
3393        print(exc)
3394        os.sys.exit(1)
3395        return
3396
3397    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3398
3399    _, spec_kernel = find_kernel_root(args.spec)
3400    if args.mode == 'uapi' or args.header:
3401        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3402    else:
3403        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3404    cw.p("/* Do not edit directly, auto-generated from: */")
3405    cw.p(f"/*\t{spec_kernel} */")
3406    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3407    if args.exclude_op or args.user_header:
3408        line = ''
3409        line += ' --user-header '.join([''] + args.user_header)
3410        line += ' --exclude-op '.join([''] + args.exclude_op)
3411        cw.p(f'/* YNL-ARG{line} */')
3412    cw.nl()
3413
3414    if args.mode == 'uapi':
3415        render_uapi(parsed, cw)
3416        return
3417
3418    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3419    if args.header:
3420        cw.p('#ifndef ' + hdr_prot)
3421        cw.p('#define ' + hdr_prot)
3422        cw.nl()
3423
3424    if args.out_file:
3425        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3426    else:
3427        hdr_file = "generated_header_file.h"
3428
3429    if args.mode == 'kernel':
3430        cw.p('#include <net/netlink.h>')
3431        cw.p('#include <net/genetlink.h>')
3432        cw.nl()
3433        if not args.header:
3434            if args.out_file:
3435                cw.p(f'#include "{hdr_file}"')
3436            cw.nl()
3437        headers = ['uapi/' + parsed.uapi_header]
3438        headers += parsed.kernel_family.get('headers', [])
3439    else:
3440        cw.p('#include <stdlib.h>')
3441        cw.p('#include <string.h>')
3442        if args.header:
3443            cw.p('#include <linux/types.h>')
3444            if family_contains_bitfield32(parsed):
3445                cw.p('#include <linux/netlink.h>')
3446        else:
3447            cw.p(f'#include "{hdr_file}"')
3448            cw.p('#include "ynl.h"')
3449        headers = []
3450    for definition in parsed['definitions'] + parsed['attribute-sets']:
3451        if 'header' in definition:
3452            headers.append(definition['header'])
3453    if args.mode == 'user':
3454        headers.append(parsed.uapi_header)
3455    seen_header = []
3456    for one in headers:
3457        if one not in seen_header:
3458            cw.p(f"#include <{one}>")
3459            seen_header.append(one)
3460    cw.nl()
3461
3462    if args.mode == "user":
3463        if not args.header:
3464            cw.p("#include <linux/genetlink.h>")
3465            cw.nl()
3466            for one in args.user_header:
3467                cw.p(f'#include "{one}"')
3468        else:
3469            cw.p('struct ynl_sock;')
3470            cw.nl()
3471            render_user_family(parsed, cw, True)
3472        cw.nl()
3473
3474    if args.mode == "kernel":
3475        if args.header:
3476            for _, struct in sorted(parsed.pure_nested_structs.items()):
3477                if struct.request:
3478                    cw.p('/* Common nested types */')
3479                    break
3480            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3481                if struct.request:
3482                    print_req_policy_fwd(cw, struct)
3483            cw.nl()
3484
3485            if parsed.kernel_policy == 'global':
3486                cw.p(f"/* Global operation policy for {parsed.name} */")
3487
3488                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3489                print_req_policy_fwd(cw, struct)
3490                cw.nl()
3491
3492            if parsed.kernel_policy in {'per-op', 'split'}:
3493                for op_name, op in parsed.ops.items():
3494                    if 'do' in op and 'event' not in op:
3495                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3496                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3497                        cw.nl()
3498
3499            print_kernel_op_table_hdr(parsed, cw)
3500            print_kernel_mcgrp_hdr(parsed, cw)
3501            print_kernel_family_struct_hdr(parsed, cw)
3502        else:
3503            print_kernel_policy_ranges(parsed, cw)
3504            print_kernel_policy_sparse_enum_validates(parsed, cw)
3505
3506            for _, struct in sorted(parsed.pure_nested_structs.items()):
3507                if struct.request:
3508                    cw.p('/* Common nested types */')
3509                    break
3510            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3511                if struct.request:
3512                    print_req_policy(cw, struct)
3513            cw.nl()
3514
3515            if parsed.kernel_policy == 'global':
3516                cw.p(f"/* Global operation policy for {parsed.name} */")
3517
3518                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3519                print_req_policy(cw, struct)
3520                cw.nl()
3521
3522            for op_name, op in parsed.ops.items():
3523                if parsed.kernel_policy in {'per-op', 'split'}:
3524                    for op_mode in ['do', 'dump']:
3525                        if op_mode in op and 'request' in op[op_mode]:
3526                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3527                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3528                            print_req_policy(cw, ri.struct['request'], ri=ri)
3529                            cw.nl()
3530
3531            print_kernel_op_table(parsed, cw)
3532            print_kernel_mcgrp_src(parsed, cw)
3533            print_kernel_family_struct_src(parsed, cw)
3534
3535    if args.mode == "user":
3536        if args.header:
3537            cw.p('/* Enums */')
3538            put_op_name_fwd(parsed, cw)
3539
3540            for name, const in parsed.consts.items():
3541                if isinstance(const, EnumSet):
3542                    put_enum_to_str_fwd(parsed, cw, const)
3543            cw.nl()
3544
3545            cw.p('/* Common nested types */')
3546            for attr_set, struct in parsed.pure_nested_structs.items():
3547                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3548                print_type_full(ri, struct)
3549
3550            for op_name, op in parsed.ops.items():
3551                cw.p(f"/* ============== {op.enum_name} ============== */")
3552
3553                if 'do' in op and 'event' not in op:
3554                    cw.p(f"/* {op.enum_name} - do */")
3555                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3556                    print_req_type(ri)
3557                    print_req_type_helpers(ri)
3558                    cw.nl()
3559                    print_rsp_type(ri)
3560                    print_rsp_type_helpers(ri)
3561                    cw.nl()
3562                    print_req_prototype(ri)
3563                    cw.nl()
3564
3565                if 'dump' in op:
3566                    cw.p(f"/* {op.enum_name} - dump */")
3567                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3568                    print_req_type(ri)
3569                    print_req_type_helpers(ri)
3570                    if not ri.type_consistent or ri.type_oneside:
3571                        print_rsp_type(ri)
3572                    print_wrapped_type(ri)
3573                    print_dump_prototype(ri)
3574                    cw.nl()
3575
3576                if op.has_ntf:
3577                    cw.p(f"/* {op.enum_name} - notify */")
3578                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3579                    if not ri.type_consistent:
3580                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3581                    print_wrapped_type(ri)
3582
3583            for op_name, op in parsed.ntfs.items():
3584                if 'event' in op:
3585                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3586                    cw.p(f"/* {op.enum_name} - event */")
3587                    print_rsp_type(ri)
3588                    cw.nl()
3589                    print_wrapped_type(ri)
3590            cw.nl()
3591        else:
3592            cw.p('/* Enums */')
3593            put_op_name(parsed, cw)
3594
3595            for name, const in parsed.consts.items():
3596                if isinstance(const, EnumSet):
3597                    put_enum_to_str(parsed, cw, const)
3598            cw.nl()
3599
3600            has_recursive_nests = False
3601            cw.p('/* Policies */')
3602            for struct in parsed.pure_nested_structs.values():
3603                if struct.recursive:
3604                    put_typol_fwd(cw, struct)
3605                    has_recursive_nests = True
3606            if has_recursive_nests:
3607                cw.nl()
3608            for struct in parsed.pure_nested_structs.values():
3609                put_typol(cw, struct)
3610            for name in parsed.root_sets:
3611                struct = Struct(parsed, name)
3612                put_typol(cw, struct)
3613
3614            cw.p('/* Common nested types */')
3615            if has_recursive_nests:
3616                for attr_set, struct in parsed.pure_nested_structs.items():
3617                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3618                    free_rsp_nested_prototype(ri)
3619                    if struct.request:
3620                        put_req_nested_prototype(ri, struct)
3621                    if struct.reply:
3622                        parse_rsp_nested_prototype(ri, struct)
3623                cw.nl()
3624            for attr_set, struct in parsed.pure_nested_structs.items():
3625                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3626
3627                free_rsp_nested(ri, struct)
3628                if struct.request:
3629                    put_req_nested(ri, struct)
3630                if struct.reply:
3631                    parse_rsp_nested(ri, struct)
3632
3633            for op_name, op in parsed.ops.items():
3634                cw.p(f"/* ============== {op.enum_name} ============== */")
3635                if 'do' in op and 'event' not in op:
3636                    cw.p(f"/* {op.enum_name} - do */")
3637                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3638                    print_req_free(ri)
3639                    print_rsp_free(ri)
3640                    parse_rsp_msg(ri)
3641                    print_req(ri)
3642                    cw.nl()
3643
3644                if 'dump' in op:
3645                    cw.p(f"/* {op.enum_name} - dump */")
3646                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3647                    if not ri.type_consistent or ri.type_oneside:
3648                        parse_rsp_msg(ri, deref=True)
3649                    print_req_free(ri)
3650                    print_dump_type_free(ri)
3651                    print_dump(ri)
3652                    cw.nl()
3653
3654                if op.has_ntf:
3655                    cw.p(f"/* {op.enum_name} - notify */")
3656                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3657                    if not ri.type_consistent:
3658                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3659                    print_ntf_type_free(ri)
3660
3661            for op_name, op in parsed.ntfs.items():
3662                if 'event' in op:
3663                    cw.p(f"/* {op.enum_name} - event */")
3664
3665                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3666                    parse_rsp_msg(ri)
3667
3668                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3669                    print_ntf_type_free(ri)
3670            cw.nl()
3671            render_user_family(parsed, cw, False)
3672
3673    if args.header:
3674        cw.p(f'#endif /* {hdr_prot} */')
3675
3676
3677if __name__ == "__main__":
3678    main()
3679