xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision d0bdfe36d77791a67f18d18e93f3fca9a2b10240)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import filecmp
6import pathlib
7import os
8import re
9import shutil
10import sys
11import tempfile
12import yaml
13
14sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
15from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
16from lib import SpecSubMessage
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        self.is_selector = False
60
61        if 'len' in attr:
62            self.len = attr['len']
63
64        if 'nested-attributes' in attr:
65            nested = attr['nested-attributes']
66        elif 'sub-message' in attr:
67            nested = attr['sub-message']
68        else:
69            nested = None
70
71        if nested:
72            self.nested_attrs = nested
73            if self.nested_attrs == family.name:
74                self.nested_render_name = c_lower(f"{family.ident_name}")
75            else:
76                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
77
78            if self.nested_attrs in self.family.consts:
79                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
80            else:
81                self.nested_struct_type = 'struct ' + self.nested_render_name
82
83        self.c_name = c_lower(self.name)
84        if self.c_name in _C_KW:
85            self.c_name += '_'
86        if self.c_name[0].isdigit():
87            self.c_name = '_' + self.c_name
88
89        # Added by resolve():
90        self.enum_name = None
91        delattr(self, "enum_name")
92
93    def _get_real_attr(self):
94        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
95        return self.family.attr_sets[self.attr_set.subset_of][self.name]
96
97    def set_request(self):
98        self.request = True
99        if self.attr_set.subset_of:
100            self._get_real_attr().set_request()
101
102    def set_reply(self):
103        self.reply = True
104        if self.attr_set.subset_of:
105            self._get_real_attr().set_reply()
106
107    def get_limit(self, limit, default=None):
108        value = self.checks.get(limit, default)
109        if value is None:
110            return value
111        if isinstance(value, int):
112            return value
113        if value in self.family.consts:
114            return self.family.consts[value]["value"]
115        return limit_to_number(value)
116
117    def get_limit_str(self, limit, default=None, suffix=''):
118        value = self.checks.get(limit, default)
119        if value is None:
120            return ''
121        if isinstance(value, int):
122            return str(value) + suffix
123        if value in self.family.consts:
124            const = self.family.consts[value]
125            if const.get('header'):
126                return c_upper(value)
127            return c_upper(f"{self.family['name']}-{value}")
128        return c_upper(value)
129
130    def resolve(self):
131        if 'parent-sub-message' in self.attr:
132            enum_name = self.attr['parent-sub-message'].enum_name
133        elif 'name-prefix' in self.attr:
134            enum_name = f"{self.attr['name-prefix']}{self.name}"
135        else:
136            enum_name = f"{self.attr_set.name_prefix}{self.name}"
137        self.enum_name = c_upper(enum_name)
138
139        if self.attr_set.subset_of:
140            if self.checks != self._get_real_attr().checks:
141                raise Exception("Overriding checks not supported by codegen, yet")
142
143    def is_multi_val(self):
144        return None
145
146    def is_scalar(self):
147        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
148
149    def is_recursive(self):
150        return False
151
152    def is_recursive_for_op(self, ri):
153        return self.is_recursive() and not ri.op
154
155    def presence_type(self):
156        return 'present'
157
158    def presence_member(self, space, type_filter):
159        if self.presence_type() != type_filter:
160            return
161
162        if self.presence_type() == 'present':
163            pfx = '__' if space == 'user' else ''
164            return f"{pfx}u32 {self.c_name}:1;"
165
166        if self.presence_type() in {'len', 'count'}:
167            pfx = '__' if space == 'user' else ''
168            return f"{pfx}u32 {self.c_name};"
169
170    def _complex_member_type(self, ri):
171        return None
172
173    def free_needs_iter(self):
174        return False
175
176    def _free_lines(self, ri, var, ref):
177        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
178            return [f'free({var}->{ref}{self.c_name});']
179        return []
180
181    def free(self, ri, var, ref):
182        lines = self._free_lines(ri, var, ref)
183        for line in lines:
184            ri.cw.p(line)
185
186    def arg_member(self, ri):
187        member = self._complex_member_type(ri)
188        if member:
189            spc = ' ' if member[-1] != '*' else ''
190            arg = [member + spc + '*' + self.c_name]
191            if self.presence_type() == 'count':
192                arg += ['unsigned int n_' + self.c_name]
193            return arg
194        raise Exception(f"Struct member not implemented for class type {self.type}")
195
196    def struct_member(self, ri):
197        member = self._complex_member_type(ri)
198        if member:
199            ptr = '*' if self.is_multi_val() else ''
200            if self.is_recursive_for_op(ri):
201                ptr = '*'
202            spc = ' ' if member[-1] != '*' else ''
203            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
204            return
205        members = self.arg_member(ri)
206        for one in members:
207            ri.cw.p(one + ';')
208
209    def _attr_policy(self, policy):
210        return '{ .type = ' + policy + ', }'
211
212    def attr_policy(self, cw):
213        policy = f'NLA_{c_upper(self.type)}'
214        if self.attr.get('byte-order') == 'big-endian':
215            if self.type in {'u16', 'u32'}:
216                policy = f'NLA_BE{self.type[1:]}'
217
218        spec = self._attr_policy(policy)
219        cw.p(f"\t[{self.enum_name}] = {spec},")
220
221    def _attr_typol(self):
222        raise Exception(f"Type policy not implemented for class type {self.type}")
223
224    def attr_typol(self, cw):
225        typol = self._attr_typol()
226        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
227
228    def _attr_put_line(self, ri, var, line):
229        presence = self.presence_type()
230        if presence in {'present', 'len'}:
231            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
232        ri.cw.p(f"{line};")
233
234    def _attr_put_simple(self, ri, var, put_type):
235        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
236        self._attr_put_line(ri, var, line)
237
238    def attr_put(self, ri, var):
239        raise Exception(f"Put not implemented for class type {self.type}")
240
241    def _attr_get(self, ri, var):
242        raise Exception(f"Attr get not implemented for class type {self.type}")
243
244    def attr_get(self, ri, var, first):
245        lines, init_lines, local_vars = self._attr_get(ri, var)
246        if type(lines) is str:
247            lines = [lines]
248        if type(init_lines) is str:
249            init_lines = [init_lines]
250
251        kw = 'if' if first else 'else if'
252        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
253        if local_vars:
254            for local in local_vars:
255                ri.cw.p(local)
256            ri.cw.nl()
257
258        if not self.is_multi_val():
259            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
260            ri.cw.p("return YNL_PARSE_CB_ERROR;")
261            if self.presence_type() == 'present':
262                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
263
264        if init_lines:
265            ri.cw.nl()
266            for line in init_lines:
267                ri.cw.p(line)
268
269        for line in lines:
270            ri.cw.p(line)
271        ri.cw.block_end()
272        return True
273
274    def _setter_lines(self, ri, member, presence):
275        raise Exception(f"Setter not implemented for class type {self.type}")
276
277    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
278        ref = (ref if ref else []) + [self.c_name]
279        member = f"{var}->{'.'.join(ref)}"
280
281        local_vars = []
282        if self.free_needs_iter():
283            local_vars += ['unsigned int i;']
284
285        code = []
286        presence = ''
287        for i in range(0, len(ref)):
288            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
289            # Every layer below last is a nest, so we know it uses bit presence
290            # last layer is "self" and may be a complex type
291            if i == len(ref) - 1 and self.presence_type() != 'present':
292                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
293                continue
294            code.append(presence + ' = 1;')
295        ref_path = '.'.join(ref[:-1])
296        if ref_path:
297            ref_path += '.'
298        code += self._free_lines(ri, var, ref_path)
299        code += self._setter_lines(ri, member, presence)
300
301        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
302        free = bool([x for x in code if 'free(' in x])
303        alloc = bool([x for x in code if 'alloc(' in x])
304        if free and not alloc:
305            func_name = '__' + func_name
306        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
307                         body=code,
308                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
309
310
311class TypeUnused(Type):
312    def presence_type(self):
313        return ''
314
315    def arg_member(self, ri):
316        return []
317
318    def _attr_get(self, ri, var):
319        return ['return YNL_PARSE_CB_ERROR;'], None, None
320
321    def _attr_typol(self):
322        return '.type = YNL_PT_REJECT, '
323
324    def attr_policy(self, cw):
325        pass
326
327    def attr_put(self, ri, var):
328        pass
329
330    def attr_get(self, ri, var, first):
331        pass
332
333    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
334        pass
335
336
337class TypePad(Type):
338    def presence_type(self):
339        return ''
340
341    def arg_member(self, ri):
342        return []
343
344    def _attr_typol(self):
345        return '.type = YNL_PT_IGNORE, '
346
347    def attr_put(self, ri, var):
348        pass
349
350    def attr_get(self, ri, var, first):
351        pass
352
353    def attr_policy(self, cw):
354        pass
355
356    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
357        pass
358
359
360class TypeScalar(Type):
361    def __init__(self, family, attr_set, attr, value):
362        super().__init__(family, attr_set, attr, value)
363
364        self.byte_order_comment = ''
365        if 'byte-order' in attr:
366            self.byte_order_comment = f" /* {attr['byte-order']} */"
367
368        # Classic families have some funny enums, don't bother
369        # computing checks, since we only need them for kernel policies
370        if not family.is_classic():
371            self._init_checks()
372
373        # Added by resolve():
374        self.is_bitfield = None
375        delattr(self, "is_bitfield")
376        self.type_name = None
377        delattr(self, "type_name")
378
379    def resolve(self):
380        self.resolve_up(super())
381
382        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
383            self.is_bitfield = True
384        elif 'enum' in self.attr:
385            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
386        else:
387            self.is_bitfield = False
388
389        if not self.is_bitfield and 'enum' in self.attr:
390            self.type_name = self.family.consts[self.attr['enum']].user_type
391        elif self.is_auto_scalar:
392            self.type_name = '__' + self.type[0] + '64'
393        else:
394            self.type_name = '__' + self.type
395
396    def _init_checks(self):
397        if 'enum' in self.attr:
398            enum = self.family.consts[self.attr['enum']]
399            low, high = enum.value_range()
400            if low is None and high is None:
401                self.checks['sparse'] = True
402            else:
403                if 'min' not in self.checks:
404                    if low != 0 or self.type[0] == 's':
405                        self.checks['min'] = low
406                if 'max' not in self.checks:
407                    self.checks['max'] = high
408
409        if 'min' in self.checks and 'max' in self.checks:
410            if self.get_limit('min') > self.get_limit('max'):
411                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
412            self.checks['range'] = True
413
414        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
415        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
416        if low < 0 and self.type[0] == 'u':
417            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
418        if low < -32768 or high > 32767:
419            self.checks['full-range'] = True
420
421    def _attr_policy(self, policy):
422        if 'flags-mask' in self.checks or self.is_bitfield:
423            if self.is_bitfield:
424                enum = self.family.consts[self.attr['enum']]
425                mask = enum.get_mask(as_flags=True)
426            else:
427                flags = self.family.consts[self.checks['flags-mask']]
428                flag_cnt = len(flags['entries'])
429                mask = (1 << flag_cnt) - 1
430            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
431        elif 'full-range' in self.checks:
432            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
433        elif 'range' in self.checks:
434            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
435        elif 'min' in self.checks:
436            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
437        elif 'max' in self.checks:
438            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
439        elif 'sparse' in self.checks:
440            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
441        return super()._attr_policy(policy)
442
443    def _attr_typol(self):
444        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
445
446    def arg_member(self, ri):
447        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
448
449    def attr_put(self, ri, var):
450        self._attr_put_simple(ri, var, self.type)
451
452    def _attr_get(self, ri, var):
453        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
454
455    def _setter_lines(self, ri, member, presence):
456        return [f"{member} = {self.c_name};"]
457
458
459class TypeFlag(Type):
460    def arg_member(self, ri):
461        return []
462
463    def _attr_typol(self):
464        return '.type = YNL_PT_FLAG, '
465
466    def attr_put(self, ri, var):
467        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
468
469    def _attr_get(self, ri, var):
470        return [], None, None
471
472    def _setter_lines(self, ri, member, presence):
473        return []
474
475
476class TypeString(Type):
477    def arg_member(self, ri):
478        return [f"const char *{self.c_name}"]
479
480    def presence_type(self):
481        return 'len'
482
483    def struct_member(self, ri):
484        ri.cw.p(f"char *{self.c_name};")
485
486    def _attr_typol(self):
487        typol = '.type = YNL_PT_NUL_STR, '
488        if self.is_selector:
489            typol += '.is_selector = 1, '
490        return typol
491
492    def _attr_policy(self, policy):
493        if 'exact-len' in self.checks:
494            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
495        else:
496            mem = '{ .type = ' + policy
497            if 'max-len' in self.checks:
498                mem += ', .len = ' + self.get_limit_str('max-len')
499            mem += ', }'
500        return mem
501
502    def attr_policy(self, cw):
503        if self.checks.get('unterminated-ok', False):
504            policy = 'NLA_STRING'
505        else:
506            policy = 'NLA_NUL_STRING'
507
508        spec = self._attr_policy(policy)
509        cw.p(f"\t[{self.enum_name}] = {spec},")
510
511    def attr_put(self, ri, var):
512        self._attr_put_simple(ri, var, 'str')
513
514    def _attr_get(self, ri, var):
515        len_mem = var + '->_len.' + self.c_name
516        return [f"{len_mem} = len;",
517                f"{var}->{self.c_name} = malloc(len + 1);",
518                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
519                f"{var}->{self.c_name}[len] = 0;"], \
520               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
521               ['unsigned int len;']
522
523    def _setter_lines(self, ri, member, presence):
524        return [f"{presence} = strlen({self.c_name});",
525                f"{member} = malloc({presence} + 1);",
526                f'memcpy({member}, {self.c_name}, {presence});',
527                f'{member}[{presence}] = 0;']
528
529
530class TypeBinary(Type):
531    def arg_member(self, ri):
532        return [f"const void *{self.c_name}", 'size_t len']
533
534    def presence_type(self):
535        return 'len'
536
537    def struct_member(self, ri):
538        ri.cw.p(f"void *{self.c_name};")
539
540    def _attr_typol(self):
541        return '.type = YNL_PT_BINARY,'
542
543    def _attr_policy(self, policy):
544        if len(self.checks) == 0:
545            pass
546        elif len(self.checks) == 1:
547            check_name = list(self.checks)[0]
548            if check_name not in {'exact-len', 'min-len', 'max-len'}:
549                raise Exception('Unsupported check for binary type: ' + check_name)
550        else:
551            raise Exception('More than one check for binary type not implemented, yet')
552
553        if len(self.checks) == 0:
554            mem = '{ .type = NLA_BINARY, }'
555        elif 'exact-len' in self.checks:
556            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
557        elif 'min-len' in self.checks:
558            mem = 'NLA_POLICY_MIN_LEN(' + self.get_limit_str('min-len') + ')'
559        elif 'max-len' in self.checks:
560            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
561
562        return mem
563
564    def attr_put(self, ri, var):
565        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
566                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
567
568    def _attr_get(self, ri, var):
569        len_mem = var + '->_len.' + self.c_name
570        return [f"{len_mem} = len;",
571                f"{var}->{self.c_name} = malloc(len);",
572                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
573               ['len = ynl_attr_data_len(attr);'], \
574               ['unsigned int len;']
575
576    def _setter_lines(self, ri, member, presence):
577        return [f"{presence} = len;",
578                f"{member} = malloc({presence});",
579                f'memcpy({member}, {self.c_name}, {presence});']
580
581
582class TypeBinaryStruct(TypeBinary):
583    def struct_member(self, ri):
584        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
585
586    def _attr_get(self, ri, var):
587        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
588        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
589        return [f"{len_mem} = len;",
590                f"if (len < {struct_sz})",
591                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
592                "else",
593                f"{var}->{self.c_name} = malloc(len);",
594                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
595               ['len = ynl_attr_data_len(attr);'], \
596               ['unsigned int len;']
597
598
599class TypeBinaryScalarArray(TypeBinary):
600    def arg_member(self, ri):
601        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
602
603    def presence_type(self):
604        return 'count'
605
606    def struct_member(self, ri):
607        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
608
609    def attr_put(self, ri, var):
610        presence = self.presence_type()
611        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
612        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
613        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
614                f"{var}->{self.c_name}, i);")
615        ri.cw.block_end()
616
617    def _attr_get(self, ri, var):
618        len_mem = var + '->_count.' + self.c_name
619        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
620                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
621                f"{var}->{self.c_name} = malloc(len);",
622                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
623               ['len = ynl_attr_data_len(attr);'], \
624               ['unsigned int len;']
625
626    def _setter_lines(self, ri, member, presence):
627        return [f"{presence} = count;",
628                f"count *= sizeof(__{self.get('sub-type')});",
629                f"{member} = malloc(count);",
630                f'memcpy({member}, {self.c_name}, count);']
631
632
633class TypeBitfield32(Type):
634    def _complex_member_type(self, ri):
635        return "struct nla_bitfield32"
636
637    def _attr_typol(self):
638        return '.type = YNL_PT_BITFIELD32, '
639
640    def _attr_policy(self, policy):
641        if 'enum' not in self.attr:
642            raise Exception('Enum required for bitfield32 attr')
643        enum = self.family.consts[self.attr['enum']]
644        mask = enum.get_mask(as_flags=True)
645        return f"NLA_POLICY_BITFIELD32({mask})"
646
647    def attr_put(self, ri, var):
648        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
649        self._attr_put_line(ri, var, line)
650
651    def _attr_get(self, ri, var):
652        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
653
654    def _setter_lines(self, ri, member, presence):
655        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
656
657
658class TypeNest(Type):
659    def is_recursive(self):
660        return self.family.pure_nested_structs[self.nested_attrs].recursive
661
662    def _complex_member_type(self, ri):
663        return self.nested_struct_type
664
665    def _free_lines(self, ri, var, ref):
666        lines = []
667        at = '&'
668        if self.is_recursive_for_op(ri):
669            at = ''
670            lines += [f'if ({var}->{ref}{self.c_name})']
671        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
672        return lines
673
674    def _attr_typol(self):
675        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
676
677    def _attr_policy(self, policy):
678        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
679
680    def attr_put(self, ri, var):
681        at = '' if self.is_recursive_for_op(ri) else '&'
682        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
683                            f"{self.enum_name}, {at}{var}->{self.c_name})")
684
685    def _attr_get(self, ri, var):
686        pns = self.family.pure_nested_structs[self.nested_attrs]
687        args = ["&parg", "attr"]
688        for sel in pns.external_selectors():
689            args.append(f'{var}->{sel.name}')
690        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
691                     "return YNL_PARSE_CB_ERROR;"]
692        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
693                      f"parg.data = &{var}->{self.c_name};"]
694        return get_lines, init_lines, None
695
696    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
697        ref = (ref if ref else []) + [self.c_name]
698
699        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
700            if attr.is_recursive():
701                continue
702            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref,
703                        var=var)
704
705
706class TypeMultiAttr(Type):
707    def __init__(self, family, attr_set, attr, value, base_type):
708        super().__init__(family, attr_set, attr, value)
709
710        self.base_type = base_type
711
712    def is_multi_val(self):
713        return True
714
715    def presence_type(self):
716        return 'count'
717
718    def _complex_member_type(self, ri):
719        if 'type' not in self.attr or self.attr['type'] == 'nest':
720            return self.nested_struct_type
721        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
722            return None  # use arg_member()
723        elif self.attr['type'] == 'string':
724            return 'struct ynl_string *'
725        elif self.attr['type'] in scalars:
726            scalar_pfx = '__' if ri.ku_space == 'user' else ''
727            return scalar_pfx + self.attr['type']
728        else:
729            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
730
731    def arg_member(self, ri):
732        if self.type == 'binary' and 'struct' in self.attr:
733            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
734                    f'unsigned int n_{self.c_name}']
735        return super().arg_member(ri)
736
737    def free_needs_iter(self):
738        return self.attr['type'] in {'nest', 'string'}
739
740    def _free_lines(self, ri, var, ref):
741        lines = []
742        if self.attr['type'] in scalars:
743            lines += [f"free({var}->{ref}{self.c_name});"]
744        elif self.attr['type'] == 'binary':
745            lines += [f"free({var}->{ref}{self.c_name});"]
746        elif self.attr['type'] == 'string':
747            lines += [
748                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
749                f"free({var}->{ref}{self.c_name}[i]);",
750                f"free({var}->{ref}{self.c_name});",
751            ]
752        elif 'type' not in self.attr or self.attr['type'] == 'nest':
753            lines += [
754                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
755                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
756                f"free({var}->{ref}{self.c_name});",
757            ]
758        else:
759            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
760        return lines
761
762    def _attr_policy(self, policy):
763        return self.base_type._attr_policy(policy)
764
765    def _attr_typol(self):
766        return self.base_type._attr_typol()
767
768    def _attr_get(self, ri, var):
769        return f'n_{self.c_name}++;', None, None
770
771    def attr_put(self, ri, var):
772        if self.attr['type'] in scalars:
773            put_type = self.type
774            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
775            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
776        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
777            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
778            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
779        elif self.attr['type'] == 'string':
780            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
781            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
782        elif 'type' not in self.attr or self.attr['type'] == 'nest':
783            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
784            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
785                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
786        else:
787            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
788
789    def _setter_lines(self, ri, member, presence):
790        return [f"{member} = {self.c_name};",
791                f"{presence} = n_{self.c_name};"]
792
793
794class TypeArrayNest(Type):
795    def is_multi_val(self):
796        return True
797
798    def presence_type(self):
799        return 'count'
800
801    def _complex_member_type(self, ri):
802        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
803            return self.nested_struct_type
804        elif self.attr['sub-type'] in scalars:
805            scalar_pfx = '__' if ri.ku_space == 'user' else ''
806            return scalar_pfx + self.attr['sub-type']
807        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
808            return None  # use arg_member()
809        else:
810            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
811
812    def arg_member(self, ri):
813        if self.sub_type == 'binary' and 'exact-len' in self.checks:
814            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
815                    f'unsigned int n_{self.c_name}']
816        return super().arg_member(ri)
817
818    def _attr_policy(self, policy):
819        if self.attr['sub-type'] == 'nest':
820            return f'NLA_POLICY_NESTED_ARRAY({self.nested_render_name}_nl_policy)'
821        return super()._attr_policy(policy)
822
823    def _attr_typol(self):
824        if self.attr['sub-type'] in scalars:
825            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
826        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
827            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
828        else:
829            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
830
831    def _attr_get(self, ri, var):
832        local_vars = ['const struct nlattr *attr2;']
833        get_lines = [f'attr_{self.c_name} = attr;',
834                     'ynl_attr_for_each_nested(attr2, attr) {',
835                     '\tif (ynl_attr_validate(yarg, attr2))',
836                     '\t\treturn YNL_PARSE_CB_ERROR;',
837                     f'\tn_{self.c_name}++;',
838                     '}']
839        return get_lines, None, local_vars
840
841    def attr_put(self, ri, var):
842        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
843        if self.sub_type in scalars:
844            put_type = self.sub_type
845            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
846            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
847            ri.cw.block_end()
848        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
849            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
850            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
851        elif self.sub_type == 'nest':
852            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
853            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
854        else:
855            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
856        ri.cw.p('ynl_attr_nest_end(nlh, array);')
857
858    def _setter_lines(self, ri, member, presence):
859        return [f"{member} = {self.c_name};",
860                f"{presence} = n_{self.c_name};"]
861
862
863class TypeNestTypeValue(Type):
864    def _complex_member_type(self, ri):
865        return self.nested_struct_type
866
867    def _attr_typol(self):
868        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
869
870    def _attr_get(self, ri, var):
871        prev = 'attr'
872        tv_args = ''
873        get_lines = []
874        local_vars = []
875        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
876                      f"parg.data = &{var}->{self.c_name};"]
877        if 'type-value' in self.attr:
878            tv_names = [c_lower(x) for x in self.attr["type-value"]]
879            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
880            local_vars += [f'__u32 {", ".join(tv_names)};']
881            for level in self.attr["type-value"]:
882                level = c_lower(level)
883                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
884                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
885                prev = 'attr_' + level
886
887            tv_args = f", {', '.join(tv_names)}"
888
889        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
890        return get_lines, init_lines, local_vars
891
892
893class TypeSubMessage(TypeNest):
894    def __init__(self, family, attr_set, attr, value):
895        super().__init__(family, attr_set, attr, value)
896
897        self.selector = Selector(attr, attr_set)
898
899    def _attr_typol(self):
900        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
901        typol += '.is_submsg = 1, '
902        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
903        # support external selectors. No family uses sub-messages with external
904        # selector for requests so this is fine for now.
905        if not self.selector.is_external():
906            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
907        return typol
908
909    def _attr_get(self, ri, var):
910        sel = c_lower(self['selector'])
911        if self.selector.is_external():
912            sel_var = f"_sel_{sel}"
913        else:
914            sel_var = f"{var}->{sel}"
915        get_lines = [f'if (!{sel_var})',
916                     'return ynl_submsg_failed(yarg, "%s", "%s");' %
917                        (self.name, self['selector']),
918                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
919                     "return YNL_PARSE_CB_ERROR;"]
920        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
921                      f"parg.data = &{var}->{self.c_name};"]
922        return get_lines, init_lines, None
923
924
925class Selector:
926    def __init__(self, msg_attr, attr_set):
927        self.name = msg_attr["selector"]
928
929        if self.name in attr_set:
930            self.attr = attr_set[self.name]
931            self.attr.is_selector = True
932            self._external = False
933        else:
934            # The selector will need to get passed down thru the structs
935            self.attr = None
936            self._external = True
937
938    def set_attr(self, attr):
939        self.attr = attr
940
941    def is_external(self):
942        return self._external
943
944
945class Struct:
946    def __init__(self, family, space_name, type_list=None, fixed_header=None,
947                 inherited=None, submsg=None):
948        self.family = family
949        self.space_name = space_name
950        self.attr_set = family.attr_sets[space_name]
951        # Use list to catch comparisons with empty sets
952        self._inherited = inherited if inherited is not None else []
953        self.inherited = []
954        self.fixed_header = None
955        if fixed_header:
956            self.fixed_header = 'struct ' + c_lower(fixed_header)
957        self.submsg = submsg
958
959        self.nested = type_list is None
960        if family.name == c_lower(space_name):
961            self.render_name = c_lower(family.ident_name)
962        else:
963            self.render_name = c_lower(family.ident_name + '-' + space_name)
964        self.struct_name = 'struct ' + self.render_name
965        if self.nested and space_name in family.consts:
966            self.struct_name += '_'
967        self.ptr_name = self.struct_name + ' *'
968        # All attr sets this one contains, directly or multiple levels down
969        self.child_nests = set()
970
971        self.request = False
972        self.reply = False
973        self.recursive = False
974        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
975
976        self.attr_list = []
977        self.attrs = dict()
978        if type_list is not None:
979            for t in type_list:
980                self.attr_list.append((t, self.attr_set[t]),)
981        else:
982            for t in self.attr_set:
983                self.attr_list.append((t, self.attr_set[t]),)
984
985        max_val = 0
986        self.attr_max_val = None
987        for name, attr in self.attr_list:
988            if attr.value >= max_val:
989                max_val = attr.value
990                self.attr_max_val = attr
991            self.attrs[name] = attr
992
993    def __iter__(self):
994        yield from self.attrs
995
996    def __getitem__(self, key):
997        return self.attrs[key]
998
999    def member_list(self):
1000        return self.attr_list
1001
1002    def set_inherited(self, new_inherited):
1003        if self._inherited != new_inherited:
1004            raise Exception("Inheriting different members not supported")
1005        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1006
1007    def external_selectors(self):
1008        sels = []
1009        for name, attr in self.attr_list:
1010            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1011                sels.append(attr.selector)
1012        return sels
1013
1014    def free_needs_iter(self):
1015        for _, attr in self.attr_list:
1016            if attr.free_needs_iter():
1017                return True
1018        return False
1019
1020
1021class EnumEntry(SpecEnumEntry):
1022    def __init__(self, enum_set, yaml, prev, value_start):
1023        super().__init__(enum_set, yaml, prev, value_start)
1024
1025        if prev:
1026            self.value_change = (self.value != prev.value + 1)
1027        else:
1028            self.value_change = (self.value != 0)
1029        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1030
1031        # Added by resolve:
1032        self.c_name = None
1033        delattr(self, "c_name")
1034
1035    def resolve(self):
1036        self.resolve_up(super())
1037
1038        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1039
1040
1041class EnumSet(SpecEnumSet):
1042    def __init__(self, family, yaml):
1043        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1044
1045        if 'enum-name' in yaml:
1046            if yaml['enum-name']:
1047                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1048                self.user_type = self.enum_name
1049            else:
1050                self.enum_name = None
1051        else:
1052            self.enum_name = 'enum ' + self.render_name
1053
1054        if self.enum_name:
1055            self.user_type = self.enum_name
1056        else:
1057            self.user_type = 'int'
1058
1059        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1060        self.header = yaml.get('header', None)
1061        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1062
1063        super().__init__(family, yaml)
1064
1065    def new_entry(self, entry, prev_entry, value_start):
1066        return EnumEntry(self, entry, prev_entry, value_start)
1067
1068    def value_range(self):
1069        low = min([x.value for x in self.entries.values()])
1070        high = max([x.value for x in self.entries.values()])
1071
1072        if high - low + 1 != len(self.entries):
1073            return None, None
1074
1075        return low, high
1076
1077
1078class AttrSet(SpecAttrSet):
1079    def __init__(self, family, yaml):
1080        super().__init__(family, yaml)
1081
1082        if self.subset_of is None:
1083            if 'name-prefix' in yaml:
1084                pfx = yaml['name-prefix']
1085            elif self.name == family.name:
1086                pfx = family.ident_name + '-a-'
1087            else:
1088                pfx = f"{family.ident_name}-a-{self.name}-"
1089            self.name_prefix = c_upper(pfx)
1090            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1091            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1092        else:
1093            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1094            self.max_name = family.attr_sets[self.subset_of].max_name
1095            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1096
1097        # Added by resolve:
1098        self.c_name = None
1099        delattr(self, "c_name")
1100
1101    def resolve(self):
1102        self.c_name = c_lower(self.name)
1103        if self.c_name in _C_KW:
1104            self.c_name += '_'
1105        if self.c_name == self.family.c_name:
1106            self.c_name = ''
1107
1108    def new_attr(self, elem, value):
1109        if elem['type'] in scalars:
1110            t = TypeScalar(self.family, self, elem, value)
1111        elif elem['type'] == 'unused':
1112            t = TypeUnused(self.family, self, elem, value)
1113        elif elem['type'] == 'pad':
1114            t = TypePad(self.family, self, elem, value)
1115        elif elem['type'] == 'flag':
1116            t = TypeFlag(self.family, self, elem, value)
1117        elif elem['type'] == 'string':
1118            t = TypeString(self.family, self, elem, value)
1119        elif elem['type'] == 'binary':
1120            if 'struct' in elem:
1121                t = TypeBinaryStruct(self.family, self, elem, value)
1122            elif elem.get('sub-type') in scalars:
1123                t = TypeBinaryScalarArray(self.family, self, elem, value)
1124            else:
1125                t = TypeBinary(self.family, self, elem, value)
1126        elif elem['type'] == 'bitfield32':
1127            t = TypeBitfield32(self.family, self, elem, value)
1128        elif elem['type'] == 'nest':
1129            t = TypeNest(self.family, self, elem, value)
1130        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1131            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1132                t = TypeArrayNest(self.family, self, elem, value)
1133            else:
1134                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1135        elif elem['type'] == 'nest-type-value':
1136            t = TypeNestTypeValue(self.family, self, elem, value)
1137        elif elem['type'] == 'sub-message':
1138            t = TypeSubMessage(self.family, self, elem, value)
1139        else:
1140            raise Exception(f"No typed class for type {elem['type']}")
1141
1142        if 'multi-attr' in elem and elem['multi-attr']:
1143            t = TypeMultiAttr(self.family, self, elem, value, t)
1144
1145        return t
1146
1147
1148class Operation(SpecOperation):
1149    def __init__(self, family, yaml, req_value, rsp_value):
1150        # Fill in missing operation properties (for fixed hdr-only msgs)
1151        for mode in ['do', 'dump', 'event']:
1152            for direction in ['request', 'reply']:
1153                try:
1154                    yaml[mode][direction].setdefault('attributes', [])
1155                except KeyError:
1156                    pass
1157
1158        super().__init__(family, yaml, req_value, rsp_value)
1159
1160        self.render_name = c_lower(family.ident_name + '_' + self.name)
1161
1162        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1163                         ('dump' in yaml and 'request' in yaml['dump'])
1164
1165        self.has_ntf = False
1166
1167        # Added by resolve:
1168        self.enum_name = None
1169        delattr(self, "enum_name")
1170
1171    def resolve(self):
1172        self.resolve_up(super())
1173
1174        if not self.is_async:
1175            self.enum_name = self.family.op_prefix + c_upper(self.name)
1176        else:
1177            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1178
1179    def mark_has_ntf(self):
1180        self.has_ntf = True
1181
1182
1183class SubMessage(SpecSubMessage):
1184    def __init__(self, family, yaml):
1185        super().__init__(family, yaml)
1186
1187        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1188
1189    def resolve(self):
1190        self.resolve_up(super())
1191
1192
1193class Family(SpecFamily):
1194    def __init__(self, file_name, exclude_ops):
1195        # Added by resolve:
1196        self.c_name = None
1197        delattr(self, "c_name")
1198        self.op_prefix = None
1199        delattr(self, "op_prefix")
1200        self.async_op_prefix = None
1201        delattr(self, "async_op_prefix")
1202        self.mcgrps = None
1203        delattr(self, "mcgrps")
1204        self.consts = None
1205        delattr(self, "consts")
1206        self.hooks = None
1207        delattr(self, "hooks")
1208
1209        super().__init__(file_name, exclude_ops=exclude_ops)
1210
1211        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1212        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1213
1214        if 'definitions' not in self.yaml:
1215            self.yaml['definitions'] = []
1216
1217        if 'uapi-header' in self.yaml:
1218            self.uapi_header = self.yaml['uapi-header']
1219        else:
1220            self.uapi_header = f"linux/{self.ident_name}.h"
1221        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1222            self.uapi_header_name = self.uapi_header[6:-2]
1223        else:
1224            self.uapi_header_name = self.ident_name
1225
1226    def resolve(self):
1227        self.resolve_up(super())
1228
1229        self.c_name = c_lower(self.ident_name)
1230        if 'name-prefix' in self.yaml['operations']:
1231            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1232        else:
1233            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1234        if 'async-prefix' in self.yaml['operations']:
1235            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1236        else:
1237            self.async_op_prefix = self.op_prefix
1238
1239        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1240
1241        self.hooks = dict()
1242        for when in ['pre', 'post']:
1243            self.hooks[when] = dict()
1244            for op_mode in ['do', 'dump']:
1245                self.hooks[when][op_mode] = dict()
1246                self.hooks[when][op_mode]['set'] = set()
1247                self.hooks[when][op_mode]['list'] = []
1248
1249        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1250        self.root_sets = dict()
1251        # dict space-name -> Struct
1252        self.pure_nested_structs = dict()
1253
1254        self._mark_notify()
1255        self._mock_up_events()
1256
1257        self._load_root_sets()
1258        self._load_nested_sets()
1259        self._load_attr_use()
1260        self._load_selector_passing()
1261        self._load_hooks()
1262
1263        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1264        if self.kernel_policy == 'global':
1265            self._load_global_policy()
1266
1267    def new_enum(self, elem):
1268        return EnumSet(self, elem)
1269
1270    def new_attr_set(self, elem):
1271        return AttrSet(self, elem)
1272
1273    def new_operation(self, elem, req_value, rsp_value):
1274        return Operation(self, elem, req_value, rsp_value)
1275
1276    def new_sub_message(self, elem):
1277        return SubMessage(self, elem)
1278
1279    def is_classic(self):
1280        return self.proto == 'netlink-raw'
1281
1282    def _mark_notify(self):
1283        for op in self.msgs.values():
1284            if 'notify' in op:
1285                self.ops[op['notify']].mark_has_ntf()
1286
1287    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1288    def _mock_up_events(self):
1289        for op in self.yaml['operations']['list']:
1290            if 'event' in op:
1291                op['do'] = {
1292                    'reply': {
1293                        'attributes': op['event']['attributes']
1294                    }
1295                }
1296
1297    def _load_root_sets(self):
1298        for op_name, op in self.msgs.items():
1299            if 'attribute-set' not in op:
1300                continue
1301
1302            req_attrs = set()
1303            rsp_attrs = set()
1304            for op_mode in ['do', 'dump']:
1305                if op_mode in op and 'request' in op[op_mode]:
1306                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1307                if op_mode in op and 'reply' in op[op_mode]:
1308                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1309            if 'event' in op:
1310                rsp_attrs.update(set(op['event']['attributes']))
1311
1312            if op['attribute-set'] not in self.root_sets:
1313                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1314            else:
1315                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1316                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1317
1318    def _sort_pure_types(self):
1319        # Try to reorder according to dependencies
1320        pns_key_list = list(self.pure_nested_structs.keys())
1321        pns_key_seen = set()
1322        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1323        for _ in range(rounds):
1324            if len(pns_key_list) == 0:
1325                break
1326            name = pns_key_list.pop(0)
1327            finished = True
1328            for _, spec in self.attr_sets[name].items():
1329                if 'nested-attributes' in spec:
1330                    nested = spec['nested-attributes']
1331                elif 'sub-message' in spec:
1332                    nested = spec.sub_message
1333                else:
1334                    continue
1335
1336                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1337                if self.pure_nested_structs[nested].recursive:
1338                    continue
1339                if nested not in pns_key_seen:
1340                    # Dicts are sorted, this will make struct last
1341                    struct = self.pure_nested_structs.pop(name)
1342                    self.pure_nested_structs[name] = struct
1343                    finished = False
1344                    break
1345            if finished:
1346                pns_key_seen.add(name)
1347            else:
1348                pns_key_list.append(name)
1349
1350    def _load_nested_set_nest(self, spec):
1351        inherit = set()
1352        nested = spec['nested-attributes']
1353        if nested not in self.root_sets:
1354            if nested not in self.pure_nested_structs:
1355                self.pure_nested_structs[nested] = \
1356                    Struct(self, nested, inherited=inherit,
1357                           fixed_header=spec.get('fixed-header'))
1358        else:
1359            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1360
1361        if 'type-value' in spec:
1362            if nested in self.root_sets:
1363                raise Exception("Inheriting members to a space used as root not supported")
1364            inherit.update(set(spec['type-value']))
1365        elif spec['type'] == 'indexed-array':
1366            inherit.add('idx')
1367        self.pure_nested_structs[nested].set_inherited(inherit)
1368
1369        return nested
1370
1371    def _load_nested_set_submsg(self, spec):
1372        # Fake the struct type for the sub-message itself
1373        # its not a attr_set but codegen wants attr_sets.
1374        submsg = self.sub_msgs[spec["sub-message"]]
1375        nested = submsg.name
1376
1377        attrs = []
1378        for name, fmt in submsg.formats.items():
1379            attr = {
1380                "name": name,
1381                "parent-sub-message": spec,
1382            }
1383            if 'attribute-set' in fmt:
1384                attr |= {
1385                    "type": "nest",
1386                    "nested-attributes": fmt['attribute-set'],
1387                }
1388                if 'fixed-header' in fmt:
1389                    attr |= { "fixed-header": fmt["fixed-header"] }
1390            elif 'fixed-header' in fmt:
1391                attr |= {
1392                    "type": "binary",
1393                    "struct": fmt["fixed-header"],
1394                }
1395            else:
1396                attr["type"] = "flag"
1397            attrs.append(attr)
1398
1399        self.attr_sets[nested] = AttrSet(self, {
1400            "name": nested,
1401            "name-pfx": self.name + '-' + spec.name + '-',
1402            "attributes": attrs
1403        })
1404
1405        if nested not in self.pure_nested_structs:
1406            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1407
1408        return nested
1409
1410    def _load_nested_sets(self):
1411        attr_set_queue = list(self.root_sets.keys())
1412        attr_set_seen = set(self.root_sets.keys())
1413
1414        while len(attr_set_queue):
1415            a_set = attr_set_queue.pop(0)
1416            for attr, spec in self.attr_sets[a_set].items():
1417                if 'nested-attributes' in spec:
1418                    nested = self._load_nested_set_nest(spec)
1419                elif 'sub-message' in spec:
1420                    nested = self._load_nested_set_submsg(spec)
1421                else:
1422                    continue
1423
1424                if nested not in attr_set_seen:
1425                    attr_set_queue.append(nested)
1426                    attr_set_seen.add(nested)
1427
1428        for root_set, rs_members in self.root_sets.items():
1429            for attr, spec in self.attr_sets[root_set].items():
1430                if 'nested-attributes' in spec:
1431                    nested = spec['nested-attributes']
1432                elif 'sub-message' in spec:
1433                    nested = spec.sub_message
1434                else:
1435                    nested = None
1436
1437                if nested:
1438                    if attr in rs_members['request']:
1439                        self.pure_nested_structs[nested].request = True
1440                    if attr in rs_members['reply']:
1441                        self.pure_nested_structs[nested].reply = True
1442
1443                    if spec.is_multi_val():
1444                        child = self.pure_nested_structs.get(nested)
1445                        child.in_multi_val = True
1446
1447        self._sort_pure_types()
1448
1449        # Propagate the request / reply / recursive
1450        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1451            for _, spec in self.attr_sets[attr_set].items():
1452                if attr_set in struct.child_nests:
1453                    struct.recursive = True
1454
1455                if 'nested-attributes' in spec:
1456                    child_name = spec['nested-attributes']
1457                elif 'sub-message' in spec:
1458                    child_name = spec.sub_message
1459                else:
1460                    continue
1461
1462                struct.child_nests.add(child_name)
1463                child = self.pure_nested_structs.get(child_name)
1464                if child:
1465                    if not child.recursive:
1466                        struct.child_nests.update(child.child_nests)
1467                    child.request |= struct.request
1468                    child.reply |= struct.reply
1469                    if spec.is_multi_val():
1470                        child.in_multi_val = True
1471
1472        self._sort_pure_types()
1473
1474    def _load_attr_use(self):
1475        for _, struct in self.pure_nested_structs.items():
1476            if struct.request:
1477                for _, arg in struct.member_list():
1478                    arg.set_request()
1479            if struct.reply:
1480                for _, arg in struct.member_list():
1481                    arg.set_reply()
1482
1483        for root_set, rs_members in self.root_sets.items():
1484            for attr, spec in self.attr_sets[root_set].items():
1485                if attr in rs_members['request']:
1486                    spec.set_request()
1487                if attr in rs_members['reply']:
1488                    spec.set_reply()
1489
1490    def _load_selector_passing(self):
1491        def all_structs():
1492            for k, v in reversed(self.pure_nested_structs.items()):
1493                yield k, v
1494            for k, _ in self.root_sets.items():
1495                yield k, None  # we don't have a struct, but it must be terminal
1496
1497        for attr_set, struct in all_structs():
1498            for _, spec in self.attr_sets[attr_set].items():
1499                if 'nested-attributes' in spec:
1500                    child_name = spec['nested-attributes']
1501                elif 'sub-message' in spec:
1502                    child_name = spec.sub_message
1503                else:
1504                    continue
1505
1506                child = self.pure_nested_structs.get(child_name)
1507                for selector in child.external_selectors():
1508                    if selector.name in self.attr_sets[attr_set]:
1509                        sel_attr = self.attr_sets[attr_set][selector.name]
1510                        selector.set_attr(sel_attr)
1511                    else:
1512                        raise Exception("Passing selector thru more than one layer not supported")
1513
1514    def _load_global_policy(self):
1515        global_set = set()
1516        attr_set_name = None
1517        for op_name, op in self.ops.items():
1518            if not op:
1519                continue
1520            if 'attribute-set' not in op:
1521                continue
1522
1523            if attr_set_name is None:
1524                attr_set_name = op['attribute-set']
1525            if attr_set_name != op['attribute-set']:
1526                raise Exception('For a global policy all ops must use the same set')
1527
1528            for op_mode in ['do', 'dump']:
1529                if op_mode in op:
1530                    req = op[op_mode].get('request')
1531                    if req:
1532                        global_set.update(req.get('attributes', []))
1533
1534        self.global_policy = []
1535        self.global_policy_set = attr_set_name
1536        for attr in self.attr_sets[attr_set_name]:
1537            if attr in global_set:
1538                self.global_policy.append(attr)
1539
1540    def _load_hooks(self):
1541        for op in self.ops.values():
1542            for op_mode in ['do', 'dump']:
1543                if op_mode not in op:
1544                    continue
1545                for when in ['pre', 'post']:
1546                    if when not in op[op_mode]:
1547                        continue
1548                    name = op[op_mode][when]
1549                    if name in self.hooks[when][op_mode]['set']:
1550                        continue
1551                    self.hooks[when][op_mode]['set'].add(name)
1552                    self.hooks[when][op_mode]['list'].append(name)
1553
1554
1555class RenderInfo:
1556    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1557        self.family = family
1558        self.nl = cw.nlib
1559        self.ku_space = ku_space
1560        self.op_mode = op_mode
1561        self.op = op
1562
1563        fixed_hdr = op.fixed_header if op else None
1564        self.fixed_hdr_len = 'ys->family->hdr_len'
1565        if op and op.fixed_header:
1566            if op.fixed_header != family.fixed_header:
1567                if family.is_classic():
1568                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1569                else:
1570                    raise Exception("Per-op fixed header not supported, yet")
1571
1572
1573        # 'do' and 'dump' response parsing is identical
1574        self.type_consistent = True
1575        self.type_oneside = False
1576        if op_mode != 'do' and 'dump' in op:
1577            if 'do' in op:
1578                if ('reply' in op['do']) != ('reply' in op["dump"]):
1579                    self.type_consistent = False
1580                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1581                    self.type_consistent = False
1582            else:
1583                self.type_consistent = True
1584                self.type_oneside = True
1585
1586        self.attr_set = attr_set
1587        if not self.attr_set:
1588            self.attr_set = op['attribute-set']
1589
1590        self.type_name_conflict = False
1591        if op:
1592            self.type_name = c_lower(op.name)
1593        else:
1594            self.type_name = c_lower(attr_set)
1595            if attr_set in family.consts:
1596                self.type_name_conflict = True
1597
1598        self.cw = cw
1599
1600        self.struct = dict()
1601        if op_mode == 'notify':
1602            op_mode = 'do' if 'do' in op else 'dump'
1603        for op_dir in ['request', 'reply']:
1604            if op:
1605                type_list = []
1606                if op_dir in op[op_mode]:
1607                    type_list = op[op_mode][op_dir]['attributes']
1608                self.struct[op_dir] = Struct(family, self.attr_set,
1609                                             fixed_header=fixed_hdr,
1610                                             type_list=type_list)
1611        if op_mode == 'event':
1612            self.struct['reply'] = Struct(family, self.attr_set,
1613                                          fixed_header=fixed_hdr,
1614                                          type_list=op['event']['attributes'])
1615
1616    def type_empty(self, key):
1617        return len(self.struct[key].attr_list) == 0 and \
1618            self.struct['request'].fixed_header is None
1619
1620    def needs_nlflags(self, direction):
1621        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1622
1623
1624class CodeWriter:
1625    def __init__(self, nlib, out_file=None, overwrite=True):
1626        self.nlib = nlib
1627        self._overwrite = overwrite
1628
1629        self._nl = False
1630        self._block_end = False
1631        self._silent_block = False
1632        self._ind = 0
1633        self._ifdef_block = None
1634        if out_file is None:
1635            self._out = os.sys.stdout
1636        else:
1637            self._out = tempfile.NamedTemporaryFile('w+')
1638            self._out_file = out_file
1639
1640    def __del__(self):
1641        self.close_out_file()
1642
1643    def close_out_file(self):
1644        if self._out == os.sys.stdout:
1645            return
1646        # Avoid modifying the file if contents didn't change
1647        self._out.flush()
1648        if not self._overwrite and os.path.isfile(self._out_file):
1649            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1650                return
1651        with open(self._out_file, 'w+') as out_file:
1652            self._out.seek(0)
1653            shutil.copyfileobj(self._out, out_file)
1654            self._out.close()
1655        self._out = os.sys.stdout
1656
1657    @classmethod
1658    def _is_cond(cls, line):
1659        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1660
1661    def p(self, line, add_ind=0):
1662        if self._block_end:
1663            self._block_end = False
1664            if line.startswith('else'):
1665                line = '} ' + line
1666            else:
1667                self._out.write('\t' * self._ind + '}\n')
1668
1669        if self._nl:
1670            self._out.write('\n')
1671            self._nl = False
1672
1673        ind = self._ind
1674        if line[-1] == ':':
1675            ind -= 1
1676        if self._silent_block:
1677            ind += 1
1678        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1679        self._silent_block |= line.strip() == 'else'
1680        if line[0] == '#':
1681            ind = 0
1682        if add_ind:
1683            ind += add_ind
1684        self._out.write('\t' * ind + line + '\n')
1685
1686    def nl(self):
1687        self._nl = True
1688
1689    def block_start(self, line=''):
1690        if line:
1691            line = line + ' '
1692        self.p(line + '{')
1693        self._ind += 1
1694
1695    def block_end(self, line=''):
1696        if line and line[0] not in {';', ','}:
1697            line = ' ' + line
1698        self._ind -= 1
1699        self._nl = False
1700        if not line:
1701            # Delay printing closing bracket in case "else" comes next
1702            if self._block_end:
1703                self._out.write('\t' * (self._ind + 1) + '}\n')
1704            self._block_end = True
1705        else:
1706            self.p('}' + line)
1707
1708    def write_doc_line(self, doc, indent=True):
1709        words = doc.split()
1710        line = ' *'
1711        for word in words:
1712            if len(line) + len(word) >= 79:
1713                self.p(line)
1714                line = ' *'
1715                if indent:
1716                    line += '  '
1717            line += ' ' + word
1718        self.p(line)
1719
1720    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1721        if not args:
1722            args = ['void']
1723
1724        if doc:
1725            self.p('/*')
1726            self.p(' * ' + doc)
1727            self.p(' */')
1728
1729        oneline = qual_ret
1730        if qual_ret[-1] != '*':
1731            oneline += ' '
1732        oneline += f"{name}({', '.join(args)}){suffix}"
1733
1734        if len(oneline) < 80:
1735            self.p(oneline)
1736            return
1737
1738        v = qual_ret
1739        if len(v) > 3:
1740            self.p(v)
1741            v = ''
1742        elif qual_ret[-1] != '*':
1743            v += ' '
1744        v += name + '('
1745        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1746        delta_ind = len(v) - len(ind)
1747        v += args[0]
1748        i = 1
1749        while i < len(args):
1750            next_len = len(v) + len(args[i])
1751            if v[0] == '\t':
1752                next_len += delta_ind
1753            if next_len > 76:
1754                self.p(v + ',')
1755                v = ind
1756            else:
1757                v += ', '
1758            v += args[i]
1759            i += 1
1760        self.p(v + ')' + suffix)
1761
1762    def write_func_lvar(self, local_vars):
1763        if not local_vars:
1764            return
1765
1766        if type(local_vars) is str:
1767            local_vars = [local_vars]
1768
1769        local_vars.sort(key=len, reverse=True)
1770        for var in local_vars:
1771            self.p(var)
1772        self.nl()
1773
1774    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1775        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1776        self.block_start()
1777        self.write_func_lvar(local_vars=local_vars)
1778
1779        for line in body:
1780            self.p(line)
1781        self.block_end()
1782
1783    def writes_defines(self, defines):
1784        longest = 0
1785        for define in defines:
1786            if len(define[0]) > longest:
1787                longest = len(define[0])
1788        longest = ((longest + 8) // 8) * 8
1789        for define in defines:
1790            line = '#define ' + define[0]
1791            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1792            if type(define[1]) is int:
1793                line += str(define[1])
1794            elif type(define[1]) is str:
1795                line += '"' + define[1] + '"'
1796            self.p(line)
1797
1798    def write_struct_init(self, members):
1799        longest = max([len(x[0]) for x in members])
1800        longest += 1  # because we prepend a .
1801        longest = ((longest + 8) // 8) * 8
1802        for one in members:
1803            line = '.' + one[0]
1804            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1805            line += '= ' + str(one[1]) + ','
1806            self.p(line)
1807
1808    def ifdef_block(self, config):
1809        config_option = None
1810        if config:
1811            config_option = 'CONFIG_' + c_upper(config)
1812        if self._ifdef_block == config_option:
1813            return
1814
1815        if self._ifdef_block:
1816            self.p('#endif /* ' + self._ifdef_block + ' */')
1817        if config_option:
1818            self.p('#ifdef ' + config_option)
1819        self._ifdef_block = config_option
1820
1821
1822scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1823
1824direction_to_suffix = {
1825    'reply': '_rsp',
1826    'request': '_req',
1827    '': ''
1828}
1829
1830op_mode_to_wrapper = {
1831    'do': '',
1832    'dump': '_list',
1833    'notify': '_ntf',
1834    'event': '',
1835}
1836
1837_C_KW = {
1838    'auto',
1839    'bool',
1840    'break',
1841    'case',
1842    'char',
1843    'const',
1844    'continue',
1845    'default',
1846    'do',
1847    'double',
1848    'else',
1849    'enum',
1850    'extern',
1851    'float',
1852    'for',
1853    'goto',
1854    'if',
1855    'inline',
1856    'int',
1857    'long',
1858    'register',
1859    'return',
1860    'short',
1861    'signed',
1862    'sizeof',
1863    'static',
1864    'struct',
1865    'switch',
1866    'typedef',
1867    'union',
1868    'unsigned',
1869    'void',
1870    'volatile',
1871    'while'
1872}
1873
1874
1875def rdir(direction):
1876    if direction == 'reply':
1877        return 'request'
1878    if direction == 'request':
1879        return 'reply'
1880    return direction
1881
1882
1883def op_prefix(ri, direction, deref=False):
1884    suffix = f"_{ri.type_name}"
1885
1886    if not ri.op_mode:
1887        pass
1888    elif ri.op_mode == 'do':
1889        suffix += f"{direction_to_suffix[direction]}"
1890    else:
1891        if direction == 'request':
1892            suffix += '_req'
1893            if not ri.type_oneside:
1894                suffix += '_dump'
1895        else:
1896            if ri.type_consistent:
1897                if deref:
1898                    suffix += f"{direction_to_suffix[direction]}"
1899                else:
1900                    suffix += op_mode_to_wrapper[ri.op_mode]
1901            else:
1902                suffix += '_rsp'
1903                suffix += '_dump' if deref else '_list'
1904
1905    return f"{ri.family.c_name}{suffix}"
1906
1907
1908def type_name(ri, direction, deref=False):
1909    return f"struct {op_prefix(ri, direction, deref=deref)}"
1910
1911
1912def print_prototype(ri, direction, terminate=True, doc=None):
1913    suffix = ';' if terminate else ''
1914
1915    fname = ri.op.render_name
1916    if ri.op_mode == 'dump':
1917        fname += '_dump'
1918
1919    args = ['struct ynl_sock *ys']
1920    if 'request' in ri.op[ri.op_mode]:
1921        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1922
1923    ret = 'int'
1924    if 'reply' in ri.op[ri.op_mode]:
1925        ret = f"{type_name(ri, rdir(direction))} *"
1926
1927    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1928
1929
1930def print_req_prototype(ri):
1931    print_prototype(ri, "request", doc=ri.op['doc'])
1932
1933
1934def print_dump_prototype(ri):
1935    print_prototype(ri, "request")
1936
1937
1938def put_typol_submsg(cw, struct):
1939    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1940
1941    i = 0
1942    for name, arg in struct.member_list():
1943        nest = ""
1944        if arg.type == 'nest':
1945            nest = f" .nest = &{arg.nested_render_name}_nest,"
1946        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s",%s },' %
1947             (i, name, nest))
1948        i += 1
1949
1950    cw.block_end(line=';')
1951    cw.nl()
1952
1953    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1954    cw.p(f'.max_attr = {i - 1},')
1955    cw.p(f'.table = {struct.render_name}_policy,')
1956    cw.block_end(line=';')
1957    cw.nl()
1958
1959
1960def put_typol_fwd(cw, struct):
1961    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1962
1963
1964def put_typol(cw, struct):
1965    if struct.submsg:
1966        put_typol_submsg(cw, struct)
1967        return
1968
1969    type_max = struct.attr_set.max_name
1970    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1971
1972    for _, arg in struct.member_list():
1973        arg.attr_typol(cw)
1974
1975    cw.block_end(line=';')
1976    cw.nl()
1977
1978    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1979    cw.p(f'.max_attr = {type_max},')
1980    cw.p(f'.table = {struct.render_name}_policy,')
1981    cw.block_end(line=';')
1982    cw.nl()
1983
1984
1985def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1986    args = [f'int {arg_name}']
1987    if enum:
1988        args = [enum.user_type + ' ' + arg_name]
1989    cw.write_func_prot('const char *', f'{render_name}_str', args)
1990    cw.block_start()
1991    if enum and enum.type == 'flags':
1992        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1993    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1994    cw.p('return NULL;')
1995    cw.p(f'return {map_name}[{arg_name}];')
1996    cw.block_end()
1997    cw.nl()
1998
1999
2000def put_op_name_fwd(family, cw):
2001    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
2002
2003
2004def put_op_name(family, cw):
2005    map_name = f'{family.c_name}_op_strmap'
2006    cw.block_start(line=f"static const char * const {map_name}[] =")
2007    for op_name, op in family.msgs.items():
2008        if op.rsp_value:
2009            # Make sure we don't add duplicated entries, if multiple commands
2010            # produce the same response in legacy families.
2011            if family.rsp_by_value[op.rsp_value] != op:
2012                cw.p(f'// skip "{op_name}", duplicate reply value')
2013                continue
2014
2015            if op.req_value == op.rsp_value:
2016                cw.p(f'[{op.enum_name}] = "{op_name}",')
2017            else:
2018                cw.p(f'[{op.rsp_value}] = "{op_name}",')
2019    cw.block_end(line=';')
2020    cw.nl()
2021
2022    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2023
2024
2025def put_enum_to_str_fwd(family, cw, enum):
2026    args = [enum.user_type + ' value']
2027    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2028
2029
2030def put_enum_to_str(family, cw, enum):
2031    map_name = f'{enum.render_name}_strmap'
2032    cw.block_start(line=f"static const char * const {map_name}[] =")
2033    for entry in enum.entries.values():
2034        cw.p(f'[{entry.value}] = "{entry.name}",')
2035    cw.block_end(line=';')
2036    cw.nl()
2037
2038    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2039
2040
2041def put_req_nested_prototype(ri, struct, suffix=';'):
2042    func_args = ['struct nlmsghdr *nlh',
2043                 'unsigned int attr_type',
2044                 f'{struct.ptr_name}obj']
2045
2046    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2047                          suffix=suffix)
2048
2049
2050def put_req_nested(ri, struct):
2051    local_vars = []
2052    init_lines = []
2053
2054    if struct.submsg is None:
2055        local_vars.append('struct nlattr *nest;')
2056        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2057    if struct.fixed_header:
2058        local_vars.append('void *hdr;')
2059        struct_sz = f'sizeof({struct.fixed_header})'
2060        init_lines.append(f"hdr = ynl_nlmsg_put_extra_header(nlh, {struct_sz});")
2061        init_lines.append(f"memcpy(hdr, &obj->_hdr, {struct_sz});")
2062
2063    has_anest = False
2064    has_count = False
2065    for _, arg in struct.member_list():
2066        has_anest |= arg.type == 'indexed-array'
2067        has_count |= arg.presence_type() == 'count'
2068    if has_anest:
2069        local_vars.append('struct nlattr *array;')
2070    if has_count:
2071        local_vars.append('unsigned int i;')
2072
2073    put_req_nested_prototype(ri, struct, suffix='')
2074    ri.cw.block_start()
2075    ri.cw.write_func_lvar(local_vars)
2076
2077    for line in init_lines:
2078        ri.cw.p(line)
2079
2080    for _, arg in struct.member_list():
2081        arg.attr_put(ri, "obj")
2082
2083    if struct.submsg is None:
2084        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2085
2086    ri.cw.nl()
2087    ri.cw.p('return 0;')
2088    ri.cw.block_end()
2089    ri.cw.nl()
2090
2091
2092def _multi_parse(ri, struct, init_lines, local_vars):
2093    if struct.fixed_header:
2094        local_vars += ['void *hdr;']
2095    if struct.nested:
2096        if struct.fixed_header:
2097            iter_line = f"ynl_attr_for_each_nested_off(attr, nested, sizeof({struct.fixed_header}))"
2098        else:
2099            iter_line = "ynl_attr_for_each_nested(attr, nested)"
2100    else:
2101        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2102        if ri.op.fixed_header != ri.family.fixed_header:
2103            if ri.family.is_classic():
2104                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2105            else:
2106                raise Exception("Per-op fixed header not supported, yet")
2107
2108    array_nests = set()
2109    multi_attrs = set()
2110    needs_parg = False
2111    for arg, aspec in struct.member_list():
2112        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2113            if aspec["sub-type"] in {'binary', 'nest'}:
2114                local_vars.append(f'const struct nlattr *attr_{aspec.c_name} = NULL;')
2115                array_nests.add(arg)
2116            elif aspec['sub-type'] in scalars:
2117                local_vars.append(f'const struct nlattr *attr_{aspec.c_name} = NULL;')
2118                array_nests.add(arg)
2119            else:
2120                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2121        if 'multi-attr' in aspec:
2122            multi_attrs.add(arg)
2123        needs_parg |= 'nested-attributes' in aspec
2124        needs_parg |= 'sub-message' in aspec
2125    if array_nests or multi_attrs:
2126        local_vars.append('int i;')
2127    if needs_parg:
2128        local_vars.append('struct ynl_parse_arg parg;')
2129        init_lines.append('parg.ys = yarg->ys;')
2130
2131    all_multi = array_nests | multi_attrs
2132
2133    for anest in sorted(all_multi):
2134        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
2135
2136    ri.cw.block_start()
2137    ri.cw.write_func_lvar(local_vars)
2138
2139    for line in init_lines:
2140        ri.cw.p(line)
2141    ri.cw.nl()
2142
2143    for arg in struct.inherited:
2144        ri.cw.p(f'dst->{arg} = {arg};')
2145
2146    if struct.fixed_header:
2147        if struct.nested:
2148            ri.cw.p('hdr = ynl_attr_data(nested);')
2149        elif ri.family.is_classic():
2150            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2151        else:
2152            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2153        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2154    for anest in sorted(all_multi):
2155        aspec = struct[anest]
2156        ri.cw.p(f"if (dst->{aspec.c_name})")
2157        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2158
2159    ri.cw.nl()
2160    ri.cw.block_start(line=iter_line)
2161    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2162    ri.cw.nl()
2163
2164    first = True
2165    for _, arg in struct.member_list():
2166        good = arg.attr_get(ri, 'dst', first=first)
2167        # First may be 'unused' or 'pad', ignore those
2168        first &= not good
2169
2170    ri.cw.block_end()
2171    ri.cw.nl()
2172
2173    for anest in sorted(array_nests):
2174        aspec = struct[anest]
2175
2176        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2177        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2178        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2179        ri.cw.p('i = 0;')
2180        if 'nested-attributes' in aspec:
2181            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2182        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2183        if 'nested-attributes' in aspec:
2184            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2185            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2186            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2187        elif aspec.sub_type in scalars:
2188            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2189        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2190            # Length is validated by typol
2191            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2192        else:
2193            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2194        ri.cw.p('i++;')
2195        ri.cw.block_end()
2196        ri.cw.block_end()
2197    ri.cw.nl()
2198
2199    for anest in sorted(multi_attrs):
2200        aspec = struct[anest]
2201        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2202        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2203        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2204        ri.cw.p('i = 0;')
2205        if 'nested-attributes' in aspec:
2206            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2207        ri.cw.block_start(line=iter_line)
2208        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2209        if 'nested-attributes' in aspec:
2210            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2211            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2212            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2213        elif aspec.type in scalars:
2214            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2215        elif aspec.type == 'binary' and 'struct' in aspec:
2216            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2217            ri.cw.nl()
2218            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2219            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2220            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2221        elif aspec.type == 'string':
2222            ri.cw.p('unsigned int len;')
2223            ri.cw.nl()
2224            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2225            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2226            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2227            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2228            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2229        else:
2230            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2231        ri.cw.p('i++;')
2232        ri.cw.block_end()
2233        ri.cw.block_end()
2234        ri.cw.block_end()
2235    ri.cw.nl()
2236
2237    if struct.nested:
2238        ri.cw.p('return 0;')
2239    else:
2240        ri.cw.p('return YNL_PARSE_CB_OK;')
2241    ri.cw.block_end()
2242    ri.cw.nl()
2243
2244
2245def parse_rsp_submsg(ri, struct):
2246    parse_rsp_nested_prototype(ri, struct, suffix='')
2247
2248    var = 'dst'
2249    local_vars = {'const struct nlattr *attr = nested;',
2250                  f'{struct.ptr_name}{var} = yarg->data;',
2251                  'struct ynl_parse_arg parg;'}
2252
2253    for _, arg in struct.member_list():
2254        _, _, l_vars = arg._attr_get(ri, var)
2255        local_vars |= set(l_vars) if l_vars else set()
2256
2257    ri.cw.block_start()
2258    ri.cw.write_func_lvar(list(local_vars))
2259    ri.cw.p('parg.ys = yarg->ys;')
2260    ri.cw.nl()
2261
2262    first = True
2263    for name, arg in struct.member_list():
2264        kw = 'if' if first else 'else if'
2265        first = False
2266
2267        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2268        get_lines, init_lines, _ = arg._attr_get(ri, var)
2269        for line in init_lines or []:
2270            ri.cw.p(line)
2271        for line in get_lines:
2272            ri.cw.p(line)
2273        if arg.presence_type() == 'present':
2274            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2275        ri.cw.block_end()
2276    ri.cw.p('return 0;')
2277    ri.cw.block_end()
2278    ri.cw.nl()
2279
2280
2281def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2282    func_args = ['struct ynl_parse_arg *yarg',
2283                 'const struct nlattr *nested']
2284    for sel in struct.external_selectors():
2285        func_args.append('const char *_sel_' + sel.name)
2286    if struct.submsg:
2287        func_args.insert(1, 'const char *sel')
2288    for arg in struct.inherited:
2289        func_args.append('__u32 ' + arg)
2290
2291    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2292                          suffix=suffix)
2293
2294
2295def parse_rsp_nested(ri, struct):
2296    if struct.submsg:
2297        return parse_rsp_submsg(ri, struct)
2298
2299    parse_rsp_nested_prototype(ri, struct, suffix='')
2300
2301    local_vars = ['const struct nlattr *attr;',
2302                  f'{struct.ptr_name}dst = yarg->data;']
2303    init_lines = []
2304
2305    if struct.member_list():
2306        _multi_parse(ri, struct, init_lines, local_vars)
2307    else:
2308        # Empty nest
2309        ri.cw.block_start()
2310        ri.cw.p('return 0;')
2311        ri.cw.block_end()
2312        ri.cw.nl()
2313
2314
2315def parse_rsp_msg(ri, deref=False):
2316    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2317        return
2318
2319    func_args = ['const struct nlmsghdr *nlh',
2320                 'struct ynl_parse_arg *yarg']
2321
2322    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2323                  'const struct nlattr *attr;']
2324    init_lines = ['dst = yarg->data;']
2325
2326    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2327
2328    if ri.struct["reply"].member_list():
2329        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2330    else:
2331        # Empty reply
2332        ri.cw.block_start()
2333        ri.cw.p('return YNL_PARSE_CB_OK;')
2334        ri.cw.block_end()
2335        ri.cw.nl()
2336
2337
2338def print_req(ri):
2339    ret_ok = '0'
2340    ret_err = '-1'
2341    direction = "request"
2342    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2343                  'struct nlmsghdr *nlh;',
2344                  'int err;']
2345
2346    if 'reply' in ri.op[ri.op_mode]:
2347        ret_ok = 'rsp'
2348        ret_err = 'NULL'
2349        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2350
2351    if ri.struct["request"].fixed_header:
2352        local_vars += ['size_t hdr_len;',
2353                       'void *hdr;']
2354
2355    for _, attr in ri.struct["request"].member_list():
2356        if attr.presence_type() == 'count':
2357            local_vars += ['unsigned int i;']
2358            break
2359
2360    print_prototype(ri, direction, terminate=False)
2361    ri.cw.block_start()
2362    ri.cw.write_func_lvar(local_vars)
2363
2364    if ri.family.is_classic():
2365        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2366    else:
2367        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2368
2369    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2370    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2371    if 'reply' in ri.op[ri.op_mode]:
2372        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2373    ri.cw.nl()
2374
2375    if ri.struct['request'].fixed_header:
2376        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2377        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2378        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2379        ri.cw.nl()
2380
2381    for _, attr in ri.struct["request"].member_list():
2382        attr.attr_put(ri, "req")
2383    ri.cw.nl()
2384
2385    if 'reply' in ri.op[ri.op_mode]:
2386        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2387        ri.cw.p('yrs.yarg.data = rsp;')
2388        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2389        if ri.op.value is not None:
2390            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2391        else:
2392            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2393        ri.cw.nl()
2394    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2395    ri.cw.p('if (err < 0)')
2396    if 'reply' in ri.op[ri.op_mode]:
2397        ri.cw.p('goto err_free;')
2398    else:
2399        ri.cw.p('return -1;')
2400    ri.cw.nl()
2401
2402    ri.cw.p(f"return {ret_ok};")
2403    ri.cw.nl()
2404
2405    if 'reply' in ri.op[ri.op_mode]:
2406        ri.cw.p('err_free:')
2407        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2408        ri.cw.p(f"return {ret_err};")
2409
2410    ri.cw.block_end()
2411
2412
2413def print_dump(ri):
2414    direction = "request"
2415    print_prototype(ri, direction, terminate=False)
2416    ri.cw.block_start()
2417    local_vars = ['struct ynl_dump_state yds = {};',
2418                  'struct nlmsghdr *nlh;',
2419                  'int err;']
2420
2421    if ri.struct['request'].fixed_header:
2422        local_vars += ['size_t hdr_len;',
2423                       'void *hdr;']
2424
2425    ri.cw.write_func_lvar(local_vars)
2426
2427    ri.cw.p('yds.yarg.ys = ys;')
2428    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2429    ri.cw.p("yds.yarg.data = NULL;")
2430    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2431    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2432    if ri.op.value is not None:
2433        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2434    else:
2435        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2436    ri.cw.nl()
2437    if ri.family.is_classic():
2438        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2439    else:
2440        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2441
2442    if ri.struct['request'].fixed_header:
2443        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2444        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2445        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2446        ri.cw.nl()
2447
2448    if "request" in ri.op[ri.op_mode]:
2449        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2450        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2451        ri.cw.nl()
2452        for _, attr in ri.struct["request"].member_list():
2453            attr.attr_put(ri, "req")
2454    ri.cw.nl()
2455
2456    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2457    ri.cw.p('if (err < 0)')
2458    ri.cw.p('goto free_list;')
2459    ri.cw.nl()
2460
2461    ri.cw.p('return yds.first;')
2462    ri.cw.nl()
2463    ri.cw.p('free_list:')
2464    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2465    ri.cw.p('return NULL;')
2466    ri.cw.block_end()
2467
2468
2469def call_free(ri, direction, var):
2470    return f"{op_prefix(ri, direction)}_free({var});"
2471
2472
2473def free_arg_name(direction):
2474    if direction:
2475        return direction_to_suffix[direction][1:]
2476    return 'obj'
2477
2478
2479def print_alloc_wrapper(ri, direction, struct=None):
2480    name = op_prefix(ri, direction)
2481    struct_name = name
2482    if ri.type_name_conflict:
2483        struct_name += '_'
2484
2485    args = ["void"]
2486    cnt = "1"
2487    if struct and struct.in_multi_val:
2488        args = ["unsigned int n"]
2489        cnt = "n"
2490
2491    ri.cw.write_func_prot(f'static inline struct {struct_name} *',
2492                          f"{name}_alloc", args)
2493    ri.cw.block_start()
2494    ri.cw.p(f'return calloc({cnt}, sizeof(struct {struct_name}));')
2495    ri.cw.block_end()
2496
2497
2498def print_free_prototype(ri, direction, suffix=';'):
2499    name = op_prefix(ri, direction)
2500    struct_name = name
2501    if ri.type_name_conflict:
2502        struct_name += '_'
2503    arg = free_arg_name(direction)
2504    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2505
2506
2507def print_nlflags_set(ri, direction):
2508    name = op_prefix(ri, direction)
2509    ri.cw.write_func_prot('static inline void', f"{name}_set_nlflags",
2510                          [f"struct {name} *req", "__u16 nl_flags"])
2511    ri.cw.block_start()
2512    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2513    ri.cw.block_end()
2514    ri.cw.nl()
2515
2516
2517def _print_type(ri, direction, struct):
2518    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2519    if not direction and ri.type_name_conflict:
2520        suffix += '_'
2521
2522    if ri.op_mode == 'dump' and not ri.type_oneside:
2523        suffix += '_dump'
2524
2525    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2526
2527    if ri.needs_nlflags(direction):
2528        ri.cw.p('__u16 _nlmsg_flags;')
2529        ri.cw.nl()
2530    if struct.fixed_header:
2531        ri.cw.p(struct.fixed_header + ' _hdr;')
2532        ri.cw.nl()
2533
2534    for type_filter in ['present', 'len', 'count']:
2535        meta_started = False
2536        for _, attr in struct.member_list():
2537            line = attr.presence_member(ri.ku_space, type_filter)
2538            if line:
2539                if not meta_started:
2540                    ri.cw.block_start(line="struct")
2541                    meta_started = True
2542                ri.cw.p(line)
2543        if meta_started:
2544            ri.cw.block_end(line=f'_{type_filter};')
2545    ri.cw.nl()
2546
2547    for arg in struct.inherited:
2548        ri.cw.p(f"__u32 {arg};")
2549
2550    for _, attr in struct.member_list():
2551        attr.struct_member(ri)
2552
2553    ri.cw.block_end(line=';')
2554    ri.cw.nl()
2555
2556
2557def print_type(ri, direction):
2558    _print_type(ri, direction, ri.struct[direction])
2559
2560
2561def print_type_full(ri, struct):
2562    _print_type(ri, "", struct)
2563
2564    if struct.request and struct.in_multi_val:
2565        print_alloc_wrapper(ri, "", struct)
2566        ri.cw.nl()
2567        free_rsp_nested_prototype(ri)
2568        ri.cw.nl()
2569
2570        # Name conflicts are too hard to deal with with the current code base,
2571        # they are very rare so don't bother printing setters in that case.
2572        if ri.ku_space == 'user' and not ri.type_name_conflict:
2573            for _, attr in struct.member_list():
2574                attr.setter(ri, ri.attr_set, "", var="obj")
2575        ri.cw.nl()
2576
2577
2578def print_type_helpers(ri, direction, deref=False):
2579    print_free_prototype(ri, direction)
2580    ri.cw.nl()
2581
2582    if ri.needs_nlflags(direction):
2583        print_nlflags_set(ri, direction)
2584
2585    if ri.ku_space == 'user' and direction == 'request':
2586        for _, attr in ri.struct[direction].member_list():
2587            attr.setter(ri, ri.attr_set, direction, deref=deref)
2588    ri.cw.nl()
2589
2590
2591def print_req_type_helpers(ri):
2592    if ri.type_empty("request"):
2593        return
2594    print_alloc_wrapper(ri, "request")
2595    print_type_helpers(ri, "request")
2596
2597
2598def print_rsp_type_helpers(ri):
2599    if 'reply' not in ri.op[ri.op_mode]:
2600        return
2601    print_type_helpers(ri, "reply")
2602
2603
2604def print_parse_prototype(ri, direction, terminate=True):
2605    suffix = "_rsp" if direction == "reply" else "_req"
2606    term = ';' if terminate else ''
2607
2608    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2609                          ['const struct nlattr **tb',
2610                           f"struct {ri.op.render_name}{suffix} *req"],
2611                          suffix=term)
2612
2613
2614def print_req_type(ri):
2615    if ri.type_empty("request"):
2616        return
2617    print_type(ri, "request")
2618
2619
2620def print_req_free(ri):
2621    if 'request' not in ri.op[ri.op_mode]:
2622        return
2623    _free_type(ri, 'request', ri.struct['request'])
2624
2625
2626def print_rsp_type(ri):
2627    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2628        direction = 'reply'
2629    elif ri.op_mode == 'event':
2630        direction = 'reply'
2631    else:
2632        return
2633    print_type(ri, direction)
2634
2635
2636def print_wrapped_type(ri):
2637    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2638    if ri.op_mode == 'dump':
2639        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2640    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2641        ri.cw.p('__u16 family;')
2642        ri.cw.p('__u8 cmd;')
2643        ri.cw.p('struct ynl_ntf_base_type *next;')
2644        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2645    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2646    ri.cw.block_end(line=';')
2647    ri.cw.nl()
2648    print_free_prototype(ri, 'reply')
2649    ri.cw.nl()
2650
2651
2652def _free_type_members_iter(ri, struct):
2653    if struct.free_needs_iter():
2654        ri.cw.p('unsigned int i;')
2655        ri.cw.nl()
2656
2657
2658def _free_type_members(ri, var, struct, ref=''):
2659    for _, attr in struct.member_list():
2660        attr.free(ri, var, ref)
2661
2662
2663def _free_type(ri, direction, struct):
2664    var = free_arg_name(direction)
2665
2666    print_free_prototype(ri, direction, suffix='')
2667    ri.cw.block_start()
2668    _free_type_members_iter(ri, struct)
2669    _free_type_members(ri, var, struct)
2670    if direction:
2671        ri.cw.p(f'free({var});')
2672    ri.cw.block_end()
2673    ri.cw.nl()
2674
2675
2676def free_rsp_nested_prototype(ri):
2677        print_free_prototype(ri, "")
2678
2679
2680def free_rsp_nested(ri, struct):
2681    _free_type(ri, "", struct)
2682
2683
2684def print_rsp_free(ri):
2685    if 'reply' not in ri.op[ri.op_mode]:
2686        return
2687    _free_type(ri, 'reply', ri.struct['reply'])
2688
2689
2690def print_dump_type_free(ri):
2691    sub_type = type_name(ri, 'reply')
2692
2693    print_free_prototype(ri, 'reply', suffix='')
2694    ri.cw.block_start()
2695    ri.cw.p(f"{sub_type} *next = rsp;")
2696    ri.cw.nl()
2697    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2698    _free_type_members_iter(ri, ri.struct['reply'])
2699    ri.cw.p('rsp = next;')
2700    ri.cw.p('next = rsp->next;')
2701    ri.cw.nl()
2702
2703    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2704    ri.cw.p('free(rsp);')
2705    ri.cw.block_end()
2706    ri.cw.block_end()
2707    ri.cw.nl()
2708
2709
2710def print_ntf_type_free(ri):
2711    print_free_prototype(ri, 'reply', suffix='')
2712    ri.cw.block_start()
2713    _free_type_members_iter(ri, ri.struct['reply'])
2714    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2715    ri.cw.p('free(rsp);')
2716    ri.cw.block_end()
2717    ri.cw.nl()
2718
2719
2720def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2721    if terminate and ri and policy_should_be_static(struct.family):
2722        return
2723
2724    if terminate:
2725        prefix = 'extern '
2726    else:
2727        if ri and policy_should_be_static(struct.family):
2728            prefix = 'static '
2729        else:
2730            prefix = ''
2731
2732    suffix = ';' if terminate else ' = {'
2733
2734    max_attr = struct.attr_max_val
2735    if ri:
2736        name = ri.op.render_name
2737        if ri.op.dual_policy:
2738            name += '_' + ri.op_mode
2739    else:
2740        name = struct.render_name
2741    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2742
2743
2744def print_req_policy(cw, struct, ri=None):
2745    if ri and ri.op:
2746        cw.ifdef_block(ri.op.get('config-cond', None))
2747    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2748    for _, arg in struct.member_list():
2749        arg.attr_policy(cw)
2750    cw.p("};")
2751    cw.ifdef_block(None)
2752    cw.nl()
2753
2754
2755def kernel_can_gen_family_struct(family):
2756    return family.proto == 'genetlink'
2757
2758
2759def policy_should_be_static(family):
2760    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2761
2762
2763def print_kernel_policy_ranges(family, cw):
2764    first = True
2765    for _, attr_set in family.attr_sets.items():
2766        if attr_set.subset_of:
2767            continue
2768
2769        for _, attr in attr_set.items():
2770            if not attr.request:
2771                continue
2772            if 'full-range' not in attr.checks:
2773                continue
2774
2775            if first:
2776                cw.p('/* Integer value ranges */')
2777                first = False
2778
2779            sign = '' if attr.type[0] == 'u' else '_signed'
2780            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2781            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2782            members = []
2783            if 'min' in attr.checks:
2784                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2785            if 'max' in attr.checks:
2786                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2787            cw.write_struct_init(members)
2788            cw.block_end(line=';')
2789            cw.nl()
2790
2791
2792def print_kernel_policy_sparse_enum_validates(family, cw):
2793    first = True
2794    for _, attr_set in family.attr_sets.items():
2795        if attr_set.subset_of:
2796            continue
2797
2798        for _, attr in attr_set.items():
2799            if not attr.request:
2800                continue
2801            if not attr.enum_name:
2802                continue
2803            if 'sparse' not in attr.checks:
2804                continue
2805
2806            if first:
2807                cw.p('/* Sparse enums validation callbacks */')
2808                first = False
2809
2810            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2811                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2812            cw.block_start()
2813            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2814            enum = family.consts[attr['enum']]
2815            first_entry = True
2816            for entry in enum.entries.values():
2817                if first_entry:
2818                    first_entry = False
2819                else:
2820                    cw.p('fallthrough;')
2821                cw.p(f'case {entry.c_name}:')
2822            cw.p('return 0;')
2823            cw.block_end()
2824            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2825            cw.p('return -EINVAL;')
2826            cw.block_end()
2827            cw.nl()
2828
2829
2830def print_kernel_op_table_fwd(family, cw, terminate):
2831    exported = not kernel_can_gen_family_struct(family)
2832
2833    if not terminate or exported:
2834        cw.p(f"/* Ops table for {family.ident_name} */")
2835
2836        pol_to_struct = {'global': 'genl_small_ops',
2837                         'per-op': 'genl_ops',
2838                         'split': 'genl_split_ops'}
2839        struct_type = pol_to_struct[family.kernel_policy]
2840
2841        if not exported:
2842            cnt = ""
2843        elif family.kernel_policy == 'split':
2844            cnt = 0
2845            for op in family.ops.values():
2846                if 'do' in op:
2847                    cnt += 1
2848                if 'dump' in op:
2849                    cnt += 1
2850        else:
2851            cnt = len(family.ops)
2852
2853        qual = 'static const' if not exported else 'const'
2854        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2855        if terminate:
2856            cw.p(f"extern {line};")
2857        else:
2858            cw.block_start(line=line + ' =')
2859
2860    if not terminate:
2861        return
2862
2863    cw.nl()
2864    for name in family.hooks['pre']['do']['list']:
2865        cw.write_func_prot('int', c_lower(name),
2866                           ['const struct genl_split_ops *ops',
2867                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2868    for name in family.hooks['post']['do']['list']:
2869        cw.write_func_prot('void', c_lower(name),
2870                           ['const struct genl_split_ops *ops',
2871                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2872    for name in family.hooks['pre']['dump']['list']:
2873        cw.write_func_prot('int', c_lower(name),
2874                           ['struct netlink_callback *cb'], suffix=';')
2875    for name in family.hooks['post']['dump']['list']:
2876        cw.write_func_prot('int', c_lower(name),
2877                           ['struct netlink_callback *cb'], suffix=';')
2878
2879    cw.nl()
2880
2881    for op_name, op in family.ops.items():
2882        if op.is_async:
2883            continue
2884
2885        if 'do' in op:
2886            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2887            cw.write_func_prot('int', name,
2888                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2889
2890        if 'dump' in op:
2891            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2892            cw.write_func_prot('int', name,
2893                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2894    cw.nl()
2895
2896
2897def print_kernel_op_table_hdr(family, cw):
2898    print_kernel_op_table_fwd(family, cw, terminate=True)
2899
2900
2901def print_kernel_op_table(family, cw):
2902    print_kernel_op_table_fwd(family, cw, terminate=False)
2903    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2904        for op_name, op in family.ops.items():
2905            if op.is_async:
2906                continue
2907
2908            cw.ifdef_block(op.get('config-cond', None))
2909            cw.block_start()
2910            members = [('cmd', op.enum_name)]
2911            if 'dont-validate' in op:
2912                members.append(('validate',
2913                                ' | '.join([c_upper('genl-dont-validate-' + x)
2914                                            for x in op['dont-validate']])), )
2915            for op_mode in ['do', 'dump']:
2916                if op_mode in op:
2917                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2918                    members.append((op_mode + 'it', name))
2919            if family.kernel_policy == 'per-op':
2920                struct = Struct(family, op['attribute-set'],
2921                                type_list=op['do']['request']['attributes'])
2922
2923                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2924                members.append(('policy', name))
2925                members.append(('maxattr', struct.attr_max_val.enum_name))
2926            if 'flags' in op:
2927                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2928            cw.write_struct_init(members)
2929            cw.block_end(line=',')
2930    elif family.kernel_policy == 'split':
2931        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2932                    'dump': {'pre': 'start', 'post': 'done'}}
2933
2934        for op_name, op in family.ops.items():
2935            for op_mode in ['do', 'dump']:
2936                if op.is_async or op_mode not in op:
2937                    continue
2938
2939                cw.ifdef_block(op.get('config-cond', None))
2940                cw.block_start()
2941                members = [('cmd', op.enum_name)]
2942                if 'dont-validate' in op:
2943                    dont_validate = []
2944                    for x in op['dont-validate']:
2945                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2946                            continue
2947                        if op_mode == "dump" and x == 'strict':
2948                            continue
2949                        dont_validate.append(x)
2950
2951                    if dont_validate:
2952                        members.append(('validate',
2953                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2954                                                    for x in dont_validate])), )
2955                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2956                if 'pre' in op[op_mode]:
2957                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2958                members.append((op_mode + 'it', name))
2959                if 'post' in op[op_mode]:
2960                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2961                if 'request' in op[op_mode]:
2962                    struct = Struct(family, op['attribute-set'],
2963                                    type_list=op[op_mode]['request']['attributes'])
2964
2965                    if op.dual_policy:
2966                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2967                    else:
2968                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2969                    members.append(('policy', name))
2970                    members.append(('maxattr', struct.attr_max_val.enum_name))
2971                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2972                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2973                cw.write_struct_init(members)
2974                cw.block_end(line=',')
2975    cw.ifdef_block(None)
2976
2977    cw.block_end(line=';')
2978    cw.nl()
2979
2980
2981def print_kernel_mcgrp_hdr(family, cw):
2982    if not family.mcgrps['list']:
2983        return
2984
2985    cw.block_start('enum')
2986    for grp in family.mcgrps['list']:
2987        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2988        cw.p(grp_id)
2989    cw.block_end(';')
2990    cw.nl()
2991
2992
2993def print_kernel_mcgrp_src(family, cw):
2994    if not family.mcgrps['list']:
2995        return
2996
2997    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2998    for grp in family.mcgrps['list']:
2999        name = grp['name']
3000        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
3001        cw.p('[' + grp_id + '] = { "' + name + '", },')
3002    cw.block_end(';')
3003    cw.nl()
3004
3005
3006def print_kernel_family_struct_hdr(family, cw):
3007    if not kernel_can_gen_family_struct(family):
3008        return
3009
3010    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
3011    cw.nl()
3012    if 'sock-priv' in family.kernel_family:
3013        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
3014        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
3015        cw.nl()
3016
3017
3018def print_kernel_family_struct_src(family, cw):
3019    if not kernel_can_gen_family_struct(family):
3020        return
3021
3022    if 'sock-priv' in family.kernel_family:
3023        # Generate "trampolines" to make CFI happy
3024        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
3025                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
3026                      ["void *priv"])
3027        cw.nl()
3028        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
3029                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
3030                      ["void *priv"])
3031        cw.nl()
3032
3033    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
3034    cw.p('.name\t\t= ' + family.fam_key + ',')
3035    cw.p('.version\t= ' + family.ver_key + ',')
3036    cw.p('.netnsok\t= true,')
3037    cw.p('.parallel_ops\t= true,')
3038    cw.p('.module\t\t= THIS_MODULE,')
3039    if family.kernel_policy == 'per-op':
3040        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
3041        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3042    elif family.kernel_policy == 'split':
3043        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
3044        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3045    if family.mcgrps['list']:
3046        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
3047        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
3048    if 'sock-priv' in family.kernel_family:
3049        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
3050        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
3051        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
3052    cw.block_end(';')
3053
3054
3055def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
3056    start_line = 'enum'
3057    if enum_name in obj:
3058        if obj[enum_name]:
3059            start_line = 'enum ' + c_lower(obj[enum_name])
3060    elif ckey and ckey in obj:
3061        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3062    cw.block_start(line=start_line)
3063
3064
3065def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3066    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3067    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3068    max_value = f"({cnt_name} - 1)"
3069
3070    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3071    val = 0
3072    for op in family.msgs.values():
3073        if separate_ntf and ('notify' in op or 'event' in op):
3074            continue
3075
3076        suffix = ','
3077        if op.value != val:
3078            suffix = f" = {op.value},"
3079            val = op.value
3080        cw.p(op.enum_name + suffix)
3081        val += 1
3082    cw.nl()
3083    cw.p(cnt_name + ('' if max_by_define else ','))
3084    if not max_by_define:
3085        cw.p(f"{max_name} = {max_value}")
3086    cw.block_end(line=';')
3087    if max_by_define:
3088        cw.p(f"#define {max_name} {max_value}")
3089    cw.nl()
3090
3091
3092def render_uapi_directional(family, cw, max_by_define):
3093    max_name = f"{family.op_prefix}USER_MAX"
3094    cnt_name = f"__{family.op_prefix}USER_CNT"
3095    max_value = f"({cnt_name} - 1)"
3096
3097    cw.block_start(line='enum')
3098    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3099    val = 0
3100    for op in family.msgs.values():
3101        if 'do' in op and 'event' not in op:
3102            suffix = ','
3103            if op.value and op.value != val:
3104                suffix = f" = {op.value},"
3105                val = op.value
3106            cw.p(op.enum_name + suffix)
3107            val += 1
3108    cw.nl()
3109    cw.p(cnt_name + ('' if max_by_define else ','))
3110    if not max_by_define:
3111        cw.p(f"{max_name} = {max_value}")
3112    cw.block_end(line=';')
3113    if max_by_define:
3114        cw.p(f"#define {max_name} {max_value}")
3115    cw.nl()
3116
3117    max_name = f"{family.op_prefix}KERNEL_MAX"
3118    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3119    max_value = f"({cnt_name} - 1)"
3120
3121    cw.block_start(line='enum')
3122    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3123    val = 0
3124    for op in family.msgs.values():
3125        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3126            enum_name = op.enum_name
3127            if 'event' not in op and 'notify' not in op:
3128                enum_name = f'{enum_name}_REPLY'
3129
3130            suffix = ','
3131            if op.value and op.value != val:
3132                suffix = f" = {op.value},"
3133                val = op.value
3134            cw.p(enum_name + suffix)
3135            val += 1
3136    cw.nl()
3137    cw.p(cnt_name + ('' if max_by_define else ','))
3138    if not max_by_define:
3139        cw.p(f"{max_name} = {max_value}")
3140    cw.block_end(line=';')
3141    if max_by_define:
3142        cw.p(f"#define {max_name} {max_value}")
3143    cw.nl()
3144
3145
3146def render_uapi(family, cw):
3147    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3148    hdr_prot = hdr_prot.replace('/', '_')
3149    cw.p('#ifndef ' + hdr_prot)
3150    cw.p('#define ' + hdr_prot)
3151    cw.nl()
3152
3153    defines = [(family.fam_key, family["name"]),
3154               (family.ver_key, family.get('version', 1))]
3155    cw.writes_defines(defines)
3156    cw.nl()
3157
3158    defines = []
3159    for const in family['definitions']:
3160        if const.get('header'):
3161            continue
3162
3163        if const['type'] != 'const':
3164            cw.writes_defines(defines)
3165            defines = []
3166            cw.nl()
3167
3168        # Write kdoc for enum and flags (one day maybe also structs)
3169        if const['type'] == 'enum' or const['type'] == 'flags':
3170            enum = family.consts[const['name']]
3171
3172            if enum.header:
3173                continue
3174
3175            if enum.has_doc():
3176                if enum.has_entry_doc():
3177                    cw.p('/**')
3178                    doc = ''
3179                    if 'doc' in enum:
3180                        doc = ' - ' + enum['doc']
3181                    cw.write_doc_line(enum.enum_name + doc)
3182                else:
3183                    cw.p('/*')
3184                    cw.write_doc_line(enum['doc'], indent=False)
3185                for entry in enum.entries.values():
3186                    if entry.has_doc():
3187                        doc = '@' + entry.c_name + ': ' + entry['doc']
3188                        cw.write_doc_line(doc)
3189                cw.p(' */')
3190
3191            uapi_enum_start(family, cw, const, 'name')
3192            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3193            for entry in enum.entries.values():
3194                suffix = ','
3195                if entry.value_change:
3196                    suffix = f" = {entry.user_value()}" + suffix
3197                cw.p(entry.c_name + suffix)
3198
3199            if const.get('render-max', False):
3200                cw.nl()
3201                cw.p('/* private: */')
3202                if const['type'] == 'flags':
3203                    max_name = c_upper(name_pfx + 'mask')
3204                    max_val = f' = {enum.get_mask()},'
3205                    cw.p(max_name + max_val)
3206                else:
3207                    cnt_name = enum.enum_cnt_name
3208                    max_name = c_upper(name_pfx + 'max')
3209                    if not cnt_name:
3210                        cnt_name = '__' + name_pfx + 'max'
3211                    cw.p(c_upper(cnt_name) + ',')
3212                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3213            cw.block_end(line=';')
3214            cw.nl()
3215        elif const['type'] == 'const':
3216            name_pfx = const.get('name-prefix', f"{family.ident_name}-")
3217            defines.append([c_upper(family.get('c-define-name',
3218                                               f"{name_pfx}{const['name']}")),
3219                            const['value']])
3220
3221    if defines:
3222        cw.writes_defines(defines)
3223        cw.nl()
3224
3225    max_by_define = family.get('max-by-define', False)
3226
3227    for _, attr_set in family.attr_sets.items():
3228        if attr_set.subset_of:
3229            continue
3230
3231        max_value = f"({attr_set.cnt_name} - 1)"
3232
3233        val = 0
3234        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3235        for _, attr in attr_set.items():
3236            suffix = ','
3237            if attr.value != val:
3238                suffix = f" = {attr.value},"
3239                val = attr.value
3240            val += 1
3241            cw.p(attr.enum_name + suffix)
3242        if attr_set.items():
3243            cw.nl()
3244        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3245        if not max_by_define:
3246            cw.p(f"{attr_set.max_name} = {max_value}")
3247        cw.block_end(line=';')
3248        if max_by_define:
3249            cw.p(f"#define {attr_set.max_name} {max_value}")
3250        cw.nl()
3251
3252    # Commands
3253    separate_ntf = 'async-prefix' in family['operations']
3254
3255    if family.msg_id_model == 'unified':
3256        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3257    elif family.msg_id_model == 'directional':
3258        render_uapi_directional(family, cw, max_by_define)
3259    else:
3260        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3261
3262    if separate_ntf:
3263        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3264        for op in family.msgs.values():
3265            if separate_ntf and not ('notify' in op or 'event' in op):
3266                continue
3267
3268            suffix = ','
3269            if 'value' in op:
3270                suffix = f" = {op['value']},"
3271            cw.p(op.enum_name + suffix)
3272        cw.block_end(line=';')
3273        cw.nl()
3274
3275    # Multicast
3276    defines = []
3277    for grp in family.mcgrps['list']:
3278        name = grp['name']
3279        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3280                        f'{name}'])
3281    cw.nl()
3282    if defines:
3283        cw.writes_defines(defines)
3284        cw.nl()
3285
3286    cw.p(f'#endif /* {hdr_prot} */')
3287
3288
3289def _render_user_ntf_entry(ri, op):
3290    if not ri.family.is_classic():
3291        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3292    else:
3293        crud_op = ri.family.req_by_value[op.rsp_value]
3294        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3295    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3296    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3297    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3298    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3299    ri.cw.block_end(line=',')
3300
3301
3302def render_user_family(family, cw, prototype):
3303    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3304    if prototype:
3305        cw.p(f'extern {symbol};')
3306        return
3307
3308    if family.ntfs:
3309        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3310        for ntf_op_name, ntf_op in family.ntfs.items():
3311            if 'notify' in ntf_op:
3312                op = family.ops[ntf_op['notify']]
3313                ri = RenderInfo(cw, family, "user", op, "notify")
3314            elif 'event' in ntf_op:
3315                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3316            else:
3317                raise Exception('Invalid notification ' + ntf_op_name)
3318            _render_user_ntf_entry(ri, ntf_op)
3319        for op_name, op in family.ops.items():
3320            if 'event' not in op:
3321                continue
3322            ri = RenderInfo(cw, family, "user", op, "event")
3323            _render_user_ntf_entry(ri, op)
3324        cw.block_end(line=";")
3325        cw.nl()
3326
3327    cw.block_start(f'{symbol} = ')
3328    cw.p(f'.name\t\t= "{family.c_name}",')
3329    if family.is_classic():
3330        cw.p('.is_classic\t= true,')
3331        cw.p(f'.classic_id\t= {family.get("protonum")},')
3332    if family.is_classic():
3333        if family.fixed_header:
3334            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3335    elif family.fixed_header:
3336        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3337    else:
3338        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3339    if family.ntfs:
3340        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3341        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3342    cw.block_end(line=';')
3343
3344
3345def family_contains_bitfield32(family):
3346    for _, attr_set in family.attr_sets.items():
3347        if attr_set.subset_of:
3348            continue
3349        for _, attr in attr_set.items():
3350            if attr.type == "bitfield32":
3351                return True
3352    return False
3353
3354
3355def find_kernel_root(full_path):
3356    sub_path = ''
3357    while True:
3358        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3359        full_path = os.path.dirname(full_path)
3360        maintainers = os.path.join(full_path, "MAINTAINERS")
3361        if os.path.exists(maintainers):
3362            return full_path, sub_path[:-1]
3363
3364
3365def main():
3366    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3367    parser.add_argument('--mode', dest='mode', type=str, required=True,
3368                        choices=('user', 'kernel', 'uapi'))
3369    parser.add_argument('--spec', dest='spec', type=str, required=True)
3370    parser.add_argument('--header', dest='header', action='store_true', default=None)
3371    parser.add_argument('--source', dest='header', action='store_false')
3372    parser.add_argument('--user-header', nargs='+', default=[])
3373    parser.add_argument('--cmp-out', action='store_true', default=None,
3374                        help='Do not overwrite the output file if the new output is identical to the old')
3375    parser.add_argument('--exclude-op', action='append', default=[])
3376    parser.add_argument('-o', dest='out_file', type=str, default=None)
3377    args = parser.parse_args()
3378
3379    if args.header is None:
3380        parser.error("--header or --source is required")
3381
3382    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3383
3384    try:
3385        parsed = Family(args.spec, exclude_ops)
3386        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3387            print('Spec license:', parsed.license)
3388            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3389            os.sys.exit(1)
3390    except yaml.YAMLError as exc:
3391        print(exc)
3392        os.sys.exit(1)
3393        return
3394
3395    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3396
3397    _, spec_kernel = find_kernel_root(args.spec)
3398    if args.mode == 'uapi' or args.header:
3399        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3400    else:
3401        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3402    cw.p("/* Do not edit directly, auto-generated from: */")
3403    cw.p(f"/*\t{spec_kernel} */")
3404    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3405    if args.exclude_op or args.user_header:
3406        line = ''
3407        line += ' --user-header '.join([''] + args.user_header)
3408        line += ' --exclude-op '.join([''] + args.exclude_op)
3409        cw.p(f'/* YNL-ARG{line} */')
3410    cw.nl()
3411
3412    if args.mode == 'uapi':
3413        render_uapi(parsed, cw)
3414        return
3415
3416    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3417    if args.header:
3418        cw.p('#ifndef ' + hdr_prot)
3419        cw.p('#define ' + hdr_prot)
3420        cw.nl()
3421
3422    if args.out_file:
3423        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3424    else:
3425        hdr_file = "generated_header_file.h"
3426
3427    if args.mode == 'kernel':
3428        cw.p('#include <net/netlink.h>')
3429        cw.p('#include <net/genetlink.h>')
3430        cw.nl()
3431        if not args.header:
3432            if args.out_file:
3433                cw.p(f'#include "{hdr_file}"')
3434            cw.nl()
3435        headers = ['uapi/' + parsed.uapi_header]
3436        headers += parsed.kernel_family.get('headers', [])
3437    else:
3438        cw.p('#include <stdlib.h>')
3439        cw.p('#include <string.h>')
3440        if args.header:
3441            cw.p('#include <linux/types.h>')
3442            if family_contains_bitfield32(parsed):
3443                cw.p('#include <linux/netlink.h>')
3444        else:
3445            cw.p(f'#include "{hdr_file}"')
3446            cw.p('#include "ynl.h"')
3447        headers = []
3448    for definition in parsed['definitions'] + parsed['attribute-sets']:
3449        if 'header' in definition:
3450            headers.append(definition['header'])
3451    if args.mode == 'user':
3452        headers.append(parsed.uapi_header)
3453    seen_header = []
3454    for one in headers:
3455        if one not in seen_header:
3456            cw.p(f"#include <{one}>")
3457            seen_header.append(one)
3458    cw.nl()
3459
3460    if args.mode == "user":
3461        if not args.header:
3462            cw.p("#include <linux/genetlink.h>")
3463            cw.nl()
3464            for one in args.user_header:
3465                cw.p(f'#include "{one}"')
3466        else:
3467            cw.p('struct ynl_sock;')
3468            cw.nl()
3469            render_user_family(parsed, cw, True)
3470        cw.nl()
3471
3472    if args.mode == "kernel":
3473        if args.header:
3474            for _, struct in sorted(parsed.pure_nested_structs.items()):
3475                if struct.request:
3476                    cw.p('/* Common nested types */')
3477                    break
3478            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3479                if struct.request:
3480                    print_req_policy_fwd(cw, struct)
3481            cw.nl()
3482
3483            if parsed.kernel_policy == 'global':
3484                cw.p(f"/* Global operation policy for {parsed.name} */")
3485
3486                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3487                print_req_policy_fwd(cw, struct)
3488                cw.nl()
3489
3490            if parsed.kernel_policy in {'per-op', 'split'}:
3491                for op_name, op in parsed.ops.items():
3492                    if 'do' in op and 'event' not in op:
3493                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3494                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3495                        cw.nl()
3496
3497            print_kernel_op_table_hdr(parsed, cw)
3498            print_kernel_mcgrp_hdr(parsed, cw)
3499            print_kernel_family_struct_hdr(parsed, cw)
3500        else:
3501            print_kernel_policy_ranges(parsed, cw)
3502            print_kernel_policy_sparse_enum_validates(parsed, cw)
3503
3504            for _, struct in sorted(parsed.pure_nested_structs.items()):
3505                if struct.request:
3506                    cw.p('/* Common nested types */')
3507                    break
3508            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3509                if struct.request:
3510                    print_req_policy(cw, struct)
3511            cw.nl()
3512
3513            if parsed.kernel_policy == 'global':
3514                cw.p(f"/* Global operation policy for {parsed.name} */")
3515
3516                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3517                print_req_policy(cw, struct)
3518                cw.nl()
3519
3520            for op_name, op in parsed.ops.items():
3521                if parsed.kernel_policy in {'per-op', 'split'}:
3522                    for op_mode in ['do', 'dump']:
3523                        if op_mode in op and 'request' in op[op_mode]:
3524                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3525                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3526                            print_req_policy(cw, ri.struct['request'], ri=ri)
3527                            cw.nl()
3528
3529            print_kernel_op_table(parsed, cw)
3530            print_kernel_mcgrp_src(parsed, cw)
3531            print_kernel_family_struct_src(parsed, cw)
3532
3533    if args.mode == "user":
3534        if args.header:
3535            cw.p('/* Enums */')
3536            put_op_name_fwd(parsed, cw)
3537
3538            for name, const in parsed.consts.items():
3539                if isinstance(const, EnumSet):
3540                    put_enum_to_str_fwd(parsed, cw, const)
3541            cw.nl()
3542
3543            cw.p('/* Common nested types */')
3544            for attr_set, struct in parsed.pure_nested_structs.items():
3545                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3546                print_type_full(ri, struct)
3547
3548            for op_name, op in parsed.ops.items():
3549                cw.p(f"/* ============== {op.enum_name} ============== */")
3550
3551                if 'do' in op and 'event' not in op:
3552                    cw.p(f"/* {op.enum_name} - do */")
3553                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3554                    print_req_type(ri)
3555                    print_req_type_helpers(ri)
3556                    cw.nl()
3557                    print_rsp_type(ri)
3558                    print_rsp_type_helpers(ri)
3559                    cw.nl()
3560                    print_req_prototype(ri)
3561                    cw.nl()
3562
3563                if 'dump' in op:
3564                    cw.p(f"/* {op.enum_name} - dump */")
3565                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3566                    print_req_type(ri)
3567                    print_req_type_helpers(ri)
3568                    if not ri.type_consistent or ri.type_oneside:
3569                        print_rsp_type(ri)
3570                    print_wrapped_type(ri)
3571                    print_dump_prototype(ri)
3572                    cw.nl()
3573
3574                if op.has_ntf:
3575                    cw.p(f"/* {op.enum_name} - notify */")
3576                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3577                    if not ri.type_consistent:
3578                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3579                    print_wrapped_type(ri)
3580
3581            for op_name, op in parsed.ntfs.items():
3582                if 'event' in op:
3583                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3584                    cw.p(f"/* {op.enum_name} - event */")
3585                    print_rsp_type(ri)
3586                    cw.nl()
3587                    print_wrapped_type(ri)
3588            cw.nl()
3589        else:
3590            cw.p('/* Enums */')
3591            put_op_name(parsed, cw)
3592
3593            for name, const in parsed.consts.items():
3594                if isinstance(const, EnumSet):
3595                    put_enum_to_str(parsed, cw, const)
3596            cw.nl()
3597
3598            has_recursive_nests = False
3599            cw.p('/* Policies */')
3600            for struct in parsed.pure_nested_structs.values():
3601                if struct.recursive:
3602                    put_typol_fwd(cw, struct)
3603                    has_recursive_nests = True
3604            if has_recursive_nests:
3605                cw.nl()
3606            for struct in parsed.pure_nested_structs.values():
3607                put_typol(cw, struct)
3608            for name in parsed.root_sets:
3609                struct = Struct(parsed, name)
3610                put_typol(cw, struct)
3611
3612            cw.p('/* Common nested types */')
3613            if has_recursive_nests:
3614                for attr_set, struct in parsed.pure_nested_structs.items():
3615                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3616                    free_rsp_nested_prototype(ri)
3617                    if struct.request:
3618                        put_req_nested_prototype(ri, struct)
3619                    if struct.reply:
3620                        parse_rsp_nested_prototype(ri, struct)
3621                cw.nl()
3622            for attr_set, struct in parsed.pure_nested_structs.items():
3623                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3624
3625                free_rsp_nested(ri, struct)
3626                if struct.request:
3627                    put_req_nested(ri, struct)
3628                if struct.reply:
3629                    parse_rsp_nested(ri, struct)
3630
3631            for op_name, op in parsed.ops.items():
3632                cw.p(f"/* ============== {op.enum_name} ============== */")
3633                if 'do' in op and 'event' not in op:
3634                    cw.p(f"/* {op.enum_name} - do */")
3635                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3636                    print_req_free(ri)
3637                    print_rsp_free(ri)
3638                    parse_rsp_msg(ri)
3639                    print_req(ri)
3640                    cw.nl()
3641
3642                if 'dump' in op:
3643                    cw.p(f"/* {op.enum_name} - dump */")
3644                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3645                    if not ri.type_consistent or ri.type_oneside:
3646                        parse_rsp_msg(ri, deref=True)
3647                    print_req_free(ri)
3648                    print_dump_type_free(ri)
3649                    print_dump(ri)
3650                    cw.nl()
3651
3652                if op.has_ntf:
3653                    cw.p(f"/* {op.enum_name} - notify */")
3654                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3655                    if not ri.type_consistent:
3656                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3657                    print_ntf_type_free(ri)
3658
3659            for op_name, op in parsed.ntfs.items():
3660                if 'event' in op:
3661                    cw.p(f"/* {op.enum_name} - event */")
3662
3663                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3664                    parse_rsp_msg(ri)
3665
3666                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3667                    print_ntf_type_free(ri)
3668            cw.nl()
3669            render_user_family(parsed, cw, False)
3670
3671    if args.header:
3672        cw.p(f'#endif /* {hdr_prot} */')
3673
3674
3675if __name__ == "__main__":
3676    main()
3677