xref: /linux/tools/net/ynl/pyynl/ynl_gen_c.py (revision 099902fc66f87424a5b4aa6d58832c7aab0ac6c6)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import filecmp
6import pathlib
7import os
8import re
9import shutil
10import sys
11import tempfile
12import yaml
13
14sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
15from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
16from lib import SpecSubMessage
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        self.is_selector = False
60
61        if 'len' in attr:
62            self.len = attr['len']
63
64        if 'nested-attributes' in attr:
65            nested = attr['nested-attributes']
66        elif 'sub-message' in attr:
67            nested = attr['sub-message']
68        else:
69            nested = None
70
71        if nested:
72            self.nested_attrs = nested
73            if self.nested_attrs == family.name:
74                self.nested_render_name = c_lower(f"{family.ident_name}")
75            else:
76                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
77
78            if self.nested_attrs in self.family.consts:
79                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
80            else:
81                self.nested_struct_type = 'struct ' + self.nested_render_name
82
83        self.c_name = c_lower(self.name)
84        if self.c_name in _C_KW:
85            self.c_name += '_'
86        if self.c_name[0].isdigit():
87            self.c_name = '_' + self.c_name
88
89        # Added by resolve():
90        self.enum_name = None
91        delattr(self, "enum_name")
92
93    def _get_real_attr(self):
94        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
95        return self.family.attr_sets[self.attr_set.subset_of][self.name]
96
97    def set_request(self):
98        self.request = True
99        if self.attr_set.subset_of:
100            self._get_real_attr().set_request()
101
102    def set_reply(self):
103        self.reply = True
104        if self.attr_set.subset_of:
105            self._get_real_attr().set_reply()
106
107    def get_limit(self, limit, default=None):
108        value = self.checks.get(limit, default)
109        if value is None:
110            return value
111        if isinstance(value, int):
112            return value
113        if value in self.family.consts:
114            return self.family.consts[value]["value"]
115        return limit_to_number(value)
116
117    def get_limit_str(self, limit, default=None, suffix=''):
118        value = self.checks.get(limit, default)
119        if value is None:
120            return ''
121        if isinstance(value, int):
122            return str(value) + suffix
123        if value in self.family.consts:
124            const = self.family.consts[value]
125            if const.get('header'):
126                return c_upper(value)
127            return c_upper(f"{self.family['name']}-{value}")
128        return c_upper(value)
129
130    def resolve(self):
131        if 'parent-sub-message' in self.attr:
132            enum_name = self.attr['parent-sub-message'].enum_name
133        elif 'name-prefix' in self.attr:
134            enum_name = f"{self.attr['name-prefix']}{self.name}"
135        else:
136            enum_name = f"{self.attr_set.name_prefix}{self.name}"
137        self.enum_name = c_upper(enum_name)
138
139        if self.attr_set.subset_of:
140            if self.checks != self._get_real_attr().checks:
141                raise Exception("Overriding checks not supported by codegen, yet")
142
143    def is_multi_val(self):
144        return None
145
146    def is_scalar(self):
147        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
148
149    def is_recursive(self):
150        return False
151
152    def is_recursive_for_op(self, ri):
153        return self.is_recursive() and not ri.op
154
155    def presence_type(self):
156        return 'present'
157
158    def presence_member(self, space, type_filter):
159        if self.presence_type() != type_filter:
160            return
161
162        if self.presence_type() == 'present':
163            pfx = '__' if space == 'user' else ''
164            return f"{pfx}u32 {self.c_name}:1;"
165
166        if self.presence_type() in {'len', 'count'}:
167            pfx = '__' if space == 'user' else ''
168            return f"{pfx}u32 {self.c_name};"
169
170    def _complex_member_type(self, ri):
171        return None
172
173    def free_needs_iter(self):
174        return False
175
176    def _free_lines(self, ri, var, ref):
177        if self.is_multi_val() or self.presence_type() in {'count', 'len'}:
178            return [f'free({var}->{ref}{self.c_name});']
179        return []
180
181    def free(self, ri, var, ref):
182        lines = self._free_lines(ri, var, ref)
183        for line in lines:
184            ri.cw.p(line)
185
186    def arg_member(self, ri):
187        member = self._complex_member_type(ri)
188        if member:
189            spc = ' ' if member[-1] != '*' else ''
190            arg = [member + spc + '*' + self.c_name]
191            if self.presence_type() == 'count':
192                arg += ['unsigned int n_' + self.c_name]
193            return arg
194        raise Exception(f"Struct member not implemented for class type {self.type}")
195
196    def struct_member(self, ri):
197        member = self._complex_member_type(ri)
198        if member:
199            ptr = '*' if self.is_multi_val() else ''
200            if self.is_recursive_for_op(ri):
201                ptr = '*'
202            spc = ' ' if member[-1] != '*' else ''
203            ri.cw.p(f"{member}{spc}{ptr}{self.c_name};")
204            return
205        members = self.arg_member(ri)
206        for one in members:
207            ri.cw.p(one + ';')
208
209    def _attr_policy(self, policy):
210        return '{ .type = ' + policy + ', }'
211
212    def attr_policy(self, cw):
213        policy = f'NLA_{c_upper(self.type)}'
214        if self.attr.get('byte-order') == 'big-endian':
215            if self.type in {'u16', 'u32'}:
216                policy = f'NLA_BE{self.type[1:]}'
217
218        spec = self._attr_policy(policy)
219        cw.p(f"\t[{self.enum_name}] = {spec},")
220
221    def _attr_typol(self):
222        raise Exception(f"Type policy not implemented for class type {self.type}")
223
224    def attr_typol(self, cw):
225        typol = self._attr_typol()
226        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
227
228    def _attr_put_line(self, ri, var, line):
229        presence = self.presence_type()
230        if presence in {'present', 'len'}:
231            ri.cw.p(f"if ({var}->_{presence}.{self.c_name})")
232        ri.cw.p(f"{line};")
233
234    def _attr_put_simple(self, ri, var, put_type):
235        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
236        self._attr_put_line(ri, var, line)
237
238    def attr_put(self, ri, var):
239        raise Exception(f"Put not implemented for class type {self.type}")
240
241    def _attr_get(self, ri, var):
242        raise Exception(f"Attr get not implemented for class type {self.type}")
243
244    def attr_get(self, ri, var, first):
245        lines, init_lines, _ = self._attr_get(ri, var)
246        if type(lines) is str:
247            lines = [lines]
248        if type(init_lines) is str:
249            init_lines = [init_lines]
250
251        kw = 'if' if first else 'else if'
252        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
253
254        if not self.is_multi_val():
255            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
256            ri.cw.p("return YNL_PARSE_CB_ERROR;")
257            if self.presence_type() == 'present':
258                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
259
260        if init_lines:
261            ri.cw.nl()
262            for line in init_lines:
263                ri.cw.p(line)
264
265        for line in lines:
266            ri.cw.p(line)
267        ri.cw.block_end()
268        return True
269
270    def _setter_lines(self, ri, member, presence):
271        raise Exception(f"Setter not implemented for class type {self.type}")
272
273    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
274        ref = (ref if ref else []) + [self.c_name]
275        member = f"{var}->{'.'.join(ref)}"
276
277        local_vars = []
278        if self.free_needs_iter():
279            local_vars += ['unsigned int i;']
280
281        code = []
282        presence = ''
283        for i in range(0, len(ref)):
284            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
285            # Every layer below last is a nest, so we know it uses bit presence
286            # last layer is "self" and may be a complex type
287            if i == len(ref) - 1 and self.presence_type() != 'present':
288                presence = f"{var}->{'.'.join(ref[:i] + [''])}_{self.presence_type()}.{ref[i]}"
289                continue
290            code.append(presence + ' = 1;')
291        ref_path = '.'.join(ref[:-1])
292        if ref_path:
293            ref_path += '.'
294        code += self._free_lines(ri, var, ref_path)
295        code += self._setter_lines(ri, member, presence)
296
297        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
298        free = bool([x for x in code if 'free(' in x])
299        alloc = bool([x for x in code if 'alloc(' in x])
300        if free and not alloc:
301            func_name = '__' + func_name
302        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
303                         body=code,
304                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
305
306
307class TypeUnused(Type):
308    def presence_type(self):
309        return ''
310
311    def arg_member(self, ri):
312        return []
313
314    def _attr_get(self, ri, var):
315        return ['return YNL_PARSE_CB_ERROR;'], None, None
316
317    def _attr_typol(self):
318        return '.type = YNL_PT_REJECT, '
319
320    def attr_policy(self, cw):
321        pass
322
323    def attr_put(self, ri, var):
324        pass
325
326    def attr_get(self, ri, var, first):
327        pass
328
329    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
330        pass
331
332
333class TypePad(Type):
334    def presence_type(self):
335        return ''
336
337    def arg_member(self, ri):
338        return []
339
340    def _attr_typol(self):
341        return '.type = YNL_PT_IGNORE, '
342
343    def attr_put(self, ri, var):
344        pass
345
346    def attr_get(self, ri, var, first):
347        pass
348
349    def attr_policy(self, cw):
350        pass
351
352    def setter(self, ri, space, direction, deref=False, ref=None, var=None):
353        pass
354
355
356class TypeScalar(Type):
357    def __init__(self, family, attr_set, attr, value):
358        super().__init__(family, attr_set, attr, value)
359
360        self.byte_order_comment = ''
361        if 'byte-order' in attr:
362            self.byte_order_comment = f" /* {attr['byte-order']} */"
363
364        # Classic families have some funny enums, don't bother
365        # computing checks, since we only need them for kernel policies
366        if not family.is_classic():
367            self._init_checks()
368
369        # Added by resolve():
370        self.is_bitfield = None
371        delattr(self, "is_bitfield")
372        self.type_name = None
373        delattr(self, "type_name")
374
375    def resolve(self):
376        self.resolve_up(super())
377
378        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
379            self.is_bitfield = True
380        elif 'enum' in self.attr:
381            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
382        else:
383            self.is_bitfield = False
384
385        if not self.is_bitfield and 'enum' in self.attr:
386            self.type_name = self.family.consts[self.attr['enum']].user_type
387        elif self.is_auto_scalar:
388            self.type_name = '__' + self.type[0] + '64'
389        else:
390            self.type_name = '__' + self.type
391
392    def _init_checks(self):
393        if 'enum' in self.attr:
394            enum = self.family.consts[self.attr['enum']]
395            low, high = enum.value_range()
396            if low is None and high is None:
397                self.checks['sparse'] = True
398            else:
399                if 'min' not in self.checks:
400                    if low != 0 or self.type[0] == 's':
401                        self.checks['min'] = low
402                if 'max' not in self.checks:
403                    self.checks['max'] = high
404
405        if 'min' in self.checks and 'max' in self.checks:
406            if self.get_limit('min') > self.get_limit('max'):
407                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
408            self.checks['range'] = True
409
410        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
411        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
412        if low < 0 and self.type[0] == 'u':
413            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
414        if low < -32768 or high > 32767:
415            self.checks['full-range'] = True
416
417    def _attr_policy(self, policy):
418        if 'flags-mask' in self.checks or self.is_bitfield:
419            if self.is_bitfield:
420                enum = self.family.consts[self.attr['enum']]
421                mask = enum.get_mask(as_flags=True)
422            else:
423                flags = self.family.consts[self.checks['flags-mask']]
424                flag_cnt = len(flags['entries'])
425                mask = (1 << flag_cnt) - 1
426            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
427        elif 'full-range' in self.checks:
428            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
429        elif 'range' in self.checks:
430            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
431        elif 'min' in self.checks:
432            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
433        elif 'max' in self.checks:
434            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
435        elif 'sparse' in self.checks:
436            return f"NLA_POLICY_VALIDATE_FN({policy}, &{c_lower(self.enum_name)}_validate)"
437        return super()._attr_policy(policy)
438
439    def _attr_typol(self):
440        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
441
442    def arg_member(self, ri):
443        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
444
445    def attr_put(self, ri, var):
446        self._attr_put_simple(ri, var, self.type)
447
448    def _attr_get(self, ri, var):
449        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
450
451    def _setter_lines(self, ri, member, presence):
452        return [f"{member} = {self.c_name};"]
453
454
455class TypeFlag(Type):
456    def arg_member(self, ri):
457        return []
458
459    def _attr_typol(self):
460        return '.type = YNL_PT_FLAG, '
461
462    def attr_put(self, ri, var):
463        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
464
465    def _attr_get(self, ri, var):
466        return [], None, None
467
468    def _setter_lines(self, ri, member, presence):
469        return []
470
471
472class TypeString(Type):
473    def arg_member(self, ri):
474        return [f"const char *{self.c_name}"]
475
476    def presence_type(self):
477        return 'len'
478
479    def struct_member(self, ri):
480        ri.cw.p(f"char *{self.c_name};")
481
482    def _attr_typol(self):
483        typol = '.type = YNL_PT_NUL_STR, '
484        if self.is_selector:
485            typol += '.is_selector = 1, '
486        return typol
487
488    def _attr_policy(self, policy):
489        if 'exact-len' in self.checks:
490            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
491        else:
492            mem = '{ .type = ' + policy
493            if 'max-len' in self.checks:
494                mem += ', .len = ' + self.get_limit_str('max-len')
495            mem += ', }'
496        return mem
497
498    def attr_policy(self, cw):
499        if self.checks.get('unterminated-ok', False):
500            policy = 'NLA_STRING'
501        else:
502            policy = 'NLA_NUL_STRING'
503
504        spec = self._attr_policy(policy)
505        cw.p(f"\t[{self.enum_name}] = {spec},")
506
507    def attr_put(self, ri, var):
508        self._attr_put_simple(ri, var, 'str')
509
510    def _attr_get(self, ri, var):
511        len_mem = var + '->_len.' + self.c_name
512        return [f"{len_mem} = len;",
513                f"{var}->{self.c_name} = malloc(len + 1);",
514                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
515                f"{var}->{self.c_name}[len] = 0;"], \
516               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
517               ['unsigned int len;']
518
519    def _setter_lines(self, ri, member, presence):
520        return [f"{presence} = strlen({self.c_name});",
521                f"{member} = malloc({presence} + 1);",
522                f'memcpy({member}, {self.c_name}, {presence});',
523                f'{member}[{presence}] = 0;']
524
525
526class TypeBinary(Type):
527    def arg_member(self, ri):
528        return [f"const void *{self.c_name}", 'size_t len']
529
530    def presence_type(self):
531        return 'len'
532
533    def struct_member(self, ri):
534        ri.cw.p(f"void *{self.c_name};")
535
536    def _attr_typol(self):
537        return '.type = YNL_PT_BINARY,'
538
539    def _attr_policy(self, policy):
540        if len(self.checks) == 0:
541            pass
542        elif len(self.checks) == 1:
543            check_name = list(self.checks)[0]
544            if check_name not in {'exact-len', 'min-len', 'max-len'}:
545                raise Exception('Unsupported check for binary type: ' + check_name)
546        else:
547            raise Exception('More than one check for binary type not implemented, yet')
548
549        if len(self.checks) == 0:
550            mem = '{ .type = NLA_BINARY, }'
551        elif 'exact-len' in self.checks:
552            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
553        elif 'min-len' in self.checks:
554            mem = 'NLA_POLICY_MIN_LEN(' + self.get_limit_str('min-len') + ')'
555        elif 'max-len' in self.checks:
556            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
557
558        return mem
559
560    def attr_put(self, ri, var):
561        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
562                            f"{var}->{self.c_name}, {var}->_len.{self.c_name})")
563
564    def _attr_get(self, ri, var):
565        len_mem = var + '->_len.' + self.c_name
566        return [f"{len_mem} = len;",
567                f"{var}->{self.c_name} = malloc(len);",
568                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
569               ['len = ynl_attr_data_len(attr);'], \
570               ['unsigned int len;']
571
572    def _setter_lines(self, ri, member, presence):
573        return [f"{presence} = len;",
574                f"{member} = malloc({presence});",
575                f'memcpy({member}, {self.c_name}, {presence});']
576
577
578class TypeBinaryStruct(TypeBinary):
579    def struct_member(self, ri):
580        ri.cw.p(f'struct {c_lower(self.get("struct"))} *{self.c_name};')
581
582    def _attr_get(self, ri, var):
583        struct_sz = 'sizeof(struct ' + c_lower(self.get("struct")) + ')'
584        len_mem = var + '->_' + self.presence_type() + '.' + self.c_name
585        return [f"{len_mem} = len;",
586                f"if (len < {struct_sz})",
587                f"{var}->{self.c_name} = calloc(1, {struct_sz});",
588                "else",
589                f"{var}->{self.c_name} = malloc(len);",
590                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
591               ['len = ynl_attr_data_len(attr);'], \
592               ['unsigned int len;']
593
594
595class TypeBinaryScalarArray(TypeBinary):
596    def arg_member(self, ri):
597        return [f'__{self.get("sub-type")} *{self.c_name}', 'size_t count']
598
599    def presence_type(self):
600        return 'count'
601
602    def struct_member(self, ri):
603        ri.cw.p(f'__{self.get("sub-type")} *{self.c_name};')
604
605    def attr_put(self, ri, var):
606        presence = self.presence_type()
607        ri.cw.block_start(line=f"if ({var}->_{presence}.{self.c_name})")
608        ri.cw.p(f"i = {var}->_{presence}.{self.c_name} * sizeof(__{self.get('sub-type')});")
609        ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, " +
610                f"{var}->{self.c_name}, i);")
611        ri.cw.block_end()
612
613    def _attr_get(self, ri, var):
614        len_mem = var + '->_count.' + self.c_name
615        return [f"{len_mem} = len / sizeof(__{self.get('sub-type')});",
616                f"len = {len_mem} * sizeof(__{self.get('sub-type')});",
617                f"{var}->{self.c_name} = malloc(len);",
618                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
619               ['len = ynl_attr_data_len(attr);'], \
620               ['unsigned int len;']
621
622    def _setter_lines(self, ri, member, presence):
623        return [f"{presence} = count;",
624                f"count *= sizeof(__{self.get('sub-type')});",
625                f"{member} = malloc(count);",
626                f'memcpy({member}, {self.c_name}, count);']
627
628
629class TypeBitfield32(Type):
630    def _complex_member_type(self, ri):
631        return "struct nla_bitfield32"
632
633    def _attr_typol(self):
634        return '.type = YNL_PT_BITFIELD32, '
635
636    def _attr_policy(self, policy):
637        if 'enum' not in self.attr:
638            raise Exception('Enum required for bitfield32 attr')
639        enum = self.family.consts[self.attr['enum']]
640        mask = enum.get_mask(as_flags=True)
641        return f"NLA_POLICY_BITFIELD32({mask})"
642
643    def attr_put(self, ri, var):
644        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
645        self._attr_put_line(ri, var, line)
646
647    def _attr_get(self, ri, var):
648        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
649
650    def _setter_lines(self, ri, member, presence):
651        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
652
653
654class TypeNest(Type):
655    def is_recursive(self):
656        return self.family.pure_nested_structs[self.nested_attrs].recursive
657
658    def _complex_member_type(self, ri):
659        return self.nested_struct_type
660
661    def _free_lines(self, ri, var, ref):
662        lines = []
663        at = '&'
664        if self.is_recursive_for_op(ri):
665            at = ''
666            lines += [f'if ({var}->{ref}{self.c_name})']
667        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
668        return lines
669
670    def _attr_typol(self):
671        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
672
673    def _attr_policy(self, policy):
674        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
675
676    def attr_put(self, ri, var):
677        at = '' if self.is_recursive_for_op(ri) else '&'
678        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
679                            f"{self.enum_name}, {at}{var}->{self.c_name})")
680
681    def _attr_get(self, ri, var):
682        pns = self.family.pure_nested_structs[self.nested_attrs]
683        args = ["&parg", "attr"]
684        for sel in pns.external_selectors():
685            args.append(f'{var}->{sel.name}')
686        get_lines = [f"if ({self.nested_render_name}_parse({', '.join(args)}))",
687                     "return YNL_PARSE_CB_ERROR;"]
688        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
689                      f"parg.data = &{var}->{self.c_name};"]
690        return get_lines, init_lines, None
691
692    def setter(self, ri, space, direction, deref=False, ref=None, var="req"):
693        ref = (ref if ref else []) + [self.c_name]
694
695        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
696            if attr.is_recursive():
697                continue
698            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref,
699                        var=var)
700
701
702class TypeMultiAttr(Type):
703    def __init__(self, family, attr_set, attr, value, base_type):
704        super().__init__(family, attr_set, attr, value)
705
706        self.base_type = base_type
707
708    def is_multi_val(self):
709        return True
710
711    def presence_type(self):
712        return 'count'
713
714    def _complex_member_type(self, ri):
715        if 'type' not in self.attr or self.attr['type'] == 'nest':
716            return self.nested_struct_type
717        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
718            return None  # use arg_member()
719        elif self.attr['type'] == 'string':
720            return 'struct ynl_string *'
721        elif self.attr['type'] in scalars:
722            scalar_pfx = '__' if ri.ku_space == 'user' else ''
723            return scalar_pfx + self.attr['type']
724        else:
725            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
726
727    def arg_member(self, ri):
728        if self.type == 'binary' and 'struct' in self.attr:
729            return [f'struct {c_lower(self.attr["struct"])} *{self.c_name}',
730                    f'unsigned int n_{self.c_name}']
731        return super().arg_member(ri)
732
733    def free_needs_iter(self):
734        return self.attr['type'] in {'nest', 'string'}
735
736    def _free_lines(self, ri, var, ref):
737        lines = []
738        if self.attr['type'] in scalars:
739            lines += [f"free({var}->{ref}{self.c_name});"]
740        elif self.attr['type'] == 'binary':
741            lines += [f"free({var}->{ref}{self.c_name});"]
742        elif self.attr['type'] == 'string':
743            lines += [
744                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
745                f"free({var}->{ref}{self.c_name}[i]);",
746                f"free({var}->{ref}{self.c_name});",
747            ]
748        elif 'type' not in self.attr or self.attr['type'] == 'nest':
749            lines += [
750                f"for (i = 0; i < {var}->{ref}_count.{self.c_name}; i++)",
751                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
752                f"free({var}->{ref}{self.c_name});",
753            ]
754        else:
755            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
756        return lines
757
758    def _attr_policy(self, policy):
759        return self.base_type._attr_policy(policy)
760
761    def _attr_typol(self):
762        return self.base_type._attr_typol()
763
764    def _attr_get(self, ri, var):
765        return f'n_{self.c_name}++;', None, None
766
767    def attr_put(self, ri, var):
768        if self.attr['type'] in scalars:
769            put_type = self.type
770            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
771            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
772        elif self.attr['type'] == 'binary' and 'struct' in self.attr:
773            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
774            ri.cw.p(f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}[i], sizeof(struct {c_lower(self.attr['struct'])}));")
775        elif self.attr['type'] == 'string':
776            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
777            ri.cw.p(f"ynl_attr_put_str(nlh, {self.enum_name}, {var}->{self.c_name}[i]->str);")
778        elif 'type' not in self.attr or self.attr['type'] == 'nest':
779            ri.cw.p(f"for (i = 0; i < {var}->_count.{self.c_name}; i++)")
780            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
781                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
782        else:
783            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
784
785    def _setter_lines(self, ri, member, presence):
786        return [f"{member} = {self.c_name};",
787                f"{presence} = n_{self.c_name};"]
788
789
790class TypeArrayNest(Type):
791    def is_multi_val(self):
792        return True
793
794    def presence_type(self):
795        return 'count'
796
797    def _complex_member_type(self, ri):
798        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
799            return self.nested_struct_type
800        elif self.attr['sub-type'] in scalars:
801            scalar_pfx = '__' if ri.ku_space == 'user' else ''
802            return scalar_pfx + self.attr['sub-type']
803        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
804            return None  # use arg_member()
805        else:
806            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
807
808    def arg_member(self, ri):
809        if self.sub_type == 'binary' and 'exact-len' in self.checks:
810            return [f'unsigned char (*{self.c_name})[{self.checks["exact-len"]}]',
811                    f'unsigned int n_{self.c_name}']
812        return super().arg_member(ri)
813
814    def _attr_policy(self, policy):
815        if self.attr['sub-type'] == 'nest':
816            return f'NLA_POLICY_NESTED_ARRAY({self.nested_render_name}_nl_policy)'
817        return super()._attr_policy(policy)
818
819    def _attr_typol(self):
820        if self.attr['sub-type'] in scalars:
821            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
822        elif self.attr['sub-type'] == 'binary' and 'exact-len' in self.checks:
823            return f'.type = YNL_PT_BINARY, .len = {self.checks["exact-len"]}, '
824        elif self.attr['sub-type'] == 'nest':
825            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
826        else:
827            raise Exception(f"Typol for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
828
829    def _attr_get(self, ri, var):
830        local_vars = ['const struct nlattr *attr2;']
831        get_lines = [f'attr_{self.c_name} = attr;',
832                     'ynl_attr_for_each_nested(attr2, attr) {',
833                     '\tif (ynl_attr_validate(yarg, attr2))',
834                     '\t\treturn YNL_PARSE_CB_ERROR;',
835                     f'\tn_{self.c_name}++;',
836                     '}']
837        return get_lines, None, local_vars
838
839    def attr_put(self, ri, var):
840        ri.cw.p(f'array = ynl_attr_nest_start(nlh, {self.enum_name});')
841        if self.sub_type in scalars:
842            put_type = self.sub_type
843            ri.cw.block_start(line=f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
844            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, i, {var}->{self.c_name}[i]);")
845            ri.cw.block_end()
846        elif self.sub_type == 'binary' and 'exact-len' in self.checks:
847            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
848            ri.cw.p(f"ynl_attr_put(nlh, i, {var}->{self.c_name}[i], {self.checks['exact-len']});")
849        elif self.sub_type == 'nest':
850            ri.cw.p(f'for (i = 0; i < {var}->_count.{self.c_name}; i++)')
851            ri.cw.p(f"{self.nested_render_name}_put(nlh, i, &{var}->{self.c_name}[i]);")
852        else:
853            raise Exception(f"Put for ArrayNest sub-type {self.attr['sub-type']} not supported, yet")
854        ri.cw.p('ynl_attr_nest_end(nlh, array);')
855
856    def _setter_lines(self, ri, member, presence):
857        return [f"{member} = {self.c_name};",
858                f"{presence} = n_{self.c_name};"]
859
860
861class TypeNestTypeValue(Type):
862    def _complex_member_type(self, ri):
863        return self.nested_struct_type
864
865    def _attr_typol(self):
866        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
867
868    def _attr_get(self, ri, var):
869        prev = 'attr'
870        tv_args = ''
871        get_lines = []
872        local_vars = []
873        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
874                      f"parg.data = &{var}->{self.c_name};"]
875        if 'type-value' in self.attr:
876            tv_names = [c_lower(x) for x in self.attr["type-value"]]
877            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
878            local_vars += [f'__u32 {", ".join(tv_names)};']
879            for level in self.attr["type-value"]:
880                level = c_lower(level)
881                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
882                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
883                prev = 'attr_' + level
884
885            tv_args = f", {', '.join(tv_names)}"
886
887        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
888        return get_lines, init_lines, local_vars
889
890
891class TypeSubMessage(TypeNest):
892    def __init__(self, family, attr_set, attr, value):
893        super().__init__(family, attr_set, attr, value)
894
895        self.selector = Selector(attr, attr_set)
896
897    def _attr_typol(self):
898        typol = f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
899        typol += '.is_submsg = 1, '
900        # Reverse-parsing of the policy (ynl_err_walk() in ynl.c) does not
901        # support external selectors. No family uses sub-messages with external
902        # selector for requests so this is fine for now.
903        if not self.selector.is_external():
904            typol += f'.selector_type = {self.attr_set[self["selector"]].value} '
905        return typol
906
907    def _attr_get(self, ri, var):
908        sel = c_lower(self['selector'])
909        if self.selector.is_external():
910            sel_var = f"_sel_{sel}"
911        else:
912            sel_var = f"{var}->{sel}"
913        get_lines = [f'if (!{sel_var})',
914                     'return ynl_submsg_failed(yarg, "%s", "%s");' %
915                        (self.name, self['selector']),
916                    f"if ({self.nested_render_name}_parse(&parg, {sel_var}, attr))",
917                     "return YNL_PARSE_CB_ERROR;"]
918        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
919                      f"parg.data = &{var}->{self.c_name};"]
920        return get_lines, init_lines, None
921
922
923class Selector:
924    def __init__(self, msg_attr, attr_set):
925        self.name = msg_attr["selector"]
926
927        if self.name in attr_set:
928            self.attr = attr_set[self.name]
929            self.attr.is_selector = True
930            self._external = False
931        else:
932            # The selector will need to get passed down thru the structs
933            self.attr = None
934            self._external = True
935
936    def set_attr(self, attr):
937        self.attr = attr
938
939    def is_external(self):
940        return self._external
941
942
943class Struct:
944    def __init__(self, family, space_name, type_list=None, fixed_header=None,
945                 inherited=None, submsg=None):
946        self.family = family
947        self.space_name = space_name
948        self.attr_set = family.attr_sets[space_name]
949        # Use list to catch comparisons with empty sets
950        self._inherited = inherited if inherited is not None else []
951        self.inherited = []
952        self.fixed_header = None
953        if fixed_header:
954            self.fixed_header = 'struct ' + c_lower(fixed_header)
955        self.submsg = submsg
956
957        self.nested = type_list is None
958        if family.name == c_lower(space_name):
959            self.render_name = c_lower(family.ident_name)
960        else:
961            self.render_name = c_lower(family.ident_name + '-' + space_name)
962        self.struct_name = 'struct ' + self.render_name
963        if self.nested and space_name in family.consts:
964            self.struct_name += '_'
965        self.ptr_name = self.struct_name + ' *'
966        # All attr sets this one contains, directly or multiple levels down
967        self.child_nests = set()
968
969        self.request = False
970        self.reply = False
971        self.recursive = False
972        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
973
974        self.attr_list = []
975        self.attrs = dict()
976        if type_list is not None:
977            for t in type_list:
978                self.attr_list.append((t, self.attr_set[t]),)
979        else:
980            for t in self.attr_set:
981                self.attr_list.append((t, self.attr_set[t]),)
982
983        max_val = 0
984        self.attr_max_val = None
985        for name, attr in self.attr_list:
986            if attr.value >= max_val:
987                max_val = attr.value
988                self.attr_max_val = attr
989            self.attrs[name] = attr
990
991    def __iter__(self):
992        yield from self.attrs
993
994    def __getitem__(self, key):
995        return self.attrs[key]
996
997    def member_list(self):
998        return self.attr_list
999
1000    def set_inherited(self, new_inherited):
1001        if self._inherited != new_inherited:
1002            raise Exception("Inheriting different members not supported")
1003        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
1004
1005    def external_selectors(self):
1006        sels = []
1007        for name, attr in self.attr_list:
1008            if isinstance(attr, TypeSubMessage) and attr.selector.is_external():
1009                sels.append(attr.selector)
1010        return sels
1011
1012    def free_needs_iter(self):
1013        for _, attr in self.attr_list:
1014            if attr.free_needs_iter():
1015                return True
1016        return False
1017
1018
1019class EnumEntry(SpecEnumEntry):
1020    def __init__(self, enum_set, yaml, prev, value_start):
1021        super().__init__(enum_set, yaml, prev, value_start)
1022
1023        if prev:
1024            self.value_change = (self.value != prev.value + 1)
1025        else:
1026            self.value_change = (self.value != 0)
1027        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
1028
1029        # Added by resolve:
1030        self.c_name = None
1031        delattr(self, "c_name")
1032
1033    def resolve(self):
1034        self.resolve_up(super())
1035
1036        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
1037
1038
1039class EnumSet(SpecEnumSet):
1040    def __init__(self, family, yaml):
1041        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1042
1043        if 'enum-name' in yaml:
1044            if yaml['enum-name']:
1045                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
1046                self.user_type = self.enum_name
1047            else:
1048                self.enum_name = None
1049        else:
1050            self.enum_name = 'enum ' + self.render_name
1051
1052        if self.enum_name:
1053            self.user_type = self.enum_name
1054        else:
1055            self.user_type = 'int'
1056
1057        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
1058        self.header = yaml.get('header', None)
1059        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
1060
1061        super().__init__(family, yaml)
1062
1063    def new_entry(self, entry, prev_entry, value_start):
1064        return EnumEntry(self, entry, prev_entry, value_start)
1065
1066    def value_range(self):
1067        low = min([x.value for x in self.entries.values()])
1068        high = max([x.value for x in self.entries.values()])
1069
1070        if high - low + 1 != len(self.entries):
1071            return None, None
1072
1073        return low, high
1074
1075
1076class AttrSet(SpecAttrSet):
1077    def __init__(self, family, yaml):
1078        super().__init__(family, yaml)
1079
1080        if self.subset_of is None:
1081            if 'name-prefix' in yaml:
1082                pfx = yaml['name-prefix']
1083            elif self.name == family.name:
1084                pfx = family.ident_name + '-a-'
1085            else:
1086                pfx = f"{family.ident_name}-a-{self.name}-"
1087            self.name_prefix = c_upper(pfx)
1088            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
1089            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
1090        else:
1091            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
1092            self.max_name = family.attr_sets[self.subset_of].max_name
1093            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
1094
1095        # Added by resolve:
1096        self.c_name = None
1097        delattr(self, "c_name")
1098
1099    def resolve(self):
1100        self.c_name = c_lower(self.name)
1101        if self.c_name in _C_KW:
1102            self.c_name += '_'
1103        if self.c_name == self.family.c_name:
1104            self.c_name = ''
1105
1106    def new_attr(self, elem, value):
1107        if elem['type'] in scalars:
1108            t = TypeScalar(self.family, self, elem, value)
1109        elif elem['type'] == 'unused':
1110            t = TypeUnused(self.family, self, elem, value)
1111        elif elem['type'] == 'pad':
1112            t = TypePad(self.family, self, elem, value)
1113        elif elem['type'] == 'flag':
1114            t = TypeFlag(self.family, self, elem, value)
1115        elif elem['type'] == 'string':
1116            t = TypeString(self.family, self, elem, value)
1117        elif elem['type'] == 'binary':
1118            if 'struct' in elem:
1119                t = TypeBinaryStruct(self.family, self, elem, value)
1120            elif elem.get('sub-type') in scalars:
1121                t = TypeBinaryScalarArray(self.family, self, elem, value)
1122            else:
1123                t = TypeBinary(self.family, self, elem, value)
1124        elif elem['type'] == 'bitfield32':
1125            t = TypeBitfield32(self.family, self, elem, value)
1126        elif elem['type'] == 'nest':
1127            t = TypeNest(self.family, self, elem, value)
1128        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
1129            if elem["sub-type"] in ['binary', 'nest', 'u32']:
1130                t = TypeArrayNest(self.family, self, elem, value)
1131            else:
1132                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
1133        elif elem['type'] == 'nest-type-value':
1134            t = TypeNestTypeValue(self.family, self, elem, value)
1135        elif elem['type'] == 'sub-message':
1136            t = TypeSubMessage(self.family, self, elem, value)
1137        else:
1138            raise Exception(f"No typed class for type {elem['type']}")
1139
1140        if 'multi-attr' in elem and elem['multi-attr']:
1141            t = TypeMultiAttr(self.family, self, elem, value, t)
1142
1143        return t
1144
1145
1146class Operation(SpecOperation):
1147    def __init__(self, family, yaml, req_value, rsp_value):
1148        # Fill in missing operation properties (for fixed hdr-only msgs)
1149        for mode in ['do', 'dump', 'event']:
1150            for direction in ['request', 'reply']:
1151                try:
1152                    yaml[mode][direction].setdefault('attributes', [])
1153                except KeyError:
1154                    pass
1155
1156        super().__init__(family, yaml, req_value, rsp_value)
1157
1158        self.render_name = c_lower(family.ident_name + '_' + self.name)
1159
1160        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
1161                         ('dump' in yaml and 'request' in yaml['dump'])
1162
1163        self.has_ntf = False
1164
1165        # Added by resolve:
1166        self.enum_name = None
1167        delattr(self, "enum_name")
1168
1169    def resolve(self):
1170        self.resolve_up(super())
1171
1172        if not self.is_async:
1173            self.enum_name = self.family.op_prefix + c_upper(self.name)
1174        else:
1175            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
1176
1177    def mark_has_ntf(self):
1178        self.has_ntf = True
1179
1180
1181class SubMessage(SpecSubMessage):
1182    def __init__(self, family, yaml):
1183        super().__init__(family, yaml)
1184
1185        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
1186
1187    def resolve(self):
1188        self.resolve_up(super())
1189
1190
1191class Family(SpecFamily):
1192    def __init__(self, file_name, exclude_ops):
1193        # Added by resolve:
1194        self.c_name = None
1195        delattr(self, "c_name")
1196        self.op_prefix = None
1197        delattr(self, "op_prefix")
1198        self.async_op_prefix = None
1199        delattr(self, "async_op_prefix")
1200        self.mcgrps = None
1201        delattr(self, "mcgrps")
1202        self.consts = None
1203        delattr(self, "consts")
1204        self.hooks = None
1205        delattr(self, "hooks")
1206
1207        super().__init__(file_name, exclude_ops=exclude_ops)
1208
1209        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
1210        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
1211
1212        if 'definitions' not in self.yaml:
1213            self.yaml['definitions'] = []
1214
1215        if 'uapi-header' in self.yaml:
1216            self.uapi_header = self.yaml['uapi-header']
1217        else:
1218            self.uapi_header = f"linux/{self.ident_name}.h"
1219        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
1220            self.uapi_header_name = self.uapi_header[6:-2]
1221        else:
1222            self.uapi_header_name = self.ident_name
1223
1224    def resolve(self):
1225        self.resolve_up(super())
1226
1227        self.c_name = c_lower(self.ident_name)
1228        if 'name-prefix' in self.yaml['operations']:
1229            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1230        else:
1231            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1232        if 'async-prefix' in self.yaml['operations']:
1233            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1234        else:
1235            self.async_op_prefix = self.op_prefix
1236
1237        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1238
1239        self.hooks = dict()
1240        for when in ['pre', 'post']:
1241            self.hooks[when] = dict()
1242            for op_mode in ['do', 'dump']:
1243                self.hooks[when][op_mode] = dict()
1244                self.hooks[when][op_mode]['set'] = set()
1245                self.hooks[when][op_mode]['list'] = []
1246
1247        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1248        self.root_sets = dict()
1249        # dict space-name -> Struct
1250        self.pure_nested_structs = dict()
1251
1252        self._mark_notify()
1253        self._mock_up_events()
1254
1255        self._load_root_sets()
1256        self._load_nested_sets()
1257        self._load_attr_use()
1258        self._load_selector_passing()
1259        self._load_hooks()
1260
1261        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1262        if self.kernel_policy == 'global':
1263            self._load_global_policy()
1264
1265    def new_enum(self, elem):
1266        return EnumSet(self, elem)
1267
1268    def new_attr_set(self, elem):
1269        return AttrSet(self, elem)
1270
1271    def new_operation(self, elem, req_value, rsp_value):
1272        return Operation(self, elem, req_value, rsp_value)
1273
1274    def new_sub_message(self, elem):
1275        return SubMessage(self, elem)
1276
1277    def is_classic(self):
1278        return self.proto == 'netlink-raw'
1279
1280    def _mark_notify(self):
1281        for op in self.msgs.values():
1282            if 'notify' in op:
1283                self.ops[op['notify']].mark_has_ntf()
1284
1285    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1286    def _mock_up_events(self):
1287        for op in self.yaml['operations']['list']:
1288            if 'event' in op:
1289                op['do'] = {
1290                    'reply': {
1291                        'attributes': op['event']['attributes']
1292                    }
1293                }
1294
1295    def _load_root_sets(self):
1296        for op_name, op in self.msgs.items():
1297            if 'attribute-set' not in op:
1298                continue
1299
1300            req_attrs = set()
1301            rsp_attrs = set()
1302            for op_mode in ['do', 'dump']:
1303                if op_mode in op and 'request' in op[op_mode]:
1304                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1305                if op_mode in op and 'reply' in op[op_mode]:
1306                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1307            if 'event' in op:
1308                rsp_attrs.update(set(op['event']['attributes']))
1309
1310            if op['attribute-set'] not in self.root_sets:
1311                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1312            else:
1313                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1314                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1315
1316    def _sort_pure_types(self):
1317        # Try to reorder according to dependencies
1318        pns_key_list = list(self.pure_nested_structs.keys())
1319        pns_key_seen = set()
1320        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1321        for _ in range(rounds):
1322            if len(pns_key_list) == 0:
1323                break
1324            name = pns_key_list.pop(0)
1325            finished = True
1326            for _, spec in self.attr_sets[name].items():
1327                if 'nested-attributes' in spec:
1328                    nested = spec['nested-attributes']
1329                elif 'sub-message' in spec:
1330                    nested = spec.sub_message
1331                else:
1332                    continue
1333
1334                # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1335                if self.pure_nested_structs[nested].recursive:
1336                    continue
1337                if nested not in pns_key_seen:
1338                    # Dicts are sorted, this will make struct last
1339                    struct = self.pure_nested_structs.pop(name)
1340                    self.pure_nested_structs[name] = struct
1341                    finished = False
1342                    break
1343            if finished:
1344                pns_key_seen.add(name)
1345            else:
1346                pns_key_list.append(name)
1347
1348    def _load_nested_set_nest(self, spec):
1349        inherit = set()
1350        nested = spec['nested-attributes']
1351        if nested not in self.root_sets:
1352            if nested not in self.pure_nested_structs:
1353                self.pure_nested_structs[nested] = \
1354                    Struct(self, nested, inherited=inherit,
1355                           fixed_header=spec.get('fixed-header'))
1356        else:
1357            raise Exception(f'Using attr set as root and nested not supported - {nested}')
1358
1359        if 'type-value' in spec:
1360            if nested in self.root_sets:
1361                raise Exception("Inheriting members to a space used as root not supported")
1362            inherit.update(set(spec['type-value']))
1363        elif spec['type'] == 'indexed-array':
1364            inherit.add('idx')
1365        self.pure_nested_structs[nested].set_inherited(inherit)
1366
1367        return nested
1368
1369    def _load_nested_set_submsg(self, spec):
1370        # Fake the struct type for the sub-message itself
1371        # its not a attr_set but codegen wants attr_sets.
1372        submsg = self.sub_msgs[spec["sub-message"]]
1373        nested = submsg.name
1374
1375        attrs = []
1376        for name, fmt in submsg.formats.items():
1377            attr = {
1378                "name": name,
1379                "parent-sub-message": spec,
1380            }
1381            if 'attribute-set' in fmt:
1382                attr |= {
1383                    "type": "nest",
1384                    "nested-attributes": fmt['attribute-set'],
1385                }
1386                if 'fixed-header' in fmt:
1387                    attr |= { "fixed-header": fmt["fixed-header"] }
1388            elif 'fixed-header' in fmt:
1389                attr |= {
1390                    "type": "binary",
1391                    "struct": fmt["fixed-header"],
1392                }
1393            else:
1394                attr["type"] = "flag"
1395            attrs.append(attr)
1396
1397        self.attr_sets[nested] = AttrSet(self, {
1398            "name": nested,
1399            "name-pfx": self.name + '-' + spec.name + '-',
1400            "attributes": attrs
1401        })
1402
1403        if nested not in self.pure_nested_structs:
1404            self.pure_nested_structs[nested] = Struct(self, nested, submsg=submsg)
1405
1406        return nested
1407
1408    def _load_nested_sets(self):
1409        attr_set_queue = list(self.root_sets.keys())
1410        attr_set_seen = set(self.root_sets.keys())
1411
1412        while len(attr_set_queue):
1413            a_set = attr_set_queue.pop(0)
1414            for attr, spec in self.attr_sets[a_set].items():
1415                if 'nested-attributes' in spec:
1416                    nested = self._load_nested_set_nest(spec)
1417                elif 'sub-message' in spec:
1418                    nested = self._load_nested_set_submsg(spec)
1419                else:
1420                    continue
1421
1422                if nested not in attr_set_seen:
1423                    attr_set_queue.append(nested)
1424                    attr_set_seen.add(nested)
1425
1426        for root_set, rs_members in self.root_sets.items():
1427            for attr, spec in self.attr_sets[root_set].items():
1428                if 'nested-attributes' in spec:
1429                    nested = spec['nested-attributes']
1430                elif 'sub-message' in spec:
1431                    nested = spec.sub_message
1432                else:
1433                    nested = None
1434
1435                if nested:
1436                    if attr in rs_members['request']:
1437                        self.pure_nested_structs[nested].request = True
1438                    if attr in rs_members['reply']:
1439                        self.pure_nested_structs[nested].reply = True
1440
1441                    if spec.is_multi_val():
1442                        child = self.pure_nested_structs.get(nested)
1443                        child.in_multi_val = True
1444
1445        self._sort_pure_types()
1446
1447        # Propagate the request / reply / recursive
1448        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1449            for _, spec in self.attr_sets[attr_set].items():
1450                if attr_set in struct.child_nests:
1451                    struct.recursive = True
1452
1453                if 'nested-attributes' in spec:
1454                    child_name = spec['nested-attributes']
1455                elif 'sub-message' in spec:
1456                    child_name = spec.sub_message
1457                else:
1458                    continue
1459
1460                struct.child_nests.add(child_name)
1461                child = self.pure_nested_structs.get(child_name)
1462                if child:
1463                    if not child.recursive:
1464                        struct.child_nests.update(child.child_nests)
1465                    child.request |= struct.request
1466                    child.reply |= struct.reply
1467                    if spec.is_multi_val():
1468                        child.in_multi_val = True
1469
1470        self._sort_pure_types()
1471
1472    def _load_attr_use(self):
1473        for _, struct in self.pure_nested_structs.items():
1474            if struct.request:
1475                for _, arg in struct.member_list():
1476                    arg.set_request()
1477            if struct.reply:
1478                for _, arg in struct.member_list():
1479                    arg.set_reply()
1480
1481        for root_set, rs_members in self.root_sets.items():
1482            for attr, spec in self.attr_sets[root_set].items():
1483                if attr in rs_members['request']:
1484                    spec.set_request()
1485                if attr in rs_members['reply']:
1486                    spec.set_reply()
1487
1488    def _load_selector_passing(self):
1489        def all_structs():
1490            for k, v in reversed(self.pure_nested_structs.items()):
1491                yield k, v
1492            for k, _ in self.root_sets.items():
1493                yield k, None  # we don't have a struct, but it must be terminal
1494
1495        for attr_set, struct in all_structs():
1496            for _, spec in self.attr_sets[attr_set].items():
1497                if 'nested-attributes' in spec:
1498                    child_name = spec['nested-attributes']
1499                elif 'sub-message' in spec:
1500                    child_name = spec.sub_message
1501                else:
1502                    continue
1503
1504                child = self.pure_nested_structs.get(child_name)
1505                for selector in child.external_selectors():
1506                    if selector.name in self.attr_sets[attr_set]:
1507                        sel_attr = self.attr_sets[attr_set][selector.name]
1508                        selector.set_attr(sel_attr)
1509                    else:
1510                        raise Exception("Passing selector thru more than one layer not supported")
1511
1512    def _load_global_policy(self):
1513        global_set = set()
1514        attr_set_name = None
1515        for op_name, op in self.ops.items():
1516            if not op:
1517                continue
1518            if 'attribute-set' not in op:
1519                continue
1520
1521            if attr_set_name is None:
1522                attr_set_name = op['attribute-set']
1523            if attr_set_name != op['attribute-set']:
1524                raise Exception('For a global policy all ops must use the same set')
1525
1526            for op_mode in ['do', 'dump']:
1527                if op_mode in op:
1528                    req = op[op_mode].get('request')
1529                    if req:
1530                        global_set.update(req.get('attributes', []))
1531
1532        self.global_policy = []
1533        self.global_policy_set = attr_set_name
1534        for attr in self.attr_sets[attr_set_name]:
1535            if attr in global_set:
1536                self.global_policy.append(attr)
1537
1538    def _load_hooks(self):
1539        for op in self.ops.values():
1540            for op_mode in ['do', 'dump']:
1541                if op_mode not in op:
1542                    continue
1543                for when in ['pre', 'post']:
1544                    if when not in op[op_mode]:
1545                        continue
1546                    name = op[op_mode][when]
1547                    if name in self.hooks[when][op_mode]['set']:
1548                        continue
1549                    self.hooks[when][op_mode]['set'].add(name)
1550                    self.hooks[when][op_mode]['list'].append(name)
1551
1552
1553class RenderInfo:
1554    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1555        self.family = family
1556        self.nl = cw.nlib
1557        self.ku_space = ku_space
1558        self.op_mode = op_mode
1559        self.op = op
1560
1561        fixed_hdr = op.fixed_header if op else None
1562        self.fixed_hdr_len = 'ys->family->hdr_len'
1563        if op and op.fixed_header:
1564            if op.fixed_header != family.fixed_header:
1565                if family.is_classic():
1566                    self.fixed_hdr_len = f"sizeof(struct {c_lower(fixed_hdr)})"
1567                else:
1568                    raise Exception("Per-op fixed header not supported, yet")
1569
1570
1571        # 'do' and 'dump' response parsing is identical
1572        self.type_consistent = True
1573        self.type_oneside = False
1574        if op_mode != 'do' and 'dump' in op:
1575            if 'do' in op:
1576                if ('reply' in op['do']) != ('reply' in op["dump"]):
1577                    self.type_consistent = False
1578                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1579                    self.type_consistent = False
1580            else:
1581                self.type_consistent = True
1582                self.type_oneside = True
1583
1584        self.attr_set = attr_set
1585        if not self.attr_set:
1586            self.attr_set = op['attribute-set']
1587
1588        self.type_name_conflict = False
1589        if op:
1590            self.type_name = c_lower(op.name)
1591        else:
1592            self.type_name = c_lower(attr_set)
1593            if attr_set in family.consts:
1594                self.type_name_conflict = True
1595
1596        self.cw = cw
1597
1598        self.struct = dict()
1599        if op_mode == 'notify':
1600            op_mode = 'do' if 'do' in op else 'dump'
1601        for op_dir in ['request', 'reply']:
1602            if op:
1603                type_list = []
1604                if op_dir in op[op_mode]:
1605                    type_list = op[op_mode][op_dir]['attributes']
1606                self.struct[op_dir] = Struct(family, self.attr_set,
1607                                             fixed_header=fixed_hdr,
1608                                             type_list=type_list)
1609        if op_mode == 'event':
1610            self.struct['reply'] = Struct(family, self.attr_set,
1611                                          fixed_header=fixed_hdr,
1612                                          type_list=op['event']['attributes'])
1613
1614    def type_empty(self, key):
1615        return len(self.struct[key].attr_list) == 0 and \
1616            self.struct['request'].fixed_header is None
1617
1618    def needs_nlflags(self, direction):
1619        return self.op_mode == 'do' and direction == 'request' and self.family.is_classic()
1620
1621
1622class CodeWriter:
1623    def __init__(self, nlib, out_file=None, overwrite=True):
1624        self.nlib = nlib
1625        self._overwrite = overwrite
1626
1627        self._nl = False
1628        self._block_end = False
1629        self._silent_block = False
1630        self._ind = 0
1631        self._ifdef_block = None
1632        if out_file is None:
1633            self._out = os.sys.stdout
1634        else:
1635            self._out = tempfile.NamedTemporaryFile('w+')
1636            self._out_file = out_file
1637
1638    def __del__(self):
1639        self.close_out_file()
1640
1641    def close_out_file(self):
1642        if self._out == os.sys.stdout:
1643            return
1644        # Avoid modifying the file if contents didn't change
1645        self._out.flush()
1646        if not self._overwrite and os.path.isfile(self._out_file):
1647            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1648                return
1649        with open(self._out_file, 'w+') as out_file:
1650            self._out.seek(0)
1651            shutil.copyfileobj(self._out, out_file)
1652            self._out.close()
1653        self._out = os.sys.stdout
1654
1655    @classmethod
1656    def _is_cond(cls, line):
1657        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1658
1659    def p(self, line, add_ind=0):
1660        if self._block_end:
1661            self._block_end = False
1662            if line.startswith('else'):
1663                line = '} ' + line
1664            else:
1665                self._out.write('\t' * self._ind + '}\n')
1666
1667        if self._nl:
1668            self._out.write('\n')
1669            self._nl = False
1670
1671        ind = self._ind
1672        if line[-1] == ':':
1673            ind -= 1
1674        if self._silent_block:
1675            ind += 1
1676        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1677        self._silent_block |= line.strip() == 'else'
1678        if line[0] == '#':
1679            ind = 0
1680        if add_ind:
1681            ind += add_ind
1682        self._out.write('\t' * ind + line + '\n')
1683
1684    def nl(self):
1685        self._nl = True
1686
1687    def block_start(self, line=''):
1688        if line:
1689            line = line + ' '
1690        self.p(line + '{')
1691        self._ind += 1
1692
1693    def block_end(self, line=''):
1694        if line and line[0] not in {';', ','}:
1695            line = ' ' + line
1696        self._ind -= 1
1697        self._nl = False
1698        if not line:
1699            # Delay printing closing bracket in case "else" comes next
1700            if self._block_end:
1701                self._out.write('\t' * (self._ind + 1) + '}\n')
1702            self._block_end = True
1703        else:
1704            self.p('}' + line)
1705
1706    def write_doc_line(self, doc, indent=True):
1707        words = doc.split()
1708        line = ' *'
1709        for word in words:
1710            if len(line) + len(word) >= 79:
1711                self.p(line)
1712                line = ' *'
1713                if indent:
1714                    line += '  '
1715            line += ' ' + word
1716        self.p(line)
1717
1718    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1719        if not args:
1720            args = ['void']
1721
1722        if doc:
1723            self.p('/*')
1724            self.p(' * ' + doc)
1725            self.p(' */')
1726
1727        oneline = qual_ret
1728        if qual_ret[-1] != '*':
1729            oneline += ' '
1730        oneline += f"{name}({', '.join(args)}){suffix}"
1731
1732        if len(oneline) < 80:
1733            self.p(oneline)
1734            return
1735
1736        v = qual_ret
1737        if len(v) > 3:
1738            self.p(v)
1739            v = ''
1740        elif qual_ret[-1] != '*':
1741            v += ' '
1742        v += name + '('
1743        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1744        delta_ind = len(v) - len(ind)
1745        v += args[0]
1746        i = 1
1747        while i < len(args):
1748            next_len = len(v) + len(args[i])
1749            if v[0] == '\t':
1750                next_len += delta_ind
1751            if next_len > 76:
1752                self.p(v + ',')
1753                v = ind
1754            else:
1755                v += ', '
1756            v += args[i]
1757            i += 1
1758        self.p(v + ')' + suffix)
1759
1760    def write_func_lvar(self, local_vars):
1761        if not local_vars:
1762            return
1763
1764        if type(local_vars) is str:
1765            local_vars = [local_vars]
1766
1767        local_vars.sort(key=len, reverse=True)
1768        for var in local_vars:
1769            self.p(var)
1770        self.nl()
1771
1772    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1773        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1774        self.block_start()
1775        self.write_func_lvar(local_vars=local_vars)
1776
1777        for line in body:
1778            self.p(line)
1779        self.block_end()
1780
1781    def writes_defines(self, defines):
1782        longest = 0
1783        for define in defines:
1784            if len(define[0]) > longest:
1785                longest = len(define[0])
1786        longest = ((longest + 8) // 8) * 8
1787        for define in defines:
1788            line = '#define ' + define[0]
1789            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1790            if type(define[1]) is int:
1791                line += str(define[1])
1792            elif type(define[1]) is str:
1793                line += '"' + define[1] + '"'
1794            self.p(line)
1795
1796    def write_struct_init(self, members):
1797        longest = max([len(x[0]) for x in members])
1798        longest += 1  # because we prepend a .
1799        longest = ((longest + 8) // 8) * 8
1800        for one in members:
1801            line = '.' + one[0]
1802            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1803            line += '= ' + str(one[1]) + ','
1804            self.p(line)
1805
1806    def ifdef_block(self, config):
1807        config_option = None
1808        if config:
1809            config_option = 'CONFIG_' + c_upper(config)
1810        if self._ifdef_block == config_option:
1811            return
1812
1813        if self._ifdef_block:
1814            self.p('#endif /* ' + self._ifdef_block + ' */')
1815        if config_option:
1816            self.p('#ifdef ' + config_option)
1817        self._ifdef_block = config_option
1818
1819
1820scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1821
1822direction_to_suffix = {
1823    'reply': '_rsp',
1824    'request': '_req',
1825    '': ''
1826}
1827
1828op_mode_to_wrapper = {
1829    'do': '',
1830    'dump': '_list',
1831    'notify': '_ntf',
1832    'event': '',
1833}
1834
1835_C_KW = {
1836    'auto',
1837    'bool',
1838    'break',
1839    'case',
1840    'char',
1841    'const',
1842    'continue',
1843    'default',
1844    'do',
1845    'double',
1846    'else',
1847    'enum',
1848    'extern',
1849    'float',
1850    'for',
1851    'goto',
1852    'if',
1853    'inline',
1854    'int',
1855    'long',
1856    'register',
1857    'return',
1858    'short',
1859    'signed',
1860    'sizeof',
1861    'static',
1862    'struct',
1863    'switch',
1864    'typedef',
1865    'union',
1866    'unsigned',
1867    'void',
1868    'volatile',
1869    'while'
1870}
1871
1872
1873def rdir(direction):
1874    if direction == 'reply':
1875        return 'request'
1876    if direction == 'request':
1877        return 'reply'
1878    return direction
1879
1880
1881def op_prefix(ri, direction, deref=False):
1882    suffix = f"_{ri.type_name}"
1883
1884    if not ri.op_mode:
1885        pass
1886    elif ri.op_mode == 'do':
1887        suffix += f"{direction_to_suffix[direction]}"
1888    else:
1889        if direction == 'request':
1890            suffix += '_req'
1891            if not ri.type_oneside:
1892                suffix += '_dump'
1893        else:
1894            if ri.type_consistent:
1895                if deref:
1896                    suffix += f"{direction_to_suffix[direction]}"
1897                else:
1898                    suffix += op_mode_to_wrapper[ri.op_mode]
1899            else:
1900                suffix += '_rsp'
1901                suffix += '_dump' if deref else '_list'
1902
1903    return f"{ri.family.c_name}{suffix}"
1904
1905
1906def type_name(ri, direction, deref=False):
1907    return f"struct {op_prefix(ri, direction, deref=deref)}"
1908
1909
1910def print_prototype(ri, direction, terminate=True, doc=None):
1911    suffix = ';' if terminate else ''
1912
1913    fname = ri.op.render_name
1914    if ri.op_mode == 'dump':
1915        fname += '_dump'
1916
1917    args = ['struct ynl_sock *ys']
1918    if 'request' in ri.op[ri.op_mode]:
1919        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1920
1921    ret = 'int'
1922    if 'reply' in ri.op[ri.op_mode]:
1923        ret = f"{type_name(ri, rdir(direction))} *"
1924
1925    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1926
1927
1928def print_req_prototype(ri):
1929    print_prototype(ri, "request", doc=ri.op['doc'])
1930
1931
1932def print_dump_prototype(ri):
1933    print_prototype(ri, "request")
1934
1935
1936def put_typol_submsg(cw, struct):
1937    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[] =')
1938
1939    i = 0
1940    for name, arg in struct.member_list():
1941        nest = ""
1942        if arg.type == 'nest':
1943            nest = f" .nest = &{arg.nested_render_name}_nest,"
1944        cw.p('[%d] = { .type = YNL_PT_SUBMSG, .name = "%s",%s },' %
1945             (i, name, nest))
1946        i += 1
1947
1948    cw.block_end(line=';')
1949    cw.nl()
1950
1951    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1952    cw.p(f'.max_attr = {i - 1},')
1953    cw.p(f'.table = {struct.render_name}_policy,')
1954    cw.block_end(line=';')
1955    cw.nl()
1956
1957
1958def put_typol_fwd(cw, struct):
1959    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1960
1961
1962def put_typol(cw, struct):
1963    if struct.submsg:
1964        put_typol_submsg(cw, struct)
1965        return
1966
1967    type_max = struct.attr_set.max_name
1968    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1969
1970    for _, arg in struct.member_list():
1971        arg.attr_typol(cw)
1972
1973    cw.block_end(line=';')
1974    cw.nl()
1975
1976    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1977    cw.p(f'.max_attr = {type_max},')
1978    cw.p(f'.table = {struct.render_name}_policy,')
1979    cw.block_end(line=';')
1980    cw.nl()
1981
1982
1983def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1984    args = [f'int {arg_name}']
1985    if enum:
1986        args = [enum.user_type + ' ' + arg_name]
1987    cw.write_func_prot('const char *', f'{render_name}_str', args)
1988    cw.block_start()
1989    if enum and enum.type == 'flags':
1990        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1991    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1992    cw.p('return NULL;')
1993    cw.p(f'return {map_name}[{arg_name}];')
1994    cw.block_end()
1995    cw.nl()
1996
1997
1998def put_op_name_fwd(family, cw):
1999    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
2000
2001
2002def put_op_name(family, cw):
2003    map_name = f'{family.c_name}_op_strmap'
2004    cw.block_start(line=f"static const char * const {map_name}[] =")
2005    for op_name, op in family.msgs.items():
2006        if op.rsp_value:
2007            # Make sure we don't add duplicated entries, if multiple commands
2008            # produce the same response in legacy families.
2009            if family.rsp_by_value[op.rsp_value] != op:
2010                cw.p(f'// skip "{op_name}", duplicate reply value')
2011                continue
2012
2013            if op.req_value == op.rsp_value:
2014                cw.p(f'[{op.enum_name}] = "{op_name}",')
2015            else:
2016                cw.p(f'[{op.rsp_value}] = "{op_name}",')
2017    cw.block_end(line=';')
2018    cw.nl()
2019
2020    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
2021
2022
2023def put_enum_to_str_fwd(family, cw, enum):
2024    args = [enum.user_type + ' value']
2025    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
2026
2027
2028def put_enum_to_str(family, cw, enum):
2029    map_name = f'{enum.render_name}_strmap'
2030    cw.block_start(line=f"static const char * const {map_name}[] =")
2031    for entry in enum.entries.values():
2032        cw.p(f'[{entry.value}] = "{entry.name}",')
2033    cw.block_end(line=';')
2034    cw.nl()
2035
2036    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
2037
2038
2039def put_local_vars(struct):
2040    local_vars = []
2041    has_array = False
2042    has_count = False
2043    for _, arg in struct.member_list():
2044        has_array |= arg.type == 'indexed-array'
2045        has_count |= arg.presence_type() == 'count'
2046    if has_array:
2047        local_vars.append('struct nlattr *array;')
2048    if has_count:
2049        local_vars.append('unsigned int i;')
2050    return local_vars
2051
2052
2053def put_req_nested_prototype(ri, struct, suffix=';'):
2054    func_args = ['struct nlmsghdr *nlh',
2055                 'unsigned int attr_type',
2056                 f'{struct.ptr_name}obj']
2057
2058    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
2059                          suffix=suffix)
2060
2061
2062def put_req_nested(ri, struct):
2063    local_vars = []
2064    init_lines = []
2065
2066    if struct.submsg is None:
2067        local_vars.append('struct nlattr *nest;')
2068        init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
2069    if struct.fixed_header:
2070        local_vars.append('void *hdr;')
2071        struct_sz = f'sizeof({struct.fixed_header})'
2072        init_lines.append(f"hdr = ynl_nlmsg_put_extra_header(nlh, {struct_sz});")
2073        init_lines.append(f"memcpy(hdr, &obj->_hdr, {struct_sz});")
2074
2075    local_vars += put_local_vars(struct)
2076
2077    put_req_nested_prototype(ri, struct, suffix='')
2078    ri.cw.block_start()
2079    ri.cw.write_func_lvar(local_vars)
2080
2081    for line in init_lines:
2082        ri.cw.p(line)
2083
2084    for _, arg in struct.member_list():
2085        arg.attr_put(ri, "obj")
2086
2087    if struct.submsg is None:
2088        ri.cw.p("ynl_attr_nest_end(nlh, nest);")
2089
2090    ri.cw.nl()
2091    ri.cw.p('return 0;')
2092    ri.cw.block_end()
2093    ri.cw.nl()
2094
2095
2096def _multi_parse(ri, struct, init_lines, local_vars):
2097    if struct.fixed_header:
2098        local_vars += ['void *hdr;']
2099    if struct.nested:
2100        if struct.fixed_header:
2101            iter_line = f"ynl_attr_for_each_nested_off(attr, nested, sizeof({struct.fixed_header}))"
2102        else:
2103            iter_line = "ynl_attr_for_each_nested(attr, nested)"
2104    else:
2105        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
2106        if ri.op.fixed_header != ri.family.fixed_header:
2107            if ri.family.is_classic():
2108                iter_line = f"ynl_attr_for_each(attr, nlh, sizeof({struct.fixed_header}))"
2109            else:
2110                raise Exception("Per-op fixed header not supported, yet")
2111
2112    var_set = set()
2113    array_nests = set()
2114    multi_attrs = set()
2115    needs_parg = False
2116    for arg, aspec in struct.member_list():
2117        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
2118            if aspec["sub-type"] in {'binary', 'nest'}:
2119                local_vars.append(f'const struct nlattr *attr_{aspec.c_name} = NULL;')
2120                array_nests.add(arg)
2121            elif aspec['sub-type'] in scalars:
2122                local_vars.append(f'const struct nlattr *attr_{aspec.c_name} = NULL;')
2123                array_nests.add(arg)
2124            else:
2125                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
2126        if 'multi-attr' in aspec:
2127            multi_attrs.add(arg)
2128        needs_parg |= 'nested-attributes' in aspec
2129        needs_parg |= 'sub-message' in aspec
2130
2131        try:
2132            _, _, l_vars = aspec._attr_get(ri, '')
2133            var_set |= set(l_vars) if l_vars else set()
2134        except Exception:
2135            pass  # _attr_get() not implemented by simple types, ignore
2136    local_vars += list(var_set)
2137    if array_nests or multi_attrs:
2138        local_vars.append('int i;')
2139    if needs_parg:
2140        local_vars.append('struct ynl_parse_arg parg;')
2141        init_lines.append('parg.ys = yarg->ys;')
2142
2143    all_multi = array_nests | multi_attrs
2144
2145    for anest in sorted(all_multi):
2146        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
2147
2148    ri.cw.block_start()
2149    ri.cw.write_func_lvar(local_vars)
2150
2151    for line in init_lines:
2152        ri.cw.p(line)
2153    ri.cw.nl()
2154
2155    for arg in struct.inherited:
2156        ri.cw.p(f'dst->{arg} = {arg};')
2157
2158    if struct.fixed_header:
2159        if struct.nested:
2160            ri.cw.p('hdr = ynl_attr_data(nested);')
2161        elif ri.family.is_classic():
2162            ri.cw.p('hdr = ynl_nlmsg_data(nlh);')
2163        else:
2164            ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
2165        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({struct.fixed_header}));")
2166    for anest in sorted(all_multi):
2167        aspec = struct[anest]
2168        ri.cw.p(f"if (dst->{aspec.c_name})")
2169        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
2170
2171    ri.cw.nl()
2172    ri.cw.block_start(line=iter_line)
2173    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
2174    ri.cw.nl()
2175
2176    first = True
2177    for _, arg in struct.member_list():
2178        good = arg.attr_get(ri, 'dst', first=first)
2179        # First may be 'unused' or 'pad', ignore those
2180        first &= not good
2181
2182    ri.cw.block_end()
2183    ri.cw.nl()
2184
2185    for anest in sorted(array_nests):
2186        aspec = struct[anest]
2187
2188        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2189        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2190        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2191        ri.cw.p('i = 0;')
2192        if 'nested-attributes' in aspec:
2193            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2194        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
2195        if 'nested-attributes' in aspec:
2196            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2197            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
2198            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2199        elif aspec.sub_type in scalars:
2200            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
2201        elif aspec.sub_type == 'binary' and 'exact-len' in aspec.checks:
2202            # Length is validated by typol
2203            ri.cw.p(f'memcpy(dst->{aspec.c_name}[i], ynl_attr_data(attr), {aspec.checks["exact-len"]});')
2204        else:
2205            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
2206        ri.cw.p('i++;')
2207        ri.cw.block_end()
2208        ri.cw.block_end()
2209    ri.cw.nl()
2210
2211    for anest in sorted(multi_attrs):
2212        aspec = struct[anest]
2213        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
2214        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
2215        ri.cw.p(f"dst->_count.{aspec.c_name} = n_{aspec.c_name};")
2216        ri.cw.p('i = 0;')
2217        if 'nested-attributes' in aspec:
2218            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
2219        ri.cw.block_start(line=iter_line)
2220        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
2221        if 'nested-attributes' in aspec:
2222            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
2223            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
2224            ri.cw.p('return YNL_PARSE_CB_ERROR;')
2225        elif aspec.type in scalars:
2226            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
2227        elif aspec.type == 'binary' and 'struct' in aspec:
2228            ri.cw.p('size_t len = ynl_attr_data_len(attr);')
2229            ri.cw.nl()
2230            ri.cw.p(f'if (len > sizeof(dst->{aspec.c_name}[0]))')
2231            ri.cw.p(f'len = sizeof(dst->{aspec.c_name}[0]);')
2232            ri.cw.p(f"memcpy(&dst->{aspec.c_name}[i], ynl_attr_data(attr), len);")
2233        elif aspec.type == 'string':
2234            ri.cw.p('unsigned int len;')
2235            ri.cw.nl()
2236            ri.cw.p('len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));')
2237            ri.cw.p(f'dst->{aspec.c_name}[i] = malloc(sizeof(struct ynl_string) + len + 1);')
2238            ri.cw.p(f"dst->{aspec.c_name}[i]->len = len;")
2239            ri.cw.p(f"memcpy(dst->{aspec.c_name}[i]->str, ynl_attr_get_str(attr), len);")
2240            ri.cw.p(f"dst->{aspec.c_name}[i]->str[len] = 0;")
2241        else:
2242            raise Exception(f'Nest parsing of type {aspec.type} not supported yet')
2243        ri.cw.p('i++;')
2244        ri.cw.block_end()
2245        ri.cw.block_end()
2246        ri.cw.block_end()
2247    ri.cw.nl()
2248
2249    if struct.nested:
2250        ri.cw.p('return 0;')
2251    else:
2252        ri.cw.p('return YNL_PARSE_CB_OK;')
2253    ri.cw.block_end()
2254    ri.cw.nl()
2255
2256
2257def parse_rsp_submsg(ri, struct):
2258    parse_rsp_nested_prototype(ri, struct, suffix='')
2259
2260    var = 'dst'
2261    local_vars = {'const struct nlattr *attr = nested;',
2262                  f'{struct.ptr_name}{var} = yarg->data;',
2263                  'struct ynl_parse_arg parg;'}
2264
2265    for _, arg in struct.member_list():
2266        _, _, l_vars = arg._attr_get(ri, var)
2267        local_vars |= set(l_vars) if l_vars else set()
2268
2269    ri.cw.block_start()
2270    ri.cw.write_func_lvar(list(local_vars))
2271    ri.cw.p('parg.ys = yarg->ys;')
2272    ri.cw.nl()
2273
2274    first = True
2275    for name, arg in struct.member_list():
2276        kw = 'if' if first else 'else if'
2277        first = False
2278
2279        ri.cw.block_start(line=f'{kw} (!strcmp(sel, "{name}"))')
2280        get_lines, init_lines, _ = arg._attr_get(ri, var)
2281        for line in init_lines or []:
2282            ri.cw.p(line)
2283        for line in get_lines:
2284            ri.cw.p(line)
2285        if arg.presence_type() == 'present':
2286            ri.cw.p(f"{var}->_present.{arg.c_name} = 1;")
2287        ri.cw.block_end()
2288    ri.cw.p('return 0;')
2289    ri.cw.block_end()
2290    ri.cw.nl()
2291
2292
2293def parse_rsp_nested_prototype(ri, struct, suffix=';'):
2294    func_args = ['struct ynl_parse_arg *yarg',
2295                 'const struct nlattr *nested']
2296    for sel in struct.external_selectors():
2297        func_args.append('const char *_sel_' + sel.name)
2298    if struct.submsg:
2299        func_args.insert(1, 'const char *sel')
2300    for arg in struct.inherited:
2301        func_args.append('__u32 ' + arg)
2302
2303    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
2304                          suffix=suffix)
2305
2306
2307def parse_rsp_nested(ri, struct):
2308    if struct.submsg:
2309        return parse_rsp_submsg(ri, struct)
2310
2311    parse_rsp_nested_prototype(ri, struct, suffix='')
2312
2313    local_vars = ['const struct nlattr *attr;',
2314                  f'{struct.ptr_name}dst = yarg->data;']
2315    init_lines = []
2316
2317    if struct.member_list():
2318        _multi_parse(ri, struct, init_lines, local_vars)
2319    else:
2320        # Empty nest
2321        ri.cw.block_start()
2322        ri.cw.p('return 0;')
2323        ri.cw.block_end()
2324        ri.cw.nl()
2325
2326
2327def parse_rsp_msg(ri, deref=False):
2328    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
2329        return
2330
2331    func_args = ['const struct nlmsghdr *nlh',
2332                 'struct ynl_parse_arg *yarg']
2333
2334    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
2335                  'const struct nlattr *attr;']
2336    init_lines = ['dst = yarg->data;']
2337
2338    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
2339
2340    if ri.struct["reply"].member_list():
2341        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
2342    else:
2343        # Empty reply
2344        ri.cw.block_start()
2345        ri.cw.p('return YNL_PARSE_CB_OK;')
2346        ri.cw.block_end()
2347        ri.cw.nl()
2348
2349
2350def print_req(ri):
2351    ret_ok = '0'
2352    ret_err = '-1'
2353    direction = "request"
2354    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
2355                  'struct nlmsghdr *nlh;',
2356                  'int err;']
2357
2358    if 'reply' in ri.op[ri.op_mode]:
2359        ret_ok = 'rsp'
2360        ret_err = 'NULL'
2361        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
2362
2363    if ri.struct["request"].fixed_header:
2364        local_vars += ['size_t hdr_len;',
2365                       'void *hdr;']
2366
2367    local_vars += put_local_vars(ri.struct['request'])
2368
2369    print_prototype(ri, direction, terminate=False)
2370    ri.cw.block_start()
2371    ri.cw.write_func_lvar(local_vars)
2372
2373    if ri.family.is_classic():
2374        ri.cw.p(f"nlh = ynl_msg_start_req(ys, {ri.op.enum_name}, req->_nlmsg_flags);")
2375    else:
2376        ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2377
2378    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2379    ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2380    if 'reply' in ri.op[ri.op_mode]:
2381        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2382    ri.cw.nl()
2383
2384    if ri.struct['request'].fixed_header:
2385        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2386        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2387        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2388        ri.cw.nl()
2389
2390    for _, attr in ri.struct["request"].member_list():
2391        attr.attr_put(ri, "req")
2392    ri.cw.nl()
2393
2394    if 'reply' in ri.op[ri.op_mode]:
2395        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
2396        ri.cw.p('yrs.yarg.data = rsp;')
2397        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
2398        if ri.op.value is not None:
2399            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
2400        else:
2401            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
2402        ri.cw.nl()
2403    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
2404    ri.cw.p('if (err < 0)')
2405    if 'reply' in ri.op[ri.op_mode]:
2406        ri.cw.p('goto err_free;')
2407    else:
2408        ri.cw.p('return -1;')
2409    ri.cw.nl()
2410
2411    ri.cw.p(f"return {ret_ok};")
2412    ri.cw.nl()
2413
2414    if 'reply' in ri.op[ri.op_mode]:
2415        ri.cw.p('err_free:')
2416        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
2417        ri.cw.p(f"return {ret_err};")
2418
2419    ri.cw.block_end()
2420
2421
2422def print_dump(ri):
2423    direction = "request"
2424    print_prototype(ri, direction, terminate=False)
2425    ri.cw.block_start()
2426    local_vars = ['struct ynl_dump_state yds = {};',
2427                  'struct nlmsghdr *nlh;',
2428                  'int err;']
2429
2430    if ri.struct['request'].fixed_header:
2431        local_vars += ['size_t hdr_len;',
2432                       'void *hdr;']
2433
2434    if 'request' in ri.op[ri.op_mode]:
2435        local_vars += put_local_vars(ri.struct['request'])
2436
2437    ri.cw.write_func_lvar(local_vars)
2438
2439    ri.cw.p('yds.yarg.ys = ys;')
2440    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
2441    ri.cw.p("yds.yarg.data = NULL;")
2442    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
2443    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
2444    if ri.op.value is not None:
2445        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
2446    else:
2447        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
2448    ri.cw.nl()
2449    if ri.family.is_classic():
2450        ri.cw.p(f"nlh = ynl_msg_start_dump(ys, {ri.op.enum_name});")
2451    else:
2452        ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
2453
2454    if ri.struct['request'].fixed_header:
2455        ri.cw.p("hdr_len = sizeof(req->_hdr);")
2456        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
2457        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
2458        ri.cw.nl()
2459
2460    if "request" in ri.op[ri.op_mode]:
2461        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
2462        ri.cw.p(f"ys->req_hdr_len = {ri.fixed_hdr_len};")
2463        ri.cw.nl()
2464        for _, attr in ri.struct["request"].member_list():
2465            attr.attr_put(ri, "req")
2466    ri.cw.nl()
2467
2468    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
2469    ri.cw.p('if (err < 0)')
2470    ri.cw.p('goto free_list;')
2471    ri.cw.nl()
2472
2473    ri.cw.p('return yds.first;')
2474    ri.cw.nl()
2475    ri.cw.p('free_list:')
2476    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
2477    ri.cw.p('return NULL;')
2478    ri.cw.block_end()
2479
2480
2481def call_free(ri, direction, var):
2482    return f"{op_prefix(ri, direction)}_free({var});"
2483
2484
2485def free_arg_name(direction):
2486    if direction:
2487        return direction_to_suffix[direction][1:]
2488    return 'obj'
2489
2490
2491def print_alloc_wrapper(ri, direction, struct=None):
2492    name = op_prefix(ri, direction)
2493    struct_name = name
2494    if ri.type_name_conflict:
2495        struct_name += '_'
2496
2497    args = ["void"]
2498    cnt = "1"
2499    if struct and struct.in_multi_val:
2500        args = ["unsigned int n"]
2501        cnt = "n"
2502
2503    ri.cw.write_func_prot(f'static inline struct {struct_name} *',
2504                          f"{name}_alloc", args)
2505    ri.cw.block_start()
2506    ri.cw.p(f'return calloc({cnt}, sizeof(struct {struct_name}));')
2507    ri.cw.block_end()
2508
2509
2510def print_free_prototype(ri, direction, suffix=';'):
2511    name = op_prefix(ri, direction)
2512    struct_name = name
2513    if ri.type_name_conflict:
2514        struct_name += '_'
2515    arg = free_arg_name(direction)
2516    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2517
2518
2519def print_nlflags_set(ri, direction):
2520    name = op_prefix(ri, direction)
2521    ri.cw.write_func_prot('static inline void', f"{name}_set_nlflags",
2522                          [f"struct {name} *req", "__u16 nl_flags"])
2523    ri.cw.block_start()
2524    ri.cw.p('req->_nlmsg_flags = nl_flags;')
2525    ri.cw.block_end()
2526    ri.cw.nl()
2527
2528
2529def _print_type(ri, direction, struct):
2530    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2531    if not direction and ri.type_name_conflict:
2532        suffix += '_'
2533
2534    if ri.op_mode == 'dump' and not ri.type_oneside:
2535        suffix += '_dump'
2536
2537    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2538
2539    if ri.needs_nlflags(direction):
2540        ri.cw.p('__u16 _nlmsg_flags;')
2541        ri.cw.nl()
2542    if struct.fixed_header:
2543        ri.cw.p(struct.fixed_header + ' _hdr;')
2544        ri.cw.nl()
2545
2546    for type_filter in ['present', 'len', 'count']:
2547        meta_started = False
2548        for _, attr in struct.member_list():
2549            line = attr.presence_member(ri.ku_space, type_filter)
2550            if line:
2551                if not meta_started:
2552                    ri.cw.block_start(line="struct")
2553                    meta_started = True
2554                ri.cw.p(line)
2555        if meta_started:
2556            ri.cw.block_end(line=f'_{type_filter};')
2557    ri.cw.nl()
2558
2559    for arg in struct.inherited:
2560        ri.cw.p(f"__u32 {arg};")
2561
2562    for _, attr in struct.member_list():
2563        attr.struct_member(ri)
2564
2565    ri.cw.block_end(line=';')
2566    ri.cw.nl()
2567
2568
2569def print_type(ri, direction):
2570    _print_type(ri, direction, ri.struct[direction])
2571
2572
2573def print_type_full(ri, struct):
2574    _print_type(ri, "", struct)
2575
2576    if struct.request and struct.in_multi_val:
2577        print_alloc_wrapper(ri, "", struct)
2578        ri.cw.nl()
2579        free_rsp_nested_prototype(ri)
2580        ri.cw.nl()
2581
2582        # Name conflicts are too hard to deal with with the current code base,
2583        # they are very rare so don't bother printing setters in that case.
2584        if ri.ku_space == 'user' and not ri.type_name_conflict:
2585            for _, attr in struct.member_list():
2586                attr.setter(ri, ri.attr_set, "", var="obj")
2587        ri.cw.nl()
2588
2589
2590def print_type_helpers(ri, direction, deref=False):
2591    print_free_prototype(ri, direction)
2592    ri.cw.nl()
2593
2594    if ri.needs_nlflags(direction):
2595        print_nlflags_set(ri, direction)
2596
2597    if ri.ku_space == 'user' and direction == 'request':
2598        for _, attr in ri.struct[direction].member_list():
2599            attr.setter(ri, ri.attr_set, direction, deref=deref)
2600    ri.cw.nl()
2601
2602
2603def print_req_type_helpers(ri):
2604    if ri.type_empty("request"):
2605        return
2606    print_alloc_wrapper(ri, "request")
2607    print_type_helpers(ri, "request")
2608
2609
2610def print_rsp_type_helpers(ri):
2611    if 'reply' not in ri.op[ri.op_mode]:
2612        return
2613    print_type_helpers(ri, "reply")
2614
2615
2616def print_parse_prototype(ri, direction, terminate=True):
2617    suffix = "_rsp" if direction == "reply" else "_req"
2618    term = ';' if terminate else ''
2619
2620    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2621                          ['const struct nlattr **tb',
2622                           f"struct {ri.op.render_name}{suffix} *req"],
2623                          suffix=term)
2624
2625
2626def print_req_type(ri):
2627    if ri.type_empty("request"):
2628        return
2629    print_type(ri, "request")
2630
2631
2632def print_req_free(ri):
2633    if 'request' not in ri.op[ri.op_mode]:
2634        return
2635    _free_type(ri, 'request', ri.struct['request'])
2636
2637
2638def print_rsp_type(ri):
2639    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2640        direction = 'reply'
2641    elif ri.op_mode == 'event':
2642        direction = 'reply'
2643    else:
2644        return
2645    print_type(ri, direction)
2646
2647
2648def print_wrapped_type(ri):
2649    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2650    if ri.op_mode == 'dump':
2651        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2652    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2653        ri.cw.p('__u16 family;')
2654        ri.cw.p('__u8 cmd;')
2655        ri.cw.p('struct ynl_ntf_base_type *next;')
2656        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2657    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2658    ri.cw.block_end(line=';')
2659    ri.cw.nl()
2660    print_free_prototype(ri, 'reply')
2661    ri.cw.nl()
2662
2663
2664def _free_type_members_iter(ri, struct):
2665    if struct.free_needs_iter():
2666        ri.cw.p('unsigned int i;')
2667        ri.cw.nl()
2668
2669
2670def _free_type_members(ri, var, struct, ref=''):
2671    for _, attr in struct.member_list():
2672        attr.free(ri, var, ref)
2673
2674
2675def _free_type(ri, direction, struct):
2676    var = free_arg_name(direction)
2677
2678    print_free_prototype(ri, direction, suffix='')
2679    ri.cw.block_start()
2680    _free_type_members_iter(ri, struct)
2681    _free_type_members(ri, var, struct)
2682    if direction:
2683        ri.cw.p(f'free({var});')
2684    ri.cw.block_end()
2685    ri.cw.nl()
2686
2687
2688def free_rsp_nested_prototype(ri):
2689        print_free_prototype(ri, "")
2690
2691
2692def free_rsp_nested(ri, struct):
2693    _free_type(ri, "", struct)
2694
2695
2696def print_rsp_free(ri):
2697    if 'reply' not in ri.op[ri.op_mode]:
2698        return
2699    _free_type(ri, 'reply', ri.struct['reply'])
2700
2701
2702def print_dump_type_free(ri):
2703    sub_type = type_name(ri, 'reply')
2704
2705    print_free_prototype(ri, 'reply', suffix='')
2706    ri.cw.block_start()
2707    ri.cw.p(f"{sub_type} *next = rsp;")
2708    ri.cw.nl()
2709    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2710    _free_type_members_iter(ri, ri.struct['reply'])
2711    ri.cw.p('rsp = next;')
2712    ri.cw.p('next = rsp->next;')
2713    ri.cw.nl()
2714
2715    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2716    ri.cw.p('free(rsp);')
2717    ri.cw.block_end()
2718    ri.cw.block_end()
2719    ri.cw.nl()
2720
2721
2722def print_ntf_type_free(ri):
2723    print_free_prototype(ri, 'reply', suffix='')
2724    ri.cw.block_start()
2725    _free_type_members_iter(ri, ri.struct['reply'])
2726    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2727    ri.cw.p('free(rsp);')
2728    ri.cw.block_end()
2729    ri.cw.nl()
2730
2731
2732def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2733    if terminate and ri and policy_should_be_static(struct.family):
2734        return
2735
2736    if terminate:
2737        prefix = 'extern '
2738    else:
2739        if ri and policy_should_be_static(struct.family):
2740            prefix = 'static '
2741        else:
2742            prefix = ''
2743
2744    suffix = ';' if terminate else ' = {'
2745
2746    max_attr = struct.attr_max_val
2747    if ri:
2748        name = ri.op.render_name
2749        if ri.op.dual_policy:
2750            name += '_' + ri.op_mode
2751    else:
2752        name = struct.render_name
2753    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2754
2755
2756def print_req_policy(cw, struct, ri=None):
2757    if ri and ri.op:
2758        cw.ifdef_block(ri.op.get('config-cond', None))
2759    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2760    for _, arg in struct.member_list():
2761        arg.attr_policy(cw)
2762    cw.p("};")
2763    cw.ifdef_block(None)
2764    cw.nl()
2765
2766
2767def kernel_can_gen_family_struct(family):
2768    return family.proto == 'genetlink'
2769
2770
2771def policy_should_be_static(family):
2772    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2773
2774
2775def print_kernel_policy_ranges(family, cw):
2776    first = True
2777    for _, attr_set in family.attr_sets.items():
2778        if attr_set.subset_of:
2779            continue
2780
2781        for _, attr in attr_set.items():
2782            if not attr.request:
2783                continue
2784            if 'full-range' not in attr.checks:
2785                continue
2786
2787            if first:
2788                cw.p('/* Integer value ranges */')
2789                first = False
2790
2791            sign = '' if attr.type[0] == 'u' else '_signed'
2792            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2793            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2794            members = []
2795            if 'min' in attr.checks:
2796                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2797            if 'max' in attr.checks:
2798                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2799            cw.write_struct_init(members)
2800            cw.block_end(line=';')
2801            cw.nl()
2802
2803
2804def print_kernel_policy_sparse_enum_validates(family, cw):
2805    first = True
2806    for _, attr_set in family.attr_sets.items():
2807        if attr_set.subset_of:
2808            continue
2809
2810        for _, attr in attr_set.items():
2811            if not attr.request:
2812                continue
2813            if not attr.enum_name:
2814                continue
2815            if 'sparse' not in attr.checks:
2816                continue
2817
2818            if first:
2819                cw.p('/* Sparse enums validation callbacks */')
2820                first = False
2821
2822            cw.write_func_prot('static int', f'{c_lower(attr.enum_name)}_validate',
2823                               ['const struct nlattr *attr', 'struct netlink_ext_ack *extack'])
2824            cw.block_start()
2825            cw.block_start(line=f'switch (nla_get_{attr["type"]}(attr))')
2826            enum = family.consts[attr['enum']]
2827            first_entry = True
2828            for entry in enum.entries.values():
2829                if first_entry:
2830                    first_entry = False
2831                else:
2832                    cw.p('fallthrough;')
2833                cw.p(f'case {entry.c_name}:')
2834            cw.p('return 0;')
2835            cw.block_end()
2836            cw.p('NL_SET_ERR_MSG_ATTR(extack, attr, "invalid enum value");')
2837            cw.p('return -EINVAL;')
2838            cw.block_end()
2839            cw.nl()
2840
2841
2842def print_kernel_op_table_fwd(family, cw, terminate):
2843    exported = not kernel_can_gen_family_struct(family)
2844
2845    if not terminate or exported:
2846        cw.p(f"/* Ops table for {family.ident_name} */")
2847
2848        pol_to_struct = {'global': 'genl_small_ops',
2849                         'per-op': 'genl_ops',
2850                         'split': 'genl_split_ops'}
2851        struct_type = pol_to_struct[family.kernel_policy]
2852
2853        if not exported:
2854            cnt = ""
2855        elif family.kernel_policy == 'split':
2856            cnt = 0
2857            for op in family.ops.values():
2858                if 'do' in op:
2859                    cnt += 1
2860                if 'dump' in op:
2861                    cnt += 1
2862        else:
2863            cnt = len(family.ops)
2864
2865        qual = 'static const' if not exported else 'const'
2866        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2867        if terminate:
2868            cw.p(f"extern {line};")
2869        else:
2870            cw.block_start(line=line + ' =')
2871
2872    if not terminate:
2873        return
2874
2875    cw.nl()
2876    for name in family.hooks['pre']['do']['list']:
2877        cw.write_func_prot('int', c_lower(name),
2878                           ['const struct genl_split_ops *ops',
2879                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2880    for name in family.hooks['post']['do']['list']:
2881        cw.write_func_prot('void', c_lower(name),
2882                           ['const struct genl_split_ops *ops',
2883                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2884    for name in family.hooks['pre']['dump']['list']:
2885        cw.write_func_prot('int', c_lower(name),
2886                           ['struct netlink_callback *cb'], suffix=';')
2887    for name in family.hooks['post']['dump']['list']:
2888        cw.write_func_prot('int', c_lower(name),
2889                           ['struct netlink_callback *cb'], suffix=';')
2890
2891    cw.nl()
2892
2893    for op_name, op in family.ops.items():
2894        if op.is_async:
2895            continue
2896
2897        if 'do' in op:
2898            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2899            cw.write_func_prot('int', name,
2900                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2901
2902        if 'dump' in op:
2903            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2904            cw.write_func_prot('int', name,
2905                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2906    cw.nl()
2907
2908
2909def print_kernel_op_table_hdr(family, cw):
2910    print_kernel_op_table_fwd(family, cw, terminate=True)
2911
2912
2913def print_kernel_op_table(family, cw):
2914    print_kernel_op_table_fwd(family, cw, terminate=False)
2915    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2916        for op_name, op in family.ops.items():
2917            if op.is_async:
2918                continue
2919
2920            cw.ifdef_block(op.get('config-cond', None))
2921            cw.block_start()
2922            members = [('cmd', op.enum_name)]
2923            if 'dont-validate' in op:
2924                members.append(('validate',
2925                                ' | '.join([c_upper('genl-dont-validate-' + x)
2926                                            for x in op['dont-validate']])), )
2927            for op_mode in ['do', 'dump']:
2928                if op_mode in op:
2929                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2930                    members.append((op_mode + 'it', name))
2931            if family.kernel_policy == 'per-op':
2932                struct = Struct(family, op['attribute-set'],
2933                                type_list=op['do']['request']['attributes'])
2934
2935                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2936                members.append(('policy', name))
2937                members.append(('maxattr', struct.attr_max_val.enum_name))
2938            if 'flags' in op:
2939                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2940            cw.write_struct_init(members)
2941            cw.block_end(line=',')
2942    elif family.kernel_policy == 'split':
2943        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2944                    'dump': {'pre': 'start', 'post': 'done'}}
2945
2946        for op_name, op in family.ops.items():
2947            for op_mode in ['do', 'dump']:
2948                if op.is_async or op_mode not in op:
2949                    continue
2950
2951                cw.ifdef_block(op.get('config-cond', None))
2952                cw.block_start()
2953                members = [('cmd', op.enum_name)]
2954                if 'dont-validate' in op:
2955                    dont_validate = []
2956                    for x in op['dont-validate']:
2957                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2958                            continue
2959                        if op_mode == "dump" and x == 'strict':
2960                            continue
2961                        dont_validate.append(x)
2962
2963                    if dont_validate:
2964                        members.append(('validate',
2965                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2966                                                    for x in dont_validate])), )
2967                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2968                if 'pre' in op[op_mode]:
2969                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2970                members.append((op_mode + 'it', name))
2971                if 'post' in op[op_mode]:
2972                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2973                if 'request' in op[op_mode]:
2974                    struct = Struct(family, op['attribute-set'],
2975                                    type_list=op[op_mode]['request']['attributes'])
2976
2977                    if op.dual_policy:
2978                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2979                    else:
2980                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2981                    members.append(('policy', name))
2982                    members.append(('maxattr', struct.attr_max_val.enum_name))
2983                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2984                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2985                cw.write_struct_init(members)
2986                cw.block_end(line=',')
2987    cw.ifdef_block(None)
2988
2989    cw.block_end(line=';')
2990    cw.nl()
2991
2992
2993def print_kernel_mcgrp_hdr(family, cw):
2994    if not family.mcgrps['list']:
2995        return
2996
2997    cw.block_start('enum')
2998    for grp in family.mcgrps['list']:
2999        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
3000        cw.p(grp_id)
3001    cw.block_end(';')
3002    cw.nl()
3003
3004
3005def print_kernel_mcgrp_src(family, cw):
3006    if not family.mcgrps['list']:
3007        return
3008
3009    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
3010    for grp in family.mcgrps['list']:
3011        name = grp['name']
3012        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
3013        cw.p('[' + grp_id + '] = { "' + name + '", },')
3014    cw.block_end(';')
3015    cw.nl()
3016
3017
3018def print_kernel_family_struct_hdr(family, cw):
3019    if not kernel_can_gen_family_struct(family):
3020        return
3021
3022    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
3023    cw.nl()
3024    if 'sock-priv' in family.kernel_family:
3025        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
3026        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
3027        cw.nl()
3028
3029
3030def print_kernel_family_struct_src(family, cw):
3031    if not kernel_can_gen_family_struct(family):
3032        return
3033
3034    if 'sock-priv' in family.kernel_family:
3035        # Generate "trampolines" to make CFI happy
3036        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
3037                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
3038                      ["void *priv"])
3039        cw.nl()
3040        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
3041                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
3042                      ["void *priv"])
3043        cw.nl()
3044
3045    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
3046    cw.p('.name\t\t= ' + family.fam_key + ',')
3047    cw.p('.version\t= ' + family.ver_key + ',')
3048    cw.p('.netnsok\t= true,')
3049    cw.p('.parallel_ops\t= true,')
3050    cw.p('.module\t\t= THIS_MODULE,')
3051    if family.kernel_policy == 'per-op':
3052        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
3053        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3054    elif family.kernel_policy == 'split':
3055        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
3056        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
3057    if family.mcgrps['list']:
3058        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
3059        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
3060    if 'sock-priv' in family.kernel_family:
3061        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
3062        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
3063        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
3064    cw.block_end(';')
3065
3066
3067def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
3068    start_line = 'enum'
3069    if enum_name in obj:
3070        if obj[enum_name]:
3071            start_line = 'enum ' + c_lower(obj[enum_name])
3072    elif ckey and ckey in obj:
3073        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
3074    cw.block_start(line=start_line)
3075
3076
3077def render_uapi_unified(family, cw, max_by_define, separate_ntf):
3078    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
3079    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
3080    max_value = f"({cnt_name} - 1)"
3081
3082    uapi_enum_start(family, cw, family['operations'], 'enum-name')
3083    val = 0
3084    for op in family.msgs.values():
3085        if separate_ntf and ('notify' in op or 'event' in op):
3086            continue
3087
3088        suffix = ','
3089        if op.value != val:
3090            suffix = f" = {op.value},"
3091            val = op.value
3092        cw.p(op.enum_name + suffix)
3093        val += 1
3094    cw.nl()
3095    cw.p(cnt_name + ('' if max_by_define else ','))
3096    if not max_by_define:
3097        cw.p(f"{max_name} = {max_value}")
3098    cw.block_end(line=';')
3099    if max_by_define:
3100        cw.p(f"#define {max_name} {max_value}")
3101    cw.nl()
3102
3103
3104def render_uapi_directional(family, cw, max_by_define):
3105    max_name = f"{family.op_prefix}USER_MAX"
3106    cnt_name = f"__{family.op_prefix}USER_CNT"
3107    max_value = f"({cnt_name} - 1)"
3108
3109    cw.block_start(line='enum')
3110    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
3111    val = 0
3112    for op in family.msgs.values():
3113        if 'do' in op and 'event' not in op:
3114            suffix = ','
3115            if op.value and op.value != val:
3116                suffix = f" = {op.value},"
3117                val = op.value
3118            cw.p(op.enum_name + suffix)
3119            val += 1
3120    cw.nl()
3121    cw.p(cnt_name + ('' if max_by_define else ','))
3122    if not max_by_define:
3123        cw.p(f"{max_name} = {max_value}")
3124    cw.block_end(line=';')
3125    if max_by_define:
3126        cw.p(f"#define {max_name} {max_value}")
3127    cw.nl()
3128
3129    max_name = f"{family.op_prefix}KERNEL_MAX"
3130    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
3131    max_value = f"({cnt_name} - 1)"
3132
3133    cw.block_start(line='enum')
3134    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
3135    val = 0
3136    for op in family.msgs.values():
3137        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
3138            enum_name = op.enum_name
3139            if 'event' not in op and 'notify' not in op:
3140                enum_name = f'{enum_name}_REPLY'
3141
3142            suffix = ','
3143            if op.value and op.value != val:
3144                suffix = f" = {op.value},"
3145                val = op.value
3146            cw.p(enum_name + suffix)
3147            val += 1
3148    cw.nl()
3149    cw.p(cnt_name + ('' if max_by_define else ','))
3150    if not max_by_define:
3151        cw.p(f"{max_name} = {max_value}")
3152    cw.block_end(line=';')
3153    if max_by_define:
3154        cw.p(f"#define {max_name} {max_value}")
3155    cw.nl()
3156
3157
3158def render_uapi(family, cw):
3159    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
3160    hdr_prot = hdr_prot.replace('/', '_')
3161    cw.p('#ifndef ' + hdr_prot)
3162    cw.p('#define ' + hdr_prot)
3163    cw.nl()
3164
3165    defines = [(family.fam_key, family["name"]),
3166               (family.ver_key, family.get('version', 1))]
3167    cw.writes_defines(defines)
3168    cw.nl()
3169
3170    defines = []
3171    for const in family['definitions']:
3172        if const.get('header'):
3173            continue
3174
3175        if const['type'] != 'const':
3176            cw.writes_defines(defines)
3177            defines = []
3178            cw.nl()
3179
3180        # Write kdoc for enum and flags (one day maybe also structs)
3181        if const['type'] == 'enum' or const['type'] == 'flags':
3182            enum = family.consts[const['name']]
3183
3184            if enum.header:
3185                continue
3186
3187            if enum.has_doc():
3188                if enum.has_entry_doc():
3189                    cw.p('/**')
3190                    doc = ''
3191                    if 'doc' in enum:
3192                        doc = ' - ' + enum['doc']
3193                    cw.write_doc_line(enum.enum_name + doc)
3194                else:
3195                    cw.p('/*')
3196                    cw.write_doc_line(enum['doc'], indent=False)
3197                for entry in enum.entries.values():
3198                    if entry.has_doc():
3199                        doc = '@' + entry.c_name + ': ' + entry['doc']
3200                        cw.write_doc_line(doc)
3201                cw.p(' */')
3202
3203            uapi_enum_start(family, cw, const, 'name')
3204            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
3205            for entry in enum.entries.values():
3206                suffix = ','
3207                if entry.value_change:
3208                    suffix = f" = {entry.user_value()}" + suffix
3209                cw.p(entry.c_name + suffix)
3210
3211            if const.get('render-max', False):
3212                cw.nl()
3213                cw.p('/* private: */')
3214                if const['type'] == 'flags':
3215                    max_name = c_upper(name_pfx + 'mask')
3216                    max_val = f' = {enum.get_mask()},'
3217                    cw.p(max_name + max_val)
3218                else:
3219                    cnt_name = enum.enum_cnt_name
3220                    max_name = c_upper(name_pfx + 'max')
3221                    if not cnt_name:
3222                        cnt_name = '__' + name_pfx + 'max'
3223                    cw.p(c_upper(cnt_name) + ',')
3224                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
3225            cw.block_end(line=';')
3226            cw.nl()
3227        elif const['type'] == 'const':
3228            name_pfx = const.get('name-prefix', f"{family.ident_name}-")
3229            defines.append([c_upper(family.get('c-define-name',
3230                                               f"{name_pfx}{const['name']}")),
3231                            const['value']])
3232
3233    if defines:
3234        cw.writes_defines(defines)
3235        cw.nl()
3236
3237    max_by_define = family.get('max-by-define', False)
3238
3239    for _, attr_set in family.attr_sets.items():
3240        if attr_set.subset_of:
3241            continue
3242
3243        max_value = f"({attr_set.cnt_name} - 1)"
3244
3245        val = 0
3246        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
3247        for _, attr in attr_set.items():
3248            suffix = ','
3249            if attr.value != val:
3250                suffix = f" = {attr.value},"
3251                val = attr.value
3252            val += 1
3253            cw.p(attr.enum_name + suffix)
3254        if attr_set.items():
3255            cw.nl()
3256        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
3257        if not max_by_define:
3258            cw.p(f"{attr_set.max_name} = {max_value}")
3259        cw.block_end(line=';')
3260        if max_by_define:
3261            cw.p(f"#define {attr_set.max_name} {max_value}")
3262        cw.nl()
3263
3264    # Commands
3265    separate_ntf = 'async-prefix' in family['operations']
3266
3267    if family.msg_id_model == 'unified':
3268        render_uapi_unified(family, cw, max_by_define, separate_ntf)
3269    elif family.msg_id_model == 'directional':
3270        render_uapi_directional(family, cw, max_by_define)
3271    else:
3272        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
3273
3274    if separate_ntf:
3275        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
3276        for op in family.msgs.values():
3277            if separate_ntf and not ('notify' in op or 'event' in op):
3278                continue
3279
3280            suffix = ','
3281            if 'value' in op:
3282                suffix = f" = {op['value']},"
3283            cw.p(op.enum_name + suffix)
3284        cw.block_end(line=';')
3285        cw.nl()
3286
3287    # Multicast
3288    defines = []
3289    for grp in family.mcgrps['list']:
3290        name = grp['name']
3291        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
3292                        f'{name}'])
3293    cw.nl()
3294    if defines:
3295        cw.writes_defines(defines)
3296        cw.nl()
3297
3298    cw.p(f'#endif /* {hdr_prot} */')
3299
3300
3301def _render_user_ntf_entry(ri, op):
3302    if not ri.family.is_classic():
3303        ri.cw.block_start(line=f"[{op.enum_name}] = ")
3304    else:
3305        crud_op = ri.family.req_by_value[op.rsp_value]
3306        ri.cw.block_start(line=f"[{crud_op.enum_name}] = ")
3307    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
3308    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
3309    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
3310    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
3311    ri.cw.block_end(line=',')
3312
3313
3314def render_user_family(family, cw, prototype):
3315    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
3316    if prototype:
3317        cw.p(f'extern {symbol};')
3318        return
3319
3320    if family.ntfs:
3321        cw.block_start(line=f"static const struct ynl_ntf_info {family.c_name}_ntf_info[] = ")
3322        for ntf_op_name, ntf_op in family.ntfs.items():
3323            if 'notify' in ntf_op:
3324                op = family.ops[ntf_op['notify']]
3325                ri = RenderInfo(cw, family, "user", op, "notify")
3326            elif 'event' in ntf_op:
3327                ri = RenderInfo(cw, family, "user", ntf_op, "event")
3328            else:
3329                raise Exception('Invalid notification ' + ntf_op_name)
3330            _render_user_ntf_entry(ri, ntf_op)
3331        for op_name, op in family.ops.items():
3332            if 'event' not in op:
3333                continue
3334            ri = RenderInfo(cw, family, "user", op, "event")
3335            _render_user_ntf_entry(ri, op)
3336        cw.block_end(line=";")
3337        cw.nl()
3338
3339    cw.block_start(f'{symbol} = ')
3340    cw.p(f'.name\t\t= "{family.c_name}",')
3341    if family.is_classic():
3342        cw.p('.is_classic\t= true,')
3343        cw.p(f'.classic_id\t= {family.get("protonum")},')
3344    if family.is_classic():
3345        if family.fixed_header:
3346            cw.p(f'.hdr_len\t= sizeof(struct {c_lower(family.fixed_header)}),')
3347    elif family.fixed_header:
3348        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
3349    else:
3350        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
3351    if family.ntfs:
3352        cw.p(f".ntf_info\t= {family.c_name}_ntf_info,")
3353        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family.c_name}_ntf_info),")
3354    cw.block_end(line=';')
3355
3356
3357def family_contains_bitfield32(family):
3358    for _, attr_set in family.attr_sets.items():
3359        if attr_set.subset_of:
3360            continue
3361        for _, attr in attr_set.items():
3362            if attr.type == "bitfield32":
3363                return True
3364    return False
3365
3366
3367def find_kernel_root(full_path):
3368    sub_path = ''
3369    while True:
3370        sub_path = os.path.join(os.path.basename(full_path), sub_path)
3371        full_path = os.path.dirname(full_path)
3372        maintainers = os.path.join(full_path, "MAINTAINERS")
3373        if os.path.exists(maintainers):
3374            return full_path, sub_path[:-1]
3375
3376
3377def main():
3378    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
3379    parser.add_argument('--mode', dest='mode', type=str, required=True,
3380                        choices=('user', 'kernel', 'uapi'))
3381    parser.add_argument('--spec', dest='spec', type=str, required=True)
3382    parser.add_argument('--header', dest='header', action='store_true', default=None)
3383    parser.add_argument('--source', dest='header', action='store_false')
3384    parser.add_argument('--user-header', nargs='+', default=[])
3385    parser.add_argument('--cmp-out', action='store_true', default=None,
3386                        help='Do not overwrite the output file if the new output is identical to the old')
3387    parser.add_argument('--exclude-op', action='append', default=[])
3388    parser.add_argument('-o', dest='out_file', type=str, default=None)
3389    args = parser.parse_args()
3390
3391    if args.header is None:
3392        parser.error("--header or --source is required")
3393
3394    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
3395
3396    try:
3397        parsed = Family(args.spec, exclude_ops)
3398        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
3399            print('Spec license:', parsed.license)
3400            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
3401            os.sys.exit(1)
3402    except yaml.YAMLError as exc:
3403        print(exc)
3404        os.sys.exit(1)
3405        return
3406
3407    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
3408
3409    _, spec_kernel = find_kernel_root(args.spec)
3410    if args.mode == 'uapi' or args.header:
3411        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
3412    else:
3413        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
3414    cw.p("/* Do not edit directly, auto-generated from: */")
3415    cw.p(f"/*\t{spec_kernel} */")
3416    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
3417    if args.exclude_op or args.user_header:
3418        line = ''
3419        line += ' --user-header '.join([''] + args.user_header)
3420        line += ' --exclude-op '.join([''] + args.exclude_op)
3421        cw.p(f'/* YNL-ARG{line} */')
3422    cw.nl()
3423
3424    if args.mode == 'uapi':
3425        render_uapi(parsed, cw)
3426        return
3427
3428    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
3429    if args.header:
3430        cw.p('#ifndef ' + hdr_prot)
3431        cw.p('#define ' + hdr_prot)
3432        cw.nl()
3433
3434    if args.out_file:
3435        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
3436    else:
3437        hdr_file = "generated_header_file.h"
3438
3439    if args.mode == 'kernel':
3440        cw.p('#include <net/netlink.h>')
3441        cw.p('#include <net/genetlink.h>')
3442        cw.nl()
3443        if not args.header:
3444            if args.out_file:
3445                cw.p(f'#include "{hdr_file}"')
3446            cw.nl()
3447        headers = ['uapi/' + parsed.uapi_header]
3448        headers += parsed.kernel_family.get('headers', [])
3449    else:
3450        cw.p('#include <stdlib.h>')
3451        cw.p('#include <string.h>')
3452        if args.header:
3453            cw.p('#include <linux/types.h>')
3454            if family_contains_bitfield32(parsed):
3455                cw.p('#include <linux/netlink.h>')
3456        else:
3457            cw.p(f'#include "{hdr_file}"')
3458            cw.p('#include "ynl.h"')
3459        headers = []
3460    for definition in parsed['definitions'] + parsed['attribute-sets']:
3461        if 'header' in definition:
3462            headers.append(definition['header'])
3463    if args.mode == 'user':
3464        headers.append(parsed.uapi_header)
3465    seen_header = []
3466    for one in headers:
3467        if one not in seen_header:
3468            cw.p(f"#include <{one}>")
3469            seen_header.append(one)
3470    cw.nl()
3471
3472    if args.mode == "user":
3473        if not args.header:
3474            cw.p("#include <linux/genetlink.h>")
3475            cw.nl()
3476            for one in args.user_header:
3477                cw.p(f'#include "{one}"')
3478        else:
3479            cw.p('struct ynl_sock;')
3480            cw.nl()
3481            render_user_family(parsed, cw, True)
3482        cw.nl()
3483
3484    if args.mode == "kernel":
3485        if args.header:
3486            for _, struct in sorted(parsed.pure_nested_structs.items()):
3487                if struct.request:
3488                    cw.p('/* Common nested types */')
3489                    break
3490            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3491                if struct.request:
3492                    print_req_policy_fwd(cw, struct)
3493            cw.nl()
3494
3495            if parsed.kernel_policy == 'global':
3496                cw.p(f"/* Global operation policy for {parsed.name} */")
3497
3498                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3499                print_req_policy_fwd(cw, struct)
3500                cw.nl()
3501
3502            if parsed.kernel_policy in {'per-op', 'split'}:
3503                for op_name, op in parsed.ops.items():
3504                    if 'do' in op and 'event' not in op:
3505                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
3506                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
3507                        cw.nl()
3508
3509            print_kernel_op_table_hdr(parsed, cw)
3510            print_kernel_mcgrp_hdr(parsed, cw)
3511            print_kernel_family_struct_hdr(parsed, cw)
3512        else:
3513            print_kernel_policy_ranges(parsed, cw)
3514            print_kernel_policy_sparse_enum_validates(parsed, cw)
3515
3516            for _, struct in sorted(parsed.pure_nested_structs.items()):
3517                if struct.request:
3518                    cw.p('/* Common nested types */')
3519                    break
3520            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
3521                if struct.request:
3522                    print_req_policy(cw, struct)
3523            cw.nl()
3524
3525            if parsed.kernel_policy == 'global':
3526                cw.p(f"/* Global operation policy for {parsed.name} */")
3527
3528                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
3529                print_req_policy(cw, struct)
3530                cw.nl()
3531
3532            for op_name, op in parsed.ops.items():
3533                if parsed.kernel_policy in {'per-op', 'split'}:
3534                    for op_mode in ['do', 'dump']:
3535                        if op_mode in op and 'request' in op[op_mode]:
3536                            cw.p(f"/* {op.enum_name} - {op_mode} */")
3537                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
3538                            print_req_policy(cw, ri.struct['request'], ri=ri)
3539                            cw.nl()
3540
3541            print_kernel_op_table(parsed, cw)
3542            print_kernel_mcgrp_src(parsed, cw)
3543            print_kernel_family_struct_src(parsed, cw)
3544
3545    if args.mode == "user":
3546        if args.header:
3547            cw.p('/* Enums */')
3548            put_op_name_fwd(parsed, cw)
3549
3550            for name, const in parsed.consts.items():
3551                if isinstance(const, EnumSet):
3552                    put_enum_to_str_fwd(parsed, cw, const)
3553            cw.nl()
3554
3555            cw.p('/* Common nested types */')
3556            for attr_set, struct in parsed.pure_nested_structs.items():
3557                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3558                print_type_full(ri, struct)
3559
3560            for op_name, op in parsed.ops.items():
3561                cw.p(f"/* ============== {op.enum_name} ============== */")
3562
3563                if 'do' in op and 'event' not in op:
3564                    cw.p(f"/* {op.enum_name} - do */")
3565                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3566                    print_req_type(ri)
3567                    print_req_type_helpers(ri)
3568                    cw.nl()
3569                    print_rsp_type(ri)
3570                    print_rsp_type_helpers(ri)
3571                    cw.nl()
3572                    print_req_prototype(ri)
3573                    cw.nl()
3574
3575                if 'dump' in op:
3576                    cw.p(f"/* {op.enum_name} - dump */")
3577                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3578                    print_req_type(ri)
3579                    print_req_type_helpers(ri)
3580                    if not ri.type_consistent or ri.type_oneside:
3581                        print_rsp_type(ri)
3582                    print_wrapped_type(ri)
3583                    print_dump_prototype(ri)
3584                    cw.nl()
3585
3586                if op.has_ntf:
3587                    cw.p(f"/* {op.enum_name} - notify */")
3588                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3589                    if not ri.type_consistent:
3590                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3591                    print_wrapped_type(ri)
3592
3593            for op_name, op in parsed.ntfs.items():
3594                if 'event' in op:
3595                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3596                    cw.p(f"/* {op.enum_name} - event */")
3597                    print_rsp_type(ri)
3598                    cw.nl()
3599                    print_wrapped_type(ri)
3600            cw.nl()
3601        else:
3602            cw.p('/* Enums */')
3603            put_op_name(parsed, cw)
3604
3605            for name, const in parsed.consts.items():
3606                if isinstance(const, EnumSet):
3607                    put_enum_to_str(parsed, cw, const)
3608            cw.nl()
3609
3610            has_recursive_nests = False
3611            cw.p('/* Policies */')
3612            for struct in parsed.pure_nested_structs.values():
3613                if struct.recursive:
3614                    put_typol_fwd(cw, struct)
3615                    has_recursive_nests = True
3616            if has_recursive_nests:
3617                cw.nl()
3618            for struct in parsed.pure_nested_structs.values():
3619                put_typol(cw, struct)
3620            for name in parsed.root_sets:
3621                struct = Struct(parsed, name)
3622                put_typol(cw, struct)
3623
3624            cw.p('/* Common nested types */')
3625            if has_recursive_nests:
3626                for attr_set, struct in parsed.pure_nested_structs.items():
3627                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3628                    free_rsp_nested_prototype(ri)
3629                    if struct.request:
3630                        put_req_nested_prototype(ri, struct)
3631                    if struct.reply:
3632                        parse_rsp_nested_prototype(ri, struct)
3633                cw.nl()
3634            for attr_set, struct in parsed.pure_nested_structs.items():
3635                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3636
3637                free_rsp_nested(ri, struct)
3638                if struct.request:
3639                    put_req_nested(ri, struct)
3640                if struct.reply:
3641                    parse_rsp_nested(ri, struct)
3642
3643            for op_name, op in parsed.ops.items():
3644                cw.p(f"/* ============== {op.enum_name} ============== */")
3645                if 'do' in op and 'event' not in op:
3646                    cw.p(f"/* {op.enum_name} - do */")
3647                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3648                    print_req_free(ri)
3649                    print_rsp_free(ri)
3650                    parse_rsp_msg(ri)
3651                    print_req(ri)
3652                    cw.nl()
3653
3654                if 'dump' in op:
3655                    cw.p(f"/* {op.enum_name} - dump */")
3656                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3657                    if not ri.type_consistent or ri.type_oneside:
3658                        parse_rsp_msg(ri, deref=True)
3659                    print_req_free(ri)
3660                    print_dump_type_free(ri)
3661                    print_dump(ri)
3662                    cw.nl()
3663
3664                if op.has_ntf:
3665                    cw.p(f"/* {op.enum_name} - notify */")
3666                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3667                    if not ri.type_consistent:
3668                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3669                    print_ntf_type_free(ri)
3670
3671            for op_name, op in parsed.ntfs.items():
3672                if 'event' in op:
3673                    cw.p(f"/* {op.enum_name} - event */")
3674
3675                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3676                    parse_rsp_msg(ri)
3677
3678                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3679                    print_ntf_type_free(ri)
3680            cw.nl()
3681            render_user_family(parsed, cw, False)
3682
3683    if args.header:
3684        cw.p(f'#endif /* {hdr_prot} */')
3685
3686
3687if __name__ == "__main__":
3688    main()
3689